LLVM 19.0.0git
ScalarEvolution.cpp
Go to the documentation of this file.
1//===- ScalarEvolution.cpp - Scalar Evolution Analysis --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains the implementation of the scalar evolution analysis
10// engine, which is used primarily to analyze expressions involving induction
11// variables in loops.
12//
13// There are several aspects to this library. First is the representation of
14// scalar expressions, which are represented as subclasses of the SCEV class.
15// These classes are used to represent certain types of subexpressions that we
16// can handle. We only create one SCEV of a particular shape, so
17// pointer-comparisons for equality are legal.
18//
19// One important aspect of the SCEV objects is that they are never cyclic, even
20// if there is a cycle in the dataflow for an expression (ie, a PHI node). If
21// the PHI node is one of the idioms that we can represent (e.g., a polynomial
22// recurrence) then we represent it directly as a recurrence node, otherwise we
23// represent it as a SCEVUnknown node.
24//
25// In addition to being able to represent expressions of various types, we also
26// have folders that are used to build the *canonical* representation for a
27// particular expression. These folders are capable of using a variety of
28// rewrite rules to simplify the expressions.
29//
30// Once the folders are defined, we can implement the more interesting
31// higher-level code, such as the code that recognizes PHI nodes of various
32// types, computes the execution count of a loop, etc.
33//
34// TODO: We should use these routines and value representations to implement
35// dependence analysis!
36//
37//===----------------------------------------------------------------------===//
38//
39// There are several good references for the techniques used in this analysis.
40//
41// Chains of recurrences -- a method to expedite the evaluation
42// of closed-form functions
43// Olaf Bachmann, Paul S. Wang, Eugene V. Zima
44//
45// On computational properties of chains of recurrences
46// Eugene V. Zima
47//
48// Symbolic Evaluation of Chains of Recurrences for Loop Optimization
49// Robert A. van Engelen
50//
51// Efficient Symbolic Analysis for Optimizing Compilers
52// Robert A. van Engelen
53//
54// Using the chains of recurrences algebra for data dependence testing and
55// induction variable substitution
56// MS Thesis, Johnie Birch
57//
58//===----------------------------------------------------------------------===//
59
61#include "llvm/ADT/APInt.h"
62#include "llvm/ADT/ArrayRef.h"
63#include "llvm/ADT/DenseMap.h"
66#include "llvm/ADT/FoldingSet.h"
67#include "llvm/ADT/STLExtras.h"
68#include "llvm/ADT/ScopeExit.h"
69#include "llvm/ADT/Sequence.h"
71#include "llvm/ADT/SmallSet.h"
73#include "llvm/ADT/Statistic.h"
75#include "llvm/ADT/StringRef.h"
84#include "llvm/Config/llvm-config.h"
85#include "llvm/IR/Argument.h"
86#include "llvm/IR/BasicBlock.h"
87#include "llvm/IR/CFG.h"
88#include "llvm/IR/Constant.h"
90#include "llvm/IR/Constants.h"
91#include "llvm/IR/DataLayout.h"
93#include "llvm/IR/Dominators.h"
94#include "llvm/IR/Function.h"
95#include "llvm/IR/GlobalAlias.h"
96#include "llvm/IR/GlobalValue.h"
98#include "llvm/IR/InstrTypes.h"
99#include "llvm/IR/Instruction.h"
100#include "llvm/IR/Instructions.h"
102#include "llvm/IR/Intrinsics.h"
103#include "llvm/IR/LLVMContext.h"
104#include "llvm/IR/Operator.h"
105#include "llvm/IR/PatternMatch.h"
106#include "llvm/IR/Type.h"
107#include "llvm/IR/Use.h"
108#include "llvm/IR/User.h"
109#include "llvm/IR/Value.h"
110#include "llvm/IR/Verifier.h"
112#include "llvm/Pass.h"
113#include "llvm/Support/Casting.h"
116#include "llvm/Support/Debug.h"
121#include <algorithm>
122#include <cassert>
123#include <climits>
124#include <cstdint>
125#include <cstdlib>
126#include <map>
127#include <memory>
128#include <numeric>
129#include <optional>
130#include <tuple>
131#include <utility>
132#include <vector>
133
134using namespace llvm;
135using namespace PatternMatch;
136
137#define DEBUG_TYPE "scalar-evolution"
138
139STATISTIC(NumExitCountsComputed,
140 "Number of loop exits with predictable exit counts");
141STATISTIC(NumExitCountsNotComputed,
142 "Number of loop exits without predictable exit counts");
143STATISTIC(NumBruteForceTripCountsComputed,
144 "Number of loops with trip counts computed by force");
145
146#ifdef EXPENSIVE_CHECKS
147bool llvm::VerifySCEV = true;
148#else
149bool llvm::VerifySCEV = false;
150#endif
151
153 MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
154 cl::desc("Maximum number of iterations SCEV will "
155 "symbolically execute a constant "
156 "derived loop"),
157 cl::init(100));
158
160 "verify-scev", cl::Hidden, cl::location(VerifySCEV),
161 cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"));
163 "verify-scev-strict", cl::Hidden,
164 cl::desc("Enable stricter verification with -verify-scev is passed"));
165
167 "scev-verify-ir", cl::Hidden,
168 cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"),
169 cl::init(false));
170
172 "scev-mulops-inline-threshold", cl::Hidden,
173 cl::desc("Threshold for inlining multiplication operands into a SCEV"),
174 cl::init(32));
175
177 "scev-addops-inline-threshold", cl::Hidden,
178 cl::desc("Threshold for inlining addition operands into a SCEV"),
179 cl::init(500));
180
182 "scalar-evolution-max-scev-compare-depth", cl::Hidden,
183 cl::desc("Maximum depth of recursive SCEV complexity comparisons"),
184 cl::init(32));
185
187 "scalar-evolution-max-scev-operations-implication-depth", cl::Hidden,
188 cl::desc("Maximum depth of recursive SCEV operations implication analysis"),
189 cl::init(2));
190
192 "scalar-evolution-max-value-compare-depth", cl::Hidden,
193 cl::desc("Maximum depth of recursive value complexity comparisons"),
194 cl::init(2));
195
197 MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden,
198 cl::desc("Maximum depth of recursive arithmetics"),
199 cl::init(32));
200
202 "scalar-evolution-max-constant-evolving-depth", cl::Hidden,
203 cl::desc("Maximum depth of recursive constant evolving"), cl::init(32));
204
206 MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden,
207 cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"),
208 cl::init(8));
209
211 MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden,
212 cl::desc("Max coefficients in AddRec during evolving"),
213 cl::init(8));
214
216 HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden,
217 cl::desc("Size of the expression which is considered huge"),
218 cl::init(4096));
219
221 "scev-range-iter-threshold", cl::Hidden,
222 cl::desc("Threshold for switching to iteratively computing SCEV ranges"),
223 cl::init(32));
224
225static cl::opt<bool>
226ClassifyExpressions("scalar-evolution-classify-expressions",
227 cl::Hidden, cl::init(true),
228 cl::desc("When printing analysis, include information on every instruction"));
229
231 "scalar-evolution-use-expensive-range-sharpening", cl::Hidden,
232 cl::init(false),
233 cl::desc("Use more powerful methods of sharpening expression ranges. May "
234 "be costly in terms of compile time"));
235
237 "scalar-evolution-max-scc-analysis-depth", cl::Hidden,
238 cl::desc("Maximum amount of nodes to process while searching SCEVUnknown "
239 "Phi strongly connected components"),
240 cl::init(8));
241
242static cl::opt<bool>
243 EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden,
244 cl::desc("Handle <= and >= in finite loops"),
245 cl::init(true));
246
248 "scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden,
249 cl::desc("Infer nuw/nsw flags using context where suitable"),
250 cl::init(true));
251
252//===----------------------------------------------------------------------===//
253// SCEV class definitions
254//===----------------------------------------------------------------------===//
255
256//===----------------------------------------------------------------------===//
257// Implementation of the SCEV class.
258//
259
260#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
262 print(dbgs());
263 dbgs() << '\n';
264}
265#endif
266
268 switch (getSCEVType()) {
269 case scConstant:
270 cast<SCEVConstant>(this)->getValue()->printAsOperand(OS, false);
271 return;
272 case scVScale:
273 OS << "vscale";
274 return;
275 case scPtrToInt: {
276 const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(this);
277 const SCEV *Op = PtrToInt->getOperand();
278 OS << "(ptrtoint " << *Op->getType() << " " << *Op << " to "
279 << *PtrToInt->getType() << ")";
280 return;
281 }
282 case scTruncate: {
283 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(this);
284 const SCEV *Op = Trunc->getOperand();
285 OS << "(trunc " << *Op->getType() << " " << *Op << " to "
286 << *Trunc->getType() << ")";
287 return;
288 }
289 case scZeroExtend: {
290 const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(this);
291 const SCEV *Op = ZExt->getOperand();
292 OS << "(zext " << *Op->getType() << " " << *Op << " to "
293 << *ZExt->getType() << ")";
294 return;
295 }
296 case scSignExtend: {
297 const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(this);
298 const SCEV *Op = SExt->getOperand();
299 OS << "(sext " << *Op->getType() << " " << *Op << " to "
300 << *SExt->getType() << ")";
301 return;
302 }
303 case scAddRecExpr: {
304 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(this);
305 OS << "{" << *AR->getOperand(0);
306 for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i)
307 OS << ",+," << *AR->getOperand(i);
308 OS << "}<";
309 if (AR->hasNoUnsignedWrap())
310 OS << "nuw><";
311 if (AR->hasNoSignedWrap())
312 OS << "nsw><";
313 if (AR->hasNoSelfWrap() &&
315 OS << "nw><";
316 AR->getLoop()->getHeader()->printAsOperand(OS, /*PrintType=*/false);
317 OS << ">";
318 return;
319 }
320 case scAddExpr:
321 case scMulExpr:
322 case scUMaxExpr:
323 case scSMaxExpr:
324 case scUMinExpr:
325 case scSMinExpr:
327 const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(this);
328 const char *OpStr = nullptr;
329 switch (NAry->getSCEVType()) {
330 case scAddExpr: OpStr = " + "; break;
331 case scMulExpr: OpStr = " * "; break;
332 case scUMaxExpr: OpStr = " umax "; break;
333 case scSMaxExpr: OpStr = " smax "; break;
334 case scUMinExpr:
335 OpStr = " umin ";
336 break;
337 case scSMinExpr:
338 OpStr = " smin ";
339 break;
341 OpStr = " umin_seq ";
342 break;
343 default:
344 llvm_unreachable("There are no other nary expression types.");
345 }
346 OS << "(";
347 ListSeparator LS(OpStr);
348 for (const SCEV *Op : NAry->operands())
349 OS << LS << *Op;
350 OS << ")";
351 switch (NAry->getSCEVType()) {
352 case scAddExpr:
353 case scMulExpr:
354 if (NAry->hasNoUnsignedWrap())
355 OS << "<nuw>";
356 if (NAry->hasNoSignedWrap())
357 OS << "<nsw>";
358 break;
359 default:
360 // Nothing to print for other nary expressions.
361 break;
362 }
363 return;
364 }
365 case scUDivExpr: {
366 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(this);
367 OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")";
368 return;
369 }
370 case scUnknown:
371 cast<SCEVUnknown>(this)->getValue()->printAsOperand(OS, false);
372 return;
374 OS << "***COULDNOTCOMPUTE***";
375 return;
376 }
377 llvm_unreachable("Unknown SCEV kind!");
378}
379
381 switch (getSCEVType()) {
382 case scConstant:
383 return cast<SCEVConstant>(this)->getType();
384 case scVScale:
385 return cast<SCEVVScale>(this)->getType();
386 case scPtrToInt:
387 case scTruncate:
388 case scZeroExtend:
389 case scSignExtend:
390 return cast<SCEVCastExpr>(this)->getType();
391 case scAddRecExpr:
392 return cast<SCEVAddRecExpr>(this)->getType();
393 case scMulExpr:
394 return cast<SCEVMulExpr>(this)->getType();
395 case scUMaxExpr:
396 case scSMaxExpr:
397 case scUMinExpr:
398 case scSMinExpr:
399 return cast<SCEVMinMaxExpr>(this)->getType();
401 return cast<SCEVSequentialMinMaxExpr>(this)->getType();
402 case scAddExpr:
403 return cast<SCEVAddExpr>(this)->getType();
404 case scUDivExpr:
405 return cast<SCEVUDivExpr>(this)->getType();
406 case scUnknown:
407 return cast<SCEVUnknown>(this)->getType();
409 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
410 }
411 llvm_unreachable("Unknown SCEV kind!");
412}
413
415 switch (getSCEVType()) {
416 case scConstant:
417 case scVScale:
418 case scUnknown:
419 return {};
420 case scPtrToInt:
421 case scTruncate:
422 case scZeroExtend:
423 case scSignExtend:
424 return cast<SCEVCastExpr>(this)->operands();
425 case scAddRecExpr:
426 case scAddExpr:
427 case scMulExpr:
428 case scUMaxExpr:
429 case scSMaxExpr:
430 case scUMinExpr:
431 case scSMinExpr:
433 return cast<SCEVNAryExpr>(this)->operands();
434 case scUDivExpr:
435 return cast<SCEVUDivExpr>(this)->operands();
437 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
438 }
439 llvm_unreachable("Unknown SCEV kind!");
440}
441
442bool SCEV::isZero() const {
443 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
444 return SC->getValue()->isZero();
445 return false;
446}
447
448bool SCEV::isOne() const {
449 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
450 return SC->getValue()->isOne();
451 return false;
452}
453
455 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
456 return SC->getValue()->isMinusOne();
457 return false;
458}
459
461 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(this);
462 if (!Mul) return false;
463
464 // If there is a constant factor, it will be first.
465 const SCEVConstant *SC = dyn_cast<SCEVConstant>(Mul->getOperand(0));
466 if (!SC) return false;
467
468 // Return true if the value is negative, this matches things like (-42 * V).
469 return SC->getAPInt().isNegative();
470}
471
474
476 return S->getSCEVType() == scCouldNotCompute;
477}
478
481 ID.AddInteger(scConstant);
482 ID.AddPointer(V);
483 void *IP = nullptr;
484 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
485 SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V);
486 UniqueSCEVs.InsertNode(S, IP);
487 return S;
488}
489
491 return getConstant(ConstantInt::get(getContext(), Val));
492}
493
494const SCEV *
496 IntegerType *ITy = cast<IntegerType>(getEffectiveSCEVType(Ty));
497 return getConstant(ConstantInt::get(ITy, V, isSigned));
498}
499
502 ID.AddInteger(scVScale);
503 ID.AddPointer(Ty);
504 void *IP = nullptr;
505 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
506 return S;
507 SCEV *S = new (SCEVAllocator) SCEVVScale(ID.Intern(SCEVAllocator), Ty);
508 UniqueSCEVs.InsertNode(S, IP);
509 return S;
510}
511
513 const SCEV *Res = getConstant(Ty, EC.getKnownMinValue());
514 if (EC.isScalable())
515 Res = getMulExpr(Res, getVScale(Ty));
516 return Res;
517}
518
520 const SCEV *op, Type *ty)
521 : SCEV(ID, SCEVTy, computeExpressionSize(op)), Op(op), Ty(ty) {}
522
523SCEVPtrToIntExpr::SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op,
524 Type *ITy)
525 : SCEVCastExpr(ID, scPtrToInt, Op, ITy) {
527 "Must be a non-bit-width-changing pointer-to-integer cast!");
528}
529
531 SCEVTypes SCEVTy, const SCEV *op,
532 Type *ty)
533 : SCEVCastExpr(ID, SCEVTy, op, ty) {}
534
535SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op,
536 Type *ty)
538 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
539 "Cannot truncate non-integer value!");
540}
541
542SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID,
543 const SCEV *op, Type *ty)
545 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
546 "Cannot zero extend non-integer value!");
547}
548
549SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID,
550 const SCEV *op, Type *ty)
552 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
553 "Cannot sign extend non-integer value!");
554}
555
556void SCEVUnknown::deleted() {
557 // Clear this SCEVUnknown from various maps.
558 SE->forgetMemoizedResults(this);
559
560 // Remove this SCEVUnknown from the uniquing map.
561 SE->UniqueSCEVs.RemoveNode(this);
562
563 // Release the value.
564 setValPtr(nullptr);
565}
566
567void SCEVUnknown::allUsesReplacedWith(Value *New) {
568 // Clear this SCEVUnknown from various maps.
569 SE->forgetMemoizedResults(this);
570
571 // Remove this SCEVUnknown from the uniquing map.
572 SE->UniqueSCEVs.RemoveNode(this);
573
574 // Replace the value pointer in case someone is still using this SCEVUnknown.
575 setValPtr(New);
576}
577
578//===----------------------------------------------------------------------===//
579// SCEV Utilities
580//===----------------------------------------------------------------------===//
581
582/// Compare the two values \p LV and \p RV in terms of their "complexity" where
583/// "complexity" is a partial (and somewhat ad-hoc) relation used to order
584/// operands in SCEV expressions. \p EqCache is a set of pairs of values that
585/// have been previously deemed to be "equally complex" by this routine. It is
586/// intended to avoid exponential time complexity in cases like:
587///
588/// %a = f(%x, %y)
589/// %b = f(%a, %a)
590/// %c = f(%b, %b)
591///
592/// %d = f(%x, %y)
593/// %e = f(%d, %d)
594/// %f = f(%e, %e)
595///
596/// CompareValueComplexity(%f, %c)
597///
598/// Since we do not continue running this routine on expression trees once we
599/// have seen unequal values, there is no need to track them in the cache.
600static int
602 const LoopInfo *const LI, Value *LV, Value *RV,
603 unsigned Depth) {
604 if (Depth > MaxValueCompareDepth || EqCacheValue.isEquivalent(LV, RV))
605 return 0;
606
607 // Order pointer values after integer values. This helps SCEVExpander form
608 // GEPs.
609 bool LIsPointer = LV->getType()->isPointerTy(),
610 RIsPointer = RV->getType()->isPointerTy();
611 if (LIsPointer != RIsPointer)
612 return (int)LIsPointer - (int)RIsPointer;
613
614 // Compare getValueID values.
615 unsigned LID = LV->getValueID(), RID = RV->getValueID();
616 if (LID != RID)
617 return (int)LID - (int)RID;
618
619 // Sort arguments by their position.
620 if (const auto *LA = dyn_cast<Argument>(LV)) {
621 const auto *RA = cast<Argument>(RV);
622 unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo();
623 return (int)LArgNo - (int)RArgNo;
624 }
625
626 if (const auto *LGV = dyn_cast<GlobalValue>(LV)) {
627 const auto *RGV = cast<GlobalValue>(RV);
628
629 const auto IsGVNameSemantic = [&](const GlobalValue *GV) {
630 auto LT = GV->getLinkage();
631 return !(GlobalValue::isPrivateLinkage(LT) ||
633 };
634
635 // Use the names to distinguish the two values, but only if the
636 // names are semantically important.
637 if (IsGVNameSemantic(LGV) && IsGVNameSemantic(RGV))
638 return LGV->getName().compare(RGV->getName());
639 }
640
641 // For instructions, compare their loop depth, and their operand count. This
642 // is pretty loose.
643 if (const auto *LInst = dyn_cast<Instruction>(LV)) {
644 const auto *RInst = cast<Instruction>(RV);
645
646 // Compare loop depths.
647 const BasicBlock *LParent = LInst->getParent(),
648 *RParent = RInst->getParent();
649 if (LParent != RParent) {
650 unsigned LDepth = LI->getLoopDepth(LParent),
651 RDepth = LI->getLoopDepth(RParent);
652 if (LDepth != RDepth)
653 return (int)LDepth - (int)RDepth;
654 }
655
656 // Compare the number of operands.
657 unsigned LNumOps = LInst->getNumOperands(),
658 RNumOps = RInst->getNumOperands();
659 if (LNumOps != RNumOps)
660 return (int)LNumOps - (int)RNumOps;
661
662 for (unsigned Idx : seq(LNumOps)) {
663 int Result =
664 CompareValueComplexity(EqCacheValue, LI, LInst->getOperand(Idx),
665 RInst->getOperand(Idx), Depth + 1);
666 if (Result != 0)
667 return Result;
668 }
669 }
670
671 EqCacheValue.unionSets(LV, RV);
672 return 0;
673}
674
675// Return negative, zero, or positive, if LHS is less than, equal to, or greater
676// than RHS, respectively. A three-way result allows recursive comparisons to be
677// more efficient.
678// If the max analysis depth was reached, return std::nullopt, assuming we do
679// not know if they are equivalent for sure.
680static std::optional<int>
683 const LoopInfo *const LI, const SCEV *LHS,
684 const SCEV *RHS, DominatorTree &DT, unsigned Depth = 0) {
685 // Fast-path: SCEVs are uniqued so we can do a quick equality check.
686 if (LHS == RHS)
687 return 0;
688
689 // Primarily, sort the SCEVs by their getSCEVType().
690 SCEVTypes LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
691 if (LType != RType)
692 return (int)LType - (int)RType;
693
694 if (EqCacheSCEV.isEquivalent(LHS, RHS))
695 return 0;
696
698 return std::nullopt;
699
700 // Aside from the getSCEVType() ordering, the particular ordering
701 // isn't very important except that it's beneficial to be consistent,
702 // so that (a + b) and (b + a) don't end up as different expressions.
703 switch (LType) {
704 case scUnknown: {
705 const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
706 const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
707
708 int X = CompareValueComplexity(EqCacheValue, LI, LU->getValue(),
709 RU->getValue(), Depth + 1);
710 if (X == 0)
711 EqCacheSCEV.unionSets(LHS, RHS);
712 return X;
713 }
714
715 case scConstant: {
716 const SCEVConstant *LC = cast<SCEVConstant>(LHS);
717 const SCEVConstant *RC = cast<SCEVConstant>(RHS);
718
719 // Compare constant values.
720 const APInt &LA = LC->getAPInt();
721 const APInt &RA = RC->getAPInt();
722 unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth();
723 if (LBitWidth != RBitWidth)
724 return (int)LBitWidth - (int)RBitWidth;
725 return LA.ult(RA) ? -1 : 1;
726 }
727
728 case scVScale: {
729 const auto *LTy = cast<IntegerType>(cast<SCEVVScale>(LHS)->getType());
730 const auto *RTy = cast<IntegerType>(cast<SCEVVScale>(RHS)->getType());
731 return LTy->getBitWidth() - RTy->getBitWidth();
732 }
733
734 case scAddRecExpr: {
735 const SCEVAddRecExpr *LA = cast<SCEVAddRecExpr>(LHS);
736 const SCEVAddRecExpr *RA = cast<SCEVAddRecExpr>(RHS);
737
738 // There is always a dominance between two recs that are used by one SCEV,
739 // so we can safely sort recs by loop header dominance. We require such
740 // order in getAddExpr.
741 const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop();
742 if (LLoop != RLoop) {
743 const BasicBlock *LHead = LLoop->getHeader(), *RHead = RLoop->getHeader();
744 assert(LHead != RHead && "Two loops share the same header?");
745 if (DT.dominates(LHead, RHead))
746 return 1;
747 assert(DT.dominates(RHead, LHead) &&
748 "No dominance between recurrences used by one SCEV?");
749 return -1;
750 }
751
752 [[fallthrough]];
753 }
754
755 case scTruncate:
756 case scZeroExtend:
757 case scSignExtend:
758 case scPtrToInt:
759 case scAddExpr:
760 case scMulExpr:
761 case scUDivExpr:
762 case scSMaxExpr:
763 case scUMaxExpr:
764 case scSMinExpr:
765 case scUMinExpr:
767 ArrayRef<const SCEV *> LOps = LHS->operands();
768 ArrayRef<const SCEV *> ROps = RHS->operands();
769
770 // Lexicographically compare n-ary-like expressions.
771 unsigned LNumOps = LOps.size(), RNumOps = ROps.size();
772 if (LNumOps != RNumOps)
773 return (int)LNumOps - (int)RNumOps;
774
775 for (unsigned i = 0; i != LNumOps; ++i) {
776 auto X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LOps[i],
777 ROps[i], DT, Depth + 1);
778 if (X != 0)
779 return X;
780 }
781 EqCacheSCEV.unionSets(LHS, RHS);
782 return 0;
783 }
784
786 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
787 }
788 llvm_unreachable("Unknown SCEV kind!");
789}
790
791/// Given a list of SCEV objects, order them by their complexity, and group
792/// objects of the same complexity together by value. When this routine is
793/// finished, we know that any duplicates in the vector are consecutive and that
794/// complexity is monotonically increasing.
795///
796/// Note that we go take special precautions to ensure that we get deterministic
797/// results from this routine. In other words, we don't want the results of
798/// this to depend on where the addresses of various SCEV objects happened to
799/// land in memory.
801 LoopInfo *LI, DominatorTree &DT) {
802 if (Ops.size() < 2) return; // Noop
803
806
807 // Whether LHS has provably less complexity than RHS.
808 auto IsLessComplex = [&](const SCEV *LHS, const SCEV *RHS) {
809 auto Complexity =
810 CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LHS, RHS, DT);
811 return Complexity && *Complexity < 0;
812 };
813 if (Ops.size() == 2) {
814 // This is the common case, which also happens to be trivially simple.
815 // Special case it.
816 const SCEV *&LHS = Ops[0], *&RHS = Ops[1];
817 if (IsLessComplex(RHS, LHS))
818 std::swap(LHS, RHS);
819 return;
820 }
821
822 // Do the rough sort by complexity.
823 llvm::stable_sort(Ops, [&](const SCEV *LHS, const SCEV *RHS) {
824 return IsLessComplex(LHS, RHS);
825 });
826
827 // Now that we are sorted by complexity, group elements of the same
828 // complexity. Note that this is, at worst, N^2, but the vector is likely to
829 // be extremely short in practice. Note that we take this approach because we
830 // do not want to depend on the addresses of the objects we are grouping.
831 for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) {
832 const SCEV *S = Ops[i];
833 unsigned Complexity = S->getSCEVType();
834
835 // If there are any objects of the same complexity and same value as this
836 // one, group them.
837 for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) {
838 if (Ops[j] == S) { // Found a duplicate.
839 // Move it to immediately after i'th element.
840 std::swap(Ops[i+1], Ops[j]);
841 ++i; // no need to rescan it.
842 if (i == e-2) return; // Done!
843 }
844 }
845 }
846}
847
848/// Returns true if \p Ops contains a huge SCEV (the subtree of S contains at
849/// least HugeExprThreshold nodes).
851 return any_of(Ops, [](const SCEV *S) {
853 });
854}
855
856//===----------------------------------------------------------------------===//
857// Simple SCEV method implementations
858//===----------------------------------------------------------------------===//
859
860/// Compute BC(It, K). The result has width W. Assume, K > 0.
861static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
862 ScalarEvolution &SE,
863 Type *ResultTy) {
864 // Handle the simplest case efficiently.
865 if (K == 1)
866 return SE.getTruncateOrZeroExtend(It, ResultTy);
867
868 // We are using the following formula for BC(It, K):
869 //
870 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
871 //
872 // Suppose, W is the bitwidth of the return value. We must be prepared for
873 // overflow. Hence, we must assure that the result of our computation is
874 // equal to the accurate one modulo 2^W. Unfortunately, division isn't
875 // safe in modular arithmetic.
876 //
877 // However, this code doesn't use exactly that formula; the formula it uses
878 // is something like the following, where T is the number of factors of 2 in
879 // K! (i.e. trailing zeros in the binary representation of K!), and ^ is
880 // exponentiation:
881 //
882 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T)
883 //
884 // This formula is trivially equivalent to the previous formula. However,
885 // this formula can be implemented much more efficiently. The trick is that
886 // K! / 2^T is odd, and exact division by an odd number *is* safe in modular
887 // arithmetic. To do exact division in modular arithmetic, all we have
888 // to do is multiply by the inverse. Therefore, this step can be done at
889 // width W.
890 //
891 // The next issue is how to safely do the division by 2^T. The way this
892 // is done is by doing the multiplication step at a width of at least W + T
893 // bits. This way, the bottom W+T bits of the product are accurate. Then,
894 // when we perform the division by 2^T (which is equivalent to a right shift
895 // by T), the bottom W bits are accurate. Extra bits are okay; they'll get
896 // truncated out after the division by 2^T.
897 //
898 // In comparison to just directly using the first formula, this technique
899 // is much more efficient; using the first formula requires W * K bits,
900 // but this formula less than W + K bits. Also, the first formula requires
901 // a division step, whereas this formula only requires multiplies and shifts.
902 //
903 // It doesn't matter whether the subtraction step is done in the calculation
904 // width or the input iteration count's width; if the subtraction overflows,
905 // the result must be zero anyway. We prefer here to do it in the width of
906 // the induction variable because it helps a lot for certain cases; CodeGen
907 // isn't smart enough to ignore the overflow, which leads to much less
908 // efficient code if the width of the subtraction is wider than the native
909 // register width.
910 //
911 // (It's possible to not widen at all by pulling out factors of 2 before
912 // the multiplication; for example, K=2 can be calculated as
913 // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires
914 // extra arithmetic, so it's not an obvious win, and it gets
915 // much more complicated for K > 3.)
916
917 // Protection from insane SCEVs; this bound is conservative,
918 // but it probably doesn't matter.
919 if (K > 1000)
920 return SE.getCouldNotCompute();
921
922 unsigned W = SE.getTypeSizeInBits(ResultTy);
923
924 // Calculate K! / 2^T and T; we divide out the factors of two before
925 // multiplying for calculating K! / 2^T to avoid overflow.
926 // Other overflow doesn't matter because we only care about the bottom
927 // W bits of the result.
928 APInt OddFactorial(W, 1);
929 unsigned T = 1;
930 for (unsigned i = 3; i <= K; ++i) {
931 APInt Mult(W, i);
932 unsigned TwoFactors = Mult.countr_zero();
933 T += TwoFactors;
934 Mult.lshrInPlace(TwoFactors);
935 OddFactorial *= Mult;
936 }
937
938 // We need at least W + T bits for the multiplication step
939 unsigned CalculationBits = W + T;
940
941 // Calculate 2^T, at width T+W.
942 APInt DivFactor = APInt::getOneBitSet(CalculationBits, T);
943
944 // Calculate the multiplicative inverse of K! / 2^T;
945 // this multiplication factor will perform the exact division by
946 // K! / 2^T.
948 APInt MultiplyFactor = OddFactorial.zext(W+1);
949 MultiplyFactor = MultiplyFactor.multiplicativeInverse(Mod);
950 MultiplyFactor = MultiplyFactor.trunc(W);
951
952 // Calculate the product, at width T+W
953 IntegerType *CalculationTy = IntegerType::get(SE.getContext(),
954 CalculationBits);
955 const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
956 for (unsigned i = 1; i != K; ++i) {
957 const SCEV *S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i));
958 Dividend = SE.getMulExpr(Dividend,
959 SE.getTruncateOrZeroExtend(S, CalculationTy));
960 }
961
962 // Divide by 2^T
963 const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
964
965 // Truncate the result, and divide by K! / 2^T.
966
967 return SE.getMulExpr(SE.getConstant(MultiplyFactor),
968 SE.getTruncateOrZeroExtend(DivResult, ResultTy));
969}
970
971/// Return the value of this chain of recurrences at the specified iteration
972/// number. We can evaluate this recurrence by multiplying each element in the
973/// chain by the binomial coefficient corresponding to it. In other words, we
974/// can evaluate {A,+,B,+,C,+,D} as:
975///
976/// A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
977///
978/// where BC(It, k) stands for binomial coefficient.
980 ScalarEvolution &SE) const {
981 return evaluateAtIteration(operands(), It, SE);
982}
983
984const SCEV *
986 const SCEV *It, ScalarEvolution &SE) {
987 assert(Operands.size() > 0);
988 const SCEV *Result = Operands[0];
989 for (unsigned i = 1, e = Operands.size(); i != e; ++i) {
990 // The computation is correct in the face of overflow provided that the
991 // multiplication is performed _after_ the evaluation of the binomial
992 // coefficient.
993 const SCEV *Coeff = BinomialCoefficient(It, i, SE, Result->getType());
994 if (isa<SCEVCouldNotCompute>(Coeff))
995 return Coeff;
996
997 Result = SE.getAddExpr(Result, SE.getMulExpr(Operands[i], Coeff));
998 }
999 return Result;
1000}
1001
1002//===----------------------------------------------------------------------===//
1003// SCEV Expression folder implementations
1004//===----------------------------------------------------------------------===//
1005
1007 unsigned Depth) {
1008 assert(Depth <= 1 &&
1009 "getLosslessPtrToIntExpr() should self-recurse at most once.");
1010
1011 // We could be called with an integer-typed operands during SCEV rewrites.
1012 // Since the operand is an integer already, just perform zext/trunc/self cast.
1013 if (!Op->getType()->isPointerTy())
1014 return Op;
1015
1016 // What would be an ID for such a SCEV cast expression?
1018 ID.AddInteger(scPtrToInt);
1019 ID.AddPointer(Op);
1020
1021 void *IP = nullptr;
1022
1023 // Is there already an expression for such a cast?
1024 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1025 return S;
1026
1027 // It isn't legal for optimizations to construct new ptrtoint expressions
1028 // for non-integral pointers.
1029 if (getDataLayout().isNonIntegralPointerType(Op->getType()))
1030 return getCouldNotCompute();
1031
1032 Type *IntPtrTy = getDataLayout().getIntPtrType(Op->getType());
1033
1034 // We can only trivially model ptrtoint if SCEV's effective (integer) type
1035 // is sufficiently wide to represent all possible pointer values.
1036 // We could theoretically teach SCEV to truncate wider pointers, but
1037 // that isn't implemented for now.
1039 getDataLayout().getTypeSizeInBits(IntPtrTy))
1040 return getCouldNotCompute();
1041
1042 // If not, is this expression something we can't reduce any further?
1043 if (auto *U = dyn_cast<SCEVUnknown>(Op)) {
1044 // Perform some basic constant folding. If the operand of the ptr2int cast
1045 // is a null pointer, don't create a ptr2int SCEV expression (that will be
1046 // left as-is), but produce a zero constant.
1047 // NOTE: We could handle a more general case, but lack motivational cases.
1048 if (isa<ConstantPointerNull>(U->getValue()))
1049 return getZero(IntPtrTy);
1050
1051 // Create an explicit cast node.
1052 // We can reuse the existing insert position since if we get here,
1053 // we won't have made any changes which would invalidate it.
1054 SCEV *S = new (SCEVAllocator)
1055 SCEVPtrToIntExpr(ID.Intern(SCEVAllocator), Op, IntPtrTy);
1056 UniqueSCEVs.InsertNode(S, IP);
1057 registerUser(S, Op);
1058 return S;
1059 }
1060
1061 assert(Depth == 0 && "getLosslessPtrToIntExpr() should not self-recurse for "
1062 "non-SCEVUnknown's.");
1063
1064 // Otherwise, we've got some expression that is more complex than just a
1065 // single SCEVUnknown. But we don't want to have a SCEVPtrToIntExpr of an
1066 // arbitrary expression, we want to have SCEVPtrToIntExpr of an SCEVUnknown
1067 // only, and the expressions must otherwise be integer-typed.
1068 // So sink the cast down to the SCEVUnknown's.
1069
1070 /// The SCEVPtrToIntSinkingRewriter takes a scalar evolution expression,
1071 /// which computes a pointer-typed value, and rewrites the whole expression
1072 /// tree so that *all* the computations are done on integers, and the only
1073 /// pointer-typed operands in the expression are SCEVUnknown.
1074 class SCEVPtrToIntSinkingRewriter
1075 : public SCEVRewriteVisitor<SCEVPtrToIntSinkingRewriter> {
1077
1078 public:
1079 SCEVPtrToIntSinkingRewriter(ScalarEvolution &SE) : SCEVRewriteVisitor(SE) {}
1080
1081 static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE) {
1082 SCEVPtrToIntSinkingRewriter Rewriter(SE);
1083 return Rewriter.visit(Scev);
1084 }
1085
1086 const SCEV *visit(const SCEV *S) {
1087 Type *STy = S->getType();
1088 // If the expression is not pointer-typed, just keep it as-is.
1089 if (!STy->isPointerTy())
1090 return S;
1091 // Else, recursively sink the cast down into it.
1092 return Base::visit(S);
1093 }
1094
1095 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
1097 bool Changed = false;
1098 for (const auto *Op : Expr->operands()) {
1099 Operands.push_back(visit(Op));
1100 Changed |= Op != Operands.back();
1101 }
1102 return !Changed ? Expr : SE.getAddExpr(Operands, Expr->getNoWrapFlags());
1103 }
1104
1105 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
1107 bool Changed = false;
1108 for (const auto *Op : Expr->operands()) {
1109 Operands.push_back(visit(Op));
1110 Changed |= Op != Operands.back();
1111 }
1112 return !Changed ? Expr : SE.getMulExpr(Operands, Expr->getNoWrapFlags());
1113 }
1114
1115 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
1116 assert(Expr->getType()->isPointerTy() &&
1117 "Should only reach pointer-typed SCEVUnknown's.");
1118 return SE.getLosslessPtrToIntExpr(Expr, /*Depth=*/1);
1119 }
1120 };
1121
1122 // And actually perform the cast sinking.
1123 const SCEV *IntOp = SCEVPtrToIntSinkingRewriter::rewrite(Op, *this);
1124 assert(IntOp->getType()->isIntegerTy() &&
1125 "We must have succeeded in sinking the cast, "
1126 "and ending up with an integer-typed expression!");
1127 return IntOp;
1128}
1129
1131 assert(Ty->isIntegerTy() && "Target type must be an integer type!");
1132
1133 const SCEV *IntOp = getLosslessPtrToIntExpr(Op);
1134 if (isa<SCEVCouldNotCompute>(IntOp))
1135 return IntOp;
1136
1137 return getTruncateOrZeroExtend(IntOp, Ty);
1138}
1139
1141 unsigned Depth) {
1142 assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) &&
1143 "This is not a truncating conversion!");
1144 assert(isSCEVable(Ty) &&
1145 "This is not a conversion to a SCEVable type!");
1146 assert(!Op->getType()->isPointerTy() && "Can't truncate pointer!");
1147 Ty = getEffectiveSCEVType(Ty);
1148
1150 ID.AddInteger(scTruncate);
1151 ID.AddPointer(Op);
1152 ID.AddPointer(Ty);
1153 void *IP = nullptr;
1154 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1155
1156 // Fold if the operand is constant.
1157 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1158 return getConstant(
1159 cast<ConstantInt>(ConstantExpr::getTrunc(SC->getValue(), Ty)));
1160
1161 // trunc(trunc(x)) --> trunc(x)
1162 if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op))
1163 return getTruncateExpr(ST->getOperand(), Ty, Depth + 1);
1164
1165 // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing
1166 if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
1167 return getTruncateOrSignExtend(SS->getOperand(), Ty, Depth + 1);
1168
1169 // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing
1170 if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1171 return getTruncateOrZeroExtend(SZ->getOperand(), Ty, Depth + 1);
1172
1173 if (Depth > MaxCastDepth) {
1174 SCEV *S =
1175 new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), Op, Ty);
1176 UniqueSCEVs.InsertNode(S, IP);
1177 registerUser(S, Op);
1178 return S;
1179 }
1180
1181 // trunc(x1 + ... + xN) --> trunc(x1) + ... + trunc(xN) and
1182 // trunc(x1 * ... * xN) --> trunc(x1) * ... * trunc(xN),
1183 // if after transforming we have at most one truncate, not counting truncates
1184 // that replace other casts.
1185 if (isa<SCEVAddExpr>(Op) || isa<SCEVMulExpr>(Op)) {
1186 auto *CommOp = cast<SCEVCommutativeExpr>(Op);
1188 unsigned numTruncs = 0;
1189 for (unsigned i = 0, e = CommOp->getNumOperands(); i != e && numTruncs < 2;
1190 ++i) {
1191 const SCEV *S = getTruncateExpr(CommOp->getOperand(i), Ty, Depth + 1);
1192 if (!isa<SCEVIntegralCastExpr>(CommOp->getOperand(i)) &&
1193 isa<SCEVTruncateExpr>(S))
1194 numTruncs++;
1195 Operands.push_back(S);
1196 }
1197 if (numTruncs < 2) {
1198 if (isa<SCEVAddExpr>(Op))
1199 return getAddExpr(Operands);
1200 if (isa<SCEVMulExpr>(Op))
1201 return getMulExpr(Operands);
1202 llvm_unreachable("Unexpected SCEV type for Op.");
1203 }
1204 // Although we checked in the beginning that ID is not in the cache, it is
1205 // possible that during recursion and different modification ID was inserted
1206 // into the cache. So if we find it, just return it.
1207 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1208 return S;
1209 }
1210
1211 // If the input value is a chrec scev, truncate the chrec's operands.
1212 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
1214 for (const SCEV *Op : AddRec->operands())
1215 Operands.push_back(getTruncateExpr(Op, Ty, Depth + 1));
1216 return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap);
1217 }
1218
1219 // Return zero if truncating to known zeros.
1220 uint32_t MinTrailingZeros = getMinTrailingZeros(Op);
1221 if (MinTrailingZeros >= getTypeSizeInBits(Ty))
1222 return getZero(Ty);
1223
1224 // The cast wasn't folded; create an explicit cast node. We can reuse
1225 // the existing insert position since if we get here, we won't have
1226 // made any changes which would invalidate it.
1227 SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator),
1228 Op, Ty);
1229 UniqueSCEVs.InsertNode(S, IP);
1230 registerUser(S, Op);
1231 return S;
1232}
1233
1234// Get the limit of a recurrence such that incrementing by Step cannot cause
1235// signed overflow as long as the value of the recurrence within the
1236// loop does not exceed this limit before incrementing.
1237static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step,
1238 ICmpInst::Predicate *Pred,
1239 ScalarEvolution *SE) {
1240 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1241 if (SE->isKnownPositive(Step)) {
1242 *Pred = ICmpInst::ICMP_SLT;
1244 SE->getSignedRangeMax(Step));
1245 }
1246 if (SE->isKnownNegative(Step)) {
1247 *Pred = ICmpInst::ICMP_SGT;
1249 SE->getSignedRangeMin(Step));
1250 }
1251 return nullptr;
1252}
1253
1254// Get the limit of a recurrence such that incrementing by Step cannot cause
1255// unsigned overflow as long as the value of the recurrence within the loop does
1256// not exceed this limit before incrementing.
1258 ICmpInst::Predicate *Pred,
1259 ScalarEvolution *SE) {
1260 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1261 *Pred = ICmpInst::ICMP_ULT;
1262
1264 SE->getUnsignedRangeMax(Step));
1265}
1266
1267namespace {
1268
1269struct ExtendOpTraitsBase {
1270 typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *,
1271 unsigned);
1272};
1273
1274// Used to make code generic over signed and unsigned overflow.
1275template <typename ExtendOp> struct ExtendOpTraits {
1276 // Members present:
1277 //
1278 // static const SCEV::NoWrapFlags WrapType;
1279 //
1280 // static const ExtendOpTraitsBase::GetExtendExprTy GetExtendExpr;
1281 //
1282 // static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1283 // ICmpInst::Predicate *Pred,
1284 // ScalarEvolution *SE);
1285};
1286
1287template <>
1288struct ExtendOpTraits<SCEVSignExtendExpr> : public ExtendOpTraitsBase {
1289 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNSW;
1290
1291 static const GetExtendExprTy GetExtendExpr;
1292
1293 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1294 ICmpInst::Predicate *Pred,
1295 ScalarEvolution *SE) {
1296 return getSignedOverflowLimitForStep(Step, Pred, SE);
1297 }
1298};
1299
1300const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1302
1303template <>
1304struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase {
1305 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNUW;
1306
1307 static const GetExtendExprTy GetExtendExpr;
1308
1309 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1310 ICmpInst::Predicate *Pred,
1311 ScalarEvolution *SE) {
1312 return getUnsignedOverflowLimitForStep(Step, Pred, SE);
1313 }
1314};
1315
1316const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1318
1319} // end anonymous namespace
1320
1321// The recurrence AR has been shown to have no signed/unsigned wrap or something
1322// close to it. Typically, if we can prove NSW/NUW for AR, then we can just as
1323// easily prove NSW/NUW for its preincrement or postincrement sibling. This
1324// allows normalizing a sign/zero extended AddRec as such: {sext/zext(Step +
1325// Start),+,Step} => {(Step + sext/zext(Start),+,Step} As a result, the
1326// expression "Step + sext/zext(PreIncAR)" is congruent with
1327// "sext/zext(PostIncAR)"
1328template <typename ExtendOpTy>
1329static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty,
1330 ScalarEvolution *SE, unsigned Depth) {
1331 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1332 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1333
1334 const Loop *L = AR->getLoop();
1335 const SCEV *Start = AR->getStart();
1336 const SCEV *Step = AR->getStepRecurrence(*SE);
1337
1338 // Check for a simple looking step prior to loop entry.
1339 const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Start);
1340 if (!SA)
1341 return nullptr;
1342
1343 // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV
1344 // subtraction is expensive. For this purpose, perform a quick and dirty
1345 // difference, by checking for Step in the operand list. Note, that
1346 // SA might have repeated ops, like %a + %a + ..., so only remove one.
1348 for (auto It = DiffOps.begin(); It != DiffOps.end(); ++It)
1349 if (*It == Step) {
1350 DiffOps.erase(It);
1351 break;
1352 }
1353
1354 if (DiffOps.size() == SA->getNumOperands())
1355 return nullptr;
1356
1357 // Try to prove `WrapType` (SCEV::FlagNSW or SCEV::FlagNUW) on `PreStart` +
1358 // `Step`:
1359
1360 // 1. NSW/NUW flags on the step increment.
1361 auto PreStartFlags =
1363 const SCEV *PreStart = SE->getAddExpr(DiffOps, PreStartFlags);
1364 const SCEVAddRecExpr *PreAR = dyn_cast<SCEVAddRecExpr>(
1365 SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap));
1366
1367 // "{S,+,X} is <nsw>/<nuw>" and "the backedge is taken at least once" implies
1368 // "S+X does not sign/unsign-overflow".
1369 //
1370
1371 const SCEV *BECount = SE->getBackedgeTakenCount(L);
1372 if (PreAR && PreAR->getNoWrapFlags(WrapType) &&
1373 !isa<SCEVCouldNotCompute>(BECount) && SE->isKnownPositive(BECount))
1374 return PreStart;
1375
1376 // 2. Direct overflow check on the step operation's expression.
1377 unsigned BitWidth = SE->getTypeSizeInBits(AR->getType());
1378 Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2);
1379 const SCEV *OperandExtendedStart =
1380 SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy, Depth),
1381 (SE->*GetExtendExpr)(Step, WideTy, Depth));
1382 if ((SE->*GetExtendExpr)(Start, WideTy, Depth) == OperandExtendedStart) {
1383 if (PreAR && AR->getNoWrapFlags(WrapType)) {
1384 // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW
1385 // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then
1386 // `PreAR` == {`PreStart`,+,`Step`} is also `WrapType`. Cache this fact.
1387 SE->setNoWrapFlags(const_cast<SCEVAddRecExpr *>(PreAR), WrapType);
1388 }
1389 return PreStart;
1390 }
1391
1392 // 3. Loop precondition.
1394 const SCEV *OverflowLimit =
1395 ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(Step, &Pred, SE);
1396
1397 if (OverflowLimit &&
1398 SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit))
1399 return PreStart;
1400
1401 return nullptr;
1402}
1403
1404// Get the normalized zero or sign extended expression for this AddRec's Start.
1405template <typename ExtendOpTy>
1406static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty,
1407 ScalarEvolution *SE,
1408 unsigned Depth) {
1409 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1410
1411 const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE, Depth);
1412 if (!PreStart)
1413 return (SE->*GetExtendExpr)(AR->getStart(), Ty, Depth);
1414
1415 return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty,
1416 Depth),
1417 (SE->*GetExtendExpr)(PreStart, Ty, Depth));
1418}
1419
1420// Try to prove away overflow by looking at "nearby" add recurrences. A
1421// motivating example for this rule: if we know `{0,+,4}` is `ult` `-1` and it
1422// does not itself wrap then we can conclude that `{1,+,4}` is `nuw`.
1423//
1424// Formally:
1425//
1426// {S,+,X} == {S-T,+,X} + T
1427// => Ext({S,+,X}) == Ext({S-T,+,X} + T)
1428//
1429// If ({S-T,+,X} + T) does not overflow ... (1)
1430//
1431// RHS == Ext({S-T,+,X} + T) == Ext({S-T,+,X}) + Ext(T)
1432//
1433// If {S-T,+,X} does not overflow ... (2)
1434//
1435// RHS == Ext({S-T,+,X}) + Ext(T) == {Ext(S-T),+,Ext(X)} + Ext(T)
1436// == {Ext(S-T)+Ext(T),+,Ext(X)}
1437//
1438// If (S-T)+T does not overflow ... (3)
1439//
1440// RHS == {Ext(S-T)+Ext(T),+,Ext(X)} == {Ext(S-T+T),+,Ext(X)}
1441// == {Ext(S),+,Ext(X)} == LHS
1442//
1443// Thus, if (1), (2) and (3) are true for some T, then
1444// Ext({S,+,X}) == {Ext(S),+,Ext(X)}
1445//
1446// (3) is implied by (1) -- "(S-T)+T does not overflow" is simply "({S-T,+,X}+T)
1447// does not overflow" restricted to the 0th iteration. Therefore we only need
1448// to check for (1) and (2).
1449//
1450// In the current context, S is `Start`, X is `Step`, Ext is `ExtendOpTy` and T
1451// is `Delta` (defined below).
1452template <typename ExtendOpTy>
1453bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start,
1454 const SCEV *Step,
1455 const Loop *L) {
1456 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1457
1458 // We restrict `Start` to a constant to prevent SCEV from spending too much
1459 // time here. It is correct (but more expensive) to continue with a
1460 // non-constant `Start` and do a general SCEV subtraction to compute
1461 // `PreStart` below.
1462 const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start);
1463 if (!StartC)
1464 return false;
1465
1466 APInt StartAI = StartC->getAPInt();
1467
1468 for (unsigned Delta : {-2, -1, 1, 2}) {
1469 const SCEV *PreStart = getConstant(StartAI - Delta);
1470
1472 ID.AddInteger(scAddRecExpr);
1473 ID.AddPointer(PreStart);
1474 ID.AddPointer(Step);
1475 ID.AddPointer(L);
1476 void *IP = nullptr;
1477 const auto *PreAR =
1478 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
1479
1480 // Give up if we don't already have the add recurrence we need because
1481 // actually constructing an add recurrence is relatively expensive.
1482 if (PreAR && PreAR->getNoWrapFlags(WrapType)) { // proves (2)
1483 const SCEV *DeltaS = getConstant(StartC->getType(), Delta);
1485 const SCEV *Limit = ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(
1486 DeltaS, &Pred, this);
1487 if (Limit && isKnownPredicate(Pred, PreAR, Limit)) // proves (1)
1488 return true;
1489 }
1490 }
1491
1492 return false;
1493}
1494
1495// Finds an integer D for an expression (C + x + y + ...) such that the top
1496// level addition in (D + (C - D + x + y + ...)) would not wrap (signed or
1497// unsigned) and the number of trailing zeros of (C - D + x + y + ...) is
1498// maximized, where C is the \p ConstantTerm, x, y, ... are arbitrary SCEVs, and
1499// the (C + x + y + ...) expression is \p WholeAddExpr.
1501 const SCEVConstant *ConstantTerm,
1502 const SCEVAddExpr *WholeAddExpr) {
1503 const APInt &C = ConstantTerm->getAPInt();
1504 const unsigned BitWidth = C.getBitWidth();
1505 // Find number of trailing zeros of (x + y + ...) w/o the C first:
1506 uint32_t TZ = BitWidth;
1507 for (unsigned I = 1, E = WholeAddExpr->getNumOperands(); I < E && TZ; ++I)
1508 TZ = std::min(TZ, SE.getMinTrailingZeros(WholeAddExpr->getOperand(I)));
1509 if (TZ) {
1510 // Set D to be as many least significant bits of C as possible while still
1511 // guaranteeing that adding D to (C - D + x + y + ...) won't cause a wrap:
1512 return TZ < BitWidth ? C.trunc(TZ).zext(BitWidth) : C;
1513 }
1514 return APInt(BitWidth, 0);
1515}
1516
1517// Finds an integer D for an affine AddRec expression {C,+,x} such that the top
1518// level addition in (D + {C-D,+,x}) would not wrap (signed or unsigned) and the
1519// number of trailing zeros of (C - D + x * n) is maximized, where C is the \p
1520// ConstantStart, x is an arbitrary \p Step, and n is the loop trip count.
1522 const APInt &ConstantStart,
1523 const SCEV *Step) {
1524 const unsigned BitWidth = ConstantStart.getBitWidth();
1525 const uint32_t TZ = SE.getMinTrailingZeros(Step);
1526 if (TZ)
1527 return TZ < BitWidth ? ConstantStart.trunc(TZ).zext(BitWidth)
1528 : ConstantStart;
1529 return APInt(BitWidth, 0);
1530}
1531
1533 const ScalarEvolution::FoldID &ID, const SCEV *S,
1536 &FoldCacheUser) {
1537 auto I = FoldCache.insert({ID, S});
1538 if (!I.second) {
1539 // Remove FoldCacheUser entry for ID when replacing an existing FoldCache
1540 // entry.
1541 auto &UserIDs = FoldCacheUser[I.first->second];
1542 assert(count(UserIDs, ID) == 1 && "unexpected duplicates in UserIDs");
1543 for (unsigned I = 0; I != UserIDs.size(); ++I)
1544 if (UserIDs[I] == ID) {
1545 std::swap(UserIDs[I], UserIDs.back());
1546 break;
1547 }
1548 UserIDs.pop_back();
1549 I.first->second = S;
1550 }
1551 auto R = FoldCacheUser.insert({S, {}});
1552 R.first->second.push_back(ID);
1553}
1554
1555const SCEV *
1557 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1558 "This is not an extending conversion!");
1559 assert(isSCEVable(Ty) &&
1560 "This is not a conversion to a SCEVable type!");
1561 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1562 Ty = getEffectiveSCEVType(Ty);
1563
1564 FoldID ID(scZeroExtend, Op, Ty);
1565 auto Iter = FoldCache.find(ID);
1566 if (Iter != FoldCache.end())
1567 return Iter->second;
1568
1569 const SCEV *S = getZeroExtendExprImpl(Op, Ty, Depth);
1570 if (!isa<SCEVZeroExtendExpr>(S))
1571 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
1572 return S;
1573}
1574
1576 unsigned Depth) {
1577 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1578 "This is not an extending conversion!");
1579 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
1580 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1581
1582 // Fold if the operand is constant.
1583 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1584 return getConstant(SC->getAPInt().zext(getTypeSizeInBits(Ty)));
1585
1586 // zext(zext(x)) --> zext(x)
1587 if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1588 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1589
1590 // Before doing any expensive analysis, check to see if we've already
1591 // computed a SCEV for this Op and Ty.
1593 ID.AddInteger(scZeroExtend);
1594 ID.AddPointer(Op);
1595 ID.AddPointer(Ty);
1596 void *IP = nullptr;
1597 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1598 if (Depth > MaxCastDepth) {
1599 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1600 Op, Ty);
1601 UniqueSCEVs.InsertNode(S, IP);
1602 registerUser(S, Op);
1603 return S;
1604 }
1605
1606 // zext(trunc(x)) --> zext(x) or x or trunc(x)
1607 if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
1608 // It's possible the bits taken off by the truncate were all zero bits. If
1609 // so, we should be able to simplify this further.
1610 const SCEV *X = ST->getOperand();
1612 unsigned TruncBits = getTypeSizeInBits(ST->getType());
1613 unsigned NewBits = getTypeSizeInBits(Ty);
1614 if (CR.truncate(TruncBits).zeroExtend(NewBits).contains(
1615 CR.zextOrTrunc(NewBits)))
1616 return getTruncateOrZeroExtend(X, Ty, Depth);
1617 }
1618
1619 // If the input value is a chrec scev, and we can prove that the value
1620 // did not overflow the old, smaller, value, we can zero extend all of the
1621 // operands (often constants). This allows analysis of something like
1622 // this: for (unsigned char X = 0; X < 100; ++X) { int Y = X; }
1623 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
1624 if (AR->isAffine()) {
1625 const SCEV *Start = AR->getStart();
1626 const SCEV *Step = AR->getStepRecurrence(*this);
1627 unsigned BitWidth = getTypeSizeInBits(AR->getType());
1628 const Loop *L = AR->getLoop();
1629
1630 // If we have special knowledge that this addrec won't overflow,
1631 // we don't need to do any further analysis.
1632 if (AR->hasNoUnsignedWrap()) {
1633 Start =
1634 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1);
1635 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1636 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1637 }
1638
1639 // Check whether the backedge-taken count is SCEVCouldNotCompute.
1640 // Note that this serves two purposes: It filters out loops that are
1641 // simply not analyzable, and it covers the case where this code is
1642 // being called from within backedge-taken count analysis, such that
1643 // attempting to ask for the backedge-taken count would likely result
1644 // in infinite recursion. In the later case, the analysis code will
1645 // cope with a conservative value, and it will take care to purge
1646 // that value once it has finished.
1647 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
1648 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
1649 // Manually compute the final value for AR, checking for overflow.
1650
1651 // Check whether the backedge-taken count can be losslessly casted to
1652 // the addrec's type. The count is always unsigned.
1653 const SCEV *CastedMaxBECount =
1654 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
1655 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
1656 CastedMaxBECount, MaxBECount->getType(), Depth);
1657 if (MaxBECount == RecastedMaxBECount) {
1658 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
1659 // Check whether Start+Step*MaxBECount has no unsigned overflow.
1660 const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step,
1662 const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul,
1664 Depth + 1),
1665 WideTy, Depth + 1);
1666 const SCEV *WideStart = getZeroExtendExpr(Start, WideTy, Depth + 1);
1667 const SCEV *WideMaxBECount =
1668 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
1669 const SCEV *OperandExtendedAdd =
1670 getAddExpr(WideStart,
1671 getMulExpr(WideMaxBECount,
1672 getZeroExtendExpr(Step, WideTy, Depth + 1),
1675 if (ZAdd == OperandExtendedAdd) {
1676 // Cache knowledge of AR NUW, which is propagated to this AddRec.
1677 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1678 // Return the expression with the addrec on the outside.
1679 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1680 Depth + 1);
1681 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1682 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1683 }
1684 // Similar to above, only this time treat the step value as signed.
1685 // This covers loops that count down.
1686 OperandExtendedAdd =
1687 getAddExpr(WideStart,
1688 getMulExpr(WideMaxBECount,
1689 getSignExtendExpr(Step, WideTy, Depth + 1),
1692 if (ZAdd == OperandExtendedAdd) {
1693 // Cache knowledge of AR NW, which is propagated to this AddRec.
1694 // Negative step causes unsigned wrap, but it still can't self-wrap.
1695 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1696 // Return the expression with the addrec on the outside.
1697 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1698 Depth + 1);
1699 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1700 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1701 }
1702 }
1703 }
1704
1705 // Normally, in the cases we can prove no-overflow via a
1706 // backedge guarding condition, we can also compute a backedge
1707 // taken count for the loop. The exceptions are assumptions and
1708 // guards present in the loop -- SCEV is not great at exploiting
1709 // these to compute max backedge taken counts, but can still use
1710 // these to prove lack of overflow. Use this fact to avoid
1711 // doing extra work that may not pay off.
1712 if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards ||
1713 !AC.assumptions().empty()) {
1714
1715 auto NewFlags = proveNoUnsignedWrapViaInduction(AR);
1716 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
1717 if (AR->hasNoUnsignedWrap()) {
1718 // Same as nuw case above - duplicated here to avoid a compile time
1719 // issue. It's not clear that the order of checks does matter, but
1720 // it's one of two issue possible causes for a change which was
1721 // reverted. Be conservative for the moment.
1722 Start =
1723 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1);
1724 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1725 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1726 }
1727
1728 // For a negative step, we can extend the operands iff doing so only
1729 // traverses values in the range zext([0,UINT_MAX]).
1730 if (isKnownNegative(Step)) {
1732 getSignedRangeMin(Step));
1735 // Cache knowledge of AR NW, which is propagated to this
1736 // AddRec. Negative step causes unsigned wrap, but it
1737 // still can't self-wrap.
1738 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1739 // Return the expression with the addrec on the outside.
1740 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1741 Depth + 1);
1742 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1743 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1744 }
1745 }
1746 }
1747
1748 // zext({C,+,Step}) --> (zext(D) + zext({C-D,+,Step}))<nuw><nsw>
1749 // if D + (C - D + Step * n) could be proven to not unsigned wrap
1750 // where D maximizes the number of trailing zeros of (C - D + Step * n)
1751 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
1752 const APInt &C = SC->getAPInt();
1753 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
1754 if (D != 0) {
1755 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1756 const SCEV *SResidual =
1757 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
1758 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1759 return getAddExpr(SZExtD, SZExtR,
1761 Depth + 1);
1762 }
1763 }
1764
1765 if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) {
1766 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1767 Start =
1768 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1);
1769 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1770 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1771 }
1772 }
1773
1774 // zext(A % B) --> zext(A) % zext(B)
1775 {
1776 const SCEV *LHS;
1777 const SCEV *RHS;
1778 if (matchURem(Op, LHS, RHS))
1779 return getURemExpr(getZeroExtendExpr(LHS, Ty, Depth + 1),
1780 getZeroExtendExpr(RHS, Ty, Depth + 1));
1781 }
1782
1783 // zext(A / B) --> zext(A) / zext(B).
1784 if (auto *Div = dyn_cast<SCEVUDivExpr>(Op))
1785 return getUDivExpr(getZeroExtendExpr(Div->getLHS(), Ty, Depth + 1),
1786 getZeroExtendExpr(Div->getRHS(), Ty, Depth + 1));
1787
1788 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1789 // zext((A + B + ...)<nuw>) --> (zext(A) + zext(B) + ...)<nuw>
1790 if (SA->hasNoUnsignedWrap()) {
1791 // If the addition does not unsign overflow then we can, by definition,
1792 // commute the zero extension with the addition operation.
1794 for (const auto *Op : SA->operands())
1795 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1796 return getAddExpr(Ops, SCEV::FlagNUW, Depth + 1);
1797 }
1798
1799 // zext(C + x + y + ...) --> (zext(D) + zext((C - D) + x + y + ...))
1800 // if D + (C - D + x + y + ...) could be proven to not unsigned wrap
1801 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1802 //
1803 // Often address arithmetics contain expressions like
1804 // (zext (add (shl X, C1), C2)), for instance, (zext (5 + (4 * X))).
1805 // This transformation is useful while proving that such expressions are
1806 // equal or differ by a small constant amount, see LoadStoreVectorizer pass.
1807 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1808 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1809 if (D != 0) {
1810 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1811 const SCEV *SResidual =
1813 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1814 return getAddExpr(SZExtD, SZExtR,
1816 Depth + 1);
1817 }
1818 }
1819 }
1820
1821 if (auto *SM = dyn_cast<SCEVMulExpr>(Op)) {
1822 // zext((A * B * ...)<nuw>) --> (zext(A) * zext(B) * ...)<nuw>
1823 if (SM->hasNoUnsignedWrap()) {
1824 // If the multiply does not unsign overflow then we can, by definition,
1825 // commute the zero extension with the multiply operation.
1827 for (const auto *Op : SM->operands())
1828 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1829 return getMulExpr(Ops, SCEV::FlagNUW, Depth + 1);
1830 }
1831
1832 // zext(2^K * (trunc X to iN)) to iM ->
1833 // 2^K * (zext(trunc X to i{N-K}) to iM)<nuw>
1834 //
1835 // Proof:
1836 //
1837 // zext(2^K * (trunc X to iN)) to iM
1838 // = zext((trunc X to iN) << K) to iM
1839 // = zext((trunc X to i{N-K}) << K)<nuw> to iM
1840 // (because shl removes the top K bits)
1841 // = zext((2^K * (trunc X to i{N-K}))<nuw>) to iM
1842 // = (2^K * (zext(trunc X to i{N-K}) to iM))<nuw>.
1843 //
1844 if (SM->getNumOperands() == 2)
1845 if (auto *MulLHS = dyn_cast<SCEVConstant>(SM->getOperand(0)))
1846 if (MulLHS->getAPInt().isPowerOf2())
1847 if (auto *TruncRHS = dyn_cast<SCEVTruncateExpr>(SM->getOperand(1))) {
1848 int NewTruncBits = getTypeSizeInBits(TruncRHS->getType()) -
1849 MulLHS->getAPInt().logBase2();
1850 Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits);
1851 return getMulExpr(
1852 getZeroExtendExpr(MulLHS, Ty),
1854 getTruncateExpr(TruncRHS->getOperand(), NewTruncTy), Ty),
1855 SCEV::FlagNUW, Depth + 1);
1856 }
1857 }
1858
1859 // zext(umin(x, y)) -> umin(zext(x), zext(y))
1860 // zext(umax(x, y)) -> umax(zext(x), zext(y))
1861 if (isa<SCEVUMinExpr>(Op) || isa<SCEVUMaxExpr>(Op)) {
1862 auto *MinMax = cast<SCEVMinMaxExpr>(Op);
1864 for (auto *Operand : MinMax->operands())
1865 Operands.push_back(getZeroExtendExpr(Operand, Ty));
1866 if (isa<SCEVUMinExpr>(MinMax))
1867 return getUMinExpr(Operands);
1868 return getUMaxExpr(Operands);
1869 }
1870
1871 // zext(umin_seq(x, y)) -> umin_seq(zext(x), zext(y))
1872 if (auto *MinMax = dyn_cast<SCEVSequentialMinMaxExpr>(Op)) {
1873 assert(isa<SCEVSequentialUMinExpr>(MinMax) && "Not supported!");
1875 for (auto *Operand : MinMax->operands())
1876 Operands.push_back(getZeroExtendExpr(Operand, Ty));
1877 return getUMinExpr(Operands, /*Sequential*/ true);
1878 }
1879
1880 // The cast wasn't folded; create an explicit cast node.
1881 // Recompute the insert position, as it may have been invalidated.
1882 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1883 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1884 Op, Ty);
1885 UniqueSCEVs.InsertNode(S, IP);
1886 registerUser(S, Op);
1887 return S;
1888}
1889
1890const SCEV *
1892 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1893 "This is not an extending conversion!");
1894 assert(isSCEVable(Ty) &&
1895 "This is not a conversion to a SCEVable type!");
1896 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1897 Ty = getEffectiveSCEVType(Ty);
1898
1899 FoldID ID(scSignExtend, Op, Ty);
1900 auto Iter = FoldCache.find(ID);
1901 if (Iter != FoldCache.end())
1902 return Iter->second;
1903
1904 const SCEV *S = getSignExtendExprImpl(Op, Ty, Depth);
1905 if (!isa<SCEVSignExtendExpr>(S))
1906 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
1907 return S;
1908}
1909
1911 unsigned Depth) {
1912 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1913 "This is not an extending conversion!");
1914 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
1915 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1916 Ty = getEffectiveSCEVType(Ty);
1917
1918 // Fold if the operand is constant.
1919 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1920 return getConstant(SC->getAPInt().sext(getTypeSizeInBits(Ty)));
1921
1922 // sext(sext(x)) --> sext(x)
1923 if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
1924 return getSignExtendExpr(SS->getOperand(), Ty, Depth + 1);
1925
1926 // sext(zext(x)) --> zext(x)
1927 if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1928 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1929
1930 // Before doing any expensive analysis, check to see if we've already
1931 // computed a SCEV for this Op and Ty.
1933 ID.AddInteger(scSignExtend);
1934 ID.AddPointer(Op);
1935 ID.AddPointer(Ty);
1936 void *IP = nullptr;
1937 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1938 // Limit recursion depth.
1939 if (Depth > MaxCastDepth) {
1940 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
1941 Op, Ty);
1942 UniqueSCEVs.InsertNode(S, IP);
1943 registerUser(S, Op);
1944 return S;
1945 }
1946
1947 // sext(trunc(x)) --> sext(x) or x or trunc(x)
1948 if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
1949 // It's possible the bits taken off by the truncate were all sign bits. If
1950 // so, we should be able to simplify this further.
1951 const SCEV *X = ST->getOperand();
1953 unsigned TruncBits = getTypeSizeInBits(ST->getType());
1954 unsigned NewBits = getTypeSizeInBits(Ty);
1955 if (CR.truncate(TruncBits).signExtend(NewBits).contains(
1956 CR.sextOrTrunc(NewBits)))
1957 return getTruncateOrSignExtend(X, Ty, Depth);
1958 }
1959
1960 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1961 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
1962 if (SA->hasNoSignedWrap()) {
1963 // If the addition does not sign overflow then we can, by definition,
1964 // commute the sign extension with the addition operation.
1966 for (const auto *Op : SA->operands())
1967 Ops.push_back(getSignExtendExpr(Op, Ty, Depth + 1));
1968 return getAddExpr(Ops, SCEV::FlagNSW, Depth + 1);
1969 }
1970
1971 // sext(C + x + y + ...) --> (sext(D) + sext((C - D) + x + y + ...))
1972 // if D + (C - D + x + y + ...) could be proven to not signed wrap
1973 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1974 //
1975 // For instance, this will bring two seemingly different expressions:
1976 // 1 + sext(5 + 20 * %x + 24 * %y) and
1977 // sext(6 + 20 * %x + 24 * %y)
1978 // to the same form:
1979 // 2 + sext(4 + 20 * %x + 24 * %y)
1980 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1981 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1982 if (D != 0) {
1983 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
1984 const SCEV *SResidual =
1986 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
1987 return getAddExpr(SSExtD, SSExtR,
1989 Depth + 1);
1990 }
1991 }
1992 }
1993 // If the input value is a chrec scev, and we can prove that the value
1994 // did not overflow the old, smaller, value, we can sign extend all of the
1995 // operands (often constants). This allows analysis of something like
1996 // this: for (signed char X = 0; X < 100; ++X) { int Y = X; }
1997 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
1998 if (AR->isAffine()) {
1999 const SCEV *Start = AR->getStart();
2000 const SCEV *Step = AR->getStepRecurrence(*this);
2001 unsigned BitWidth = getTypeSizeInBits(AR->getType());
2002 const Loop *L = AR->getLoop();
2003
2004 // If we have special knowledge that this addrec won't overflow,
2005 // we don't need to do any further analysis.
2006 if (AR->hasNoSignedWrap()) {
2007 Start =
2008 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1);
2009 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2010 return getAddRecExpr(Start, Step, L, SCEV::FlagNSW);
2011 }
2012
2013 // Check whether the backedge-taken count is SCEVCouldNotCompute.
2014 // Note that this serves two purposes: It filters out loops that are
2015 // simply not analyzable, and it covers the case where this code is
2016 // being called from within backedge-taken count analysis, such that
2017 // attempting to ask for the backedge-taken count would likely result
2018 // in infinite recursion. In the later case, the analysis code will
2019 // cope with a conservative value, and it will take care to purge
2020 // that value once it has finished.
2021 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
2022 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
2023 // Manually compute the final value for AR, checking for
2024 // overflow.
2025
2026 // Check whether the backedge-taken count can be losslessly casted to
2027 // the addrec's type. The count is always unsigned.
2028 const SCEV *CastedMaxBECount =
2029 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
2030 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
2031 CastedMaxBECount, MaxBECount->getType(), Depth);
2032 if (MaxBECount == RecastedMaxBECount) {
2033 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
2034 // Check whether Start+Step*MaxBECount has no signed overflow.
2035 const SCEV *SMul = getMulExpr(CastedMaxBECount, Step,
2037 const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul,
2039 Depth + 1),
2040 WideTy, Depth + 1);
2041 const SCEV *WideStart = getSignExtendExpr(Start, WideTy, Depth + 1);
2042 const SCEV *WideMaxBECount =
2043 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
2044 const SCEV *OperandExtendedAdd =
2045 getAddExpr(WideStart,
2046 getMulExpr(WideMaxBECount,
2047 getSignExtendExpr(Step, WideTy, Depth + 1),
2050 if (SAdd == OperandExtendedAdd) {
2051 // Cache knowledge of AR NSW, which is propagated to this AddRec.
2052 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2053 // Return the expression with the addrec on the outside.
2054 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2055 Depth + 1);
2056 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2057 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2058 }
2059 // Similar to above, only this time treat the step value as unsigned.
2060 // This covers loops that count up with an unsigned step.
2061 OperandExtendedAdd =
2062 getAddExpr(WideStart,
2063 getMulExpr(WideMaxBECount,
2064 getZeroExtendExpr(Step, WideTy, Depth + 1),
2067 if (SAdd == OperandExtendedAdd) {
2068 // If AR wraps around then
2069 //
2070 // abs(Step) * MaxBECount > unsigned-max(AR->getType())
2071 // => SAdd != OperandExtendedAdd
2072 //
2073 // Thus (AR is not NW => SAdd != OperandExtendedAdd) <=>
2074 // (SAdd == OperandExtendedAdd => AR is NW)
2075
2076 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
2077
2078 // Return the expression with the addrec on the outside.
2079 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2080 Depth + 1);
2081 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
2082 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2083 }
2084 }
2085 }
2086
2087 auto NewFlags = proveNoSignedWrapViaInduction(AR);
2088 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
2089 if (AR->hasNoSignedWrap()) {
2090 // Same as nsw case above - duplicated here to avoid a compile time
2091 // issue. It's not clear that the order of checks does matter, but
2092 // it's one of two issue possible causes for a change which was
2093 // reverted. Be conservative for the moment.
2094 Start =
2095 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1);
2096 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2097 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2098 }
2099
2100 // sext({C,+,Step}) --> (sext(D) + sext({C-D,+,Step}))<nuw><nsw>
2101 // if D + (C - D + Step * n) could be proven to not signed wrap
2102 // where D maximizes the number of trailing zeros of (C - D + Step * n)
2103 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
2104 const APInt &C = SC->getAPInt();
2105 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
2106 if (D != 0) {
2107 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
2108 const SCEV *SResidual =
2109 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
2110 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
2111 return getAddExpr(SSExtD, SSExtR,
2113 Depth + 1);
2114 }
2115 }
2116
2117 if (proveNoWrapByVaryingStart<SCEVSignExtendExpr>(Start, Step, L)) {
2118 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2119 Start =
2120 getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1);
2121 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2122 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2123 }
2124 }
2125
2126 // If the input value is provably positive and we could not simplify
2127 // away the sext build a zext instead.
2129 return getZeroExtendExpr(Op, Ty, Depth + 1);
2130
2131 // sext(smin(x, y)) -> smin(sext(x), sext(y))
2132 // sext(smax(x, y)) -> smax(sext(x), sext(y))
2133 if (isa<SCEVSMinExpr>(Op) || isa<SCEVSMaxExpr>(Op)) {
2134 auto *MinMax = cast<SCEVMinMaxExpr>(Op);
2136 for (auto *Operand : MinMax->operands())
2137 Operands.push_back(getSignExtendExpr(Operand, Ty));
2138 if (isa<SCEVSMinExpr>(MinMax))
2139 return getSMinExpr(Operands);
2140 return getSMaxExpr(Operands);
2141 }
2142
2143 // The cast wasn't folded; create an explicit cast node.
2144 // Recompute the insert position, as it may have been invalidated.
2145 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2146 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
2147 Op, Ty);
2148 UniqueSCEVs.InsertNode(S, IP);
2149 registerUser(S, { Op });
2150 return S;
2151}
2152
2154 Type *Ty) {
2155 switch (Kind) {
2156 case scTruncate:
2157 return getTruncateExpr(Op, Ty);
2158 case scZeroExtend:
2159 return getZeroExtendExpr(Op, Ty);
2160 case scSignExtend:
2161 return getSignExtendExpr(Op, Ty);
2162 case scPtrToInt:
2163 return getPtrToIntExpr(Op, Ty);
2164 default:
2165 llvm_unreachable("Not a SCEV cast expression!");
2166 }
2167}
2168
2169/// getAnyExtendExpr - Return a SCEV for the given operand extended with
2170/// unspecified bits out to the given type.
2172 Type *Ty) {
2173 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
2174 "This is not an extending conversion!");
2175 assert(isSCEVable(Ty) &&
2176 "This is not a conversion to a SCEVable type!");
2177 Ty = getEffectiveSCEVType(Ty);
2178
2179 // Sign-extend negative constants.
2180 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
2181 if (SC->getAPInt().isNegative())
2182 return getSignExtendExpr(Op, Ty);
2183
2184 // Peel off a truncate cast.
2185 if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Op)) {
2186 const SCEV *NewOp = T->getOperand();
2187 if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty))
2188 return getAnyExtendExpr(NewOp, Ty);
2189 return getTruncateOrNoop(NewOp, Ty);
2190 }
2191
2192 // Next try a zext cast. If the cast is folded, use it.
2193 const SCEV *ZExt = getZeroExtendExpr(Op, Ty);
2194 if (!isa<SCEVZeroExtendExpr>(ZExt))
2195 return ZExt;
2196
2197 // Next try a sext cast. If the cast is folded, use it.
2198 const SCEV *SExt = getSignExtendExpr(Op, Ty);
2199 if (!isa<SCEVSignExtendExpr>(SExt))
2200 return SExt;
2201
2202 // Force the cast to be folded into the operands of an addrec.
2203 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) {
2205 for (const SCEV *Op : AR->operands())
2206 Ops.push_back(getAnyExtendExpr(Op, Ty));
2207 return getAddRecExpr(Ops, AR->getLoop(), SCEV::FlagNW);
2208 }
2209
2210 // If the expression is obviously signed, use the sext cast value.
2211 if (isa<SCEVSMaxExpr>(Op))
2212 return SExt;
2213
2214 // Absent any other information, use the zext cast value.
2215 return ZExt;
2216}
2217
2218/// Process the given Ops list, which is a list of operands to be added under
2219/// the given scale, update the given map. This is a helper function for
2220/// getAddRecExpr. As an example of what it does, given a sequence of operands
2221/// that would form an add expression like this:
2222///
2223/// m + n + 13 + (A * (o + p + (B * (q + m + 29)))) + r + (-1 * r)
2224///
2225/// where A and B are constants, update the map with these values:
2226///
2227/// (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0)
2228///
2229/// and add 13 + A*B*29 to AccumulatedConstant.
2230/// This will allow getAddRecExpr to produce this:
2231///
2232/// 13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B)
2233///
2234/// This form often exposes folding opportunities that are hidden in
2235/// the original operand list.
2236///
2237/// Return true iff it appears that any interesting folding opportunities
2238/// may be exposed. This helps getAddRecExpr short-circuit extra work in
2239/// the common case where no interesting opportunities are present, and
2240/// is also used as a check to avoid infinite recursion.
2241static bool
2244 APInt &AccumulatedConstant,
2245 ArrayRef<const SCEV *> Ops, const APInt &Scale,
2246 ScalarEvolution &SE) {
2247 bool Interesting = false;
2248
2249 // Iterate over the add operands. They are sorted, with constants first.
2250 unsigned i = 0;
2251 while (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
2252 ++i;
2253 // Pull a buried constant out to the outside.
2254 if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero())
2255 Interesting = true;
2256 AccumulatedConstant += Scale * C->getAPInt();
2257 }
2258
2259 // Next comes everything else. We're especially interested in multiplies
2260 // here, but they're in the middle, so just visit the rest with one loop.
2261 for (; i != Ops.size(); ++i) {
2262 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[i]);
2263 if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) {
2264 APInt NewScale =
2265 Scale * cast<SCEVConstant>(Mul->getOperand(0))->getAPInt();
2266 if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) {
2267 // A multiplication of a constant with another add; recurse.
2268 const SCEVAddExpr *Add = cast<SCEVAddExpr>(Mul->getOperand(1));
2269 Interesting |=
2270 CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2271 Add->operands(), NewScale, SE);
2272 } else {
2273 // A multiplication of a constant with some other value. Update
2274 // the map.
2275 SmallVector<const SCEV *, 4> MulOps(drop_begin(Mul->operands()));
2276 const SCEV *Key = SE.getMulExpr(MulOps);
2277 auto Pair = M.insert({Key, NewScale});
2278 if (Pair.second) {
2279 NewOps.push_back(Pair.first->first);
2280 } else {
2281 Pair.first->second += NewScale;
2282 // The map already had an entry for this value, which may indicate
2283 // a folding opportunity.
2284 Interesting = true;
2285 }
2286 }
2287 } else {
2288 // An ordinary operand. Update the map.
2289 std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair =
2290 M.insert({Ops[i], Scale});
2291 if (Pair.second) {
2292 NewOps.push_back(Pair.first->first);
2293 } else {
2294 Pair.first->second += Scale;
2295 // The map already had an entry for this value, which may indicate
2296 // a folding opportunity.
2297 Interesting = true;
2298 }
2299 }
2300 }
2301
2302 return Interesting;
2303}
2304
2306 const SCEV *LHS, const SCEV *RHS,
2307 const Instruction *CtxI) {
2308 const SCEV *(ScalarEvolution::*Operation)(const SCEV *, const SCEV *,
2309 SCEV::NoWrapFlags, unsigned);
2310 switch (BinOp) {
2311 default:
2312 llvm_unreachable("Unsupported binary op");
2313 case Instruction::Add:
2315 break;
2316 case Instruction::Sub:
2318 break;
2319 case Instruction::Mul:
2321 break;
2322 }
2323
2324 const SCEV *(ScalarEvolution::*Extension)(const SCEV *, Type *, unsigned) =
2327
2328 // Check ext(LHS op RHS) == ext(LHS) op ext(RHS)
2329 auto *NarrowTy = cast<IntegerType>(LHS->getType());
2330 auto *WideTy =
2331 IntegerType::get(NarrowTy->getContext(), NarrowTy->getBitWidth() * 2);
2332
2333 const SCEV *A = (this->*Extension)(
2334 (this->*Operation)(LHS, RHS, SCEV::FlagAnyWrap, 0), WideTy, 0);
2335 const SCEV *LHSB = (this->*Extension)(LHS, WideTy, 0);
2336 const SCEV *RHSB = (this->*Extension)(RHS, WideTy, 0);
2337 const SCEV *B = (this->*Operation)(LHSB, RHSB, SCEV::FlagAnyWrap, 0);
2338 if (A == B)
2339 return true;
2340 // Can we use context to prove the fact we need?
2341 if (!CtxI)
2342 return false;
2343 // TODO: Support mul.
2344 if (BinOp == Instruction::Mul)
2345 return false;
2346 auto *RHSC = dyn_cast<SCEVConstant>(RHS);
2347 // TODO: Lift this limitation.
2348 if (!RHSC)
2349 return false;
2350 APInt C = RHSC->getAPInt();
2351 unsigned NumBits = C.getBitWidth();
2352 bool IsSub = (BinOp == Instruction::Sub);
2353 bool IsNegativeConst = (Signed && C.isNegative());
2354 // Compute the direction and magnitude by which we need to check overflow.
2355 bool OverflowDown = IsSub ^ IsNegativeConst;
2356 APInt Magnitude = C;
2357 if (IsNegativeConst) {
2358 if (C == APInt::getSignedMinValue(NumBits))
2359 // TODO: SINT_MIN on inversion gives the same negative value, we don't
2360 // want to deal with that.
2361 return false;
2362 Magnitude = -C;
2363 }
2364
2366 if (OverflowDown) {
2367 // To avoid overflow down, we need to make sure that MIN + Magnitude <= LHS.
2368 APInt Min = Signed ? APInt::getSignedMinValue(NumBits)
2369 : APInt::getMinValue(NumBits);
2370 APInt Limit = Min + Magnitude;
2371 return isKnownPredicateAt(Pred, getConstant(Limit), LHS, CtxI);
2372 } else {
2373 // To avoid overflow up, we need to make sure that LHS <= MAX - Magnitude.
2374 APInt Max = Signed ? APInt::getSignedMaxValue(NumBits)
2375 : APInt::getMaxValue(NumBits);
2376 APInt Limit = Max - Magnitude;
2377 return isKnownPredicateAt(Pred, LHS, getConstant(Limit), CtxI);
2378 }
2379}
2380
2381std::optional<SCEV::NoWrapFlags>
2383 const OverflowingBinaryOperator *OBO) {
2384 // It cannot be done any better.
2385 if (OBO->hasNoUnsignedWrap() && OBO->hasNoSignedWrap())
2386 return std::nullopt;
2387
2389
2390 if (OBO->hasNoUnsignedWrap())
2392 if (OBO->hasNoSignedWrap())
2394
2395 bool Deduced = false;
2396
2397 if (OBO->getOpcode() != Instruction::Add &&
2398 OBO->getOpcode() != Instruction::Sub &&
2399 OBO->getOpcode() != Instruction::Mul)
2400 return std::nullopt;
2401
2402 const SCEV *LHS = getSCEV(OBO->getOperand(0));
2403 const SCEV *RHS = getSCEV(OBO->getOperand(1));
2404
2405 const Instruction *CtxI =
2406 UseContextForNoWrapFlagInference ? dyn_cast<Instruction>(OBO) : nullptr;
2407 if (!OBO->hasNoUnsignedWrap() &&
2409 /* Signed */ false, LHS, RHS, CtxI)) {
2411 Deduced = true;
2412 }
2413
2414 if (!OBO->hasNoSignedWrap() &&
2416 /* Signed */ true, LHS, RHS, CtxI)) {
2418 Deduced = true;
2419 }
2420
2421 if (Deduced)
2422 return Flags;
2423 return std::nullopt;
2424}
2425
2426// We're trying to construct a SCEV of type `Type' with `Ops' as operands and
2427// `OldFlags' as can't-wrap behavior. Infer a more aggressive set of
2428// can't-overflow flags for the operation if possible.
2429static SCEV::NoWrapFlags
2431 const ArrayRef<const SCEV *> Ops,
2432 SCEV::NoWrapFlags Flags) {
2433 using namespace std::placeholders;
2434
2435 using OBO = OverflowingBinaryOperator;
2436
2437 bool CanAnalyze =
2439 (void)CanAnalyze;
2440 assert(CanAnalyze && "don't call from other places!");
2441
2442 int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
2443 SCEV::NoWrapFlags SignOrUnsignWrap =
2444 ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2445
2446 // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
2447 auto IsKnownNonNegative = [&](const SCEV *S) {
2448 return SE->isKnownNonNegative(S);
2449 };
2450
2451 if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative))
2452 Flags =
2453 ScalarEvolution::setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
2454
2455 SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2456
2457 if (SignOrUnsignWrap != SignOrUnsignMask &&
2458 (Type == scAddExpr || Type == scMulExpr) && Ops.size() == 2 &&
2459 isa<SCEVConstant>(Ops[0])) {
2460
2461 auto Opcode = [&] {
2462 switch (Type) {
2463 case scAddExpr:
2464 return Instruction::Add;
2465 case scMulExpr:
2466 return Instruction::Mul;
2467 default:
2468 llvm_unreachable("Unexpected SCEV op.");
2469 }
2470 }();
2471
2472 const APInt &C = cast<SCEVConstant>(Ops[0])->getAPInt();
2473
2474 // (A <opcode> C) --> (A <opcode> C)<nsw> if the op doesn't sign overflow.
2475 if (!(SignOrUnsignWrap & SCEV::FlagNSW)) {
2477 Opcode, C, OBO::NoSignedWrap);
2478 if (NSWRegion.contains(SE->getSignedRange(Ops[1])))
2480 }
2481
2482 // (A <opcode> C) --> (A <opcode> C)<nuw> if the op doesn't unsign overflow.
2483 if (!(SignOrUnsignWrap & SCEV::FlagNUW)) {
2485 Opcode, C, OBO::NoUnsignedWrap);
2486 if (NUWRegion.contains(SE->getUnsignedRange(Ops[1])))
2488 }
2489 }
2490
2491 // <0,+,nonnegative><nw> is also nuw
2492 // TODO: Add corresponding nsw case
2494 !ScalarEvolution::hasFlags(Flags, SCEV::FlagNUW) && Ops.size() == 2 &&
2495 Ops[0]->isZero() && IsKnownNonNegative(Ops[1]))
2497
2498 // both (udiv X, Y) * Y and Y * (udiv X, Y) are always NUW
2500 Ops.size() == 2) {
2501 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[0]))
2502 if (UDiv->getOperand(1) == Ops[1])
2504 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[1]))
2505 if (UDiv->getOperand(1) == Ops[0])
2507 }
2508
2509 return Flags;
2510}
2511
2513 return isLoopInvariant(S, L) && properlyDominates(S, L->getHeader());
2514}
2515
2516/// Get a canonical add expression, or something simpler if possible.
2518 SCEV::NoWrapFlags OrigFlags,
2519 unsigned Depth) {
2520 assert(!(OrigFlags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) &&
2521 "only nuw or nsw allowed");
2522 assert(!Ops.empty() && "Cannot get empty add!");
2523 if (Ops.size() == 1) return Ops[0];
2524#ifndef NDEBUG
2525 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
2526 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2527 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
2528 "SCEVAddExpr operand types don't match!");
2529 unsigned NumPtrs = count_if(
2530 Ops, [](const SCEV *Op) { return Op->getType()->isPointerTy(); });
2531 assert(NumPtrs <= 1 && "add has at most one pointer operand");
2532#endif
2533
2534 // Sort by complexity, this groups all similar expression types together.
2535 GroupByComplexity(Ops, &LI, DT);
2536
2537 // If there are any constants, fold them together.
2538 unsigned Idx = 0;
2539 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
2540 ++Idx;
2541 assert(Idx < Ops.size());
2542 while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
2543 // We found two constants, fold them together!
2544 Ops[0] = getConstant(LHSC->getAPInt() + RHSC->getAPInt());
2545 if (Ops.size() == 2) return Ops[0];
2546 Ops.erase(Ops.begin()+1); // Erase the folded element
2547 LHSC = cast<SCEVConstant>(Ops[0]);
2548 }
2549
2550 // If we are left with a constant zero being added, strip it off.
2551 if (LHSC->getValue()->isZero()) {
2552 Ops.erase(Ops.begin());
2553 --Idx;
2554 }
2555
2556 if (Ops.size() == 1) return Ops[0];
2557 }
2558
2559 // Delay expensive flag strengthening until necessary.
2560 auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
2561 return StrengthenNoWrapFlags(this, scAddExpr, Ops, OrigFlags);
2562 };
2563
2564 // Limit recursion calls depth.
2566 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2567
2568 if (SCEV *S = findExistingSCEVInCache(scAddExpr, Ops)) {
2569 // Don't strengthen flags if we have no new information.
2570 SCEVAddExpr *Add = static_cast<SCEVAddExpr *>(S);
2571 if (Add->getNoWrapFlags(OrigFlags) != OrigFlags)
2572 Add->setNoWrapFlags(ComputeFlags(Ops));
2573 return S;
2574 }
2575
2576 // Okay, check to see if the same value occurs in the operand list more than
2577 // once. If so, merge them together into an multiply expression. Since we
2578 // sorted the list, these values are required to be adjacent.
2579 Type *Ty = Ops[0]->getType();
2580 bool FoundMatch = false;
2581 for (unsigned i = 0, e = Ops.size(); i != e-1; ++i)
2582 if (Ops[i] == Ops[i+1]) { // X + Y + Y --> X + Y*2
2583 // Scan ahead to count how many equal operands there are.
2584 unsigned Count = 2;
2585 while (i+Count != e && Ops[i+Count] == Ops[i])
2586 ++Count;
2587 // Merge the values into a multiply.
2588 const SCEV *Scale = getConstant(Ty, Count);
2589 const SCEV *Mul = getMulExpr(Scale, Ops[i], SCEV::FlagAnyWrap, Depth + 1);
2590 if (Ops.size() == Count)
2591 return Mul;
2592 Ops[i] = Mul;
2593 Ops.erase(Ops.begin()+i+1, Ops.begin()+i+Count);
2594 --i; e -= Count - 1;
2595 FoundMatch = true;
2596 }
2597 if (FoundMatch)
2598 return getAddExpr(Ops, OrigFlags, Depth + 1);
2599
2600 // Check for truncates. If all the operands are truncated from the same
2601 // type, see if factoring out the truncate would permit the result to be
2602 // folded. eg., n*trunc(x) + m*trunc(y) --> trunc(trunc(m)*x + trunc(n)*y)
2603 // if the contents of the resulting outer trunc fold to something simple.
2604 auto FindTruncSrcType = [&]() -> Type * {
2605 // We're ultimately looking to fold an addrec of truncs and muls of only
2606 // constants and truncs, so if we find any other types of SCEV
2607 // as operands of the addrec then we bail and return nullptr here.
2608 // Otherwise, we return the type of the operand of a trunc that we find.
2609 if (auto *T = dyn_cast<SCEVTruncateExpr>(Ops[Idx]))
2610 return T->getOperand()->getType();
2611 if (const auto *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
2612 const auto *LastOp = Mul->getOperand(Mul->getNumOperands() - 1);
2613 if (const auto *T = dyn_cast<SCEVTruncateExpr>(LastOp))
2614 return T->getOperand()->getType();
2615 }
2616 return nullptr;
2617 };
2618 if (auto *SrcType = FindTruncSrcType()) {
2620 bool Ok = true;
2621 // Check all the operands to see if they can be represented in the
2622 // source type of the truncate.
2623 for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
2624 if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Ops[i])) {
2625 if (T->getOperand()->getType() != SrcType) {
2626 Ok = false;
2627 break;
2628 }
2629 LargeOps.push_back(T->getOperand());
2630 } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
2631 LargeOps.push_back(getAnyExtendExpr(C, SrcType));
2632 } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Ops[i])) {
2633 SmallVector<const SCEV *, 8> LargeMulOps;
2634 for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) {
2635 if (const SCEVTruncateExpr *T =
2636 dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) {
2637 if (T->getOperand()->getType() != SrcType) {
2638 Ok = false;
2639 break;
2640 }
2641 LargeMulOps.push_back(T->getOperand());
2642 } else if (const auto *C = dyn_cast<SCEVConstant>(M->getOperand(j))) {
2643 LargeMulOps.push_back(getAnyExtendExpr(C, SrcType));
2644 } else {
2645 Ok = false;
2646 break;
2647 }
2648 }
2649 if (Ok)
2650 LargeOps.push_back(getMulExpr(LargeMulOps, SCEV::FlagAnyWrap, Depth + 1));
2651 } else {
2652 Ok = false;
2653 break;
2654 }
2655 }
2656 if (Ok) {
2657 // Evaluate the expression in the larger type.
2658 const SCEV *Fold = getAddExpr(LargeOps, SCEV::FlagAnyWrap, Depth + 1);
2659 // If it folds to something simple, use it. Otherwise, don't.
2660 if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
2661 return getTruncateExpr(Fold, Ty);
2662 }
2663 }
2664
2665 if (Ops.size() == 2) {
2666 // Check if we have an expression of the form ((X + C1) - C2), where C1 and
2667 // C2 can be folded in a way that allows retaining wrapping flags of (X +
2668 // C1).
2669 const SCEV *A = Ops[0];
2670 const SCEV *B = Ops[1];
2671 auto *AddExpr = dyn_cast<SCEVAddExpr>(B);
2672 auto *C = dyn_cast<SCEVConstant>(A);
2673 if (AddExpr && C && isa<SCEVConstant>(AddExpr->getOperand(0))) {
2674 auto C1 = cast<SCEVConstant>(AddExpr->getOperand(0))->getAPInt();
2675 auto C2 = C->getAPInt();
2676 SCEV::NoWrapFlags PreservedFlags = SCEV::FlagAnyWrap;
2677
2678 APInt ConstAdd = C1 + C2;
2679 auto AddFlags = AddExpr->getNoWrapFlags();
2680 // Adding a smaller constant is NUW if the original AddExpr was NUW.
2682 ConstAdd.ule(C1)) {
2683 PreservedFlags =
2685 }
2686
2687 // Adding a constant with the same sign and small magnitude is NSW, if the
2688 // original AddExpr was NSW.
2690 C1.isSignBitSet() == ConstAdd.isSignBitSet() &&
2691 ConstAdd.abs().ule(C1.abs())) {
2692 PreservedFlags =
2694 }
2695
2696 if (PreservedFlags != SCEV::FlagAnyWrap) {
2697 SmallVector<const SCEV *, 4> NewOps(AddExpr->operands());
2698 NewOps[0] = getConstant(ConstAdd);
2699 return getAddExpr(NewOps, PreservedFlags);
2700 }
2701 }
2702 }
2703
2704 // Canonicalize (-1 * urem X, Y) + X --> (Y * X/Y)
2705 if (Ops.size() == 2) {
2706 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[0]);
2707 if (Mul && Mul->getNumOperands() == 2 &&
2708 Mul->getOperand(0)->isAllOnesValue()) {
2709 const SCEV *X;
2710 const SCEV *Y;
2711 if (matchURem(Mul->getOperand(1), X, Y) && X == Ops[1]) {
2712 return getMulExpr(Y, getUDivExpr(X, Y));
2713 }
2714 }
2715 }
2716
2717 // Skip past any other cast SCEVs.
2718 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
2719 ++Idx;
2720
2721 // If there are add operands they would be next.
2722 if (Idx < Ops.size()) {
2723 bool DeletedAdd = false;
2724 // If the original flags and all inlined SCEVAddExprs are NUW, use the
2725 // common NUW flag for expression after inlining. Other flags cannot be
2726 // preserved, because they may depend on the original order of operations.
2727 SCEV::NoWrapFlags CommonFlags = maskFlags(OrigFlags, SCEV::FlagNUW);
2728 while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
2729 if (Ops.size() > AddOpsInlineThreshold ||
2730 Add->getNumOperands() > AddOpsInlineThreshold)
2731 break;
2732 // If we have an add, expand the add operands onto the end of the operands
2733 // list.
2734 Ops.erase(Ops.begin()+Idx);
2735 append_range(Ops, Add->operands());
2736 DeletedAdd = true;
2737 CommonFlags = maskFlags(CommonFlags, Add->getNoWrapFlags());
2738 }
2739
2740 // If we deleted at least one add, we added operands to the end of the list,
2741 // and they are not necessarily sorted. Recurse to resort and resimplify
2742 // any operands we just acquired.
2743 if (DeletedAdd)
2744 return getAddExpr(Ops, CommonFlags, Depth + 1);
2745 }
2746
2747 // Skip over the add expression until we get to a multiply.
2748 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
2749 ++Idx;
2750
2751 // Check to see if there are any folding opportunities present with
2752 // operands multiplied by constant values.
2753 if (Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx])) {
2757 APInt AccumulatedConstant(BitWidth, 0);
2758 if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2759 Ops, APInt(BitWidth, 1), *this)) {
2760 struct APIntCompare {
2761 bool operator()(const APInt &LHS, const APInt &RHS) const {
2762 return LHS.ult(RHS);
2763 }
2764 };
2765
2766 // Some interesting folding opportunity is present, so its worthwhile to
2767 // re-generate the operands list. Group the operands by constant scale,
2768 // to avoid multiplying by the same constant scale multiple times.
2769 std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare> MulOpLists;
2770 for (const SCEV *NewOp : NewOps)
2771 MulOpLists[M.find(NewOp)->second].push_back(NewOp);
2772 // Re-generate the operands list.
2773 Ops.clear();
2774 if (AccumulatedConstant != 0)
2775 Ops.push_back(getConstant(AccumulatedConstant));
2776 for (auto &MulOp : MulOpLists) {
2777 if (MulOp.first == 1) {
2778 Ops.push_back(getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1));
2779 } else if (MulOp.first != 0) {
2781 getConstant(MulOp.first),
2782 getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1),
2783 SCEV::FlagAnyWrap, Depth + 1));
2784 }
2785 }
2786 if (Ops.empty())
2787 return getZero(Ty);
2788 if (Ops.size() == 1)
2789 return Ops[0];
2790 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2791 }
2792 }
2793
2794 // If we are adding something to a multiply expression, make sure the
2795 // something is not already an operand of the multiply. If so, merge it into
2796 // the multiply.
2797 for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) {
2798 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]);
2799 for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) {
2800 const SCEV *MulOpSCEV = Mul->getOperand(MulOp);
2801 if (isa<SCEVConstant>(MulOpSCEV))
2802 continue;
2803 for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
2804 if (MulOpSCEV == Ops[AddOp]) {
2805 // Fold W + X + (X * Y * Z) --> W + (X * ((Y*Z)+1))
2806 const SCEV *InnerMul = Mul->getOperand(MulOp == 0);
2807 if (Mul->getNumOperands() != 2) {
2808 // If the multiply has more than two operands, we must get the
2809 // Y*Z term.
2811 Mul->operands().take_front(MulOp));
2812 append_range(MulOps, Mul->operands().drop_front(MulOp + 1));
2813 InnerMul = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2814 }
2815 SmallVector<const SCEV *, 2> TwoOps = {getOne(Ty), InnerMul};
2816 const SCEV *AddOne = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2817 const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV,
2819 if (Ops.size() == 2) return OuterMul;
2820 if (AddOp < Idx) {
2821 Ops.erase(Ops.begin()+AddOp);
2822 Ops.erase(Ops.begin()+Idx-1);
2823 } else {
2824 Ops.erase(Ops.begin()+Idx);
2825 Ops.erase(Ops.begin()+AddOp-1);
2826 }
2827 Ops.push_back(OuterMul);
2828 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2829 }
2830
2831 // Check this multiply against other multiplies being added together.
2832 for (unsigned OtherMulIdx = Idx+1;
2833 OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]);
2834 ++OtherMulIdx) {
2835 const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]);
2836 // If MulOp occurs in OtherMul, we can fold the two multiplies
2837 // together.
2838 for (unsigned OMulOp = 0, e = OtherMul->getNumOperands();
2839 OMulOp != e; ++OMulOp)
2840 if (OtherMul->getOperand(OMulOp) == MulOpSCEV) {
2841 // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E))
2842 const SCEV *InnerMul1 = Mul->getOperand(MulOp == 0);
2843 if (Mul->getNumOperands() != 2) {
2845 Mul->operands().take_front(MulOp));
2846 append_range(MulOps, Mul->operands().drop_front(MulOp+1));
2847 InnerMul1 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2848 }
2849 const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0);
2850 if (OtherMul->getNumOperands() != 2) {
2852 OtherMul->operands().take_front(OMulOp));
2853 append_range(MulOps, OtherMul->operands().drop_front(OMulOp+1));
2854 InnerMul2 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2855 }
2856 SmallVector<const SCEV *, 2> TwoOps = {InnerMul1, InnerMul2};
2857 const SCEV *InnerMulSum =
2858 getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2859 const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum,
2861 if (Ops.size() == 2) return OuterMul;
2862 Ops.erase(Ops.begin()+Idx);
2863 Ops.erase(Ops.begin()+OtherMulIdx-1);
2864 Ops.push_back(OuterMul);
2865 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2866 }
2867 }
2868 }
2869 }
2870
2871 // If there are any add recurrences in the operands list, see if any other
2872 // added values are loop invariant. If so, we can fold them into the
2873 // recurrence.
2874 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
2875 ++Idx;
2876
2877 // Scan over all recurrences, trying to fold loop invariants into them.
2878 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
2879 // Scan all of the other operands to this add and add them to the vector if
2880 // they are loop invariant w.r.t. the recurrence.
2882 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
2883 const Loop *AddRecLoop = AddRec->getLoop();
2884 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2885 if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) {
2886 LIOps.push_back(Ops[i]);
2887 Ops.erase(Ops.begin()+i);
2888 --i; --e;
2889 }
2890
2891 // If we found some loop invariants, fold them into the recurrence.
2892 if (!LIOps.empty()) {
2893 // Compute nowrap flags for the addition of the loop-invariant ops and
2894 // the addrec. Temporarily push it as an operand for that purpose. These
2895 // flags are valid in the scope of the addrec only.
2896 LIOps.push_back(AddRec);
2897 SCEV::NoWrapFlags Flags = ComputeFlags(LIOps);
2898 LIOps.pop_back();
2899
2900 // NLI + LI + {Start,+,Step} --> NLI + {LI+Start,+,Step}
2901 LIOps.push_back(AddRec->getStart());
2902
2903 SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2904
2905 // It is not in general safe to propagate flags valid on an add within
2906 // the addrec scope to one outside it. We must prove that the inner
2907 // scope is guaranteed to execute if the outer one does to be able to
2908 // safely propagate. We know the program is undefined if poison is
2909 // produced on the inner scoped addrec. We also know that *for this use*
2910 // the outer scoped add can't overflow (because of the flags we just
2911 // computed for the inner scoped add) without the program being undefined.
2912 // Proving that entry to the outer scope neccesitates entry to the inner
2913 // scope, thus proves the program undefined if the flags would be violated
2914 // in the outer scope.
2915 SCEV::NoWrapFlags AddFlags = Flags;
2916 if (AddFlags != SCEV::FlagAnyWrap) {
2917 auto *DefI = getDefiningScopeBound(LIOps);
2918 auto *ReachI = &*AddRecLoop->getHeader()->begin();
2919 if (!isGuaranteedToTransferExecutionTo(DefI, ReachI))
2920 AddFlags = SCEV::FlagAnyWrap;
2921 }
2922 AddRecOps[0] = getAddExpr(LIOps, AddFlags, Depth + 1);
2923
2924 // Build the new addrec. Propagate the NUW and NSW flags if both the
2925 // outer add and the inner addrec are guaranteed to have no overflow.
2926 // Always propagate NW.
2927 Flags = AddRec->getNoWrapFlags(setFlags(Flags, SCEV::FlagNW));
2928 const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags);
2929
2930 // If all of the other operands were loop invariant, we are done.
2931 if (Ops.size() == 1) return NewRec;
2932
2933 // Otherwise, add the folded AddRec by the non-invariant parts.
2934 for (unsigned i = 0;; ++i)
2935 if (Ops[i] == AddRec) {
2936 Ops[i] = NewRec;
2937 break;
2938 }
2939 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2940 }
2941
2942 // Okay, if there weren't any loop invariants to be folded, check to see if
2943 // there are multiple AddRec's with the same loop induction variable being
2944 // added together. If so, we can fold them.
2945 for (unsigned OtherIdx = Idx+1;
2946 OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2947 ++OtherIdx) {
2948 // We expect the AddRecExpr's to be sorted in reverse dominance order,
2949 // so that the 1st found AddRecExpr is dominated by all others.
2950 assert(DT.dominates(
2951 cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()->getHeader(),
2952 AddRec->getLoop()->getHeader()) &&
2953 "AddRecExprs are not sorted in reverse dominance order?");
2954 if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) {
2955 // Other + {A,+,B}<L> + {C,+,D}<L> --> Other + {A+C,+,B+D}<L>
2956 SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2957 for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2958 ++OtherIdx) {
2959 const auto *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
2960 if (OtherAddRec->getLoop() == AddRecLoop) {
2961 for (unsigned i = 0, e = OtherAddRec->getNumOperands();
2962 i != e; ++i) {
2963 if (i >= AddRecOps.size()) {
2964 append_range(AddRecOps, OtherAddRec->operands().drop_front(i));
2965 break;
2966 }
2968 AddRecOps[i], OtherAddRec->getOperand(i)};
2969 AddRecOps[i] = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2970 }
2971 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
2972 }
2973 }
2974 // Step size has changed, so we cannot guarantee no self-wraparound.
2975 Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap);
2976 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2977 }
2978 }
2979
2980 // Otherwise couldn't fold anything into this recurrence. Move onto the
2981 // next one.
2982 }
2983
2984 // Okay, it looks like we really DO need an add expr. Check to see if we
2985 // already have one, otherwise create a new one.
2986 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2987}
2988
2989const SCEV *
2990ScalarEvolution::getOrCreateAddExpr(ArrayRef<const SCEV *> Ops,
2991 SCEV::NoWrapFlags Flags) {
2993 ID.AddInteger(scAddExpr);
2994 for (const SCEV *Op : Ops)
2995 ID.AddPointer(Op);
2996 void *IP = nullptr;
2997 SCEVAddExpr *S =
2998 static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2999 if (!S) {
3000 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3001 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
3002 S = new (SCEVAllocator)
3003 SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size());
3004 UniqueSCEVs.InsertNode(S, IP);
3005 registerUser(S, Ops);
3006 }
3007 S->setNoWrapFlags(Flags);
3008 return S;
3009}
3010
3011const SCEV *
3012ScalarEvolution::getOrCreateAddRecExpr(ArrayRef<const SCEV *> Ops,
3013 const Loop *L, SCEV::NoWrapFlags Flags) {
3015 ID.AddInteger(scAddRecExpr);
3016 for (const SCEV *Op : Ops)
3017 ID.AddPointer(Op);
3018 ID.AddPointer(L);
3019 void *IP = nullptr;
3020 SCEVAddRecExpr *S =
3021 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3022 if (!S) {
3023 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3024 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
3025 S = new (SCEVAllocator)
3026 SCEVAddRecExpr(ID.Intern(SCEVAllocator), O, Ops.size(), L);
3027 UniqueSCEVs.InsertNode(S, IP);
3028 LoopUsers[L].push_back(S);
3029 registerUser(S, Ops);
3030 }
3031 setNoWrapFlags(S, Flags);
3032 return S;
3033}
3034
3035const SCEV *
3036ScalarEvolution::getOrCreateMulExpr(ArrayRef<const SCEV *> Ops,
3037 SCEV::NoWrapFlags Flags) {
3039 ID.AddInteger(scMulExpr);
3040 for (const SCEV *Op : Ops)
3041 ID.AddPointer(Op);
3042 void *IP = nullptr;
3043 SCEVMulExpr *S =
3044 static_cast<SCEVMulExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3045 if (!S) {
3046 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3047 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
3048 S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator),
3049 O, Ops.size());
3050 UniqueSCEVs.InsertNode(S, IP);
3051 registerUser(S, Ops);
3052 }
3053 S->setNoWrapFlags(Flags);
3054 return S;
3055}
3056
3057static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) {
3058 uint64_t k = i*j;
3059 if (j > 1 && k / j != i) Overflow = true;
3060 return k;
3061}
3062
3063/// Compute the result of "n choose k", the binomial coefficient. If an
3064/// intermediate computation overflows, Overflow will be set and the return will
3065/// be garbage. Overflow is not cleared on absence of overflow.
3066static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) {
3067 // We use the multiplicative formula:
3068 // n(n-1)(n-2)...(n-(k-1)) / k(k-1)(k-2)...1 .
3069 // At each iteration, we take the n-th term of the numeral and divide by the
3070 // (k-n)th term of the denominator. This division will always produce an
3071 // integral result, and helps reduce the chance of overflow in the
3072 // intermediate computations. However, we can still overflow even when the
3073 // final result would fit.
3074
3075 if (n == 0 || n == k) return 1;
3076 if (k > n) return 0;
3077
3078 if (k > n/2)
3079 k = n-k;
3080
3081 uint64_t r = 1;
3082 for (uint64_t i = 1; i <= k; ++i) {
3083 r = umul_ov(r, n-(i-1), Overflow);
3084 r /= i;
3085 }
3086 return r;
3087}
3088
3089/// Determine if any of the operands in this SCEV are a constant or if
3090/// any of the add or multiply expressions in this SCEV contain a constant.
3091static bool containsConstantInAddMulChain(const SCEV *StartExpr) {
3092 struct FindConstantInAddMulChain {
3093 bool FoundConstant = false;
3094
3095 bool follow(const SCEV *S) {
3096 FoundConstant |= isa<SCEVConstant>(S);
3097 return isa<SCEVAddExpr>(S) || isa<SCEVMulExpr>(S);
3098 }
3099
3100 bool isDone() const {
3101 return FoundConstant;
3102 }
3103 };
3104
3105 FindConstantInAddMulChain F;
3107 ST.visitAll(StartExpr);
3108 return F.FoundConstant;
3109}
3110
3111/// Get a canonical multiply expression, or something simpler if possible.
3113 SCEV::NoWrapFlags OrigFlags,
3114 unsigned Depth) {
3115 assert(OrigFlags == maskFlags(OrigFlags, SCEV::FlagNUW | SCEV::FlagNSW) &&
3116 "only nuw or nsw allowed");
3117 assert(!Ops.empty() && "Cannot get empty mul!");
3118 if (Ops.size() == 1) return Ops[0];
3119#ifndef NDEBUG
3120 Type *ETy = Ops[0]->getType();
3121 assert(!ETy->isPointerTy());
3122 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
3123 assert(Ops[i]->getType() == ETy &&
3124 "SCEVMulExpr operand types don't match!");
3125#endif
3126
3127 // Sort by complexity, this groups all similar expression types together.
3128 GroupByComplexity(Ops, &LI, DT);
3129
3130 // If there are any constants, fold them together.
3131 unsigned Idx = 0;
3132 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3133 ++Idx;
3134 assert(Idx < Ops.size());
3135 while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
3136 // We found two constants, fold them together!
3137 Ops[0] = getConstant(LHSC->getAPInt() * RHSC->getAPInt());
3138 if (Ops.size() == 2) return Ops[0];
3139 Ops.erase(Ops.begin()+1); // Erase the folded element
3140 LHSC = cast<SCEVConstant>(Ops[0]);
3141 }
3142
3143 // If we have a multiply of zero, it will always be zero.
3144 if (LHSC->getValue()->isZero())
3145 return LHSC;
3146
3147 // If we are left with a constant one being multiplied, strip it off.
3148 if (LHSC->getValue()->isOne()) {
3149 Ops.erase(Ops.begin());
3150 --Idx;
3151 }
3152
3153 if (Ops.size() == 1)
3154 return Ops[0];
3155 }
3156
3157 // Delay expensive flag strengthening until necessary.
3158 auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
3159 return StrengthenNoWrapFlags(this, scMulExpr, Ops, OrigFlags);
3160 };
3161
3162 // Limit recursion calls depth.
3164 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3165
3166 if (SCEV *S = findExistingSCEVInCache(scMulExpr, Ops)) {
3167 // Don't strengthen flags if we have no new information.
3168 SCEVMulExpr *Mul = static_cast<SCEVMulExpr *>(S);
3169 if (Mul->getNoWrapFlags(OrigFlags) != OrigFlags)
3170 Mul->setNoWrapFlags(ComputeFlags(Ops));
3171 return S;
3172 }
3173
3174 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3175 if (Ops.size() == 2) {
3176 // C1*(C2+V) -> C1*C2 + C1*V
3177 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1]))
3178 // If any of Add's ops are Adds or Muls with a constant, apply this
3179 // transformation as well.
3180 //
3181 // TODO: There are some cases where this transformation is not
3182 // profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of
3183 // this transformation should be narrowed down.
3184 if (Add->getNumOperands() == 2 && containsConstantInAddMulChain(Add)) {
3185 const SCEV *LHS = getMulExpr(LHSC, Add->getOperand(0),
3187 const SCEV *RHS = getMulExpr(LHSC, Add->getOperand(1),
3189 return getAddExpr(LHS, RHS, SCEV::FlagAnyWrap, Depth + 1);
3190 }
3191
3192 if (Ops[0]->isAllOnesValue()) {
3193 // If we have a mul by -1 of an add, try distributing the -1 among the
3194 // add operands.
3195 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) {
3197 bool AnyFolded = false;
3198 for (const SCEV *AddOp : Add->operands()) {
3199 const SCEV *Mul = getMulExpr(Ops[0], AddOp, SCEV::FlagAnyWrap,
3200 Depth + 1);
3201 if (!isa<SCEVMulExpr>(Mul)) AnyFolded = true;
3202 NewOps.push_back(Mul);
3203 }
3204 if (AnyFolded)
3205 return getAddExpr(NewOps, SCEV::FlagAnyWrap, Depth + 1);
3206 } else if (const auto *AddRec = dyn_cast<SCEVAddRecExpr>(Ops[1])) {
3207 // Negation preserves a recurrence's no self-wrap property.
3209 for (const SCEV *AddRecOp : AddRec->operands())
3210 Operands.push_back(getMulExpr(Ops[0], AddRecOp, SCEV::FlagAnyWrap,
3211 Depth + 1));
3212 // Let M be the minimum representable signed value. AddRec with nsw
3213 // multiplied by -1 can have signed overflow if and only if it takes a
3214 // value of M: M * (-1) would stay M and (M + 1) * (-1) would be the
3215 // maximum signed value. In all other cases signed overflow is
3216 // impossible.
3217 auto FlagsMask = SCEV::FlagNW;
3218 if (hasFlags(AddRec->getNoWrapFlags(), SCEV::FlagNSW)) {
3219 auto MinInt =
3220 APInt::getSignedMinValue(getTypeSizeInBits(AddRec->getType()));
3221 if (getSignedRangeMin(AddRec) != MinInt)
3222 FlagsMask = setFlags(FlagsMask, SCEV::FlagNSW);
3223 }
3224 return getAddRecExpr(Operands, AddRec->getLoop(),
3225 AddRec->getNoWrapFlags(FlagsMask));
3226 }
3227 }
3228 }
3229 }
3230
3231 // Skip over the add expression until we get to a multiply.
3232 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
3233 ++Idx;
3234
3235 // If there are mul operands inline them all into this expression.
3236 if (Idx < Ops.size()) {
3237 bool DeletedMul = false;
3238 while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
3239 if (Ops.size() > MulOpsInlineThreshold)
3240 break;
3241 // If we have an mul, expand the mul operands onto the end of the
3242 // operands list.
3243 Ops.erase(Ops.begin()+Idx);
3244 append_range(Ops, Mul->operands());
3245 DeletedMul = true;
3246 }
3247
3248 // If we deleted at least one mul, we added operands to the end of the
3249 // list, and they are not necessarily sorted. Recurse to resort and
3250 // resimplify any operands we just acquired.
3251 if (DeletedMul)
3252 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3253 }
3254
3255 // If there are any add recurrences in the operands list, see if any other
3256 // added values are loop invariant. If so, we can fold them into the
3257 // recurrence.
3258 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
3259 ++Idx;
3260
3261 // Scan over all recurrences, trying to fold loop invariants into them.
3262 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
3263 // Scan all of the other operands to this mul and add them to the vector
3264 // if they are loop invariant w.r.t. the recurrence.
3266 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
3267 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3268 if (isAvailableAtLoopEntry(Ops[i], AddRec->getLoop())) {
3269 LIOps.push_back(Ops[i]);
3270 Ops.erase(Ops.begin()+i);
3271 --i; --e;
3272 }
3273
3274 // If we found some loop invariants, fold them into the recurrence.
3275 if (!LIOps.empty()) {
3276 // NLI * LI * {Start,+,Step} --> NLI * {LI*Start,+,LI*Step}
3278 NewOps.reserve(AddRec->getNumOperands());
3279 const SCEV *Scale = getMulExpr(LIOps, SCEV::FlagAnyWrap, Depth + 1);
3280
3281 // If both the mul and addrec are nuw, we can preserve nuw.
3282 // If both the mul and addrec are nsw, we can only preserve nsw if either
3283 // a) they are also nuw, or
3284 // b) all multiplications of addrec operands with scale are nsw.
3285 SCEV::NoWrapFlags Flags =
3286 AddRec->getNoWrapFlags(ComputeFlags({Scale, AddRec}));
3287
3288 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
3289 NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i),
3290 SCEV::FlagAnyWrap, Depth + 1));
3291
3292 if (hasFlags(Flags, SCEV::FlagNSW) && !hasFlags(Flags, SCEV::FlagNUW)) {
3294 Instruction::Mul, getSignedRange(Scale),
3296 if (!NSWRegion.contains(getSignedRange(AddRec->getOperand(i))))
3297 Flags = clearFlags(Flags, SCEV::FlagNSW);
3298 }
3299 }
3300
3301 const SCEV *NewRec = getAddRecExpr(NewOps, AddRec->getLoop(), Flags);
3302
3303 // If all of the other operands were loop invariant, we are done.
3304 if (Ops.size() == 1) return NewRec;
3305
3306 // Otherwise, multiply the folded AddRec by the non-invariant parts.
3307 for (unsigned i = 0;; ++i)
3308 if (Ops[i] == AddRec) {
3309 Ops[i] = NewRec;
3310 break;
3311 }
3312 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3313 }
3314
3315 // Okay, if there weren't any loop invariants to be folded, check to see
3316 // if there are multiple AddRec's with the same loop induction variable
3317 // being multiplied together. If so, we can fold them.
3318
3319 // {A1,+,A2,+,...,+,An}<L> * {B1,+,B2,+,...,+,Bn}<L>
3320 // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [
3321 // choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z
3322 // ]]],+,...up to x=2n}.
3323 // Note that the arguments to choose() are always integers with values
3324 // known at compile time, never SCEV objects.
3325 //
3326 // The implementation avoids pointless extra computations when the two
3327 // addrec's are of different length (mathematically, it's equivalent to
3328 // an infinite stream of zeros on the right).
3329 bool OpsModified = false;
3330 for (unsigned OtherIdx = Idx+1;
3331 OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
3332 ++OtherIdx) {
3333 const SCEVAddRecExpr *OtherAddRec =
3334 dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]);
3335 if (!OtherAddRec || OtherAddRec->getLoop() != AddRec->getLoop())
3336 continue;
3337
3338 // Limit max number of arguments to avoid creation of unreasonably big
3339 // SCEVAddRecs with very complex operands.
3340 if (AddRec->getNumOperands() + OtherAddRec->getNumOperands() - 1 >
3341 MaxAddRecSize || hasHugeExpression({AddRec, OtherAddRec}))
3342 continue;
3343
3344 bool Overflow = false;
3345 Type *Ty = AddRec->getType();
3346 bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64;
3348 for (int x = 0, xe = AddRec->getNumOperands() +
3349 OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) {
3350 SmallVector <const SCEV *, 7> SumOps;
3351 for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) {
3352 uint64_t Coeff1 = Choose(x, 2*x - y, Overflow);
3353 for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1),
3354 ze = std::min(x+1, (int)OtherAddRec->getNumOperands());
3355 z < ze && !Overflow; ++z) {
3356 uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow);
3357 uint64_t Coeff;
3358 if (LargerThan64Bits)
3359 Coeff = umul_ov(Coeff1, Coeff2, Overflow);
3360 else
3361 Coeff = Coeff1*Coeff2;
3362 const SCEV *CoeffTerm = getConstant(Ty, Coeff);
3363 const SCEV *Term1 = AddRec->getOperand(y-z);
3364 const SCEV *Term2 = OtherAddRec->getOperand(z);
3365 SumOps.push_back(getMulExpr(CoeffTerm, Term1, Term2,
3366 SCEV::FlagAnyWrap, Depth + 1));
3367 }
3368 }
3369 if (SumOps.empty())
3370 SumOps.push_back(getZero(Ty));
3371 AddRecOps.push_back(getAddExpr(SumOps, SCEV::FlagAnyWrap, Depth + 1));
3372 }
3373 if (!Overflow) {
3374 const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRec->getLoop(),
3376 if (Ops.size() == 2) return NewAddRec;
3377 Ops[Idx] = NewAddRec;
3378 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
3379 OpsModified = true;
3380 AddRec = dyn_cast<SCEVAddRecExpr>(NewAddRec);
3381 if (!AddRec)
3382 break;
3383 }
3384 }
3385 if (OpsModified)
3386 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3387
3388 // Otherwise couldn't fold anything into this recurrence. Move onto the
3389 // next one.
3390 }
3391
3392 // Okay, it looks like we really DO need an mul expr. Check to see if we
3393 // already have one, otherwise create a new one.
3394 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3395}
3396
3397/// Represents an unsigned remainder expression based on unsigned division.
3399 const SCEV *RHS) {
3402 "SCEVURemExpr operand types don't match!");
3403
3404 // Short-circuit easy cases
3405 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3406 // If constant is one, the result is trivial
3407 if (RHSC->getValue()->isOne())
3408 return getZero(LHS->getType()); // X urem 1 --> 0
3409
3410 // If constant is a power of two, fold into a zext(trunc(LHS)).
3411 if (RHSC->getAPInt().isPowerOf2()) {
3412 Type *FullTy = LHS->getType();
3413 Type *TruncTy =
3414 IntegerType::get(getContext(), RHSC->getAPInt().logBase2());
3415 return getZeroExtendExpr(getTruncateExpr(LHS, TruncTy), FullTy);
3416 }
3417 }
3418
3419 // Fallback to %a == %x urem %y == %x -<nuw> ((%x udiv %y) *<nuw> %y)
3420 const SCEV *UDiv = getUDivExpr(LHS, RHS);
3421 const SCEV *Mult = getMulExpr(UDiv, RHS, SCEV::FlagNUW);
3422 return getMinusSCEV(LHS, Mult, SCEV::FlagNUW);
3423}
3424
3425/// Get a canonical unsigned division expression, or something simpler if
3426/// possible.
3428 const SCEV *RHS) {
3429 assert(!LHS->getType()->isPointerTy() &&
3430 "SCEVUDivExpr operand can't be pointer!");
3431 assert(LHS->getType() == RHS->getType() &&
3432 "SCEVUDivExpr operand types don't match!");
3433
3435 ID.AddInteger(scUDivExpr);
3436 ID.AddPointer(LHS);
3437 ID.AddPointer(RHS);
3438 void *IP = nullptr;
3439 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3440 return S;
3441
3442 // 0 udiv Y == 0
3443 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS))
3444 if (LHSC->getValue()->isZero())
3445 return LHS;
3446
3447 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3448 if (RHSC->getValue()->isOne())
3449 return LHS; // X udiv 1 --> x
3450 // If the denominator is zero, the result of the udiv is undefined. Don't
3451 // try to analyze it, because the resolution chosen here may differ from
3452 // the resolution chosen in other parts of the compiler.
3453 if (!RHSC->getValue()->isZero()) {
3454 // Determine if the division can be folded into the operands of
3455 // its operands.
3456 // TODO: Generalize this to non-constants by using known-bits information.
3457 Type *Ty = LHS->getType();
3458 unsigned LZ = RHSC->getAPInt().countl_zero();
3459 unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1;
3460 // For non-power-of-two values, effectively round the value up to the
3461 // nearest power of two.
3462 if (!RHSC->getAPInt().isPowerOf2())
3463 ++MaxShiftAmt;
3464 IntegerType *ExtTy =
3465 IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt);
3466 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
3467 if (const SCEVConstant *Step =
3468 dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this))) {
3469 // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded.
3470 const APInt &StepInt = Step->getAPInt();
3471 const APInt &DivInt = RHSC->getAPInt();
3472 if (!StepInt.urem(DivInt) &&
3473 getZeroExtendExpr(AR, ExtTy) ==
3474 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3475 getZeroExtendExpr(Step, ExtTy),
3476 AR->getLoop(), SCEV::FlagAnyWrap)) {
3478 for (const SCEV *Op : AR->operands())
3479 Operands.push_back(getUDivExpr(Op, RHS));
3480 return getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagNW);
3481 }
3482 /// Get a canonical UDivExpr for a recurrence.
3483 /// {X,+,N}/C => {Y,+,N}/C where Y=X-(X%N). Safe when C%N=0.
3484 // We can currently only fold X%N if X is constant.
3485 const SCEVConstant *StartC = dyn_cast<SCEVConstant>(AR->getStart());
3486 if (StartC && !DivInt.urem(StepInt) &&
3487 getZeroExtendExpr(AR, ExtTy) ==
3488 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3489 getZeroExtendExpr(Step, ExtTy),
3490 AR->getLoop(), SCEV::FlagAnyWrap)) {
3491 const APInt &StartInt = StartC->getAPInt();
3492 const APInt &StartRem = StartInt.urem(StepInt);
3493 if (StartRem != 0) {
3494 const SCEV *NewLHS =
3495 getAddRecExpr(getConstant(StartInt - StartRem), Step,
3496 AR->getLoop(), SCEV::FlagNW);
3497 if (LHS != NewLHS) {
3498 LHS = NewLHS;
3499
3500 // Reset the ID to include the new LHS, and check if it is
3501 // already cached.
3502 ID.clear();
3503 ID.AddInteger(scUDivExpr);
3504 ID.AddPointer(LHS);
3505 ID.AddPointer(RHS);
3506 IP = nullptr;
3507 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3508 return S;
3509 }
3510 }
3511 }
3512 }
3513 // (A*B)/C --> A*(B/C) if safe and B/C can be folded.
3514 if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) {
3516 for (const SCEV *Op : M->operands())
3517 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3518 if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands))
3519 // Find an operand that's safely divisible.
3520 for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) {
3521 const SCEV *Op = M->getOperand(i);
3522 const SCEV *Div = getUDivExpr(Op, RHSC);
3523 if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) {
3524 Operands = SmallVector<const SCEV *, 4>(M->operands());
3525 Operands[i] = Div;
3526 return getMulExpr(Operands);
3527 }
3528 }
3529 }
3530
3531 // (A/B)/C --> A/(B*C) if safe and B*C can be folded.
3532 if (const SCEVUDivExpr *OtherDiv = dyn_cast<SCEVUDivExpr>(LHS)) {
3533 if (auto *DivisorConstant =
3534 dyn_cast<SCEVConstant>(OtherDiv->getRHS())) {
3535 bool Overflow = false;
3536 APInt NewRHS =
3537 DivisorConstant->getAPInt().umul_ov(RHSC->getAPInt(), Overflow);
3538 if (Overflow) {
3539 return getConstant(RHSC->getType(), 0, false);
3540 }
3541 return getUDivExpr(OtherDiv->getLHS(), getConstant(NewRHS));
3542 }
3543 }
3544
3545 // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded.
3546 if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(LHS)) {
3548 for (const SCEV *Op : A->operands())
3549 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3550 if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) {
3551 Operands.clear();
3552 for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) {
3553 const SCEV *Op = getUDivExpr(A->getOperand(i), RHS);
3554 if (isa<SCEVUDivExpr>(Op) ||
3555 getMulExpr(Op, RHS) != A->getOperand(i))
3556 break;
3557 Operands.push_back(Op);
3558 }
3559 if (Operands.size() == A->getNumOperands())
3560 return getAddExpr(Operands);
3561 }
3562 }
3563
3564 // Fold if both operands are constant.
3565 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS))
3566 return getConstant(LHSC->getAPInt().udiv(RHSC->getAPInt()));
3567 }
3568 }
3569
3570 // The Insertion Point (IP) might be invalid by now (due to UniqueSCEVs
3571 // changes). Make sure we get a new one.
3572 IP = nullptr;
3573 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
3574 SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator),
3575 LHS, RHS);
3576 UniqueSCEVs.InsertNode(S, IP);
3577 registerUser(S, {LHS, RHS});
3578 return S;
3579}
3580
3581APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) {
3582 APInt A = C1->getAPInt().abs();
3583 APInt B = C2->getAPInt().abs();
3584 uint32_t ABW = A.getBitWidth();
3585 uint32_t BBW = B.getBitWidth();
3586
3587 if (ABW > BBW)
3588 B = B.zext(ABW);
3589 else if (ABW < BBW)
3590 A = A.zext(BBW);
3591
3592 return APIntOps::GreatestCommonDivisor(std::move(A), std::move(B));
3593}
3594
3595/// Get a canonical unsigned division expression, or something simpler if
3596/// possible. There is no representation for an exact udiv in SCEV IR, but we
3597/// can attempt to remove factors from the LHS and RHS. We can't do this when
3598/// it's not exact because the udiv may be clearing bits.
3600 const SCEV *RHS) {
3601 // TODO: we could try to find factors in all sorts of things, but for now we
3602 // just deal with u/exact (multiply, constant). See SCEVDivision towards the
3603 // end of this file for inspiration.
3604
3605 const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(LHS);
3606 if (!Mul || !Mul->hasNoUnsignedWrap())
3607 return getUDivExpr(LHS, RHS);
3608
3609 if (const SCEVConstant *RHSCst = dyn_cast<SCEVConstant>(RHS)) {
3610 // If the mulexpr multiplies by a constant, then that constant must be the
3611 // first element of the mulexpr.
3612 if (const auto *LHSCst = dyn_cast<SCEVConstant>(Mul->getOperand(0))) {
3613 if (LHSCst == RHSCst) {
3615 return getMulExpr(Operands);
3616 }
3617
3618 // We can't just assume that LHSCst divides RHSCst cleanly, it could be
3619 // that there's a factor provided by one of the other terms. We need to
3620 // check.
3621 APInt Factor = gcd(LHSCst, RHSCst);
3622 if (!Factor.isIntN(1)) {
3623 LHSCst =
3624 cast<SCEVConstant>(getConstant(LHSCst->getAPInt().udiv(Factor)));
3625 RHSCst =
3626 cast<SCEVConstant>(getConstant(RHSCst->getAPInt().udiv(Factor)));
3628 Operands.push_back(LHSCst);
3629 append_range(Operands, Mul->operands().drop_front());
3631 RHS = RHSCst;
3632 Mul = dyn_cast<SCEVMulExpr>(LHS);
3633 if (!Mul)
3634 return getUDivExactExpr(LHS, RHS);
3635 }
3636 }
3637 }
3638
3639 for (int i = 0, e = Mul->getNumOperands(); i != e; ++i) {
3640 if (Mul->getOperand(i) == RHS) {
3642 append_range(Operands, Mul->operands().take_front(i));
3643 append_range(Operands, Mul->operands().drop_front(i + 1));
3644 return getMulExpr(Operands);
3645 }
3646 }
3647
3648 return getUDivExpr(LHS, RHS);
3649}
3650
3651/// Get an add recurrence expression for the specified loop. Simplify the
3652/// expression as much as possible.
3653const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, const SCEV *Step,
3654 const Loop *L,
3655 SCEV::NoWrapFlags Flags) {
3657 Operands.push_back(Start);
3658 if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step))
3659 if (StepChrec->getLoop() == L) {
3660 append_range(Operands, StepChrec->operands());
3661 return getAddRecExpr(Operands, L, maskFlags(Flags, SCEV::FlagNW));
3662 }
3663
3664 Operands.push_back(Step);
3665 return getAddRecExpr(Operands, L, Flags);
3666}
3667
3668/// Get an add recurrence expression for the specified loop. Simplify the
3669/// expression as much as possible.
3670const SCEV *
3672 const Loop *L, SCEV::NoWrapFlags Flags) {
3673 if (Operands.size() == 1) return Operands[0];
3674#ifndef NDEBUG
3676 for (unsigned i = 1, e = Operands.size(); i != e; ++i) {
3678 "SCEVAddRecExpr operand types don't match!");
3679 assert(!Operands[i]->getType()->isPointerTy() && "Step must be integer");
3680 }
3681 for (unsigned i = 0, e = Operands.size(); i != e; ++i)
3683 "SCEVAddRecExpr operand is not available at loop entry!");
3684#endif
3685
3686 if (Operands.back()->isZero()) {
3687 Operands.pop_back();
3688 return getAddRecExpr(Operands, L, SCEV::FlagAnyWrap); // {X,+,0} --> X
3689 }
3690
3691 // It's tempting to want to call getConstantMaxBackedgeTakenCount count here and
3692 // use that information to infer NUW and NSW flags. However, computing a
3693 // BE count requires calling getAddRecExpr, so we may not yet have a
3694 // meaningful BE count at this point (and if we don't, we'd be stuck
3695 // with a SCEVCouldNotCompute as the cached BE count).
3696
3697 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
3698
3699 // Canonicalize nested AddRecs in by nesting them in order of loop depth.
3700 if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
3701 const Loop *NestedLoop = NestedAR->getLoop();
3702 if (L->contains(NestedLoop)
3703 ? (L->getLoopDepth() < NestedLoop->getLoopDepth())
3704 : (!NestedLoop->contains(L) &&
3705 DT.dominates(L->getHeader(), NestedLoop->getHeader()))) {
3706 SmallVector<const SCEV *, 4> NestedOperands(NestedAR->operands());
3707 Operands[0] = NestedAR->getStart();
3708 // AddRecs require their operands be loop-invariant with respect to their
3709 // loops. Don't perform this transformation if it would break this
3710 // requirement.
3711 bool AllInvariant = all_of(
3712 Operands, [&](const SCEV *Op) { return isLoopInvariant(Op, L); });
3713
3714 if (AllInvariant) {
3715 // Create a recurrence for the outer loop with the same step size.
3716 //
3717 // The outer recurrence keeps its NW flag but only keeps NUW/NSW if the
3718 // inner recurrence has the same property.
3719 SCEV::NoWrapFlags OuterFlags =
3720 maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags());
3721
3722 NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags);
3723 AllInvariant = all_of(NestedOperands, [&](const SCEV *Op) {
3724 return isLoopInvariant(Op, NestedLoop);
3725 });
3726
3727 if (AllInvariant) {
3728 // Ok, both add recurrences are valid after the transformation.
3729 //
3730 // The inner recurrence keeps its NW flag but only keeps NUW/NSW if
3731 // the outer recurrence has the same property.
3732 SCEV::NoWrapFlags InnerFlags =
3733 maskFlags(NestedAR->getNoWrapFlags(), SCEV::FlagNW | Flags);
3734 return getAddRecExpr(NestedOperands, NestedLoop, InnerFlags);
3735 }
3736 }
3737 // Reset Operands to its original state.
3738 Operands[0] = NestedAR;
3739 }
3740 }
3741
3742 // Okay, it looks like we really DO need an addrec expr. Check to see if we
3743 // already have one, otherwise create a new one.
3744 return getOrCreateAddRecExpr(Operands, L, Flags);
3745}
3746
3747const SCEV *
3749 const SmallVectorImpl<const SCEV *> &IndexExprs) {
3750 const SCEV *BaseExpr = getSCEV(GEP->getPointerOperand());
3751 // getSCEV(Base)->getType() has the same address space as Base->getType()
3752 // because SCEV::getType() preserves the address space.
3753 Type *IntIdxTy = getEffectiveSCEVType(BaseExpr->getType());
3754 const bool AssumeInBoundsFlags = [&]() {
3755 if (!GEP->isInBounds())
3756 return false;
3757
3758 // We'd like to propagate flags from the IR to the corresponding SCEV nodes,
3759 // but to do that, we have to ensure that said flag is valid in the entire
3760 // defined scope of the SCEV.
3761 auto *GEPI = dyn_cast<Instruction>(GEP);
3762 // TODO: non-instructions have global scope. We might be able to prove
3763 // some global scope cases
3764 return GEPI && isSCEVExprNeverPoison(GEPI);
3765 }();
3766
3767 SCEV::NoWrapFlags OffsetWrap =
3768 AssumeInBoundsFlags ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
3769
3770 Type *CurTy = GEP->getType();
3771 bool FirstIter = true;
3773 for (const SCEV *IndexExpr : IndexExprs) {
3774 // Compute the (potentially symbolic) offset in bytes for this index.
3775 if (StructType *STy = dyn_cast<StructType>(CurTy)) {
3776 // For a struct, add the member offset.
3777 ConstantInt *Index = cast<SCEVConstant>(IndexExpr)->getValue();
3778 unsigned FieldNo = Index->getZExtValue();
3779 const SCEV *FieldOffset = getOffsetOfExpr(IntIdxTy, STy, FieldNo);
3780 Offsets.push_back(FieldOffset);
3781
3782 // Update CurTy to the type of the field at Index.
3783 CurTy = STy->getTypeAtIndex(Index);
3784 } else {
3785 // Update CurTy to its element type.
3786 if (FirstIter) {
3787 assert(isa<PointerType>(CurTy) &&
3788 "The first index of a GEP indexes a pointer");
3789 CurTy = GEP->getSourceElementType();
3790 FirstIter = false;
3791 } else {
3793 }
3794 // For an array, add the element offset, explicitly scaled.
3795 const SCEV *ElementSize = getSizeOfExpr(IntIdxTy, CurTy);
3796 // Getelementptr indices are signed.
3797 IndexExpr = getTruncateOrSignExtend(IndexExpr, IntIdxTy);
3798
3799 // Multiply the index by the element size to compute the element offset.
3800 const SCEV *LocalOffset = getMulExpr(IndexExpr, ElementSize, OffsetWrap);
3801 Offsets.push_back(LocalOffset);
3802 }
3803 }
3804
3805 // Handle degenerate case of GEP without offsets.
3806 if (Offsets.empty())
3807 return BaseExpr;
3808
3809 // Add the offsets together, assuming nsw if inbounds.
3810 const SCEV *Offset = getAddExpr(Offsets, OffsetWrap);
3811 // Add the base address and the offset. We cannot use the nsw flag, as the
3812 // base address is unsigned. However, if we know that the offset is
3813 // non-negative, we can use nuw.
3814 SCEV::NoWrapFlags BaseWrap = AssumeInBoundsFlags && isKnownNonNegative(Offset)
3816 auto *GEPExpr = getAddExpr(BaseExpr, Offset, BaseWrap);
3817 assert(BaseExpr->getType() == GEPExpr->getType() &&
3818 "GEP should not change type mid-flight.");
3819 return GEPExpr;
3820}
3821
3822SCEV *ScalarEvolution::findExistingSCEVInCache(SCEVTypes SCEVType,
3825 ID.AddInteger(SCEVType);
3826 for (const SCEV *Op : Ops)
3827 ID.AddPointer(Op);
3828 void *IP = nullptr;
3829 return UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3830}
3831
3832const SCEV *ScalarEvolution::getAbsExpr(const SCEV *Op, bool IsNSW) {
3834 return getSMaxExpr(Op, getNegativeSCEV(Op, Flags));
3835}
3836
3839 assert(SCEVMinMaxExpr::isMinMaxType(Kind) && "Not a SCEVMinMaxExpr!");
3840 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
3841 if (Ops.size() == 1) return Ops[0];
3842#ifndef NDEBUG
3843 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
3844 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
3845 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
3846 "Operand types don't match!");
3847 assert(Ops[0]->getType()->isPointerTy() ==
3848 Ops[i]->getType()->isPointerTy() &&
3849 "min/max should be consistently pointerish");
3850 }
3851#endif
3852
3853 bool IsSigned = Kind == scSMaxExpr || Kind == scSMinExpr;
3854 bool IsMax = Kind == scSMaxExpr || Kind == scUMaxExpr;
3855
3856 // Sort by complexity, this groups all similar expression types together.
3857 GroupByComplexity(Ops, &LI, DT);
3858
3859 // Check if we have created the same expression before.
3860 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops)) {
3861 return S;
3862 }
3863
3864 // If there are any constants, fold them together.
3865 unsigned Idx = 0;
3866 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3867 ++Idx;
3868 assert(Idx < Ops.size());
3869 auto FoldOp = [&](const APInt &LHS, const APInt &RHS) {
3870 switch (Kind) {
3871 case scSMaxExpr:
3872 return APIntOps::smax(LHS, RHS);
3873 case scSMinExpr:
3874 return APIntOps::smin(LHS, RHS);
3875 case scUMaxExpr:
3876 return APIntOps::umax(LHS, RHS);
3877 case scUMinExpr:
3878 return APIntOps::umin(LHS, RHS);
3879 default:
3880 llvm_unreachable("Unknown SCEV min/max opcode");
3881 }
3882 };
3883
3884 while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
3885 // We found two constants, fold them together!
3886 ConstantInt *Fold = ConstantInt::get(
3887 getContext(), FoldOp(LHSC->getAPInt(), RHSC->getAPInt()));
3888 Ops[0] = getConstant(Fold);
3889 Ops.erase(Ops.begin()+1); // Erase the folded element
3890 if (Ops.size() == 1) return Ops[0];
3891 LHSC = cast<SCEVConstant>(Ops[0]);
3892 }
3893
3894 bool IsMinV = LHSC->getValue()->isMinValue(IsSigned);
3895 bool IsMaxV = LHSC->getValue()->isMaxValue(IsSigned);
3896
3897 if (IsMax ? IsMinV : IsMaxV) {
3898 // If we are left with a constant minimum(/maximum)-int, strip it off.
3899 Ops.erase(Ops.begin());
3900 --Idx;
3901 } else if (IsMax ? IsMaxV : IsMinV) {
3902 // If we have a max(/min) with a constant maximum(/minimum)-int,
3903 // it will always be the extremum.
3904 return LHSC;
3905 }
3906
3907 if (Ops.size() == 1) return Ops[0];
3908 }
3909
3910 // Find the first operation of the same kind
3911 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < Kind)
3912 ++Idx;
3913
3914 // Check to see if one of the operands is of the same kind. If so, expand its
3915 // operands onto our operand list, and recurse to simplify.
3916 if (Idx < Ops.size()) {
3917 bool DeletedAny = false;
3918 while (Ops[Idx]->getSCEVType() == Kind) {
3919 const SCEVMinMaxExpr *SMME = cast<SCEVMinMaxExpr>(Ops[Idx]);
3920 Ops.erase(Ops.begin()+Idx);
3921 append_range(Ops, SMME->operands());
3922 DeletedAny = true;
3923 }
3924
3925 if (DeletedAny)
3926 return getMinMaxExpr(Kind, Ops);
3927 }
3928
3929 // Okay, check to see if the same value occurs in the operand list twice. If
3930 // so, delete one. Since we sorted the list, these values are required to
3931 // be adjacent.
3936 llvm::CmpInst::Predicate FirstPred = IsMax ? GEPred : LEPred;
3937 llvm::CmpInst::Predicate SecondPred = IsMax ? LEPred : GEPred;
3938 for (unsigned i = 0, e = Ops.size() - 1; i != e; ++i) {
3939 if (Ops[i] == Ops[i + 1] ||
3940 isKnownViaNonRecursiveReasoning(FirstPred, Ops[i], Ops[i + 1])) {
3941 // X op Y op Y --> X op Y
3942 // X op Y --> X, if we know X, Y are ordered appropriately
3943 Ops.erase(Ops.begin() + i + 1, Ops.begin() + i + 2);
3944 --i;
3945 --e;
3946 } else if (isKnownViaNonRecursiveReasoning(SecondPred, Ops[i],
3947 Ops[i + 1])) {
3948 // X op Y --> Y, if we know X, Y are ordered appropriately
3949 Ops.erase(Ops.begin() + i, Ops.begin() + i + 1);
3950 --i;
3951 --e;
3952 }
3953 }
3954
3955 if (Ops.size() == 1) return Ops[0];
3956
3957 assert(!Ops.empty() && "Reduced smax down to nothing!");
3958
3959 // Okay, it looks like we really DO need an expr. Check to see if we
3960 // already have one, otherwise create a new one.
3962 ID.AddInteger(Kind);
3963 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3964 ID.AddPointer(Ops[i]);
3965 void *IP = nullptr;
3966 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3967 if (ExistingSCEV)
3968 return ExistingSCEV;
3969 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3970 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
3971 SCEV *S = new (SCEVAllocator)
3972 SCEVMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
3973
3974 UniqueSCEVs.InsertNode(S, IP);
3975 registerUser(S, Ops);
3976 return S;
3977}
3978
3979namespace {
3980
3981class SCEVSequentialMinMaxDeduplicatingVisitor final
3982 : public SCEVVisitor<SCEVSequentialMinMaxDeduplicatingVisitor,
3983 std::optional<const SCEV *>> {
3984 using RetVal = std::optional<const SCEV *>;
3986
3987 ScalarEvolution &SE;
3988 const SCEVTypes RootKind; // Must be a sequential min/max expression.
3989 const SCEVTypes NonSequentialRootKind; // Non-sequential variant of RootKind.
3991
3992 bool canRecurseInto(SCEVTypes Kind) const {
3993 // We can only recurse into the SCEV expression of the same effective type
3994 // as the type of our root SCEV expression.
3995 return RootKind == Kind || NonSequentialRootKind == Kind;
3996 };
3997
3998 RetVal visitAnyMinMaxExpr(const SCEV *S) {
3999 assert((isa<SCEVMinMaxExpr>(S) || isa<SCEVSequentialMinMaxExpr>(S)) &&
4000 "Only for min/max expressions.");
4001 SCEVTypes Kind = S->getSCEVType();
4002
4003 if (!canRecurseInto(Kind))
4004 return S;
4005
4006 auto *NAry = cast<SCEVNAryExpr>(S);
4008 bool Changed = visit(Kind, NAry->operands(), NewOps);
4009
4010 if (!Changed)
4011 return S;
4012 if (NewOps.empty())
4013 return std::nullopt;
4014
4015 return isa<SCEVSequentialMinMaxExpr>(S)
4016 ? SE.getSequentialMinMaxExpr(Kind, NewOps)
4017 : SE.getMinMaxExpr(Kind, NewOps);
4018 }
4019
4020 RetVal visit(const SCEV *S) {
4021 // Has the whole operand been seen already?
4022 if (!SeenOps.insert(S).second)
4023 return std::nullopt;
4024 return Base::visit(S);
4025 }
4026
4027public:
4028 SCEVSequentialMinMaxDeduplicatingVisitor(ScalarEvolution &SE,
4029 SCEVTypes RootKind)
4030 : SE(SE), RootKind(RootKind),
4031 NonSequentialRootKind(
4032 SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType(
4033 RootKind)) {}
4034
4035 bool /*Changed*/ visit(SCEVTypes Kind, ArrayRef<const SCEV *> OrigOps,
4037 bool Changed = false;
4039 Ops.reserve(OrigOps.size());
4040
4041 for (const SCEV *Op : OrigOps) {
4042 RetVal NewOp = visit(Op);
4043 if (NewOp != Op)
4044 Changed = true;
4045 if (NewOp)
4046 Ops.emplace_back(*NewOp);
4047 }
4048
4049 if (Changed)
4050 NewOps = std::move(Ops);
4051 return Changed;
4052 }
4053
4054 RetVal visitConstant(const SCEVConstant *Constant) { return Constant; }
4055
4056 RetVal visitVScale(const SCEVVScale *VScale) { return VScale; }
4057
4058 RetVal visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { return Expr; }
4059
4060 RetVal visitTruncateExpr(const SCEVTruncateExpr *Expr) { return Expr; }
4061
4062 RetVal visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { return Expr; }
4063
4064 RetVal visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { return Expr; }
4065
4066 RetVal visitAddExpr(const SCEVAddExpr *Expr) { return Expr; }
4067
4068 RetVal visitMulExpr(const SCEVMulExpr *Expr) { return Expr; }
4069
4070 RetVal visitUDivExpr(const SCEVUDivExpr *Expr) { return Expr; }
4071
4072 RetVal visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
4073
4074 RetVal visitSMaxExpr(const SCEVSMaxExpr *Expr) {
4075 return visitAnyMinMaxExpr(Expr);
4076 }
4077
4078 RetVal visitUMaxExpr(const SCEVUMaxExpr *Expr) {
4079 return visitAnyMinMaxExpr(Expr);
4080 }
4081
4082 RetVal visitSMinExpr(const SCEVSMinExpr *Expr) {
4083 return visitAnyMinMaxExpr(Expr);
4084 }
4085
4086 RetVal visitUMinExpr(const SCEVUMinExpr *Expr) {
4087 return visitAnyMinMaxExpr(Expr);
4088 }
4089
4090 RetVal visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
4091 return visitAnyMinMaxExpr(Expr);
4092 }
4093
4094 RetVal visitUnknown(const SCEVUnknown *Expr) { return Expr; }
4095
4096 RetVal visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { return Expr; }
4097};
4098
4099} // namespace
4100
4102 switch (Kind) {
4103 case scConstant:
4104 case scVScale:
4105 case scTruncate:
4106 case scZeroExtend:
4107 case scSignExtend:
4108 case scPtrToInt:
4109 case scAddExpr:
4110 case scMulExpr:
4111 case scUDivExpr:
4112 case scAddRecExpr:
4113 case scUMaxExpr:
4114 case scSMaxExpr:
4115 case scUMinExpr:
4116 case scSMinExpr:
4117 case scUnknown:
4118 // If any operand is poison, the whole expression is poison.
4119 return true;
4121 // FIXME: if the *first* operand is poison, the whole expression is poison.
4122 return false; // Pessimistically, say that it does not propagate poison.
4123 case scCouldNotCompute:
4124 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
4125 }
4126 llvm_unreachable("Unknown SCEV kind!");
4127}
4128
4129namespace {
4130// The only way poison may be introduced in a SCEV expression is from a
4131// poison SCEVUnknown (ConstantExprs are also represented as SCEVUnknown,
4132// not SCEVConstant). Notably, nowrap flags in SCEV nodes can *not*
4133// introduce poison -- they encode guaranteed, non-speculated knowledge.
4134//
4135// Additionally, all SCEV nodes propagate poison from inputs to outputs,
4136// with the notable exception of umin_seq, where only poison from the first
4137// operand is (unconditionally) propagated.
4138struct SCEVPoisonCollector {
4139 bool LookThroughMaybePoisonBlocking;
4141 SCEVPoisonCollector(bool LookThroughMaybePoisonBlocking)
4142 : LookThroughMaybePoisonBlocking(LookThroughMaybePoisonBlocking) {}
4143
4144 bool follow(const SCEV *S) {
4145 if (!LookThroughMaybePoisonBlocking &&
4147 return false;
4148
4149 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
4150 if (!isGuaranteedNotToBePoison(SU->getValue()))
4151 MaybePoison.insert(SU);
4152 }
4153 return true;
4154 }
4155 bool isDone() const { return false; }
4156};
4157} // namespace
4158
4159/// Return true if V is poison given that AssumedPoison is already poison.
4160static bool impliesPoison(const SCEV *AssumedPoison, const SCEV *S) {
4161 // First collect all SCEVs that might result in AssumedPoison to be poison.
4162 // We need to look through potentially poison-blocking operations here,
4163 // because we want to find all SCEVs that *might* result in poison, not only
4164 // those that are *required* to.
4165 SCEVPoisonCollector PC1(/* LookThroughMaybePoisonBlocking */ true);
4166 visitAll(AssumedPoison, PC1);
4167
4168 // AssumedPoison is never poison. As the assumption is false, the implication
4169 // is true. Don't bother walking the other SCEV in this case.
4170 if (PC1.MaybePoison.empty())
4171 return true;
4172
4173 // Collect all SCEVs in S that, if poison, *will* result in S being poison
4174 // as well. We cannot look through potentially poison-blocking operations
4175 // here, as their arguments only *may* make the result poison.
4176 SCEVPoisonCollector PC2(/* LookThroughMaybePoisonBlocking */ false);
4177 visitAll(S, PC2);
4178
4179 // Make sure that no matter which SCEV in PC1.MaybePoison is actually poison,
4180 // it will also make S poison by being part of PC2.MaybePoison.
4181 return all_of(PC1.MaybePoison, [&](const SCEVUnknown *S) {
4182 return PC2.MaybePoison.contains(S);
4183 });
4184}
4185
4187 SmallPtrSetImpl<const Value *> &Result, const SCEV *S) {
4188 SCEVPoisonCollector PC(/* LookThroughMaybePoisonBlocking */ false);
4189 visitAll(S, PC);
4190 for (const SCEVUnknown *SU : PC.MaybePoison)
4191 Result.insert(SU->getValue());
4192}
4193
4195 const SCEV *S, Instruction *I,
4196 SmallVectorImpl<Instruction *> &DropPoisonGeneratingInsts) {
4197 // If the instruction cannot be poison, it's always safe to reuse.
4199 return true;
4200
4201 // Otherwise, it is possible that I is more poisonous that S. Collect the
4202 // poison-contributors of S, and then check whether I has any additional
4203 // poison-contributors. Poison that is contributed through poison-generating
4204 // flags is handled by dropping those flags instead.
4206 getPoisonGeneratingValues(PoisonVals, S);
4207
4208 SmallVector<Value *> Worklist;
4210 Worklist.push_back(I);
4211 while (!Worklist.empty()) {
4212 Value *V = Worklist.pop_back_val();
4213 if (!Visited.insert(V).second)
4214 continue;
4215
4216 // Avoid walking large instruction graphs.
4217 if (Visited.size() > 16)
4218 return false;
4219
4220 // Either the value can't be poison, or the S would also be poison if it
4221 // is.
4222 if (PoisonVals.contains(V) || isGuaranteedNotToBePoison(V))
4223 continue;
4224
4225 auto *I = dyn_cast<Instruction>(V);
4226 if (!I)
4227 return false;
4228
4229 // Disjoint or instructions are interpreted as adds by SCEV. However, we
4230 // can't replace an arbitrary add with disjoint or, even if we drop the
4231 // flag. We would need to convert the or into an add.
4232 if (auto *PDI = dyn_cast<PossiblyDisjointInst>(I))
4233 if (PDI->isDisjoint())
4234 return false;
4235
4236 // FIXME: Ignore vscale, even though it technically could be poison. Do this
4237 // because SCEV currently assumes it can't be poison. Remove this special
4238 // case once we proper model when vscale can be poison.
4239 if (auto *II = dyn_cast<IntrinsicInst>(I);
4240 II && II->getIntrinsicID() == Intrinsic::vscale)
4241 continue;
4242
4243 if (canCreatePoison(cast<Operator>(I), /*ConsiderFlagsAndMetadata*/ false))
4244 return false;
4245
4246 // If the instruction can't create poison, we can recurse to its operands.
4247 if (I->hasPoisonGeneratingFlagsOrMetadata())
4248 DropPoisonGeneratingInsts.push_back(I);
4249
4250 for (Value *Op : I->operands())
4251 Worklist.push_back(Op);
4252 }
4253 return true;
4254}
4255
4256const SCEV *
4259 assert(SCEVSequentialMinMaxExpr::isSequentialMinMaxType(Kind) &&
4260 "Not a SCEVSequentialMinMaxExpr!");
4261 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
4262 if (Ops.size() == 1)
4263 return Ops[0];
4264#ifndef NDEBUG
4265 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
4266 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4267 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
4268 "Operand types don't match!");
4269 assert(Ops[0]->getType()->isPointerTy() ==
4270 Ops[i]->getType()->isPointerTy() &&
4271 "min/max should be consistently pointerish");
4272 }
4273#endif
4274
4275 // Note that SCEVSequentialMinMaxExpr is *NOT* commutative,
4276 // so we can *NOT* do any kind of sorting of the expressions!
4277
4278 // Check if we have created the same expression before.
4279 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops))
4280 return S;
4281
4282 // FIXME: there are *some* simplifications that we can do here.
4283
4284 // Keep only the first instance of an operand.
4285 {
4286 SCEVSequentialMinMaxDeduplicatingVisitor Deduplicator(*this, Kind);
4287 bool Changed = Deduplicator.visit(Kind, Ops, Ops);
4288 if (Changed)
4289 return getSequentialMinMaxExpr(Kind, Ops);
4290 }
4291
4292 // Check to see if one of the operands is of the same kind. If so, expand its
4293 // operands onto our operand list, and recurse to simplify.
4294 {
4295 unsigned Idx = 0;
4296 bool DeletedAny = false;
4297 while (Idx < Ops.size()) {
4298 if (Ops[Idx]->getSCEVType() != Kind) {
4299 ++Idx;
4300 continue;
4301 }
4302 const auto *SMME = cast<SCEVSequentialMinMaxExpr>(Ops[Idx]);
4303 Ops.erase(Ops.begin() + Idx);
4304 Ops.insert(Ops.begin() + Idx, SMME->operands().begin(),
4305 SMME->operands().end());
4306 DeletedAny = true;
4307 }
4308
4309 if (DeletedAny)
4310 return getSequentialMinMaxExpr(Kind, Ops);
4311 }
4312
4313 const SCEV *SaturationPoint;
4315 switch (Kind) {
4317 SaturationPoint = getZero(Ops[0]->getType());
4318 Pred = ICmpInst::ICMP_ULE;
4319 break;
4320 default:
4321 llvm_unreachable("Not a sequential min/max type.");
4322 }
4323
4324 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4325 // We can replace %x umin_seq %y with %x umin %y if either:
4326 // * %y being poison implies %x is also poison.
4327 // * %x cannot be the saturating value (e.g. zero for umin).
4328 if (::impliesPoison(Ops[i], Ops[i - 1]) ||
4329 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_NE, Ops[i - 1],
4330 SaturationPoint)) {
4331 SmallVector<const SCEV *> SeqOps = {Ops[i - 1], Ops[i]};
4332 Ops[i - 1] = getMinMaxExpr(
4334 SeqOps);
4335 Ops.erase(Ops.begin() + i);
4336 return getSequentialMinMaxExpr(Kind, Ops);
4337 }
4338 // Fold %x umin_seq %y to %x if %x ule %y.
4339 // TODO: We might be able to prove the predicate for a later operand.
4340 if (isKnownViaNonRecursiveReasoning(Pred, Ops[i - 1], Ops[i])) {
4341 Ops.erase(Ops.begin() + i);
4342 return getSequentialMinMaxExpr(Kind, Ops);
4343 }
4344 }
4345
4346 // Okay, it looks like we really DO need an expr. Check to see if we
4347 // already have one, otherwise create a new one.
4349 ID.AddInteger(Kind);
4350 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
4351 ID.AddPointer(Ops[i]);
4352 void *IP = nullptr;
4353 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
4354 if (ExistingSCEV)
4355 return ExistingSCEV;
4356
4357 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
4358 std::uninitialized_copy(Ops.begin(), Ops.end(), O);
4359 SCEV *S = new (SCEVAllocator)
4360 SCEVSequentialMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
4361
4362 UniqueSCEVs.InsertNode(S, IP);
4363 registerUser(S, Ops);
4364 return S;
4365}
4366
4367const SCEV *ScalarEvolution::getSMaxExpr(const SCEV *LHS, const SCEV *RHS) {
4369 return getSMaxExpr(Ops);
4370}
4371
4373 return getMinMaxExpr(scSMaxExpr, Ops);
4374}
4375
4376const SCEV *ScalarEvolution::getUMaxExpr(const SCEV *LHS, const SCEV *RHS) {
4378 return getUMaxExpr(Ops);
4379}
4380
4382 return getMinMaxExpr(scUMaxExpr, Ops);
4383}
4384
4386 const SCEV *RHS) {
4388 return getSMinExpr(Ops);
4389}
4390
4392 return getMinMaxExpr(scSMinExpr, Ops);
4393}
4394
4395const SCEV *ScalarEvolution::getUMinExpr(const SCEV *LHS, const SCEV *RHS,
4396 bool Sequential) {
4398 return getUMinExpr(Ops, Sequential);
4399}
4400
4402 bool Sequential) {
4403 return Sequential ? getSequentialMinMaxExpr(scSequentialUMinExpr, Ops)
4404 : getMinMaxExpr(scUMinExpr, Ops);
4405}
4406
4407const SCEV *
4409 const SCEV *Res = getConstant(IntTy, Size.getKnownMinValue());
4410 if (Size.isScalable())
4411 Res = getMulExpr(Res, getVScale(IntTy));
4412 return Res;
4413}
4414
4416 return getSizeOfExpr(IntTy, getDataLayout().getTypeAllocSize(AllocTy));
4417}
4418
4420 return getSizeOfExpr(IntTy, getDataLayout().getTypeStoreSize(StoreTy));
4421}
4422
4424 StructType *STy,
4425 unsigned FieldNo) {
4426 // We can bypass creating a target-independent constant expression and then
4427 // folding it back into a ConstantInt. This is just a compile-time
4428 // optimization.
4429 const StructLayout *SL = getDataLayout().getStructLayout(STy);
4430 assert(!SL->getSizeInBits().isScalable() &&
4431 "Cannot get offset for structure containing scalable vector types");
4432 return getConstant(IntTy, SL->getElementOffset(FieldNo));
4433}
4434
4436 // Don't attempt to do anything other than create a SCEVUnknown object
4437 // here. createSCEV only calls getUnknown after checking for all other
4438 // interesting possibilities, and any other code that calls getUnknown
4439 // is doing so in order to hide a value from SCEV canonicalization.
4440
4442 ID.AddInteger(scUnknown);
4443 ID.AddPointer(V);
4444 void *IP = nullptr;
4445 if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) {
4446 assert(cast<SCEVUnknown>(S)->getValue() == V &&
4447 "Stale SCEVUnknown in uniquing map!");
4448 return S;
4449 }
4450 SCEV *S = new (SCEVAllocator) SCEVUnknown(ID.Intern(SCEVAllocator), V, this,
4451 FirstUnknown);
4452 FirstUnknown = cast<SCEVUnknown>(S);
4453 UniqueSCEVs.InsertNode(S, IP);
4454 return S;
4455}
4456
4457//===----------------------------------------------------------------------===//
4458// Basic SCEV Analysis and PHI Idiom Recognition Code
4459//
4460
4461/// Test if values of the given type are analyzable within the SCEV
4462/// framework. This primarily includes integer types, and it can optionally
4463/// include pointer types if the ScalarEvolution class has access to
4464/// target-specific information.
4466 // Integers and pointers are always SCEVable.
4467 return Ty->isIntOrPtrTy();
4468}
4469
4470/// Return the size in bits of the specified type, for which isSCEVable must
4471/// return true.
4473 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4474 if (Ty->isPointerTy())
4476 return getDataLayout().getTypeSizeInBits(Ty);
4477}
4478
4479/// Return a type with the same bitwidth as the given type and which represents
4480/// how SCEV will treat the given type, for which isSCEVable must return
4481/// true. For pointer types, this is the pointer index sized integer type.
4483 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4484
4485 if (Ty->isIntegerTy())
4486 return Ty;
4487
4488 // The only other support type is pointer.
4489 assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!");
4490 return getDataLayout().getIndexType(Ty);
4491}
4492
4494 return getTypeSizeInBits(T1) >= getTypeSizeInBits(T2) ? T1 : T2;
4495}
4496
4498 const SCEV *B) {
4499 /// For a valid use point to exist, the defining scope of one operand
4500 /// must dominate the other.
4501 bool PreciseA, PreciseB;
4502 auto *ScopeA = getDefiningScopeBound({A}, PreciseA);
4503 auto *ScopeB = getDefiningScopeBound({B}, PreciseB);
4504 if (!PreciseA || !PreciseB)
4505 // Can't tell.
4506 return false;
4507 return (ScopeA == ScopeB) || DT.dominates(ScopeA, ScopeB) ||
4508 DT.dominates(ScopeB, ScopeA);
4509}
4510
4512 return CouldNotCompute.get();
4513}
4514
4515bool ScalarEvolution::checkValidity(const SCEV *S) const {
4516 bool ContainsNulls = SCEVExprContains(S, [](const SCEV *S) {
4517 auto *SU = dyn_cast<SCEVUnknown>(S);
4518 return SU && SU->getValue() == nullptr;
4519 });
4520
4521 return !ContainsNulls;
4522}
4523
4525 HasRecMapType::iterator I = HasRecMap.find(S);
4526 if (I != HasRecMap.end())
4527 return I->second;
4528
4529 bool FoundAddRec =
4530 SCEVExprContains(S, [](const SCEV *S) { return isa<SCEVAddRecExpr>(S); });
4531 HasRecMap.insert({S, FoundAddRec});
4532 return FoundAddRec;
4533}
4534
4535/// Return the ValueOffsetPair set for \p S. \p S can be represented
4536/// by the value and offset from any ValueOffsetPair in the set.
4537ArrayRef<Value *> ScalarEvolution::getSCEVValues(const SCEV *S) {
4538 ExprValueMapType::iterator SI = ExprValueMap.find_as(S);
4539 if (SI == ExprValueMap.end())
4540 return std::nullopt;
4541 return SI->second.getArrayRef();
4542}
4543
4544/// Erase Value from ValueExprMap and ExprValueMap. ValueExprMap.erase(V)
4545/// cannot be used separately. eraseValueFromMap should be used to remove
4546/// V from ValueExprMap and ExprValueMap at the same time.
4547void ScalarEvolution::eraseValueFromMap(Value *V) {
4548 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4549 if (I != ValueExprMap.end()) {
4550 auto EVIt = ExprValueMap.find(I->second);
4551 bool Removed = EVIt->second.remove(V);
4552 (void) Removed;
4553 assert(Removed && "Value not in ExprValueMap?");
4554 ValueExprMap.erase(I);
4555 }
4556}
4557
4558void ScalarEvolution::insertValueToMap(Value *V, const SCEV *S) {
4559 // A recursive query may have already computed the SCEV. It should be
4560 // equivalent, but may not necessarily be exactly the same, e.g. due to lazily
4561 // inferred nowrap flags.
4562 auto It = ValueExprMap.find_as(V);
4563 if (It == ValueExprMap.end()) {
4564 ValueExprMap.insert({SCEVCallbackVH(V, this), S});
4565 ExprValueMap[S].insert(V);
4566 }
4567}
4568
4569/// Return an existing SCEV if it exists, otherwise analyze the expression and
4570/// create a new one.
4572 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4573
4574 if (const SCEV *S = getExistingSCEV(V))
4575 return S;
4576 return createSCEVIter(V);
4577}
4578
4580 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4581
4582 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4583 if (I != ValueExprMap.end()) {
4584 const SCEV *S = I->second;
4585 assert(checkValidity(S) &&
4586 "existing SCEV has not been properly invalidated");
4587 return S;
4588 }
4589 return nullptr;
4590}
4591
4592/// Return a SCEV corresponding to -V = -1*V
4594 SCEV::NoWrapFlags Flags) {
4595 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4596 return getConstant(
4597 cast<ConstantInt>(ConstantExpr::getNeg(VC->getValue())));
4598
4599 Type *Ty = V->getType();
4600 Ty = getEffectiveSCEVType(Ty);
4601 return getMulExpr(V, getMinusOne(Ty), Flags);
4602}
4603
4604/// If Expr computes ~A, return A else return nullptr
4605static const SCEV *MatchNotExpr(const SCEV *Expr) {
4606 const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Expr);
4607 if (!Add || Add->getNumOperands() != 2 ||
4608 !Add->getOperand(0)->isAllOnesValue())
4609 return nullptr;
4610
4611 const SCEVMulExpr *AddRHS = dyn_cast<SCEVMulExpr>(Add->getOperand(1));
4612 if (!AddRHS || AddRHS->getNumOperands() != 2 ||
4613 !AddRHS->getOperand(0)->isAllOnesValue())
4614 return nullptr;
4615
4616 return AddRHS->getOperand(1);
4617}
4618
4619/// Return a SCEV corresponding to ~V = -1-V
4621 assert(!V->getType()->isPointerTy() && "Can't negate pointer");
4622
4623 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4624 return getConstant(
4625 cast<ConstantInt>(ConstantExpr::getNot(VC->getValue())));
4626
4627 // Fold ~(u|s)(min|max)(~x, ~y) to (u|s)(max|min)(x, y)
4628 if (const SCEVMinMaxExpr *MME = dyn_cast<SCEVMinMaxExpr>(V)) {
4629 auto MatchMinMaxNegation = [&](const SCEVMinMaxExpr *MME) {
4630 SmallVector<const SCEV *, 2> MatchedOperands;
4631 for (const SCEV *Operand : MME->operands()) {
4632 const SCEV *Matched = MatchNotExpr(Operand);
4633 if (!Matched)
4634 return (const SCEV *)nullptr;
4635 MatchedOperands.push_back(Matched);
4636 }
4637 return getMinMaxExpr(SCEVMinMaxExpr::negate(MME->getSCEVType()),
4638 MatchedOperands);
4639 };
4640 if (const SCEV *Replaced = MatchMinMaxNegation(MME))
4641 return Replaced;
4642 }
4643
4644 Type *Ty = V->getType();
4645 Ty = getEffectiveSCEVType(Ty);
4646 return getMinusSCEV(getMinusOne(Ty), V);
4647}
4648
4650 assert(P->getType()->isPointerTy());
4651
4652 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(P)) {
4653 // The base of an AddRec is the first operand.
4654 SmallVector<const SCEV *> Ops{AddRec->operands()};
4655 Ops[0] = removePointerBase(Ops[0]);
4656 // Don't try to transfer nowrap flags for now. We could in some cases
4657 // (for example, if pointer operand of the AddRec is a SCEVUnknown).
4658 return getAddRecExpr(Ops, AddRec->getLoop(), SCEV::FlagAnyWrap);
4659 }
4660 if (auto *Add = dyn_cast<SCEVAddExpr>(P)) {
4661 // The base of an Add is the pointer operand.
4662 SmallVector<const SCEV *> Ops{Add->operands()};
4663 const SCEV **PtrOp = nullptr;
4664 for (const SCEV *&AddOp : Ops) {
4665 if (AddOp->getType()->isPointerTy()) {
4666 assert(!PtrOp && "Cannot have multiple pointer ops");
4667 PtrOp = &AddOp;
4668 }
4669 }
4670 *PtrOp = removePointerBase(*PtrOp);
4671 // Don't try to transfer nowrap flags for now. We could in some cases
4672 // (for example, if the pointer operand of the Add is a SCEVUnknown).
4673 return getAddExpr(Ops);
4674 }
4675 // Any other expression must be a pointer base.
4676 return getZero(P->getType());
4677}
4678
4679const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS,
4680 SCEV::NoWrapFlags Flags,
4681 unsigned Depth) {
4682 // Fast path: X - X --> 0.
4683 if (LHS == RHS)
4684 return getZero(LHS->getType());
4685
4686 // If we subtract two pointers with different pointer bases, bail.
4687 // Eventually, we're going to add an assertion to getMulExpr that we
4688 // can't multiply by a pointer.
4689 if (RHS->getType()->isPointerTy()) {
4690 if (!LHS->getType()->isPointerTy() ||
4692 return getCouldNotCompute();
4695 }
4696
4697 // We represent LHS - RHS as LHS + (-1)*RHS. This transformation
4698 // makes it so that we cannot make much use of NUW.
4699 auto AddFlags = SCEV::FlagAnyWrap;
4700 const bool RHSIsNotMinSigned =
4702 if (hasFlags(Flags, SCEV::FlagNSW)) {
4703 // Let M be the minimum representable signed value. Then (-1)*RHS
4704 // signed-wraps if and only if RHS is M. That can happen even for
4705 // a NSW subtraction because e.g. (-1)*M signed-wraps even though
4706 // -1 - M does not. So to transfer NSW from LHS - RHS to LHS +
4707 // (-1)*RHS, we need to prove that RHS != M.
4708 //
4709 // If LHS is non-negative and we know that LHS - RHS does not
4710 // signed-wrap, then RHS cannot be M. So we can rule out signed-wrap
4711 // either by proving that RHS > M or that LHS >= 0.
4712 if (RHSIsNotMinSigned || isKnownNonNegative(LHS)) {
4713 AddFlags = SCEV::FlagNSW;
4714 }
4715 }
4716
4717 // FIXME: Find a correct way to transfer NSW to (-1)*M when LHS -
4718 // RHS is NSW and LHS >= 0.
4719 //
4720 // The difficulty here is that the NSW flag may have been proven
4721 // relative to a loop that is to be found in a recurrence in LHS and
4722 // not in RHS. Applying NSW to (-1)*M may then let the NSW have a
4723 // larger scope than intended.
4724 auto NegFlags = RHSIsNotMinSigned ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
4725
4726 return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags, Depth);
4727}
4728
4730 unsigned Depth) {
4731 Type *SrcTy = V->getType();
4732 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4733 "Cannot truncate or zero extend with non-integer arguments!");
4734 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4735 return V; // No conversion
4736 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4737 return getTruncateExpr(V, Ty, Depth);
4738 return getZeroExtendExpr(V, Ty, Depth);
4739}
4740
4742 unsigned Depth) {
4743 Type *SrcTy = V->getType();
4744 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4745 "Cannot truncate or zero extend with non-integer arguments!");
4746 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4747 return V; // No conversion
4748 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4749 return getTruncateExpr(V, Ty, Depth);
4750 return getSignExtendExpr(V, Ty, Depth);
4751}
4752
4753const SCEV *
4755 Type *SrcTy = V->getType();
4756 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4757 "Cannot noop or zero extend with non-integer arguments!");
4759 "getNoopOrZeroExtend cannot truncate!");
4760 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4761 return V; // No conversion
4762 return getZeroExtendExpr(V, Ty);
4763}
4764
4765const SCEV *
4767 Type *SrcTy = V->getType();
4768 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4769 "Cannot noop or sign extend with non-integer arguments!");
4771 "getNoopOrSignExtend cannot truncate!");
4772 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4773 return V; // No conversion
4774 return getSignExtendExpr(V, Ty);
4775}
4776
4777const SCEV *
4779 Type *SrcTy = V->getType();
4780 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4781 "Cannot noop or any extend with non-integer arguments!");
4783 "getNoopOrAnyExtend cannot truncate!");
4784 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4785 return V; // No conversion
4786 return getAnyExtendExpr(V, Ty);
4787}
4788
4789const SCEV *
4791 Type *SrcTy = V->getType();
4792 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4793 "Cannot truncate or noop with non-integer arguments!");
4795 "getTruncateOrNoop cannot extend!");
4796 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4797 return V; // No conversion
4798 return getTruncateExpr(V, Ty);
4799}
4800
4802 const SCEV *RHS) {
4803 const SCEV *PromotedLHS = LHS;
4804 const SCEV *PromotedRHS = RHS;
4805
4807 PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
4808 else
4809 PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
4810
4811 return getUMaxExpr(PromotedLHS, PromotedRHS);
4812}
4813
4815 const SCEV *RHS,
4816 bool Sequential) {
4818 return getUMinFromMismatchedTypes(Ops, Sequential);
4819}
4820
4821const SCEV *
4823 bool Sequential) {
4824 assert(!Ops.empty() && "At least one operand must be!");
4825 // Trivial case.
4826 if (Ops.size() == 1)
4827 return Ops[0];
4828
4829 // Find the max type first.
4830 Type *MaxType = nullptr;
4831 for (const auto *S : Ops)
4832 if (MaxType)
4833 MaxType = getWiderType(MaxType, S->getType());
4834 else
4835 MaxType = S->getType();
4836 assert(MaxType && "Failed to find maximum type!");
4837
4838 // Extend all ops to max type.
4839 SmallVector<const SCEV *, 2> PromotedOps;
4840 for (const auto *S : Ops)
4841 PromotedOps.push_back(getNoopOrZeroExtend(S, MaxType));
4842
4843 // Generate umin.
4844 return getUMinExpr(PromotedOps, Sequential);
4845}
4846
4848 // A pointer operand may evaluate to a nonpointer expression, such as null.
4849 if (!V->getType()->isPointerTy())
4850 return V;
4851
4852 while (true) {
4853 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(V)) {
4854 V = AddRec->getStart();
4855 } else if (auto *Add = dyn_cast<SCEVAddExpr>(V)) {
4856 const SCEV *PtrOp = nullptr;
4857 for (const SCEV *AddOp : Add->operands()) {
4858 if (AddOp->getType()->isPointerTy()) {
4859 assert(!PtrOp && "Cannot have multiple pointer ops");
4860 PtrOp = AddOp;
4861 }
4862 }
4863 assert(PtrOp && "Must have pointer op");
4864 V = PtrOp;
4865 } else // Not something we can look further into.
4866 return V;
4867 }
4868}
4869
4870/// Push users of the given Instruction onto the given Worklist.
4874 // Push the def-use children onto the Worklist stack.
4875 for (User *U : I->users()) {
4876 auto *UserInsn = cast<Instruction>(U);
4877 if (Visited.insert(UserInsn).second)
4878 Worklist.push_back(UserInsn);
4879 }
4880}
4881
4882namespace {
4883
4884/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its start
4885/// expression in case its Loop is L. If it is not L then
4886/// if IgnoreOtherLoops is true then use AddRec itself
4887/// otherwise rewrite cannot be done.
4888/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
4889class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> {
4890public:
4891 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
4892 bool IgnoreOtherLoops = true) {
4893 SCEVInitRewriter Rewriter(L, SE);
4894 const SCEV *Result = Rewriter.visit(S);
4895 if (Rewriter.hasSeenLoopVariantSCEVUnknown())
4896 return SE.getCouldNotCompute();
4897 return Rewriter.hasSeenOtherLoops() && !IgnoreOtherLoops
4898 ? SE.getCouldNotCompute()
4899 : Result;
4900 }
4901
4902 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4903 if (!SE.isLoopInvariant(Expr, L))
4904 SeenLoopVariantSCEVUnknown = true;
4905 return Expr;
4906 }
4907
4908 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4909 // Only re-write AddRecExprs for this loop.
4910 if (Expr->getLoop() == L)
4911 return Expr->getStart();
4912 SeenOtherLoops = true;
4913 return Expr;
4914 }
4915
4916 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
4917
4918 bool hasSeenOtherLoops() { return SeenOtherLoops; }
4919
4920private:
4921 explicit SCEVInitRewriter(const Loop *L, ScalarEvolution &SE)
4922 : SCEVRewriteVisitor(SE), L(L) {}
4923
4924 const Loop *L;
4925 bool SeenLoopVariantSCEVUnknown = false;
4926 bool SeenOtherLoops = false;
4927};
4928
4929/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its post
4930/// increment expression in case its Loop is L. If it is not L then
4931/// use AddRec itself.
4932/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
4933class SCEVPostIncRewriter : public SCEVRewriteVisitor<SCEVPostIncRewriter> {
4934public:
4935 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE) {
4936 SCEVPostIncRewriter Rewriter(L, SE);
4937 const SCEV *Result = Rewriter.visit(S);
4938 return Rewriter.hasSeenLoopVariantSCEVUnknown()
4939 ? SE.getCouldNotCompute()
4940 : Result;
4941 }
4942
4943 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4944 if (!SE.isLoopInvariant(Expr, L))
4945 SeenLoopVariantSCEVUnknown = true;
4946 return Expr;
4947 }
4948
4949 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4950 // Only re-write AddRecExprs for this loop.
4951 if (Expr->getLoop() == L)
4952 return Expr->getPostIncExpr(SE);
4953 SeenOtherLoops = true;
4954 return Expr;
4955 }
4956
4957 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
4958
4959 bool hasSeenOtherLoops() { return SeenOtherLoops; }
4960
4961private:
4962 explicit SCEVPostIncRewriter(const Loop *L, ScalarEvolution &SE)
4963 : SCEVRewriteVisitor(SE), L(L) {}
4964
4965 const Loop *L;
4966 bool SeenLoopVariantSCEVUnknown = false;
4967 bool SeenOtherLoops = false;
4968};
4969
4970/// This class evaluates the compare condition by matching it against the
4971/// condition of loop latch. If there is a match we assume a true value
4972/// for the condition while building SCEV nodes.
4973class SCEVBackedgeConditionFolder
4974 : public SCEVRewriteVisitor<SCEVBackedgeConditionFolder> {
4975public:
4976 static const SCEV *rewrite(const SCEV *S, const Loop *L,
4977 ScalarEvolution &SE) {
4978 bool IsPosBECond = false;
4979 Value *BECond = nullptr;
4980 if (BasicBlock *Latch = L->getLoopLatch()) {
4981 BranchInst *BI = dyn_cast<BranchInst>(Latch->getTerminator());
4982 if (BI && BI->isConditional()) {
4983 assert(BI->getSuccessor(0) != BI->getSuccessor(1) &&
4984 "Both outgoing branches should not target same header!");
4985 BECond = BI->getCondition();
4986 IsPosBECond = BI->getSuccessor(0) == L->getHeader();
4987 } else {
4988 return S;
4989 }
4990 }
4991 SCEVBackedgeConditionFolder Rewriter(L, BECond, IsPosBECond, SE);
4992 return Rewriter.visit(S);
4993 }
4994
4995 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4996 const SCEV *Result = Expr;
4997 bool InvariantF = SE.isLoopInvariant(Expr, L);
4998
4999 if (!InvariantF) {
5000 Instruction *I = cast<Instruction>(Expr->getValue());
5001 switch (I->getOpcode()) {
5002 case Instruction::Select: {
5003 SelectInst *SI = cast<SelectInst>(I);
5004 std::optional<const SCEV *> Res =
5005 compareWithBackedgeCondition(SI->getCondition());
5006 if (Res) {
5007 bool IsOne = cast<SCEVConstant>(*Res)->getValue()->isOne();
5008 Result = SE.getSCEV(IsOne ? SI->getTrueValue() : SI->getFalseValue());
5009 }
5010 break;
5011 }
5012 default: {
5013 std::optional<const SCEV *> Res = compareWithBackedgeCondition(I);
5014 if (Res)
5015 Result = *Res;
5016 break;
5017 }
5018 }
5019 }
5020 return Result;
5021 }
5022
5023private:
5024 explicit SCEVBackedgeConditionFolder(const Loop *L, Value *BECond,
5025 bool IsPosBECond, ScalarEvolution &SE)
5026 : SCEVRewriteVisitor(SE), L(L), BackedgeCond(BECond),
5027 IsPositiveBECond(IsPosBECond) {}
5028
5029 std::optional<const SCEV *> compareWithBackedgeCondition(Value *IC);
5030
5031 const Loop *L;
5032 /// Loop back condition.
5033 Value *BackedgeCond = nullptr;
5034 /// Set to true if loop back is on positive branch condition.
5035 bool IsPositiveBECond;
5036};
5037
5038std::optional<const SCEV *>
5039SCEVBackedgeConditionFolder::compareWithBackedgeCondition(Value *IC) {
5040
5041 // If value matches the backedge condition for loop latch,
5042 // then return a constant evolution node based on loopback
5043 // branch taken.
5044 if (BackedgeCond == IC)
5045 return IsPositiveBECond ? SE.getOne(Type::getInt1Ty(SE.getContext()))
5047 return std::nullopt;
5048}
5049
5050class SCEVShiftRewriter : public SCEVRewriteVisitor<SCEVShiftRewriter> {
5051public:
5052 static const SCEV *rewrite(const SCEV *S, const Loop *L,
5053 ScalarEvolution &SE) {
5054 SCEVShiftRewriter Rewriter(L, SE);
5055 const SCEV *Result = Rewriter.visit(S);
5056 return Rewriter.isValid() ? Result : SE.getCouldNotCompute();
5057 }
5058
5059 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5060 // Only allow AddRecExprs for this loop.
5061 if (!SE.isLoopInvariant(Expr, L))
5062 Valid = false;
5063 return Expr;
5064 }
5065
5066 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
5067 if (Expr->getLoop() == L && Expr->isAffine())
5068 return SE.getMinusSCEV(Expr, Expr->getStepRecurrence(SE));
5069 Valid = false;
5070 return Expr;
5071 }
5072
5073 bool isValid() { return Valid; }
5074
5075private:
5076 explicit SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE)
5077 : SCEVRewriteVisitor(SE), L(L) {}
5078
5079 const Loop *L;
5080 bool Valid = true;
5081};
5082
5083} // end anonymous namespace
5084
5086ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) {
5087 if (!AR->isAffine())
5088 return SCEV::FlagAnyWrap;
5089
5090 using OBO = OverflowingBinaryOperator;
5091
5093
5094 if (!AR->hasNoSelfWrap()) {
5095 const SCEV *BECount = getConstantMaxBackedgeTakenCount(AR->getLoop());
5096 if (const SCEVConstant *BECountMax = dyn_cast<SCEVConstant>(BECount)) {
5097 ConstantRange StepCR = getSignedRange(AR->getStepRecurrence(*this));
5098 const APInt &BECountAP = BECountMax->getAPInt();
5099 unsigned NoOverflowBitWidth =
5100 BECountAP.getActiveBits() + StepCR.getMinSignedBits();
5101 if (NoOverflowBitWidth <= getTypeSizeInBits(AR->getType()))
5103 }
5104 }
5105
5106 if (!AR->hasNoSignedWrap()) {
5107 ConstantRange AddRecRange = getSignedRange(AR);
5108 ConstantRange IncRange = getSignedRange(AR->getStepRecurrence(*this));
5109
5111 Instruction::Add, IncRange, OBO::NoSignedWrap);
5112 if (NSWRegion.contains(AddRecRange))
5114 }
5115
5116 if (!AR->hasNoUnsignedWrap()) {
5117 ConstantRange AddRecRange = getUnsignedRange(AR);
5118 ConstantRange IncRange = getUnsignedRange(AR->getStepRecurrence(*this));
5119
5121 Instruction::Add, IncRange, OBO::NoUnsignedWrap);
5122 if (NUWRegion.contains(AddRecRange))
5124 }
5125
5126 return Result;
5127}
5128
5130ScalarEvolution::proveNoSignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5132
5133 if (AR->hasNoSignedWrap())
5134 return Result;
5135
5136 if (!AR->isAffine())
5137 return Result;
5138
5139 // This function can be expensive, only try to prove NSW once per AddRec.
5140 if (!SignedWrapViaInductionTried.insert(AR).second)
5141 return Result;
5142
5143 const SCEV *Step = AR->getStepRecurrence(*this);
5144 const Loop *L = AR->getLoop();
5145
5146 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5147 // Note that this serves two purposes: It filters out loops that are
5148 // simply not analyzable, and it covers the case where this code is
5149 // being called from within backedge-taken count analysis, such that
5150 // attempting to ask for the backedge-taken count would likely result
5151 // in infinite recursion. In the later case, the analysis code will
5152 // cope with a conservative value, and it will take care to purge
5153 // that value once it has finished.
5154 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5155
5156 // Normally, in the cases we can prove no-overflow via a
5157 // backedge guarding condition, we can also compute a backedge
5158 // taken count for the loop. The exceptions are assumptions and
5159 // guards present in the loop -- SCEV is not great at exploiting
5160 // these to compute max backedge taken counts, but can still use
5161 // these to prove lack of overflow. Use this fact to avoid
5162 // doing extra work that may not pay off.
5163
5164 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5165 AC.assumptions().empty())
5166 return Result;
5167
5168 // If the backedge is guarded by a comparison with the pre-inc value the
5169 // addrec is safe. Also, if the entry is guarded by a comparison with the
5170 // start value and the backedge is guarded by a comparison with the post-inc
5171 // value, the addrec is safe.
5173 const SCEV *OverflowLimit =
5174 getSignedOverflowLimitForStep(Step, &Pred, this);
5175 if (OverflowLimit &&
5176 (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) ||
5177 isKnownOnEveryIteration(Pred, AR, OverflowLimit))) {
5178 Result = setFlags(Result, SCEV::FlagNSW);
5179 }
5180 return Result;
5181}
5183ScalarEvolution::proveNoUnsignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5185
5186 if (AR->hasNoUnsignedWrap())
5187 return Result;
5188
5189 if (!AR->isAffine())
5190 return Result;
5191
5192 // This function can be expensive, only try to prove NUW once per AddRec.
5193 if (!UnsignedWrapViaInductionTried.insert(AR).second)
5194 return Result;
5195
5196 const SCEV *Step = AR->getStepRecurrence(*this);
5197 unsigned BitWidth = getTypeSizeInBits(AR->getType());
5198 const Loop *L = AR->getLoop();
5199
5200 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5201 // Note that this serves two purposes: It filters out loops that are
5202 // simply not analyzable, and it covers the case where this code is
5203 // being called from within backedge-taken count analysis, such that
5204 // attempting to ask for the backedge-taken count would likely result
5205 // in infinite recursion. In the later case, the analysis code will
5206 // cope with a conservative value, and it will take care to purge
5207 // that value once it has finished.
5208 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5209
5210 // Normally, in the cases we can prove no-overflow via a
5211 // backedge guarding condition, we can also compute a backedge
5212 // taken count for the loop. The exceptions are assumptions and
5213 // guards present in the loop -- SCEV is not great at exploiting
5214 // these to compute max backedge taken counts, but can still use
5215 // these to prove lack of overflow. Use this fact to avoid
5216 // doing extra work that may not pay off.
5217
5218 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5219 AC.assumptions().empty())
5220 return Result;
5221
5222 // If the backedge is guarded by a comparison with the pre-inc value the
5223 // addrec is safe. Also, if the entry is guarded by a comparison with the
5224 // start value and the backedge is guarded by a comparison with the post-inc
5225 // value, the addrec is safe.
5226 if (isKnownPositive(Step)) {
5228 getUnsignedRangeMax(Step));
5231 Result = setFlags(Result, SCEV::FlagNUW);
5232 }
5233 }
5234
5235 return Result;
5236}
5237
5238namespace {
5239
5240/// Represents an abstract binary operation. This may exist as a
5241/// normal instruction or constant expression, or may have been
5242/// derived from an expression tree.
5243struct BinaryOp {
5244 unsigned Opcode;
5245 Value *LHS;
5246 Value *RHS;
5247 bool IsNSW = false;
5248 bool IsNUW = false;
5249
5250 /// Op is set if this BinaryOp corresponds to a concrete LLVM instruction or
5251 /// constant expression.
5252 Operator *Op = nullptr;
5253
5254 explicit BinaryOp(Operator *Op)
5255 : Opcode(Op->getOpcode()), LHS(Op->getOperand(0)), RHS(Op->getOperand(1)),
5256 Op(Op) {
5257 if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Op)) {
5258 IsNSW = OBO->hasNoSignedWrap();
5259 IsNUW = OBO->hasNoUnsignedWrap();
5260 }
5261 }
5262
5263 explicit BinaryOp(unsigned Opcode, Value *LHS, Value *RHS, bool IsNSW = false,
5264 bool IsNUW = false)
5265 : Opcode(Opcode), LHS(LHS), RHS(RHS), IsNSW(IsNSW), IsNUW(IsNUW) {}
5266};
5267
5268} // end anonymous namespace
5269
5270/// Try to map \p V into a BinaryOp, and return \c std::nullopt on failure.
5271static std::optional<BinaryOp> MatchBinaryOp(Value *V, const DataLayout &DL,
5272 AssumptionCache &AC,
5273 const DominatorTree &DT,
5274 const Instruction *CxtI) {
5275 auto *Op = dyn_cast<Operator>(V);
5276 if (!Op)
5277 return std::nullopt;
5278
5279 // Implementation detail: all the cleverness here should happen without
5280 // creating new SCEV expressions -- our caller knowns tricks to avoid creating
5281 // SCEV expressions when possible, and we should not break that.
5282
5283 switch (Op->getOpcode()) {
5284 case Instruction::Add:
5285 case Instruction::Sub:
5286 case Instruction::Mul:
5287 case Instruction::UDiv:
5288 case Instruction::URem:
5289 case Instruction::And:
5290 case Instruction::AShr:
5291 case Instruction::Shl:
5292 return BinaryOp(Op);
5293
5294 case Instruction::Or: {
5295 // Convert or disjoint into add nuw nsw.
5296 if (cast<PossiblyDisjointInst>(Op)->isDisjoint())
5297 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1),
5298 /*IsNSW=*/true, /*IsNUW=*/true);
5299 return BinaryOp(Op);
5300 }
5301
5302 case Instruction::Xor:
5303 if (auto *RHSC = dyn_cast<ConstantInt>(Op->getOperand(1)))
5304 // If the RHS of the xor is a signmask, then this is just an add.
5305 // Instcombine turns add of signmask into xor as a strength reduction step.
5306 if (RHSC->getValue().isSignMask())
5307 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5308 // Binary `xor` is a bit-wise `add`.
5309 if (V->getType()->isIntegerTy(1))
5310 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5311 return BinaryOp(Op);
5312
5313 case Instruction::LShr:
5314 // Turn logical shift right of a constant into a unsigned divide.
5315 if (ConstantInt *SA = dyn_cast<ConstantInt>(Op->getOperand(1))) {
5316 uint32_t BitWidth = cast<IntegerType>(Op->getType())->getBitWidth();
5317
5318 // If the shift count is not less than the bitwidth, the result of
5319 // the shift is undefined. Don't try to analyze it, because the
5320 // resolution chosen here may differ from the resolution chosen in
5321 // other parts of the compiler.
5322 if (SA->getValue().ult(BitWidth)) {
5323 Constant *X =
5324 ConstantInt::get(SA->getContext(),
5325 APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
5326 return BinaryOp(Instruction::UDiv, Op->getOperand(0), X);
5327 }
5328 }
5329 return BinaryOp(Op);
5330
5331 case Instruction::ExtractValue: {
5332 auto *EVI = cast<ExtractValueInst>(Op);
5333 if (EVI->getNumIndices() != 1 || EVI->getIndices()[0] != 0)
5334 break;
5335
5336 auto *WO = dyn_cast<WithOverflowInst>(EVI->getAggregateOperand());
5337 if (!WO)
5338 break;
5339
5340 Instruction::BinaryOps BinOp = WO->getBinaryOp();
5341 bool Signed = WO->isSigned();
5342 // TODO: Should add nuw/nsw flags for mul as well.
5343 if (BinOp == Instruction::Mul || !isOverflowIntrinsicNoWrap(WO, DT))
5344 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS());
5345
5346 // Now that we know that all uses of the arithmetic-result component of
5347 // CI are guarded by the overflow check, we can go ahead and pretend
5348 // that the arithmetic is non-overflowing.
5349 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS(),
5350 /* IsNSW = */ Signed, /* IsNUW = */ !Signed);
5351 }
5352
5353 default:
5354 break;
5355 }
5356
5357 // Recognise intrinsic loop.decrement.reg, and as this has exactly the same
5358 // semantics as a Sub, return a binary sub expression.
5359 if (auto *II = dyn_cast<IntrinsicInst>(V))
5360 if (II->getIntrinsicID() == Intrinsic::loop_decrement_reg)
5361 return BinaryOp(Instruction::Sub, II->getOperand(0), II->getOperand(1));
5362
5363 return std::nullopt;
5364}
5365
5366/// Helper function to createAddRecFromPHIWithCasts. We have a phi
5367/// node whose symbolic (unknown) SCEV is \p SymbolicPHI, which is updated via
5368/// the loop backedge by a SCEVAddExpr, possibly also with a few casts on the
5369/// way. This function checks if \p Op, an operand of this SCEVAddExpr,
5370/// follows one of the following patterns:
5371/// Op == (SExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5372/// Op == (ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5373/// If the SCEV expression of \p Op conforms with one of the expected patterns
5374/// we return the type of the truncation operation, and indicate whether the
5375/// truncated type should be treated as signed/unsigned by setting
5376/// \p Signed to true/false, respectively.
5377static Type *isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI,
5378 bool &Signed, ScalarEvolution &SE) {
5379 // The case where Op == SymbolicPHI (that is, with no type conversions on
5380 // the way) is handled by the regular add recurrence creating logic and
5381 // would have already been triggered in createAddRecForPHI. Reaching it here
5382 // means that createAddRecFromPHI had failed for this PHI before (e.g.,
5383 // because one of the other operands of the SCEVAddExpr updating this PHI is
5384 // not invariant).
5385 //
5386 // Here we look for the case where Op = (ext(trunc(SymbolicPHI))), and in
5387 // this case predicates that allow us to prove that Op == SymbolicPHI will
5388 // be added.
5389 if (Op == SymbolicPHI)
5390 return nullptr;
5391
5392 unsigned SourceBits = SE.getTypeSizeInBits(SymbolicPHI->getType());
5393 unsigned NewBits = SE.getTypeSizeInBits(Op->getType());
5394 if (SourceBits != NewBits)
5395 return nullptr;
5396
5397 const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(Op);
5398 const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(Op);
5399 if (!SExt && !ZExt)
5400 return nullptr;
5401 const SCEVTruncateExpr *Trunc =
5402 SExt ? dyn_cast<SCEVTruncateExpr>(SExt->getOperand())
5403 : dyn_cast<SCEVTruncateExpr>(ZExt->getOperand());
5404 if (!Trunc)
5405 return nullptr;
5406 const SCEV *X = Trunc->getOperand();
5407 if (X != SymbolicPHI)
5408 return nullptr;
5409 Signed = SExt != nullptr;
5410 return Trunc->getType();
5411}
5412
5413static const Loop *isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI) {
5414 if (!PN->getType()->isIntegerTy())
5415 return nullptr;
5416 const Loop *L = LI.getLoopFor(PN->getParent());
5417 if (!L || L->getHeader() != PN->getParent())
5418 return nullptr;
5419 return L;
5420}
5421
5422// Analyze \p SymbolicPHI, a SCEV expression of a phi node, and check if the
5423// computation that updates the phi follows the following pattern:
5424// (SExt/ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy) + InvariantAccum
5425// which correspond to a phi->trunc->sext/zext->add->phi update chain.
5426// If so, try to see if it can be rewritten as an AddRecExpr under some
5427// Predicates. If successful, return them as a pair. Also cache the results
5428// of the analysis.
5429//
5430// Example usage scenario:
5431// Say the Rewriter is called for the following SCEV:
5432// 8 * ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5433// where:
5434// %X = phi i64 (%Start, %BEValue)
5435// It will visitMul->visitAdd->visitSExt->visitTrunc->visitUnknown(%X),
5436// and call this function with %SymbolicPHI = %X.
5437//
5438// The analysis will find that the value coming around the backedge has
5439// the following SCEV:
5440// BEValue = ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5441// Upon concluding that this matches the desired pattern, the function
5442// will return the pair {NewAddRec, SmallPredsVec} where:
5443// NewAddRec = {%Start,+,%Step}
5444// SmallPredsVec = {P1, P2, P3} as follows:
5445// P1(WrapPred): AR: {trunc(%Start),+,(trunc %Step)}<nsw> Flags: <nssw>
5446// P2(EqualPred): %Start == (sext i32 (trunc i64 %Start to i32) to i64)
5447// P3(EqualPred): %Step == (sext i32 (trunc i64 %Step to i32) to i64)
5448// The returned pair means that SymbolicPHI can be rewritten into NewAddRec
5449// under the predicates {P1,P2,P3}.
5450// This predicated rewrite will be cached in PredicatedSCEVRewrites:
5451// PredicatedSCEVRewrites[{%X,L}] = {NewAddRec, {P1,P2,P3)}
5452//
5453// TODO's:
5454//
5455// 1) Extend the Induction descriptor to also support inductions that involve
5456// casts: When needed (namely, when we are called in the context of the
5457// vectorizer induction analysis), a Set of cast instructions will be
5458// populated by this method, and provided back to isInductionPHI. This is
5459// needed to allow the vectorizer to properly record them to be ignored by
5460// the cost model and to avoid vectorizing them (otherwise these casts,
5461// which are redundant under the runtime overflow checks, will be
5462// vectorized, which can be costly).
5463//
5464// 2) Support additional induction/PHISCEV patterns: We also want to support
5465// inductions where the sext-trunc / zext-trunc operations (partly) occur
5466// after the induction update operation (the induction increment):
5467//
5468// (Trunc iy (SExt/ZExt ix (%SymbolicPHI + InvariantAccum) to iy) to ix)
5469// which correspond to a phi->add->trunc->sext/zext->phi update chain.
5470//
5471// (Trunc iy ((SExt/ZExt ix (%SymbolicPhi) to iy) + InvariantAccum) to ix)
5472// which correspond to a phi->trunc->add->sext/zext->phi update chain.
5473//
5474// 3) Outline common code with createAddRecFromPHI to avoid duplication.
5475std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5476ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI) {
5478
5479 // *** Part1: Analyze if we have a phi-with-cast pattern for which we can
5480 // return an AddRec expression under some predicate.
5481
5482 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5483 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5484 assert(L && "Expecting an integer loop header phi");
5485
5486 // The loop may have multiple entrances or multiple exits; we can analyze
5487 // this phi as an addrec if it has a unique entry value and a unique
5488 // backedge value.
5489 Value *BEValueV = nullptr, *StartValueV = nullptr;
5490 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5491 Value *V = PN->getIncomingValue(i);
5492 if (L->contains(PN->getIncomingBlock(i))) {
5493 if (!BEValueV) {
5494 BEValueV = V;
5495 } else if (BEValueV != V) {
5496 BEValueV = nullptr;
5497 break;
5498 }
5499 } else if (!StartValueV) {
5500 StartValueV = V;
5501 } else if (StartValueV != V) {
5502 StartValueV = nullptr;
5503 break;
5504 }
5505 }
5506 if (!BEValueV || !StartValueV)
5507 return std::nullopt;
5508
5509 const SCEV *BEValue = getSCEV(BEValueV);
5510
5511 // If the value coming around the backedge is an add with the symbolic
5512 // value we just inserted, possibly with casts that we can ignore under
5513 // an appropriate runtime guard, then we found a simple induction variable!
5514 const auto *Add = dyn_cast<SCEVAddExpr>(BEValue);
5515 if (!Add)
5516 return std::nullopt;
5517
5518 // If there is a single occurrence of the symbolic value, possibly
5519 // casted, replace it with a recurrence.
5520 unsigned FoundIndex = Add->getNumOperands();
5521 Type *TruncTy = nullptr;
5522 bool Signed;
5523 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5524 if ((TruncTy =
5525 isSimpleCastedPHI(Add->getOperand(i), SymbolicPHI, Signed, *this)))
5526 if (FoundIndex == e) {
5527 FoundIndex = i;
5528 break;
5529 }
5530
5531 if (FoundIndex == Add->getNumOperands())
5532 return std::nullopt;
5533
5534 // Create an add with everything but the specified operand.
5536 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5537 if (i != FoundIndex)
5538 Ops.push_back(Add->getOperand(i));
5539 const SCEV *Accum = getAddExpr(Ops);
5540
5541 // The runtime checks will not be valid if the step amount is
5542 // varying inside the loop.
5543 if (!isLoopInvariant(Accum, L))
5544 return std::nullopt;
5545
5546 // *** Part2: Create the predicates
5547
5548 // Analysis was successful: we have a phi-with-cast pattern for which we
5549 // can return an AddRec expression under the following predicates:
5550 //
5551 // P1: A Wrap predicate that guarantees that Trunc(Start) + i*Trunc(Accum)
5552 // fits within the truncated type (does not overflow) for i = 0 to n-1.
5553 // P2: An Equal predicate that guarantees that
5554 // Start = (Ext ix (Trunc iy (Start) to ix) to iy)
5555 // P3: An Equal predicate that guarantees that
5556 // Accum = (Ext ix (Trunc iy (Accum) to ix) to iy)
5557 //
5558 // As we next prove, the above predicates guarantee that:
5559 // Start + i*Accum = (Ext ix (Trunc iy ( Start + i*Accum ) to ix) to iy)
5560 //
5561 //
5562 // More formally, we want to prove that:
5563 // Expr(i+1) = Start + (i+1) * Accum
5564 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5565 //
5566 // Given that:
5567 // 1) Expr(0) = Start
5568 // 2) Expr(1) = Start + Accum
5569 // = (Ext ix (Trunc iy (Start) to ix) to iy) + Accum :: from P2
5570 // 3) Induction hypothesis (step i):
5571 // Expr(i) = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum
5572 //
5573 // Proof:
5574 // Expr(i+1) =
5575 // = Start + (i+1)*Accum
5576 // = (Start + i*Accum) + Accum
5577 // = Expr(i) + Accum
5578 // = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum + Accum
5579 // :: from step i
5580 //
5581 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy) + Accum + Accum
5582 //
5583 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy)
5584 // + (Ext ix (Trunc iy (Accum) to ix) to iy)
5585 // + Accum :: from P3
5586 //
5587 // = (Ext ix (Trunc iy ((Start + (i-1)*Accum) + Accum) to ix) to iy)
5588 // + Accum :: from P1: Ext(x)+Ext(y)=>Ext(x+y)
5589 //
5590 // = (Ext ix (Trunc iy (Start + i*Accum) to ix) to iy) + Accum
5591 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5592 //
5593 // By induction, the same applies to all iterations 1<=i<n:
5594 //
5595
5596 // Create a truncated addrec for which we will add a no overflow check (P1).
5597 const SCEV *StartVal = getSCEV(StartValueV);
5598 const SCEV *PHISCEV =
5599 getAddRecExpr(getTruncateExpr(StartVal, TruncTy),
5600 getTruncateExpr(Accum, TruncTy), L, SCEV::FlagAnyWrap);
5601
5602 // PHISCEV can be either a SCEVConstant or a SCEVAddRecExpr.
5603 // ex: If truncated Accum is 0 and StartVal is a constant, then PHISCEV
5604 // will be constant.
5605 //
5606 // If PHISCEV is a constant, then P1 degenerates into P2 or P3, so we don't
5607 // add P1.
5608 if (const auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5612 const SCEVPredicate *AddRecPred = getWrapPredicate(AR, AddedFlags);
5613 Predicates.push_back(AddRecPred);
5614 }
5615
5616 // Create the Equal Predicates P2,P3:
5617
5618 // It is possible that the predicates P2 and/or P3 are computable at
5619 // compile time due to StartVal and/or Accum being constants.
5620 // If either one is, then we can check that now and escape if either P2
5621 // or P3 is false.
5622
5623 // Construct the extended SCEV: (Ext ix (Trunc iy (Expr) to ix) to iy)
5624 // for each of StartVal and Accum
5625 auto getExtendedExpr = [&](const SCEV *Expr,
5626 bool CreateSignExtend) -> const SCEV * {
5627 assert(isLoopInvariant(Expr, L) && "Expr is expected to be invariant");
5628 const SCEV *TruncatedExpr = getTruncateExpr(Expr, TruncTy);
5629 const SCEV *ExtendedExpr =
5630 CreateSignExtend ? getSignExtendExpr(TruncatedExpr, Expr->getType())
5631 : getZeroExtendExpr(TruncatedExpr, Expr->getType());
5632 return ExtendedExpr;
5633 };
5634
5635 // Given:
5636 // ExtendedExpr = (Ext ix (Trunc iy (Expr) to ix) to iy
5637 // = getExtendedExpr(Expr)
5638 // Determine whether the predicate P: Expr == ExtendedExpr
5639 // is known to be false at compile time
5640 auto PredIsKnownFalse = [&](const SCEV *Expr,
5641 const SCEV *ExtendedExpr) -> bool {
5642 return Expr != ExtendedExpr &&
5643 isKnownPredicate(ICmpInst::ICMP_NE, Expr, ExtendedExpr);
5644 };
5645
5646 const SCEV *StartExtended = getExtendedExpr(StartVal, Signed);
5647 if (PredIsKnownFalse(StartVal, StartExtended)) {
5648 LLVM_DEBUG(dbgs() << "P2 is compile-time false\n";);
5649 return std::nullopt;
5650 }
5651
5652 // The Step is always Signed (because the overflow checks are either
5653 // NSSW or NUSW)
5654 const SCEV *AccumExtended = getExtendedExpr(Accum, /*CreateSignExtend=*/true);
5655 if (PredIsKnownFalse(Accum, AccumExtended)) {
5656 LLVM_DEBUG(dbgs() << "P3 is compile-time false\n";);
5657 return std::nullopt;
5658 }
5659
5660 auto AppendPredicate = [&](const SCEV *Expr,
5661 const SCEV *ExtendedExpr) -> void {
5662 if (Expr != ExtendedExpr &&
5663 !isKnownPredicate(ICmpInst::ICMP_EQ, Expr, ExtendedExpr)) {
5664 const SCEVPredicate *Pred = getEqualPredicate(Expr, ExtendedExpr);
5665 LLVM_DEBUG(dbgs() << "Added Predicate: " << *Pred);
5666 Predicates.push_back(Pred);
5667 }
5668 };
5669
5670 AppendPredicate(StartVal, StartExtended);
5671 AppendPredicate(Accum, AccumExtended);
5672
5673 // *** Part3: Predicates are ready. Now go ahead and create the new addrec in
5674 // which the casts had been folded away. The caller can rewrite SymbolicPHI
5675 // into NewAR if it will also add the runtime overflow checks specified in
5676 // Predicates.
5677 auto *NewAR = getAddRecExpr(StartVal, Accum, L, SCEV::FlagAnyWrap);
5678
5679 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> PredRewrite =
5680 std::make_pair(NewAR, Predicates);
5681 // Remember the result of the analysis for this SCEV at this locayyytion.
5682 PredicatedSCEVRewrites[{SymbolicPHI, L}] = PredRewrite;
5683 return PredRewrite;
5684}
5685
5686std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5688 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5689 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5690 if (!L)
5691 return std::nullopt;
5692
5693 // Check to see if we already analyzed this PHI.
5694 auto I = PredicatedSCEVRewrites.find({SymbolicPHI, L});
5695 if (I != PredicatedSCEVRewrites.end()) {
5696 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> Rewrite =
5697 I->second;
5698 // Analysis was done before and failed to create an AddRec:
5699 if (Rewrite.first == SymbolicPHI)
5700 return std::nullopt;
5701 // Analysis was done before and succeeded to create an AddRec under
5702 // a predicate:
5703 assert(isa<SCEVAddRecExpr>(Rewrite.first) && "Expected an AddRec");
5704 assert(!(Rewrite.second).empty() && "Expected to find Predicates");
5705 return Rewrite;
5706 }
5707
5708 std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5709 Rewrite = createAddRecFromPHIWithCastsImpl(SymbolicPHI);
5710
5711 // Record in the cache that the analysis failed
5712 if (!Rewrite) {
5714 PredicatedSCEVRewrites[{SymbolicPHI, L}] = {SymbolicPHI, Predicates};
5715 return std::nullopt;
5716 }
5717
5718 return Rewrite;
5719}
5720
5721// FIXME: This utility is currently required because the Rewriter currently
5722// does not rewrite this expression:
5723// {0, +, (sext ix (trunc iy to ix) to iy)}
5724// into {0, +, %step},
5725// even when the following Equal predicate exists:
5726// "%step == (sext ix (trunc iy to ix) to iy)".
5728 const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const {
5729 if (AR1 == AR2)
5730 return true;
5731
5732 auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool {
5733 if (Expr1 != Expr2 && !Preds->implies(SE.getEqualPredicate(Expr1, Expr2)) &&
5734 !Preds->implies(SE.getEqualPredicate(Expr2, Expr1)))
5735 return false;
5736 return true;
5737 };
5738
5739 if (!areExprsEqual(AR1->getStart(), AR2->getStart()) ||
5740 !areExprsEqual(AR1->getStepRecurrence(SE), AR2->getStepRecurrence(SE)))
5741 return false;
5742 return true;
5743}
5744
5745/// A helper function for createAddRecFromPHI to handle simple cases.
5746///
5747/// This function tries to find an AddRec expression for the simplest (yet most
5748/// common) cases: PN = PHI(Start, OP(Self, LoopInvariant)).
5749/// If it fails, createAddRecFromPHI will use a more general, but slow,
5750/// technique for finding the AddRec expression.
5751const SCEV *ScalarEvolution::createSimpleAffineAddRec(PHINode *PN,
5752 Value *BEValueV,
5753 Value *StartValueV) {
5754 const Loop *L = LI.getLoopFor(PN->getParent());
5755 assert(L && L->getHeader() == PN->getParent());
5756 assert(BEValueV && StartValueV);
5757
5758 auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN);
5759 if (!BO)
5760 return nullptr;
5761
5762 if (BO->Opcode != Instruction::Add)
5763 return nullptr;
5764
5765 const SCEV *Accum = nullptr;
5766 if (BO->LHS == PN && L->isLoopInvariant(BO->RHS))
5767 Accum = getSCEV(BO->RHS);
5768 else if (BO->RHS == PN && L->isLoopInvariant(BO->LHS))
5769 Accum = getSCEV(BO->LHS);
5770
5771 if (!Accum)
5772 return nullptr;
5773
5775 if (BO->IsNUW)
5776 Flags = setFlags(Flags, SCEV::FlagNUW);
5777 if (BO->IsNSW)
5778 Flags = setFlags(Flags, SCEV::FlagNSW);
5779
5780 const SCEV *StartVal = getSCEV(StartValueV);
5781 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5782 insertValueToMap(PN, PHISCEV);
5783
5784 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5785 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR),
5787 proveNoWrapViaConstantRanges(AR)));
5788 }
5789
5790 // We can add Flags to the post-inc expression only if we
5791 // know that it is *undefined behavior* for BEValueV to
5792 // overflow.
5793 if (auto *BEInst = dyn_cast<Instruction>(BEValueV)) {
5794 assert(isLoopInvariant(Accum, L) &&
5795 "Accum is defined outside L, but is not invariant?");
5796 if (isAddRecNeverPoison(BEInst, L))
5797 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5798 }
5799
5800 return PHISCEV;
5801}
5802
5803const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
5804 const Loop *L = LI.getLoopFor(PN->getParent());
5805 if (!L || L->getHeader() != PN->getParent())
5806 return nullptr;
5807
5808 // The loop may have multiple entrances or multiple exits; we can analyze
5809 // this phi as an addrec if it has a unique entry value and a unique
5810 // backedge value.
5811 Value *BEValueV = nullptr, *StartValueV = nullptr;
5812 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5813 Value *V = PN->getIncomingValue(i);
5814 if (L->contains(PN->getIncomingBlock(i))) {
5815 if (!BEValueV) {
5816 BEValueV = V;
5817 } else if (BEValueV != V) {
5818 BEValueV = nullptr;
5819 break;
5820 }
5821 } else if (!StartValueV) {
5822 StartValueV = V;
5823 } else if (StartValueV != V) {
5824 StartValueV = nullptr;
5825 break;
5826 }
5827 }
5828 if (!BEValueV || !StartValueV)
5829 return nullptr;
5830
5831 assert(ValueExprMap.find_as(PN) == ValueExprMap.end() &&
5832 "PHI node already processed?");
5833
5834 // First, try to find AddRec expression without creating a fictituos symbolic
5835 // value for PN.
5836 if (auto *S = createSimpleAffineAddRec(PN, BEValueV, StartValueV))
5837 return S;
5838
5839 // Handle PHI node value symbolically.
5840 const SCEV *SymbolicName = getUnknown(PN);
5841 insertValueToMap(PN, SymbolicName);
5842
5843 // Using this symbolic name for the PHI, analyze the value coming around
5844 // the back-edge.
5845 const SCEV *BEValue = getSCEV(BEValueV);
5846
5847 // NOTE: If BEValue is loop invariant, we know that the PHI node just
5848 // has a special value for the first iteration of the loop.
5849
5850 // If the value coming around the backedge is an add with the symbolic
5851 // value we just inserted, then we found a simple induction variable!
5852 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) {
5853 // If there is a single occurrence of the symbolic value, replace it
5854 // with a recurrence.
5855 unsigned FoundIndex = Add->getNumOperands();
5856 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5857 if (Add->getOperand(i) == SymbolicName)
5858 if (FoundIndex == e) {
5859 FoundIndex = i;
5860 break;
5861 }
5862
5863 if (FoundIndex != Add->getNumOperands()) {
5864 // Create an add with everything but the specified operand.
5866 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5867 if (i != FoundIndex)
5868 Ops.push_back(SCEVBackedgeConditionFolder::rewrite(Add->getOperand(i),
5869 L, *this));
5870 const SCEV *Accum = getAddExpr(Ops);
5871
5872 // This is not a valid addrec if the step amount is varying each
5873 // loop iteration, but is not itself an addrec in this loop.
5874 if (isLoopInvariant(Accum, L) ||
5875 (isa<SCEVAddRecExpr>(Accum) &&
5876 cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) {
5878
5879 if (auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN)) {
5880 if (BO->Opcode == Instruction::Add && BO->LHS == PN) {
5881 if (BO->IsNUW)
5882 Flags = setFlags(Flags, SCEV::FlagNUW);
5883 if (BO->IsNSW)
5884 Flags = setFlags(Flags, SCEV::FlagNSW);
5885 }
5886 } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(BEValueV)) {
5887 // If the increment is an inbounds GEP, then we know the address
5888 // space cannot be wrapped around. We cannot make any guarantee
5889 // about signed or unsigned overflow because pointers are
5890 // unsigned but we may have a negative index from the base
5891 // pointer. We can guarantee that no unsigned wrap occurs if the
5892 // indices form a positive value.
5893 if (GEP->isInBounds() && GEP->getOperand(0) == PN) {
5894 Flags = setFlags(Flags, SCEV::FlagNW);
5895 if (isKnownPositive(Accum))
5896 Flags = setFlags(Flags, SCEV::FlagNUW);
5897 }
5898
5899 // We cannot transfer nuw and nsw flags from subtraction
5900 // operations -- sub nuw X, Y is not the same as add nuw X, -Y
5901 // for instance.
5902 }
5903
5904 const SCEV *StartVal = getSCEV(StartValueV);
5905 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5906
5907 // Okay, for the entire analysis of this edge we assumed the PHI
5908 // to be symbolic. We now need to go back and purge all of the
5909 // entries for the scalars that use the symbolic expression.
5910 forgetMemoizedResults(SymbolicName);
5911 insertValueToMap(PN, PHISCEV);
5912
5913 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5914 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR),
5916 proveNoWrapViaConstantRanges(AR)));
5917 }
5918
5919 // We can add Flags to the post-inc expression only if we
5920 // know that it is *undefined behavior* for BEValueV to
5921 // overflow.
5922 if (auto *BEInst = dyn_cast<Instruction>(BEValueV))
5923 if (isLoopInvariant(Accum, L) && isAddRecNeverPoison(BEInst, L))
5924 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5925
5926 return PHISCEV;
5927 }
5928 }
5929 } else {
5930 // Otherwise, this could be a loop like this:
5931 // i = 0; for (j = 1; ..; ++j) { .... i = j; }
5932 // In this case, j = {1,+,1} and BEValue is j.
5933 // Because the other in-value of i (0) fits the evolution of BEValue
5934 // i really is an addrec evolution.
5935 //
5936 // We can generalize this saying that i is the shifted value of BEValue
5937 // by one iteration:
5938 // PHI(f(0), f({1,+,1})) --> f({0,+,1})
5939 const SCEV *Shifted = SCEVShiftRewriter::rewrite(BEValue, L, *this);
5940 const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this, false);
5941 if (Shifted != getCouldNotCompute() &&
5942 Start != getCouldNotCompute()) {
5943 const SCEV *StartVal = getSCEV(StartValueV);
5944 if (Start == StartVal) {
5945 // Okay, for the entire analysis of this edge we assumed the PHI
5946 // to be symbolic. We now need to go back and purge all of the
5947 // entries for the scalars that use the symbolic expression.
5948 forgetMemoizedResults(SymbolicName);
5949 insertValueToMap(PN, Shifted);
5950 return Shifted;
5951 }
5952 }
5953 }
5954
5955 // Remove the temporary PHI node SCEV that has been inserted while intending
5956 // to create an AddRecExpr for this PHI node. We can not keep this temporary
5957 // as it will prevent later (possibly simpler) SCEV expressions to be added
5958 // to the ValueExprMap.
5959 eraseValueFromMap(PN);
5960
5961 return nullptr;
5962}
5963
5964// Try to match a control flow sequence that branches out at BI and merges back
5965// at Merge into a "C ? LHS : RHS" select pattern. Return true on a successful
5966// match.
5968 Value *&C, Value *&LHS, Value *&RHS) {
5969 C = BI->getCondition();
5970
5971 BasicBlockEdge LeftEdge(BI->getParent(), BI->getSuccessor(0));
5972 BasicBlockEdge RightEdge(BI->getParent(), BI->getSuccessor(1));
5973
5974 if (!LeftEdge.isSingleEdge())
5975 return false;
5976
5977 assert(RightEdge.isSingleEdge() && "Follows from LeftEdge.isSingleEdge()");
5978
5979 Use &LeftUse = Merge->getOperandUse(0);
5980 Use &RightUse = Merge->getOperandUse(1);
5981
5982 if (DT.dominates(LeftEdge, LeftUse) && DT.dominates(RightEdge, RightUse)) {
5983 LHS = LeftUse;
5984 RHS = RightUse;
5985 return true;
5986 }
5987
5988 if (DT.dominates(LeftEdge, RightUse) && DT.dominates(RightEdge, LeftUse)) {
5989 LHS = RightUse;
5990 RHS = LeftUse;
5991 return true;
5992 }
5993
5994 return false;
5995}
5996
5997const SCEV *ScalarEvolution::createNodeFromSelectLikePHI(PHINode *PN) {
5998 auto IsReachable =
5999 [&](BasicBlock *BB) { return DT.isReachableFromEntry(BB); };
6000 if (PN->getNumIncomingValues() == 2 && all_of(PN->blocks(), IsReachable)) {
6001 // Try to match
6002 //
6003 // br %cond, label %left, label %right
6004 // left:
6005 // br label %merge
6006 // right:
6007 // br label %merge
6008 // merge:
6009 // V = phi [ %x, %left ], [ %y, %right ]
6010 //
6011 // as "select %cond, %x, %y"
6012
6013 BasicBlock *IDom = DT[PN->getParent()]->getIDom()->getBlock();
6014 assert(IDom && "At least the entry block should dominate PN");
6015
6016 auto *BI = dyn_cast<BranchInst>(IDom->getTerminator());
6017 Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr;
6018
6019 if (BI && BI->isConditional() &&
6020 BrPHIToSelect(DT, BI, PN, Cond, LHS, RHS) &&
6021 properlyDominates(getSCEV(LHS), PN->getParent()) &&
6022 properlyDominates(getSCEV(RHS), PN->getParent()))
6023 return createNodeForSelectOrPHI(PN, Cond, LHS, RHS);
6024 }
6025
6026 return nullptr;
6027}
6028
6029const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
6030 if (const SCEV *S = createAddRecFromPHI(PN))
6031 return S;
6032
6033 if (Value *V = simplifyInstruction(PN, {getDataLayout(), &TLI, &DT, &AC}))
6034 return getSCEV(V);
6035
6036 if (const SCEV *S = createNodeFromSelectLikePHI(PN))
6037 return S;
6038
6039 // If it's not a loop phi, we can't handle it yet.
6040 return getUnknown(PN);
6041}
6042
6043bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind,
6044 SCEVTypes RootKind) {
6045 struct FindClosure {
6046 const SCEV *OperandToFind;
6047 const SCEVTypes RootKind; // Must be a sequential min/max expression.
6048 const SCEVTypes NonSequentialRootKind; // Non-seq variant of RootKind.
6049
6050 bool Found = false;
6051
6052 bool canRecurseInto(SCEVTypes Kind) const {
6053 // We can only recurse into the SCEV expression of the same effective type
6054 // as the type of our root SCEV expression, and into zero-extensions.
6055 return RootKind == Kind || NonSequentialRootKind == Kind ||
6056 scZeroExtend == Kind;
6057 };
6058
6059 FindClosure(const SCEV *OperandToFind, SCEVTypes RootKind)
6060 : OperandToFind(OperandToFind), RootKind(RootKind),
6061 NonSequentialRootKind(
6063 RootKind)) {}
6064
6065 bool follow(const SCEV *S) {
6066 Found = S == OperandToFind;
6067
6068 return !isDone() && canRecurseInto(S->getSCEVType());
6069 }
6070
6071 bool isDone() const { return Found; }
6072 };
6073
6074 FindClosure FC(OperandToFind, RootKind);
6075 visitAll(Root, FC);
6076 return FC.Found;
6077}
6078
6079std::optional<const SCEV *>
6080ScalarEvolution::createNodeForSelectOrPHIInstWithICmpInstCond(Type *Ty,
6081 ICmpInst *Cond,
6082 Value *TrueVal,
6083 Value *FalseVal) {
6084 // Try to match some simple smax or umax patterns.
6085 auto *ICI = Cond;
6086
6087 Value *LHS = ICI->getOperand(0);
6088 Value *RHS = ICI->getOperand(1);
6089
6090 switch (ICI->getPredicate()) {
6091 case ICmpInst::ICMP_SLT:
6092 case ICmpInst::ICMP_SLE:
6093 case ICmpInst::ICMP_ULT:
6094 case ICmpInst::ICMP_ULE:
6095 std::swap(LHS, RHS);
6096 [[fallthrough]];
6097 case ICmpInst::ICMP_SGT:
6098 case ICmpInst::ICMP_SGE:
6099 case ICmpInst::ICMP_UGT:
6100 case ICmpInst::ICMP_UGE:
6101 // a > b ? a+x : b+x -> max(a, b)+x
6102 // a > b ? b+x : a+x -> min(a, b)+x
6104 bool Signed = ICI->isSigned();
6105 const SCEV *LA = getSCEV(TrueVal);
6106 const SCEV *RA = getSCEV(FalseVal);
6107 const SCEV *LS = getSCEV(LHS);
6108 const SCEV *RS = getSCEV(RHS);
6109 if (LA->getType()->isPointerTy()) {
6110 // FIXME: Handle cases where LS/RS are pointers not equal to LA/RA.
6111 // Need to make sure we can't produce weird expressions involving
6112 // negated pointers.
6113 if (LA == LS && RA == RS)
6114 return Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS);
6115 if (LA == RS && RA == LS)
6116 return Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS);
6117 }
6118 auto CoerceOperand = [&](const SCEV *Op) -> const SCEV * {
6119 if (Op->getType()->isPointerTy()) {
6121 if (isa<SCEVCouldNotCompute>(Op))
6122 return Op;
6123 }
6124 if (Signed)
6125 Op = getNoopOrSignExtend(Op, Ty);
6126 else
6127 Op = getNoopOrZeroExtend(Op, Ty);
6128 return Op;
6129 };
6130 LS = CoerceOperand(LS);
6131 RS = CoerceOperand(RS);
6132 if (isa<SCEVCouldNotCompute>(LS) || isa<SCEVCouldNotCompute>(RS))
6133 break;
6134 const SCEV *LDiff = getMinusSCEV(LA, LS);
6135 const SCEV *RDiff = getMinusSCEV(RA, RS);
6136 if (LDiff == RDiff)
6137 return getAddExpr(Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS),
6138 LDiff);
6139 LDiff = getMinusSCEV(LA, RS);
6140 RDiff = getMinusSCEV(RA, LS);
6141 if (LDiff == RDiff)
6142 return getAddExpr(Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS),
6143 LDiff);
6144 }
6145 break;
6146 case ICmpInst::ICMP_NE:
6147 // x != 0 ? x+y : C+y -> x == 0 ? C+y : x+y
6148 std::swap(TrueVal, FalseVal);
6149 [[fallthrough]];
6150 case ICmpInst::ICMP_EQ:
6151 // x == 0 ? C+y : x+y -> umax(x, C)+y iff C u<= 1
6153 isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()) {
6154 const SCEV *X = getNoopOrZeroExtend(getSCEV(LHS), Ty);
6155 const SCEV *TrueValExpr = getSCEV(TrueVal); // C+y
6156 const SCEV *FalseValExpr = getSCEV(FalseVal); // x+y
6157 const SCEV *Y = getMinusSCEV(FalseValExpr, X); // y = (x+y)-x
6158 const SCEV *C = getMinusSCEV(TrueValExpr, Y); // C = (C+y)-y
6159 if (isa<SCEVConstant>(C) && cast<SCEVConstant>(C)->getAPInt().ule(1))
6160 return getAddExpr(getUMaxExpr(X, C), Y);
6161 }
6162 // x == 0 ? 0 : umin (..., x, ...) -> umin_seq(x, umin (...))
6163 // x == 0 ? 0 : umin_seq(..., x, ...) -> umin_seq(x, umin_seq(...))
6164 // x == 0 ? 0 : umin (..., umin_seq(..., x, ...), ...)
6165 // -> umin_seq(x, umin (..., umin_seq(...), ...))
6166 if (isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero() &&
6167 isa<ConstantInt>(TrueVal) && cast<ConstantInt>(TrueVal)->isZero()) {
6168 const SCEV *X = getSCEV(LHS);
6169 while (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(X))
6170 X = ZExt->getOperand();
6171 if (getTypeSizeInBits(X->getType()) <= getTypeSizeInBits(Ty)) {
6172 const SCEV *FalseValExpr = getSCEV(FalseVal);
6173 if (SCEVMinMaxExprContains(FalseValExpr, X, scSequentialUMinExpr))
6174 return getUMinExpr(getNoopOrZeroExtend(X, Ty), FalseValExpr,
6175 /*Sequential=*/true);
6176 }
6177 }
6178 break;
6179 default:
6180 break;
6181 }
6182
6183 return std::nullopt;
6184}
6185
6186static std::optional<const SCEV *>
6188 const SCEV *TrueExpr, const SCEV *FalseExpr) {
6189 assert(CondExpr->getType()->isIntegerTy(1) &&
6190 TrueExpr->getType() == FalseExpr->getType() &&
6191 TrueExpr->getType()->isIntegerTy(1) &&
6192 "Unexpected operands of a select.");
6193
6194 // i1 cond ? i1 x : i1 C --> C + (i1 cond ? (i1 x - i1 C) : i1 0)
6195 // --> C + (umin_seq cond, x - C)
6196 //
6197 // i1 cond ? i1 C : i1 x --> C + (i1 cond ? i1 0 : (i1 x - i1 C))
6198 // --> C + (i1 ~cond ? (i1 x - i1 C) : i1 0)
6199 // --> C + (umin_seq ~cond, x - C)
6200
6201 // FIXME: while we can't legally model the case where both of the hands
6202 // are fully variable, we only require that the *difference* is constant.
6203 if (!isa<SCEVConstant>(TrueExpr) && !isa<SCEVConstant>(FalseExpr))
6204 return std::nullopt;
6205
6206 const SCEV *X, *C;
6207 if (isa<SCEVConstant>(TrueExpr)) {
6208 CondExpr = SE->getNotSCEV(CondExpr);
6209 X = FalseExpr;
6210 C = TrueExpr;
6211 } else {
6212 X = TrueExpr;
6213 C = FalseExpr;
6214 }
6215 return SE->getAddExpr(C, SE->getUMinExpr(CondExpr, SE->getMinusSCEV(X, C),
6216 /*Sequential=*/true));
6217}
6218
6219static std::optional<const SCEV *>
6221 Value *FalseVal) {
6222 if (!isa<ConstantInt>(TrueVal) && !isa<ConstantInt>(FalseVal))
6223 return std::nullopt;
6224
6225 const auto *SECond = SE->getSCEV(Cond);
6226 const auto *SETrue = SE->getSCEV(TrueVal);
6227 const auto *SEFalse = SE->getSCEV(FalseVal);
6228 return createNodeForSelectViaUMinSeq(SE, SECond, SETrue, SEFalse);
6229}
6230
6231const SCEV *ScalarEvolution::createNodeForSelectOrPHIViaUMinSeq(
6232 Value *V, Value *Cond, Value *TrueVal, Value *FalseVal) {
6233 assert(Cond->getType()->isIntegerTy(1) && "Select condition is not an i1?");
6234 assert(TrueVal->getType() == FalseVal->getType() &&
6235 V->getType() == TrueVal->getType() &&
6236 "Types of select hands and of the result must match.");
6237
6238 // For now, only deal with i1-typed `select`s.
6239 if (!V->getType()->isIntegerTy(1))
6240 return getUnknown(V);
6241
6242 if (std::optional<const SCEV *> S =
6243 createNodeForSelectViaUMinSeq(this, Cond, TrueVal, FalseVal))
6244 return *S;
6245
6246 return getUnknown(V);
6247}
6248
6249const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Value *V, Value *Cond,
6250 Value *TrueVal,
6251 Value *FalseVal) {
6252 // Handle "constant" branch or select. This can occur for instance when a
6253 // loop pass transforms an inner loop and moves on to process the outer loop.
6254 if (auto *CI = dyn_cast<ConstantInt>(Cond))
6255 return getSCEV(CI->isOne() ? TrueVal : FalseVal);
6256
6257 if (auto *I = dyn_cast<Instruction>(V)) {
6258 if (auto *ICI = dyn_cast<ICmpInst>(Cond)) {
6259 if (std::optional<const SCEV *> S =
6260 createNodeForSelectOrPHIInstWithICmpInstCond(I->getType(), ICI,
6261 TrueVal, FalseVal))
6262 return *S;
6263 }
6264 }
6265
6266 return createNodeForSelectOrPHIViaUMinSeq(V, Cond, TrueVal, FalseVal);
6267}
6268
6269/// Expand GEP instructions into add and multiply operations. This allows them
6270/// to be analyzed by regular SCEV code.
6271const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
6272 assert(GEP->getSourceElementType()->isSized() &&
6273 "GEP source element type must be sized");
6274
6276 for (Value *Index : GEP->indices())
6277 IndexExprs.push_back(getSCEV(Index));
6278 return getGEPExpr(GEP, IndexExprs);
6279}
6280
6281APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) {
6283 auto GetShiftedByZeros = [BitWidth](uint32_t TrailingZeros) {
6284 return TrailingZeros >= BitWidth
6286 : APInt::getOneBitSet(BitWidth, TrailingZeros);
6287 };
6288 auto GetGCDMultiple = [this](const SCEVNAryExpr *N) {
6289 // The result is GCD of all operands results.
6290 APInt Res = getConstantMultiple(N->getOperand(0));
6291 for (unsigned I = 1, E = N->getNumOperands(); I < E && Res != 1; ++I)
6293 Res, getConstantMultiple(N->getOperand(I)));
6294 return Res;
6295 };
6296
6297 switch (S->getSCEVType()) {
6298 case scConstant:
6299 return cast<SCEVConstant>(S)->getAPInt();
6300 case scPtrToInt:
6301 return getConstantMultiple(cast<SCEVPtrToIntExpr>(S)->getOperand());
6302 case scUDivExpr:
6303 case scVScale:
6304 return APInt(BitWidth, 1);
6305 case scTruncate: {
6306 // Only multiples that are a power of 2 will hold after truncation.
6307 const SCEVTruncateExpr *T = cast<SCEVTruncateExpr>(S);
6308 uint32_t TZ = getMinTrailingZeros(T->getOperand());
6309 return GetShiftedByZeros(TZ);
6310 }
6311 case scZeroExtend: {
6312 const SCEVZeroExtendExpr *Z = cast<SCEVZeroExtendExpr>(S);
6313 return getConstantMultiple(Z->getOperand()).zext(BitWidth);
6314 }
6315 case scSignExtend: {
6316 const SCEVSignExtendExpr *E = cast<SCEVSignExtendExpr>(S);
6317 return getConstantMultiple(E->getOperand()).sext(BitWidth);
6318 }
6319 case scMulExpr: {
6320 const SCEVMulExpr *M = cast<SCEVMulExpr>(S);
6321 if (M->hasNoUnsignedWrap()) {
6322 // The result is the product of all operand results.
6323 APInt Res = getConstantMultiple(M->getOperand(0));
6324 for (const SCEV *Operand : M->operands().drop_front())
6325 Res = Res * getConstantMultiple(Operand);
6326 return Res;
6327 }
6328
6329 // If there are no wrap guarentees, find the trailing zeros, which is the
6330 // sum of trailing zeros for all its operands.
6331 uint32_t TZ = 0;
6332 for (const SCEV *Operand : M->operands())
6333 TZ += getMinTrailingZeros(Operand);
6334 return GetShiftedByZeros(TZ);
6335 }
6336 case scAddExpr:
6337 case scAddRecExpr: {
6338 const SCEVNAryExpr *N = cast<SCEVNAryExpr>(S);
6339 if (N->hasNoUnsignedWrap())
6340 return GetGCDMultiple(N);
6341 // Find the trailing bits, which is the minimum of its operands.
6342 uint32_t TZ = getMinTrailingZeros(N->getOperand(0));
6343 for (const SCEV *Operand : N->operands().drop_front())
6344 TZ = std::min(TZ, getMinTrailingZeros(Operand));
6345 return GetShiftedByZeros(TZ);
6346 }
6347 case scUMaxExpr:
6348 case scSMaxExpr:
6349 case scUMinExpr:
6350 case scSMinExpr:
6352 return GetGCDMultiple(cast<SCEVNAryExpr>(S));
6353 case scUnknown: {
6354 // ask ValueTracking for known bits
6355 const SCEVUnknown *U = cast<SCEVUnknown>(S);
6356 unsigned Known =
6357 computeKnownBits(U->getValue(), getDataLayout(), 0, &AC, nullptr, &DT)
6358 .countMinTrailingZeros();
6359 return GetShiftedByZeros(Known);
6360 }
6361 case scCouldNotCompute:
6362 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6363 }
6364 llvm_unreachable("Unknown SCEV kind!");
6365}
6366
6368 auto I = ConstantMultipleCache.find(S);
6369 if (I != ConstantMultipleCache.end())
6370 return I->second;
6371
6372 APInt Result = getConstantMultipleImpl(S);
6373 auto InsertPair = ConstantMultipleCache.insert({S, Result});
6374 assert(InsertPair.second && "Should insert a new key");
6375 return InsertPair.first->second;
6376}
6377
6379 APInt Multiple = getConstantMultiple(S);
6380 return Multiple == 0 ? APInt(Multiple.getBitWidth(), 1) : Multiple;
6381}
6382
6384 return std::min(getConstantMultiple(S).countTrailingZeros(),
6385 (unsigned)getTypeSizeInBits(S->getType()));
6386}
6387
6388/// Helper method to assign a range to V from metadata present in the IR.
6389static std::optional<ConstantRange> GetRangeFromMetadata(Value *V) {
6390 if (Instruction *I = dyn_cast<Instruction>(V))
6391 if (MDNode *MD = I->getMetadata(LLVMContext::MD_range))
6392 return getConstantRangeFromMetadata(*MD);
6393
6394 return std::nullopt;
6395}
6396
6398 SCEV::NoWrapFlags Flags) {
6399 if (AddRec->getNoWrapFlags(Flags) != Flags) {
6400 AddRec->setNoWrapFlags(Flags);
6401 UnsignedRanges.erase(AddRec);
6402 SignedRanges.erase(AddRec);
6403 ConstantMultipleCache.erase(AddRec);
6404 }
6405}
6406
6407ConstantRange ScalarEvolution::
6408getRangeForUnknownRecurrence(const SCEVUnknown *U) {
6409 const DataLayout &DL = getDataLayout();
6410
6411 unsigned BitWidth = getTypeSizeInBits(U->getType());
6412 const ConstantRange FullSet(BitWidth, /*isFullSet=*/true);
6413
6414 // Match a simple recurrence of the form: <start, ShiftOp, Step>, and then
6415 // use information about the trip count to improve our available range. Note
6416 // that the trip count independent cases are already handled by known bits.
6417 // WARNING: The definition of recurrence used here is subtly different than
6418 // the one used by AddRec (and thus most of this file). Step is allowed to
6419 // be arbitrarily loop varying here, where AddRec allows only loop invariant
6420 // and other addrecs in the same loop (for non-affine addrecs). The code
6421 // below intentionally handles the case where step is not loop invariant.
6422 auto *P = dyn_cast<PHINode>(U->getValue());
6423 if (!P)
6424 return FullSet;
6425
6426 // Make sure that no Phi input comes from an unreachable block. Otherwise,
6427 // even the values that are not available in these blocks may come from them,
6428 // and this leads to false-positive recurrence test.
6429 for (auto *Pred : predecessors(P->getParent()))
6430 if (!DT.isReachableFromEntry(Pred))
6431 return FullSet;
6432
6433 BinaryOperator *BO;
6434 Value *Start, *Step;
6435 if (!matchSimpleRecurrence(P, BO, Start, Step))
6436 return FullSet;
6437
6438 // If we found a recurrence in reachable code, we must be in a loop. Note
6439 // that BO might be in some subloop of L, and that's completely okay.
6440 auto *L = LI.getLoopFor(P->getParent());
6441 assert(L && L->getHeader() == P->getParent());
6442 if (!L->contains(BO->getParent()))
6443 // NOTE: This bailout should be an assert instead. However, asserting
6444 // the condition here exposes a case where LoopFusion is querying SCEV
6445 // with malformed loop information during the midst of the transform.
6446 // There doesn't appear to be an obvious fix, so for the moment bailout
6447 // until the caller issue can be fixed. PR49566 tracks the bug.
6448 return FullSet;
6449
6450 // TODO: Extend to other opcodes such as mul, and div
6451 switch (BO->getOpcode()) {
6452 default:
6453 return FullSet;
6454 case Instruction::AShr:
6455 case Instruction::LShr:
6456 case Instruction::Shl:
6457 break;
6458 };
6459
6460 if (BO->getOperand(0) != P)
6461 // TODO: Handle the power function forms some day.
6462 return FullSet;
6463
6464 unsigned TC = getSmallConstantMaxTripCount(L);
6465 if (!TC || TC >= BitWidth)
6466 return FullSet;
6467
6468 auto KnownStart = computeKnownBits(Start, DL, 0, &AC, nullptr, &DT);
6469 auto KnownStep = computeKnownBits(Step, DL, 0, &AC, nullptr, &DT);
6470 assert(KnownStart.getBitWidth() == BitWidth &&
6471 KnownStep.getBitWidth() == BitWidth);
6472
6473 // Compute total shift amount, being careful of overflow and bitwidths.
6474 auto MaxShiftAmt = KnownStep.getMaxValue();
6475 APInt TCAP(BitWidth, TC-1);
6476 bool Overflow = false;
6477 auto TotalShift = MaxShiftAmt.umul_ov(TCAP, Overflow);
6478 if (Overflow)
6479 return FullSet;
6480
6481 switch (BO->getOpcode()) {
6482 default:
6483 llvm_unreachable("filtered out above");
6484 case Instruction::AShr: {
6485 // For each ashr, three cases:
6486 // shift = 0 => unchanged value
6487 // saturation => 0 or -1
6488 // other => a value closer to zero (of the same sign)
6489 // Thus, the end value is closer to zero than the start.
6490 auto KnownEnd = KnownBits::ashr(KnownStart,
6491 KnownBits::makeConstant(TotalShift));
6492 if (KnownStart.isNonNegative())
6493 // Analogous to lshr (simply not yet canonicalized)
6494 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6495 KnownStart.getMaxValue() + 1);
6496 if (KnownStart.isNegative())
6497 // End >=u Start && End <=s Start
6498 return ConstantRange::getNonEmpty(KnownStart.getMinValue(),
6499 KnownEnd.getMaxValue() + 1);
6500 break;
6501 }
6502 case Instruction::LShr: {
6503 // For each lshr, three cases:
6504 // shift = 0 => unchanged value
6505 // saturation => 0
6506 // other => a smaller positive number
6507 // Thus, the low end of the unsigned range is the last value produced.
6508 auto KnownEnd = KnownBits::lshr(KnownStart,
6509 KnownBits::makeConstant(TotalShift));
6510 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6511 KnownStart.getMaxValue() + 1);
6512 }
6513 case Instruction::Shl: {
6514 // Iff no bits are shifted out, value increases on every shift.
6515 auto KnownEnd = KnownBits::shl(KnownStart,
6516 KnownBits::makeConstant(TotalShift));
6517 if (TotalShift.ult(KnownStart.countMinLeadingZeros()))
6518 return ConstantRange(KnownStart.getMinValue(),
6519 KnownEnd.getMaxValue() + 1);
6520 break;
6521 }
6522 };
6523 return FullSet;
6524}
6525
6526const ConstantRange &
6527ScalarEvolution::getRangeRefIter(const SCEV *S,
6528 ScalarEvolution::RangeSignHint SignHint) {
6530 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6531 : SignedRanges;
6534
6535 // Add Expr to the worklist, if Expr is either an N-ary expression or a
6536 // SCEVUnknown PHI node.
6537 auto AddToWorklist = [&WorkList, &Seen, &Cache](const SCEV *Expr) {
6538 if (!Seen.insert(Expr).second)
6539 return;
6540 if (Cache.contains(Expr))
6541 return;
6542 switch (Expr->getSCEVType()) {
6543 case scUnknown:
6544 if (!isa<PHINode>(cast<SCEVUnknown>(Expr)->getValue()))
6545 break;
6546 [[fallthrough]];
6547 case scConstant:
6548 case scVScale:
6549 case scTruncate:
6550 case scZeroExtend:
6551 case scSignExtend:
6552 case scPtrToInt:
6553 case scAddExpr:
6554 case scMulExpr:
6555 case scUDivExpr:
6556 case scAddRecExpr:
6557 case scUMaxExpr:
6558 case scSMaxExpr:
6559 case scUMinExpr:
6560 case scSMinExpr:
6562 WorkList.push_back(Expr);
6563 break;
6564 case scCouldNotCompute:
6565 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6566 }
6567 };
6568 AddToWorklist(S);
6569
6570 // Build worklist by queuing operands of N-ary expressions and phi nodes.
6571 for (unsigned I = 0; I != WorkList.size(); ++I) {
6572 const SCEV *P = WorkList[I];
6573 auto *UnknownS = dyn_cast<SCEVUnknown>(P);
6574 // If it is not a `SCEVUnknown`, just recurse into operands.
6575 if (!UnknownS) {
6576 for (const SCEV *Op : P->operands())
6577 AddToWorklist(Op);
6578 continue;
6579 }
6580 // `SCEVUnknown`'s require special treatment.
6581 if (const PHINode *P = dyn_cast<PHINode>(UnknownS->getValue())) {
6582 if (!PendingPhiRangesIter.insert(P).second)
6583 continue;
6584 for (auto &Op : reverse(P->operands()))
6585 AddToWorklist(getSCEV(Op));
6586 }
6587 }
6588
6589 if (!WorkList.empty()) {
6590 // Use getRangeRef to compute ranges for items in the worklist in reverse
6591 // order. This will force ranges for earlier operands to be computed before
6592 // their users in most cases.
6593 for (const SCEV *P : reverse(drop_begin(WorkList))) {
6594 getRangeRef(P, SignHint);
6595
6596 if (auto *UnknownS = dyn_cast<SCEVUnknown>(P))
6597 if (const PHINode *P = dyn_cast<PHINode>(UnknownS->getValue()))
6598 PendingPhiRangesIter.erase(P);
6599 }
6600 }
6601
6602 return getRangeRef(S, SignHint, 0);
6603}
6604
6605/// Determine the range for a particular SCEV. If SignHint is
6606/// HINT_RANGE_UNSIGNED (resp. HINT_RANGE_SIGNED) then getRange prefers ranges
6607/// with a "cleaner" unsigned (resp. signed) representation.
6608const ConstantRange &ScalarEvolution::getRangeRef(
6609 const SCEV *S, ScalarEvolution::RangeSignHint SignHint, unsigned Depth) {
6611 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6612 : SignedRanges;
6614 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? ConstantRange::Unsigned
6616
6617 // See if we've computed this range already.
6619 if (I != Cache.end())
6620 return I->second;
6621
6622 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
6623 return setRange(C, SignHint, ConstantRange(C->getAPInt()));
6624
6625 // Switch to iteratively computing the range for S, if it is part of a deeply
6626 // nested expression.
6628 return getRangeRefIter(S, SignHint);
6629
6630 unsigned BitWidth = getTypeSizeInBits(S->getType());
6631 ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true);
6632 using OBO = OverflowingBinaryOperator;
6633
6634 // If the value has known zeros, the maximum value will have those known zeros
6635 // as well.
6636 if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) {
6637 APInt Multiple = getNonZeroConstantMultiple(S);
6638 APInt Remainder = APInt::getMaxValue(BitWidth).urem(Multiple);
6639 if (!Remainder.isZero())
6640 ConservativeResult =
6642 APInt::getMaxValue(BitWidth) - Remainder + 1);
6643 }
6644 else {
6646 if (TZ != 0) {
6647 ConservativeResult = ConstantRange(
6649 APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1);
6650 }
6651 }
6652
6653 switch (S->getSCEVType()) {
6654 case scConstant:
6655 llvm_unreachable("Already handled above.");
6656 case scVScale:
6657 return setRange(S, SignHint, getVScaleRange(&F, BitWidth));
6658 case scTruncate: {
6659 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(S);
6660 ConstantRange X = getRangeRef(Trunc->getOperand(), SignHint, Depth + 1);
6661 return setRange(
6662 Trunc, SignHint,
6663 ConservativeResult.intersectWith(X.truncate(BitWidth), RangeType));
6664 }
6665 case scZeroExtend: {
6666 const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(S);
6667 ConstantRange X = getRangeRef(ZExt->getOperand(), SignHint, Depth + 1);
6668 return setRange(
6669 ZExt, SignHint,
6670 ConservativeResult.intersectWith(X.zeroExtend(BitWidth), RangeType));
6671 }
6672 case scSignExtend: {
6673 const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(S);
6674 ConstantRange X = getRangeRef(SExt->getOperand(), SignHint, Depth + 1);
6675 return setRange(
6676 SExt, SignHint,
6677 ConservativeResult.intersectWith(X.signExtend(BitWidth), RangeType));
6678 }
6679 case scPtrToInt: {
6680 const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(S);
6681 ConstantRange X = getRangeRef(PtrToInt->getOperand(), SignHint, Depth + 1);
6682 return setRange(PtrToInt, SignHint, X);
6683 }
6684 case scAddExpr: {
6685 const SCEVAddExpr *Add = cast<SCEVAddExpr>(S);
6686 ConstantRange X = getRangeRef(Add->getOperand(0), SignHint, Depth + 1);
6687 unsigned WrapType = OBO::AnyWrap;
6688 if (Add->hasNoSignedWrap())
6689 WrapType |= OBO::NoSignedWrap;
6690 if (Add->hasNoUnsignedWrap())
6691 WrapType |= OBO::NoUnsignedWrap;
6692 for (unsigned i = 1, e = Add->getNumOperands(); i != e; ++i)
6693 X = X.addWithNoWrap(getRangeRef(Add->getOperand(i), SignHint, Depth + 1),
6694 WrapType, RangeType);
6695 return setRange(Add, SignHint,
6696 ConservativeResult.intersectWith(X, RangeType));
6697 }
6698 case scMulExpr: {
6699 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(S);
6700 ConstantRange X = getRangeRef(Mul->getOperand(0), SignHint, Depth + 1);
6701 for (unsigned i = 1, e = Mul->getNumOperands(); i != e; ++i)
6702 X = X.multiply(getRangeRef(Mul->getOperand(i), SignHint, Depth + 1));
6703 return setRange(Mul, SignHint,
6704 ConservativeResult.intersectWith(X, RangeType));
6705 }
6706 case scUDivExpr: {
6707 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
6708 ConstantRange X = getRangeRef(UDiv->getLHS(), SignHint, Depth + 1);
6709 ConstantRange Y = getRangeRef(UDiv->getRHS(), SignHint, Depth + 1);
6710 return setRange(UDiv, SignHint,
6711 ConservativeResult.intersectWith(X.udiv(Y), RangeType));
6712 }
6713 case scAddRecExpr: {
6714 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(S);
6715 // If there's no unsigned wrap, the value will never be less than its
6716 // initial value.
6717 if (AddRec->hasNoUnsignedWrap()) {
6718 APInt UnsignedMinValue = getUnsignedRangeMin(AddRec->getStart());
6719 if (!UnsignedMinValue.isZero())
6720 ConservativeResult = ConservativeResult.intersectWith(
6721 ConstantRange(UnsignedMinValue, APInt(BitWidth, 0)), RangeType);
6722 }
6723
6724 // If there's no signed wrap, and all the operands except initial value have
6725 // the same sign or zero, the value won't ever be:
6726 // 1: smaller than initial value if operands are non negative,
6727 // 2: bigger than initial value if operands are non positive.
6728 // For both cases, value can not cross signed min/max boundary.
6729 if (AddRec->hasNoSignedWrap()) {
6730 bool AllNonNeg = true;
6731 bool AllNonPos = true;
6732 for (unsigned i = 1, e = AddRec->getNumOperands(); i != e; ++i) {
6733 if (!isKnownNonNegative(AddRec->getOperand(i)))
6734 AllNonNeg = false;
6735 if (!isKnownNonPositive(AddRec->getOperand(i)))
6736 AllNonPos = false;
6737 }
6738 if (AllNonNeg)
6739 ConservativeResult = ConservativeResult.intersectWith(
6742 RangeType);
6743 else if (AllNonPos)
6744 ConservativeResult = ConservativeResult.intersectWith(
6746 getSignedRangeMax(AddRec->getStart()) +
6747 1),
6748 RangeType);
6749 }
6750
6751 // TODO: non-affine addrec
6752 if (AddRec->isAffine()) {
6753 const SCEV *MaxBEScev =
6755 if (!isa<SCEVCouldNotCompute>(MaxBEScev)) {
6756 APInt MaxBECount = cast<SCEVConstant>(MaxBEScev)->getAPInt();
6757
6758 // Adjust MaxBECount to the same bitwidth as AddRec. We can truncate if
6759 // MaxBECount's active bits are all <= AddRec's bit width.
6760 if (MaxBECount.getBitWidth() > BitWidth &&
6761 MaxBECount.getActiveBits() <= BitWidth)
6762 MaxBECount = MaxBECount.trunc(BitWidth);
6763 else if (MaxBECount.getBitWidth() < BitWidth)
6764 MaxBECount = MaxBECount.zext(BitWidth);
6765
6766 if (MaxBECount.getBitWidth() == BitWidth) {
6767 auto RangeFromAffine = getRangeForAffineAR(
6768 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
6769 ConservativeResult =
6770 ConservativeResult.intersectWith(RangeFromAffine, RangeType);
6771
6772 auto RangeFromFactoring = getRangeViaFactoring(
6773 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
6774 ConservativeResult =
6775 ConservativeResult.intersectWith(RangeFromFactoring, RangeType);
6776 }
6777 }
6778
6779 // Now try symbolic BE count and more powerful methods.
6781 const SCEV *SymbolicMaxBECount =
6783 if (!isa<SCEVCouldNotCompute>(SymbolicMaxBECount) &&
6784 getTypeSizeInBits(MaxBEScev->getType()) <= BitWidth &&
6785 AddRec->hasNoSelfWrap()) {
6786 auto RangeFromAffineNew = getRangeForAffineNoSelfWrappingAR(
6787 AddRec, SymbolicMaxBECount, BitWidth, SignHint);
6788 ConservativeResult =
6789 ConservativeResult.intersectWith(RangeFromAffineNew, RangeType);
6790 }
6791 }
6792 }
6793
6794 return setRange(AddRec, SignHint, std::move(ConservativeResult));
6795 }
6796 case scUMaxExpr:
6797 case scSMaxExpr:
6798 case scUMinExpr:
6799 case scSMinExpr:
6800 case scSequentialUMinExpr: {
6802 switch (S->getSCEVType()) {
6803 case scUMaxExpr:
6804 ID = Intrinsic::umax;
6805 break;
6806 case scSMaxExpr:
6807 ID = Intrinsic::smax;
6808 break;
6809 case scUMinExpr:
6811 ID = Intrinsic::umin;
6812 break;
6813 case scSMinExpr:
6814 ID = Intrinsic::smin;
6815 break;
6816 default:
6817 llvm_unreachable("Unknown SCEVMinMaxExpr/SCEVSequentialMinMaxExpr.");
6818 }
6819
6820 const auto *NAry = cast<SCEVNAryExpr>(S);
6821 ConstantRange X = getRangeRef(NAry->getOperand(0), SignHint, Depth + 1);
6822 for (unsigned i = 1, e = NAry->getNumOperands(); i != e; ++i)
6823 X = X.intrinsic(
6824 ID, {X, getRangeRef(NAry->getOperand(i), SignHint, Depth + 1)});
6825 return setRange(S, SignHint,
6826 ConservativeResult.intersectWith(X, RangeType));
6827 }
6828 case scUnknown: {
6829 const SCEVUnknown *U = cast<SCEVUnknown>(S);
6830 Value *V = U->getValue();
6831
6832 // Check if the IR explicitly contains !range metadata.
6833 std::optional<ConstantRange> MDRange = GetRangeFromMetadata(V);
6834 if (MDRange)
6835 ConservativeResult =
6836 ConservativeResult.intersectWith(*MDRange, RangeType);
6837
6838 // Use facts about recurrences in the underlying IR. Note that add
6839 // recurrences are AddRecExprs and thus don't hit this path. This
6840 // primarily handles shift recurrences.
6841 auto CR = getRangeForUnknownRecurrence(U);
6842 ConservativeResult = ConservativeResult.intersectWith(CR);
6843
6844 // See if ValueTracking can give us a useful range.
6845 const DataLayout &DL = getDataLayout();
6846 KnownBits Known = computeKnownBits(V, DL, 0, &AC, nullptr, &DT);
6847 if (Known.getBitWidth() != BitWidth)
6848 Known = Known.zextOrTrunc(BitWidth);
6849
6850 // ValueTracking may be able to compute a tighter result for the number of
6851 // sign bits than for the value of those sign bits.
6852 unsigned NS = ComputeNumSignBits(V, DL, 0, &AC, nullptr, &DT);
6853 if (U->getType()->isPointerTy()) {
6854 // If the pointer size is larger than the index size type, this can cause
6855 // NS to be larger than BitWidth. So compensate for this.
6856 unsigned ptrSize = DL.getPointerTypeSizeInBits(U->getType());
6857 int ptrIdxDiff = ptrSize - BitWidth;
6858 if (ptrIdxDiff > 0 && ptrSize > BitWidth && NS > (unsigned)ptrIdxDiff)
6859 NS -= ptrIdxDiff;
6860 }
6861
6862 if (NS > 1) {
6863 // If we know any of the sign bits, we know all of the sign bits.
6864 if (!Known.Zero.getHiBits(NS).isZero())
6865 Known.Zero.setHighBits(NS);
6866 if (!Known.One.getHiBits(NS).isZero())
6867 Known.One.setHighBits(NS);
6868 }
6869
6870 if (Known.getMinValue() != Known.getMaxValue() + 1)
6871 ConservativeResult = ConservativeResult.intersectWith(
6872 ConstantRange(Known.getMinValue(), Known.getMaxValue() + 1),
6873 RangeType);
6874 if (NS > 1)
6875 ConservativeResult = ConservativeResult.intersectWith(
6877 APInt::getSignedMaxValue(BitWidth).ashr(NS - 1) + 1),
6878 RangeType);
6879
6880 if (U->getType()->isPointerTy() && SignHint == HINT_RANGE_UNSIGNED) {
6881 // Strengthen the range if the underlying IR value is a
6882 // global/alloca/heap allocation using the size of the object.
6883 ObjectSizeOpts Opts;
6884 Opts.RoundToAlign = false;
6885 Opts.NullIsUnknownSize = true;
6886 uint64_t ObjSize;
6887 if ((isa<GlobalVariable>(V) || isa<AllocaInst>(V) ||
6888 isAllocationFn(V, &TLI)) &&
6889 getObjectSize(V, ObjSize, DL, &TLI, Opts) && ObjSize > 1) {
6890 // The highest address the object can start is ObjSize bytes before the
6891 // end (unsigned max value). If this value is not a multiple of the
6892 // alignment, the last possible start value is the next lowest multiple
6893 // of the alignment. Note: The computations below cannot overflow,
6894 // because if they would there's no possible start address for the
6895 // object.
6896 APInt MaxVal = APInt::getMaxValue(BitWidth) - APInt(BitWidth, ObjSize);
6897 uint64_t Align = U->getValue()->getPointerAlignment(DL).value();
6898 uint64_t Rem = MaxVal.urem(Align);
6899 MaxVal -= APInt(BitWidth, Rem);
6900 APInt MinVal = APInt::getZero(BitWidth);
6901 if (llvm::isKnownNonZero(V, DL))
6902 MinVal = Align;
6903 ConservativeResult = ConservativeResult.intersectWith(
6904 ConstantRange::getNonEmpty(MinVal, MaxVal + 1), RangeType);
6905 }
6906 }
6907
6908 // A range of Phi is a subset of union of all ranges of its input.
6909 if (PHINode *Phi = dyn_cast<PHINode>(V)) {
6910 // Make sure that we do not run over cycled Phis.
6911 if (PendingPhiRanges.insert(Phi).second) {
6912 ConstantRange RangeFromOps(BitWidth, /*isFullSet=*/false);
6913
6914 for (const auto &Op : Phi->operands()) {
6915 auto OpRange = getRangeRef(getSCEV(Op), SignHint, Depth + 1);
6916 RangeFromOps = RangeFromOps.unionWith(OpRange);
6917 // No point to continue if we already have a full set.
6918 if (RangeFromOps.isFullSet())
6919 break;
6920 }
6921 ConservativeResult =
6922 ConservativeResult.intersectWith(RangeFromOps, RangeType);
6923 bool Erased = PendingPhiRanges.erase(Phi);
6924 assert(Erased && "Failed to erase Phi properly?");
6925 (void)Erased;
6926 }
6927 }
6928
6929 // vscale can't be equal to zero
6930 if (const auto *II = dyn_cast<IntrinsicInst>(V))
6931 if (II->getIntrinsicID() == Intrinsic::vscale) {
6933 ConservativeResult = ConservativeResult.difference(Disallowed);
6934 }
6935
6936 return setRange(U, SignHint, std::move(ConservativeResult));
6937 }
6938 case scCouldNotCompute:
6939 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6940 }
6941
6942 return setRange(S, SignHint, std::move(ConservativeResult));
6943}
6944
6945// Given a StartRange, Step and MaxBECount for an expression compute a range of
6946// values that the expression can take. Initially, the expression has a value
6947// from StartRange and then is changed by Step up to MaxBECount times. Signed
6948// argument defines if we treat Step as signed or unsigned.
6950 const ConstantRange &StartRange,
6951 const APInt &MaxBECount,
6952 bool Signed) {
6953 unsigned BitWidth = Step.getBitWidth();
6954 assert(BitWidth == StartRange.getBitWidth() &&
6955 BitWidth == MaxBECount.getBitWidth() && "mismatched bit widths");
6956 // If either Step or MaxBECount is 0, then the expression won't change, and we
6957 // just need to return the initial range.
6958 if (Step == 0 || MaxBECount == 0)
6959 return StartRange;
6960
6961 // If we don't know anything about the initial value (i.e. StartRange is
6962 // FullRange), then we don't know anything about the final range either.
6963 // Return FullRange.
6964 if (StartRange.isFullSet())
6965 return ConstantRange::getFull(BitWidth);
6966
6967 // If Step is signed and negative, then we use its absolute value, but we also
6968 // note that we're moving in the opposite direction.
6969 bool Descending = Signed && Step.isNegative();
6970
6971 if (Signed)
6972 // This is correct even for INT_SMIN. Let's look at i8 to illustrate this:
6973 // abs(INT_SMIN) = abs(-128) = abs(0x80) = -0x80 = 0x80 = 128.
6974 // This equations hold true due to the well-defined wrap-around behavior of
6975 // APInt.
6976 Step = Step.abs();
6977
6978 // Check if Offset is more than full span of BitWidth. If it is, the
6979 // expression is guaranteed to overflow.
6980 if (APInt::getMaxValue(StartRange.getBitWidth()).udiv(Step).ult(MaxBECount))
6981 return ConstantRange::getFull(BitWidth);
6982
6983 // Offset is by how much the expression can change. Checks above guarantee no
6984 // overflow here.
6985 APInt Offset = Step * MaxBECount;
6986
6987 // Minimum value of the final range will match the minimal value of StartRange
6988 // if the expression is increasing and will be decreased by Offset otherwise.
6989 // Maximum value of the final range will match the maximal value of StartRange
6990 // if the expression is decreasing and will be increased by Offset otherwise.
6991 APInt StartLower = StartRange.getLower();
6992 APInt StartUpper = StartRange.getUpper() - 1;
6993 APInt MovedBoundary = Descending ? (StartLower - std::move(Offset))
6994 : (StartUpper + std::move(Offset));
6995
6996 // It's possible that the new minimum/maximum value will fall into the initial
6997 // range (due to wrap around). This means that the expression can take any
6998 // value in this bitwidth, and we have to return full range.
6999 if (StartRange.contains(MovedBoundary))
7000 return ConstantRange::getFull(BitWidth);
7001
7002 APInt NewLower =
7003 Descending ? std::move(MovedBoundary) : std::move(StartLower);
7004 APInt NewUpper =
7005 Descending ? std::move(StartUpper) : std::move(MovedBoundary);
7006 NewUpper += 1;
7007
7008 // No overflow detected, return [StartLower, StartUpper + Offset + 1) range.
7009 return ConstantRange::getNonEmpty(std::move(NewLower), std::move(NewUpper));
7010}
7011
7012ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start,
7013 const SCEV *Step,
7014 const APInt &MaxBECount) {
7015 assert(getTypeSizeInBits(Start->getType()) ==
7016 getTypeSizeInBits(Step->getType()) &&
7017 getTypeSizeInBits(Start->getType()) == MaxBECount.getBitWidth() &&
7018 "mismatched bit widths");
7019
7020 // First, consider step signed.
7021 ConstantRange StartSRange = getSignedRange(Start);
7022 ConstantRange StepSRange = getSignedRange(Step);
7023
7024 // If Step can be both positive and negative, we need to find ranges for the
7025 // maximum absolute step values in both directions and union them.
7027 StepSRange.getSignedMin(), StartSRange, MaxBECount, /* Signed = */ true);
7029 StartSRange, MaxBECount,
7030 /* Signed = */ true));
7031
7032 // Next, consider step unsigned.
7034 getUnsignedRangeMax(Step), getUnsignedRange(Start), MaxBECount,
7035 /* Signed = */ false);
7036
7037 // Finally, intersect signed and unsigned ranges.
7039}
7040
7041ConstantRange ScalarEvolution::getRangeForAffineNoSelfWrappingAR(
7042 const SCEVAddRecExpr *AddRec, const SCEV *MaxBECount, unsigned BitWidth,
7043 ScalarEvolution::RangeSignHint SignHint) {
7044 assert(AddRec->isAffine() && "Non-affine AddRecs are not suppored!\n");
7045 assert(AddRec->hasNoSelfWrap() &&
7046 "This only works for non-self-wrapping AddRecs!");
7047 const bool IsSigned = SignHint == HINT_RANGE_SIGNED;
7048 const SCEV *Step = AddRec->getStepRecurrence(*this);
7049 // Only deal with constant step to save compile time.
7050 if (!isa<SCEVConstant>(Step))
7051 return ConstantRange::getFull(BitWidth);
7052 // Let's make sure that we can prove that we do not self-wrap during
7053 // MaxBECount iterations. We need this because MaxBECount is a maximum
7054 // iteration count estimate, and we might infer nw from some exit for which we
7055 // do not know max exit count (or any other side reasoning).
7056 // TODO: Turn into assert at some point.
7057 if (getTypeSizeInBits(MaxBECount->getType()) >
7058 getTypeSizeInBits(AddRec->getType()))
7059 return ConstantRange::getFull(BitWidth);
7060 MaxBECount = getNoopOrZeroExtend(MaxBECount, AddRec->getType());
7061 const SCEV *RangeWidth = getMinusOne(AddRec->getType());
7062 const SCEV *StepAbs = getUMinExpr(Step, getNegativeSCEV(Step));
7063 const SCEV *MaxItersWithoutWrap = getUDivExpr(RangeWidth, StepAbs);
7064 if (!isKnownPredicateViaConstantRanges(ICmpInst::ICMP_ULE, MaxBECount,
7065 MaxItersWithoutWrap))
7066 return ConstantRange::getFull(BitWidth);
7067
7068 ICmpInst::Predicate LEPred =
7070 ICmpInst::Predicate GEPred =
7072 const SCEV *End = AddRec->evaluateAtIteration(MaxBECount, *this);
7073
7074 // We know that there is no self-wrap. Let's take Start and End values and
7075 // look at all intermediate values V1, V2, ..., Vn that IndVar takes during
7076 // the iteration. They either lie inside the range [Min(Start, End),
7077 // Max(Start, End)] or outside it:
7078 //
7079 // Case 1: RangeMin ... Start V1 ... VN End ... RangeMax;
7080 // Case 2: RangeMin Vk ... V1 Start ... End Vn ... Vk + 1 RangeMax;
7081 //
7082 // No self wrap flag guarantees that the intermediate values cannot be BOTH
7083 // outside and inside the range [Min(Start, End), Max(Start, End)]. Using that
7084 // knowledge, let's try to prove that we are dealing with Case 1. It is so if
7085 // Start <= End and step is positive, or Start >= End and step is negative.
7086 const SCEV *Start = applyLoopGuards(AddRec->getStart(), AddRec->getLoop());
7087 ConstantRange StartRange = getRangeRef(Start, SignHint);
7088 ConstantRange EndRange = getRangeRef(End, SignHint);
7089 ConstantRange RangeBetween = StartRange.unionWith(EndRange);
7090 // If they already cover full iteration space, we will know nothing useful
7091 // even if we prove what we want to prove.
7092 if (RangeBetween.isFullSet())
7093 return RangeBetween;
7094 // Only deal with ranges that do not wrap (i.e. RangeMin < RangeMax).
7095 bool IsWrappedSet = IsSigned ? RangeBetween.isSignWrappedSet()
7096 : RangeBetween.isWrappedSet();
7097 if (IsWrappedSet)
7098 return ConstantRange::getFull(BitWidth);
7099
7100 if (isKnownPositive(Step) &&
7101 isKnownPredicateViaConstantRanges(LEPred, Start, End))
7102 return RangeBetween;
7103 if (isKnownNegative(Step) &&
7104 isKnownPredicateViaConstantRanges(GEPred, Start, End))
7105 return RangeBetween;
7106 return ConstantRange::getFull(BitWidth);
7107}
7108
7109ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start,
7110 const SCEV *Step,
7111 const APInt &MaxBECount) {
7112 // RangeOf({C?A:B,+,C?P:Q}) == RangeOf(C?{A,+,P}:{B,+,Q})
7113 // == RangeOf({A,+,P}) union RangeOf({B,+,Q})
7114
7115 unsigned BitWidth = MaxBECount.getBitWidth();
7116 assert(getTypeSizeInBits(Start->getType()) == BitWidth &&
7117 getTypeSizeInBits(Step->getType()) == BitWidth &&
7118 "mismatched bit widths");
7119
7120 struct SelectPattern {
7121 Value *Condition = nullptr;
7122 APInt TrueValue;
7123 APInt FalseValue;
7124
7125 explicit SelectPattern(ScalarEvolution &SE, unsigned BitWidth,
7126 const SCEV *S) {
7127 std::optional<unsigned> CastOp;
7128 APInt Offset(BitWidth, 0);
7129
7131 "Should be!");
7132
7133 // Peel off a constant offset:
7134 if (auto *SA = dyn_cast<SCEVAddExpr>(S)) {
7135 // In the future we could consider being smarter here and handle
7136 // {Start+Step,+,Step} too.
7137 if (SA->getNumOperands() != 2 || !isa<SCEVConstant>(SA->getOperand(0)))
7138 return;
7139
7140 Offset = cast<SCEVConstant>(SA->getOperand(0))->getAPInt();
7141 S = SA->getOperand(1);
7142 }
7143
7144 // Peel off a cast operation
7145 if (auto *SCast = dyn_cast<SCEVIntegralCastExpr>(S)) {
7146 CastOp = SCast->getSCEVType();
7147 S = SCast->getOperand();
7148 }
7149
7150 using namespace llvm::PatternMatch;
7151
7152 auto *SU = dyn_cast<SCEVUnknown>(S);
7153 const APInt *TrueVal, *FalseVal;
7154 if (!SU ||
7155 !match(SU->getValue(), m_Select(m_Value(Condition), m_APInt(TrueVal),
7156 m_APInt(FalseVal)))) {
7157 Condition = nullptr;
7158 return;
7159 }
7160
7161 TrueValue = *TrueVal;
7162 FalseValue = *FalseVal;
7163
7164 // Re-apply the cast we peeled off earlier
7165 if (CastOp)
7166 switch (*CastOp) {
7167 default:
7168 llvm_unreachable("Unknown SCEV cast type!");
7169
7170 case scTruncate:
7171 TrueValue = TrueValue.trunc(BitWidth);
7172 FalseValue = FalseValue.trunc(BitWidth);
7173 break;
7174 case scZeroExtend:
7175 TrueValue = TrueValue.zext(BitWidth);
7176 FalseValue = FalseValue.zext(BitWidth);
7177 break;
7178 case scSignExtend:
7179 TrueValue = TrueValue.sext(BitWidth);
7180 FalseValue = FalseValue.sext(BitWidth);
7181 break;
7182 }
7183
7184 // Re-apply the constant offset we peeled off earlier
7185 TrueValue += Offset;
7186 FalseValue += Offset;
7187 }
7188
7189 bool isRecognized() { return Condition != nullptr; }
7190 };
7191
7192 SelectPattern StartPattern(*this, BitWidth, Start);
7193 if (!StartPattern.isRecognized())
7194 return ConstantRange::getFull(BitWidth);
7195
7196 SelectPattern StepPattern(*this, BitWidth, Step);
7197 if (!StepPattern.isRecognized())
7198 return ConstantRange::getFull(BitWidth);
7199
7200 if (StartPattern.Condition != StepPattern.Condition) {
7201 // We don't handle this case today; but we could, by considering four
7202 // possibilities below instead of two. I'm not sure if there are cases where
7203 // that will help over what getRange already does, though.
7204 return ConstantRange::getFull(BitWidth);
7205 }
7206
7207 // NB! Calling ScalarEvolution::getConstant is fine, but we should not try to
7208 // construct arbitrary general SCEV expressions here. This function is called
7209 // from deep in the call stack, and calling getSCEV (on a sext instruction,
7210 // say) can end up caching a suboptimal value.
7211
7212 // FIXME: without the explicit `this` receiver below, MSVC errors out with
7213 // C2352 and C2512 (otherwise it isn't needed).
7214
7215 const SCEV *TrueStart = this->getConstant(StartPattern.TrueValue);
7216 const SCEV *TrueStep = this->getConstant(StepPattern.TrueValue);
7217 const SCEV *FalseStart = this->getConstant(StartPattern.FalseValue);
7218 const SCEV *FalseStep = this->getConstant(StepPattern.FalseValue);
7219
7220 ConstantRange TrueRange =
7221 this->getRangeForAffineAR(TrueStart, TrueStep, MaxBECount);
7222 ConstantRange FalseRange =
7223 this->getRangeForAffineAR(FalseStart, FalseStep, MaxBECount);
7224
7225 return TrueRange.unionWith(FalseRange);
7226}
7227
7228SCEV::NoWrapFlags ScalarEvolution::getNoWrapFlagsFromUB(const Value *V) {
7229 if (isa<ConstantExpr>(V)) return SCEV::FlagAnyWrap;
7230 const BinaryOperator *BinOp = cast<BinaryOperator>(V);
7231
7232 // Return early if there are no flags to propagate to the SCEV.
7234 if (BinOp->hasNoUnsignedWrap())
7236 if (BinOp->hasNoSignedWrap())
7238 if (Flags == SCEV::FlagAnyWrap)
7239 return SCEV::FlagAnyWrap;
7240
7241 return isSCEVExprNeverPoison(BinOp) ? Flags : SCEV::FlagAnyWrap;
7242}
7243
7244const Instruction *
7245ScalarEvolution::getNonTrivialDefiningScopeBound(const SCEV *S) {
7246 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S))
7247 return &*AddRec->getLoop()->getHeader()->begin();
7248 if (auto *U = dyn_cast<SCEVUnknown>(S))
7249 if (auto *I = dyn_cast<Instruction>(U->getValue()))
7250 return I;
7251 return nullptr;
7252}
7253
7254const Instruction *
7255ScalarEvolution::getDefiningScopeBound(ArrayRef<const SCEV *> Ops,
7256 bool &Precise) {
7257 Precise = true;
7258 // Do a bounded search of the def relation of the requested SCEVs.
7261 auto pushOp = [&](const SCEV *S) {
7262 if (!Visited.insert(S).second)
7263 return;
7264 // Threshold of 30 here is arbitrary.
7265 if (Visited.size() > 30) {
7266 Precise = false;
7267 return;
7268 }
7269 Worklist.push_back(S);
7270 };
7271
7272 for (const auto *S : Ops)
7273 pushOp(S);
7274
7275 const Instruction *Bound = nullptr;
7276 while (!Worklist.empty()) {
7277 auto *S = Worklist.pop_back_val();
7278 if (auto *DefI = getNonTrivialDefiningScopeBound(S)) {
7279 if (!Bound || DT.dominates(Bound, DefI))
7280 Bound = DefI;
7281 } else {
7282 for (const auto *Op : S->operands())
7283 pushOp(Op);
7284 }
7285 }
7286 return Bound ? Bound : &*F.getEntryBlock().begin();
7287}
7288
7289const Instruction *
7290ScalarEvolution::getDefiningScopeBound(ArrayRef<const SCEV *> Ops) {
7291 bool Discard;
7292 return getDefiningScopeBound(Ops, Discard);
7293}
7294
7295bool ScalarEvolution::isGuaranteedToTransferExecutionTo(const Instruction *A,
7296 const Instruction *B) {
7297 if (A->getParent() == B->getParent() &&
7299 B->getIterator()))
7300 return true;
7301
7302 auto *BLoop = LI.getLoopFor(B->getParent());
7303 if (BLoop && BLoop->getHeader() == B->getParent() &&
7304 BLoop->getLoopPreheader() == A->getParent() &&
7306 A->getParent()->end()) &&
7307 isGuaranteedToTransferExecutionToSuccessor(B->getParent()->begin(),
7308 B->getIterator()))
7309 return true;
7310 return false;
7311}
7312
7313
7314bool ScalarEvolution::isSCEVExprNeverPoison(const Instruction *I) {
7315 // Only proceed if we can prove that I does not yield poison.
7317 return false;
7318
7319 // At this point we know that if I is executed, then it does not wrap
7320 // according to at least one of NSW or NUW. If I is not executed, then we do
7321 // not know if the calculation that I represents would wrap. Multiple
7322 // instructions can map to the same SCEV. If we apply NSW or NUW from I to
7323 // the SCEV, we must guarantee no wrapping for that SCEV also when it is
7324 // derived from other instructions that map to the same SCEV. We cannot make
7325 // that guarantee for cases where I is not executed. So we need to find a
7326 // upper bound on the defining scope for the SCEV, and prove that I is
7327 // executed every time we enter that scope. When the bounding scope is a
7328 // loop (the common case), this is equivalent to proving I executes on every
7329 // iteration of that loop.
7331 for (const Use &Op : I->operands()) {
7332 // I could be an extractvalue from a call to an overflow intrinsic.
7333 // TODO: We can do better here in some cases.
7334 if (isSCEVable(Op->getType()))
7335 SCEVOps.push_back(getSCEV(Op));
7336 }
7337 auto *DefI = getDefiningScopeBound(SCEVOps);
7338 return isGuaranteedToTransferExecutionTo(DefI, I);
7339}
7340
7341bool ScalarEvolution::isAddRecNeverPoison(const Instruction *I, const Loop *L) {
7342 // If we know that \c I can never be poison period, then that's enough.
7343 if (isSCEVExprNeverPoison(I))
7344 return true;
7345
7346 // If the loop only has one exit, then we know that, if the loop is entered,
7347 // any instruction dominating that exit will be executed. If any such
7348 // instruction would result in UB, the addrec cannot be poison.
7349 //
7350 // This is basically the same reasoning as in isSCEVExprNeverPoison(), but
7351 // also handles uses outside the loop header (they just need to dominate the
7352 // single exit).
7353
7354 auto *ExitingBB = L->getExitingBlock();
7355 if (!ExitingBB || !loopHasNoAbnormalExits(L))
7356 return false;
7357
7360
7361 // We start by assuming \c I, the post-inc add recurrence, is poison. Only
7362 // things that are known to be poison under that assumption go on the
7363 // Worklist.
7364 KnownPoison.insert(I);
7365 Worklist.push_back(I);
7366
7367 while (!Worklist.empty()) {
7368 const Instruction *Poison = Worklist.pop_back_val();
7369
7370 for (const Use &U : Poison->uses()) {
7371 const Instruction *PoisonUser = cast<Instruction>(U.getUser());
7372 if (mustTriggerUB(PoisonUser, KnownPoison) &&
7373 DT.dominates(PoisonUser->getParent(), ExitingBB))
7374 return true;
7375
7376 if (propagatesPoison(U) && L->contains(PoisonUser))
7377 if (KnownPoison.insert(PoisonUser).second)
7378 Worklist.push_back(PoisonUser);
7379 }
7380 }
7381
7382 return false;
7383}
7384
7385ScalarEvolution::LoopProperties
7386ScalarEvolution::getLoopProperties(const Loop *L) {
7387 using LoopProperties = ScalarEvolution::LoopProperties;
7388
7389 auto Itr = LoopPropertiesCache.find(L);
7390 if (Itr == LoopPropertiesCache.end()) {
7391 auto HasSideEffects = [](Instruction *I) {
7392 if (auto *SI = dyn_cast<StoreInst>(I))
7393 return !SI->isSimple();
7394
7395 return I->mayThrow() || I->mayWriteToMemory();
7396 };
7397
7398 LoopProperties LP = {/* HasNoAbnormalExits */ true,
7399 /*HasNoSideEffects*/ true};
7400
7401 for (auto *BB : L->getBlocks())
7402 for (auto &I : *BB) {
7404 LP.HasNoAbnormalExits = false;
7405 if (HasSideEffects(&I))
7406 LP.HasNoSideEffects = false;
7407 if (!LP.HasNoAbnormalExits && !LP.HasNoSideEffects)
7408 break; // We're already as pessimistic as we can get.
7409 }
7410
7411 auto InsertPair = LoopPropertiesCache.insert({L, LP});
7412 assert(InsertPair.second && "We just checked!");
7413 Itr = InsertPair.first;
7414 }
7415
7416 return Itr->second;
7417}
7418
7420 // A mustprogress loop without side effects must be finite.
7421 // TODO: The check used here is very conservative. It's only *specific*
7422 // side effects which are well defined in infinite loops.
7423 return isFinite(L) || (isMustProgress(L) && loopHasNoSideEffects(L));
7424}
7425
7426const SCEV *ScalarEvolution::createSCEVIter(Value *V) {
7427 // Worklist item with a Value and a bool indicating whether all operands have
7428 // been visited already.
7431
7432 Stack.emplace_back(V, true);
7433 Stack.emplace_back(V, false);
7434 while (!Stack.empty()) {
7435 auto E = Stack.pop_back_val();
7436 Value *CurV = E.getPointer();
7437
7438 if (getExistingSCEV(CurV))
7439 continue;
7440
7442 const SCEV *CreatedSCEV = nullptr;
7443 // If all operands have been visited already, create the SCEV.
7444 if (E.getInt()) {
7445 CreatedSCEV = createSCEV(CurV);
7446 } else {
7447 // Otherwise get the operands we need to create SCEV's for before creating
7448 // the SCEV for CurV. If the SCEV for CurV can be constructed trivially,
7449 // just use it.
7450 CreatedSCEV = getOperandsToCreate(CurV, Ops);
7451 }
7452
7453 if (CreatedSCEV) {
7454 insertValueToMap(CurV, CreatedSCEV);
7455 } else {
7456 // Queue CurV for SCEV creation, followed by its's operands which need to
7457 // be constructed first.
7458 Stack.emplace_back(CurV, true);
7459 for (Value *Op : Ops)
7460 Stack.emplace_back(Op, false);
7461 }
7462 }
7463
7464 return getExistingSCEV(V);
7465}
7466
7467const SCEV *
7468ScalarEvolution::getOperandsToCreate(Value *V, SmallVectorImpl<Value *> &Ops) {
7469 if (!isSCEVable(V->getType()))
7470 return getUnknown(V);
7471
7472 if (Instruction *I = dyn_cast<Instruction>(V)) {
7473 // Don't attempt to analyze instructions in blocks that aren't
7474 // reachable. Such instructions don't matter, and they aren't required
7475 // to obey basic rules for definitions dominating uses which this
7476 // analysis depends on.
7477 if (!DT.isReachableFromEntry(I->getParent()))
7478 return getUnknown(PoisonValue::get(V->getType()));
7479 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7480 return getConstant(CI);
7481 else if (isa<GlobalAlias>(V))
7482 return getUnknown(V);
7483 else if (!isa<ConstantExpr>(V))
7484 return getUnknown(V);
7485
7486 Operator *U = cast<Operator>(V);
7487 if (auto BO =
7488 MatchBinaryOp(U, getDataLayout(), AC, DT, dyn_cast<Instruction>(V))) {
7489 bool IsConstArg = isa<ConstantInt>(BO->RHS);
7490 switch (BO->Opcode) {
7491 case Instruction::Add:
7492 case Instruction::Mul: {
7493 // For additions and multiplications, traverse add/mul chains for which we
7494 // can potentially create a single SCEV, to reduce the number of
7495 // get{Add,Mul}Expr calls.
7496 do {
7497 if (BO->Op) {
7498 if (BO->Op != V && getExistingSCEV(BO->Op)) {
7499 Ops.push_back(BO->Op);
7500 break;
7501 }
7502 }
7503 Ops.push_back(BO->RHS);
7504 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7505 dyn_cast<Instruction>(V));
7506 if (!NewBO ||
7507 (BO->Opcode == Instruction::Add &&
7508 (NewBO->Opcode != Instruction::Add &&
7509 NewBO->Opcode != Instruction::Sub)) ||
7510 (BO->Opcode == Instruction::Mul &&
7511 NewBO->Opcode != Instruction::Mul)) {
7512 Ops.push_back(BO->LHS);
7513 break;
7514 }
7515 // CreateSCEV calls getNoWrapFlagsFromUB, which under certain conditions
7516 // requires a SCEV for the LHS.
7517 if (BO->Op && (BO->IsNSW || BO->IsNUW)) {
7518 auto *I = dyn_cast<Instruction>(BO->Op);
7519 if (I && programUndefinedIfPoison(I)) {
7520 Ops.push_back(BO->LHS);
7521 break;
7522 }
7523 }
7524 BO = NewBO;
7525 } while (true);
7526 return nullptr;
7527 }
7528 case Instruction::Sub:
7529 case Instruction::UDiv:
7530 case Instruction::URem:
7531 break;
7532 case Instruction::AShr:
7533 case Instruction::Shl:
7534 case Instruction::Xor:
7535 if (!IsConstArg)
7536 return nullptr;
7537 break;
7538 case Instruction::And:
7539 case Instruction::Or:
7540 if (!IsConstArg && !BO->LHS->getType()->isIntegerTy(1))
7541 return nullptr;
7542 break;
7543 case Instruction::LShr:
7544 return getUnknown(V);
7545 default:
7546 llvm_unreachable("Unhandled binop");
7547 break;
7548 }
7549
7550 Ops.push_back(BO->LHS);
7551 Ops.push_back(BO->RHS);
7552 return nullptr;
7553 }
7554
7555 switch (U->getOpcode()) {
7556 case Instruction::Trunc:
7557 case Instruction::ZExt:
7558 case Instruction::SExt:
7559 case Instruction::PtrToInt:
7560 Ops.push_back(U->getOperand(0));
7561 return nullptr;
7562
7563 case Instruction::BitCast:
7564 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType())) {
7565 Ops.push_back(U->getOperand(0));
7566 return nullptr;
7567 }
7568 return getUnknown(V);
7569
7570 case Instruction::SDiv:
7571 case Instruction::SRem:
7572 Ops.push_back(U->getOperand(0));
7573 Ops.push_back(U->getOperand(1));
7574 return nullptr;
7575
7576 case Instruction::GetElementPtr:
7577 assert(cast<GEPOperator>(U)->getSourceElementType()->isSized() &&
7578 "GEP source element type must be sized");
7579 for (Value *Index : U->operands())
7580 Ops.push_back(Index);
7581 return nullptr;
7582
7583 case Instruction::IntToPtr:
7584 return getUnknown(V);
7585
7586 case Instruction::PHI:
7587 // Keep constructing SCEVs' for phis recursively for now.
7588 return nullptr;
7589
7590 case Instruction::Select: {
7591 // Check if U is a select that can be simplified to a SCEVUnknown.
7592 auto CanSimplifyToUnknown = [this, U]() {
7593 if (U->getType()->isIntegerTy(1) || isa<ConstantInt>(U->getOperand(0)))
7594 return false;
7595
7596 auto *ICI = dyn_cast<ICmpInst>(U->getOperand(0));
7597 if (!ICI)
7598 return false;
7599 Value *LHS = ICI->getOperand(0);
7600 Value *RHS = ICI->getOperand(1);
7601 if (ICI->getPredicate() == CmpInst::ICMP_EQ ||
7602 ICI->getPredicate() == CmpInst::ICMP_NE) {
7603 if (!(isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero()))
7604 return true;
7605 } else if (getTypeSizeInBits(LHS->getType()) >
7606 getTypeSizeInBits(U->getType()))
7607 return true;
7608 return false;
7609 };
7610 if (CanSimplifyToUnknown())
7611 return getUnknown(U);
7612
7613 for (Value *Inc : U->operands())
7614 Ops.push_back(Inc);
7615 return nullptr;
7616 break;
7617 }
7618 case Instruction::Call:
7619 case Instruction::Invoke:
7620 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand()) {
7621 Ops.push_back(RV);
7622 return nullptr;
7623 }
7624
7625 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
7626 switch (II->getIntrinsicID()) {
7627 case Intrinsic::abs:
7628 Ops.push_back(II->getArgOperand(0));
7629 return nullptr;
7630 case Intrinsic::umax:
7631 case Intrinsic::umin:
7632 case Intrinsic::smax:
7633 case Intrinsic::smin:
7634 case Intrinsic::usub_sat:
7635 case Intrinsic::uadd_sat:
7636 Ops.push_back(II->getArgOperand(0));
7637 Ops.push_back(II->getArgOperand(1));
7638 return nullptr;
7639 case Intrinsic::start_loop_iterations:
7640 case Intrinsic::annotation:
7641 case Intrinsic::ptr_annotation:
7642 Ops.push_back(II->getArgOperand(0));
7643 return nullptr;
7644 default:
7645 break;
7646 }
7647 }
7648 break;
7649 }
7650
7651 return nullptr;
7652}
7653
7654const SCEV *ScalarEvolution::createSCEV(Value *V) {
7655 if (!isSCEVable(V->getType()))
7656 return getUnknown(V);
7657
7658 if (Instruction *I = dyn_cast<Instruction>(V)) {
7659 // Don't attempt to analyze instructions in blocks that aren't
7660 // reachable. Such instructions don't matter, and they aren't required
7661 // to obey basic rules for definitions dominating uses which this
7662 // analysis depends on.
7663 if (!DT.isReachableFromEntry(I->getParent()))
7664 return getUnknown(PoisonValue::get(V->getType()));
7665 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7666 return getConstant(CI);
7667 else if (isa<GlobalAlias>(V))
7668 return getUnknown(V);
7669 else if (!isa<ConstantExpr>(V))
7670 return getUnknown(V);
7671
7672 const SCEV *LHS;
7673 const SCEV *RHS;
7674
7675 Operator *U = cast<Operator>(V);
7676 if (auto BO =
7677 MatchBinaryOp(U, getDataLayout(), AC, DT, dyn_cast<Instruction>(V))) {
7678 switch (BO->Opcode) {
7679 case Instruction::Add: {
7680 // The simple thing to do would be to just call getSCEV on both operands
7681 // and call getAddExpr with the result. However if we're looking at a
7682 // bunch of things all added together, this can be quite inefficient,
7683 // because it leads to N-1 getAddExpr calls for N ultimate operands.
7684 // Instead, gather up all the operands and make a single getAddExpr call.
7685 // LLVM IR canonical form means we need only traverse the left operands.
7687 do {
7688 if (BO->Op) {
7689 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
7690 AddOps.push_back(OpSCEV);
7691 break;
7692 }
7693
7694 // If a NUW or NSW flag can be applied to the SCEV for this
7695 // addition, then compute the SCEV for this addition by itself
7696 // with a separate call to getAddExpr. We need to do that
7697 // instead of pushing the operands of the addition onto AddOps,
7698 // since the flags are only known to apply to this particular
7699 // addition - they may not apply to other additions that can be
7700 // formed with operands from AddOps.
7701 const SCEV *RHS = getSCEV(BO->RHS);
7702 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
7703 if (Flags != SCEV::FlagAnyWrap) {
7704 const SCEV *LHS = getSCEV(BO->LHS);
7705 if (BO->Opcode == Instruction::Sub)
7706 AddOps.push_back(getMinusSCEV(LHS, RHS, Flags));
7707 else
7708 AddOps.push_back(getAddExpr(LHS, RHS, Flags));
7709 break;
7710 }
7711 }
7712
7713 if (BO->Opcode == Instruction::Sub)
7714 AddOps.push_back(getNegativeSCEV(getSCEV(BO->RHS)));
7715 else
7716 AddOps.push_back(getSCEV(BO->RHS));
7717
7718 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7719 dyn_cast<Instruction>(V));
7720 if (!NewBO || (NewBO->Opcode != Instruction::Add &&
7721 NewBO->Opcode != Instruction::Sub)) {
7722 AddOps.push_back(getSCEV(BO->LHS));
7723 break;
7724 }
7725 BO = NewBO;
7726 } while (true);
7727
7728 return getAddExpr(AddOps);
7729 }
7730
7731 case Instruction::Mul: {
7733 do {
7734 if (BO->Op) {
7735 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
7736 MulOps.push_back(OpSCEV);
7737 break;
7738 }
7739
7740 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
7741 if (Flags != SCEV::FlagAnyWrap) {
7742 LHS = getSCEV(BO->LHS);
7743 RHS = getSCEV(BO->RHS);
7744 MulOps.push_back(getMulExpr(LHS, RHS, Flags));
7745 break;
7746 }
7747 }
7748
7749 MulOps.push_back(getSCEV(BO->RHS));
7750 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7751 dyn_cast<Instruction>(V));
7752 if (!NewBO || NewBO->Opcode != Instruction::Mul) {
7753 MulOps.push_back(getSCEV(BO->LHS));
7754 break;
7755 }
7756 BO = NewBO;
7757 } while (true);
7758
7759 return getMulExpr(MulOps);
7760 }
7761 case Instruction::UDiv:
7762 LHS = getSCEV(BO->LHS);
7763 RHS = getSCEV(BO->RHS);
7764 return getUDivExpr(LHS, RHS);
7765 case Instruction::URem:
7766 LHS = getSCEV(BO->LHS);
7767 RHS = getSCEV(BO->RHS);
7768 return getURemExpr(LHS, RHS);
7769 case Instruction::Sub: {
7771 if (BO->Op)
7772 Flags = getNoWrapFlagsFromUB(BO->Op);
7773 LHS = getSCEV(BO->LHS);
7774 RHS = getSCEV(BO->RHS);
7775 return getMinusSCEV(LHS, RHS, Flags);
7776 }
7777 case Instruction::And:
7778 // For an expression like x&255 that merely masks off the high bits,
7779 // use zext(trunc(x)) as the SCEV expression.
7780 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
7781 if (CI->isZero())
7782 return getSCEV(BO->RHS);
7783 if (CI->isMinusOne())
7784 return getSCEV(BO->LHS);
7785 const APInt &A = CI->getValue();
7786
7787 // Instcombine's ShrinkDemandedConstant may strip bits out of
7788 // constants, obscuring what would otherwise be a low-bits mask.
7789 // Use computeKnownBits to compute what ShrinkDemandedConstant
7790 // knew about to reconstruct a low-bits mask value.
7791 unsigned LZ = A.countl_zero();
7792 unsigned TZ = A.countr_zero();
7793 unsigned BitWidth = A.getBitWidth();
7794 KnownBits Known(BitWidth);
7795 computeKnownBits(BO->LHS, Known, getDataLayout(),
7796 0, &AC, nullptr, &DT);
7797
7798 APInt EffectiveMask =
7799 APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ);
7800 if ((LZ != 0 || TZ != 0) && !((~A & ~Known.Zero) & EffectiveMask)) {
7801 const SCEV *MulCount = getConstant(APInt::getOneBitSet(BitWidth, TZ));
7802 const SCEV *LHS = getSCEV(BO->LHS);
7803 const SCEV *ShiftedLHS = nullptr;
7804 if (auto *LHSMul = dyn_cast<SCEVMulExpr>(LHS)) {
7805 if (auto *OpC = dyn_cast<SCEVConstant>(LHSMul->getOperand(0))) {
7806 // For an expression like (x * 8) & 8, simplify the multiply.
7807 unsigned MulZeros = OpC->getAPInt().countr_zero();
7808 unsigned GCD = std::min(MulZeros, TZ);
7809 APInt DivAmt = APInt::getOneBitSet(BitWidth, TZ - GCD);
7811 MulOps.push_back(getConstant(OpC->getAPInt().lshr(GCD)));
7812 append_range(MulOps, LHSMul->operands().drop_front());
7813 auto *NewMul = getMulExpr(MulOps, LHSMul->getNoWrapFlags());
7814 ShiftedLHS = getUDivExpr(NewMul, getConstant(DivAmt));
7815 }
7816 }
7817 if (!ShiftedLHS)
7818 ShiftedLHS = getUDivExpr(LHS, MulCount);
7819 return getMulExpr(
7821 getTruncateExpr(ShiftedLHS,
7822 IntegerType::get(getContext(), BitWidth - LZ - TZ)),
7823 BO->LHS->getType()),
7824 MulCount);
7825 }
7826 }
7827 // Binary `and` is a bit-wise `umin`.
7828 if (BO->LHS->getType()->isIntegerTy(1)) {
7829 LHS = getSCEV(BO->LHS);
7830 RHS = getSCEV(BO->RHS);
7831 return getUMinExpr(LHS, RHS);
7832 }
7833 break;
7834
7835 case Instruction::Or:
7836 // Binary `or` is a bit-wise `umax`.
7837 if (BO->LHS->getType()->isIntegerTy(1)) {
7838 LHS = getSCEV(BO->LHS);
7839 RHS = getSCEV(BO->RHS);
7840 return getUMaxExpr(LHS, RHS);
7841 }
7842 break;
7843
7844 case Instruction::Xor:
7845 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
7846 // If the RHS of xor is -1, then this is a not operation.
7847 if (CI->isMinusOne())
7848 return getNotSCEV(getSCEV(BO->LHS));
7849
7850 // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask.
7851 // This is a variant of the check for xor with -1, and it handles
7852 // the case where instcombine has trimmed non-demanded bits out
7853 // of an xor with -1.
7854 if (auto *LBO = dyn_cast<BinaryOperator>(BO->LHS))
7855 if (ConstantInt *LCI = dyn_cast<ConstantInt>(LBO->getOperand(1)))
7856 if (LBO->getOpcode() == Instruction::And &&
7857 LCI->getValue() == CI->getValue())
7858 if (const SCEVZeroExtendExpr *Z =
7859 dyn_cast<SCEVZeroExtendExpr>(getSCEV(BO->LHS))) {
7860 Type *UTy = BO->LHS->getType();
7861 const SCEV *Z0 = Z->getOperand();
7862 Type *Z0Ty = Z0->getType();
7863 unsigned Z0TySize = getTypeSizeInBits(Z0Ty);
7864
7865 // If C is a low-bits mask, the zero extend is serving to
7866 // mask off the high bits. Complement the operand and
7867 // re-apply the zext.
7868 if (CI->getValue().isMask(Z0TySize))
7869 return getZeroExtendExpr(getNotSCEV(Z0), UTy);
7870
7871 // If C is a single bit, it may be in the sign-bit position
7872 // before the zero-extend. In this case, represent the xor
7873 // using an add, which is equivalent, and re-apply the zext.
7874 APInt Trunc = CI->getValue().trunc(Z0TySize);
7875 if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() &&
7876 Trunc.isSignMask())
7877 return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)),
7878 UTy);
7879 }
7880 }
7881 break;
7882
7883 case Instruction::Shl:
7884 // Turn shift left of a constant amount into a multiply.
7885 if (ConstantInt *SA = dyn_cast<ConstantInt>(BO->RHS)) {
7886 uint32_t BitWidth = cast<IntegerType>(SA->getType())->getBitWidth();
7887
7888 // If the shift count is not less than the bitwidth, the result of
7889 // the shift is undefined. Don't try to analyze it, because the
7890 // resolution chosen here may differ from the resolution chosen in
7891 // other parts of the compiler.
7892 if (SA->getValue().uge(BitWidth))
7893 break;
7894
7895 // We can safely preserve the nuw flag in all cases. It's also safe to
7896 // turn a nuw nsw shl into a nuw nsw mul. However, nsw in isolation
7897 // requires special handling. It can be preserved as long as we're not
7898 // left shifting by bitwidth - 1.
7899 auto Flags = SCEV::FlagAnyWrap;
7900 if (BO->Op) {
7901 auto MulFlags = getNoWrapFlagsFromUB(BO->Op);
7902 if ((MulFlags & SCEV::FlagNSW) &&
7903 ((MulFlags & SCEV::FlagNUW) || SA->getValue().ult(BitWidth - 1)))
7905 if (MulFlags & SCEV::FlagNUW)
7907 }
7908
7909 ConstantInt *X = ConstantInt::get(
7910 getContext(), APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
7911 return getMulExpr(getSCEV(BO->LHS), getConstant(X), Flags);
7912 }
7913 break;
7914
7915 case Instruction::AShr:
7916 // AShr X, C, where C is a constant.
7917 ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS);
7918 if (!CI)
7919 break;
7920
7921 Type *OuterTy = BO->LHS->getType();
7923 // If the shift count is not less than the bitwidth, the result of
7924 // the shift is undefined. Don't try to analyze it, because the
7925 // resolution chosen here may differ from the resolution chosen in
7926 // other parts of the compiler.
7927 if (CI->getValue().uge(BitWidth))
7928 break;
7929
7930 if (CI->isZero())
7931 return getSCEV(BO->LHS); // shift by zero --> noop
7932
7933 uint64_t AShrAmt = CI->getZExtValue();
7934 Type *TruncTy = IntegerType::get(getContext(), BitWidth - AShrAmt);
7935
7936 Operator *L = dyn_cast<Operator>(BO->LHS);
7937 const SCEV *AddTruncateExpr = nullptr;
7938 ConstantInt *ShlAmtCI = nullptr;
7939 const SCEV *AddConstant = nullptr;
7940
7941 if (L && L->getOpcode() == Instruction::Add) {
7942 // X = Shl A, n
7943 // Y = Add X, c
7944 // Z = AShr Y, m
7945 // n, c and m are constants.
7946
7947 Operator *LShift = dyn_cast<Operator>(L->getOperand(0));
7948 ConstantInt *AddOperandCI = dyn_cast<ConstantInt>(L->getOperand(1));
7949 if (LShift && LShift->getOpcode() == Instruction::Shl) {
7950 if (AddOperandCI) {
7951 const SCEV *ShlOp0SCEV = getSCEV(LShift->getOperand(0));
7952 ShlAmtCI = dyn_cast<ConstantInt>(LShift->getOperand(1));
7953 // since we truncate to TruncTy, the AddConstant should be of the
7954 // same type, so create a new Constant with type same as TruncTy.
7955 // Also, the Add constant should be shifted right by AShr amount.
7956 APInt AddOperand = AddOperandCI->getValue().ashr(AShrAmt);
7957 AddConstant = getConstant(AddOperand.trunc(BitWidth - AShrAmt));
7958 // we model the expression as sext(add(trunc(A), c << n)), since the
7959 // sext(trunc) part is already handled below, we create a
7960 // AddExpr(TruncExp) which will be used later.
7961 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
7962 }
7963 }
7964 } else if (L && L->getOpcode() == Instruction::Shl) {
7965 // X = Shl A, n
7966 // Y = AShr X, m
7967 // Both n and m are constant.
7968
7969 const SCEV *ShlOp0SCEV = getSCEV(L->getOperand(0));
7970 ShlAmtCI = dyn_cast<ConstantInt>(L->getOperand(1));
7971 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
7972 }
7973
7974 if (AddTruncateExpr && ShlAmtCI) {
7975 // We can merge the two given cases into a single SCEV statement,
7976 // incase n = m, the mul expression will be 2^0, so it gets resolved to
7977 // a simpler case. The following code handles the two cases:
7978 //
7979 // 1) For a two-shift sext-inreg, i.e. n = m,
7980 // use sext(trunc(x)) as the SCEV expression.
7981 //
7982 // 2) When n > m, use sext(mul(trunc(x), 2^(n-m)))) as the SCEV
7983 // expression. We already checked that ShlAmt < BitWidth, so
7984 // the multiplier, 1 << (ShlAmt - AShrAmt), fits into TruncTy as
7985 // ShlAmt - AShrAmt < Amt.
7986 const APInt &ShlAmt = ShlAmtCI->getValue();
7987 if (ShlAmt.ult(BitWidth) && ShlAmt.uge(AShrAmt)) {
7989 ShlAmtCI->getZExtValue() - AShrAmt);
7990 const SCEV *CompositeExpr =
7991 getMulExpr(AddTruncateExpr, getConstant(Mul));
7992 if (L->getOpcode() != Instruction::Shl)
7993 CompositeExpr = getAddExpr(CompositeExpr, AddConstant);
7994
7995 return getSignExtendExpr(CompositeExpr, OuterTy);
7996 }
7997 }
7998 break;
7999 }
8000 }
8001
8002 switch (U->getOpcode()) {
8003 case Instruction::Trunc:
8004 return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType());
8005
8006 case Instruction::ZExt:
8007 return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType());
8008
8009 case Instruction::SExt:
8010 if (auto BO = MatchBinaryOp(U->getOperand(0), getDataLayout(), AC, DT,
8011 dyn_cast<Instruction>(V))) {
8012 // The NSW flag of a subtract does not always survive the conversion to
8013 // A + (-1)*B. By pushing sign extension onto its operands we are much
8014 // more likely to preserve NSW and allow later AddRec optimisations.
8015 //
8016 // NOTE: This is effectively duplicating this logic from getSignExtend:
8017 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
8018 // but by that point the NSW information has potentially been lost.
8019 if (BO->Opcode == Instruction::Sub && BO->IsNSW) {
8020 Type *Ty = U->getType();
8021 auto *V1 = getSignExtendExpr(getSCEV(BO->LHS), Ty);
8022 auto *V2 = getSignExtendExpr(getSCEV(BO->RHS), Ty);
8023 return getMinusSCEV(V1, V2, SCEV::FlagNSW);
8024 }
8025 }
8026 return getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType());
8027
8028 case Instruction::BitCast:
8029 // BitCasts are no-op casts so we just eliminate the cast.
8030 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType()))
8031 return getSCEV(U->getOperand(0));
8032 break;
8033
8034 case Instruction::PtrToInt: {
8035 // Pointer to integer cast is straight-forward, so do model it.
8036 const SCEV *Op = getSCEV(U->getOperand(0));
8037 Type *DstIntTy = U->getType();
8038 // But only if effective SCEV (integer) type is wide enough to represent
8039 // all possible pointer values.
8040 const SCEV *IntOp = getPtrToIntExpr(Op, DstIntTy);
8041 if (isa<SCEVCouldNotCompute>(IntOp))
8042 return getUnknown(V);
8043 return IntOp;
8044 }
8045 case Instruction::IntToPtr:
8046 // Just don't deal with inttoptr casts.
8047 return getUnknown(V);
8048
8049 case Instruction::SDiv:
8050 // If both operands are non-negative, this is just an udiv.
8051 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8052 isKnownNonNegative(getSCEV(U->getOperand(1))))
8053 return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8054 break;
8055
8056 case Instruction::SRem:
8057 // If both operands are non-negative, this is just an urem.
8058 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8059 isKnownNonNegative(getSCEV(U->getOperand(1))))
8060 return getURemExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8061 break;
8062
8063 case Instruction::GetElementPtr:
8064 return createNodeForGEP(cast<GEPOperator>(U));
8065
8066 case Instruction::PHI:
8067 return createNodeForPHI(cast<PHINode>(U));
8068
8069 case Instruction::Select:
8070 return createNodeForSelectOrPHI(U, U->getOperand(0), U->getOperand(1),
8071 U->getOperand(2));
8072
8073 case Instruction::Call:
8074 case Instruction::Invoke:
8075 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand())
8076 return getSCEV(RV);
8077
8078 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
8079 switch (II->getIntrinsicID()) {
8080 case Intrinsic::abs:
8081 return getAbsExpr(
8082 getSCEV(II->getArgOperand(0)),
8083 /*IsNSW=*/cast<ConstantInt>(II->getArgOperand(1))->isOne());
8084 case Intrinsic::umax:
8085 LHS = getSCEV(II->getArgOperand(0));
8086 RHS = getSCEV(II->getArgOperand(1));
8087 return getUMaxExpr(LHS, RHS);
8088 case Intrinsic::umin:
8089 LHS = getSCEV(II->getArgOperand(0));
8090 RHS = getSCEV(II->getArgOperand(1));
8091 return getUMinExpr(LHS, RHS);
8092 case Intrinsic::smax:
8093 LHS = getSCEV(II->getArgOperand(0));
8094 RHS = getSCEV(II->getArgOperand(1));
8095 return getSMaxExpr(LHS, RHS);
8096 case Intrinsic::smin:
8097 LHS = getSCEV(II->getArgOperand(0));
8098 RHS = getSCEV(II->getArgOperand(1));
8099 return getSMinExpr(LHS, RHS);
8100 case Intrinsic::usub_sat: {
8101 const SCEV *X = getSCEV(II->getArgOperand(0));
8102 const SCEV *Y = getSCEV(II->getArgOperand(1));
8103 const SCEV *ClampedY = getUMinExpr(X, Y);
8104 return getMinusSCEV(X, ClampedY, SCEV::FlagNUW);
8105 }
8106 case Intrinsic::uadd_sat: {
8107 const SCEV *X = getSCEV(II->getArgOperand(0));
8108 const SCEV *Y = getSCEV(II->getArgOperand(1));
8109 const SCEV *ClampedX = getUMinExpr(X, getNotSCEV(Y));
8110 return getAddExpr(ClampedX, Y, SCEV::FlagNUW);
8111 }
8112 case Intrinsic::start_loop_iterations:
8113 case Intrinsic::annotation:
8114 case Intrinsic::ptr_annotation:
8115 // A start_loop_iterations or llvm.annotation or llvm.prt.annotation is
8116 // just eqivalent to the first operand for SCEV purposes.
8117 return getSCEV(II->getArgOperand(0));
8118 case Intrinsic::vscale:
8119 return getVScale(II->getType());
8120 default:
8121 break;
8122 }
8123 }
8124 break;
8125 }
8126
8127 return getUnknown(V);
8128}
8129
8130//===----------------------------------------------------------------------===//
8131// Iteration Count Computation Code
8132//
8133
8135 if (isa<SCEVCouldNotCompute>(ExitCount))
8136 return getCouldNotCompute();
8137
8138 auto *ExitCountType = ExitCount->getType();
8139 assert(ExitCountType->isIntegerTy());
8140 auto *EvalTy = Type::getIntNTy(ExitCountType->getContext(),
8141 1 + ExitCountType->getScalarSizeInBits());
8142 return getTripCountFromExitCount(ExitCount, EvalTy, nullptr);
8143}
8144
8146 Type *EvalTy,
8147 const Loop *L) {
8148 if (isa<SCEVCouldNotCompute>(ExitCount))
8149 return getCouldNotCompute();
8150
8151 unsigned ExitCountSize = getTypeSizeInBits(ExitCount->getType());
8152 unsigned EvalSize = EvalTy->getPrimitiveSizeInBits();
8153
8154 auto CanAddOneWithoutOverflow = [&]() {
8155 ConstantRange ExitCountRange =
8156 getRangeRef(ExitCount, RangeSignHint::HINT_RANGE_UNSIGNED);
8157 if (!ExitCountRange.contains(APInt::getMaxValue(ExitCountSize)))
8158 return true;
8159
8160 return L && isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, ExitCount,
8161 getMinusOne(ExitCount->getType()));
8162 };
8163
8164 // If we need to zero extend the backedge count, check if we can add one to
8165 // it prior to zero extending without overflow. Provided this is safe, it
8166 // allows better simplification of the +1.
8167 if (EvalSize > ExitCountSize && CanAddOneWithoutOverflow())
8168 return getZeroExtendExpr(
8169 getAddExpr(ExitCount, getOne(ExitCount->getType())), EvalTy);
8170
8171 // Get the total trip count from the count by adding 1. This may wrap.
8172 return getAddExpr(getTruncateOrZeroExtend(ExitCount, EvalTy), getOne(EvalTy));
8173}
8174
8175static unsigned getConstantTripCount(const SCEVConstant *ExitCount) {
8176 if (!ExitCount)
8177 return 0;
8178
8179 ConstantInt *ExitConst = ExitCount->getValue();
8180
8181 // Guard against huge trip counts.
8182 if (ExitConst->getValue().getActiveBits() > 32)
8183 return 0;
8184
8185 // In case of integer overflow, this returns 0, which is correct.
8186 return ((unsigned)ExitConst->getZExtValue()) + 1;
8187}
8188
8190 auto *ExitCount = dyn_cast<SCEVConstant>(getBackedgeTakenCount(L, Exact));
8191 return getConstantTripCount(ExitCount);
8192}
8193
8194unsigned
8196 const BasicBlock *ExitingBlock) {
8197 assert(ExitingBlock && "Must pass a non-null exiting block!");
8198 assert(L->isLoopExiting(ExitingBlock) &&
8199 "Exiting block must actually branch out of the loop!");
8200 const SCEVConstant *ExitCount =
8201 dyn_cast<SCEVConstant>(getExitCount(L, ExitingBlock));
8202 return getConstantTripCount(ExitCount);
8203}
8204
8206 const auto *MaxExitCount =
8207 dyn_cast<SCEVConstant>(getConstantMaxBackedgeTakenCount(L));
8208 return getConstantTripCount(MaxExitCount);
8209}
8210
8212 SmallVector<BasicBlock *, 8> ExitingBlocks;
8213 L->getExitingBlocks(ExitingBlocks);
8214
8215 std::optional<unsigned> Res;
8216 for (auto *ExitingBB : ExitingBlocks) {
8217 unsigned Multiple = getSmallConstantTripMultiple(L, ExitingBB);
8218 if (!Res)
8219 Res = Multiple;
8220 Res = (unsigned)std::gcd(*Res, Multiple);
8221 }
8222 return Res.value_or(1);
8223}
8224
8226 const SCEV *ExitCount) {
8227 if (ExitCount == getCouldNotCompute())
8228 return 1;
8229
8230 // Get the trip count
8231 const SCEV *TCExpr = getTripCountFromExitCount(applyLoopGuards(ExitCount, L));
8232
8233 APInt Multiple = getNonZeroConstantMultiple(TCExpr);
8234 // If a trip multiple is huge (>=2^32), the trip count is still divisible by
8235 // the greatest power of 2 divisor less than 2^32.
8236 return Multiple.getActiveBits() > 32
8237 ? 1U << std::min((unsigned)31, Multiple.countTrailingZeros())
8238 : (unsigned)Multiple.zextOrTrunc(32).getZExtValue();
8239}
8240
8241/// Returns the largest constant divisor of the trip count of this loop as a
8242/// normal unsigned value, if possible. This means that the actual trip count is
8243/// always a multiple of the returned value (don't forget the trip count could
8244/// very well be zero as well!).
8245///
8246/// Returns 1 if the trip count is unknown or not guaranteed to be the
8247/// multiple of a constant (which is also the case if the trip count is simply
8248/// constant, use getSmallConstantTripCount for that case), Will also return 1
8249/// if the trip count is very large (>= 2^32).
8250///
8251/// As explained in the comments for getSmallConstantTripCount, this assumes
8252/// that control exits the loop via ExitingBlock.
8253unsigned
8255 const BasicBlock *ExitingBlock) {
8256 assert(ExitingBlock && "Must pass a non-null exiting block!");
8257 assert(L->isLoopExiting(ExitingBlock) &&
8258 "Exiting block must actually branch out of the loop!");
8259 const SCEV *ExitCount = getExitCount(L, ExitingBlock);
8260 return getSmallConstantTripMultiple(L, ExitCount);
8261}
8262
8264 const BasicBlock *ExitingBlock,
8265 ExitCountKind Kind) {
8266 switch (Kind) {
8267 case Exact:
8268 return getBackedgeTakenInfo(L).getExact(ExitingBlock, this);
8269 case SymbolicMaximum:
8270 return getBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this);
8271 case ConstantMaximum:
8272 return getBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this);
8273 };
8274 llvm_unreachable("Invalid ExitCountKind!");
8275}
8276
8277const SCEV *
8280 return getPredicatedBackedgeTakenInfo(L).getExact(L, this, &Preds);
8281}
8282
8284 ExitCountKind Kind) {
8285 switch (Kind) {
8286 case Exact:
8287 return getBackedgeTakenInfo(L).getExact(L, this);
8288 case ConstantMaximum:
8289 return getBackedgeTakenInfo(L).getConstantMax(this);
8290 case SymbolicMaximum:
8291 return getBackedgeTakenInfo(L).getSymbolicMax(L, this);
8292 };
8293 llvm_unreachable("Invalid ExitCountKind!");
8294}
8295
8297 return getBackedgeTakenInfo(L).isConstantMaxOrZero(this);
8298}
8299
8300/// Push PHI nodes in the header of the given loop onto the given Worklist.
8301static void PushLoopPHIs(const Loop *L,
8304 BasicBlock *Header = L->getHeader();
8305
8306 // Push all Loop-header PHIs onto the Worklist stack.
8307 for (PHINode &PN : Header->phis())
8308 if (Visited.insert(&PN).second)
8309 Worklist.push_back(&PN);
8310}
8311
8312const ScalarEvolution::BackedgeTakenInfo &
8313ScalarEvolution::getPredicatedBackedgeTakenInfo(const Loop *L) {
8314 auto &BTI = getBackedgeTakenInfo(L);
8315 if (BTI.hasFullInfo())
8316 return BTI;
8317
8318 auto Pair = PredicatedBackedgeTakenCounts.insert({L, BackedgeTakenInfo()});
8319
8320 if (!Pair.second)
8321 return Pair.first->second;
8322
8323 BackedgeTakenInfo Result =
8324 computeBackedgeTakenCount(L, /*AllowPredicates=*/true);
8325
8326 return PredicatedBackedgeTakenCounts.find(L)->second = std::move(Result);
8327}
8328
8329ScalarEvolution::BackedgeTakenInfo &
8330ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
8331 // Initially insert an invalid entry for this loop. If the insertion
8332 // succeeds, proceed to actually compute a backedge-taken count and
8333 // update the value. The temporary CouldNotCompute value tells SCEV
8334 // code elsewhere that it shouldn't attempt to request a new
8335 // backedge-taken count, which could result in infinite recursion.
8336 std::pair<DenseMap<const Loop *, BackedgeTakenInfo>::iterator, bool> Pair =
8337 BackedgeTakenCounts.insert({L, BackedgeTakenInfo()});
8338 if (!Pair.second)
8339 return Pair.first->second;
8340
8341 // computeBackedgeTakenCount may allocate memory for its result. Inserting it
8342 // into the BackedgeTakenCounts map transfers ownership. Otherwise, the result
8343 // must be cleared in this scope.
8344 BackedgeTakenInfo Result = computeBackedgeTakenCount(L);
8345
8346 // Now that we know more about the trip count for this loop, forget any
8347 // existing SCEV values for PHI nodes in this loop since they are only
8348 // conservative estimates made without the benefit of trip count
8349 // information. This invalidation is not necessary for correctness, and is
8350 // only done to produce more precise results.
8351 if (Result.hasAnyInfo()) {
8352 // Invalidate any expression using an addrec in this loop.
8354 auto LoopUsersIt = LoopUsers.find(L);
8355 if (LoopUsersIt != LoopUsers.end())
8356 append_range(ToForget, LoopUsersIt->second);
8357 forgetMemoizedResults(ToForget);
8358
8359 // Invalidate constant-evolved loop header phis.
8360 for (PHINode &PN : L->getHeader()->phis())
8361 ConstantEvolutionLoopExitValue.erase(&PN);
8362 }
8363
8364 // Re-lookup the insert position, since the call to
8365 // computeBackedgeTakenCount above could result in a
8366 // recusive call to getBackedgeTakenInfo (on a different
8367 // loop), which would invalidate the iterator computed
8368 // earlier.
8369 return BackedgeTakenCounts.find(L)->second = std::move(Result);
8370}
8371
8373 // This method is intended to forget all info about loops. It should
8374 // invalidate caches as if the following happened:
8375 // - The trip counts of all loops have changed arbitrarily
8376 // - Every llvm::Value has been updated in place to produce a different
8377 // result.
8378 BackedgeTakenCounts.clear();
8379 PredicatedBackedgeTakenCounts.clear();
8380 BECountUsers.clear();
8381 LoopPropertiesCache.clear();
8382 ConstantEvolutionLoopExitValue.clear();
8383 ValueExprMap.clear();
8384 ValuesAtScopes.clear();
8385 ValuesAtScopesUsers.clear();
8386 LoopDispositions.clear();
8387 BlockDispositions.clear();
8388 UnsignedRanges.clear();
8389 SignedRanges.clear();
8390 ExprValueMap.clear();
8391 HasRecMap.clear();
8392 ConstantMultipleCache.clear();
8393 PredicatedSCEVRewrites.clear();
8394 FoldCache.clear();
8395 FoldCacheUser.clear();
8396}
8397void ScalarEvolution::visitAndClearUsers(
8401 while (!Worklist.empty()) {
8402 Instruction *I = Worklist.pop_back_val();
8403 if (!isSCEVable(I->getType()))
8404 continue;
8405
8407 ValueExprMap.find_as(static_cast<Value *>(I));
8408 if (It != ValueExprMap.end()) {
8409 eraseValueFromMap(It->first);
8410 ToForget.push_back(It->second);
8411 if (PHINode *PN = dyn_cast<PHINode>(I))
8412 ConstantEvolutionLoopExitValue.erase(PN);
8413 }
8414
8415 PushDefUseChildren(I, Worklist, Visited);
8416 }
8417}
8418
8420 SmallVector<const Loop *, 16> LoopWorklist(1, L);
8424
8425 // Iterate over all the loops and sub-loops to drop SCEV information.
8426 while (!LoopWorklist.empty()) {
8427 auto *CurrL = LoopWorklist.pop_back_val();
8428
8429 // Drop any stored trip count value.
8430 forgetBackedgeTakenCounts(CurrL, /* Predicated */ false);
8431 forgetBackedgeTakenCounts(CurrL, /* Predicated */ true);
8432
8433 // Drop information about predicated SCEV rewrites for this loop.
8434 for (auto I = PredicatedSCEVRewrites.begin();
8435 I != PredicatedSCEVRewrites.end();) {
8436 std::pair<const SCEV *, const Loop *> Entry = I->first;
8437 if (Entry.second == CurrL)
8438 PredicatedSCEVRewrites.erase(I++);
8439 else
8440 ++I;
8441 }
8442
8443 auto LoopUsersItr = LoopUsers.find(CurrL);
8444 if (LoopUsersItr != LoopUsers.end()) {
8445 ToForget.insert(ToForget.end(), LoopUsersItr->second.begin(),
8446 LoopUsersItr->second.end());
8447 }
8448
8449 // Drop information about expressions based on loop-header PHIs.
8450 PushLoopPHIs(CurrL, Worklist, Visited);
8451 visitAndClearUsers(Worklist, Visited, ToForget);
8452
8453 LoopPropertiesCache.erase(CurrL);
8454 // Forget all contained loops too, to avoid dangling entries in the
8455 // ValuesAtScopes map.
8456 LoopWorklist.append(CurrL->begin(), CurrL->end());
8457 }
8458 forgetMemoizedResults(ToForget);
8459}
8460
8462 forgetLoop(L->getOutermostLoop());
8463}
8464
8466 Instruction *I = dyn_cast<Instruction>(V);
8467 if (!I) return;
8468
8469 // Drop information about expressions based on loop-header PHIs.
8473 Worklist.push_back(I);
8474 Visited.insert(I);
8475 visitAndClearUsers(Worklist, Visited, ToForget);
8476
8477 forgetMemoizedResults(ToForget);
8478}
8479
8481 if (!isSCEVable(V->getType()))
8482 return;
8483
8484 // If SCEV looked through a trivial LCSSA phi node, we might have SCEV's
8485 // directly using a SCEVUnknown/SCEVAddRec defined in the loop. After an
8486 // extra predecessor is added, this is no longer valid. Find all Unknowns and
8487 // AddRecs defined in the loop and invalidate any SCEV's making use of them.
8488 if (const SCEV *S = getExistingSCEV(V)) {
8489 struct InvalidationRootCollector {
8490 Loop *L;
8492
8493 InvalidationRootCollector(Loop *L) : L(L) {}
8494
8495 bool follow(const SCEV *S) {
8496 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
8497 if (auto *I = dyn_cast<Instruction>(SU->getValue()))
8498 if (L->contains(I))
8499 Roots.push_back(S);
8500 } else if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
8501 if (L->contains(AddRec->getLoop()))
8502 Roots.push_back(S);
8503 }
8504 return true;
8505 }
8506 bool isDone() const { return false; }
8507 };
8508
8509 InvalidationRootCollector C(L);
8510 visitAll(S, C);
8511 forgetMemoizedResults(C.Roots);
8512 }
8513
8514 // Also perform the normal invalidation.
8515 forgetValue(V);
8516}
8517
8518void ScalarEvolution::forgetLoopDispositions() { LoopDispositions.clear(); }
8519
8521 // Unless a specific value is passed to invalidation, completely clear both
8522 // caches.
8523 if (!V) {
8524 BlockDispositions.clear();
8525 LoopDispositions.clear();
8526 return;
8527 }
8528
8529 if (!isSCEVable(V->getType()))
8530 return;
8531
8532 const SCEV *S = getExistingSCEV(V);
8533 if (!S)
8534 return;
8535
8536 // Invalidate the block and loop dispositions cached for S. Dispositions of
8537 // S's users may change if S's disposition changes (i.e. a user may change to
8538 // loop-invariant, if S changes to loop invariant), so also invalidate
8539 // dispositions of S's users recursively.
8540 SmallVector<const SCEV *, 8> Worklist = {S};
8542 while (!Worklist.empty()) {
8543 const SCEV *Curr = Worklist.pop_back_val();
8544 bool LoopDispoRemoved = LoopDispositions.erase(Curr);
8545 bool BlockDispoRemoved = BlockDispositions.erase(Curr);
8546 if (!LoopDispoRemoved && !BlockDispoRemoved)
8547 continue;
8548 auto Users = SCEVUsers.find(Curr);
8549 if (Users != SCEVUsers.end())
8550 for (const auto *User : Users->second)
8551 if (Seen.insert(User).second)
8552 Worklist.push_back(User);
8553 }
8554}
8555
8556/// Get the exact loop backedge taken count considering all loop exits. A
8557/// computable result can only be returned for loops with all exiting blocks
8558/// dominating the latch. howFarToZero assumes that the limit of each loop test
8559/// is never skipped. This is a valid assumption as long as the loop exits via
8560/// that test. For precise results, it is the caller's responsibility to specify
8561/// the relevant loop exiting block using getExact(ExitingBlock, SE).
8562const SCEV *
8563ScalarEvolution::BackedgeTakenInfo::getExact(const Loop *L, ScalarEvolution *SE,
8565 // If any exits were not computable, the loop is not computable.
8566 if (!isComplete() || ExitNotTaken.empty())
8567 return SE->getCouldNotCompute();
8568
8569 const BasicBlock *Latch = L->getLoopLatch();
8570 // All exiting blocks we have collected must dominate the only backedge.
8571 if (!Latch)
8572 return SE->getCouldNotCompute();
8573
8574 // All exiting blocks we have gathered dominate loop's latch, so exact trip
8575 // count is simply a minimum out of all these calculated exit counts.
8577 for (const auto &ENT : ExitNotTaken) {
8578 const SCEV *BECount = ENT.ExactNotTaken;
8579 assert(BECount != SE->getCouldNotCompute() && "Bad exit SCEV!");
8580 assert(SE->DT.dominates(ENT.ExitingBlock, Latch) &&
8581 "We should only have known counts for exiting blocks that dominate "
8582 "latch!");
8583
8584 Ops.push_back(BECount);
8585
8586 if (Preds)
8587 for (const auto *P : ENT.Predicates)
8588 Preds->push_back(P);
8589
8590 assert((Preds || ENT.hasAlwaysTruePredicate()) &&
8591 "Predicate should be always true!");
8592 }
8593
8594 // If an earlier exit exits on the first iteration (exit count zero), then
8595 // a later poison exit count should not propagate into the result. This are
8596 // exactly the semantics provided by umin_seq.
8597 return SE->getUMinFromMismatchedTypes(Ops, /* Sequential */ true);
8598}
8599
8600/// Get the exact not taken count for this loop exit.
8601const SCEV *
8602ScalarEvolution::BackedgeTakenInfo::getExact(const BasicBlock *ExitingBlock,
8603 ScalarEvolution *SE) const {
8604 for (const auto &ENT : ExitNotTaken)
8605 if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
8606 return ENT.ExactNotTaken;
8607
8608 return SE->getCouldNotCompute();
8609}
8610
8611const SCEV *ScalarEvolution::BackedgeTakenInfo::getConstantMax(
8612 const BasicBlock *ExitingBlock, ScalarEvolution *SE) const {
8613 for (const auto &ENT : ExitNotTaken)
8614 if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
8615 return ENT.ConstantMaxNotTaken;
8616
8617 return SE->getCouldNotCompute();
8618}
8619
8620const SCEV *ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(
8621 const BasicBlock *ExitingBlock, ScalarEvolution *SE) const {
8622 for (const auto &ENT : ExitNotTaken)
8623 if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
8624 return ENT.SymbolicMaxNotTaken;
8625
8626 return SE->getCouldNotCompute();
8627}
8628
8629/// getConstantMax - Get the constant max backedge taken count for the loop.
8630const SCEV *
8631ScalarEvolution::BackedgeTakenInfo::getConstantMax(ScalarEvolution *SE) const {
8632 auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) {
8633 return !ENT.hasAlwaysTruePredicate();
8634 };
8635
8636 if (!getConstantMax() || any_of(ExitNotTaken, PredicateNotAlwaysTrue))
8637 return SE->getCouldNotCompute();
8638
8639 assert((isa<SCEVCouldNotCompute>(getConstantMax()) ||
8640 isa<SCEVConstant>(getConstantMax())) &&
8641 "No point in having a non-constant max backedge taken count!");
8642 return getConstantMax();
8643}
8644
8645const SCEV *
8646ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(const Loop *L,
8647 ScalarEvolution *SE) {
8648 if (!SymbolicMax)
8649 SymbolicMax = SE->computeSymbolicMaxBackedgeTakenCount(L);
8650 return SymbolicMax;
8651}
8652
8653bool ScalarEvolution::BackedgeTakenInfo::isConstantMaxOrZero(
8654 ScalarEvolution *SE) const {
8655 auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) {
8656 return !ENT.hasAlwaysTruePredicate();
8657 };
8658 return MaxOrZero && !any_of(ExitNotTaken, PredicateNotAlwaysTrue);
8659}
8660
8662 : ExitLimit(E, E, E, false, std::nullopt) {}
8663
8665 const SCEV *E, const SCEV *ConstantMaxNotTaken,
8666 const SCEV *SymbolicMaxNotTaken, bool MaxOrZero,
8668 : ExactNotTaken(E), ConstantMaxNotTaken(ConstantMaxNotTaken),
8669 SymbolicMaxNotTaken(SymbolicMaxNotTaken), MaxOrZero(MaxOrZero) {
8670 // If we prove the max count is zero, so is the symbolic bound. This happens
8671 // in practice due to differences in a) how context sensitive we've chosen
8672 // to be and b) how we reason about bounds implied by UB.
8673 if (ConstantMaxNotTaken->isZero()) {
8675 this->SymbolicMaxNotTaken = SymbolicMaxNotTaken = ConstantMaxNotTaken;
8676 }
8677
8678 assert((isa<SCEVCouldNotCompute>(ExactNotTaken) ||
8679 !isa<SCEVCouldNotCompute>(ConstantMaxNotTaken)) &&
8680 "Exact is not allowed to be less precise than Constant Max");
8681 assert((isa<SCEVCouldNotCompute>(ExactNotTaken) ||
8682 !isa<SCEVCouldNotCompute>(SymbolicMaxNotTaken)) &&
8683 "Exact is not allowed to be less precise than Symbolic Max");
8684 assert((isa<SCEVCouldNotCompute>(SymbolicMaxNotTaken) ||
8685 !isa<SCEVCouldNotCompute>(ConstantMaxNotTaken)) &&
8686 "Symbolic Max is not allowed to be less precise than Constant Max");
8687 assert((isa<SCEVCouldNotCompute>(ConstantMaxNotTaken) ||
8688 isa<SCEVConstant>(ConstantMaxNotTaken)) &&
8689 "No point in having a non-constant max backedge taken count!");
8690 for (const auto *PredSet : PredSetList)
8691 for (const auto *P : *PredSet)
8692 addPredicate(P);
8693 assert((isa<SCEVCouldNotCompute>(E) || !E->getType()->isPointerTy()) &&
8694 "Backedge count should be int");
8695 assert((isa<SCEVCouldNotCompute>(ConstantMaxNotTaken) ||
8697 "Max backedge count should be int");
8698}
8699
8701 const SCEV *E, const SCEV *ConstantMaxNotTaken,
8702 const SCEV *SymbolicMaxNotTaken, bool MaxOrZero,
8704 : ExitLimit(E, ConstantMaxNotTaken, SymbolicMaxNotTaken, MaxOrZero,
8705 { &PredSet }) {}
8706
8707/// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each
8708/// computable exit into a persistent ExitNotTakenInfo array.
8709ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(
8711 bool IsComplete, const SCEV *ConstantMax, bool MaxOrZero)
8712 : ConstantMax(ConstantMax), IsComplete(IsComplete), MaxOrZero(MaxOrZero) {
8713 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
8714
8715 ExitNotTaken.reserve(ExitCounts.size());
8716 std::transform(ExitCounts.begin(), ExitCounts.end(),
8717 std::back_inserter(ExitNotTaken),
8718 [&](const EdgeExitInfo &EEI) {
8719 BasicBlock *ExitBB = EEI.first;
8720 const ExitLimit &EL = EEI.second;
8721 return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken,
8722 EL.ConstantMaxNotTaken, EL.SymbolicMaxNotTaken,
8723 EL.Predicates);
8724 });
8725 assert((isa<SCEVCouldNotCompute>(ConstantMax) ||
8726 isa<SCEVConstant>(ConstantMax)) &&
8727 "No point in having a non-constant max backedge taken count!");
8728}
8729
8730/// Compute the number of times the backedge of the specified loop will execute.
8731ScalarEvolution::BackedgeTakenInfo
8732ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
8733 bool AllowPredicates) {
8734 SmallVector<BasicBlock *, 8> ExitingBlocks;
8735 L->getExitingBlocks(ExitingBlocks);
8736
8737 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
8738
8740 bool CouldComputeBECount = true;
8741 BasicBlock *Latch = L->getLoopLatch(); // may be NULL.
8742 const SCEV *MustExitMaxBECount = nullptr;
8743 const SCEV *MayExitMaxBECount = nullptr;
8744 bool MustExitMaxOrZero = false;
8745
8746 // Compute the ExitLimit for each loop exit. Use this to populate ExitCounts
8747 // and compute maxBECount.
8748 // Do a union of all the predicates here.
8749 for (unsigned i = 0, e = ExitingBlocks.size(); i != e; ++i) {
8750 BasicBlock *ExitBB = ExitingBlocks[i];
8751
8752 // We canonicalize untaken exits to br (constant), ignore them so that
8753 // proving an exit untaken doesn't negatively impact our ability to reason
8754 // about the loop as whole.
8755 if (auto *BI = dyn_cast<BranchInst>(ExitBB->getTerminator()))
8756 if (auto *CI = dyn_cast<ConstantInt>(BI->getCondition())) {
8757 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
8758 if (ExitIfTrue == CI->isZero())
8759 continue;
8760 }
8761
8762 ExitLimit EL = computeExitLimit(L, ExitBB, AllowPredicates);
8763
8764 assert((AllowPredicates || EL.Predicates.empty()) &&
8765 "Predicated exit limit when predicates are not allowed!");
8766
8767 // 1. For each exit that can be computed, add an entry to ExitCounts.
8768 // CouldComputeBECount is true only if all exits can be computed.
8769 if (EL.ExactNotTaken != getCouldNotCompute())
8770 ++NumExitCountsComputed;
8771 else
8772 // We couldn't compute an exact value for this exit, so
8773 // we won't be able to compute an exact value for the loop.
8774 CouldComputeBECount = false;
8775 // Remember exit count if either exact or symbolic is known. Because
8776 // Exact always implies symbolic, only check symbolic.
8777 if (EL.SymbolicMaxNotTaken != getCouldNotCompute())
8778 ExitCounts.emplace_back(ExitBB, EL);
8779 else {
8780 assert(EL.ExactNotTaken == getCouldNotCompute() &&
8781 "Exact is known but symbolic isn't?");
8782 ++NumExitCountsNotComputed;
8783 }
8784
8785 // 2. Derive the loop's MaxBECount from each exit's max number of
8786 // non-exiting iterations. Partition the loop exits into two kinds:
8787 // LoopMustExits and LoopMayExits.
8788 //
8789 // If the exit dominates the loop latch, it is a LoopMustExit otherwise it
8790 // is a LoopMayExit. If any computable LoopMustExit is found, then
8791 // MaxBECount is the minimum EL.ConstantMaxNotTaken of computable
8792 // LoopMustExits. Otherwise, MaxBECount is conservatively the maximum
8793 // EL.ConstantMaxNotTaken, where CouldNotCompute is considered greater than
8794 // any
8795 // computable EL.ConstantMaxNotTaken.
8796 if (EL.ConstantMaxNotTaken != getCouldNotCompute() && Latch &&
8797 DT.dominates(ExitBB, Latch)) {
8798 if (!MustExitMaxBECount) {
8799 MustExitMaxBECount = EL.ConstantMaxNotTaken;
8800 MustExitMaxOrZero = EL.MaxOrZero;
8801 } else {
8802 MustExitMaxBECount = getUMinFromMismatchedTypes(MustExitMaxBECount,
8803 EL.ConstantMaxNotTaken);
8804 }
8805 } else if (MayExitMaxBECount != getCouldNotCompute()) {
8806 if (!MayExitMaxBECount || EL.ConstantMaxNotTaken == getCouldNotCompute())
8807 MayExitMaxBECount = EL.ConstantMaxNotTaken;
8808 else {
8809 MayExitMaxBECount = getUMaxFromMismatchedTypes(MayExitMaxBECount,
8810 EL.ConstantMaxNotTaken);
8811 }
8812 }
8813 }
8814 const SCEV *MaxBECount = MustExitMaxBECount ? MustExitMaxBECount :
8815 (MayExitMaxBECount ? MayExitMaxBECount : getCouldNotCompute());
8816 // The loop backedge will be taken the maximum or zero times if there's
8817 // a single exit that must be taken the maximum or zero times.
8818 bool MaxOrZero = (MustExitMaxOrZero && ExitingBlocks.size() == 1);
8819
8820 // Remember which SCEVs are used in exit limits for invalidation purposes.
8821 // We only care about non-constant SCEVs here, so we can ignore
8822 // EL.ConstantMaxNotTaken
8823 // and MaxBECount, which must be SCEVConstant.
8824 for (const auto &Pair : ExitCounts) {
8825 if (!isa<SCEVConstant>(Pair.second.ExactNotTaken))
8826 BECountUsers[Pair.second.ExactNotTaken].insert({L, AllowPredicates});
8827 if (!isa<SCEVConstant>(Pair.second.SymbolicMaxNotTaken))
8828 BECountUsers[Pair.second.SymbolicMaxNotTaken].insert(
8829 {L, AllowPredicates});
8830 }
8831 return BackedgeTakenInfo(std::move(ExitCounts), CouldComputeBECount,
8832 MaxBECount, MaxOrZero);
8833}
8834
8836ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
8837 bool AllowPredicates) {
8838 assert(L->contains(ExitingBlock) && "Exit count for non-loop block?");
8839 // If our exiting block does not dominate the latch, then its connection with
8840 // loop's exit limit may be far from trivial.
8841 const BasicBlock *Latch = L->getLoopLatch();
8842 if (!Latch || !DT.dominates(ExitingBlock, Latch))
8843 return getCouldNotCompute();
8844
8845 bool IsOnlyExit = (L->getExitingBlock() != nullptr);
8846 Instruction *Term = ExitingBlock->getTerminator();
8847 if (BranchInst *BI = dyn_cast<BranchInst>(Term)) {
8848 assert(BI->isConditional() && "If unconditional, it can't be in loop!");
8849 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
8850 assert(ExitIfTrue == L->contains(BI->getSuccessor(1)) &&
8851 "It should have one successor in loop and one exit block!");
8852 // Proceed to the next level to examine the exit condition expression.
8853 return computeExitLimitFromCond(L, BI->getCondition(), ExitIfTrue,
8854 /*ControlsOnlyExit=*/IsOnlyExit,
8855 AllowPredicates);
8856 }
8857
8858 if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) {
8859 // For switch, make sure that there is a single exit from the loop.
8860 BasicBlock *Exit = nullptr;
8861 for (auto *SBB : successors(ExitingBlock))
8862 if (!L->contains(SBB)) {
8863 if (Exit) // Multiple exit successors.
8864 return getCouldNotCompute();
8865 Exit = SBB;
8866 }
8867 assert(Exit && "Exiting block must have at least one exit");
8868 return computeExitLimitFromSingleExitSwitch(
8869 L, SI, Exit,
8870 /*ControlsOnlyExit=*/IsOnlyExit);
8871 }
8872
8873 return getCouldNotCompute();
8874}
8875
8877 const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
8878 bool AllowPredicates) {
8879 ScalarEvolution::ExitLimitCacheTy Cache(L, ExitIfTrue, AllowPredicates);
8880 return computeExitLimitFromCondCached(Cache, L, ExitCond, ExitIfTrue,
8881 ControlsOnlyExit, AllowPredicates);
8882}
8883
8884std::optional<ScalarEvolution::ExitLimit>
8885ScalarEvolution::ExitLimitCache::find(const Loop *L, Value *ExitCond,
8886 bool ExitIfTrue, bool ControlsOnlyExit,
8887 bool AllowPredicates) {
8888 (void)this->L;
8889 (void)this->ExitIfTrue;
8890 (void)this->AllowPredicates;
8891
8892 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
8893 this->AllowPredicates == AllowPredicates &&
8894 "Variance in assumed invariant key components!");
8895 auto Itr = TripCountMap.find({ExitCond, ControlsOnlyExit});
8896 if (Itr == TripCountMap.end())
8897 return std::nullopt;
8898 return Itr->second;
8899}
8900
8901void ScalarEvolution::ExitLimitCache::insert(const Loop *L, Value *ExitCond,
8902 bool ExitIfTrue,
8903 bool ControlsOnlyExit,
8904 bool AllowPredicates,
8905 const ExitLimit &EL) {
8906 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
8907 this->AllowPredicates == AllowPredicates &&
8908 "Variance in assumed invariant key components!");
8909
8910 auto InsertResult = TripCountMap.insert({{ExitCond, ControlsOnlyExit}, EL});
8911 assert(InsertResult.second && "Expected successful insertion!");
8912 (void)InsertResult;
8913 (void)ExitIfTrue;
8914}
8915
8916ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached(
8917 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
8918 bool ControlsOnlyExit, bool AllowPredicates) {
8919
8920 if (auto MaybeEL = Cache.find(L, ExitCond, ExitIfTrue, ControlsOnlyExit,
8921 AllowPredicates))
8922 return *MaybeEL;
8923
8924 ExitLimit EL = computeExitLimitFromCondImpl(
8925 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates);
8926 Cache.insert(L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates, EL);
8927 return EL;
8928}
8929
8930ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
8931 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
8932 bool ControlsOnlyExit, bool AllowPredicates) {
8933 // Handle BinOp conditions (And, Or).
8934 if (auto LimitFromBinOp = computeExitLimitFromCondFromBinOp(
8935 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates))
8936 return *LimitFromBinOp;
8937
8938 // With an icmp, it may be feasible to compute an exact backedge-taken count.
8939 // Proceed to the next level to examine the icmp.
8940 if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond)) {
8941 ExitLimit EL =
8942 computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsOnlyExit);
8943 if (EL.hasFullInfo() || !AllowPredicates)
8944 return EL;
8945
8946 // Try again, but use SCEV predicates this time.
8947 return computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue,
8948 ControlsOnlyExit,
8949 /*AllowPredicates=*/true);
8950 }
8951
8952 // Check for a constant condition. These are normally stripped out by
8953 // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to
8954 // preserve the CFG and is temporarily leaving constant conditions
8955 // in place.
8956 if (ConstantInt *CI = dyn_cast<ConstantInt>(ExitCond)) {
8957 if (ExitIfTrue == !CI->getZExtValue())
8958 // The backedge is always taken.
8959 return getCouldNotCompute();
8960 // The backedge is never taken.
8961 return getZero(CI->getType());
8962 }
8963
8964 // If we're exiting based on the overflow flag of an x.with.overflow intrinsic
8965 // with a constant step, we can form an equivalent icmp predicate and figure
8966 // out how many iterations will be taken before we exit.
8967 const WithOverflowInst *WO;
8968 const APInt *C;
8969 if (match(ExitCond, m_ExtractValue<1>(m_WithOverflowInst(WO))) &&
8970 match(WO->getRHS(), m_APInt(C))) {
8971 ConstantRange NWR =
8973 WO->getNoWrapKind());
8974 CmpInst::Predicate Pred;
8975 APInt NewRHSC, Offset;
8976 NWR.getEquivalentICmp(Pred, NewRHSC, Offset);
8977 if (!ExitIfTrue)
8978 Pred = ICmpInst::getInversePredicate(Pred);
8979 auto *LHS = getSCEV(WO->getLHS());
8980 if (Offset != 0)
8982 auto EL = computeExitLimitFromICmp(L, Pred, LHS, getConstant(NewRHSC),
8983 ControlsOnlyExit, AllowPredicates);
8984 if (EL.hasAnyInfo())
8985 return EL;
8986 }
8987
8988 // If it's not an integer or pointer comparison then compute it the hard way.
8989 return computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
8990}
8991
8992std::optional<ScalarEvolution::ExitLimit>
8993ScalarEvolution::computeExitLimitFromCondFromBinOp(
8994 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
8995 bool ControlsOnlyExit, bool AllowPredicates) {
8996 // Check if the controlling expression for this loop is an And or Or.
8997 Value *Op0, *Op1;
8998 bool IsAnd = false;
8999 if (match(ExitCond, m_LogicalAnd(m_Value(Op0), m_Value(Op1))))
9000 IsAnd = true;
9001 else if (match(ExitCond, m_LogicalOr(m_Value(Op0), m_Value(Op1))))
9002 IsAnd = false;
9003 else
9004 return std::nullopt;
9005
9006 // EitherMayExit is true in these two cases:
9007 // br (and Op0 Op1), loop, exit
9008 // br (or Op0 Op1), exit, loop
9009 bool EitherMayExit = IsAnd ^ ExitIfTrue;
9010 ExitLimit EL0 = computeExitLimitFromCondCached(
9011 Cache, L, Op0, ExitIfTrue, ControlsOnlyExit && !EitherMayExit,
9012 AllowPredicates);
9013 ExitLimit EL1 = computeExitLimitFromCondCached(
9014 Cache, L, Op1, ExitIfTrue, ControlsOnlyExit && !EitherMayExit,
9015 AllowPredicates);
9016
9017 // Be robust against unsimplified IR for the form "op i1 X, NeutralElement"
9018 const Constant *NeutralElement = ConstantInt::get(ExitCond->getType(), IsAnd);
9019 if (isa<ConstantInt>(Op1))
9020 return Op1 == NeutralElement ? EL0 : EL1;
9021 if (isa<ConstantInt>(Op0))
9022 return Op0 == NeutralElement ? EL1 : EL0;
9023
9024 const SCEV *BECount = getCouldNotCompute();
9025 const SCEV *ConstantMaxBECount = getCouldNotCompute();
9026 const SCEV *SymbolicMaxBECount = getCouldNotCompute();
9027 if (EitherMayExit) {
9028 bool UseSequentialUMin = !isa<BinaryOperator>(ExitCond);
9029 // Both conditions must be same for the loop to continue executing.
9030 // Choose the less conservative count.
9031 if (EL0.ExactNotTaken != getCouldNotCompute() &&
9032 EL1.ExactNotTaken != getCouldNotCompute()) {
9033 BECount = getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken,
9034 UseSequentialUMin);
9035 }
9036 if (EL0.ConstantMaxNotTaken == getCouldNotCompute())
9037 ConstantMaxBECount = EL1.ConstantMaxNotTaken;
9038 else if (EL1.ConstantMaxNotTaken == getCouldNotCompute())
9039 ConstantMaxBECount = EL0.ConstantMaxNotTaken;
9040 else
9041 ConstantMaxBECount = getUMinFromMismatchedTypes(EL0.ConstantMaxNotTaken,
9042 EL1.ConstantMaxNotTaken);
9043 if (EL0.SymbolicMaxNotTaken == getCouldNotCompute())
9044 SymbolicMaxBECount = EL1.SymbolicMaxNotTaken;
9045 else if (EL1.SymbolicMaxNotTaken == getCouldNotCompute())
9046 SymbolicMaxBECount = EL0.SymbolicMaxNotTaken;
9047 else
9048 SymbolicMaxBECount = getUMinFromMismatchedTypes(
9049 EL0.SymbolicMaxNotTaken, EL1.SymbolicMaxNotTaken, UseSequentialUMin);
9050 } else {
9051 // Both conditions must be same at the same time for the loop to exit.
9052 // For now, be conservative.
9053 if (EL0.ExactNotTaken == EL1.ExactNotTaken)
9054 BECount = EL0.ExactNotTaken;
9055 }
9056
9057 // There are cases (e.g. PR26207) where computeExitLimitFromCond is able
9058 // to be more aggressive when computing BECount than when computing
9059 // ConstantMaxBECount. In these cases it is possible for EL0.ExactNotTaken
9060 // and
9061 // EL1.ExactNotTaken to match, but for EL0.ConstantMaxNotTaken and
9062 // EL1.ConstantMaxNotTaken to not.
9063 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
9064 !isa<SCEVCouldNotCompute>(BECount))
9065 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
9066 if (isa<SCEVCouldNotCompute>(SymbolicMaxBECount))
9067 SymbolicMaxBECount =
9068 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
9069 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
9070 { &EL0.Predicates, &EL1.Predicates });
9071}
9072
9073ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9074 const Loop *L, ICmpInst *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
9075 bool AllowPredicates) {
9076 // If the condition was exit on true, convert the condition to exit on false
9078 if (!ExitIfTrue)
9079 Pred = ExitCond->getPredicate();
9080 else
9081 Pred = ExitCond->getInversePredicate();
9082 const ICmpInst::Predicate OriginalPred = Pred;
9083
9084 const SCEV *LHS = getSCEV(ExitCond->getOperand(0));
9085 const SCEV *RHS = getSCEV(ExitCond->getOperand(1));
9086
9087 ExitLimit EL = computeExitLimitFromICmp(L, Pred, LHS, RHS, ControlsOnlyExit,
9088 AllowPredicates);
9089 if (EL.hasAnyInfo())
9090 return EL;
9091
9092 auto *ExhaustiveCount =
9093 computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
9094
9095 if (!isa<SCEVCouldNotCompute>(ExhaustiveCount))
9096 return ExhaustiveCount;
9097
9098 return computeShiftCompareExitLimit(ExitCond->getOperand(0),
9099 ExitCond->getOperand(1), L, OriginalPred);
9100}
9101ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9102 const Loop *L, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
9103 bool ControlsOnlyExit, bool AllowPredicates) {
9104
9105 // Try to evaluate any dependencies out of the loop.
9106 LHS = getSCEVAtScope(LHS, L);
9107 RHS = getSCEVAtScope(RHS, L);
9108
9109 // At this point, we would like to compute how many iterations of the
9110 // loop the predicate will return true for these inputs.
9111 if (isLoopInvariant(LHS, L) && !isLoopInvariant(RHS, L)) {
9112 // If there is a loop-invariant, force it into the RHS.
9113 std::swap(LHS, RHS);
9114 Pred = ICmpInst::getSwappedPredicate(Pred);
9115 }
9116
9117 bool ControllingFiniteLoop = ControlsOnlyExit && loopHasNoAbnormalExits(L) &&
9119 // Simplify the operands before analyzing them.
9120 (void)SimplifyICmpOperands(Pred, LHS, RHS, /*Depth=*/0);
9121
9122 // If we have a comparison of a chrec against a constant, try to use value
9123 // ranges to answer this query.
9124 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS))
9125 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS))
9126 if (AddRec->getLoop() == L) {
9127 // Form the constant range.
9128 ConstantRange CompRange =
9129 ConstantRange::makeExactICmpRegion(Pred, RHSC->getAPInt());
9130
9131 const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this);
9132 if (!isa<SCEVCouldNotCompute>(Ret)) return Ret;
9133 }
9134
9135 // If this loop must exit based on this condition (or execute undefined
9136 // behaviour), and we can prove the test sequence produced must repeat
9137 // the same values on self-wrap of the IV, then we can infer that IV
9138 // doesn't self wrap because if it did, we'd have an infinite (undefined)
9139 // loop.
9140 if (ControllingFiniteLoop && isLoopInvariant(RHS, L)) {
9141 // TODO: We can peel off any functions which are invertible *in L*. Loop
9142 // invariant terms are effectively constants for our purposes here.
9143 auto *InnerLHS = LHS;
9144 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS))
9145 InnerLHS = ZExt->getOperand();
9146 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(InnerLHS)) {
9147 auto *StrideC = dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this));
9148 if (!AR->hasNoSelfWrap() && AR->getLoop() == L && AR->isAffine() &&
9149 StrideC && StrideC->getAPInt().isPowerOf2()) {
9150 auto Flags = AR->getNoWrapFlags();
9151 Flags = setFlags(Flags, SCEV::FlagNW);
9154 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
9155 }
9156 }
9157 }
9158
9159 switch (Pred) {
9160 case ICmpInst::ICMP_NE: { // while (X != Y)
9161 // Convert to: while (X-Y != 0)
9162 if (LHS->getType()->isPointerTy()) {
9164 if (isa<SCEVCouldNotCompute>(LHS))
9165 return LHS;
9166 }
9167 if (RHS->getType()->isPointerTy()) {
9169 if (isa<SCEVCouldNotCompute>(RHS))
9170 return RHS;
9171 }
9172 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit,
9173 AllowPredicates);
9174 if (EL.hasAnyInfo())
9175 return EL;
9176 break;
9177 }
9178 case ICmpInst::ICMP_EQ: { // while (X == Y)
9179 // Convert to: while (X-Y == 0)
9180 if (LHS->getType()->isPointerTy()) {
9182 if (isa<SCEVCouldNotCompute>(LHS))
9183 return LHS;
9184 }
9185 if (RHS->getType()->isPointerTy()) {
9187 if (isa<SCEVCouldNotCompute>(RHS))
9188 return RHS;
9189 }
9190 ExitLimit EL = howFarToNonZero(getMinusSCEV(LHS, RHS), L);
9191 if (EL.hasAnyInfo()) return EL;
9192 break;
9193 }
9194 case ICmpInst::ICMP_SLE:
9195 case ICmpInst::ICMP_ULE:
9196 // Since the loop is finite, an invariant RHS cannot include the boundary
9197 // value, otherwise it would loop forever.
9198 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9199 !isLoopInvariant(RHS, L))
9200 break;
9201 RHS = getAddExpr(getOne(RHS->getType()), RHS);
9202 [[fallthrough]];
9203 case ICmpInst::ICMP_SLT:
9204 case ICmpInst::ICMP_ULT: { // while (X < Y)
9205 bool IsSigned = ICmpInst::isSigned(Pred);
9206 ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9207 AllowPredicates);
9208 if (EL.hasAnyInfo())
9209 return EL;
9210 break;
9211 }
9212 case ICmpInst::ICMP_SGE:
9213 case ICmpInst::ICMP_UGE:
9214 // Since the loop is finite, an invariant RHS cannot include the boundary
9215 // value, otherwise it would loop forever.
9216 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9217 !isLoopInvariant(RHS, L))
9218 break;
9219 RHS = getAddExpr(getMinusOne(RHS->getType()), RHS);
9220 [[fallthrough]];
9221 case ICmpInst::ICMP_SGT:
9222 case ICmpInst::ICMP_UGT: { // while (X > Y)
9223 bool IsSigned = ICmpInst::isSigned(Pred);
9224 ExitLimit EL = howManyGreaterThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9225 AllowPredicates);
9226 if (EL.hasAnyInfo())
9227 return EL;
9228 break;
9229 }
9230 default:
9231 break;
9232 }
9233
9234 return getCouldNotCompute();
9235}
9236
9238ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L,
9239 SwitchInst *Switch,
9240 BasicBlock *ExitingBlock,
9241 bool ControlsOnlyExit) {
9242 assert(!L->contains(ExitingBlock) && "Not an exiting block!");
9243
9244 // Give up if the exit is the default dest of a switch.
9245 if (Switch->getDefaultDest() == ExitingBlock)
9246 return getCouldNotCompute();
9247
9248 assert(L->contains(Switch->getDefaultDest()) &&
9249 "Default case must not exit the loop!");
9250 const SCEV *LHS = getSCEVAtScope(Switch->getCondition(), L);
9251 const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock));
9252
9253 // while (X != Y) --> while (X-Y != 0)
9254 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit);
9255 if (EL.hasAnyInfo())
9256 return EL;
9257
9258 return getCouldNotCompute();
9259}
9260
9261static ConstantInt *
9263 ScalarEvolution &SE) {
9264 const SCEV *InVal = SE.getConstant(C);
9265 const SCEV *Val = AddRec->evaluateAtIteration(InVal, SE);
9266 assert(isa<SCEVConstant>(Val) &&
9267 "Evaluation of SCEV at constant didn't fold correctly?");
9268 return cast<SCEVConstant>(Val)->getValue();
9269}
9270
9271ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit(
9272 Value *LHS, Value *RHSV, const Loop *L, ICmpInst::Predicate Pred) {
9273 ConstantInt *RHS = dyn_cast<ConstantInt>(RHSV);
9274 if (!RHS)
9275 return getCouldNotCompute();
9276
9277 const BasicBlock *Latch = L->getLoopLatch();
9278 if (!Latch)
9279 return getCouldNotCompute();
9280
9281 const BasicBlock *Predecessor = L->getLoopPredecessor();
9282 if (!Predecessor)
9283 return getCouldNotCompute();
9284
9285 // Return true if V is of the form "LHS `shift_op` <positive constant>".
9286 // Return LHS in OutLHS and shift_opt in OutOpCode.
9287 auto MatchPositiveShift =
9288 [](Value *V, Value *&OutLHS, Instruction::BinaryOps &OutOpCode) {
9289
9290 using namespace PatternMatch;
9291
9292 ConstantInt *ShiftAmt;
9293 if (match(V, m_LShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9294 OutOpCode = Instruction::LShr;
9295 else if (match(V, m_AShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9296 OutOpCode = Instruction::AShr;
9297 else if (match(V, m_Shl(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9298 OutOpCode = Instruction::Shl;
9299 else
9300 return false;
9301
9302 return ShiftAmt->getValue().isStrictlyPositive();
9303 };
9304
9305 // Recognize a "shift recurrence" either of the form %iv or of %iv.shifted in
9306 //
9307 // loop:
9308 // %iv = phi i32 [ %iv.shifted, %loop ], [ %val, %preheader ]
9309 // %iv.shifted = lshr i32 %iv, <positive constant>
9310 //
9311 // Return true on a successful match. Return the corresponding PHI node (%iv
9312 // above) in PNOut and the opcode of the shift operation in OpCodeOut.
9313 auto MatchShiftRecurrence =
9314 [&](Value *V, PHINode *&PNOut, Instruction::BinaryOps &OpCodeOut) {
9315 std::optional<Instruction::BinaryOps> PostShiftOpCode;
9316
9317 {
9319 Value *V;
9320
9321 // If we encounter a shift instruction, "peel off" the shift operation,
9322 // and remember that we did so. Later when we inspect %iv's backedge
9323 // value, we will make sure that the backedge value uses the same
9324 // operation.
9325 //
9326 // Note: the peeled shift operation does not have to be the same
9327 // instruction as the one feeding into the PHI's backedge value. We only
9328 // really care about it being the same *kind* of shift instruction --
9329 // that's all that is required for our later inferences to hold.
9330 if (MatchPositiveShift(LHS, V, OpC)) {
9331 PostShiftOpCode = OpC;
9332 LHS = V;
9333 }
9334 }
9335
9336 PNOut = dyn_cast<PHINode>(LHS);
9337 if (!PNOut || PNOut->getParent() != L->getHeader())
9338 return false;
9339
9340 Value *BEValue = PNOut->getIncomingValueForBlock(Latch);
9341 Value *OpLHS;
9342
9343 return
9344 // The backedge value for the PHI node must be a shift by a positive
9345 // amount
9346 MatchPositiveShift(BEValue, OpLHS, OpCodeOut) &&
9347
9348 // of the PHI node itself
9349 OpLHS == PNOut &&
9350
9351 // and the kind of shift should be match the kind of shift we peeled
9352 // off, if any.
9353 (!PostShiftOpCode || *PostShiftOpCode == OpCodeOut);
9354 };
9355
9356 PHINode *PN;
9358 if (!MatchShiftRecurrence(LHS, PN, OpCode))
9359 return getCouldNotCompute();
9360
9361 const DataLayout &DL = getDataLayout();
9362
9363 // The key rationale for this optimization is that for some kinds of shift
9364 // recurrences, the value of the recurrence "stabilizes" to either 0 or -1
9365 // within a finite number of iterations. If the condition guarding the
9366 // backedge (in the sense that the backedge is taken if the condition is true)
9367 // is false for the value the shift recurrence stabilizes to, then we know
9368 // that the backedge is taken only a finite number of times.
9369
9370 ConstantInt *StableValue = nullptr;
9371 switch (OpCode) {
9372 default:
9373 llvm_unreachable("Impossible case!");
9374
9375 case Instruction::AShr: {
9376 // {K,ashr,<positive-constant>} stabilizes to signum(K) in at most
9377 // bitwidth(K) iterations.
9378 Value *FirstValue = PN->getIncomingValueForBlock(Predecessor);
9379 KnownBits Known = computeKnownBits(FirstValue, DL, 0, &AC,
9380 Predecessor->getTerminator(), &DT);
9381 auto *Ty = cast<IntegerType>(RHS->getType());
9382 if (Known.isNonNegative())
9383 StableValue = ConstantInt::get(Ty, 0);
9384 else if (Known.isNegative())
9385 StableValue = ConstantInt::get(Ty, -1, true);
9386 else
9387 return getCouldNotCompute();
9388
9389 break;
9390 }
9391 case Instruction::LShr:
9392 case Instruction::Shl:
9393 // Both {K,lshr,<positive-constant>} and {K,shl,<positive-constant>}
9394 // stabilize to 0 in at most bitwidth(K) iterations.
9395 StableValue = ConstantInt::get(cast<IntegerType>(RHS->getType()), 0);
9396 break;
9397 }
9398
9399 auto *Result =
9400 ConstantFoldCompareInstOperands(Pred, StableValue, RHS, DL, &TLI);
9401 assert(Result->getType()->isIntegerTy(1) &&
9402 "Otherwise cannot be an operand to a branch instruction");
9403
9404 if (Result->isZeroValue()) {
9405 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
9406 const SCEV *UpperBound =
9408 return ExitLimit(getCouldNotCompute(), UpperBound, UpperBound, false);
9409 }
9410
9411 return getCouldNotCompute();
9412}
9413
9414/// Return true if we can constant fold an instruction of the specified type,
9415/// assuming that all operands were constants.
9416static bool CanConstantFold(const Instruction *I) {
9417 if (isa<BinaryOperator>(I) || isa<CmpInst>(I) ||
9418 isa<SelectInst>(I) || isa<CastInst>(I) || isa<GetElementPtrInst>(I) ||
9419 isa<LoadInst>(I) || isa<ExtractValueInst>(I))
9420 return true;
9421
9422 if (const CallInst *CI = dyn_cast<CallInst>(I))
9423 if (const Function *F = CI->getCalledFunction())
9424 return canConstantFoldCallTo(CI, F);
9425 return false;
9426}
9427
9428/// Determine whether this instruction can constant evolve within this loop
9429/// assuming its operands can all constant evolve.
9430static bool canConstantEvolve(Instruction *I, const Loop *L) {
9431 // An instruction outside of the loop can't be derived from a loop PHI.
9432 if (!L->contains(I)) return false;
9433
9434 if (isa<PHINode>(I)) {
9435 // We don't currently keep track of the control flow needed to evaluate
9436 // PHIs, so we cannot handle PHIs inside of loops.
9437 return L->getHeader() == I->getParent();
9438 }
9439
9440 // If we won't be able to constant fold this expression even if the operands
9441 // are constants, bail early.
9442 return CanConstantFold(I);
9443}
9444
9445/// getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by
9446/// recursing through each instruction operand until reaching a loop header phi.
9447static PHINode *
9450 unsigned Depth) {
9452 return nullptr;
9453
9454 // Otherwise, we can evaluate this instruction if all of its operands are
9455 // constant or derived from a PHI node themselves.
9456 PHINode *PHI = nullptr;
9457 for (Value *Op : UseInst->operands()) {
9458 if (isa<Constant>(Op)) continue;
9459
9460 Instruction *OpInst = dyn_cast<Instruction>(Op);
9461 if (!OpInst || !canConstantEvolve(OpInst, L)) return nullptr;
9462
9463 PHINode *P = dyn_cast<PHINode>(OpInst);
9464 if (!P)
9465 // If this operand is already visited, reuse the prior result.
9466 // We may have P != PHI if this is the deepest point at which the
9467 // inconsistent paths meet.
9468 P = PHIMap.lookup(OpInst);
9469 if (!P) {
9470 // Recurse and memoize the results, whether a phi is found or not.
9471 // This recursive call invalidates pointers into PHIMap.
9472 P = getConstantEvolvingPHIOperands(OpInst, L, PHIMap, Depth + 1);
9473 PHIMap[OpInst] = P;
9474 }
9475 if (!P)
9476 return nullptr; // Not evolving from PHI
9477 if (PHI && PHI != P)
9478 return nullptr; // Evolving from multiple different PHIs.
9479 PHI = P;
9480 }
9481 // This is a expression evolving from a constant PHI!
9482 return PHI;
9483}
9484
9485/// getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node
9486/// in the loop that V is derived from. We allow arbitrary operations along the
9487/// way, but the operands of an operation must either be constants or a value
9488/// derived from a constant PHI. If this expression does not fit with these
9489/// constraints, return null.
9491 Instruction *I = dyn_cast<Instruction>(V);
9492 if (!I || !canConstantEvolve(I, L)) return nullptr;
9493
9494 if (PHINode *PN = dyn_cast<PHINode>(I))
9495 return PN;
9496
9497 // Record non-constant instructions contained by the loop.
9499 return getConstantEvolvingPHIOperands(I, L, PHIMap, 0);
9500}
9501
9502/// EvaluateExpression - Given an expression that passes the
9503/// getConstantEvolvingPHI predicate, evaluate its value assuming the PHI node
9504/// in the loop has the value PHIVal. If we can't fold this expression for some
9505/// reason, return null.
9508 const DataLayout &DL,
9509 const TargetLibraryInfo *TLI) {
9510 // Convenient constant check, but redundant for recursive calls.
9511 if (Constant *C = dyn_cast<Constant>(V)) return C;
9512 Instruction *I = dyn_cast<Instruction>(V);
9513 if (!I) return nullptr;
9514
9515 if (Constant *C = Vals.lookup(I)) return C;
9516
9517 // An instruction inside the loop depends on a value outside the loop that we
9518 // weren't given a mapping for, or a value such as a call inside the loop.
9519 if (!canConstantEvolve(I, L)) return nullptr;
9520
9521 // An unmapped PHI can be due to a branch or another loop inside this loop,
9522 // or due to this not being the initial iteration through a loop where we
9523 // couldn't compute the evolution of this particular PHI last time.
9524 if (isa<PHINode>(I)) return nullptr;
9525
9526 std::vector<Constant*> Operands(I->getNumOperands());
9527
9528 for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
9529 Instruction *Operand = dyn_cast<Instruction>(I->getOperand(i));
9530 if (!Operand) {
9531 Operands[i] = dyn_cast<Constant>(I->getOperand(i));
9532 if (!Operands[i]) return nullptr;
9533 continue;
9534 }
9535 Constant *C = EvaluateExpression(Operand, L, Vals, DL, TLI);
9536 Vals[Operand] = C;
9537 if (!C) return nullptr;
9538 Operands[i] = C;
9539 }
9540
9541 return ConstantFoldInstOperands(I, Operands, DL, TLI);
9542}
9543
9544
9545// If every incoming value to PN except the one for BB is a specific Constant,
9546// return that, else return nullptr.
9548 Constant *IncomingVal = nullptr;
9549
9550 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
9551 if (PN->getIncomingBlock(i) == BB)
9552 continue;
9553
9554 auto *CurrentVal = dyn_cast<Constant>(PN->getIncomingValue(i));
9555 if (!CurrentVal)
9556 return nullptr;
9557
9558 if (IncomingVal != CurrentVal) {
9559 if (IncomingVal)
9560 return nullptr;
9561 IncomingVal = CurrentVal;
9562 }
9563 }
9564
9565 return IncomingVal;
9566}
9567
9568/// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
9569/// in the header of its containing loop, we know the loop executes a
9570/// constant number of times, and the PHI node is just a recurrence
9571/// involving constants, fold it.
9572Constant *
9573ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN,
9574 const APInt &BEs,
9575 const Loop *L) {
9576 auto I = ConstantEvolutionLoopExitValue.find(PN);
9577 if (I != ConstantEvolutionLoopExitValue.end())
9578 return I->second;
9579
9581 return ConstantEvolutionLoopExitValue[PN] = nullptr; // Not going to evaluate it.
9582
9583 Constant *&RetVal = ConstantEvolutionLoopExitValue[PN];
9584
9586 BasicBlock *Header = L->getHeader();
9587 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
9588
9589 BasicBlock *Latch = L->getLoopLatch();
9590 if (!Latch)
9591 return nullptr;
9592
9593 for (PHINode &PHI : Header->phis()) {
9594 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
9595 CurrentIterVals[&PHI] = StartCST;
9596 }
9597 if (!CurrentIterVals.count(PN))
9598 return RetVal = nullptr;
9599
9600 Value *BEValue = PN->getIncomingValueForBlock(Latch);
9601
9602 // Execute the loop symbolically to determine the exit value.
9603 assert(BEs.getActiveBits() < CHAR_BIT * sizeof(unsigned) &&
9604 "BEs is <= MaxBruteForceIterations which is an 'unsigned'!");
9605
9606 unsigned NumIterations = BEs.getZExtValue(); // must be in range
9607 unsigned IterationNum = 0;
9608 const DataLayout &DL = getDataLayout();
9609 for (; ; ++IterationNum) {
9610 if (IterationNum == NumIterations)
9611 return RetVal = CurrentIterVals[PN]; // Got exit value!
9612
9613 // Compute the value of the PHIs for the next iteration.
9614 // EvaluateExpression adds non-phi values to the CurrentIterVals map.
9616 Constant *NextPHI =
9617 EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9618 if (!NextPHI)
9619 return nullptr; // Couldn't evaluate!
9620 NextIterVals[PN] = NextPHI;
9621
9622 bool StoppedEvolving = NextPHI == CurrentIterVals[PN];
9623
9624 // Also evaluate the other PHI nodes. However, we don't get to stop if we
9625 // cease to be able to evaluate one of them or if they stop evolving,
9626 // because that doesn't necessarily prevent us from computing PN.
9628 for (const auto &I : CurrentIterVals) {
9629 PHINode *PHI = dyn_cast<PHINode>(I.first);
9630 if (!PHI || PHI == PN || PHI->getParent() != Header) continue;
9631 PHIsToCompute.emplace_back(PHI, I.second);
9632 }
9633 // We use two distinct loops because EvaluateExpression may invalidate any
9634 // iterators into CurrentIterVals.
9635 for (const auto &I : PHIsToCompute) {
9636 PHINode *PHI = I.first;
9637 Constant *&NextPHI = NextIterVals[PHI];
9638 if (!NextPHI) { // Not already computed.
9639 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
9640 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9641 }
9642 if (NextPHI != I.second)
9643 StoppedEvolving = false;
9644 }
9645
9646 // If all entries in CurrentIterVals == NextIterVals then we can stop
9647 // iterating, the loop can't continue to change.
9648 if (StoppedEvolving)
9649 return RetVal = CurrentIterVals[PN];
9650
9651 CurrentIterVals.swap(NextIterVals);
9652 }
9653}
9654
9655const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L,
9656 Value *Cond,
9657 bool ExitWhen) {
9659 if (!PN) return getCouldNotCompute();
9660
9661 // If the loop is canonicalized, the PHI will have exactly two entries.
9662 // That's the only form we support here.
9663 if (PN->getNumIncomingValues() != 2) return getCouldNotCompute();
9664
9666 BasicBlock *Header = L->getHeader();
9667 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
9668
9669 BasicBlock *Latch = L->getLoopLatch();
9670 assert(Latch && "Should follow from NumIncomingValues == 2!");
9671
9672 for (PHINode &PHI : Header->phis()) {
9673 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
9674 CurrentIterVals[&PHI] = StartCST;
9675 }
9676 if (!CurrentIterVals.count(PN))
9677 return getCouldNotCompute();
9678
9679 // Okay, we find a PHI node that defines the trip count of this loop. Execute
9680 // the loop symbolically to determine when the condition gets a value of
9681 // "ExitWhen".
9682 unsigned MaxIterations = MaxBruteForceIterations; // Limit analysis.
9683 const DataLayout &DL = getDataLayout();
9684 for (unsigned IterationNum = 0; IterationNum != MaxIterations;++IterationNum){
9685 auto *CondVal = dyn_cast_or_null<ConstantInt>(
9686 EvaluateExpression(Cond, L, CurrentIterVals, DL, &TLI));
9687
9688 // Couldn't symbolically evaluate.
9689 if (!CondVal) return getCouldNotCompute();
9690
9691 if (CondVal->getValue() == uint64_t(ExitWhen)) {
9692 ++NumBruteForceTripCountsComputed;
9693 return getConstant(Type::getInt32Ty(getContext()), IterationNum);
9694 }
9695
9696 // Update all the PHI nodes for the next iteration.
9698
9699 // Create a list of which PHIs we need to compute. We want to do this before
9700 // calling EvaluateExpression on them because that may invalidate iterators
9701 // into CurrentIterVals.
9702 SmallVector<PHINode *, 8> PHIsToCompute;
9703 for (const auto &I : CurrentIterVals) {
9704 PHINode *PHI = dyn_cast<PHINode>(I.first);
9705 if (!PHI || PHI->getParent() != Header) continue;
9706 PHIsToCompute.push_back(PHI);
9707 }
9708 for (PHINode *PHI : PHIsToCompute) {
9709 Constant *&NextPHI = NextIterVals[PHI];
9710 if (NextPHI) continue; // Already computed!
9711
9712 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
9713 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9714 }
9715 CurrentIterVals.swap(NextIterVals);
9716 }
9717
9718 // Too many iterations were needed to evaluate.
9719 return getCouldNotCompute();
9720}
9721
9722const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
9724 ValuesAtScopes[V];
9725 // Check to see if we've folded this expression at this loop before.
9726 for (auto &LS : Values)
9727 if (LS.first == L)
9728 return LS.second ? LS.second : V;
9729
9730 Values.emplace_back(L, nullptr);
9731
9732 // Otherwise compute it.
9733 const SCEV *C = computeSCEVAtScope(V, L);
9734 for (auto &LS : reverse(ValuesAtScopes[V]))
9735 if (LS.first == L) {
9736 LS.second = C;
9737 if (!isa<SCEVConstant>(C))
9738 ValuesAtScopesUsers[C].push_back({L, V});
9739 break;
9740 }
9741 return C;
9742}
9743
9744/// This builds up a Constant using the ConstantExpr interface. That way, we
9745/// will return Constants for objects which aren't represented by a
9746/// SCEVConstant, because SCEVConstant is restricted to ConstantInt.
9747/// Returns NULL if the SCEV isn't representable as a Constant.
9749 switch (V->getSCEVType()) {
9750 case scCouldNotCompute:
9751 case scAddRecExpr:
9752 case scVScale:
9753 return nullptr;
9754 case scConstant:
9755 return cast<SCEVConstant>(V)->getValue();
9756 case scUnknown:
9757 return dyn_cast<Constant>(cast<SCEVUnknown>(V)->getValue());
9758 case scPtrToInt: {
9759 const SCEVPtrToIntExpr *P2I = cast<SCEVPtrToIntExpr>(V);
9760 if (Constant *CastOp = BuildConstantFromSCEV(P2I->getOperand()))
9761 return ConstantExpr::getPtrToInt(CastOp, P2I->getType());
9762
9763 return nullptr;
9764 }
9765 case scTruncate: {
9766 const SCEVTruncateExpr *ST = cast<SCEVTruncateExpr>(V);
9767 if (Constant *CastOp = BuildConstantFromSCEV(ST->getOperand()))
9768 return ConstantExpr::getTrunc(CastOp, ST->getType());
9769 return nullptr;
9770 }
9771 case scAddExpr: {
9772 const SCEVAddExpr *SA = cast<SCEVAddExpr>(V);
9773 Constant *C = nullptr;
9774 for (const SCEV *Op : SA->operands()) {
9776 if (!OpC)
9777 return nullptr;
9778 if (!C) {
9779 C = OpC;
9780 continue;
9781 }
9782 assert(!C->getType()->isPointerTy() &&
9783 "Can only have one pointer, and it must be last");
9784 if (OpC->getType()->isPointerTy()) {
9785 // The offsets have been converted to bytes. We can add bytes using
9786 // an i8 GEP.
9788 OpC, C);
9789 } else {
9790 C = ConstantExpr::getAdd(C, OpC);
9791 }
9792 }
9793 return C;
9794 }
9795 case scMulExpr:
9796 case scSignExtend:
9797 case scZeroExtend:
9798 case scUDivExpr:
9799 case scSMaxExpr:
9800 case scUMaxExpr:
9801 case scSMinExpr:
9802 case scUMinExpr:
9804 return nullptr;
9805 }
9806 llvm_unreachable("Unknown SCEV kind!");
9807}
9808
9809const SCEV *
9810ScalarEvolution::getWithOperands(const SCEV *S,
9812 switch (S->getSCEVType()) {
9813 case scTruncate:
9814 case scZeroExtend:
9815 case scSignExtend:
9816 case scPtrToInt:
9817 return getCastExpr(S->getSCEVType(), NewOps[0], S->getType());
9818 case scAddRecExpr: {
9819 auto *AddRec = cast<SCEVAddRecExpr>(S);
9820 return getAddRecExpr(NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags());
9821 }
9822 case scAddExpr:
9823 return getAddExpr(NewOps, cast<SCEVAddExpr>(S)->getNoWrapFlags());
9824 case scMulExpr:
9825 return getMulExpr(NewOps, cast<SCEVMulExpr>(S)->getNoWrapFlags());
9826 case scUDivExpr:
9827 return getUDivExpr(NewOps[0], NewOps[1]);
9828 case scUMaxExpr:
9829 case scSMaxExpr:
9830 case scUMinExpr:
9831 case scSMinExpr:
9832 return getMinMaxExpr(S->getSCEVType(), NewOps);
9834 return getSequentialMinMaxExpr(S->getSCEVType(), NewOps);
9835 case scConstant:
9836 case scVScale:
9837 case scUnknown:
9838 return S;
9839 case scCouldNotCompute:
9840 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
9841 }
9842 llvm_unreachable("Unknown SCEV kind!");
9843}
9844
9845const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
9846 switch (V->getSCEVType()) {
9847 case scConstant:
9848 case scVScale:
9849 return V;
9850 case scAddRecExpr: {
9851 // If this is a loop recurrence for a loop that does not contain L, then we
9852 // are dealing with the final value computed by the loop.
9853 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(V);
9854 // First, attempt to evaluate each operand.
9855 // Avoid performing the look-up in the common case where the specified
9856 // expression has no loop-variant portions.
9857 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
9858 const SCEV *OpAtScope = getSCEVAtScope(AddRec->getOperand(i), L);
9859 if (OpAtScope == AddRec->getOperand(i))
9860 continue;
9861
9862 // Okay, at least one of these operands is loop variant but might be
9863 // foldable. Build a new instance of the folded commutative expression.
9865 NewOps.reserve(AddRec->getNumOperands());
9866 append_range(NewOps, AddRec->operands().take_front(i));
9867 NewOps.push_back(OpAtScope);
9868 for (++i; i != e; ++i)
9869 NewOps.push_back(getSCEVAtScope(AddRec->getOperand(i), L));
9870
9871 const SCEV *FoldedRec = getAddRecExpr(
9872 NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags(SCEV::FlagNW));
9873 AddRec = dyn_cast<SCEVAddRecExpr>(FoldedRec);
9874 // The addrec may be folded to a nonrecurrence, for example, if the
9875 // induction variable is multiplied by zero after constant folding. Go
9876 // ahead and return the folded value.
9877 if (!AddRec)
9878 return FoldedRec;
9879 break;
9880 }
9881
9882 // If the scope is outside the addrec's loop, evaluate it by using the
9883 // loop exit value of the addrec.
9884 if (!AddRec->getLoop()->contains(L)) {
9885 // To evaluate this recurrence, we need to know how many times the AddRec
9886 // loop iterates. Compute this now.
9887 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop());
9888 if (BackedgeTakenCount == getCouldNotCompute())
9889 return AddRec;
9890
9891 // Then, evaluate the AddRec.
9892 return AddRec->evaluateAtIteration(BackedgeTakenCount, *this);
9893 }
9894
9895 return AddRec;
9896 }
9897 case scTruncate:
9898 case scZeroExtend:
9899 case scSignExtend:
9900 case scPtrToInt:
9901 case scAddExpr:
9902 case scMulExpr:
9903 case scUDivExpr:
9904 case scUMaxExpr:
9905 case scSMaxExpr:
9906 case scUMinExpr:
9907 case scSMinExpr:
9908 case scSequentialUMinExpr: {
9909 ArrayRef<const SCEV *> Ops = V->operands();
9910 // Avoid performing the look-up in the common case where the specified
9911 // expression has no loop-variant portions.
9912 for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
9913 const SCEV *OpAtScope = getSCEVAtScope(Ops[i], L);
9914 if (OpAtScope != Ops[i]) {
9915 // Okay, at least one of these operands is loop variant but might be
9916 // foldable. Build a new instance of the folded commutative expression.
9918 NewOps.reserve(Ops.size());
9919 append_range(NewOps, Ops.take_front(i));
9920 NewOps.push_back(OpAtScope);
9921
9922 for (++i; i != e; ++i) {
9923 OpAtScope = getSCEVAtScope(Ops[i], L);
9924 NewOps.push_back(OpAtScope);
9925 }
9926
9927 return getWithOperands(V, NewOps);
9928 }
9929 }
9930 // If we got here, all operands are loop invariant.
9931 return V;
9932 }
9933 case scUnknown: {
9934 // If this instruction is evolved from a constant-evolving PHI, compute the
9935 // exit value from the loop without using SCEVs.
9936 const SCEVUnknown *SU = cast<SCEVUnknown>(V);
9937 Instruction *I = dyn_cast<Instruction>(SU->getValue());
9938 if (!I)
9939 return V; // This is some other type of SCEVUnknown, just return it.
9940
9941 if (PHINode *PN = dyn_cast<PHINode>(I)) {
9942 const Loop *CurrLoop = this->LI[I->getParent()];
9943 // Looking for loop exit value.
9944 if (CurrLoop && CurrLoop->getParentLoop() == L &&
9945 PN->getParent() == CurrLoop->getHeader()) {
9946 // Okay, there is no closed form solution for the PHI node. Check
9947 // to see if the loop that contains it has a known backedge-taken
9948 // count. If so, we may be able to force computation of the exit
9949 // value.
9950 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(CurrLoop);
9951 // This trivial case can show up in some degenerate cases where
9952 // the incoming IR has not yet been fully simplified.
9953 if (BackedgeTakenCount->isZero()) {
9954 Value *InitValue = nullptr;
9955 bool MultipleInitValues = false;
9956 for (unsigned i = 0; i < PN->getNumIncomingValues(); i++) {
9957 if (!CurrLoop->contains(PN->getIncomingBlock(i))) {
9958 if (!InitValue)
9959 InitValue = PN->getIncomingValue(i);
9960 else if (InitValue != PN->getIncomingValue(i)) {
9961 MultipleInitValues = true;
9962 break;
9963 }
9964 }
9965 }
9966 if (!MultipleInitValues && InitValue)
9967 return getSCEV(InitValue);
9968 }
9969 // Do we have a loop invariant value flowing around the backedge
9970 // for a loop which must execute the backedge?
9971 if (!isa<SCEVCouldNotCompute>(BackedgeTakenCount) &&
9972 isKnownNonZero(BackedgeTakenCount) &&
9973 PN->getNumIncomingValues() == 2) {
9974
9975 unsigned InLoopPred =
9976 CurrLoop->contains(PN->getIncomingBlock(0)) ? 0 : 1;
9977 Value *BackedgeVal = PN->getIncomingValue(InLoopPred);
9978 if (CurrLoop->isLoopInvariant(BackedgeVal))
9979 return getSCEV(BackedgeVal);
9980 }
9981 if (auto *BTCC = dyn_cast<SCEVConstant>(BackedgeTakenCount)) {
9982 // Okay, we know how many times the containing loop executes. If
9983 // this is a constant evolving PHI node, get the final value at
9984 // the specified iteration number.
9985 Constant *RV =
9986 getConstantEvolutionLoopExitValue(PN, BTCC->getAPInt(), CurrLoop);
9987 if (RV)
9988 return getSCEV(RV);
9989 }
9990 }
9991 }
9992
9993 // Okay, this is an expression that we cannot symbolically evaluate
9994 // into a SCEV. Check to see if it's possible to symbolically evaluate
9995 // the arguments into constants, and if so, try to constant propagate the
9996 // result. This is particularly useful for computing loop exit values.
9997 if (!CanConstantFold(I))
9998 return V; // This is some other type of SCEVUnknown, just return it.
9999
10001 Operands.reserve(I->getNumOperands());
10002 bool MadeImprovement = false;
10003 for (Value *Op : I->operands()) {
10004 if (Constant *C = dyn_cast<Constant>(Op)) {
10005 Operands.push_back(C);
10006 continue;
10007 }
10008
10009 // If any of the operands is non-constant and if they are
10010 // non-integer and non-pointer, don't even try to analyze them
10011 // with scev techniques.
10012 if (!isSCEVable(Op->getType()))
10013 return V;
10014
10015 const SCEV *OrigV = getSCEV(Op);
10016 const SCEV *OpV = getSCEVAtScope(OrigV, L);
10017 MadeImprovement |= OrigV != OpV;
10018
10020 if (!C)
10021 return V;
10022 assert(C->getType() == Op->getType() && "Type mismatch");
10023 Operands.push_back(C);
10024 }
10025
10026 // Check to see if getSCEVAtScope actually made an improvement.
10027 if (!MadeImprovement)
10028 return V; // This is some other type of SCEVUnknown, just return it.
10029
10030 Constant *C = nullptr;
10031 const DataLayout &DL = getDataLayout();
10033 if (!C)
10034 return V;
10035 return getSCEV(C);
10036 }
10037 case scCouldNotCompute:
10038 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
10039 }
10040 llvm_unreachable("Unknown SCEV type!");
10041}
10042
10044 return getSCEVAtScope(getSCEV(V), L);
10045}
10046
10047const SCEV *ScalarEvolution::stripInjectiveFunctions(const SCEV *S) const {
10048 if (const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(S))
10049 return stripInjectiveFunctions(ZExt->getOperand());
10050 if (const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(S))
10051 return stripInjectiveFunctions(SExt->getOperand());
10052 return S;
10053}
10054
10055/// Finds the minimum unsigned root of the following equation:
10056///
10057/// A * X = B (mod N)
10058///
10059/// where N = 2^BW and BW is the common bit width of A and B. The signedness of
10060/// A and B isn't important.
10061///
10062/// If the equation does not have a solution, SCEVCouldNotCompute is returned.
10063static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const SCEV *B,
10064 ScalarEvolution &SE) {
10065 uint32_t BW = A.getBitWidth();
10066 assert(BW == SE.getTypeSizeInBits(B->getType()));
10067 assert(A != 0 && "A must be non-zero.");
10068
10069 // 1. D = gcd(A, N)
10070 //
10071 // The gcd of A and N may have only one prime factor: 2. The number of
10072 // trailing zeros in A is its multiplicity
10073 uint32_t Mult2 = A.countr_zero();
10074 // D = 2^Mult2
10075
10076 // 2. Check if B is divisible by D.
10077 //
10078 // B is divisible by D if and only if the multiplicity of prime factor 2 for B
10079 // is not less than multiplicity of this prime factor for D.
10080 if (SE.getMinTrailingZeros(B) < Mult2)
10081 return SE.getCouldNotCompute();
10082
10083 // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic
10084 // modulo (N / D).
10085 //
10086 // If D == 1, (N / D) == N == 2^BW, so we need one extra bit to represent
10087 // (N / D) in general. The inverse itself always fits into BW bits, though,
10088 // so we immediately truncate it.
10089 APInt AD = A.lshr(Mult2).zext(BW + 1); // AD = A / D
10090 APInt Mod(BW + 1, 0);
10091 Mod.setBit(BW - Mult2); // Mod = N / D
10093
10094 // 4. Compute the minimum unsigned root of the equation:
10095 // I * (B / D) mod (N / D)
10096 // To simplify the computation, we factor out the divide by D:
10097 // (I * B mod N) / D
10098 const SCEV *D = SE.getConstant(APInt::getOneBitSet(BW, Mult2));
10099 return SE.getUDivExactExpr(SE.getMulExpr(B, SE.getConstant(I)), D);
10100}
10101
10102/// For a given quadratic addrec, generate coefficients of the corresponding
10103/// quadratic equation, multiplied by a common value to ensure that they are
10104/// integers.
10105/// The returned value is a tuple { A, B, C, M, BitWidth }, where
10106/// Ax^2 + Bx + C is the quadratic function, M is the value that A, B and C
10107/// were multiplied by, and BitWidth is the bit width of the original addrec
10108/// coefficients.
10109/// This function returns std::nullopt if the addrec coefficients are not
10110/// compile- time constants.
10111static std::optional<std::tuple<APInt, APInt, APInt, APInt, unsigned>>
10113 assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!");
10114 const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0));
10115 const SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1));
10116 const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2));
10117 LLVM_DEBUG(dbgs() << __func__ << ": analyzing quadratic addrec: "
10118 << *AddRec << '\n');
10119
10120 // We currently can only solve this if the coefficients are constants.
10121 if (!LC || !MC || !NC) {
10122 LLVM_DEBUG(dbgs() << __func__ << ": coefficients are not constant\n");
10123 return std::nullopt;
10124 }
10125
10126 APInt L = LC->getAPInt();
10127 APInt M = MC->getAPInt();
10128 APInt N = NC->getAPInt();
10129 assert(!N.isZero() && "This is not a quadratic addrec");
10130
10131 unsigned BitWidth = LC->getAPInt().getBitWidth();
10132 unsigned NewWidth = BitWidth + 1;
10133 LLVM_DEBUG(dbgs() << __func__ << ": addrec coeff bw: "
10134 << BitWidth << '\n');
10135 // The sign-extension (as opposed to a zero-extension) here matches the
10136 // extension used in SolveQuadraticEquationWrap (with the same motivation).
10137 N = N.sext(NewWidth);
10138 M = M.sext(NewWidth);
10139 L = L.sext(NewWidth);
10140
10141 // The increments are M, M+N, M+2N, ..., so the accumulated values are
10142 // L+M, (L+M)+(M+N), (L+M)+(M+N)+(M+2N), ..., that is,
10143 // L+M, L+2M+N, L+3M+3N, ...
10144 // After n iterations the accumulated value Acc is L + nM + n(n-1)/2 N.
10145 //
10146 // The equation Acc = 0 is then
10147 // L + nM + n(n-1)/2 N = 0, or 2L + 2M n + n(n-1) N = 0.
10148 // In a quadratic form it becomes:
10149 // N n^2 + (2M-N) n + 2L = 0.
10150
10151 APInt A = N;
10152 APInt B = 2 * M - A;
10153 APInt C = 2 * L;
10154 APInt T = APInt(NewWidth, 2);
10155 LLVM_DEBUG(dbgs() << __func__ << ": equation " << A << "x^2 + " << B
10156 << "x + " << C << ", coeff bw: " << NewWidth
10157 << ", multiplied by " << T << '\n');
10158 return std::make_tuple(A, B, C, T, BitWidth);
10159}
10160
10161/// Helper function to compare optional APInts:
10162/// (a) if X and Y both exist, return min(X, Y),
10163/// (b) if neither X nor Y exist, return std::nullopt,
10164/// (c) if exactly one of X and Y exists, return that value.
10165static std::optional<APInt> MinOptional(std::optional<APInt> X,
10166 std::optional<APInt> Y) {
10167 if (X && Y) {
10168 unsigned W = std::max(X->getBitWidth(), Y->getBitWidth());
10169 APInt XW = X->sext(W);
10170 APInt YW = Y->sext(W);
10171 return XW.slt(YW) ? *X : *Y;
10172 }
10173 if (!X && !Y)
10174 return std::nullopt;
10175 return X ? *X : *Y;
10176}
10177
10178/// Helper function to truncate an optional APInt to a given BitWidth.
10179/// When solving addrec-related equations, it is preferable to return a value
10180/// that has the same bit width as the original addrec's coefficients. If the
10181/// solution fits in the original bit width, truncate it (except for i1).
10182/// Returning a value of a different bit width may inhibit some optimizations.
10183///
10184/// In general, a solution to a quadratic equation generated from an addrec
10185/// may require BW+1 bits, where BW is the bit width of the addrec's
10186/// coefficients. The reason is that the coefficients of the quadratic
10187/// equation are BW+1 bits wide (to avoid truncation when converting from
10188/// the addrec to the equation).
10189static std::optional<APInt> TruncIfPossible(std::optional<APInt> X,
10190 unsigned BitWidth) {
10191 if (!X)
10192 return std::nullopt;
10193 unsigned W = X->getBitWidth();
10194 if (BitWidth > 1 && BitWidth < W && X->isIntN(BitWidth))
10195 return X->trunc(BitWidth);
10196 return X;
10197}
10198
10199/// Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n
10200/// iterations. The values L, M, N are assumed to be signed, and they
10201/// should all have the same bit widths.
10202/// Find the least n >= 0 such that c(n) = 0 in the arithmetic modulo 2^BW,
10203/// where BW is the bit width of the addrec's coefficients.
10204/// If the calculated value is a BW-bit integer (for BW > 1), it will be
10205/// returned as such, otherwise the bit width of the returned value may
10206/// be greater than BW.
10207///
10208/// This function returns std::nullopt if
10209/// (a) the addrec coefficients are not constant, or
10210/// (b) SolveQuadraticEquationWrap was unable to find a solution. For cases
10211/// like x^2 = 5, no integer solutions exist, in other cases an integer
10212/// solution may exist, but SolveQuadraticEquationWrap may fail to find it.
10213static std::optional<APInt>
10215 APInt A, B, C, M;
10216 unsigned BitWidth;
10217 auto T = GetQuadraticEquation(AddRec);
10218 if (!T)
10219 return std::nullopt;
10220
10221 std::tie(A, B, C, M, BitWidth) = *T;
10222 LLVM_DEBUG(dbgs() << __func__ << ": solving for unsigned overflow\n");
10223 std::optional<APInt> X =
10225 if (!X)
10226 return std::nullopt;
10227
10228 ConstantInt *CX = ConstantInt::get(SE.getContext(), *X);
10229 ConstantInt *V = EvaluateConstantChrecAtConstant(AddRec, CX, SE);
10230 if (!V->isZero())
10231 return std::nullopt;
10232
10233 return TruncIfPossible(X, BitWidth);
10234}
10235
10236/// Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n
10237/// iterations. The values M, N are assumed to be signed, and they
10238/// should all have the same bit widths.
10239/// Find the least n such that c(n) does not belong to the given range,
10240/// while c(n-1) does.
10241///
10242/// This function returns std::nullopt if
10243/// (a) the addrec coefficients are not constant, or
10244/// (b) SolveQuadraticEquationWrap was unable to find a solution for the
10245/// bounds of the range.
10246static std::optional<APInt>
10248 const ConstantRange &Range, ScalarEvolution &SE) {
10249 assert(AddRec->getOperand(0)->isZero() &&
10250 "Starting value of addrec should be 0");
10251 LLVM_DEBUG(dbgs() << __func__ << ": solving boundary crossing for range "
10252 << Range << ", addrec " << *AddRec << '\n');
10253 // This case is handled in getNumIterationsInRange. Here we can assume that
10254 // we start in the range.
10255 assert(Range.contains(APInt(SE.getTypeSizeInBits(AddRec->getType()), 0)) &&
10256 "Addrec's initial value should be in range");
10257
10258 APInt A, B, C, M;
10259 unsigned BitWidth;
10260 auto T = GetQuadraticEquation(AddRec);
10261 if (!T)
10262 return std::nullopt;
10263
10264 // Be careful about the return value: there can be two reasons for not
10265 // returning an actual number. First, if no solutions to the equations
10266 // were found, and second, if the solutions don't leave the given range.
10267 // The first case means that the actual solution is "unknown", the second
10268 // means that it's known, but not valid. If the solution is unknown, we
10269 // cannot make any conclusions.
10270 // Return a pair: the optional solution and a flag indicating if the
10271 // solution was found.
10272 auto SolveForBoundary =
10273 [&](APInt Bound) -> std::pair<std::optional<APInt>, bool> {
10274 // Solve for signed overflow and unsigned overflow, pick the lower
10275 // solution.
10276 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: checking boundary "
10277 << Bound << " (before multiplying by " << M << ")\n");
10278 Bound *= M; // The quadratic equation multiplier.
10279
10280 std::optional<APInt> SO;
10281 if (BitWidth > 1) {
10282 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10283 "signed overflow\n");
10285 }
10286 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10287 "unsigned overflow\n");
10288 std::optional<APInt> UO =
10290
10291 auto LeavesRange = [&] (const APInt &X) {
10292 ConstantInt *C0 = ConstantInt::get(SE.getContext(), X);
10293 ConstantInt *V0 = EvaluateConstantChrecAtConstant(AddRec, C0, SE);
10294 if (Range.contains(V0->getValue()))
10295 return false;
10296 // X should be at least 1, so X-1 is non-negative.
10297 ConstantInt *C1 = ConstantInt::get(SE.getContext(), X-1);
10298 ConstantInt *V1 = EvaluateConstantChrecAtConstant(AddRec, C1, SE);
10299 if (Range.contains(V1->getValue()))
10300 return true;
10301 return false;
10302 };
10303
10304 // If SolveQuadraticEquationWrap returns std::nullopt, it means that there
10305 // can be a solution, but the function failed to find it. We cannot treat it
10306 // as "no solution".
10307 if (!SO || !UO)
10308 return {std::nullopt, false};
10309
10310 // Check the smaller value first to see if it leaves the range.
10311 // At this point, both SO and UO must have values.
10312 std::optional<APInt> Min = MinOptional(SO, UO);
10313 if (LeavesRange(*Min))
10314 return { Min, true };
10315 std::optional<APInt> Max = Min == SO ? UO : SO;
10316 if (LeavesRange(*Max))
10317 return { Max, true };
10318
10319 // Solutions were found, but were eliminated, hence the "true".
10320 return {std::nullopt, true};
10321 };
10322
10323 std::tie(A, B, C, M, BitWidth) = *T;
10324 // Lower bound is inclusive, subtract 1 to represent the exiting value.
10325 APInt Lower = Range.getLower().sext(A.getBitWidth()) - 1;
10326 APInt Upper = Range.getUpper().sext(A.getBitWidth());
10327 auto SL = SolveForBoundary(Lower);
10328 auto SU = SolveForBoundary(Upper);
10329 // If any of the solutions was unknown, no meaninigful conclusions can
10330 // be made.
10331 if (!SL.second || !SU.second)
10332 return std::nullopt;
10333
10334 // Claim: The correct solution is not some value between Min and Max.
10335 //
10336 // Justification: Assuming that Min and Max are different values, one of
10337 // them is when the first signed overflow happens, the other is when the
10338 // first unsigned overflow happens. Crossing the range boundary is only
10339 // possible via an overflow (treating 0 as a special case of it, modeling
10340 // an overflow as crossing k*2^W for some k).
10341 //
10342 // The interesting case here is when Min was eliminated as an invalid
10343 // solution, but Max was not. The argument is that if there was another
10344 // overflow between Min and Max, it would also have been eliminated if
10345 // it was considered.
10346 //
10347 // For a given boundary, it is possible to have two overflows of the same
10348 // type (signed/unsigned) without having the other type in between: this
10349 // can happen when the vertex of the parabola is between the iterations
10350 // corresponding to the overflows. This is only possible when the two
10351 // overflows cross k*2^W for the same k. In such case, if the second one
10352 // left the range (and was the first one to do so), the first overflow
10353 // would have to enter the range, which would mean that either we had left
10354 // the range before or that we started outside of it. Both of these cases
10355 // are contradictions.
10356 //
10357 // Claim: In the case where SolveForBoundary returns std::nullopt, the correct
10358 // solution is not some value between the Max for this boundary and the
10359 // Min of the other boundary.
10360 //
10361 // Justification: Assume that we had such Max_A and Min_B corresponding
10362 // to range boundaries A and B and such that Max_A < Min_B. If there was
10363 // a solution between Max_A and Min_B, it would have to be caused by an
10364 // overflow corresponding to either A or B. It cannot correspond to B,
10365 // since Min_B is the first occurrence of such an overflow. If it
10366 // corresponded to A, it would have to be either a signed or an unsigned
10367 // overflow that is larger than both eliminated overflows for A. But
10368 // between the eliminated overflows and this overflow, the values would
10369 // cover the entire value space, thus crossing the other boundary, which
10370 // is a contradiction.
10371
10372 return TruncIfPossible(MinOptional(SL.first, SU.first), BitWidth);
10373}
10374
10375ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
10376 const Loop *L,
10377 bool ControlsOnlyExit,
10378 bool AllowPredicates) {
10379
10380 // This is only used for loops with a "x != y" exit test. The exit condition
10381 // is now expressed as a single expression, V = x-y. So the exit test is
10382 // effectively V != 0. We know and take advantage of the fact that this
10383 // expression only being used in a comparison by zero context.
10384
10386 // If the value is a constant
10387 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10388 // If the value is already zero, the branch will execute zero times.
10389 if (C->getValue()->isZero()) return C;
10390 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10391 }
10392
10393 const SCEVAddRecExpr *AddRec =
10394 dyn_cast<SCEVAddRecExpr>(stripInjectiveFunctions(V));
10395
10396 if (!AddRec && AllowPredicates)
10397 // Try to make this an AddRec using runtime tests, in the first X
10398 // iterations of this loop, where X is the SCEV expression found by the
10399 // algorithm below.
10400 AddRec = convertSCEVToAddRecWithPredicates(V, L, Predicates);
10401
10402 if (!AddRec || AddRec->getLoop() != L)
10403 return getCouldNotCompute();
10404
10405 // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of
10406 // the quadratic equation to solve it.
10407 if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) {
10408 // We can only use this value if the chrec ends up with an exact zero
10409 // value at this index. When solving for "X*X != 5", for example, we
10410 // should not accept a root of 2.
10411 if (auto S = SolveQuadraticAddRecExact(AddRec, *this)) {
10412 const auto *R = cast<SCEVConstant>(getConstant(*S));
10413 return ExitLimit(R, R, R, false, Predicates);
10414 }
10415 return getCouldNotCompute();
10416 }
10417
10418 // Otherwise we can only handle this if it is affine.
10419 if (!AddRec->isAffine())
10420 return getCouldNotCompute();
10421
10422 // If this is an affine expression, the execution count of this branch is
10423 // the minimum unsigned root of the following equation:
10424 //
10425 // Start + Step*N = 0 (mod 2^BW)
10426 //
10427 // equivalent to:
10428 //
10429 // Step*N = -Start (mod 2^BW)
10430 //
10431 // where BW is the common bit width of Start and Step.
10432
10433 // Get the initial value for the loop.
10434 const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop());
10435 const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop());
10436
10437 // For now we handle only constant steps.
10438 //
10439 // TODO: Handle a nonconstant Step given AddRec<NUW>. If the
10440 // AddRec is NUW, then (in an unsigned sense) it cannot be counting up to wrap
10441 // to 0, it must be counting down to equal 0. Consequently, N = Start / -Step.
10442 // We have not yet seen any such cases.
10443 const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
10444 if (!StepC || StepC->getValue()->isZero())
10445 return getCouldNotCompute();
10446
10447 // For positive steps (counting up until unsigned overflow):
10448 // N = -Start/Step (as unsigned)
10449 // For negative steps (counting down to zero):
10450 // N = Start/-Step
10451 // First compute the unsigned distance from zero in the direction of Step.
10452 bool CountDown = StepC->getAPInt().isNegative();
10453 const SCEV *Distance = CountDown ? Start : getNegativeSCEV(Start);
10454
10455 // Handle unitary steps, which cannot wraparound.
10456 // 1*N = -Start; -1*N = Start (mod 2^BW), so:
10457 // N = Distance (as unsigned)
10458 if (StepC->getValue()->isOne() || StepC->getValue()->isMinusOne()) {
10459 APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, L));
10460 MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance));
10461
10462 // When a loop like "for (int i = 0; i != n; ++i) { /* body */ }" is rotated,
10463 // we end up with a loop whose backedge-taken count is n - 1. Detect this
10464 // case, and see if we can improve the bound.
10465 //
10466 // Explicitly handling this here is necessary because getUnsignedRange
10467 // isn't context-sensitive; it doesn't know that we only care about the
10468 // range inside the loop.
10469 const SCEV *Zero = getZero(Distance->getType());
10470 const SCEV *One = getOne(Distance->getType());
10471 const SCEV *DistancePlusOne = getAddExpr(Distance, One);
10472 if (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, DistancePlusOne, Zero)) {
10473 // If Distance + 1 doesn't overflow, we can compute the maximum distance
10474 // as "unsigned_max(Distance + 1) - 1".
10475 ConstantRange CR = getUnsignedRange(DistancePlusOne);
10476 MaxBECount = APIntOps::umin(MaxBECount, CR.getUnsignedMax() - 1);
10477 }
10478 return ExitLimit(Distance, getConstant(MaxBECount), Distance, false,
10479 Predicates);
10480 }
10481
10482 // If the condition controls loop exit (the loop exits only if the expression
10483 // is true) and the addition is no-wrap we can use unsigned divide to
10484 // compute the backedge count. In this case, the step may not divide the
10485 // distance, but we don't care because if the condition is "missed" the loop
10486 // will have undefined behavior due to wrapping.
10487 if (ControlsOnlyExit && AddRec->hasNoSelfWrap() &&
10488 loopHasNoAbnormalExits(AddRec->getLoop())) {
10489 const SCEV *Exact =
10490 getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
10491 const SCEV *ConstantMax = getCouldNotCompute();
10492 if (Exact != getCouldNotCompute()) {
10494 ConstantMax =
10496 }
10497 const SCEV *SymbolicMax =
10498 isa<SCEVCouldNotCompute>(Exact) ? ConstantMax : Exact;
10499 return ExitLimit(Exact, ConstantMax, SymbolicMax, false, Predicates);
10500 }
10501
10502 // Solve the general equation.
10503 const SCEV *E = SolveLinEquationWithOverflow(StepC->getAPInt(),
10504 getNegativeSCEV(Start), *this);
10505
10506 const SCEV *M = E;
10507 if (E != getCouldNotCompute()) {
10508 APInt MaxWithGuards = getUnsignedRangeMax(applyLoopGuards(E, L));
10509 M = getConstant(APIntOps::umin(MaxWithGuards, getUnsignedRangeMax(E)));
10510 }
10511 auto *S = isa<SCEVCouldNotCompute>(E) ? M : E;
10512 return ExitLimit(E, M, S, false, Predicates);
10513}
10514
10516ScalarEvolution::howFarToNonZero(const SCEV *V, const Loop *L) {
10517 // Loops that look like: while (X == 0) are very strange indeed. We don't
10518 // handle them yet except for the trivial case. This could be expanded in the
10519 // future as needed.
10520
10521 // If the value is a constant, check to see if it is known to be non-zero
10522 // already. If so, the backedge will execute zero times.
10523 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10524 if (!C->getValue()->isZero())
10525 return getZero(C->getType());
10526 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10527 }
10528
10529 // We could implement others, but I really doubt anyone writes loops like
10530 // this, and if they did, they would already be constant folded.
10531 return getCouldNotCompute();
10532}
10533
10534std::pair<const BasicBlock *, const BasicBlock *>
10535ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(const BasicBlock *BB)
10536 const {
10537 // If the block has a unique predecessor, then there is no path from the
10538 // predecessor to the block that does not go through the direct edge
10539 // from the predecessor to the block.
10540 if (const BasicBlock *Pred = BB->getSinglePredecessor())
10541 return {Pred, BB};
10542
10543 // A loop's header is defined to be a block that dominates the loop.
10544 // If the header has a unique predecessor outside the loop, it must be
10545 // a block that has exactly one successor that can reach the loop.
10546 if (const Loop *L = LI.getLoopFor(BB))
10547 return {L->getLoopPredecessor(), L->getHeader()};
10548
10549 return {nullptr, nullptr};
10550}
10551
10552/// SCEV structural equivalence is usually sufficient for testing whether two
10553/// expressions are equal, however for the purposes of looking for a condition
10554/// guarding a loop, it can be useful to be a little more general, since a
10555/// front-end may have replicated the controlling expression.
10556static bool HasSameValue(const SCEV *A, const SCEV *B) {
10557 // Quick check to see if they are the same SCEV.
10558 if (A == B) return true;
10559
10560 auto ComputesEqualValues = [](const Instruction *A, const Instruction *B) {
10561 // Not all instructions that are "identical" compute the same value. For
10562 // instance, two distinct alloca instructions allocating the same type are
10563 // identical and do not read memory; but compute distinct values.
10564 return A->isIdenticalTo(B) && (isa<BinaryOperator>(A) || isa<GetElementPtrInst>(A));
10565 };
10566
10567 // Otherwise, if they're both SCEVUnknown, it's possible that they hold
10568 // two different instructions with the same value. Check for this case.
10569 if (const SCEVUnknown *AU = dyn_cast<SCEVUnknown>(A))
10570 if (const SCEVUnknown *BU = dyn_cast<SCEVUnknown>(B))
10571 if (const Instruction *AI = dyn_cast<Instruction>(AU->getValue()))
10572 if (const Instruction *BI = dyn_cast<Instruction>(BU->getValue()))
10573 if (ComputesEqualValues(AI, BI))
10574 return true;
10575
10576 // Otherwise assume they may have a different value.
10577 return false;
10578}
10579
10580static bool MatchBinarySub(const SCEV *S, const SCEV *&LHS, const SCEV *&RHS) {
10581 const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S);
10582 if (!Add || Add->getNumOperands() != 2)
10583 return false;
10584 if (auto *ME = dyn_cast<SCEVMulExpr>(Add->getOperand(0));
10585 ME && ME->getNumOperands() == 2 && ME->getOperand(0)->isAllOnesValue()) {
10586 LHS = Add->getOperand(1);
10587 RHS = ME->getOperand(1);
10588 return true;
10589 }
10590 if (auto *ME = dyn_cast<SCEVMulExpr>(Add->getOperand(1));
10591 ME && ME->getNumOperands() == 2 && ME->getOperand(0)->isAllOnesValue()) {
10592 LHS = Add->getOperand(0);
10593 RHS = ME->getOperand(1);
10594 return true;
10595 }
10596 return false;
10597}
10598
10600 const SCEV *&LHS, const SCEV *&RHS,
10601 unsigned Depth) {
10602 bool Changed = false;
10603 // Simplifies ICMP to trivial true or false by turning it into '0 == 0' or
10604 // '0 != 0'.
10605 auto TrivialCase = [&](bool TriviallyTrue) {
10607 Pred = TriviallyTrue ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE;
10608 return true;
10609 };
10610 // If we hit the max recursion limit bail out.
10611 if (Depth >= 3)
10612 return false;
10613
10614 // Canonicalize a constant to the right side.
10615 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
10616 // Check for both operands constant.
10617 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
10618 if (ConstantExpr::getICmp(Pred,
10619 LHSC->getValue(),
10620 RHSC->getValue())->isNullValue())
10621 return TrivialCase(false);
10622 return TrivialCase(true);
10623 }
10624 // Otherwise swap the operands to put the constant on the right.
10625 std::swap(LHS, RHS);
10626 Pred = ICmpInst::getSwappedPredicate(Pred);
10627 Changed = true;
10628 }
10629
10630 // If we're comparing an addrec with a value which is loop-invariant in the
10631 // addrec's loop, put the addrec on the left. Also make a dominance check,
10632 // as both operands could be addrecs loop-invariant in each other's loop.
10633 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(RHS)) {
10634 const Loop *L = AR->getLoop();
10635 if (isLoopInvariant(LHS, L) && properlyDominates(LHS, L->getHeader())) {
10636 std::swap(LHS, RHS);
10637 Pred = ICmpInst::getSwappedPredicate(Pred);
10638 Changed = true;
10639 }
10640 }
10641
10642 // If there's a constant operand, canonicalize comparisons with boundary
10643 // cases, and canonicalize *-or-equal comparisons to regular comparisons.
10644 if (const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS)) {
10645 const APInt &RA = RC->getAPInt();
10646
10647 bool SimplifiedByConstantRange = false;
10648
10649 if (!ICmpInst::isEquality(Pred)) {
10651 if (ExactCR.isFullSet())
10652 return TrivialCase(true);
10653 if (ExactCR.isEmptySet())
10654 return TrivialCase(false);
10655
10656 APInt NewRHS;
10657 CmpInst::Predicate NewPred;
10658 if (ExactCR.getEquivalentICmp(NewPred, NewRHS) &&
10659 ICmpInst::isEquality(NewPred)) {
10660 // We were able to convert an inequality to an equality.
10661 Pred = NewPred;
10662 RHS = getConstant(NewRHS);
10663 Changed = SimplifiedByConstantRange = true;
10664 }
10665 }
10666
10667 if (!SimplifiedByConstantRange) {
10668 switch (Pred) {
10669 default:
10670 break;
10671 case ICmpInst::ICMP_EQ:
10672 case ICmpInst::ICMP_NE:
10673 // Fold ((-1) * %a) + %b == 0 (equivalent to %b-%a == 0) into %a == %b.
10674 if (RA.isZero() && MatchBinarySub(LHS, LHS, RHS))
10675 Changed = true;
10676 break;
10677
10678 // The "Should have been caught earlier!" messages refer to the fact
10679 // that the ExactCR.isFullSet() or ExactCR.isEmptySet() check above
10680 // should have fired on the corresponding cases, and canonicalized the
10681 // check to trivial case.
10682
10683 case ICmpInst::ICMP_UGE:
10684 assert(!RA.isMinValue() && "Should have been caught earlier!");
10685 Pred = ICmpInst::ICMP_UGT;
10686 RHS = getConstant(RA - 1);
10687 Changed = true;
10688 break;
10689 case ICmpInst::ICMP_ULE:
10690 assert(!RA.isMaxValue() && "Should have been caught earlier!");
10691 Pred = ICmpInst::ICMP_ULT;
10692 RHS = getConstant(RA + 1);
10693 Changed = true;
10694 break;
10695 case ICmpInst::ICMP_SGE:
10696 assert(!RA.isMinSignedValue() && "Should have been caught earlier!");
10697 Pred = ICmpInst::ICMP_SGT;
10698 RHS = getConstant(RA - 1);
10699 Changed = true;
10700 break;
10701 case ICmpInst::ICMP_SLE:
10702 assert(!RA.isMaxSignedValue() && "Should have been caught earlier!");
10703 Pred = ICmpInst::ICMP_SLT;
10704 RHS = getConstant(RA + 1);
10705 Changed = true;
10706 break;
10707 }
10708 }
10709 }
10710
10711 // Check for obvious equality.
10712 if (HasSameValue(LHS, RHS)) {
10713 if (ICmpInst::isTrueWhenEqual(Pred))
10714 return TrivialCase(true);
10716 return TrivialCase(false);
10717 }
10718
10719 // If possible, canonicalize GE/LE comparisons to GT/LT comparisons, by
10720 // adding or subtracting 1 from one of the operands.
10721 switch (Pred) {
10722 case ICmpInst::ICMP_SLE:
10723 if (!getSignedRangeMax(RHS).isMaxSignedValue()) {
10724 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
10726 Pred = ICmpInst::ICMP_SLT;
10727 Changed = true;
10728 } else if (!getSignedRangeMin(LHS).isMinSignedValue()) {
10729 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS,
10731 Pred = ICmpInst::ICMP_SLT;
10732 Changed = true;
10733 }
10734 break;
10735 case ICmpInst::ICMP_SGE:
10736 if (!getSignedRangeMin(RHS).isMinSignedValue()) {
10737 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS,
10739 Pred = ICmpInst::ICMP_SGT;
10740 Changed = true;
10741 } else if (!getSignedRangeMax(LHS).isMaxSignedValue()) {
10742 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
10744 Pred = ICmpInst::ICMP_SGT;
10745 Changed = true;
10746 }
10747 break;
10748 case ICmpInst::ICMP_ULE:
10749 if (!getUnsignedRangeMax(RHS).isMaxValue()) {
10750 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
10752 Pred = ICmpInst::ICMP_ULT;
10753 Changed = true;
10754 } else if (!getUnsignedRangeMin(LHS).isMinValue()) {
10755 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS);
10756 Pred = ICmpInst::ICMP_ULT;
10757 Changed = true;
10758 }
10759 break;
10760 case ICmpInst::ICMP_UGE:
10761 if (!getUnsignedRangeMin(RHS).isMinValue()) {
10762 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
10763 Pred = ICmpInst::ICMP_UGT;
10764 Changed = true;
10765 } else if (!getUnsignedRangeMax(LHS).isMaxValue()) {
10766 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
10768 Pred = ICmpInst::ICMP_UGT;
10769 Changed = true;
10770 }
10771 break;
10772 default:
10773 break;
10774 }
10775
10776 // TODO: More simplifications are possible here.
10777
10778 // Recursively simplify until we either hit a recursion limit or nothing
10779 // changes.
10780 if (Changed)
10781 return SimplifyICmpOperands(Pred, LHS, RHS, Depth + 1);
10782
10783 return Changed;
10784}
10785
10787 return getSignedRangeMax(S).isNegative();
10788}
10789
10792}
10793
10795 return !getSignedRangeMin(S).isNegative();
10796}
10797
10800}
10801
10803 // Query push down for cases where the unsigned range is
10804 // less than sufficient.
10805 if (const auto *SExt = dyn_cast<SCEVSignExtendExpr>(S))
10806 return isKnownNonZero(SExt->getOperand(0));
10807 return getUnsignedRangeMin(S) != 0;
10808}
10809
10810std::pair<const SCEV *, const SCEV *>
10812 // Compute SCEV on entry of loop L.
10813 const SCEV *Start = SCEVInitRewriter::rewrite(S, L, *this);
10814 if (Start == getCouldNotCompute())
10815 return { Start, Start };
10816 // Compute post increment SCEV for loop L.
10817 const SCEV *PostInc = SCEVPostIncRewriter::rewrite(S, L, *this);
10818 assert(PostInc != getCouldNotCompute() && "Unexpected could not compute");
10819 return { Start, PostInc };
10820}
10821
10823 const SCEV *LHS, const SCEV *RHS) {
10824 // First collect all loops.
10826 getUsedLoops(LHS, LoopsUsed);
10827 getUsedLoops(RHS, LoopsUsed);
10828
10829 if (LoopsUsed.empty())
10830 return false;
10831
10832 // Domination relationship must be a linear order on collected loops.
10833#ifndef NDEBUG
10834 for (const auto *L1 : LoopsUsed)
10835 for (const auto *L2 : LoopsUsed)
10836 assert((DT.dominates(L1->getHeader(), L2->getHeader()) ||
10837 DT.dominates(L2->getHeader(), L1->getHeader())) &&
10838 "Domination relationship is not a linear order");
10839#endif
10840
10841 const Loop *MDL =
10842 *llvm::max_element(LoopsUsed, [&](const Loop *L1, const Loop *L2) {
10843 return DT.properlyDominates(L1->getHeader(), L2->getHeader());
10844 });
10845
10846 // Get init and post increment value for LHS.
10847 auto SplitLHS = SplitIntoInitAndPostInc(MDL, LHS);
10848 // if LHS contains unknown non-invariant SCEV then bail out.
10849 if (SplitLHS.first == getCouldNotCompute())
10850 return false;
10851 assert (SplitLHS.second != getCouldNotCompute() && "Unexpected CNC");
10852 // Get init and post increment value for RHS.
10853 auto SplitRHS = SplitIntoInitAndPostInc(MDL, RHS);
10854 // if RHS contains unknown non-invariant SCEV then bail out.
10855 if (SplitRHS.first == getCouldNotCompute())
10856 return false;
10857 assert (SplitRHS.second != getCouldNotCompute() && "Unexpected CNC");
10858 // It is possible that init SCEV contains an invariant load but it does
10859 // not dominate MDL and is not available at MDL loop entry, so we should
10860 // check it here.
10861 if (!isAvailableAtLoopEntry(SplitLHS.first, MDL) ||
10862 !isAvailableAtLoopEntry(SplitRHS.first, MDL))
10863 return false;
10864
10865 // It seems backedge guard check is faster than entry one so in some cases
10866 // it can speed up whole estimation by short circuit
10867 return isLoopBackedgeGuardedByCond(MDL, Pred, SplitLHS.second,
10868 SplitRHS.second) &&
10869 isLoopEntryGuardedByCond(MDL, Pred, SplitLHS.first, SplitRHS.first);
10870}
10871
10873 const SCEV *LHS, const SCEV *RHS) {
10874 // Canonicalize the inputs first.
10875 (void)SimplifyICmpOperands(Pred, LHS, RHS);
10876
10877 if (isKnownViaInduction(Pred, LHS, RHS))
10878 return true;
10879
10880 if (isKnownPredicateViaSplitting(Pred, LHS, RHS))
10881 return true;
10882
10883 // Otherwise see what can be done with some simple reasoning.
10884 return isKnownViaNonRecursiveReasoning(Pred, LHS, RHS);
10885}
10886
10888 const SCEV *LHS,
10889 const SCEV *RHS) {
10890 if (isKnownPredicate(Pred, LHS, RHS))
10891 return true;
10893 return false;
10894 return std::nullopt;
10895}
10896
10898 const SCEV *LHS, const SCEV *RHS,
10899 const Instruction *CtxI) {
10900 // TODO: Analyze guards and assumes from Context's block.
10901 return isKnownPredicate(Pred, LHS, RHS) ||
10903}
10904
10905std::optional<bool>
10907 const SCEV *RHS, const Instruction *CtxI) {
10908 std::optional<bool> KnownWithoutContext = evaluatePredicate(Pred, LHS, RHS);
10909 if (KnownWithoutContext)
10910 return KnownWithoutContext;
10911
10912 if (isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS))
10913 return true;
10916 LHS, RHS))
10917 return false;
10918 return std::nullopt;
10919}
10920
10922 const SCEVAddRecExpr *LHS,
10923 const SCEV *RHS) {
10924 const Loop *L = LHS->getLoop();
10925 return isLoopEntryGuardedByCond(L, Pred, LHS->getStart(), RHS) &&
10926 isLoopBackedgeGuardedByCond(L, Pred, LHS->getPostIncExpr(*this), RHS);
10927}
10928
10929std::optional<ScalarEvolution::MonotonicPredicateType>
10931 ICmpInst::Predicate Pred) {
10932 auto Result = getMonotonicPredicateTypeImpl(LHS, Pred);
10933
10934#ifndef NDEBUG
10935 // Verify an invariant: inverting the predicate should turn a monotonically
10936 // increasing change to a monotonically decreasing one, and vice versa.
10937 if (Result) {
10938 auto ResultSwapped =
10939 getMonotonicPredicateTypeImpl(LHS, ICmpInst::getSwappedPredicate(Pred));
10940
10941 assert(*ResultSwapped != *Result &&
10942 "monotonicity should flip as we flip the predicate");
10943 }
10944#endif
10945
10946 return Result;
10947}
10948
10949std::optional<ScalarEvolution::MonotonicPredicateType>
10950ScalarEvolution::getMonotonicPredicateTypeImpl(const SCEVAddRecExpr *LHS,
10951 ICmpInst::Predicate Pred) {
10952 // A zero step value for LHS means the induction variable is essentially a
10953 // loop invariant value. We don't really depend on the predicate actually
10954 // flipping from false to true (for increasing predicates, and the other way
10955 // around for decreasing predicates), all we care about is that *if* the
10956 // predicate changes then it only changes from false to true.
10957 //
10958 // A zero step value in itself is not very useful, but there may be places
10959 // where SCEV can prove X >= 0 but not prove X > 0, so it is helpful to be
10960 // as general as possible.
10961
10962 // Only handle LE/LT/GE/GT predicates.
10963 if (!ICmpInst::isRelational(Pred))
10964 return std::nullopt;
10965
10966 bool IsGreater = ICmpInst::isGE(Pred) || ICmpInst::isGT(Pred);
10967 assert((IsGreater || ICmpInst::isLE(Pred) || ICmpInst::isLT(Pred)) &&
10968 "Should be greater or less!");
10969
10970 // Check that AR does not wrap.
10971 if (ICmpInst::isUnsigned(Pred)) {
10972 if (!LHS->hasNoUnsignedWrap())
10973 return std::nullopt;
10975 }
10976 assert(ICmpInst::isSigned(Pred) &&
10977 "Relational predicate is either signed or unsigned!");
10978 if (!LHS->hasNoSignedWrap())
10979 return std::nullopt;
10980
10981 const SCEV *Step = LHS->getStepRecurrence(*this);
10982
10983 if (isKnownNonNegative(Step))
10985
10986 if (isKnownNonPositive(Step))
10988
10989 return std::nullopt;
10990}
10991
10992std::optional<ScalarEvolution::LoopInvariantPredicate>
10994 const SCEV *LHS, const SCEV *RHS,
10995 const Loop *L,
10996 const Instruction *CtxI) {
10997 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
10998 if (!isLoopInvariant(RHS, L)) {
10999 if (!isLoopInvariant(LHS, L))
11000 return std::nullopt;
11001
11002 std::swap(LHS, RHS);
11003 Pred = ICmpInst::getSwappedPredicate(Pred);
11004 }
11005
11006 const SCEVAddRecExpr *ArLHS = dyn_cast<SCEVAddRecExpr>(LHS);
11007 if (!ArLHS || ArLHS->getLoop() != L)
11008 return std::nullopt;
11009
11010 auto MonotonicType = getMonotonicPredicateType(ArLHS, Pred);
11011 if (!MonotonicType)
11012 return std::nullopt;
11013 // If the predicate "ArLHS `Pred` RHS" monotonically increases from false to
11014 // true as the loop iterates, and the backedge is control dependent on
11015 // "ArLHS `Pred` RHS" == true then we can reason as follows:
11016 //
11017 // * if the predicate was false in the first iteration then the predicate
11018 // is never evaluated again, since the loop exits without taking the
11019 // backedge.
11020 // * if the predicate was true in the first iteration then it will
11021 // continue to be true for all future iterations since it is
11022 // monotonically increasing.
11023 //
11024 // For both the above possibilities, we can replace the loop varying
11025 // predicate with its value on the first iteration of the loop (which is
11026 // loop invariant).
11027 //
11028 // A similar reasoning applies for a monotonically decreasing predicate, by
11029 // replacing true with false and false with true in the above two bullets.
11030 bool Increasing = *MonotonicType == ScalarEvolution::MonotonicallyIncreasing;
11031 auto P = Increasing ? Pred : ICmpInst::getInversePredicate(Pred);
11032
11035 RHS);
11036
11037 if (!CtxI)
11038 return std::nullopt;
11039 // Try to prove via context.
11040 // TODO: Support other cases.
11041 switch (Pred) {
11042 default:
11043 break;
11044 case ICmpInst::ICMP_ULE:
11045 case ICmpInst::ICMP_ULT: {
11046 assert(ArLHS->hasNoUnsignedWrap() && "Is a requirement of monotonicity!");
11047 // Given preconditions
11048 // (1) ArLHS does not cross the border of positive and negative parts of
11049 // range because of:
11050 // - Positive step; (TODO: lift this limitation)
11051 // - nuw - does not cross zero boundary;
11052 // - nsw - does not cross SINT_MAX boundary;
11053 // (2) ArLHS <s RHS
11054 // (3) RHS >=s 0
11055 // we can replace the loop variant ArLHS <u RHS condition with loop
11056 // invariant Start(ArLHS) <u RHS.
11057 //
11058 // Because of (1) there are two options:
11059 // - ArLHS is always negative. It means that ArLHS <u RHS is always false;
11060 // - ArLHS is always non-negative. Because of (3) RHS is also non-negative.
11061 // It means that ArLHS <s RHS <=> ArLHS <u RHS.
11062 // Because of (2) ArLHS <u RHS is trivially true.
11063 // All together it means that ArLHS <u RHS <=> Start(ArLHS) >=s 0.
11064 // We can strengthen this to Start(ArLHS) <u RHS.
11065 auto SignFlippedPred = ICmpInst::getFlippedSignednessPredicate(Pred);
11066 if (ArLHS->hasNoSignedWrap() && ArLHS->isAffine() &&
11067 isKnownPositive(ArLHS->getStepRecurrence(*this)) &&
11069 isKnownPredicateAt(SignFlippedPred, ArLHS, RHS, CtxI))
11071 RHS);
11072 }
11073 }
11074
11075 return std::nullopt;
11076}
11077
11078std::optional<ScalarEvolution::LoopInvariantPredicate>
11080 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11081 const Instruction *CtxI, const SCEV *MaxIter) {
11083 Pred, LHS, RHS, L, CtxI, MaxIter))
11084 return LIP;
11085 if (auto *UMin = dyn_cast<SCEVUMinExpr>(MaxIter))
11086 // Number of iterations expressed as UMIN isn't always great for expressing
11087 // the value on the last iteration. If the straightforward approach didn't
11088 // work, try the following trick: if the a predicate is invariant for X, it
11089 // is also invariant for umin(X, ...). So try to find something that works
11090 // among subexpressions of MaxIter expressed as umin.
11091 for (auto *Op : UMin->operands())
11093 Pred, LHS, RHS, L, CtxI, Op))
11094 return LIP;
11095 return std::nullopt;
11096}
11097
11098std::optional<ScalarEvolution::LoopInvariantPredicate>
11100 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11101 const Instruction *CtxI, const SCEV *MaxIter) {
11102 // Try to prove the following set of facts:
11103 // - The predicate is monotonic in the iteration space.
11104 // - If the check does not fail on the 1st iteration:
11105 // - No overflow will happen during first MaxIter iterations;
11106 // - It will not fail on the MaxIter'th iteration.
11107 // If the check does fail on the 1st iteration, we leave the loop and no
11108 // other checks matter.
11109
11110 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
11111 if (!isLoopInvariant(RHS, L)) {
11112 if (!isLoopInvariant(LHS, L))
11113 return std::nullopt;
11114
11115 std::swap(LHS, RHS);
11116 Pred = ICmpInst::getSwappedPredicate(Pred);
11117 }
11118
11119 auto *AR = dyn_cast<SCEVAddRecExpr>(LHS);
11120 if (!AR || AR->getLoop() != L)
11121 return std::nullopt;
11122
11123 // The predicate must be relational (i.e. <, <=, >=, >).
11124 if (!ICmpInst::isRelational(Pred))
11125 return std::nullopt;
11126
11127 // TODO: Support steps other than +/- 1.
11128 const SCEV *Step = AR->getStepRecurrence(*this);
11129 auto *One = getOne(Step->getType());
11130 auto *MinusOne = getNegativeSCEV(One);
11131 if (Step != One && Step != MinusOne)
11132 return std::nullopt;
11133
11134 // Type mismatch here means that MaxIter is potentially larger than max
11135 // unsigned value in start type, which mean we cannot prove no wrap for the
11136 // indvar.
11137 if (AR->getType() != MaxIter->getType())
11138 return std::nullopt;
11139
11140 // Value of IV on suggested last iteration.
11141 const SCEV *Last = AR->evaluateAtIteration(MaxIter, *this);
11142 // Does it still meet the requirement?
11143 if (!isLoopBackedgeGuardedByCond(L, Pred, Last, RHS))
11144 return std::nullopt;
11145 // Because step is +/- 1 and MaxIter has same type as Start (i.e. it does
11146 // not exceed max unsigned value of this type), this effectively proves
11147 // that there is no wrap during the iteration. To prove that there is no
11148 // signed/unsigned wrap, we need to check that
11149 // Start <= Last for step = 1 or Start >= Last for step = -1.
11150 ICmpInst::Predicate NoOverflowPred =
11152 if (Step == MinusOne)
11153 NoOverflowPred = CmpInst::getSwappedPredicate(NoOverflowPred);
11154 const SCEV *Start = AR->getStart();
11155 if (!isKnownPredicateAt(NoOverflowPred, Start, Last, CtxI))
11156 return std::nullopt;
11157
11158 // Everything is fine.
11159 return ScalarEvolution::LoopInvariantPredicate(Pred, Start, RHS);
11160}
11161
11162bool ScalarEvolution::isKnownPredicateViaConstantRanges(
11163 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) {
11164 if (HasSameValue(LHS, RHS))
11165 return ICmpInst::isTrueWhenEqual(Pred);
11166
11167 // This code is split out from isKnownPredicate because it is called from
11168 // within isLoopEntryGuardedByCond.
11169
11170 auto CheckRanges = [&](const ConstantRange &RangeLHS,
11171 const ConstantRange &RangeRHS) {
11172 return RangeLHS.icmp(Pred, RangeRHS);
11173 };
11174
11175 // The check at the top of the function catches the case where the values are
11176 // known to be equal.
11177 if (Pred == CmpInst::ICMP_EQ)
11178 return false;
11179
11180 if (Pred == CmpInst::ICMP_NE) {
11181 auto SL = getSignedRange(LHS);
11182 auto SR = getSignedRange(RHS);
11183 if (CheckRanges(SL, SR))
11184 return true;
11185 auto UL = getUnsignedRange(LHS);
11186 auto UR = getUnsignedRange(RHS);
11187 if (CheckRanges(UL, UR))
11188 return true;
11189 auto *Diff = getMinusSCEV(LHS, RHS);
11190 return !isa<SCEVCouldNotCompute>(Diff) && isKnownNonZero(Diff);
11191 }
11192
11193 if (CmpInst::isSigned(Pred)) {
11194 auto SL = getSignedRange(LHS);
11195 auto SR = getSignedRange(RHS);
11196 return CheckRanges(SL, SR);
11197 }
11198
11199 auto UL = getUnsignedRange(LHS);
11200 auto UR = getUnsignedRange(RHS);
11201 return CheckRanges(UL, UR);
11202}
11203
11204bool ScalarEvolution::isKnownPredicateViaNoOverflow(ICmpInst::Predicate Pred,
11205 const SCEV *LHS,
11206 const SCEV *RHS) {
11207 // Match X to (A + C1)<ExpectedFlags> and Y to (A + C2)<ExpectedFlags>, where
11208 // C1 and C2 are constant integers. If either X or Y are not add expressions,
11209 // consider them as X + 0 and Y + 0 respectively. C1 and C2 are returned via
11210 // OutC1 and OutC2.
11211 auto MatchBinaryAddToConst = [this](const SCEV *X, const SCEV *Y,
11212 APInt &OutC1, APInt &OutC2,
11213 SCEV::NoWrapFlags ExpectedFlags) {
11214 const SCEV *XNonConstOp, *XConstOp;
11215 const SCEV *YNonConstOp, *YConstOp;
11216 SCEV::NoWrapFlags XFlagsPresent;
11217 SCEV::NoWrapFlags YFlagsPresent;
11218
11219 if (!splitBinaryAdd(X, XConstOp, XNonConstOp, XFlagsPresent)) {
11220 XConstOp = getZero(X->getType());
11221 XNonConstOp = X;
11222 XFlagsPresent = ExpectedFlags;
11223 }
11224 if (!isa<SCEVConstant>(XConstOp) ||
11225 (XFlagsPresent & ExpectedFlags) != ExpectedFlags)
11226 return false;
11227
11228 if (!splitBinaryAdd(Y, YConstOp, YNonConstOp, YFlagsPresent)) {
11229 YConstOp = getZero(Y->getType());
11230 YNonConstOp = Y;
11231 YFlagsPresent = ExpectedFlags;
11232 }
11233
11234 if (!isa<SCEVConstant>(YConstOp) ||
11235 (YFlagsPresent & ExpectedFlags) != ExpectedFlags)
11236 return false;
11237
11238 if (YNonConstOp != XNonConstOp)
11239 return false;
11240
11241 OutC1 = cast<SCEVConstant>(XConstOp)->getAPInt();
11242 OutC2 = cast<SCEVConstant>(YConstOp)->getAPInt();
11243
11244 return true;
11245 };
11246
11247 APInt C1;
11248 APInt C2;
11249
11250 switch (Pred) {
11251 default:
11252 break;
11253
11254 case ICmpInst::ICMP_SGE:
11255 std::swap(LHS, RHS);
11256 [[fallthrough]];
11257 case ICmpInst::ICMP_SLE:
11258 // (X + C1)<nsw> s<= (X + C2)<nsw> if C1 s<= C2.
11259 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.sle(C2))
11260 return true;
11261
11262 break;
11263
11264 case ICmpInst::ICMP_SGT:
11265 std::swap(LHS, RHS);
11266 [[fallthrough]];
11267 case ICmpInst::ICMP_SLT:
11268 // (X + C1)<nsw> s< (X + C2)<nsw> if C1 s< C2.
11269 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.slt(C2))
11270 return true;
11271
11272 break;
11273
11274 case ICmpInst::ICMP_UGE:
11275 std::swap(LHS, RHS);
11276 [[fallthrough]];
11277 case ICmpInst::ICMP_ULE:
11278 // (X + C1)<nuw> u<= (X + C2)<nuw> for C1 u<= C2.
11279 if (MatchBinaryAddToConst(RHS, LHS, C2, C1, SCEV::FlagNUW) && C1.ule(C2))
11280 return true;
11281
11282 break;
11283
11284 case ICmpInst::ICMP_UGT:
11285 std::swap(LHS, RHS);
11286 [[fallthrough]];
11287 case ICmpInst::ICMP_ULT:
11288 // (X + C1)<nuw> u< (X + C2)<nuw> if C1 u< C2.
11289 if (MatchBinaryAddToConst(RHS, LHS, C2, C1, SCEV::FlagNUW) && C1.ult(C2))
11290 return true;
11291 break;
11292 }
11293
11294 return false;
11295}
11296
11297bool ScalarEvolution::isKnownPredicateViaSplitting(ICmpInst::Predicate Pred,
11298 const SCEV *LHS,
11299 const SCEV *RHS) {
11300 if (Pred != ICmpInst::ICMP_ULT || ProvingSplitPredicate)
11301 return false;
11302
11303 // Allowing arbitrary number of activations of isKnownPredicateViaSplitting on
11304 // the stack can result in exponential time complexity.
11305 SaveAndRestore Restore(ProvingSplitPredicate, true);
11306
11307 // If L >= 0 then I `ult` L <=> I >= 0 && I `slt` L
11308 //
11309 // To prove L >= 0 we use isKnownNonNegative whereas to prove I >= 0 we use
11310 // isKnownPredicate. isKnownPredicate is more powerful, but also more
11311 // expensive; and using isKnownNonNegative(RHS) is sufficient for most of the
11312 // interesting cases seen in practice. We can consider "upgrading" L >= 0 to
11313 // use isKnownPredicate later if needed.
11314 return isKnownNonNegative(RHS) &&
11317}
11318
11319bool ScalarEvolution::isImpliedViaGuard(const BasicBlock *BB,
11321 const SCEV *LHS, const SCEV *RHS) {
11322 // No need to even try if we know the module has no guards.
11323 if (!HasGuards)
11324 return false;
11325
11326 return any_of(*BB, [&](const Instruction &I) {
11327 using namespace llvm::PatternMatch;
11328
11329 Value *Condition;
11330 return match(&I, m_Intrinsic<Intrinsic::experimental_guard>(
11331 m_Value(Condition))) &&
11332 isImpliedCond(Pred, LHS, RHS, Condition, false);
11333 });
11334}
11335
11336/// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is
11337/// protected by a conditional between LHS and RHS. This is used to
11338/// to eliminate casts.
11339bool
11342 const SCEV *LHS, const SCEV *RHS) {
11343 // Interpret a null as meaning no loop, where there is obviously no guard
11344 // (interprocedural conditions notwithstanding). Do not bother about
11345 // unreachable loops.
11346 if (!L || !DT.isReachableFromEntry(L->getHeader()))
11347 return true;
11348
11349 if (VerifyIR)
11350 assert(!verifyFunction(*L->getHeader()->getParent(), &dbgs()) &&
11351 "This cannot be done on broken IR!");
11352
11353
11354 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
11355 return true;
11356
11357 BasicBlock *Latch = L->getLoopLatch();
11358 if (!Latch)
11359 return false;
11360
11361 BranchInst *LoopContinuePredicate =
11362 dyn_cast<BranchInst>(Latch->getTerminator());
11363 if (LoopContinuePredicate && LoopContinuePredicate->isConditional() &&
11364 isImpliedCond(Pred, LHS, RHS,
11365 LoopContinuePredicate->getCondition(),
11366 LoopContinuePredicate->getSuccessor(0) != L->getHeader()))
11367 return true;
11368
11369 // We don't want more than one activation of the following loops on the stack
11370 // -- that can lead to O(n!) time complexity.
11371 if (WalkingBEDominatingConds)
11372 return false;
11373
11374 SaveAndRestore ClearOnExit(WalkingBEDominatingConds, true);
11375
11376 // See if we can exploit a trip count to prove the predicate.
11377 const auto &BETakenInfo = getBackedgeTakenInfo(L);
11378 const SCEV *LatchBECount = BETakenInfo.getExact(Latch, this);
11379 if (LatchBECount != getCouldNotCompute()) {
11380 // We know that Latch branches back to the loop header exactly
11381 // LatchBECount times. This means the backdege condition at Latch is
11382 // equivalent to "{0,+,1} u< LatchBECount".
11383 Type *Ty = LatchBECount->getType();
11384 auto NoWrapFlags = SCEV::NoWrapFlags(SCEV::FlagNUW | SCEV::FlagNW);
11385 const SCEV *LoopCounter =
11386 getAddRecExpr(getZero(Ty), getOne(Ty), L, NoWrapFlags);
11387 if (isImpliedCond(Pred, LHS, RHS, ICmpInst::ICMP_ULT, LoopCounter,
11388 LatchBECount))
11389 return true;
11390 }
11391
11392 // Check conditions due to any @llvm.assume intrinsics.
11393 for (auto &AssumeVH : AC.assumptions()) {
11394 if (!AssumeVH)
11395 continue;
11396 auto *CI = cast<CallInst>(AssumeVH);
11397 if (!DT.dominates(CI, Latch->getTerminator()))
11398 continue;
11399
11400 if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false))
11401 return true;
11402 }
11403
11404 if (isImpliedViaGuard(Latch, Pred, LHS, RHS))
11405 return true;
11406
11407 for (DomTreeNode *DTN = DT[Latch], *HeaderDTN = DT[L->getHeader()];
11408 DTN != HeaderDTN; DTN = DTN->getIDom()) {
11409 assert(DTN && "should reach the loop header before reaching the root!");
11410
11411 BasicBlock *BB = DTN->getBlock();
11412 if (isImpliedViaGuard(BB, Pred, LHS, RHS))
11413 return true;
11414
11415 BasicBlock *PBB = BB->getSinglePredecessor();
11416 if (!PBB)
11417 continue;
11418
11419 BranchInst *ContinuePredicate = dyn_cast<BranchInst>(PBB->getTerminator());
11420 if (!ContinuePredicate || !ContinuePredicate->isConditional())
11421 continue;
11422
11423 Value *Condition = ContinuePredicate->getCondition();
11424
11425 // If we have an edge `E` within the loop body that dominates the only
11426 // latch, the condition guarding `E` also guards the backedge. This
11427 // reasoning works only for loops with a single latch.
11428
11429 BasicBlockEdge DominatingEdge(PBB, BB);
11430 if (DominatingEdge.isSingleEdge()) {
11431 // We're constructively (and conservatively) enumerating edges within the
11432 // loop body that dominate the latch. The dominator tree better agree
11433 // with us on this:
11434 assert(DT.dominates(DominatingEdge, Latch) && "should be!");
11435
11436 if (isImpliedCond(Pred, LHS, RHS, Condition,
11437 BB != ContinuePredicate->getSuccessor(0)))
11438 return true;
11439 }
11440 }
11441
11442 return false;
11443}
11444
11447 const SCEV *LHS,
11448 const SCEV *RHS) {
11449 // Do not bother proving facts for unreachable code.
11450 if (!DT.isReachableFromEntry(BB))
11451 return true;
11452 if (VerifyIR)
11453 assert(!verifyFunction(*BB->getParent(), &dbgs()) &&
11454 "This cannot be done on broken IR!");
11455
11456 // If we cannot prove strict comparison (e.g. a > b), maybe we can prove
11457 // the facts (a >= b && a != b) separately. A typical situation is when the
11458 // non-strict comparison is known from ranges and non-equality is known from
11459 // dominating predicates. If we are proving strict comparison, we always try
11460 // to prove non-equality and non-strict comparison separately.
11461 auto NonStrictPredicate = ICmpInst::getNonStrictPredicate(Pred);
11462 const bool ProvingStrictComparison = (Pred != NonStrictPredicate);
11463 bool ProvedNonStrictComparison = false;
11464 bool ProvedNonEquality = false;
11465
11466 auto SplitAndProve =
11467 [&](std::function<bool(ICmpInst::Predicate)> Fn) -> bool {
11468 if (!ProvedNonStrictComparison)
11469 ProvedNonStrictComparison = Fn(NonStrictPredicate);
11470 if (!ProvedNonEquality)
11471 ProvedNonEquality = Fn(ICmpInst::ICMP_NE);
11472 if (ProvedNonStrictComparison && ProvedNonEquality)
11473 return true;
11474 return false;
11475 };
11476
11477 if (ProvingStrictComparison) {
11478 auto ProofFn = [&](ICmpInst::Predicate P) {
11479 return isKnownViaNonRecursiveReasoning(P, LHS, RHS);
11480 };
11481 if (SplitAndProve(ProofFn))
11482 return true;
11483 }
11484
11485 // Try to prove (Pred, LHS, RHS) using isImpliedCond.
11486 auto ProveViaCond = [&](const Value *Condition, bool Inverse) {
11487 const Instruction *CtxI = &BB->front();
11488 if (isImpliedCond(Pred, LHS, RHS, Condition, Inverse, CtxI))
11489 return true;
11490 if (ProvingStrictComparison) {
11491 auto ProofFn = [&](ICmpInst::Predicate P) {
11492 return isImpliedCond(P, LHS, RHS, Condition, Inverse, CtxI);
11493 };
11494 if (SplitAndProve(ProofFn))
11495 return true;
11496 }
11497 return false;
11498 };
11499
11500 // Starting at the block's predecessor, climb up the predecessor chain, as long
11501 // as there are predecessors that can be found that have unique successors
11502 // leading to the original block.
11503 const Loop *ContainingLoop = LI.getLoopFor(BB);
11504 const BasicBlock *PredBB;
11505 if (ContainingLoop && ContainingLoop->getHeader() == BB)
11506 PredBB = ContainingLoop->getLoopPredecessor();
11507 else
11508 PredBB = BB->getSinglePredecessor();
11509 for (std::pair<const BasicBlock *, const BasicBlock *> Pair(PredBB, BB);
11510 Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
11511 const BranchInst *BlockEntryPredicate =
11512 dyn_cast<BranchInst>(Pair.first->getTerminator());
11513 if (!BlockEntryPredicate || BlockEntryPredicate->isUnconditional())
11514 continue;
11515
11516 if (ProveViaCond(BlockEntryPredicate->getCondition(),
11517 BlockEntryPredicate->getSuccessor(0) != Pair.second))
11518 return true;
11519 }
11520
11521 // Check conditions due to any @llvm.assume intrinsics.
11522 for (auto &AssumeVH : AC.assumptions()) {
11523 if (!AssumeVH)
11524 continue;
11525 auto *CI = cast<CallInst>(AssumeVH);
11526 if (!DT.dominates(CI, BB))
11527 continue;
11528
11529 if (ProveViaCond(CI->getArgOperand(0), false))
11530 return true;
11531 }
11532
11533 // Check conditions due to any @llvm.experimental.guard intrinsics.
11534 auto *GuardDecl = F.getParent()->getFunction(
11535 Intrinsic::getName(Intrinsic::experimental_guard));
11536 if (GuardDecl)
11537 for (const auto *GU : GuardDecl->users())
11538 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
11539 if (Guard->getFunction() == BB->getParent() && DT.dominates(Guard, BB))
11540 if (ProveViaCond(Guard->getArgOperand(0), false))
11541 return true;
11542 return false;
11543}
11544
11547 const SCEV *LHS,
11548 const SCEV *RHS) {
11549 // Interpret a null as meaning no loop, where there is obviously no guard
11550 // (interprocedural conditions notwithstanding).
11551 if (!L)
11552 return false;
11553
11554 // Both LHS and RHS must be available at loop entry.
11556 "LHS is not available at Loop Entry");
11558 "RHS is not available at Loop Entry");
11559
11560 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
11561 return true;
11562
11563 return isBasicBlockEntryGuardedByCond(L->getHeader(), Pred, LHS, RHS);
11564}
11565
11566bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
11567 const SCEV *RHS,
11568 const Value *FoundCondValue, bool Inverse,
11569 const Instruction *CtxI) {
11570 // False conditions implies anything. Do not bother analyzing it further.
11571 if (FoundCondValue ==
11572 ConstantInt::getBool(FoundCondValue->getContext(), Inverse))
11573 return true;
11574
11575 if (!PendingLoopPredicates.insert(FoundCondValue).second)
11576 return false;
11577
11578 auto ClearOnExit =
11579 make_scope_exit([&]() { PendingLoopPredicates.erase(FoundCondValue); });
11580
11581 // Recursively handle And and Or conditions.
11582 const Value *Op0, *Op1;
11583 if (match(FoundCondValue, m_LogicalAnd(m_Value(Op0), m_Value(Op1)))) {
11584 if (!Inverse)
11585 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
11586 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
11587 } else if (match(FoundCondValue, m_LogicalOr(m_Value(Op0), m_Value(Op1)))) {
11588 if (Inverse)
11589 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
11590 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
11591 }
11592
11593 const ICmpInst *ICI = dyn_cast<ICmpInst>(FoundCondValue);
11594 if (!ICI) return false;
11595
11596 // Now that we found a conditional branch that dominates the loop or controls
11597 // the loop latch. Check to see if it is the comparison we are looking for.
11598 ICmpInst::Predicate FoundPred;
11599 if (Inverse)
11600 FoundPred = ICI->getInversePredicate();
11601 else
11602 FoundPred = ICI->getPredicate();
11603
11604 const SCEV *FoundLHS = getSCEV(ICI->getOperand(0));
11605 const SCEV *FoundRHS = getSCEV(ICI->getOperand(1));
11606
11607 return isImpliedCond(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS, CtxI);
11608}
11609
11610bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
11611 const SCEV *RHS,
11612 ICmpInst::Predicate FoundPred,
11613 const SCEV *FoundLHS, const SCEV *FoundRHS,
11614 const Instruction *CtxI) {
11615 // Balance the types.
11616 if (getTypeSizeInBits(LHS->getType()) <
11617 getTypeSizeInBits(FoundLHS->getType())) {
11618 // For unsigned and equality predicates, try to prove that both found
11619 // operands fit into narrow unsigned range. If so, try to prove facts in
11620 // narrow types.
11621 if (!CmpInst::isSigned(FoundPred) && !FoundLHS->getType()->isPointerTy() &&
11622 !FoundRHS->getType()->isPointerTy()) {
11623 auto *NarrowType = LHS->getType();
11624 auto *WideType = FoundLHS->getType();
11625 auto BitWidth = getTypeSizeInBits(NarrowType);
11626 const SCEV *MaxValue = getZeroExtendExpr(
11628 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundLHS,
11629 MaxValue) &&
11630 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundRHS,
11631 MaxValue)) {
11632 const SCEV *TruncFoundLHS = getTruncateExpr(FoundLHS, NarrowType);
11633 const SCEV *TruncFoundRHS = getTruncateExpr(FoundRHS, NarrowType);
11634 if (isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, TruncFoundLHS,
11635 TruncFoundRHS, CtxI))
11636 return true;
11637 }
11638 }
11639
11640 if (LHS->getType()->isPointerTy() || RHS->getType()->isPointerTy())
11641 return false;
11642 if (CmpInst::isSigned(Pred)) {
11643 LHS = getSignExtendExpr(LHS, FoundLHS->getType());
11644 RHS = getSignExtendExpr(RHS, FoundLHS->getType());
11645 } else {
11646 LHS = getZeroExtendExpr(LHS, FoundLHS->getType());
11647 RHS = getZeroExtendExpr(RHS, FoundLHS->getType());
11648 }
11649 } else if (getTypeSizeInBits(LHS->getType()) >
11650 getTypeSizeInBits(FoundLHS->getType())) {
11651 if (FoundLHS->getType()->isPointerTy() || FoundRHS->getType()->isPointerTy())
11652 return false;
11653 if (CmpInst::isSigned(FoundPred)) {
11654 FoundLHS = getSignExtendExpr(FoundLHS, LHS->getType());
11655 FoundRHS = getSignExtendExpr(FoundRHS, LHS->getType());
11656 } else {
11657 FoundLHS = getZeroExtendExpr(FoundLHS, LHS->getType());
11658 FoundRHS = getZeroExtendExpr(FoundRHS, LHS->getType());
11659 }
11660 }
11661 return isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, FoundLHS,
11662 FoundRHS, CtxI);
11663}
11664
11665bool ScalarEvolution::isImpliedCondBalancedTypes(
11666 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
11667 ICmpInst::Predicate FoundPred, const SCEV *FoundLHS, const SCEV *FoundRHS,
11668 const Instruction *CtxI) {
11670 getTypeSizeInBits(FoundLHS->getType()) &&
11671 "Types should be balanced!");
11672 // Canonicalize the query to match the way instcombine will have
11673 // canonicalized the comparison.
11674 if (SimplifyICmpOperands(Pred, LHS, RHS))
11675 if (LHS == RHS)
11676 return CmpInst::isTrueWhenEqual(Pred);
11677 if (SimplifyICmpOperands(FoundPred, FoundLHS, FoundRHS))
11678 if (FoundLHS == FoundRHS)
11679 return CmpInst::isFalseWhenEqual(FoundPred);
11680
11681 // Check to see if we can make the LHS or RHS match.
11682 if (LHS == FoundRHS || RHS == FoundLHS) {
11683 if (isa<SCEVConstant>(RHS)) {
11684 std::swap(FoundLHS, FoundRHS);
11685 FoundPred = ICmpInst::getSwappedPredicate(FoundPred);
11686 } else {
11687 std::swap(LHS, RHS);
11688 Pred = ICmpInst::getSwappedPredicate(Pred);
11689 }
11690 }
11691
11692 // Check whether the found predicate is the same as the desired predicate.
11693 if (FoundPred == Pred)
11694 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI);
11695
11696 // Check whether swapping the found predicate makes it the same as the
11697 // desired predicate.
11698 if (ICmpInst::getSwappedPredicate(FoundPred) == Pred) {
11699 // We can write the implication
11700 // 0. LHS Pred RHS <- FoundLHS SwapPred FoundRHS
11701 // using one of the following ways:
11702 // 1. LHS Pred RHS <- FoundRHS Pred FoundLHS
11703 // 2. RHS SwapPred LHS <- FoundLHS SwapPred FoundRHS
11704 // 3. LHS Pred RHS <- ~FoundLHS Pred ~FoundRHS
11705 // 4. ~LHS SwapPred ~RHS <- FoundLHS SwapPred FoundRHS
11706 // Forms 1. and 2. require swapping the operands of one condition. Don't
11707 // do this if it would break canonical constant/addrec ordering.
11708 if (!isa<SCEVConstant>(RHS) && !isa<SCEVAddRecExpr>(LHS))
11709 return isImpliedCondOperands(FoundPred, RHS, LHS, FoundLHS, FoundRHS,
11710 CtxI);
11711 if (!isa<SCEVConstant>(FoundRHS) && !isa<SCEVAddRecExpr>(FoundLHS))
11712 return isImpliedCondOperands(Pred, LHS, RHS, FoundRHS, FoundLHS, CtxI);
11713
11714 // There's no clear preference between forms 3. and 4., try both. Avoid
11715 // forming getNotSCEV of pointer values as the resulting subtract is
11716 // not legal.
11717 if (!LHS->getType()->isPointerTy() && !RHS->getType()->isPointerTy() &&
11718 isImpliedCondOperands(FoundPred, getNotSCEV(LHS), getNotSCEV(RHS),
11719 FoundLHS, FoundRHS, CtxI))
11720 return true;
11721
11722 if (!FoundLHS->getType()->isPointerTy() &&
11723 !FoundRHS->getType()->isPointerTy() &&
11724 isImpliedCondOperands(Pred, LHS, RHS, getNotSCEV(FoundLHS),
11725 getNotSCEV(FoundRHS), CtxI))
11726 return true;
11727
11728 return false;
11729 }
11730
11731 auto IsSignFlippedPredicate = [](CmpInst::Predicate P1,
11732 CmpInst::Predicate P2) {
11733 assert(P1 != P2 && "Handled earlier!");
11734 return CmpInst::isRelational(P2) &&
11736 };
11737 if (IsSignFlippedPredicate(Pred, FoundPred)) {
11738 // Unsigned comparison is the same as signed comparison when both the
11739 // operands are non-negative or negative.
11740 if ((isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) ||
11741 (isKnownNegative(FoundLHS) && isKnownNegative(FoundRHS)))
11742 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI);
11743 // Create local copies that we can freely swap and canonicalize our
11744 // conditions to "le/lt".
11745 ICmpInst::Predicate CanonicalPred = Pred, CanonicalFoundPred = FoundPred;
11746 const SCEV *CanonicalLHS = LHS, *CanonicalRHS = RHS,
11747 *CanonicalFoundLHS = FoundLHS, *CanonicalFoundRHS = FoundRHS;
11748 if (ICmpInst::isGT(CanonicalPred) || ICmpInst::isGE(CanonicalPred)) {
11749 CanonicalPred = ICmpInst::getSwappedPredicate(CanonicalPred);
11750 CanonicalFoundPred = ICmpInst::getSwappedPredicate(CanonicalFoundPred);
11751 std::swap(CanonicalLHS, CanonicalRHS);
11752 std::swap(CanonicalFoundLHS, CanonicalFoundRHS);
11753 }
11754 assert((ICmpInst::isLT(CanonicalPred) || ICmpInst::isLE(CanonicalPred)) &&
11755 "Must be!");
11756 assert((ICmpInst::isLT(CanonicalFoundPred) ||
11757 ICmpInst::isLE(CanonicalFoundPred)) &&
11758 "Must be!");
11759 if (ICmpInst::isSigned(CanonicalPred) && isKnownNonNegative(CanonicalRHS))
11760 // Use implication:
11761 // x <u y && y >=s 0 --> x <s y.
11762 // If we can prove the left part, the right part is also proven.
11763 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
11764 CanonicalRHS, CanonicalFoundLHS,
11765 CanonicalFoundRHS);
11766 if (ICmpInst::isUnsigned(CanonicalPred) && isKnownNegative(CanonicalRHS))
11767 // Use implication:
11768 // x <s y && y <s 0 --> x <u y.
11769 // If we can prove the left part, the right part is also proven.
11770 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
11771 CanonicalRHS, CanonicalFoundLHS,
11772 CanonicalFoundRHS);
11773 }
11774
11775 // Check if we can make progress by sharpening ranges.
11776 if (FoundPred == ICmpInst::ICMP_NE &&
11777 (isa<SCEVConstant>(FoundLHS) || isa<SCEVConstant>(FoundRHS))) {
11778
11779 const SCEVConstant *C = nullptr;
11780 const SCEV *V = nullptr;
11781
11782 if (isa<SCEVConstant>(FoundLHS)) {
11783 C = cast<SCEVConstant>(FoundLHS);
11784 V = FoundRHS;
11785 } else {
11786 C = cast<SCEVConstant>(FoundRHS);
11787 V = FoundLHS;
11788 }
11789
11790 // The guarding predicate tells us that C != V. If the known range
11791 // of V is [C, t), we can sharpen the range to [C + 1, t). The
11792 // range we consider has to correspond to same signedness as the
11793 // predicate we're interested in folding.
11794
11795 APInt Min = ICmpInst::isSigned(Pred) ?
11797
11798 if (Min == C->getAPInt()) {
11799 // Given (V >= Min && V != Min) we conclude V >= (Min + 1).
11800 // This is true even if (Min + 1) wraps around -- in case of
11801 // wraparound, (Min + 1) < Min, so (V >= Min => V >= (Min + 1)).
11802
11803 APInt SharperMin = Min + 1;
11804
11805 switch (Pred) {
11806 case ICmpInst::ICMP_SGE:
11807 case ICmpInst::ICMP_UGE:
11808 // We know V `Pred` SharperMin. If this implies LHS `Pred`
11809 // RHS, we're done.
11810 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(SharperMin),
11811 CtxI))
11812 return true;
11813 [[fallthrough]];
11814
11815 case ICmpInst::ICMP_SGT:
11816 case ICmpInst::ICMP_UGT:
11817 // We know from the range information that (V `Pred` Min ||
11818 // V == Min). We know from the guarding condition that !(V
11819 // == Min). This gives us
11820 //
11821 // V `Pred` Min || V == Min && !(V == Min)
11822 // => V `Pred` Min
11823 //
11824 // If V `Pred` Min implies LHS `Pred` RHS, we're done.
11825
11826 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min), CtxI))
11827 return true;
11828 break;
11829
11830 // `LHS < RHS` and `LHS <= RHS` are handled in the same way as `RHS > LHS` and `RHS >= LHS` respectively.
11831 case ICmpInst::ICMP_SLE:
11832 case ICmpInst::ICMP_ULE:
11833 if (isImpliedCondOperands(CmpInst::getSwappedPredicate(Pred), RHS,
11834 LHS, V, getConstant(SharperMin), CtxI))
11835 return true;
11836 [[fallthrough]];
11837
11838 case ICmpInst::ICMP_SLT:
11839 case ICmpInst::ICMP_ULT:
11840 if (isImpliedCondOperands(CmpInst::getSwappedPredicate(Pred), RHS,
11841 LHS, V, getConstant(Min), CtxI))
11842 return true;
11843 break;
11844
11845 default:
11846 // No change
11847 break;
11848 }
11849 }
11850 }
11851
11852 // Check whether the actual condition is beyond sufficient.
11853 if (FoundPred == ICmpInst::ICMP_EQ)
11854 if (ICmpInst::isTrueWhenEqual(Pred))
11855 if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
11856 return true;
11857 if (Pred == ICmpInst::ICMP_NE)
11858 if (!ICmpInst::isTrueWhenEqual(FoundPred))
11859 if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
11860 return true;
11861
11862 if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS))
11863 return true;
11864
11865 // Otherwise assume the worst.
11866 return false;
11867}
11868
11869bool ScalarEvolution::splitBinaryAdd(const SCEV *Expr,
11870 const SCEV *&L, const SCEV *&R,
11871 SCEV::NoWrapFlags &Flags) {
11872 const auto *AE = dyn_cast<SCEVAddExpr>(Expr);
11873 if (!AE || AE->getNumOperands() != 2)
11874 return false;
11875
11876 L = AE->getOperand(0);
11877 R = AE->getOperand(1);
11878 Flags = AE->getNoWrapFlags();
11879 return true;
11880}
11881
11882std::optional<APInt>
11884 // We avoid subtracting expressions here because this function is usually
11885 // fairly deep in the call stack (i.e. is called many times).
11886
11887 // X - X = 0.
11888 if (More == Less)
11889 return APInt(getTypeSizeInBits(More->getType()), 0);
11890
11891 if (isa<SCEVAddRecExpr>(Less) && isa<SCEVAddRecExpr>(More)) {
11892 const auto *LAR = cast<SCEVAddRecExpr>(Less);
11893 const auto *MAR = cast<SCEVAddRecExpr>(More);
11894
11895 if (LAR->getLoop() != MAR->getLoop())
11896 return std::nullopt;
11897
11898 // We look at affine expressions only; not for correctness but to keep
11899 // getStepRecurrence cheap.
11900 if (!LAR->isAffine() || !MAR->isAffine())
11901 return std::nullopt;
11902
11903 if (LAR->getStepRecurrence(*this) != MAR->getStepRecurrence(*this))
11904 return std::nullopt;
11905
11906 Less = LAR->getStart();
11907 More = MAR->getStart();
11908
11909 // fall through
11910 }
11911
11912 if (isa<SCEVConstant>(Less) && isa<SCEVConstant>(More)) {
11913 const auto &M = cast<SCEVConstant>(More)->getAPInt();
11914 const auto &L = cast<SCEVConstant>(Less)->getAPInt();
11915 return M - L;
11916 }
11917
11918 SCEV::NoWrapFlags Flags;
11919 const SCEV *LLess = nullptr, *RLess = nullptr;
11920 const SCEV *LMore = nullptr, *RMore = nullptr;
11921 const SCEVConstant *C1 = nullptr, *C2 = nullptr;
11922 // Compare (X + C1) vs X.
11923 if (splitBinaryAdd(Less, LLess, RLess, Flags))
11924 if ((C1 = dyn_cast<SCEVConstant>(LLess)))
11925 if (RLess == More)
11926 return -(C1->getAPInt());
11927
11928 // Compare X vs (X + C2).
11929 if (splitBinaryAdd(More, LMore, RMore, Flags))
11930 if ((C2 = dyn_cast<SCEVConstant>(LMore)))
11931 if (RMore == Less)
11932 return C2->getAPInt();
11933
11934 // Compare (X + C1) vs (X + C2).
11935 if (C1 && C2 && RLess == RMore)
11936 return C2->getAPInt() - C1->getAPInt();
11937
11938 return std::nullopt;
11939}
11940
11941bool ScalarEvolution::isImpliedCondOperandsViaAddRecStart(
11942 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
11943 const SCEV *FoundLHS, const SCEV *FoundRHS, const Instruction *CtxI) {
11944 // Try to recognize the following pattern:
11945 //
11946 // FoundRHS = ...
11947 // ...
11948 // loop:
11949 // FoundLHS = {Start,+,W}
11950 // context_bb: // Basic block from the same loop
11951 // known(Pred, FoundLHS, FoundRHS)
11952 //
11953 // If some predicate is known in the context of a loop, it is also known on
11954 // each iteration of this loop, including the first iteration. Therefore, in
11955 // this case, `FoundLHS Pred FoundRHS` implies `Start Pred FoundRHS`. Try to
11956 // prove the original pred using this fact.
11957 if (!CtxI)
11958 return false;
11959 const BasicBlock *ContextBB = CtxI->getParent();
11960 // Make sure AR varies in the context block.
11961 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundLHS)) {
11962 const Loop *L = AR->getLoop();
11963 // Make sure that context belongs to the loop and executes on 1st iteration
11964 // (if it ever executes at all).
11965 if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch()))
11966 return false;
11967 if (!isAvailableAtLoopEntry(FoundRHS, AR->getLoop()))
11968 return false;
11969 return isImpliedCondOperands(Pred, LHS, RHS, AR->getStart(), FoundRHS);
11970 }
11971
11972 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundRHS)) {
11973 const Loop *L = AR->getLoop();
11974 // Make sure that context belongs to the loop and executes on 1st iteration
11975 // (if it ever executes at all).
11976 if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch()))
11977 return false;
11978 if (!isAvailableAtLoopEntry(FoundLHS, AR->getLoop()))
11979 return false;
11980 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, AR->getStart());
11981 }
11982
11983 return false;
11984}
11985
11986bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow(
11987 ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
11988 const SCEV *FoundLHS, const SCEV *FoundRHS) {
11989 if (Pred != CmpInst::ICMP_SLT && Pred != CmpInst::ICMP_ULT)
11990 return false;
11991
11992 const auto *AddRecLHS = dyn_cast<SCEVAddRecExpr>(LHS);
11993 if (!AddRecLHS)
11994 return false;
11995
11996 const auto *AddRecFoundLHS = dyn_cast<SCEVAddRecExpr>(FoundLHS);
11997 if (!AddRecFoundLHS)
11998 return false;
11999
12000 // We'd like to let SCEV reason about control dependencies, so we constrain
12001 // both the inequalities to be about add recurrences on the same loop. This
12002 // way we can use isLoopEntryGuardedByCond later.
12003
12004 const Loop *L = AddRecFoundLHS->getLoop();
12005 if (L != AddRecLHS->getLoop())
12006 return false;
12007
12008 // FoundLHS u< FoundRHS u< -C => (FoundLHS + C) u< (FoundRHS + C) ... (1)
12009 //
12010 // FoundLHS s< FoundRHS s< INT_MIN - C => (FoundLHS + C) s< (FoundRHS + C)
12011 // ... (2)
12012 //
12013 // Informal proof for (2), assuming (1) [*]:
12014 //
12015 // We'll also assume (A s< B) <=> ((A + INT_MIN) u< (B + INT_MIN)) ... (3)[**]
12016 //
12017 // Then
12018 //
12019 // FoundLHS s< FoundRHS s< INT_MIN - C
12020 // <=> (FoundLHS + INT_MIN) u< (FoundRHS + INT_MIN) u< -C [ using (3) ]
12021 // <=> (FoundLHS + INT_MIN + C) u< (FoundRHS + INT_MIN + C) [ using (1) ]
12022 // <=> (FoundLHS + INT_MIN + C + INT_MIN) s<
12023 // (FoundRHS + INT_MIN + C + INT_MIN) [ using (3) ]
12024 // <=> FoundLHS + C s< FoundRHS + C
12025 //
12026 // [*]: (1) can be proved by ruling out overflow.
12027 //
12028 // [**]: This can be proved by analyzing all the four possibilities:
12029 // (A s< 0, B s< 0), (A s< 0, B s>= 0), (A s>= 0, B s< 0) and
12030 // (A s>= 0, B s>= 0).
12031 //
12032 // Note:
12033 // Despite (2), "FoundRHS s< INT_MIN - C" does not mean that "FoundRHS + C"
12034 // will not sign underflow. For instance, say FoundLHS = (i8 -128), FoundRHS
12035 // = (i8 -127) and C = (i8 -100). Then INT_MIN - C = (i8 -28), and FoundRHS
12036 // s< (INT_MIN - C). Lack of sign overflow / underflow in "FoundRHS + C" is
12037 // neither necessary nor sufficient to prove "(FoundLHS + C) s< (FoundRHS +
12038 // C)".
12039
12040 std::optional<APInt> LDiff = computeConstantDifference(LHS, FoundLHS);
12041 std::optional<APInt> RDiff = computeConstantDifference(RHS, FoundRHS);
12042 if (!LDiff || !RDiff || *LDiff != *RDiff)
12043 return false;
12044
12045 if (LDiff->isMinValue())
12046 return true;
12047
12048 APInt FoundRHSLimit;
12049
12050 if (Pred == CmpInst::ICMP_ULT) {
12051 FoundRHSLimit = -(*RDiff);
12052 } else {
12053 assert(Pred == CmpInst::ICMP_SLT && "Checked above!");
12054 FoundRHSLimit = APInt::getSignedMinValue(getTypeSizeInBits(RHS->getType())) - *RDiff;
12055 }
12056
12057 // Try to prove (1) or (2), as needed.
12058 return isAvailableAtLoopEntry(FoundRHS, L) &&
12059 isLoopEntryGuardedByCond(L, Pred, FoundRHS,
12060 getConstant(FoundRHSLimit));
12061}
12062
12063bool ScalarEvolution::isImpliedViaMerge(ICmpInst::Predicate Pred,
12064 const SCEV *LHS, const SCEV *RHS,
12065 const SCEV *FoundLHS,
12066 const SCEV *FoundRHS, unsigned Depth) {
12067 const PHINode *LPhi = nullptr, *RPhi = nullptr;
12068
12069 auto ClearOnExit = make_scope_exit([&]() {
12070 if (LPhi) {
12071 bool Erased = PendingMerges.erase(LPhi);
12072 assert(Erased && "Failed to erase LPhi!");
12073 (void)Erased;
12074 }
12075 if (RPhi) {
12076 bool Erased = PendingMerges.erase(RPhi);
12077 assert(Erased && "Failed to erase RPhi!");
12078 (void)Erased;
12079 }
12080 });
12081
12082 // Find respective Phis and check that they are not being pending.
12083 if (const SCEVUnknown *LU = dyn_cast<SCEVUnknown>(LHS))
12084 if (auto *Phi = dyn_cast<PHINode>(LU->getValue())) {
12085 if (!PendingMerges.insert(Phi).second)
12086 return false;
12087 LPhi = Phi;
12088 }
12089 if (const SCEVUnknown *RU = dyn_cast<SCEVUnknown>(RHS))
12090 if (auto *Phi = dyn_cast<PHINode>(RU->getValue())) {
12091 // If we detect a loop of Phi nodes being processed by this method, for
12092 // example:
12093 //
12094 // %a = phi i32 [ %some1, %preheader ], [ %b, %latch ]
12095 // %b = phi i32 [ %some2, %preheader ], [ %a, %latch ]
12096 //
12097 // we don't want to deal with a case that complex, so return conservative
12098 // answer false.
12099 if (!PendingMerges.insert(Phi).second)
12100 return false;
12101 RPhi = Phi;
12102 }
12103
12104 // If none of LHS, RHS is a Phi, nothing to do here.
12105 if (!LPhi && !RPhi)
12106 return false;
12107
12108 // If there is a SCEVUnknown Phi we are interested in, make it left.
12109 if (!LPhi) {
12110 std::swap(LHS, RHS);
12111 std::swap(FoundLHS, FoundRHS);
12112 std::swap(LPhi, RPhi);
12113 Pred = ICmpInst::getSwappedPredicate(Pred);
12114 }
12115
12116 assert(LPhi && "LPhi should definitely be a SCEVUnknown Phi!");
12117 const BasicBlock *LBB = LPhi->getParent();
12118 const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
12119
12120 auto ProvedEasily = [&](const SCEV *S1, const SCEV *S2) {
12121 return isKnownViaNonRecursiveReasoning(Pred, S1, S2) ||
12122 isImpliedCondOperandsViaRanges(Pred, S1, S2, Pred, FoundLHS, FoundRHS) ||
12123 isImpliedViaOperations(Pred, S1, S2, FoundLHS, FoundRHS, Depth);
12124 };
12125
12126 if (RPhi && RPhi->getParent() == LBB) {
12127 // Case one: RHS is also a SCEVUnknown Phi from the same basic block.
12128 // If we compare two Phis from the same block, and for each entry block
12129 // the predicate is true for incoming values from this block, then the
12130 // predicate is also true for the Phis.
12131 for (const BasicBlock *IncBB : predecessors(LBB)) {
12132 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12133 const SCEV *R = getSCEV(RPhi->getIncomingValueForBlock(IncBB));
12134 if (!ProvedEasily(L, R))
12135 return false;
12136 }
12137 } else if (RAR && RAR->getLoop()->getHeader() == LBB) {
12138 // Case two: RHS is also a Phi from the same basic block, and it is an
12139 // AddRec. It means that there is a loop which has both AddRec and Unknown
12140 // PHIs, for it we can compare incoming values of AddRec from above the loop
12141 // and latch with their respective incoming values of LPhi.
12142 // TODO: Generalize to handle loops with many inputs in a header.
12143 if (LPhi->getNumIncomingValues() != 2) return false;
12144
12145 auto *RLoop = RAR->getLoop();
12146 auto *Predecessor = RLoop->getLoopPredecessor();
12147 assert(Predecessor && "Loop with AddRec with no predecessor?");
12148 const SCEV *L1 = getSCEV(LPhi->getIncomingValueForBlock(Predecessor));
12149 if (!ProvedEasily(L1, RAR->getStart()))
12150 return false;
12151 auto *Latch = RLoop->getLoopLatch();
12152 assert(Latch && "Loop with AddRec with no latch?");
12153 const SCEV *L2 = getSCEV(LPhi->getIncomingValueForBlock(Latch));
12154 if (!ProvedEasily(L2, RAR->getPostIncExpr(*this)))
12155 return false;
12156 } else {
12157 // In all other cases go over inputs of LHS and compare each of them to RHS,
12158 // the predicate is true for (LHS, RHS) if it is true for all such pairs.
12159 // At this point RHS is either a non-Phi, or it is a Phi from some block
12160 // different from LBB.
12161 for (const BasicBlock *IncBB : predecessors(LBB)) {
12162 // Check that RHS is available in this block.
12163 if (!dominates(RHS, IncBB))
12164 return false;
12165 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12166 // Make sure L does not refer to a value from a potentially previous
12167 // iteration of a loop.
12168 if (!properlyDominates(L, LBB))
12169 return false;
12170 if (!ProvedEasily(L, RHS))
12171 return false;
12172 }
12173 }
12174 return true;
12175}
12176
12177bool ScalarEvolution::isImpliedCondOperandsViaShift(ICmpInst::Predicate Pred,
12178 const SCEV *LHS,
12179 const SCEV *RHS,
12180 const SCEV *FoundLHS,
12181 const SCEV *FoundRHS) {
12182 // We want to imply LHS < RHS from LHS < (RHS >> shiftvalue). First, make
12183 // sure that we are dealing with same LHS.
12184 if (RHS == FoundRHS) {
12185 std::swap(LHS, RHS);
12186 std::swap(FoundLHS, FoundRHS);
12187 Pred = ICmpInst::getSwappedPredicate(Pred);
12188 }
12189 if (LHS != FoundLHS)
12190 return false;
12191
12192 auto *SUFoundRHS = dyn_cast<SCEVUnknown>(FoundRHS);
12193 if (!SUFoundRHS)
12194 return false;
12195
12196 Value *Shiftee, *ShiftValue;
12197
12198 using namespace PatternMatch;
12199 if (match(SUFoundRHS->getValue(),
12200 m_LShr(m_Value(Shiftee), m_Value(ShiftValue)))) {
12201 auto *ShifteeS = getSCEV(Shiftee);
12202 // Prove one of the following:
12203 // LHS <u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <u RHS
12204 // LHS <=u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <=u RHS
12205 // LHS <s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12206 // ---> LHS <s RHS
12207 // LHS <=s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12208 // ---> LHS <=s RHS
12209 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE)
12210 return isKnownPredicate(ICmpInst::ICMP_ULE, ShifteeS, RHS);
12211 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE)
12212 if (isKnownNonNegative(ShifteeS))
12213 return isKnownPredicate(ICmpInst::ICMP_SLE, ShifteeS, RHS);
12214 }
12215
12216 return false;
12217}
12218
12219bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred,
12220 const SCEV *LHS, const SCEV *RHS,
12221 const SCEV *FoundLHS,
12222 const SCEV *FoundRHS,
12223 const Instruction *CtxI) {
12224 if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, Pred, FoundLHS, FoundRHS))
12225 return true;
12226
12227 if (isImpliedCondOperandsViaNoOverflow(Pred, LHS, RHS, FoundLHS, FoundRHS))
12228 return true;
12229
12230 if (isImpliedCondOperandsViaShift(Pred, LHS, RHS, FoundLHS, FoundRHS))
12231 return true;
12232
12233 if (isImpliedCondOperandsViaAddRecStart(Pred, LHS, RHS, FoundLHS, FoundRHS,
12234 CtxI))
12235 return true;
12236
12237 return isImpliedCondOperandsHelper(Pred, LHS, RHS,
12238 FoundLHS, FoundRHS);
12239}
12240
12241/// Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
12242template <typename MinMaxExprType>
12243static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr,
12244 const SCEV *Candidate) {
12245 const MinMaxExprType *MinMaxExpr = dyn_cast<MinMaxExprType>(MaybeMinMaxExpr);
12246 if (!MinMaxExpr)
12247 return false;
12248
12249 return is_contained(MinMaxExpr->operands(), Candidate);
12250}
12251
12254 const SCEV *LHS, const SCEV *RHS) {
12255 // If both sides are affine addrecs for the same loop, with equal
12256 // steps, and we know the recurrences don't wrap, then we only
12257 // need to check the predicate on the starting values.
12258
12259 if (!ICmpInst::isRelational(Pred))
12260 return false;
12261
12262 const SCEVAddRecExpr *LAR = dyn_cast<SCEVAddRecExpr>(LHS);
12263 if (!LAR)
12264 return false;
12265 const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
12266 if (!RAR)
12267 return false;
12268 if (LAR->getLoop() != RAR->getLoop())
12269 return false;
12270 if (!LAR->isAffine() || !RAR->isAffine())
12271 return false;
12272
12273 if (LAR->getStepRecurrence(SE) != RAR->getStepRecurrence(SE))
12274 return false;
12275
12278 if (!LAR->getNoWrapFlags(NW) || !RAR->getNoWrapFlags(NW))
12279 return false;
12280
12281 return SE.isKnownPredicate(Pred, LAR->getStart(), RAR->getStart());
12282}
12283
12284/// Is LHS `Pred` RHS true on the virtue of LHS or RHS being a Min or Max
12285/// expression?
12288 const SCEV *LHS, const SCEV *RHS) {
12289 switch (Pred) {
12290 default:
12291 return false;
12292
12293 case ICmpInst::ICMP_SGE:
12294 std::swap(LHS, RHS);
12295 [[fallthrough]];
12296 case ICmpInst::ICMP_SLE:
12297 return
12298 // min(A, ...) <= A
12299 IsMinMaxConsistingOf<SCEVSMinExpr>(LHS, RHS) ||
12300 // A <= max(A, ...)
12301 IsMinMaxConsistingOf<SCEVSMaxExpr>(RHS, LHS);
12302
12303 case ICmpInst::ICMP_UGE:
12304 std::swap(LHS, RHS);
12305 [[fallthrough]];
12306 case ICmpInst::ICMP_ULE:
12307 return
12308 // min(A, ...) <= A
12309 // FIXME: what about umin_seq?
12310 IsMinMaxConsistingOf<SCEVUMinExpr>(LHS, RHS) ||
12311 // A <= max(A, ...)
12312 IsMinMaxConsistingOf<SCEVUMaxExpr>(RHS, LHS);
12313 }
12314
12315 llvm_unreachable("covered switch fell through?!");
12316}
12317
12318bool ScalarEvolution::isImpliedViaOperations(ICmpInst::Predicate Pred,
12319 const SCEV *LHS, const SCEV *RHS,
12320 const SCEV *FoundLHS,
12321 const SCEV *FoundRHS,
12322 unsigned Depth) {
12325 "LHS and RHS have different sizes?");
12326 assert(getTypeSizeInBits(FoundLHS->getType()) ==
12327 getTypeSizeInBits(FoundRHS->getType()) &&
12328 "FoundLHS and FoundRHS have different sizes?");
12329 // We want to avoid hurting the compile time with analysis of too big trees.
12331 return false;
12332
12333 // We only want to work with GT comparison so far.
12334 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_SLT) {
12335 Pred = CmpInst::getSwappedPredicate(Pred);
12336 std::swap(LHS, RHS);
12337 std::swap(FoundLHS, FoundRHS);
12338 }
12339
12340 // For unsigned, try to reduce it to corresponding signed comparison.
12341 if (Pred == ICmpInst::ICMP_UGT)
12342 // We can replace unsigned predicate with its signed counterpart if all
12343 // involved values are non-negative.
12344 // TODO: We could have better support for unsigned.
12345 if (isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) {
12346 // Knowing that both FoundLHS and FoundRHS are non-negative, and knowing
12347 // FoundLHS >u FoundRHS, we also know that FoundLHS >s FoundRHS. Let us
12348 // use this fact to prove that LHS and RHS are non-negative.
12349 const SCEV *MinusOne = getMinusOne(LHS->getType());
12350 if (isImpliedCondOperands(ICmpInst::ICMP_SGT, LHS, MinusOne, FoundLHS,
12351 FoundRHS) &&
12352 isImpliedCondOperands(ICmpInst::ICMP_SGT, RHS, MinusOne, FoundLHS,
12353 FoundRHS))
12354 Pred = ICmpInst::ICMP_SGT;
12355 }
12356
12357 if (Pred != ICmpInst::ICMP_SGT)
12358 return false;
12359
12360 auto GetOpFromSExt = [&](const SCEV *S) {
12361 if (auto *Ext = dyn_cast<SCEVSignExtendExpr>(S))
12362 return Ext->getOperand();
12363 // TODO: If S is a SCEVConstant then you can cheaply "strip" the sext off
12364 // the constant in some cases.
12365 return S;
12366 };
12367
12368 // Acquire values from extensions.
12369 auto *OrigLHS = LHS;
12370 auto *OrigFoundLHS = FoundLHS;
12371 LHS = GetOpFromSExt(LHS);
12372 FoundLHS = GetOpFromSExt(FoundLHS);
12373
12374 // Is the SGT predicate can be proved trivially or using the found context.
12375 auto IsSGTViaContext = [&](const SCEV *S1, const SCEV *S2) {
12376 return isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGT, S1, S2) ||
12377 isImpliedViaOperations(ICmpInst::ICMP_SGT, S1, S2, OrigFoundLHS,
12378 FoundRHS, Depth + 1);
12379 };
12380
12381 if (auto *LHSAddExpr = dyn_cast<SCEVAddExpr>(LHS)) {
12382 // We want to avoid creation of any new non-constant SCEV. Since we are
12383 // going to compare the operands to RHS, we should be certain that we don't
12384 // need any size extensions for this. So let's decline all cases when the
12385 // sizes of types of LHS and RHS do not match.
12386 // TODO: Maybe try to get RHS from sext to catch more cases?
12388 return false;
12389
12390 // Should not overflow.
12391 if (!LHSAddExpr->hasNoSignedWrap())
12392 return false;
12393
12394 auto *LL = LHSAddExpr->getOperand(0);
12395 auto *LR = LHSAddExpr->getOperand(1);
12396 auto *MinusOne = getMinusOne(RHS->getType());
12397
12398 // Checks that S1 >= 0 && S2 > RHS, trivially or using the found context.
12399 auto IsSumGreaterThanRHS = [&](const SCEV *S1, const SCEV *S2) {
12400 return IsSGTViaContext(S1, MinusOne) && IsSGTViaContext(S2, RHS);
12401 };
12402 // Try to prove the following rule:
12403 // (LHS = LL + LR) && (LL >= 0) && (LR > RHS) => (LHS > RHS).
12404 // (LHS = LL + LR) && (LR >= 0) && (LL > RHS) => (LHS > RHS).
12405 if (IsSumGreaterThanRHS(LL, LR) || IsSumGreaterThanRHS(LR, LL))
12406 return true;
12407 } else if (auto *LHSUnknownExpr = dyn_cast<SCEVUnknown>(LHS)) {
12408 Value *LL, *LR;
12409 // FIXME: Once we have SDiv implemented, we can get rid of this matching.
12410
12411 using namespace llvm::PatternMatch;
12412
12413 if (match(LHSUnknownExpr->getValue(), m_SDiv(m_Value(LL), m_Value(LR)))) {
12414 // Rules for division.
12415 // We are going to perform some comparisons with Denominator and its
12416 // derivative expressions. In general case, creating a SCEV for it may
12417 // lead to a complex analysis of the entire graph, and in particular it
12418 // can request trip count recalculation for the same loop. This would
12419 // cache as SCEVCouldNotCompute to avoid the infinite recursion. To avoid
12420 // this, we only want to create SCEVs that are constants in this section.
12421 // So we bail if Denominator is not a constant.
12422 if (!isa<ConstantInt>(LR))
12423 return false;
12424
12425 auto *Denominator = cast<SCEVConstant>(getSCEV(LR));
12426
12427 // We want to make sure that LHS = FoundLHS / Denominator. If it is so,
12428 // then a SCEV for the numerator already exists and matches with FoundLHS.
12429 auto *Numerator = getExistingSCEV(LL);
12430 if (!Numerator || Numerator->getType() != FoundLHS->getType())
12431 return false;
12432
12433 // Make sure that the numerator matches with FoundLHS and the denominator
12434 // is positive.
12435 if (!HasSameValue(Numerator, FoundLHS) || !isKnownPositive(Denominator))
12436 return false;
12437
12438 auto *DTy = Denominator->getType();
12439 auto *FRHSTy = FoundRHS->getType();
12440 if (DTy->isPointerTy() != FRHSTy->isPointerTy())
12441 // One of types is a pointer and another one is not. We cannot extend
12442 // them properly to a wider type, so let us just reject this case.
12443 // TODO: Usage of getEffectiveSCEVType for DTy, FRHSTy etc should help
12444 // to avoid this check.
12445 return false;
12446
12447 // Given that:
12448 // FoundLHS > FoundRHS, LHS = FoundLHS / Denominator, Denominator > 0.
12449 auto *WTy = getWiderType(DTy, FRHSTy);
12450 auto *DenominatorExt = getNoopOrSignExtend(Denominator, WTy);
12451 auto *FoundRHSExt = getNoopOrSignExtend(FoundRHS, WTy);
12452
12453 // Try to prove the following rule:
12454 // (FoundRHS > Denominator - 2) && (RHS <= 0) => (LHS > RHS).
12455 // For example, given that FoundLHS > 2. It means that FoundLHS is at
12456 // least 3. If we divide it by Denominator < 4, we will have at least 1.
12457 auto *DenomMinusTwo = getMinusSCEV(DenominatorExt, getConstant(WTy, 2));
12458 if (isKnownNonPositive(RHS) &&
12459 IsSGTViaContext(FoundRHSExt, DenomMinusTwo))
12460 return true;
12461
12462 // Try to prove the following rule:
12463 // (FoundRHS > -1 - Denominator) && (RHS < 0) => (LHS > RHS).
12464 // For example, given that FoundLHS > -3. Then FoundLHS is at least -2.
12465 // If we divide it by Denominator > 2, then:
12466 // 1. If FoundLHS is negative, then the result is 0.
12467 // 2. If FoundLHS is non-negative, then the result is non-negative.
12468 // Anyways, the result is non-negative.
12469 auto *MinusOne = getMinusOne(WTy);
12470 auto *NegDenomMinusOne = getMinusSCEV(MinusOne, DenominatorExt);
12471 if (isKnownNegative(RHS) &&
12472 IsSGTViaContext(FoundRHSExt, NegDenomMinusOne))
12473 return true;
12474 }
12475 }
12476
12477 // If our expression contained SCEVUnknown Phis, and we split it down and now
12478 // need to prove something for them, try to prove the predicate for every
12479 // possible incoming values of those Phis.
12480 if (isImpliedViaMerge(Pred, OrigLHS, RHS, OrigFoundLHS, FoundRHS, Depth + 1))
12481 return true;
12482
12483 return false;
12484}
12485
12487 const SCEV *LHS, const SCEV *RHS) {
12488 // zext x u<= sext x, sext x s<= zext x
12489 switch (Pred) {
12490 case ICmpInst::ICMP_SGE:
12491 std::swap(LHS, RHS);
12492 [[fallthrough]];
12493 case ICmpInst::ICMP_SLE: {
12494 // If operand >=s 0 then ZExt == SExt. If operand <s 0 then SExt <s ZExt.
12495 const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(LHS);
12496 const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(RHS);
12497 if (SExt && ZExt && SExt->getOperand() == ZExt->getOperand())
12498 return true;
12499 break;
12500 }
12501 case ICmpInst::ICMP_UGE:
12502 std::swap(LHS, RHS);
12503 [[fallthrough]];
12504 case ICmpInst::ICMP_ULE: {
12505 // If operand >=s 0 then ZExt == SExt. If operand <s 0 then ZExt <u SExt.
12506 const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS);
12507 const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(RHS);
12508 if (SExt && ZExt && SExt->getOperand() == ZExt->getOperand())
12509 return true;
12510 break;
12511 }
12512 default:
12513 break;
12514 };
12515 return false;
12516}
12517
12518bool
12519ScalarEvolution::isKnownViaNonRecursiveReasoning(ICmpInst::Predicate Pred,
12520 const SCEV *LHS, const SCEV *RHS) {
12521 return isKnownPredicateExtendIdiom(Pred, LHS, RHS) ||
12522 isKnownPredicateViaConstantRanges(Pred, LHS, RHS) ||
12523 IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) ||
12524 IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) ||
12525 isKnownPredicateViaNoOverflow(Pred, LHS, RHS);
12526}
12527
12528bool
12529ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred,
12530 const SCEV *LHS, const SCEV *RHS,
12531 const SCEV *FoundLHS,
12532 const SCEV *FoundRHS) {
12533 switch (Pred) {
12534 default: llvm_unreachable("Unexpected ICmpInst::Predicate value!");
12535 case ICmpInst::ICMP_EQ:
12536 case ICmpInst::ICMP_NE:
12537 if (HasSameValue(LHS, FoundLHS) && HasSameValue(RHS, FoundRHS))
12538 return true;
12539 break;
12540 case ICmpInst::ICMP_SLT:
12541 case ICmpInst::ICMP_SLE:
12542 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, LHS, FoundLHS) &&
12543 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, RHS, FoundRHS))
12544 return true;
12545 break;
12546 case ICmpInst::ICMP_SGT:
12547 case ICmpInst::ICMP_SGE:
12548 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, LHS, FoundLHS) &&
12549 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, RHS, FoundRHS))
12550 return true;
12551 break;
12552 case ICmpInst::ICMP_ULT:
12553 case ICmpInst::ICMP_ULE:
12554 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, LHS, FoundLHS) &&
12555 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, RHS, FoundRHS))
12556 return true;
12557 break;
12558 case ICmpInst::ICMP_UGT:
12559 case ICmpInst::ICMP_UGE:
12560 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, LHS, FoundLHS) &&
12561 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, RHS, FoundRHS))
12562 return true;
12563 break;
12564 }
12565
12566 // Maybe it can be proved via operations?
12567 if (isImpliedViaOperations(Pred, LHS, RHS, FoundLHS, FoundRHS))
12568 return true;
12569
12570 return false;
12571}
12572
12573bool ScalarEvolution::isImpliedCondOperandsViaRanges(ICmpInst::Predicate Pred,
12574 const SCEV *LHS,
12575 const SCEV *RHS,
12576 ICmpInst::Predicate FoundPred,
12577 const SCEV *FoundLHS,
12578 const SCEV *FoundRHS) {
12579 if (!isa<SCEVConstant>(RHS) || !isa<SCEVConstant>(FoundRHS))
12580 // The restriction on `FoundRHS` be lifted easily -- it exists only to
12581 // reduce the compile time impact of this optimization.
12582 return false;
12583
12584 std::optional<APInt> Addend = computeConstantDifference(LHS, FoundLHS);
12585 if (!Addend)
12586 return false;
12587
12588 const APInt &ConstFoundRHS = cast<SCEVConstant>(FoundRHS)->getAPInt();
12589
12590 // `FoundLHSRange` is the range we know `FoundLHS` to be in by virtue of the
12591 // antecedent "`FoundLHS` `FoundPred` `FoundRHS`".
12592 ConstantRange FoundLHSRange =
12593 ConstantRange::makeExactICmpRegion(FoundPred, ConstFoundRHS);
12594
12595 // Since `LHS` is `FoundLHS` + `Addend`, we can compute a range for `LHS`:
12596 ConstantRange LHSRange = FoundLHSRange.add(ConstantRange(*Addend));
12597
12598 // We can also compute the range of values for `LHS` that satisfy the
12599 // consequent, "`LHS` `Pred` `RHS`":
12600 const APInt &ConstRHS = cast<SCEVConstant>(RHS)->getAPInt();
12601 // The antecedent implies the consequent if every value of `LHS` that
12602 // satisfies the antecedent also satisfies the consequent.
12603 return LHSRange.icmp(Pred, ConstRHS);
12604}
12605
12606bool ScalarEvolution::canIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride,
12607 bool IsSigned) {
12608 assert(isKnownPositive(Stride) && "Positive stride expected!");
12609
12610 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
12611 const SCEV *One = getOne(Stride->getType());
12612
12613 if (IsSigned) {
12614 APInt MaxRHS = getSignedRangeMax(RHS);
12616 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
12617
12618 // SMaxRHS + SMaxStrideMinusOne > SMaxValue => overflow!
12619 return (std::move(MaxValue) - MaxStrideMinusOne).slt(MaxRHS);
12620 }
12621
12622 APInt MaxRHS = getUnsignedRangeMax(RHS);
12623 APInt MaxValue = APInt::getMaxValue(BitWidth);
12624 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
12625
12626 // UMaxRHS + UMaxStrideMinusOne > UMaxValue => overflow!
12627 return (std::move(MaxValue) - MaxStrideMinusOne).ult(MaxRHS);
12628}
12629
12630bool ScalarEvolution::canIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride,
12631 bool IsSigned) {
12632
12633 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
12634 const SCEV *One = getOne(Stride->getType());
12635
12636 if (IsSigned) {
12637 APInt MinRHS = getSignedRangeMin(RHS);
12639 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
12640
12641 // SMinRHS - SMaxStrideMinusOne < SMinValue => overflow!
12642 return (std::move(MinValue) + MaxStrideMinusOne).sgt(MinRHS);
12643 }
12644
12645 APInt MinRHS = getUnsignedRangeMin(RHS);
12646 APInt MinValue = APInt::getMinValue(BitWidth);
12647 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
12648
12649 // UMinRHS - UMaxStrideMinusOne < UMinValue => overflow!
12650 return (std::move(MinValue) + MaxStrideMinusOne).ugt(MinRHS);
12651}
12652
12654 // umin(N, 1) + floor((N - umin(N, 1)) / D)
12655 // This is equivalent to "1 + floor((N - 1) / D)" for N != 0. The umin
12656 // expression fixes the case of N=0.
12657 const SCEV *MinNOne = getUMinExpr(N, getOne(N->getType()));
12658 const SCEV *NMinusOne = getMinusSCEV(N, MinNOne);
12659 return getAddExpr(MinNOne, getUDivExpr(NMinusOne, D));
12660}
12661
12662const SCEV *ScalarEvolution::computeMaxBECountForLT(const SCEV *Start,
12663 const SCEV *Stride,
12664 const SCEV *End,
12665 unsigned BitWidth,
12666 bool IsSigned) {
12667 // The logic in this function assumes we can represent a positive stride.
12668 // If we can't, the backedge-taken count must be zero.
12669 if (IsSigned && BitWidth == 1)
12670 return getZero(Stride->getType());
12671
12672 // This code below only been closely audited for negative strides in the
12673 // unsigned comparison case, it may be correct for signed comparison, but
12674 // that needs to be established.
12675 if (IsSigned && isKnownNegative(Stride))
12676 return getCouldNotCompute();
12677
12678 // Calculate the maximum backedge count based on the range of values
12679 // permitted by Start, End, and Stride.
12680 APInt MinStart =
12681 IsSigned ? getSignedRangeMin(Start) : getUnsignedRangeMin(Start);
12682
12683 APInt MinStride =
12684 IsSigned ? getSignedRangeMin(Stride) : getUnsignedRangeMin(Stride);
12685
12686 // We assume either the stride is positive, or the backedge-taken count
12687 // is zero. So force StrideForMaxBECount to be at least one.
12688 APInt One(BitWidth, 1);
12689 APInt StrideForMaxBECount = IsSigned ? APIntOps::smax(One, MinStride)
12690 : APIntOps::umax(One, MinStride);
12691
12692 APInt MaxValue = IsSigned ? APInt::getSignedMaxValue(BitWidth)
12693 : APInt::getMaxValue(BitWidth);
12694 APInt Limit = MaxValue - (StrideForMaxBECount - 1);
12695
12696 // Although End can be a MAX expression we estimate MaxEnd considering only
12697 // the case End = RHS of the loop termination condition. This is safe because
12698 // in the other case (End - Start) is zero, leading to a zero maximum backedge
12699 // taken count.
12700 APInt MaxEnd = IsSigned ? APIntOps::smin(getSignedRangeMax(End), Limit)
12701 : APIntOps::umin(getUnsignedRangeMax(End), Limit);
12702
12703 // MaxBECount = ceil((max(MaxEnd, MinStart) - MinStart) / Stride)
12704 MaxEnd = IsSigned ? APIntOps::smax(MaxEnd, MinStart)
12705 : APIntOps::umax(MaxEnd, MinStart);
12706
12707 return getUDivCeilSCEV(getConstant(MaxEnd - MinStart) /* Delta */,
12708 getConstant(StrideForMaxBECount) /* Step */);
12709}
12710
12712ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
12713 const Loop *L, bool IsSigned,
12714 bool ControlsOnlyExit, bool AllowPredicates) {
12716
12717 const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
12718 bool PredicatedIV = false;
12719
12720 auto canAssumeNoSelfWrap = [&](const SCEVAddRecExpr *AR) {
12721 // Can we prove this loop *must* be UB if overflow of IV occurs?
12722 // Reasoning goes as follows:
12723 // * Suppose the IV did self wrap.
12724 // * If Stride evenly divides the iteration space, then once wrap
12725 // occurs, the loop must revisit the same values.
12726 // * We know that RHS is invariant, and that none of those values
12727 // caused this exit to be taken previously. Thus, this exit is
12728 // dynamically dead.
12729 // * If this is the sole exit, then a dead exit implies the loop
12730 // must be infinite if there are no abnormal exits.
12731 // * If the loop were infinite, then it must either not be mustprogress
12732 // or have side effects. Otherwise, it must be UB.
12733 // * It can't (by assumption), be UB so we have contradicted our
12734 // premise and can conclude the IV did not in fact self-wrap.
12735 if (!isLoopInvariant(RHS, L))
12736 return false;
12737
12738 auto *StrideC = dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this));
12739 if (!StrideC || !StrideC->getAPInt().isPowerOf2())
12740 return false;
12741
12742 if (!ControlsOnlyExit || !loopHasNoAbnormalExits(L))
12743 return false;
12744
12745 return loopIsFiniteByAssumption(L);
12746 };
12747
12748 if (!IV) {
12749 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS)) {
12750 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(ZExt->getOperand());
12751 if (AR && AR->getLoop() == L && AR->isAffine()) {
12752 auto canProveNUW = [&]() {
12753 // We can use the comparison to infer no-wrap flags only if it fully
12754 // controls the loop exit.
12755 if (!ControlsOnlyExit)
12756 return false;
12757
12758 if (!isLoopInvariant(RHS, L))
12759 return false;
12760
12761 if (!isKnownNonZero(AR->getStepRecurrence(*this)))
12762 // We need the sequence defined by AR to strictly increase in the
12763 // unsigned integer domain for the logic below to hold.
12764 return false;
12765
12766 const unsigned InnerBitWidth = getTypeSizeInBits(AR->getType());
12767 const unsigned OuterBitWidth = getTypeSizeInBits(RHS->getType());
12768 // If RHS <=u Limit, then there must exist a value V in the sequence
12769 // defined by AR (e.g. {Start,+,Step}) such that V >u RHS, and
12770 // V <=u UINT_MAX. Thus, we must exit the loop before unsigned
12771 // overflow occurs. This limit also implies that a signed comparison
12772 // (in the wide bitwidth) is equivalent to an unsigned comparison as
12773 // the high bits on both sides must be zero.
12774 APInt StrideMax = getUnsignedRangeMax(AR->getStepRecurrence(*this));
12775 APInt Limit = APInt::getMaxValue(InnerBitWidth) - (StrideMax - 1);
12776 Limit = Limit.zext(OuterBitWidth);
12777 return getUnsignedRangeMax(applyLoopGuards(RHS, L)).ule(Limit);
12778 };
12779 auto Flags = AR->getNoWrapFlags();
12780 if (!hasFlags(Flags, SCEV::FlagNUW) && canProveNUW())
12781 Flags = setFlags(Flags, SCEV::FlagNUW);
12782
12783 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
12784 if (AR->hasNoUnsignedWrap()) {
12785 // Emulate what getZeroExtendExpr would have done during construction
12786 // if we'd been able to infer the fact just above at that time.
12787 const SCEV *Step = AR->getStepRecurrence(*this);
12788 Type *Ty = ZExt->getType();
12789 auto *S = getAddRecExpr(
12790 getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, 0),
12791 getZeroExtendExpr(Step, Ty, 0), L, AR->getNoWrapFlags());
12792 IV = dyn_cast<SCEVAddRecExpr>(S);
12793 }
12794 }
12795 }
12796 }
12797
12798
12799 if (!IV && AllowPredicates) {
12800 // Try to make this an AddRec using runtime tests, in the first X
12801 // iterations of this loop, where X is the SCEV expression found by the
12802 // algorithm below.
12803 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
12804 PredicatedIV = true;
12805 }
12806
12807 // Avoid weird loops
12808 if (!IV || IV->getLoop() != L || !IV->isAffine())
12809 return getCouldNotCompute();
12810
12811 // A precondition of this method is that the condition being analyzed
12812 // reaches an exiting branch which dominates the latch. Given that, we can
12813 // assume that an increment which violates the nowrap specification and
12814 // produces poison must cause undefined behavior when the resulting poison
12815 // value is branched upon and thus we can conclude that the backedge is
12816 // taken no more often than would be required to produce that poison value.
12817 // Note that a well defined loop can exit on the iteration which violates
12818 // the nowrap specification if there is another exit (either explicit or
12819 // implicit/exceptional) which causes the loop to execute before the
12820 // exiting instruction we're analyzing would trigger UB.
12821 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
12822 bool NoWrap = ControlsOnlyExit && IV->getNoWrapFlags(WrapType);
12824
12825 const SCEV *Stride = IV->getStepRecurrence(*this);
12826
12827 bool PositiveStride = isKnownPositive(Stride);
12828
12829 // Avoid negative or zero stride values.
12830 if (!PositiveStride) {
12831 // We can compute the correct backedge taken count for loops with unknown
12832 // strides if we can prove that the loop is not an infinite loop with side
12833 // effects. Here's the loop structure we are trying to handle -
12834 //
12835 // i = start
12836 // do {
12837 // A[i] = i;
12838 // i += s;
12839 // } while (i < end);
12840 //
12841 // The backedge taken count for such loops is evaluated as -
12842 // (max(end, start + stride) - start - 1) /u stride
12843 //
12844 // The additional preconditions that we need to check to prove correctness
12845 // of the above formula is as follows -
12846 //
12847 // a) IV is either nuw or nsw depending upon signedness (indicated by the
12848 // NoWrap flag).
12849 // b) the loop is guaranteed to be finite (e.g. is mustprogress and has
12850 // no side effects within the loop)
12851 // c) loop has a single static exit (with no abnormal exits)
12852 //
12853 // Precondition a) implies that if the stride is negative, this is a single
12854 // trip loop. The backedge taken count formula reduces to zero in this case.
12855 //
12856 // Precondition b) and c) combine to imply that if rhs is invariant in L,
12857 // then a zero stride means the backedge can't be taken without executing
12858 // undefined behavior.
12859 //
12860 // The positive stride case is the same as isKnownPositive(Stride) returning
12861 // true (original behavior of the function).
12862 //
12863 if (PredicatedIV || !NoWrap || !loopIsFiniteByAssumption(L) ||
12865 return getCouldNotCompute();
12866
12867 if (!isKnownNonZero(Stride)) {
12868 // If we have a step of zero, and RHS isn't invariant in L, we don't know
12869 // if it might eventually be greater than start and if so, on which
12870 // iteration. We can't even produce a useful upper bound.
12871 if (!isLoopInvariant(RHS, L))
12872 return getCouldNotCompute();
12873
12874 // We allow a potentially zero stride, but we need to divide by stride
12875 // below. Since the loop can't be infinite and this check must control
12876 // the sole exit, we can infer the exit must be taken on the first
12877 // iteration (e.g. backedge count = 0) if the stride is zero. Given that,
12878 // we know the numerator in the divides below must be zero, so we can
12879 // pick an arbitrary non-zero value for the denominator (e.g. stride)
12880 // and produce the right result.
12881 // FIXME: Handle the case where Stride is poison?
12882 auto wouldZeroStrideBeUB = [&]() {
12883 // Proof by contradiction. Suppose the stride were zero. If we can
12884 // prove that the backedge *is* taken on the first iteration, then since
12885 // we know this condition controls the sole exit, we must have an
12886 // infinite loop. We can't have a (well defined) infinite loop per
12887 // check just above.
12888 // Note: The (Start - Stride) term is used to get the start' term from
12889 // (start' + stride,+,stride). Remember that we only care about the
12890 // result of this expression when stride == 0 at runtime.
12891 auto *StartIfZero = getMinusSCEV(IV->getStart(), Stride);
12892 return isLoopEntryGuardedByCond(L, Cond, StartIfZero, RHS);
12893 };
12894 if (!wouldZeroStrideBeUB()) {
12895 Stride = getUMaxExpr(Stride, getOne(Stride->getType()));
12896 }
12897 }
12898 } else if (!Stride->isOne() && !NoWrap) {
12899 auto isUBOnWrap = [&]() {
12900 // From no-self-wrap, we need to then prove no-(un)signed-wrap. This
12901 // follows trivially from the fact that every (un)signed-wrapped, but
12902 // not self-wrapped value must be LT than the last value before
12903 // (un)signed wrap. Since we know that last value didn't exit, nor
12904 // will any smaller one.
12905 return canAssumeNoSelfWrap(IV);
12906 };
12907
12908 // Avoid proven overflow cases: this will ensure that the backedge taken
12909 // count will not generate any unsigned overflow. Relaxed no-overflow
12910 // conditions exploit NoWrapFlags, allowing to optimize in presence of
12911 // undefined behaviors like the case of C language.
12912 if (canIVOverflowOnLT(RHS, Stride, IsSigned) && !isUBOnWrap())
12913 return getCouldNotCompute();
12914 }
12915
12916 // On all paths just preceeding, we established the following invariant:
12917 // IV can be assumed not to overflow up to and including the exiting
12918 // iteration. We proved this in one of two ways:
12919 // 1) We can show overflow doesn't occur before the exiting iteration
12920 // 1a) canIVOverflowOnLT, and b) step of one
12921 // 2) We can show that if overflow occurs, the loop must execute UB
12922 // before any possible exit.
12923 // Note that we have not yet proved RHS invariant (in general).
12924
12925 const SCEV *Start = IV->getStart();
12926
12927 // Preserve pointer-typed Start/RHS to pass to isLoopEntryGuardedByCond.
12928 // If we convert to integers, isLoopEntryGuardedByCond will miss some cases.
12929 // Use integer-typed versions for actual computation; we can't subtract
12930 // pointers in general.
12931 const SCEV *OrigStart = Start;
12932 const SCEV *OrigRHS = RHS;
12933 if (Start->getType()->isPointerTy()) {
12934 Start = getLosslessPtrToIntExpr(Start);
12935 if (isa<SCEVCouldNotCompute>(Start))
12936 return Start;
12937 }
12938 if (RHS->getType()->isPointerTy()) {
12940 if (isa<SCEVCouldNotCompute>(RHS))
12941 return RHS;
12942 }
12943
12944 // When the RHS is not invariant, we do not know the end bound of the loop and
12945 // cannot calculate the ExactBECount needed by ExitLimit. However, we can
12946 // calculate the MaxBECount, given the start, stride and max value for the end
12947 // bound of the loop (RHS), and the fact that IV does not overflow (which is
12948 // checked above).
12949 if (!isLoopInvariant(RHS, L)) {
12950 const SCEV *MaxBECount = computeMaxBECountForLT(
12951 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
12952 return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,
12953 MaxBECount, false /*MaxOrZero*/, Predicates);
12954 }
12955
12956 // We use the expression (max(End,Start)-Start)/Stride to describe the
12957 // backedge count, as if the backedge is taken at least once max(End,Start)
12958 // is End and so the result is as above, and if not max(End,Start) is Start
12959 // so we get a backedge count of zero.
12960 const SCEV *BECount = nullptr;
12961 auto *OrigStartMinusStride = getMinusSCEV(OrigStart, Stride);
12962 assert(isAvailableAtLoopEntry(OrigStartMinusStride, L) && "Must be!");
12963 assert(isAvailableAtLoopEntry(OrigStart, L) && "Must be!");
12964 assert(isAvailableAtLoopEntry(OrigRHS, L) && "Must be!");
12965 // Can we prove (max(RHS,Start) > Start - Stride?
12966 if (isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigStart) &&
12967 isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigRHS)) {
12968 // In this case, we can use a refined formula for computing backedge taken
12969 // count. The general formula remains:
12970 // "End-Start /uceiling Stride" where "End = max(RHS,Start)"
12971 // We want to use the alternate formula:
12972 // "((End - 1) - (Start - Stride)) /u Stride"
12973 // Let's do a quick case analysis to show these are equivalent under
12974 // our precondition that max(RHS,Start) > Start - Stride.
12975 // * For RHS <= Start, the backedge-taken count must be zero.
12976 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
12977 // "((Start - 1) - (Start - Stride)) /u Stride" which simplies to
12978 // "Stride - 1 /u Stride" which is indeed zero for all non-zero values
12979 // of Stride. For 0 stride, we've use umin(1,Stride) above, reducing
12980 // this to the stride of 1 case.
12981 // * For RHS >= Start, the backedge count must be "RHS-Start /uceil Stride".
12982 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
12983 // "((RHS - 1) - (Start - Stride)) /u Stride" reassociates to
12984 // "((RHS - (Start - Stride) - 1) /u Stride".
12985 // Our preconditions trivially imply no overflow in that form.
12986 const SCEV *MinusOne = getMinusOne(Stride->getType());
12987 const SCEV *Numerator =
12988 getMinusSCEV(getAddExpr(RHS, MinusOne), getMinusSCEV(Start, Stride));
12989 BECount = getUDivExpr(Numerator, Stride);
12990 }
12991
12992 const SCEV *BECountIfBackedgeTaken = nullptr;
12993 if (!BECount) {
12994 auto canProveRHSGreaterThanEqualStart = [&]() {
12995 auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
12996 const SCEV *GuardedRHS = applyLoopGuards(OrigRHS, L);
12997 const SCEV *GuardedStart = applyLoopGuards(OrigStart, L);
12998
12999 if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart) ||
13000 isKnownPredicate(CondGE, GuardedRHS, GuardedStart))
13001 return true;
13002
13003 // (RHS > Start - 1) implies RHS >= Start.
13004 // * "RHS >= Start" is trivially equivalent to "RHS > Start - 1" if
13005 // "Start - 1" doesn't overflow.
13006 // * For signed comparison, if Start - 1 does overflow, it's equal
13007 // to INT_MAX, and "RHS >s INT_MAX" is trivially false.
13008 // * For unsigned comparison, if Start - 1 does overflow, it's equal
13009 // to UINT_MAX, and "RHS >u UINT_MAX" is trivially false.
13010 //
13011 // FIXME: Should isLoopEntryGuardedByCond do this for us?
13012 auto CondGT = IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT;
13013 auto *StartMinusOne = getAddExpr(OrigStart,
13014 getMinusOne(OrigStart->getType()));
13015 return isLoopEntryGuardedByCond(L, CondGT, OrigRHS, StartMinusOne);
13016 };
13017
13018 // If we know that RHS >= Start in the context of loop, then we know that
13019 // max(RHS, Start) = RHS at this point.
13020 const SCEV *End;
13021 if (canProveRHSGreaterThanEqualStart()) {
13022 End = RHS;
13023 } else {
13024 // If RHS < Start, the backedge will be taken zero times. So in
13025 // general, we can write the backedge-taken count as:
13026 //
13027 // RHS >= Start ? ceil(RHS - Start) / Stride : 0
13028 //
13029 // We convert it to the following to make it more convenient for SCEV:
13030 //
13031 // ceil(max(RHS, Start) - Start) / Stride
13032 End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start);
13033
13034 // See what would happen if we assume the backedge is taken. This is
13035 // used to compute MaxBECount.
13036 BECountIfBackedgeTaken = getUDivCeilSCEV(getMinusSCEV(RHS, Start), Stride);
13037 }
13038
13039 // At this point, we know:
13040 //
13041 // 1. If IsSigned, Start <=s End; otherwise, Start <=u End
13042 // 2. The index variable doesn't overflow.
13043 //
13044 // Therefore, we know N exists such that
13045 // (Start + Stride * N) >= End, and computing "(Start + Stride * N)"
13046 // doesn't overflow.
13047 //
13048 // Using this information, try to prove whether the addition in
13049 // "(Start - End) + (Stride - 1)" has unsigned overflow.
13050 const SCEV *One = getOne(Stride->getType());
13051 bool MayAddOverflow = [&] {
13052 if (auto *StrideC = dyn_cast<SCEVConstant>(Stride)) {
13053 if (StrideC->getAPInt().isPowerOf2()) {
13054 // Suppose Stride is a power of two, and Start/End are unsigned
13055 // integers. Let UMAX be the largest representable unsigned
13056 // integer.
13057 //
13058 // By the preconditions of this function, we know
13059 // "(Start + Stride * N) >= End", and this doesn't overflow.
13060 // As a formula:
13061 //
13062 // End <= (Start + Stride * N) <= UMAX
13063 //
13064 // Subtracting Start from all the terms:
13065 //
13066 // End - Start <= Stride * N <= UMAX - Start
13067 //
13068 // Since Start is unsigned, UMAX - Start <= UMAX. Therefore:
13069 //
13070 // End - Start <= Stride * N <= UMAX
13071 //
13072 // Stride * N is a multiple of Stride. Therefore,
13073 //
13074 // End - Start <= Stride * N <= UMAX - (UMAX mod Stride)
13075 //
13076 // Since Stride is a power of two, UMAX + 1 is divisible by Stride.
13077 // Therefore, UMAX mod Stride == Stride - 1. So we can write:
13078 //
13079 // End - Start <= Stride * N <= UMAX - Stride - 1
13080 //
13081 // Dropping the middle term:
13082 //
13083 // End - Start <= UMAX - Stride - 1
13084 //
13085 // Adding Stride - 1 to both sides:
13086 //
13087 // (End - Start) + (Stride - 1) <= UMAX
13088 //
13089 // In other words, the addition doesn't have unsigned overflow.
13090 //
13091 // A similar proof works if we treat Start/End as signed values.
13092 // Just rewrite steps before "End - Start <= Stride * N <= UMAX" to
13093 // use signed max instead of unsigned max. Note that we're trying
13094 // to prove a lack of unsigned overflow in either case.
13095 return false;
13096 }
13097 }
13098 if (Start == Stride || Start == getMinusSCEV(Stride, One)) {
13099 // If Start is equal to Stride, (End - Start) + (Stride - 1) == End - 1.
13100 // If !IsSigned, 0 <u Stride == Start <=u End; so 0 <u End - 1 <u End.
13101 // If IsSigned, 0 <s Stride == Start <=s End; so 0 <s End - 1 <s End.
13102 //
13103 // If Start is equal to Stride - 1, (End - Start) + Stride - 1 == End.
13104 return false;
13105 }
13106 return true;
13107 }();
13108
13109 const SCEV *Delta = getMinusSCEV(End, Start);
13110 if (!MayAddOverflow) {
13111 // floor((D + (S - 1)) / S)
13112 // We prefer this formulation if it's legal because it's fewer operations.
13113 BECount =
13114 getUDivExpr(getAddExpr(Delta, getMinusSCEV(Stride, One)), Stride);
13115 } else {
13116 BECount = getUDivCeilSCEV(Delta, Stride);
13117 }
13118 }
13119
13120 const SCEV *ConstantMaxBECount;
13121 bool MaxOrZero = false;
13122 if (isa<SCEVConstant>(BECount)) {
13123 ConstantMaxBECount = BECount;
13124 } else if (BECountIfBackedgeTaken &&
13125 isa<SCEVConstant>(BECountIfBackedgeTaken)) {
13126 // If we know exactly how many times the backedge will be taken if it's
13127 // taken at least once, then the backedge count will either be that or
13128 // zero.
13129 ConstantMaxBECount = BECountIfBackedgeTaken;
13130 MaxOrZero = true;
13131 } else {
13132 ConstantMaxBECount = computeMaxBECountForLT(
13133 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13134 }
13135
13136 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
13137 !isa<SCEVCouldNotCompute>(BECount))
13138 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
13139
13140 const SCEV *SymbolicMaxBECount =
13141 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13142 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, MaxOrZero,
13143 Predicates);
13144}
13145
13146ScalarEvolution::ExitLimit ScalarEvolution::howManyGreaterThans(
13147 const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned,
13148 bool ControlsOnlyExit, bool AllowPredicates) {
13150 // We handle only IV > Invariant
13151 if (!isLoopInvariant(RHS, L))
13152 return getCouldNotCompute();
13153
13154 const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
13155 if (!IV && AllowPredicates)
13156 // Try to make this an AddRec using runtime tests, in the first X
13157 // iterations of this loop, where X is the SCEV expression found by the
13158 // algorithm below.
13159 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
13160
13161 // Avoid weird loops
13162 if (!IV || IV->getLoop() != L || !IV->isAffine())
13163 return getCouldNotCompute();
13164
13165 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
13166 bool NoWrap = ControlsOnlyExit && IV->getNoWrapFlags(WrapType);
13168
13169 const SCEV *Stride = getNegativeSCEV(IV->getStepRecurrence(*this));
13170
13171 // Avoid negative or zero stride values
13172 if (!isKnownPositive(Stride))
13173 return getCouldNotCompute();
13174
13175 // Avoid proven overflow cases: this will ensure that the backedge taken count
13176 // will not generate any unsigned overflow. Relaxed no-overflow conditions
13177 // exploit NoWrapFlags, allowing to optimize in presence of undefined
13178 // behaviors like the case of C language.
13179 if (!Stride->isOne() && !NoWrap)
13180 if (canIVOverflowOnGT(RHS, Stride, IsSigned))
13181 return getCouldNotCompute();
13182
13183 const SCEV *Start = IV->getStart();
13184 const SCEV *End = RHS;
13185 if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) {
13186 // If we know that Start >= RHS in the context of loop, then we know that
13187 // min(RHS, Start) = RHS at this point.
13189 L, IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE, Start, RHS))
13190 End = RHS;
13191 else
13192 End = IsSigned ? getSMinExpr(RHS, Start) : getUMinExpr(RHS, Start);
13193 }
13194
13195 if (Start->getType()->isPointerTy()) {
13196 Start = getLosslessPtrToIntExpr(Start);
13197 if (isa<SCEVCouldNotCompute>(Start))
13198 return Start;
13199 }
13200 if (End->getType()->isPointerTy()) {
13202 if (isa<SCEVCouldNotCompute>(End))
13203 return End;
13204 }
13205
13206 // Compute ((Start - End) + (Stride - 1)) / Stride.
13207 // FIXME: This can overflow. Holding off on fixing this for now;
13208 // howManyGreaterThans will hopefully be gone soon.
13209 const SCEV *One = getOne(Stride->getType());
13210 const SCEV *BECount = getUDivExpr(
13211 getAddExpr(getMinusSCEV(Start, End), getMinusSCEV(Stride, One)), Stride);
13212
13213 APInt MaxStart = IsSigned ? getSignedRangeMax(Start)
13214 : getUnsignedRangeMax(Start);
13215
13216 APInt MinStride = IsSigned ? getSignedRangeMin(Stride)
13217 : getUnsignedRangeMin(Stride);
13218
13219 unsigned BitWidth = getTypeSizeInBits(LHS->getType());
13220 APInt Limit = IsSigned ? APInt::getSignedMinValue(BitWidth) + (MinStride - 1)
13221 : APInt::getMinValue(BitWidth) + (MinStride - 1);
13222
13223 // Although End can be a MIN expression we estimate MinEnd considering only
13224 // the case End = RHS. This is safe because in the other case (Start - End)
13225 // is zero, leading to a zero maximum backedge taken count.
13226 APInt MinEnd =
13227 IsSigned ? APIntOps::smax(getSignedRangeMin(RHS), Limit)
13228 : APIntOps::umax(getUnsignedRangeMin(RHS), Limit);
13229
13230 const SCEV *ConstantMaxBECount =
13231 isa<SCEVConstant>(BECount)
13232 ? BECount
13233 : getUDivCeilSCEV(getConstant(MaxStart - MinEnd),
13234 getConstant(MinStride));
13235
13236 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount))
13237 ConstantMaxBECount = BECount;
13238 const SCEV *SymbolicMaxBECount =
13239 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13240
13241 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
13242 Predicates);
13243}
13244
13246 ScalarEvolution &SE) const {
13247 if (Range.isFullSet()) // Infinite loop.
13248 return SE.getCouldNotCompute();
13249
13250 // If the start is a non-zero constant, shift the range to simplify things.
13251 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart()))
13252 if (!SC->getValue()->isZero()) {
13254 Operands[0] = SE.getZero(SC->getType());
13255 const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop(),
13256 getNoWrapFlags(FlagNW));
13257 if (const auto *ShiftedAddRec = dyn_cast<SCEVAddRecExpr>(Shifted))
13258 return ShiftedAddRec->getNumIterationsInRange(
13259 Range.subtract(SC->getAPInt()), SE);
13260 // This is strange and shouldn't happen.
13261 return SE.getCouldNotCompute();
13262 }
13263
13264 // The only time we can solve this is when we have all constant indices.
13265 // Otherwise, we cannot determine the overflow conditions.
13266 if (any_of(operands(), [](const SCEV *Op) { return !isa<SCEVConstant>(Op); }))
13267 return SE.getCouldNotCompute();
13268
13269 // Okay at this point we know that all elements of the chrec are constants and
13270 // that the start element is zero.
13271
13272 // First check to see if the range contains zero. If not, the first
13273 // iteration exits.
13274 unsigned BitWidth = SE.getTypeSizeInBits(getType());
13275 if (!Range.contains(APInt(BitWidth, 0)))
13276 return SE.getZero(getType());
13277
13278 if (isAffine()) {
13279 // If this is an affine expression then we have this situation:
13280 // Solve {0,+,A} in Range === Ax in Range
13281
13282 // We know that zero is in the range. If A is positive then we know that
13283 // the upper value of the range must be the first possible exit value.
13284 // If A is negative then the lower of the range is the last possible loop
13285 // value. Also note that we already checked for a full range.
13286 APInt A = cast<SCEVConstant>(getOperand(1))->getAPInt();
13287 APInt End = A.sge(1) ? (Range.getUpper() - 1) : Range.getLower();
13288
13289 // The exit value should be (End+A)/A.
13290 APInt ExitVal = (End + A).udiv(A);
13291 ConstantInt *ExitValue = ConstantInt::get(SE.getContext(), ExitVal);
13292
13293 // Evaluate at the exit value. If we really did fall out of the valid
13294 // range, then we computed our trip count, otherwise wrap around or other
13295 // things must have happened.
13296 ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue, SE);
13297 if (Range.contains(Val->getValue()))
13298 return SE.getCouldNotCompute(); // Something strange happened
13299
13300 // Ensure that the previous value is in the range.
13301 assert(Range.contains(
13303 ConstantInt::get(SE.getContext(), ExitVal - 1), SE)->getValue()) &&
13304 "Linear scev computation is off in a bad way!");
13305 return SE.getConstant(ExitValue);
13306 }
13307
13308 if (isQuadratic()) {
13309 if (auto S = SolveQuadraticAddRecRange(this, Range, SE))
13310 return SE.getConstant(*S);
13311 }
13312
13313 return SE.getCouldNotCompute();
13314}
13315
13316const SCEVAddRecExpr *
13318 assert(getNumOperands() > 1 && "AddRec with zero step?");
13319 // There is a temptation to just call getAddExpr(this, getStepRecurrence(SE)),
13320 // but in this case we cannot guarantee that the value returned will be an
13321 // AddRec because SCEV does not have a fixed point where it stops
13322 // simplification: it is legal to return ({rec1} + {rec2}). For example, it
13323 // may happen if we reach arithmetic depth limit while simplifying. So we
13324 // construct the returned value explicitly.
13326 // If this is {A,+,B,+,C,...,+,N}, then its step is {B,+,C,+,...,+,N}, and
13327 // (this + Step) is {A+B,+,B+C,+...,+,N}.
13328 for (unsigned i = 0, e = getNumOperands() - 1; i < e; ++i)
13329 Ops.push_back(SE.getAddExpr(getOperand(i), getOperand(i + 1)));
13330 // We know that the last operand is not a constant zero (otherwise it would
13331 // have been popped out earlier). This guarantees us that if the result has
13332 // the same last operand, then it will also not be popped out, meaning that
13333 // the returned value will be an AddRec.
13334 const SCEV *Last = getOperand(getNumOperands() - 1);
13335 assert(!Last->isZero() && "Recurrency with zero step?");
13336 Ops.push_back(Last);
13337 return cast<SCEVAddRecExpr>(SE.getAddRecExpr(Ops, getLoop(),
13339}
13340
13341// Return true when S contains at least an undef value.
13343 return SCEVExprContains(S, [](const SCEV *S) {
13344 if (const auto *SU = dyn_cast<SCEVUnknown>(S))
13345 return isa<UndefValue>(SU->getValue());
13346 return false;
13347 });
13348}
13349
13350// Return true when S contains a value that is a nullptr.
13352 return SCEVExprContains(S, [](const SCEV *S) {
13353 if (const auto *SU = dyn_cast<SCEVUnknown>(S))
13354 return SU->getValue() == nullptr;
13355 return false;
13356 });
13357}
13358
13359/// Return the size of an element read or written by Inst.
13361 Type *Ty;
13362 if (StoreInst *Store = dyn_cast<StoreInst>(Inst))
13363 Ty = Store->getValueOperand()->getType();
13364 else if (LoadInst *Load = dyn_cast<LoadInst>(Inst))
13365 Ty = Load->getType();
13366 else
13367 return nullptr;
13368
13370 return getSizeOfExpr(ETy, Ty);
13371}
13372
13373//===----------------------------------------------------------------------===//
13374// SCEVCallbackVH Class Implementation
13375//===----------------------------------------------------------------------===//
13376
13377void ScalarEvolution::SCEVCallbackVH::deleted() {
13378 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13379 if (PHINode *PN = dyn_cast<PHINode>(getValPtr()))
13380 SE->ConstantEvolutionLoopExitValue.erase(PN);
13381 SE->eraseValueFromMap(getValPtr());
13382 // this now dangles!
13383}
13384
13385void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) {
13386 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13387
13388 // Forget all the expressions associated with users of the old value,
13389 // so that future queries will recompute the expressions using the new
13390 // value.
13391 SE->forgetValue(getValPtr());
13392 // this now dangles!
13393}
13394
13395ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se)
13396 : CallbackVH(V), SE(se) {}
13397
13398//===----------------------------------------------------------------------===//
13399// ScalarEvolution Class Implementation
13400//===----------------------------------------------------------------------===//
13401
13404 LoopInfo &LI)
13405 : F(F), TLI(TLI), AC(AC), DT(DT), LI(LI),
13406 CouldNotCompute(new SCEVCouldNotCompute()), ValuesAtScopes(64),
13407 LoopDispositions(64), BlockDispositions(64) {
13408 // To use guards for proving predicates, we need to scan every instruction in
13409 // relevant basic blocks, and not just terminators. Doing this is a waste of
13410 // time if the IR does not actually contain any calls to
13411 // @llvm.experimental.guard, so do a quick check and remember this beforehand.
13412 //
13413 // This pessimizes the case where a pass that preserves ScalarEvolution wants
13414 // to _add_ guards to the module when there weren't any before, and wants
13415 // ScalarEvolution to optimize based on those guards. For now we prefer to be
13416 // efficient in lieu of being smart in that rather obscure case.
13417
13418 auto *GuardDecl = F.getParent()->getFunction(
13419 Intrinsic::getName(Intrinsic::experimental_guard));
13420 HasGuards = GuardDecl && !GuardDecl->use_empty();
13421}
13422
13424 : F(Arg.F), HasGuards(Arg.HasGuards), TLI(Arg.TLI), AC(Arg.AC), DT(Arg.DT),
13425 LI(Arg.LI), CouldNotCompute(std::move(Arg.CouldNotCompute)),
13426 ValueExprMap(std::move(Arg.ValueExprMap)),
13427 PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)),
13428 PendingPhiRanges(std::move(Arg.PendingPhiRanges)),
13429 PendingMerges(std::move(Arg.PendingMerges)),
13430 ConstantMultipleCache(std::move(Arg.ConstantMultipleCache)),
13431 BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)),
13432 PredicatedBackedgeTakenCounts(
13433 std::move(Arg.PredicatedBackedgeTakenCounts)),
13434 BECountUsers(std::move(Arg.BECountUsers)),
13435 ConstantEvolutionLoopExitValue(
13436 std::move(Arg.ConstantEvolutionLoopExitValue)),
13437 ValuesAtScopes(std::move(Arg.ValuesAtScopes)),
13438 ValuesAtScopesUsers(std::move(Arg.ValuesAtScopesUsers)),
13439 LoopDispositions(std::move(Arg.LoopDispositions)),
13440 LoopPropertiesCache(std::move(Arg.LoopPropertiesCache)),
13441 BlockDispositions(std::move(Arg.BlockDispositions)),
13442 SCEVUsers(std::move(Arg.SCEVUsers)),
13443 UnsignedRanges(std::move(Arg.UnsignedRanges)),
13444 SignedRanges(std::move(Arg.SignedRanges)),
13445 UniqueSCEVs(std::move(Arg.UniqueSCEVs)),
13446 UniquePreds(std::move(Arg.UniquePreds)),
13447 SCEVAllocator(std::move(Arg.SCEVAllocator)),
13448 LoopUsers(std::move(Arg.LoopUsers)),
13449 PredicatedSCEVRewrites(std::move(Arg.PredicatedSCEVRewrites)),
13450 FirstUnknown(Arg.FirstUnknown) {
13451 Arg.FirstUnknown = nullptr;
13452}
13453
13455 // Iterate through all the SCEVUnknown instances and call their
13456 // destructors, so that they release their references to their values.
13457 for (SCEVUnknown *U = FirstUnknown; U;) {
13458 SCEVUnknown *Tmp = U;
13459 U = U->Next;
13460 Tmp->~SCEVUnknown();
13461 }
13462 FirstUnknown = nullptr;
13463
13464 ExprValueMap.clear();
13465 ValueExprMap.clear();
13466 HasRecMap.clear();
13467 BackedgeTakenCounts.clear();
13468 PredicatedBackedgeTakenCounts.clear();
13469
13470 assert(PendingLoopPredicates.empty() && "isImpliedCond garbage");
13471 assert(PendingPhiRanges.empty() && "getRangeRef garbage");
13472 assert(PendingMerges.empty() && "isImpliedViaMerge garbage");
13473 assert(!WalkingBEDominatingConds && "isLoopBackedgeGuardedByCond garbage!");
13474 assert(!ProvingSplitPredicate && "ProvingSplitPredicate garbage!");
13475}
13476
13478 return !isa<SCEVCouldNotCompute>(getBackedgeTakenCount(L));
13479}
13480
13481/// When printing a top-level SCEV for trip counts, it's helpful to include
13482/// a type for constants which are otherwise hard to disambiguate.
13483static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV* S) {
13484 if (isa<SCEVConstant>(S))
13485 OS << *S->getType() << " ";
13486 OS << *S;
13487}
13488
13490 const Loop *L) {
13491 // Print all inner loops first
13492 for (Loop *I : *L)
13493 PrintLoopInfo(OS, SE, I);
13494
13495 OS << "Loop ";
13496 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13497 OS << ": ";
13498
13499 SmallVector<BasicBlock *, 8> ExitingBlocks;
13500 L->getExitingBlocks(ExitingBlocks);
13501 if (ExitingBlocks.size() != 1)
13502 OS << "<multiple exits> ";
13503
13504 auto *BTC = SE->getBackedgeTakenCount(L);
13505 if (!isa<SCEVCouldNotCompute>(BTC)) {
13506 OS << "backedge-taken count is ";
13508 } else
13509 OS << "Unpredictable backedge-taken count.";
13510 OS << "\n";
13511
13512 if (ExitingBlocks.size() > 1)
13513 for (BasicBlock *ExitingBlock : ExitingBlocks) {
13514 OS << " exit count for " << ExitingBlock->getName() << ": ";
13515 PrintSCEVWithTypeHint(OS, SE->getExitCount(L, ExitingBlock));
13516 OS << "\n";
13517 }
13518
13519 OS << "Loop ";
13520 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13521 OS << ": ";
13522
13523 auto *ConstantBTC = SE->getConstantMaxBackedgeTakenCount(L);
13524 if (!isa<SCEVCouldNotCompute>(ConstantBTC)) {
13525 OS << "constant max backedge-taken count is ";
13526 PrintSCEVWithTypeHint(OS, ConstantBTC);
13528 OS << ", actual taken count either this or zero.";
13529 } else {
13530 OS << "Unpredictable constant max backedge-taken count. ";
13531 }
13532
13533 OS << "\n"
13534 "Loop ";
13535 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13536 OS << ": ";
13537
13538 auto *SymbolicBTC = SE->getSymbolicMaxBackedgeTakenCount(L);
13539 if (!isa<SCEVCouldNotCompute>(SymbolicBTC)) {
13540 OS << "symbolic max backedge-taken count is ";
13541 PrintSCEVWithTypeHint(OS, SymbolicBTC);
13543 OS << ", actual taken count either this or zero.";
13544 } else {
13545 OS << "Unpredictable symbolic max backedge-taken count. ";
13546 }
13547 OS << "\n";
13548
13549 if (ExitingBlocks.size() > 1)
13550 for (BasicBlock *ExitingBlock : ExitingBlocks) {
13551 OS << " symbolic max exit count for " << ExitingBlock->getName() << ": ";
13552 auto *ExitBTC = SE->getExitCount(L, ExitingBlock,
13554 PrintSCEVWithTypeHint(OS, ExitBTC);
13555 OS << "\n";
13556 }
13557
13559 auto *PBT = SE->getPredicatedBackedgeTakenCount(L, Preds);
13560 if (PBT != BTC || !Preds.empty()) {
13561 OS << "Loop ";
13562 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13563 OS << ": ";
13564 if (!isa<SCEVCouldNotCompute>(PBT)) {
13565 OS << "Predicated backedge-taken count is ";
13567 } else
13568 OS << "Unpredictable predicated backedge-taken count.";
13569 OS << "\n";
13570 OS << " Predicates:\n";
13571 for (const auto *P : Preds)
13572 P->print(OS, 4);
13573 }
13574
13576 OS << "Loop ";
13577 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13578 OS << ": ";
13579 OS << "Trip multiple is " << SE->getSmallConstantTripMultiple(L) << "\n";
13580 }
13581}
13582
13583namespace llvm {
13585 switch (LD) {
13587 OS << "Variant";
13588 break;
13590 OS << "Invariant";
13591 break;
13593 OS << "Computable";
13594 break;
13595 }
13596 return OS;
13597}
13598
13600 switch (BD) {
13602 OS << "DoesNotDominate";
13603 break;
13605 OS << "Dominates";
13606 break;
13608 OS << "ProperlyDominates";
13609 break;
13610 }
13611 return OS;
13612}
13613}
13614
13616 // ScalarEvolution's implementation of the print method is to print
13617 // out SCEV values of all instructions that are interesting. Doing
13618 // this potentially causes it to create new SCEV objects though,
13619 // which technically conflicts with the const qualifier. This isn't
13620 // observable from outside the class though, so casting away the
13621 // const isn't dangerous.
13622 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
13623
13624 if (ClassifyExpressions) {
13625 OS << "Classifying expressions for: ";
13626 F.printAsOperand(OS, /*PrintType=*/false);
13627 OS << "\n";
13628 for (Instruction &I : instructions(F))
13629 if (isSCEVable(I.getType()) && !isa<CmpInst>(I)) {
13630 OS << I << '\n';
13631 OS << " --> ";
13632 const SCEV *SV = SE.getSCEV(&I);
13633 SV->print(OS);
13634 if (!isa<SCEVCouldNotCompute>(SV)) {
13635 OS << " U: ";
13636 SE.getUnsignedRange(SV).print(OS);
13637 OS << " S: ";
13638 SE.getSignedRange(SV).print(OS);
13639 }
13640
13641 const Loop *L = LI.getLoopFor(I.getParent());
13642
13643 const SCEV *AtUse = SE.getSCEVAtScope(SV, L);
13644 if (AtUse != SV) {
13645 OS << " --> ";
13646 AtUse->print(OS);
13647 if (!isa<SCEVCouldNotCompute>(AtUse)) {
13648 OS << " U: ";
13649 SE.getUnsignedRange(AtUse).print(OS);
13650 OS << " S: ";
13651 SE.getSignedRange(AtUse).print(OS);
13652 }
13653 }
13654
13655 if (L) {
13656 OS << "\t\t" "Exits: ";
13657 const SCEV *ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop());
13658 if (!SE.isLoopInvariant(ExitValue, L)) {
13659 OS << "<<Unknown>>";
13660 } else {
13661 OS << *ExitValue;
13662 }
13663
13664 bool First = true;
13665 for (const auto *Iter = L; Iter; Iter = Iter->getParentLoop()) {
13666 if (First) {
13667 OS << "\t\t" "LoopDispositions: { ";
13668 First = false;
13669 } else {
13670 OS << ", ";
13671 }
13672
13673 Iter->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13674 OS << ": " << SE.getLoopDisposition(SV, Iter);
13675 }
13676
13677 for (const auto *InnerL : depth_first(L)) {
13678 if (InnerL == L)
13679 continue;
13680 if (First) {
13681 OS << "\t\t" "LoopDispositions: { ";
13682 First = false;
13683 } else {
13684 OS << ", ";
13685 }
13686
13687 InnerL->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13688 OS << ": " << SE.getLoopDisposition(SV, InnerL);
13689 }
13690
13691 OS << " }";
13692 }
13693
13694 OS << "\n";
13695 }
13696 }
13697
13698 OS << "Determining loop execution counts for: ";
13699 F.printAsOperand(OS, /*PrintType=*/false);
13700 OS << "\n";
13701 for (Loop *I : LI)
13702 PrintLoopInfo(OS, &SE, I);
13703}
13704
13707 auto &Values = LoopDispositions[S];
13708 for (auto &V : Values) {
13709 if (V.getPointer() == L)
13710 return V.getInt();
13711 }
13712 Values.emplace_back(L, LoopVariant);
13713 LoopDisposition D = computeLoopDisposition(S, L);
13714 auto &Values2 = LoopDispositions[S];
13715 for (auto &V : llvm::reverse(Values2)) {
13716 if (V.getPointer() == L) {
13717 V.setInt(D);
13718 break;
13719 }
13720 }
13721 return D;
13722}
13723
13725ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) {
13726 switch (S->getSCEVType()) {
13727 case scConstant:
13728 case scVScale:
13729 return LoopInvariant;
13730 case scAddRecExpr: {
13731 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
13732
13733 // If L is the addrec's loop, it's computable.
13734 if (AR->getLoop() == L)
13735 return LoopComputable;
13736
13737 // Add recurrences are never invariant in the function-body (null loop).
13738 if (!L)
13739 return LoopVariant;
13740
13741 // Everything that is not defined at loop entry is variant.
13742 if (DT.dominates(L->getHeader(), AR->getLoop()->getHeader()))
13743 return LoopVariant;
13744 assert(!L->contains(AR->getLoop()) && "Containing loop's header does not"
13745 " dominate the contained loop's header?");
13746
13747 // This recurrence is invariant w.r.t. L if AR's loop contains L.
13748 if (AR->getLoop()->contains(L))
13749 return LoopInvariant;
13750
13751 // This recurrence is variant w.r.t. L if any of its operands
13752 // are variant.
13753 for (const auto *Op : AR->operands())
13754 if (!isLoopInvariant(Op, L))
13755 return LoopVariant;
13756
13757 // Otherwise it's loop-invariant.
13758 return LoopInvariant;
13759 }
13760 case scTruncate:
13761 case scZeroExtend:
13762 case scSignExtend:
13763 case scPtrToInt:
13764 case scAddExpr:
13765 case scMulExpr:
13766 case scUDivExpr:
13767 case scUMaxExpr:
13768 case scSMaxExpr:
13769 case scUMinExpr:
13770 case scSMinExpr:
13771 case scSequentialUMinExpr: {
13772 bool HasVarying = false;
13773 for (const auto *Op : S->operands()) {
13775 if (D == LoopVariant)
13776 return LoopVariant;
13777 if (D == LoopComputable)
13778 HasVarying = true;
13779 }
13780 return HasVarying ? LoopComputable : LoopInvariant;
13781 }
13782 case scUnknown:
13783 // All non-instruction values are loop invariant. All instructions are loop
13784 // invariant if they are not contained in the specified loop.
13785 // Instructions are never considered invariant in the function body
13786 // (null loop) because they are defined within the "loop".
13787 if (auto *I = dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue()))
13788 return (L && !L->contains(I)) ? LoopInvariant : LoopVariant;
13789 return LoopInvariant;
13790 case scCouldNotCompute:
13791 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
13792 }
13793 llvm_unreachable("Unknown SCEV kind!");
13794}
13795
13797 return getLoopDisposition(S, L) == LoopInvariant;
13798}
13799
13801 return getLoopDisposition(S, L) == LoopComputable;
13802}
13803
13806 auto &Values = BlockDispositions[S];
13807 for (auto &V : Values) {
13808 if (V.getPointer() == BB)
13809 return V.getInt();
13810 }
13811 Values.emplace_back(BB, DoesNotDominateBlock);
13812 BlockDisposition D = computeBlockDisposition(S, BB);
13813 auto &Values2 = BlockDispositions[S];
13814 for (auto &V : llvm::reverse(Values2)) {
13815 if (V.getPointer() == BB) {
13816 V.setInt(D);
13817 break;
13818 }
13819 }
13820 return D;
13821}
13822
13824ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) {
13825 switch (S->getSCEVType()) {
13826 case scConstant:
13827 case scVScale:
13829 case scAddRecExpr: {
13830 // This uses a "dominates" query instead of "properly dominates" query
13831 // to test for proper dominance too, because the instruction which
13832 // produces the addrec's value is a PHI, and a PHI effectively properly
13833 // dominates its entire containing block.
13834 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
13835 if (!DT.dominates(AR->getLoop()->getHeader(), BB))
13836 return DoesNotDominateBlock;
13837
13838 // Fall through into SCEVNAryExpr handling.
13839 [[fallthrough]];
13840 }
13841 case scTruncate:
13842 case scZeroExtend:
13843 case scSignExtend:
13844 case scPtrToInt:
13845 case scAddExpr:
13846 case scMulExpr:
13847 case scUDivExpr:
13848 case scUMaxExpr:
13849 case scSMaxExpr:
13850 case scUMinExpr:
13851 case scSMinExpr:
13852 case scSequentialUMinExpr: {
13853 bool Proper = true;
13854 for (const SCEV *NAryOp : S->operands()) {
13856 if (D == DoesNotDominateBlock)
13857 return DoesNotDominateBlock;
13858 if (D == DominatesBlock)
13859 Proper = false;
13860 }
13861 return Proper ? ProperlyDominatesBlock : DominatesBlock;
13862 }
13863 case scUnknown:
13864 if (Instruction *I =
13865 dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue())) {
13866 if (I->getParent() == BB)
13867 return DominatesBlock;
13868 if (DT.properlyDominates(I->getParent(), BB))
13870 return DoesNotDominateBlock;
13871 }
13873 case scCouldNotCompute:
13874 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
13875 }
13876 llvm_unreachable("Unknown SCEV kind!");
13877}
13878
13879bool ScalarEvolution::dominates(const SCEV *S, const BasicBlock *BB) {
13880 return getBlockDisposition(S, BB) >= DominatesBlock;
13881}
13882
13885}
13886
13887bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const {
13888 return SCEVExprContains(S, [&](const SCEV *Expr) { return Expr == Op; });
13889}
13890
13891void ScalarEvolution::forgetBackedgeTakenCounts(const Loop *L,
13892 bool Predicated) {
13893 auto &BECounts =
13894 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
13895 auto It = BECounts.find(L);
13896 if (It != BECounts.end()) {
13897 for (const ExitNotTakenInfo &ENT : It->second.ExitNotTaken) {
13898 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
13899 if (!isa<SCEVConstant>(S)) {
13900 auto UserIt = BECountUsers.find(S);
13901 assert(UserIt != BECountUsers.end());
13902 UserIt->second.erase({L, Predicated});
13903 }
13904 }
13905 }
13906 BECounts.erase(It);
13907 }
13908}
13909
13910void ScalarEvolution::forgetMemoizedResults(ArrayRef<const SCEV *> SCEVs) {
13911 SmallPtrSet<const SCEV *, 8> ToForget(SCEVs.begin(), SCEVs.end());
13912 SmallVector<const SCEV *, 8> Worklist(ToForget.begin(), ToForget.end());
13913
13914 while (!Worklist.empty()) {
13915 const SCEV *Curr = Worklist.pop_back_val();
13916 auto Users = SCEVUsers.find(Curr);
13917 if (Users != SCEVUsers.end())
13918 for (const auto *User : Users->second)
13919 if (ToForget.insert(User).second)
13920 Worklist.push_back(User);
13921 }
13922
13923 for (const auto *S : ToForget)
13924 forgetMemoizedResultsImpl(S);
13925
13926 for (auto I = PredicatedSCEVRewrites.begin();
13927 I != PredicatedSCEVRewrites.end();) {
13928 std::pair<const SCEV *, const Loop *> Entry = I->first;
13929 if (ToForget.count(Entry.first))
13930 PredicatedSCEVRewrites.erase(I++);
13931 else
13932 ++I;
13933 }
13934}
13935
13936void ScalarEvolution::forgetMemoizedResultsImpl(const SCEV *S) {
13937 LoopDispositions.erase(S);
13938 BlockDispositions.erase(S);
13939 UnsignedRanges.erase(S);
13940 SignedRanges.erase(S);
13941 HasRecMap.erase(S);
13942 ConstantMultipleCache.erase(S);
13943
13944 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S)) {
13945 UnsignedWrapViaInductionTried.erase(AR);
13946 SignedWrapViaInductionTried.erase(AR);
13947 }
13948
13949 auto ExprIt = ExprValueMap.find(S);
13950 if (ExprIt != ExprValueMap.end()) {
13951 for (Value *V : ExprIt->second) {
13952 auto ValueIt = ValueExprMap.find_as(V);
13953 if (ValueIt != ValueExprMap.end())
13954 ValueExprMap.erase(ValueIt);
13955 }
13956 ExprValueMap.erase(ExprIt);
13957 }
13958
13959 auto ScopeIt = ValuesAtScopes.find(S);
13960 if (ScopeIt != ValuesAtScopes.end()) {
13961 for (const auto &Pair : ScopeIt->second)
13962 if (!isa_and_nonnull<SCEVConstant>(Pair.second))
13963 llvm::erase(ValuesAtScopesUsers[Pair.second],
13964 std::make_pair(Pair.first, S));
13965 ValuesAtScopes.erase(ScopeIt);
13966 }
13967
13968 auto ScopeUserIt = ValuesAtScopesUsers.find(S);
13969 if (ScopeUserIt != ValuesAtScopesUsers.end()) {
13970 for (const auto &Pair : ScopeUserIt->second)
13971 llvm::erase(ValuesAtScopes[Pair.second], std::make_pair(Pair.first, S));
13972 ValuesAtScopesUsers.erase(ScopeUserIt);
13973 }
13974
13975 auto BEUsersIt = BECountUsers.find(S);
13976 if (BEUsersIt != BECountUsers.end()) {
13977 // Work on a copy, as forgetBackedgeTakenCounts() will modify the original.
13978 auto Copy = BEUsersIt->second;
13979 for (const auto &Pair : Copy)
13980 forgetBackedgeTakenCounts(Pair.getPointer(), Pair.getInt());
13981 BECountUsers.erase(BEUsersIt);
13982 }
13983
13984 auto FoldUser = FoldCacheUser.find(S);
13985 if (FoldUser != FoldCacheUser.end())
13986 for (auto &KV : FoldUser->second)
13987 FoldCache.erase(KV);
13988 FoldCacheUser.erase(S);
13989}
13990
13991void
13992ScalarEvolution::getUsedLoops(const SCEV *S,
13993 SmallPtrSetImpl<const Loop *> &LoopsUsed) {
13994 struct FindUsedLoops {
13995 FindUsedLoops(SmallPtrSetImpl<const Loop *> &LoopsUsed)
13996 : LoopsUsed(LoopsUsed) {}
13998 bool follow(const SCEV *S) {
13999 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S))
14000 LoopsUsed.insert(AR->getLoop());
14001 return true;
14002 }
14003
14004 bool isDone() const { return false; }
14005 };
14006
14007 FindUsedLoops F(LoopsUsed);
14009}
14010
14011void ScalarEvolution::getReachableBlocks(
14014 Worklist.push_back(&F.getEntryBlock());
14015 while (!Worklist.empty()) {
14016 BasicBlock *BB = Worklist.pop_back_val();
14017 if (!Reachable.insert(BB).second)
14018 continue;
14019
14020 Value *Cond;
14021 BasicBlock *TrueBB, *FalseBB;
14022 if (match(BB->getTerminator(), m_Br(m_Value(Cond), m_BasicBlock(TrueBB),
14023 m_BasicBlock(FalseBB)))) {
14024 if (auto *C = dyn_cast<ConstantInt>(Cond)) {
14025 Worklist.push_back(C->isOne() ? TrueBB : FalseBB);
14026 continue;
14027 }
14028
14029 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
14030 const SCEV *L = getSCEV(Cmp->getOperand(0));
14031 const SCEV *R = getSCEV(Cmp->getOperand(1));
14032 if (isKnownPredicateViaConstantRanges(Cmp->getPredicate(), L, R)) {
14033 Worklist.push_back(TrueBB);
14034 continue;
14035 }
14036 if (isKnownPredicateViaConstantRanges(Cmp->getInversePredicate(), L,
14037 R)) {
14038 Worklist.push_back(FalseBB);
14039 continue;
14040 }
14041 }
14042 }
14043
14044 append_range(Worklist, successors(BB));
14045 }
14046}
14047
14049 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
14050 ScalarEvolution SE2(F, TLI, AC, DT, LI);
14051
14052 SmallVector<Loop *, 8> LoopStack(LI.begin(), LI.end());
14053
14054 // Map's SCEV expressions from one ScalarEvolution "universe" to another.
14055 struct SCEVMapper : public SCEVRewriteVisitor<SCEVMapper> {
14056 SCEVMapper(ScalarEvolution &SE) : SCEVRewriteVisitor<SCEVMapper>(SE) {}
14057
14058 const SCEV *visitConstant(const SCEVConstant *Constant) {
14059 return SE.getConstant(Constant->getAPInt());
14060 }
14061
14062 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14063 return SE.getUnknown(Expr->getValue());
14064 }
14065
14066 const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
14067 return SE.getCouldNotCompute();
14068 }
14069 };
14070
14071 SCEVMapper SCM(SE2);
14072 SmallPtrSet<BasicBlock *, 16> ReachableBlocks;
14073 SE2.getReachableBlocks(ReachableBlocks, F);
14074
14075 auto GetDelta = [&](const SCEV *Old, const SCEV *New) -> const SCEV * {
14076 if (containsUndefs(Old) || containsUndefs(New)) {
14077 // SCEV treats "undef" as an unknown but consistent value (i.e. it does
14078 // not propagate undef aggressively). This means we can (and do) fail
14079 // verification in cases where a transform makes a value go from "undef"
14080 // to "undef+1" (say). The transform is fine, since in both cases the
14081 // result is "undef", but SCEV thinks the value increased by 1.
14082 return nullptr;
14083 }
14084
14085 // Unless VerifySCEVStrict is set, we only compare constant deltas.
14086 const SCEV *Delta = SE2.getMinusSCEV(Old, New);
14087 if (!VerifySCEVStrict && !isa<SCEVConstant>(Delta))
14088 return nullptr;
14089
14090 return Delta;
14091 };
14092
14093 while (!LoopStack.empty()) {
14094 auto *L = LoopStack.pop_back_val();
14095 llvm::append_range(LoopStack, *L);
14096
14097 // Only verify BECounts in reachable loops. For an unreachable loop,
14098 // any BECount is legal.
14099 if (!ReachableBlocks.contains(L->getHeader()))
14100 continue;
14101
14102 // Only verify cached BECounts. Computing new BECounts may change the
14103 // results of subsequent SCEV uses.
14104 auto It = BackedgeTakenCounts.find(L);
14105 if (It == BackedgeTakenCounts.end())
14106 continue;
14107
14108 auto *CurBECount =
14109 SCM.visit(It->second.getExact(L, const_cast<ScalarEvolution *>(this)));
14110 auto *NewBECount = SE2.getBackedgeTakenCount(L);
14111
14112 if (CurBECount == SE2.getCouldNotCompute() ||
14113 NewBECount == SE2.getCouldNotCompute()) {
14114 // NB! This situation is legal, but is very suspicious -- whatever pass
14115 // change the loop to make a trip count go from could not compute to
14116 // computable or vice-versa *should have* invalidated SCEV. However, we
14117 // choose not to assert here (for now) since we don't want false
14118 // positives.
14119 continue;
14120 }
14121
14122 if (SE.getTypeSizeInBits(CurBECount->getType()) >
14123 SE.getTypeSizeInBits(NewBECount->getType()))
14124 NewBECount = SE2.getZeroExtendExpr(NewBECount, CurBECount->getType());
14125 else if (SE.getTypeSizeInBits(CurBECount->getType()) <
14126 SE.getTypeSizeInBits(NewBECount->getType()))
14127 CurBECount = SE2.getZeroExtendExpr(CurBECount, NewBECount->getType());
14128
14129 const SCEV *Delta = GetDelta(CurBECount, NewBECount);
14130 if (Delta && !Delta->isZero()) {
14131 dbgs() << "Trip Count for " << *L << " Changed!\n";
14132 dbgs() << "Old: " << *CurBECount << "\n";
14133 dbgs() << "New: " << *NewBECount << "\n";
14134 dbgs() << "Delta: " << *Delta << "\n";
14135 std::abort();
14136 }
14137 }
14138
14139 // Collect all valid loops currently in LoopInfo.
14140 SmallPtrSet<Loop *, 32> ValidLoops;
14141 SmallVector<Loop *, 32> Worklist(LI.begin(), LI.end());
14142 while (!Worklist.empty()) {
14143 Loop *L = Worklist.pop_back_val();
14144 if (ValidLoops.insert(L).second)
14145 Worklist.append(L->begin(), L->end());
14146 }
14147 for (const auto &KV : ValueExprMap) {
14148#ifndef NDEBUG
14149 // Check for SCEV expressions referencing invalid/deleted loops.
14150 if (auto *AR = dyn_cast<SCEVAddRecExpr>(KV.second)) {
14151 assert(ValidLoops.contains(AR->getLoop()) &&
14152 "AddRec references invalid loop");
14153 }
14154#endif
14155
14156 // Check that the value is also part of the reverse map.
14157 auto It = ExprValueMap.find(KV.second);
14158 if (It == ExprValueMap.end() || !It->second.contains(KV.first)) {
14159 dbgs() << "Value " << *KV.first
14160 << " is in ValueExprMap but not in ExprValueMap\n";
14161 std::abort();
14162 }
14163
14164 if (auto *I = dyn_cast<Instruction>(&*KV.first)) {
14165 if (!ReachableBlocks.contains(I->getParent()))
14166 continue;
14167 const SCEV *OldSCEV = SCM.visit(KV.second);
14168 const SCEV *NewSCEV = SE2.getSCEV(I);
14169 const SCEV *Delta = GetDelta(OldSCEV, NewSCEV);
14170 if (Delta && !Delta->isZero()) {
14171 dbgs() << "SCEV for value " << *I << " changed!\n"
14172 << "Old: " << *OldSCEV << "\n"
14173 << "New: " << *NewSCEV << "\n"
14174 << "Delta: " << *Delta << "\n";
14175 std::abort();
14176 }
14177 }
14178 }
14179
14180 for (const auto &KV : ExprValueMap) {
14181 for (Value *V : KV.second) {
14182 auto It = ValueExprMap.find_as(V);
14183 if (It == ValueExprMap.end()) {
14184 dbgs() << "Value " << *V
14185 << " is in ExprValueMap but not in ValueExprMap\n";
14186 std::abort();
14187 }
14188 if (It->second != KV.first) {
14189 dbgs() << "Value " << *V << " mapped to " << *It->second
14190 << " rather than " << *KV.first << "\n";
14191 std::abort();
14192 }
14193 }
14194 }
14195
14196 // Verify integrity of SCEV users.
14197 for (const auto &S : UniqueSCEVs) {
14198 for (const auto *Op : S.operands()) {
14199 // We do not store dependencies of constants.
14200 if (isa<SCEVConstant>(Op))
14201 continue;
14202 auto It = SCEVUsers.find(Op);
14203 if (It != SCEVUsers.end() && It->second.count(&S))
14204 continue;
14205 dbgs() << "Use of operand " << *Op << " by user " << S
14206 << " is not being tracked!\n";
14207 std::abort();
14208 }
14209 }
14210
14211 // Verify integrity of ValuesAtScopes users.
14212 for (const auto &ValueAndVec : ValuesAtScopes) {
14213 const SCEV *Value = ValueAndVec.first;
14214 for (const auto &LoopAndValueAtScope : ValueAndVec.second) {
14215 const Loop *L = LoopAndValueAtScope.first;
14216 const SCEV *ValueAtScope = LoopAndValueAtScope.second;
14217 if (!isa<SCEVConstant>(ValueAtScope)) {
14218 auto It = ValuesAtScopesUsers.find(ValueAtScope);
14219 if (It != ValuesAtScopesUsers.end() &&
14220 is_contained(It->second, std::make_pair(L, Value)))
14221 continue;
14222 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14223 << *ValueAtScope << " missing in ValuesAtScopesUsers\n";
14224 std::abort();
14225 }
14226 }
14227 }
14228
14229 for (const auto &ValueAtScopeAndVec : ValuesAtScopesUsers) {
14230 const SCEV *ValueAtScope = ValueAtScopeAndVec.first;
14231 for (const auto &LoopAndValue : ValueAtScopeAndVec.second) {
14232 const Loop *L = LoopAndValue.first;
14233 const SCEV *Value = LoopAndValue.second;
14234 assert(!isa<SCEVConstant>(Value));
14235 auto It = ValuesAtScopes.find(Value);
14236 if (It != ValuesAtScopes.end() &&
14237 is_contained(It->second, std::make_pair(L, ValueAtScope)))
14238 continue;
14239 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14240 << *ValueAtScope << " missing in ValuesAtScopes\n";
14241 std::abort();
14242 }
14243 }
14244
14245 // Verify integrity of BECountUsers.
14246 auto VerifyBECountUsers = [&](bool Predicated) {
14247 auto &BECounts =
14248 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
14249 for (const auto &LoopAndBEInfo : BECounts) {
14250 for (const ExitNotTakenInfo &ENT : LoopAndBEInfo.second.ExitNotTaken) {
14251 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
14252 if (!isa<SCEVConstant>(S)) {
14253 auto UserIt = BECountUsers.find(S);
14254 if (UserIt != BECountUsers.end() &&
14255 UserIt->second.contains({ LoopAndBEInfo.first, Predicated }))
14256 continue;
14257 dbgs() << "Value " << *S << " for loop " << *LoopAndBEInfo.first
14258 << " missing from BECountUsers\n";
14259 std::abort();
14260 }
14261 }
14262 }
14263 }
14264 };
14265 VerifyBECountUsers(/* Predicated */ false);
14266 VerifyBECountUsers(/* Predicated */ true);
14267
14268 // Verify intergity of loop disposition cache.
14269 for (auto &[S, Values] : LoopDispositions) {
14270 for (auto [Loop, CachedDisposition] : Values) {
14271 const auto RecomputedDisposition = SE2.getLoopDisposition(S, Loop);
14272 if (CachedDisposition != RecomputedDisposition) {
14273 dbgs() << "Cached disposition of " << *S << " for loop " << *Loop
14274 << " is incorrect: cached " << CachedDisposition << ", actual "
14275 << RecomputedDisposition << "\n";
14276 std::abort();
14277 }
14278 }
14279 }
14280
14281 // Verify integrity of the block disposition cache.
14282 for (auto &[S, Values] : BlockDispositions) {
14283 for (auto [BB, CachedDisposition] : Values) {
14284 const auto RecomputedDisposition = SE2.getBlockDisposition(S, BB);
14285 if (CachedDisposition != RecomputedDisposition) {
14286 dbgs() << "Cached disposition of " << *S << " for block %"
14287 << BB->getName() << " is incorrect: cached " << CachedDisposition
14288 << ", actual " << RecomputedDisposition << "\n";
14289 std::abort();
14290 }
14291 }
14292 }
14293
14294 // Verify FoldCache/FoldCacheUser caches.
14295 for (auto [FoldID, Expr] : FoldCache) {
14296 auto I = FoldCacheUser.find(Expr);
14297 if (I == FoldCacheUser.end()) {
14298 dbgs() << "Missing entry in FoldCacheUser for cached expression " << *Expr
14299 << "!\n";
14300 std::abort();
14301 }
14302 if (!is_contained(I->second, FoldID)) {
14303 dbgs() << "Missing FoldID in cached users of " << *Expr << "!\n";
14304 std::abort();
14305 }
14306 }
14307 for (auto [Expr, IDs] : FoldCacheUser) {
14308 for (auto &FoldID : IDs) {
14309 auto I = FoldCache.find(FoldID);
14310 if (I == FoldCache.end()) {
14311 dbgs() << "Missing entry in FoldCache for expression " << *Expr
14312 << "!\n";
14313 std::abort();
14314 }
14315 if (I->second != Expr) {
14316 dbgs() << "Entry in FoldCache doesn't match FoldCacheUser: "
14317 << *I->second << " != " << *Expr << "!\n";
14318 std::abort();
14319 }
14320 }
14321 }
14322
14323 // Verify that ConstantMultipleCache computations are correct. We check that
14324 // cached multiples and recomputed multiples are multiples of each other to
14325 // verify correctness. It is possible that a recomputed multiple is different
14326 // from the cached multiple due to strengthened no wrap flags or changes in
14327 // KnownBits computations.
14328 for (auto [S, Multiple] : ConstantMultipleCache) {
14329 APInt RecomputedMultiple = SE2.getConstantMultiple(S);
14330 if ((Multiple != 0 && RecomputedMultiple != 0 &&
14331 Multiple.urem(RecomputedMultiple) != 0 &&
14332 RecomputedMultiple.urem(Multiple) != 0)) {
14333 dbgs() << "Incorrect cached computation in ConstantMultipleCache for "
14334 << *S << " : Computed " << RecomputedMultiple
14335 << " but cache contains " << Multiple << "!\n";
14336 std::abort();
14337 }
14338 }
14339}
14340
14342 Function &F, const PreservedAnalyses &PA,
14344 // Invalidate the ScalarEvolution object whenever it isn't preserved or one
14345 // of its dependencies is invalidated.
14346 auto PAC = PA.getChecker<ScalarEvolutionAnalysis>();
14347 return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>()) ||
14348 Inv.invalidate<AssumptionAnalysis>(F, PA) ||
14350 Inv.invalidate<LoopAnalysis>(F, PA);
14351}
14352
14353AnalysisKey ScalarEvolutionAnalysis::Key;
14354
14357 auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
14358 auto &AC = AM.getResult<AssumptionAnalysis>(F);
14359 auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
14360 auto &LI = AM.getResult<LoopAnalysis>(F);
14361 return ScalarEvolution(F, TLI, AC, DT, LI);
14362}
14363
14367 return PreservedAnalyses::all();
14368}
14369
14372 // For compatibility with opt's -analyze feature under legacy pass manager
14373 // which was not ported to NPM. This keeps tests using
14374 // update_analyze_test_checks.py working.
14375 OS << "Printing analysis 'Scalar Evolution Analysis' for function '"
14376 << F.getName() << "':\n";
14378 return PreservedAnalyses::all();
14379}
14380
14382 "Scalar Evolution Analysis", false, true)
14388 "Scalar Evolution Analysis", false, true)
14389
14391
14394}
14395
14397 SE.reset(new ScalarEvolution(
14398 F, getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F),
14399 getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F),
14400 getAnalysis<DominatorTreeWrapperPass>().getDomTree(),
14401 getAnalysis<LoopInfoWrapperPass>().getLoopInfo()));
14402 return false;
14403}
14404
14406
14408 SE->print(OS);
14409}
14410
14412 if (!VerifySCEV)
14413 return;
14414
14415 SE->verify();
14416}
14417
14419 AU.setPreservesAll();
14424}
14425
14427 const SCEV *RHS) {
14429}
14430
14431const SCEVPredicate *
14433 const SCEV *LHS, const SCEV *RHS) {
14435 assert(LHS->getType() == RHS->getType() &&
14436 "Type mismatch between LHS and RHS");
14437 // Unique this node based on the arguments
14438 ID.AddInteger(SCEVPredicate::P_Compare);
14439 ID.AddInteger(Pred);
14440 ID.AddPointer(LHS);
14441 ID.AddPointer(RHS);
14442 void *IP = nullptr;
14443 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
14444 return S;
14445 SCEVComparePredicate *Eq = new (SCEVAllocator)
14446 SCEVComparePredicate(ID.Intern(SCEVAllocator), Pred, LHS, RHS);
14447 UniquePreds.InsertNode(Eq, IP);
14448 return Eq;
14449}
14450
14452 const SCEVAddRecExpr *AR,
14455 // Unique this node based on the arguments
14456 ID.AddInteger(SCEVPredicate::P_Wrap);
14457 ID.AddPointer(AR);
14458 ID.AddInteger(AddedFlags);
14459 void *IP = nullptr;
14460 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
14461 return S;
14462 auto *OF = new (SCEVAllocator)
14463 SCEVWrapPredicate(ID.Intern(SCEVAllocator), AR, AddedFlags);
14464 UniquePreds.InsertNode(OF, IP);
14465 return OF;
14466}
14467
14468namespace {
14469
14470class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
14471public:
14472
14473 /// Rewrites \p S in the context of a loop L and the SCEV predication
14474 /// infrastructure.
14475 ///
14476 /// If \p Pred is non-null, the SCEV expression is rewritten to respect the
14477 /// equivalences present in \p Pred.
14478 ///
14479 /// If \p NewPreds is non-null, rewrite is free to add further predicates to
14480 /// \p NewPreds such that the result will be an AddRecExpr.
14481 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
14483 const SCEVPredicate *Pred) {
14484 SCEVPredicateRewriter Rewriter(L, SE, NewPreds, Pred);
14485 return Rewriter.visit(S);
14486 }
14487
14488 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14489 if (Pred) {
14490 if (auto *U = dyn_cast<SCEVUnionPredicate>(Pred)) {
14491 for (const auto *Pred : U->getPredicates())
14492 if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred))
14493 if (IPred->getLHS() == Expr &&
14494 IPred->getPredicate() == ICmpInst::ICMP_EQ)
14495 return IPred->getRHS();
14496 } else if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred)) {
14497 if (IPred->getLHS() == Expr &&
14498 IPred->getPredicate() == ICmpInst::ICMP_EQ)
14499 return IPred->getRHS();
14500 }
14501 }
14502 return convertToAddRecWithPreds(Expr);
14503 }
14504
14505 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
14506 const SCEV *Operand = visit(Expr->getOperand());
14507 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
14508 if (AR && AR->getLoop() == L && AR->isAffine()) {
14509 // This couldn't be folded because the operand didn't have the nuw
14510 // flag. Add the nusw flag as an assumption that we could make.
14511 const SCEV *Step = AR->getStepRecurrence(SE);
14512 Type *Ty = Expr->getType();
14513 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNUSW))
14514 return SE.getAddRecExpr(SE.getZeroExtendExpr(AR->getStart(), Ty),
14515 SE.getSignExtendExpr(Step, Ty), L,
14516 AR->getNoWrapFlags());
14517 }
14518 return SE.getZeroExtendExpr(Operand, Expr->getType());
14519 }
14520
14521 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
14522 const SCEV *Operand = visit(Expr->getOperand());
14523 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
14524 if (AR && AR->getLoop() == L && AR->isAffine()) {
14525 // This couldn't be folded because the operand didn't have the nsw
14526 // flag. Add the nssw flag as an assumption that we could make.
14527 const SCEV *Step = AR->getStepRecurrence(SE);
14528 Type *Ty = Expr->getType();
14529 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNSSW))
14530 return SE.getAddRecExpr(SE.getSignExtendExpr(AR->getStart(), Ty),
14531 SE.getSignExtendExpr(Step, Ty), L,
14532 AR->getNoWrapFlags());
14533 }
14534 return SE.getSignExtendExpr(Operand, Expr->getType());
14535 }
14536
14537private:
14538 explicit SCEVPredicateRewriter(const Loop *L, ScalarEvolution &SE,
14540 const SCEVPredicate *Pred)
14541 : SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {}
14542
14543 bool addOverflowAssumption(const SCEVPredicate *P) {
14544 if (!NewPreds) {
14545 // Check if we've already made this assumption.
14546 return Pred && Pred->implies(P);
14547 }
14548 NewPreds->insert(P);
14549 return true;
14550 }
14551
14552 bool addOverflowAssumption(const SCEVAddRecExpr *AR,
14554 auto *A = SE.getWrapPredicate(AR, AddedFlags);
14555 return addOverflowAssumption(A);
14556 }
14557
14558 // If \p Expr represents a PHINode, we try to see if it can be represented
14559 // as an AddRec, possibly under a predicate (PHISCEVPred). If it is possible
14560 // to add this predicate as a runtime overflow check, we return the AddRec.
14561 // If \p Expr does not meet these conditions (is not a PHI node, or we
14562 // couldn't create an AddRec for it, or couldn't add the predicate), we just
14563 // return \p Expr.
14564 const SCEV *convertToAddRecWithPreds(const SCEVUnknown *Expr) {
14565 if (!isa<PHINode>(Expr->getValue()))
14566 return Expr;
14567 std::optional<
14568 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
14569 PredicatedRewrite = SE.createAddRecFromPHIWithCasts(Expr);
14570 if (!PredicatedRewrite)
14571 return Expr;
14572 for (const auto *P : PredicatedRewrite->second){
14573 // Wrap predicates from outer loops are not supported.
14574 if (auto *WP = dyn_cast<const SCEVWrapPredicate>(P)) {
14575 if (L != WP->getExpr()->getLoop())
14576 return Expr;
14577 }
14578 if (!addOverflowAssumption(P))
14579 return Expr;
14580 }
14581 return PredicatedRewrite->first;
14582 }
14583
14585 const SCEVPredicate *Pred;
14586 const Loop *L;
14587};
14588
14589} // end anonymous namespace
14590
14591const SCEV *
14593 const SCEVPredicate &Preds) {
14594 return SCEVPredicateRewriter::rewrite(S, L, *this, nullptr, &Preds);
14595}
14596
14598 const SCEV *S, const Loop *L,
14601 S = SCEVPredicateRewriter::rewrite(S, L, *this, &TransformPreds, nullptr);
14602 auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
14603
14604 if (!AddRec)
14605 return nullptr;
14606
14607 // Since the transformation was successful, we can now transfer the SCEV
14608 // predicates.
14609 for (const auto *P : TransformPreds)
14610 Preds.insert(P);
14611
14612 return AddRec;
14613}
14614
14615/// SCEV predicates
14617 SCEVPredicateKind Kind)
14618 : FastID(ID), Kind(Kind) {}
14619
14621 const ICmpInst::Predicate Pred,
14622 const SCEV *LHS, const SCEV *RHS)
14623 : SCEVPredicate(ID, P_Compare), Pred(Pred), LHS(LHS), RHS(RHS) {
14624 assert(LHS->getType() == RHS->getType() && "LHS and RHS types don't match");
14625 assert(LHS != RHS && "LHS and RHS are the same SCEV");
14626}
14627
14629 const auto *Op = dyn_cast<SCEVComparePredicate>(N);
14630
14631 if (!Op)
14632 return false;
14633
14634 if (Pred != ICmpInst::ICMP_EQ)
14635 return false;
14636
14637 return Op->LHS == LHS && Op->RHS == RHS;
14638}
14639
14640bool SCEVComparePredicate::isAlwaysTrue() const { return false; }
14641
14643 if (Pred == ICmpInst::ICMP_EQ)
14644 OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n";
14645 else
14646 OS.indent(Depth) << "Compare predicate: " << *LHS << " " << Pred << ") "
14647 << *RHS << "\n";
14648
14649}
14650
14652 const SCEVAddRecExpr *AR,
14653 IncrementWrapFlags Flags)
14654 : SCEVPredicate(ID, P_Wrap), AR(AR), Flags(Flags) {}
14655
14656const SCEVAddRecExpr *SCEVWrapPredicate::getExpr() const { return AR; }
14657
14659 const auto *Op = dyn_cast<SCEVWrapPredicate>(N);
14660
14661 return Op && Op->AR == AR && setFlags(Flags, Op->Flags) == Flags;
14662}
14663
14665 SCEV::NoWrapFlags ScevFlags = AR->getNoWrapFlags();
14666 IncrementWrapFlags IFlags = Flags;
14667
14668 if (ScalarEvolution::setFlags(ScevFlags, SCEV::FlagNSW) == ScevFlags)
14669 IFlags = clearFlags(IFlags, IncrementNSSW);
14670
14671 return IFlags == IncrementAnyWrap;
14672}
14673
14675 OS.indent(Depth) << *getExpr() << " Added Flags: ";
14677 OS << "<nusw>";
14679 OS << "<nssw>";
14680 OS << "\n";
14681}
14682
14685 ScalarEvolution &SE) {
14686 IncrementWrapFlags ImpliedFlags = IncrementAnyWrap;
14687 SCEV::NoWrapFlags StaticFlags = AR->getNoWrapFlags();
14688
14689 // We can safely transfer the NSW flag as NSSW.
14690 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNSW) == StaticFlags)
14691 ImpliedFlags = IncrementNSSW;
14692
14693 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNUW) == StaticFlags) {
14694 // If the increment is positive, the SCEV NUW flag will also imply the
14695 // WrapPredicate NUSW flag.
14696 if (const auto *Step = dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE)))
14697 if (Step->getValue()->getValue().isNonNegative())
14698 ImpliedFlags = setFlags(ImpliedFlags, IncrementNUSW);
14699 }
14700
14701 return ImpliedFlags;
14702}
14703
14704/// Union predicates don't get cached so create a dummy set ID for it.
14706 : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
14707 for (const auto *P : Preds)
14708 add(P);
14709}
14710
14712 return all_of(Preds,
14713 [](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
14714}
14715
14717 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N))
14718 return all_of(Set->Preds,
14719 [this](const SCEVPredicate *I) { return this->implies(I); });
14720
14721 return any_of(Preds,
14722 [N](const SCEVPredicate *I) { return I->implies(N); });
14723}
14724
14726 for (const auto *Pred : Preds)
14727 Pred->print(OS, Depth);
14728}
14729
14730void SCEVUnionPredicate::add(const SCEVPredicate *N) {
14731 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) {
14732 for (const auto *Pred : Set->Preds)
14733 add(Pred);
14734 return;
14735 }
14736
14737 Preds.push_back(N);
14738}
14739
14741 Loop &L)
14742 : SE(SE), L(L) {
14744 Preds = std::make_unique<SCEVUnionPredicate>(Empty);
14745}
14746
14749 for (const auto *Op : Ops)
14750 // We do not expect that forgetting cached data for SCEVConstants will ever
14751 // open any prospects for sharpening or introduce any correctness issues,
14752 // so we don't bother storing their dependencies.
14753 if (!isa<SCEVConstant>(Op))
14754 SCEVUsers[Op].insert(User);
14755}
14756
14758 const SCEV *Expr = SE.getSCEV(V);
14759 RewriteEntry &Entry = RewriteMap[Expr];
14760
14761 // If we already have an entry and the version matches, return it.
14762 if (Entry.second && Generation == Entry.first)
14763 return Entry.second;
14764
14765 // We found an entry but it's stale. Rewrite the stale entry
14766 // according to the current predicate.
14767 if (Entry.second)
14768 Expr = Entry.second;
14769
14770 const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, &L, *Preds);
14771 Entry = {Generation, NewSCEV};
14772
14773 return NewSCEV;
14774}
14775
14777 if (!BackedgeCount) {
14779 BackedgeCount = SE.getPredicatedBackedgeTakenCount(&L, Preds);
14780 for (const auto *P : Preds)
14781 addPredicate(*P);
14782 }
14783 return BackedgeCount;
14784}
14785
14787 if (Preds->implies(&Pred))
14788 return;
14789
14790 auto &OldPreds = Preds->getPredicates();
14791 SmallVector<const SCEVPredicate*, 4> NewPreds(OldPreds.begin(), OldPreds.end());
14792 NewPreds.push_back(&Pred);
14793 Preds = std::make_unique<SCEVUnionPredicate>(NewPreds);
14794 updateGeneration();
14795}
14796
14798 return *Preds;
14799}
14800
14801void PredicatedScalarEvolution::updateGeneration() {
14802 // If the generation number wrapped recompute everything.
14803 if (++Generation == 0) {
14804 for (auto &II : RewriteMap) {
14805 const SCEV *Rewritten = II.second.second;
14806 II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, &L, *Preds)};
14807 }
14808 }
14809}
14810
14813 const SCEV *Expr = getSCEV(V);
14814 const auto *AR = cast<SCEVAddRecExpr>(Expr);
14815
14816 auto ImpliedFlags = SCEVWrapPredicate::getImpliedFlags(AR, SE);
14817
14818 // Clear the statically implied flags.
14819 Flags = SCEVWrapPredicate::clearFlags(Flags, ImpliedFlags);
14820 addPredicate(*SE.getWrapPredicate(AR, Flags));
14821
14822 auto II = FlagsMap.insert({V, Flags});
14823 if (!II.second)
14824 II.first->second = SCEVWrapPredicate::setFlags(Flags, II.first->second);
14825}
14826
14829 const SCEV *Expr = getSCEV(V);
14830 const auto *AR = cast<SCEVAddRecExpr>(Expr);
14831
14833 Flags, SCEVWrapPredicate::getImpliedFlags(AR, SE));
14834
14835 auto II = FlagsMap.find(V);
14836
14837 if (II != FlagsMap.end())
14838 Flags = SCEVWrapPredicate::clearFlags(Flags, II->second);
14839
14841}
14842
14844 const SCEV *Expr = this->getSCEV(V);
14846 auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, NewPreds);
14847
14848 if (!New)
14849 return nullptr;
14850
14851 for (const auto *P : NewPreds)
14852 addPredicate(*P);
14853
14854 RewriteMap[SE.getSCEV(V)] = {Generation, New};
14855 return New;
14856}
14857
14860 : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
14861 Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates())),
14862 Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
14863 for (auto I : Init.FlagsMap)
14864 FlagsMap.insert(I);
14865}
14866
14868 // For each block.
14869 for (auto *BB : L.getBlocks())
14870 for (auto &I : *BB) {
14871 if (!SE.isSCEVable(I.getType()))
14872 continue;
14873
14874 auto *Expr = SE.getSCEV(&I);
14875 auto II = RewriteMap.find(Expr);
14876
14877 if (II == RewriteMap.end())
14878 continue;
14879
14880 // Don't print things that are not interesting.
14881 if (II->second.second == Expr)
14882 continue;
14883
14884 OS.indent(Depth) << "[PSE]" << I << ":\n";
14885 OS.indent(Depth + 2) << *Expr << "\n";
14886 OS.indent(Depth + 2) << "--> " << *II->second.second << "\n";
14887 }
14888}
14889
14890// Match the mathematical pattern A - (A / B) * B, where A and B can be
14891// arbitrary expressions. Also match zext (trunc A to iB) to iY, which is used
14892// for URem with constant power-of-2 second operands.
14893// It's not always easy, as A and B can be folded (imagine A is X / 2, and B is
14894// 4, A / B becomes X / 8).
14895bool ScalarEvolution::matchURem(const SCEV *Expr, const SCEV *&LHS,
14896 const SCEV *&RHS) {
14897 // Try to match 'zext (trunc A to iB) to iY', which is used
14898 // for URem with constant power-of-2 second operands. Make sure the size of
14899 // the operand A matches the size of the whole expressions.
14900 if (const auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(Expr))
14901 if (const auto *Trunc = dyn_cast<SCEVTruncateExpr>(ZExt->getOperand(0))) {
14902 LHS = Trunc->getOperand();
14903 // Bail out if the type of the LHS is larger than the type of the
14904 // expression for now.
14905 if (getTypeSizeInBits(LHS->getType()) >
14906 getTypeSizeInBits(Expr->getType()))
14907 return false;
14908 if (LHS->getType() != Expr->getType())
14909 LHS = getZeroExtendExpr(LHS, Expr->getType());
14911 << getTypeSizeInBits(Trunc->getType()));
14912 return true;
14913 }
14914 const auto *Add = dyn_cast<SCEVAddExpr>(Expr);
14915 if (Add == nullptr || Add->getNumOperands() != 2)
14916 return false;
14917
14918 const SCEV *A = Add->getOperand(1);
14919 const auto *Mul = dyn_cast<SCEVMulExpr>(Add->getOperand(0));
14920
14921 if (Mul == nullptr)
14922 return false;
14923
14924 const auto MatchURemWithDivisor = [&](const SCEV *B) {
14925 // (SomeExpr + (-(SomeExpr / B) * B)).
14926 if (Expr == getURemExpr(A, B)) {
14927 LHS = A;
14928 RHS = B;
14929 return true;
14930 }
14931 return false;
14932 };
14933
14934 // (SomeExpr + (-1 * (SomeExpr / B) * B)).
14935 if (Mul->getNumOperands() == 3 && isa<SCEVConstant>(Mul->getOperand(0)))
14936 return MatchURemWithDivisor(Mul->getOperand(1)) ||
14937 MatchURemWithDivisor(Mul->getOperand(2));
14938
14939 // (SomeExpr + ((-SomeExpr / B) * B)) or (SomeExpr + ((SomeExpr / B) * -B)).
14940 if (Mul->getNumOperands() == 2)
14941 return MatchURemWithDivisor(Mul->getOperand(1)) ||
14942 MatchURemWithDivisor(Mul->getOperand(0)) ||
14943 MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(1))) ||
14944 MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(0)));
14945 return false;
14946}
14947
14948const SCEV *
14949ScalarEvolution::computeSymbolicMaxBackedgeTakenCount(const Loop *L) {
14950 SmallVector<BasicBlock*, 16> ExitingBlocks;
14951 L->getExitingBlocks(ExitingBlocks);
14952
14953 // Form an expression for the maximum exit count possible for this loop. We
14954 // merge the max and exact information to approximate a version of
14955 // getConstantMaxBackedgeTakenCount which isn't restricted to just constants.
14956 SmallVector<const SCEV*, 4> ExitCounts;
14957 for (BasicBlock *ExitingBB : ExitingBlocks) {
14958 const SCEV *ExitCount =
14960 if (!isa<SCEVCouldNotCompute>(ExitCount)) {
14961 assert(DT.dominates(ExitingBB, L->getLoopLatch()) &&
14962 "We should only have known counts for exiting blocks that "
14963 "dominate latch!");
14964 ExitCounts.push_back(ExitCount);
14965 }
14966 }
14967 if (ExitCounts.empty())
14968 return getCouldNotCompute();
14969 return getUMinFromMismatchedTypes(ExitCounts, /*Sequential*/ true);
14970}
14971
14972/// A rewriter to replace SCEV expressions in Map with the corresponding entry
14973/// in the map. It skips AddRecExpr because we cannot guarantee that the
14974/// replacement is loop invariant in the loop of the AddRec.
14975class SCEVLoopGuardRewriter : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
14977
14978public:
14981 : SCEVRewriteVisitor(SE), Map(M) {}
14982
14983 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
14984
14985 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14986 auto I = Map.find(Expr);
14987 if (I == Map.end())
14988 return Expr;
14989 return I->second;
14990 }
14991
14993 auto I = Map.find(Expr);
14994 if (I == Map.end()) {
14995 // If we didn't find the extact ZExt expr in the map, check if there's an
14996 // entry for a smaller ZExt we can use instead.
14997 Type *Ty = Expr->getType();
14998 const SCEV *Op = Expr->getOperand(0);
14999 unsigned Bitwidth = Ty->getScalarSizeInBits() / 2;
15000 while (Bitwidth % 8 == 0 && Bitwidth >= 8 &&
15001 Bitwidth > Op->getType()->getScalarSizeInBits()) {
15002 Type *NarrowTy = IntegerType::get(SE.getContext(), Bitwidth);
15003 auto *NarrowExt = SE.getZeroExtendExpr(Op, NarrowTy);
15004 auto I = Map.find(NarrowExt);
15005 if (I != Map.end())
15006 return SE.getZeroExtendExpr(I->second, Ty);
15007 Bitwidth = Bitwidth / 2;
15008 }
15009
15011 Expr);
15012 }
15013 return I->second;
15014 }
15015
15017 auto I = Map.find(Expr);
15018 if (I == Map.end())
15020 Expr);
15021 return I->second;
15022 }
15023
15024 const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) {
15025 auto I = Map.find(Expr);
15026 if (I == Map.end())
15028 return I->second;
15029 }
15030
15031 const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) {
15032 auto I = Map.find(Expr);
15033 if (I == Map.end())
15035 return I->second;
15036 }
15037};
15038
15039const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
15040 SmallVector<const SCEV *> ExprsToRewrite;
15041 auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
15042 const SCEV *RHS,
15044 &RewriteMap) {
15045 // WARNING: It is generally unsound to apply any wrap flags to the proposed
15046 // replacement SCEV which isn't directly implied by the structure of that
15047 // SCEV. In particular, using contextual facts to imply flags is *NOT*
15048 // legal. See the scoping rules for flags in the header to understand why.
15049
15050 // If LHS is a constant, apply information to the other expression.
15051 if (isa<SCEVConstant>(LHS)) {
15052 std::swap(LHS, RHS);
15053 Predicate = CmpInst::getSwappedPredicate(Predicate);
15054 }
15055
15056 // Check for a condition of the form (-C1 + X < C2). InstCombine will
15057 // create this form when combining two checks of the form (X u< C2 + C1) and
15058 // (X >=u C1).
15059 auto MatchRangeCheckIdiom = [this, Predicate, LHS, RHS, &RewriteMap,
15060 &ExprsToRewrite]() {
15061 auto *AddExpr = dyn_cast<SCEVAddExpr>(LHS);
15062 if (!AddExpr || AddExpr->getNumOperands() != 2)
15063 return false;
15064
15065 auto *C1 = dyn_cast<SCEVConstant>(AddExpr->getOperand(0));
15066 auto *LHSUnknown = dyn_cast<SCEVUnknown>(AddExpr->getOperand(1));
15067 auto *C2 = dyn_cast<SCEVConstant>(RHS);
15068 if (!C1 || !C2 || !LHSUnknown)
15069 return false;
15070
15071 auto ExactRegion =
15072 ConstantRange::makeExactICmpRegion(Predicate, C2->getAPInt())
15073 .sub(C1->getAPInt());
15074
15075 // Bail out, unless we have a non-wrapping, monotonic range.
15076 if (ExactRegion.isWrappedSet() || ExactRegion.isFullSet())
15077 return false;
15078 auto I = RewriteMap.find(LHSUnknown);
15079 const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : LHSUnknown;
15080 RewriteMap[LHSUnknown] = getUMaxExpr(
15081 getConstant(ExactRegion.getUnsignedMin()),
15082 getUMinExpr(RewrittenLHS, getConstant(ExactRegion.getUnsignedMax())));
15083 ExprsToRewrite.push_back(LHSUnknown);
15084 return true;
15085 };
15086 if (MatchRangeCheckIdiom())
15087 return;
15088
15089 // Return true if \p Expr is a MinMax SCEV expression with a non-negative
15090 // constant operand. If so, return in \p SCTy the SCEV type and in \p RHS
15091 // the non-constant operand and in \p LHS the constant operand.
15092 auto IsMinMaxSCEVWithNonNegativeConstant =
15093 [&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS,
15094 const SCEV *&RHS) {
15095 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
15096 if (MinMax->getNumOperands() != 2)
15097 return false;
15098 if (auto *C = dyn_cast<SCEVConstant>(MinMax->getOperand(0))) {
15099 if (C->getAPInt().isNegative())
15100 return false;
15101 SCTy = MinMax->getSCEVType();
15102 LHS = MinMax->getOperand(0);
15103 RHS = MinMax->getOperand(1);
15104 return true;
15105 }
15106 }
15107 return false;
15108 };
15109
15110 // Checks whether Expr is a non-negative constant, and Divisor is a positive
15111 // constant, and returns their APInt in ExprVal and in DivisorVal.
15112 auto GetNonNegExprAndPosDivisor = [&](const SCEV *Expr, const SCEV *Divisor,
15113 APInt &ExprVal, APInt &DivisorVal) {
15114 auto *ConstExpr = dyn_cast<SCEVConstant>(Expr);
15115 auto *ConstDivisor = dyn_cast<SCEVConstant>(Divisor);
15116 if (!ConstExpr || !ConstDivisor)
15117 return false;
15118 ExprVal = ConstExpr->getAPInt();
15119 DivisorVal = ConstDivisor->getAPInt();
15120 return ExprVal.isNonNegative() && !DivisorVal.isNonPositive();
15121 };
15122
15123 // Return a new SCEV that modifies \p Expr to the closest number divides by
15124 // \p Divisor and greater or equal than Expr.
15125 // For now, only handle constant Expr and Divisor.
15126 auto GetNextSCEVDividesByDivisor = [&](const SCEV *Expr,
15127 const SCEV *Divisor) {
15128 APInt ExprVal;
15129 APInt DivisorVal;
15130 if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal))
15131 return Expr;
15132 APInt Rem = ExprVal.urem(DivisorVal);
15133 if (!Rem.isZero())
15134 // return the SCEV: Expr + Divisor - Expr % Divisor
15135 return getConstant(ExprVal + DivisorVal - Rem);
15136 return Expr;
15137 };
15138
15139 // Return a new SCEV that modifies \p Expr to the closest number divides by
15140 // \p Divisor and less or equal than Expr.
15141 // For now, only handle constant Expr and Divisor.
15142 auto GetPreviousSCEVDividesByDivisor = [&](const SCEV *Expr,
15143 const SCEV *Divisor) {
15144 APInt ExprVal;
15145 APInt DivisorVal;
15146 if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal))
15147 return Expr;
15148 APInt Rem = ExprVal.urem(DivisorVal);
15149 // return the SCEV: Expr - Expr % Divisor
15150 return getConstant(ExprVal - Rem);
15151 };
15152
15153 // Apply divisibilty by \p Divisor on MinMaxExpr with constant values,
15154 // recursively. This is done by aligning up/down the constant value to the
15155 // Divisor.
15156 std::function<const SCEV *(const SCEV *, const SCEV *)>
15157 ApplyDivisibiltyOnMinMaxExpr = [&](const SCEV *MinMaxExpr,
15158 const SCEV *Divisor) {
15159 const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr;
15160 SCEVTypes SCTy;
15161 if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS,
15162 MinMaxRHS))
15163 return MinMaxExpr;
15164 auto IsMin =
15165 isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr);
15166 assert(isKnownNonNegative(MinMaxLHS) &&
15167 "Expected non-negative operand!");
15168 auto *DivisibleExpr =
15169 IsMin ? GetPreviousSCEVDividesByDivisor(MinMaxLHS, Divisor)
15170 : GetNextSCEVDividesByDivisor(MinMaxLHS, Divisor);
15172 ApplyDivisibiltyOnMinMaxExpr(MinMaxRHS, Divisor), DivisibleExpr};
15173 return getMinMaxExpr(SCTy, Ops);
15174 };
15175
15176 // If we have LHS == 0, check if LHS is computing a property of some unknown
15177 // SCEV %v which we can rewrite %v to express explicitly.
15178 const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS);
15179 if (Predicate == CmpInst::ICMP_EQ && RHSC &&
15180 RHSC->getValue()->isNullValue()) {
15181 // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
15182 // explicitly express that.
15183 const SCEV *URemLHS = nullptr;
15184 const SCEV *URemRHS = nullptr;
15185 if (matchURem(LHS, URemLHS, URemRHS)) {
15186 if (const SCEVUnknown *LHSUnknown = dyn_cast<SCEVUnknown>(URemLHS)) {
15187 auto I = RewriteMap.find(LHSUnknown);
15188 const SCEV *RewrittenLHS =
15189 I != RewriteMap.end() ? I->second : LHSUnknown;
15190 RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS);
15191 const auto *Multiple =
15192 getMulExpr(getUDivExpr(RewrittenLHS, URemRHS), URemRHS);
15193 RewriteMap[LHSUnknown] = Multiple;
15194 ExprsToRewrite.push_back(LHSUnknown);
15195 return;
15196 }
15197 }
15198 }
15199
15200 // Do not apply information for constants or if RHS contains an AddRec.
15201 if (isa<SCEVConstant>(LHS) || containsAddRecurrence(RHS))
15202 return;
15203
15204 // If RHS is SCEVUnknown, make sure the information is applied to it.
15205 if (!isa<SCEVUnknown>(LHS) && isa<SCEVUnknown>(RHS)) {
15206 std::swap(LHS, RHS);
15207 Predicate = CmpInst::getSwappedPredicate(Predicate);
15208 }
15209
15210 // Puts rewrite rule \p From -> \p To into the rewrite map. Also if \p From
15211 // and \p FromRewritten are the same (i.e. there has been no rewrite
15212 // registered for \p From), then puts this value in the list of rewritten
15213 // expressions.
15214 auto AddRewrite = [&](const SCEV *From, const SCEV *FromRewritten,
15215 const SCEV *To) {
15216 if (From == FromRewritten)
15217 ExprsToRewrite.push_back(From);
15218 RewriteMap[From] = To;
15219 };
15220
15221 // Checks whether \p S has already been rewritten. In that case returns the
15222 // existing rewrite because we want to chain further rewrites onto the
15223 // already rewritten value. Otherwise returns \p S.
15224 auto GetMaybeRewritten = [&](const SCEV *S) {
15225 auto I = RewriteMap.find(S);
15226 return I != RewriteMap.end() ? I->second : S;
15227 };
15228
15229 // Check for the SCEV expression (A /u B) * B while B is a constant, inside
15230 // \p Expr. The check is done recuresively on \p Expr, which is assumed to
15231 // be a composition of Min/Max SCEVs. Return whether the SCEV expression (A
15232 // /u B) * B was found, and return the divisor B in \p DividesBy. For
15233 // example, if Expr = umin (umax ((A /u 8) * 8, 16), 64), return true since
15234 // (A /u 8) * 8 matched the pattern, and return the constant SCEV 8 in \p
15235 // DividesBy.
15236 std::function<bool(const SCEV *, const SCEV *&)> HasDivisibiltyInfo =
15237 [&](const SCEV *Expr, const SCEV *&DividesBy) {
15238 if (auto *Mul = dyn_cast<SCEVMulExpr>(Expr)) {
15239 if (Mul->getNumOperands() != 2)
15240 return false;
15241 auto *MulLHS = Mul->getOperand(0);
15242 auto *MulRHS = Mul->getOperand(1);
15243 if (isa<SCEVConstant>(MulLHS))
15244 std::swap(MulLHS, MulRHS);
15245 if (auto *Div = dyn_cast<SCEVUDivExpr>(MulLHS))
15246 if (Div->getOperand(1) == MulRHS) {
15247 DividesBy = MulRHS;
15248 return true;
15249 }
15250 }
15251 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr))
15252 return HasDivisibiltyInfo(MinMax->getOperand(0), DividesBy) ||
15253 HasDivisibiltyInfo(MinMax->getOperand(1), DividesBy);
15254 return false;
15255 };
15256
15257 // Return true if Expr known to divide by \p DividesBy.
15258 std::function<bool(const SCEV *, const SCEV *&)> IsKnownToDivideBy =
15259 [&](const SCEV *Expr, const SCEV *DividesBy) {
15260 if (getURemExpr(Expr, DividesBy)->isZero())
15261 return true;
15262 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr))
15263 return IsKnownToDivideBy(MinMax->getOperand(0), DividesBy) &&
15264 IsKnownToDivideBy(MinMax->getOperand(1), DividesBy);
15265 return false;
15266 };
15267
15268 const SCEV *RewrittenLHS = GetMaybeRewritten(LHS);
15269 const SCEV *DividesBy = nullptr;
15270 if (HasDivisibiltyInfo(RewrittenLHS, DividesBy))
15271 // Check that the whole expression is divided by DividesBy
15272 DividesBy =
15273 IsKnownToDivideBy(RewrittenLHS, DividesBy) ? DividesBy : nullptr;
15274
15275 // Collect rewrites for LHS and its transitive operands based on the
15276 // condition.
15277 // For min/max expressions, also apply the guard to its operands:
15278 // 'min(a, b) >= c' -> '(a >= c) and (b >= c)',
15279 // 'min(a, b) > c' -> '(a > c) and (b > c)',
15280 // 'max(a, b) <= c' -> '(a <= c) and (b <= c)',
15281 // 'max(a, b) < c' -> '(a < c) and (b < c)'.
15282
15283 // We cannot express strict predicates in SCEV, so instead we replace them
15284 // with non-strict ones against plus or minus one of RHS depending on the
15285 // predicate.
15286 const SCEV *One = getOne(RHS->getType());
15287 switch (Predicate) {
15288 case CmpInst::ICMP_ULT:
15289 if (RHS->getType()->isPointerTy())
15290 return;
15291 RHS = getUMaxExpr(RHS, One);
15292 [[fallthrough]];
15293 case CmpInst::ICMP_SLT: {
15294 RHS = getMinusSCEV(RHS, One);
15295 RHS = DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15296 break;
15297 }
15298 case CmpInst::ICMP_UGT:
15299 case CmpInst::ICMP_SGT:
15300 RHS = getAddExpr(RHS, One);
15301 RHS = DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15302 break;
15303 case CmpInst::ICMP_ULE:
15304 case CmpInst::ICMP_SLE:
15305 RHS = DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15306 break;
15307 case CmpInst::ICMP_UGE:
15308 case CmpInst::ICMP_SGE:
15309 RHS = DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS;
15310 break;
15311 default:
15312 break;
15313 }
15314
15317
15318 auto EnqueueOperands = [&Worklist](const SCEVNAryExpr *S) {
15319 append_range(Worklist, S->operands());
15320 };
15321
15322 while (!Worklist.empty()) {
15323 const SCEV *From = Worklist.pop_back_val();
15324 if (isa<SCEVConstant>(From))
15325 continue;
15326 if (!Visited.insert(From).second)
15327 continue;
15328 const SCEV *FromRewritten = GetMaybeRewritten(From);
15329 const SCEV *To = nullptr;
15330
15331 switch (Predicate) {
15332 case CmpInst::ICMP_ULT:
15333 case CmpInst::ICMP_ULE:
15334 To = getUMinExpr(FromRewritten, RHS);
15335 if (auto *UMax = dyn_cast<SCEVUMaxExpr>(FromRewritten))
15336 EnqueueOperands(UMax);
15337 break;
15338 case CmpInst::ICMP_SLT:
15339 case CmpInst::ICMP_SLE:
15340 To = getSMinExpr(FromRewritten, RHS);
15341 if (auto *SMax = dyn_cast<SCEVSMaxExpr>(FromRewritten))
15342 EnqueueOperands(SMax);
15343 break;
15344 case CmpInst::ICMP_UGT:
15345 case CmpInst::ICMP_UGE:
15346 To = getUMaxExpr(FromRewritten, RHS);
15347 if (auto *UMin = dyn_cast<SCEVUMinExpr>(FromRewritten))
15348 EnqueueOperands(UMin);
15349 break;
15350 case CmpInst::ICMP_SGT:
15351 case CmpInst::ICMP_SGE:
15352 To = getSMaxExpr(FromRewritten, RHS);
15353 if (auto *SMin = dyn_cast<SCEVSMinExpr>(FromRewritten))
15354 EnqueueOperands(SMin);
15355 break;
15356 case CmpInst::ICMP_EQ:
15357 if (isa<SCEVConstant>(RHS))
15358 To = RHS;
15359 break;
15360 case CmpInst::ICMP_NE:
15361 if (isa<SCEVConstant>(RHS) &&
15362 cast<SCEVConstant>(RHS)->getValue()->isNullValue()) {
15363 const SCEV *OneAlignedUp =
15364 DividesBy ? GetNextSCEVDividesByDivisor(One, DividesBy) : One;
15365 To = getUMaxExpr(FromRewritten, OneAlignedUp);
15366 }
15367 break;
15368 default:
15369 break;
15370 }
15371
15372 if (To)
15373 AddRewrite(From, FromRewritten, To);
15374 }
15375 };
15376
15377 BasicBlock *Header = L->getHeader();
15379 // First, collect information from assumptions dominating the loop.
15380 for (auto &AssumeVH : AC.assumptions()) {
15381 if (!AssumeVH)
15382 continue;
15383 auto *AssumeI = cast<CallInst>(AssumeVH);
15384 if (!DT.dominates(AssumeI, Header))
15385 continue;
15386 Terms.emplace_back(AssumeI->getOperand(0), true);
15387 }
15388
15389 // Second, collect information from llvm.experimental.guards dominating the loop.
15390 auto *GuardDecl = F.getParent()->getFunction(
15391 Intrinsic::getName(Intrinsic::experimental_guard));
15392 if (GuardDecl)
15393 for (const auto *GU : GuardDecl->users())
15394 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
15395 if (Guard->getFunction() == Header->getParent() && DT.dominates(Guard, Header))
15396 Terms.emplace_back(Guard->getArgOperand(0), true);
15397
15398 // Third, collect conditions from dominating branches. Starting at the loop
15399 // predecessor, climb up the predecessor chain, as long as there are
15400 // predecessors that can be found that have unique successors leading to the
15401 // original header.
15402 // TODO: share this logic with isLoopEntryGuardedByCond.
15403 for (std::pair<const BasicBlock *, const BasicBlock *> Pair(
15404 L->getLoopPredecessor(), Header);
15405 Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
15406
15407 const BranchInst *LoopEntryPredicate =
15408 dyn_cast<BranchInst>(Pair.first->getTerminator());
15409 if (!LoopEntryPredicate || LoopEntryPredicate->isUnconditional())
15410 continue;
15411
15412 Terms.emplace_back(LoopEntryPredicate->getCondition(),
15413 LoopEntryPredicate->getSuccessor(0) == Pair.second);
15414 }
15415
15416 // Now apply the information from the collected conditions to RewriteMap.
15417 // Conditions are processed in reverse order, so the earliest conditions is
15418 // processed first. This ensures the SCEVs with the shortest dependency chains
15419 // are constructed first.
15421 for (auto [Term, EnterIfTrue] : reverse(Terms)) {
15422 SmallVector<Value *, 8> Worklist;
15424 Worklist.push_back(Term);
15425 while (!Worklist.empty()) {
15426 Value *Cond = Worklist.pop_back_val();
15427 if (!Visited.insert(Cond).second)
15428 continue;
15429
15430 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
15431 auto Predicate =
15432 EnterIfTrue ? Cmp->getPredicate() : Cmp->getInversePredicate();
15433 const auto *LHS = getSCEV(Cmp->getOperand(0));
15434 const auto *RHS = getSCEV(Cmp->getOperand(1));
15435 CollectCondition(Predicate, LHS, RHS, RewriteMap);
15436 continue;
15437 }
15438
15439 Value *L, *R;
15440 if (EnterIfTrue ? match(Cond, m_LogicalAnd(m_Value(L), m_Value(R)))
15441 : match(Cond, m_LogicalOr(m_Value(L), m_Value(R)))) {
15442 Worklist.push_back(L);
15443 Worklist.push_back(R);
15444 }
15445 }
15446 }
15447
15448 if (RewriteMap.empty())
15449 return Expr;
15450
15451 // Now that all rewrite information is collect, rewrite the collected
15452 // expressions with the information in the map. This applies information to
15453 // sub-expressions.
15454 if (ExprsToRewrite.size() > 1) {
15455 for (const SCEV *Expr : ExprsToRewrite) {
15456 const SCEV *RewriteTo = RewriteMap[Expr];
15457 RewriteMap.erase(Expr);
15458 SCEVLoopGuardRewriter Rewriter(*this, RewriteMap);
15459 RewriteMap.insert({Expr, Rewriter.visit(RewriteTo)});
15460 }
15461 }
15462
15463 SCEVLoopGuardRewriter Rewriter(*this, RewriteMap);
15464 return Rewriter.visit(Expr);
15465}
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static const LLT S1
Rewrite undef for PHI
This file implements a class to represent arbitrary precision integral constant values and operations...
@ PostInc
Expand Atomic instructions
basic Basic Alias true
block Block Frequency Analysis
BlockVerifier::State From
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
#define LLVM_DUMP_METHOD
Mark debug helper function definitions like dump() that should not be stripped from debug builds.
Definition: Compiler.h:529
This file contains the declarations for the subclasses of Constant, which represent the different fla...
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
#define LLVM_DEBUG(X)
Definition: Debug.h:101
This file defines the DenseMap class.
This file builds on the ADT/GraphTraits.h file to build generic depth first graph iterator.
uint64_t Size
bool End
Definition: ELF_riscv.cpp:480
Generic implementation of equivalence classes through the use Tarjan's efficient union-find algorithm...
static GCMetadataPrinterRegistry::Add< ErlangGCPrinter > X("erlang", "erlang-compatible garbage collector")
static bool isSigned(unsigned int Opcode)
This file defines a hash set that can be used to remove duplication of nodes in a graph.
#define op(i)
Hexagon Common GEP
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
iv Induction Variable Users
Definition: IVUsers.cpp:48
static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, AssumptionCache *AC)
Definition: Lint.cpp:531
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
mir Rename Register Operands
#define T1
static GCMetadataPrinterRegistry::Add< OcamlGCMetadataPrinter > Y("ocaml", "ocaml 3.10-compatible collector")
#define P(N)
ppc ctr loops verify
PowerPC Reduce CR logical Operation
if(VerifyEach)
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition: PassSupport.h:55
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:59
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:52
static bool rewrite(Function &F)
R600 Clause Merge
const SmallVectorImpl< MachineOperand > & Cond
static bool isValid(const char C)
Returns true if C is a valid mangled character: <0-9a-zA-Z_>.
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
SI optimize exec mask operations pre RA
This file contains some templates that are useful if you are working with the STL at all.
raw_pwrite_stream & OS
This file provides utility classes that use RAII to save and restore values.
bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind, SCEVTypes RootKind)
static cl::opt< unsigned > MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden, cl::desc("Max coefficients in AddRec during evolving"), cl::init(8))
static cl::opt< unsigned > RangeIterThreshold("scev-range-iter-threshold", cl::Hidden, cl::desc("Threshold for switching to iteratively computing SCEV ranges"), cl::init(32))
static const Loop * isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI)
static unsigned getConstantTripCount(const SCEVConstant *ExitCount)
static void PushLoopPHIs(const Loop *L, SmallVectorImpl< Instruction * > &Worklist, SmallPtrSetImpl< Instruction * > &Visited)
Push PHI nodes in the header of the given loop onto the given Worklist.
static void insertFoldCacheEntry(const ScalarEvolution::FoldID &ID, const SCEV *S, DenseMap< ScalarEvolution::FoldID, const SCEV * > &FoldCache, DenseMap< const SCEV *, SmallVector< ScalarEvolution::FoldID, 2 > > &FoldCacheUser)
static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
Is LHS Pred RHS true on the virtue of LHS or RHS being a Min or Max expression?
static cl::opt< bool > ClassifyExpressions("scalar-evolution-classify-expressions", cl::Hidden, cl::init(true), cl::desc("When printing analysis, include information on every instruction"))
static bool CanConstantFold(const Instruction *I)
Return true if we can constant fold an instruction of the specified type, assuming that all operands ...
static cl::opt< unsigned > AddOpsInlineThreshold("scev-addops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining addition operands into a SCEV"), cl::init(500))
static cl::opt< bool > VerifyIR("scev-verify-ir", cl::Hidden, cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"), cl::init(false))
static bool BrPHIToSelect(DominatorTree &DT, BranchInst *BI, PHINode *Merge, Value *&C, Value *&LHS, Value *&RHS)
static const SCEV * getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static std::optional< APInt > MinOptional(std::optional< APInt > X, std::optional< APInt > Y)
Helper function to compare optional APInts: (a) if X and Y both exist, return min(X,...
static cl::opt< unsigned > MulOpsInlineThreshold("scev-mulops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining multiplication operands into a SCEV"), cl::init(32))
static void GroupByComplexity(SmallVectorImpl< const SCEV * > &Ops, LoopInfo *LI, DominatorTree &DT)
Given a list of SCEV objects, order them by their complexity, and group objects of the same complexit...
static std::optional< const SCEV * > createNodeForSelectViaUMinSeq(ScalarEvolution *SE, const SCEV *CondExpr, const SCEV *TrueExpr, const SCEV *FalseExpr)
static Constant * BuildConstantFromSCEV(const SCEV *V)
This builds up a Constant using the ConstantExpr interface.
static ConstantInt * EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C, ScalarEvolution &SE)
static const SCEV * BinomialCoefficient(const SCEV *It, unsigned K, ScalarEvolution &SE, Type *ResultTy)
Compute BC(It, K). The result has width W. Assume, K > 0.
static cl::opt< unsigned > MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden, cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"), cl::init(8))
static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr, const SCEV *Candidate)
Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
static PHINode * getConstantEvolvingPHI(Value *V, const Loop *L)
getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node in the loop that V is deri...
static cl::opt< unsigned > MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden, cl::desc("Maximum number of iterations SCEV will " "symbolically execute a constant " "derived loop"), cl::init(100))
static bool MatchBinarySub(const SCEV *S, const SCEV *&LHS, const SCEV *&RHS)
static std::optional< int > CompareSCEVComplexity(EquivalenceClasses< const SCEV * > &EqCacheSCEV, EquivalenceClasses< const Value * > &EqCacheValue, const LoopInfo *const LI, const SCEV *LHS, const SCEV *RHS, DominatorTree &DT, unsigned Depth=0)
static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow)
static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV *S)
When printing a top-level SCEV for trip counts, it's helpful to include a type for constants which ar...
static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, const Loop *L)
static bool containsConstantInAddMulChain(const SCEV *StartExpr)
Determine if any of the operands in this SCEV are a constant or if any of the add or multiply express...
static const SCEV * getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static bool hasHugeExpression(ArrayRef< const SCEV * > Ops)
Returns true if Ops contains a huge SCEV (the subtree of S contains at least HugeExprThreshold nodes)...
static cl::opt< unsigned > MaxPhiSCCAnalysisSize("scalar-evolution-max-scc-analysis-depth", cl::Hidden, cl::desc("Maximum amount of nodes to process while searching SCEVUnknown " "Phi strongly connected components"), cl::init(8))
static int CompareValueComplexity(EquivalenceClasses< const Value * > &EqCacheValue, const LoopInfo *const LI, Value *LV, Value *RV, unsigned Depth)
Compare the two values LV and RV in terms of their "complexity" where "complexity" is a partial (and ...
static cl::opt< unsigned > MaxSCEVOperationsImplicationDepth("scalar-evolution-max-scev-operations-implication-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV operations implication analysis"), cl::init(2))
static void PushDefUseChildren(Instruction *I, SmallVectorImpl< Instruction * > &Worklist, SmallPtrSetImpl< Instruction * > &Visited)
Push users of the given Instruction onto the given Worklist.
static std::optional< APInt > SolveQuadraticAddRecRange(const SCEVAddRecExpr *AddRec, const ConstantRange &Range, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n iterations.
static cl::opt< bool > UseContextForNoWrapFlagInference("scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden, cl::desc("Infer nuw/nsw flags using context where suitable"), cl::init(true))
static cl::opt< bool > EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden, cl::desc("Handle <= and >= in finite loops"), cl::init(true))
static std::optional< std::tuple< APInt, APInt, APInt, APInt, unsigned > > GetQuadraticEquation(const SCEVAddRecExpr *AddRec)
For a given quadratic addrec, generate coefficients of the corresponding quadratic equation,...
static std::optional< BinaryOp > MatchBinaryOp(Value *V, const DataLayout &DL, AssumptionCache &AC, const DominatorTree &DT, const Instruction *CxtI)
Try to map V into a BinaryOp, and return std::nullopt on failure.
APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2)
static std::optional< APInt > SolveQuadraticAddRecExact(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n iterations.
static std::optional< APInt > TruncIfPossible(std::optional< APInt > X, unsigned BitWidth)
Helper function to truncate an optional APInt to a given BitWidth.
static bool IsKnownPredicateViaAddRecStart(ScalarEvolution &SE, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
static cl::opt< unsigned > MaxSCEVCompareDepth("scalar-evolution-max-scev-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV complexity comparisons"), cl::init(32))
static APInt extractConstantWithoutWrapping(ScalarEvolution &SE, const SCEVConstant *ConstantTerm, const SCEVAddExpr *WholeAddExpr)
static cl::opt< unsigned > MaxConstantEvolvingDepth("scalar-evolution-max-constant-evolving-depth", cl::Hidden, cl::desc("Maximum depth of recursive constant evolving"), cl::init(32))
static const SCEV * SolveLinEquationWithOverflow(const APInt &A, const SCEV *B, ScalarEvolution &SE)
Finds the minimum unsigned root of the following equation:
static ConstantRange getRangeForAffineARHelper(APInt Step, const ConstantRange &StartRange, const APInt &MaxBECount, bool Signed)
static std::optional< ConstantRange > GetRangeFromMetadata(Value *V)
Helper method to assign a range to V from metadata present in the IR.
static bool CollectAddOperandsWithScales(DenseMap< const SCEV *, APInt > &M, SmallVectorImpl< const SCEV * > &NewOps, APInt &AccumulatedConstant, ArrayRef< const SCEV * > Ops, const APInt &Scale, ScalarEvolution &SE)
Process the given Ops list, which is a list of operands to be added under the given scale,...
static cl::opt< unsigned > HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden, cl::desc("Size of the expression which is considered huge"), cl::init(4096))
static bool isKnownPredicateExtendIdiom(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
static Type * isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI, bool &Signed, ScalarEvolution &SE)
Helper function to createAddRecFromPHIWithCasts.
static Constant * EvaluateExpression(Value *V, const Loop *L, DenseMap< Instruction *, Constant * > &Vals, const DataLayout &DL, const TargetLibraryInfo *TLI)
EvaluateExpression - Given an expression that passes the getConstantEvolvingPHI predicate,...
static const SCEV * MatchNotExpr(const SCEV *Expr)
If Expr computes ~A, return A else return nullptr.
static cl::opt< unsigned > MaxValueCompareDepth("scalar-evolution-max-value-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive value complexity comparisons"), cl::init(2))
static cl::opt< bool, true > VerifySCEVOpt("verify-scev", cl::Hidden, cl::location(VerifySCEV), cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"))
static const SCEV * getSignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
static SCEV::NoWrapFlags StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, const ArrayRef< const SCEV * > Ops, SCEV::NoWrapFlags Flags)
static cl::opt< unsigned > MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden, cl::desc("Maximum depth of recursive arithmetics"), cl::init(32))
static bool HasSameValue(const SCEV *A, const SCEV *B)
SCEV structural equivalence is usually sufficient for testing whether two expressions are equal,...
static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow)
Compute the result of "n choose k", the binomial coefficient.
static bool canConstantEvolve(Instruction *I, const Loop *L)
Determine whether this instruction can constant evolve within this loop assuming its operands can all...
static PHINode * getConstantEvolvingPHIOperands(Instruction *UseInst, const Loop *L, DenseMap< Instruction *, PHINode * > &PHIMap, unsigned Depth)
getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by recursing through each instructi...
static bool scevUnconditionallyPropagatesPoisonFromOperands(SCEVTypes Kind)
static cl::opt< bool > VerifySCEVStrict("verify-scev-strict", cl::Hidden, cl::desc("Enable stricter verification with -verify-scev is passed"))
static Constant * getOtherIncomingValue(PHINode *PN, BasicBlock *BB)
scalar evolution
static cl::opt< bool > UseExpensiveRangeSharpening("scalar-evolution-use-expensive-range-sharpening", cl::Hidden, cl::init(false), cl::desc("Use more powerful methods of sharpening expression ranges. May " "be costly in terms of compile time"))
static const SCEV * getUnsignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
This file defines the make_scope_exit function, which executes user-defined cleanup logic at scope ex...
Provides some synthesis utilities to produce sequences of values.
This file defines the SmallPtrSet class.
This file defines the SmallSet class.
This file defines the SmallVector class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition: Statistic.h:167
This file contains some functions that are useful when dealing with strings.
static SymbolRef::Type getType(const Symbol *Sym)
Definition: TapiFile.cpp:40
This defines the Use class.
static std::optional< unsigned > getOpcode(ArrayRef< VPValue * > Values)
Returns the opcode of Values or ~0 if they do not all agree.
Definition: VPlanSLP.cpp:191
Virtual Register Rewriter
Definition: VirtRegMap.cpp:237
Value * RHS
Value * LHS
static const uint32_t IV[8]
Definition: blake3_impl.h:78
A rewriter to replace SCEV expressions in Map with the corresponding entry in the map.
const SCEV * visitAddRecExpr(const SCEVAddRecExpr *Expr)
const SCEV * visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr)
SCEVLoopGuardRewriter(ScalarEvolution &SE, DenseMap< const SCEV *, const SCEV * > &M)
const SCEV * visitSignExtendExpr(const SCEVSignExtendExpr *Expr)
const SCEV * visitUnknown(const SCEVUnknown *Expr)
const SCEV * visitUMinExpr(const SCEVUMinExpr *Expr)
const SCEV * visitSMinExpr(const SCEVSMinExpr *Expr)
Class for arbitrary precision integers.
Definition: APInt.h:76
APInt umul_ov(const APInt &RHS, bool &Overflow) const
Definition: APInt.cpp:1977
APInt udiv(const APInt &RHS) const
Unsigned division operation.
Definition: APInt.cpp:1579
APInt zext(unsigned width) const
Zero extend to a new width.
Definition: APInt.cpp:981
bool isMinSignedValue() const
Determine if this is the smallest signed value.
Definition: APInt.h:401
uint64_t getZExtValue() const
Get zero extended value.
Definition: APInt.h:1491
void setHighBits(unsigned hiBits)
Set the top hiBits bits.
Definition: APInt.h:1370
APInt multiplicativeInverse(const APInt &modulo) const
Computes the multiplicative inverse of this APInt for a given modulo.
Definition: APInt.cpp:1250
APInt getHiBits(unsigned numBits) const
Compute an APInt containing numBits highbits from this APInt.
Definition: APInt.cpp:608
APInt zextOrTrunc(unsigned width) const
Zero extend or truncate to width.
Definition: APInt.cpp:1002
unsigned getActiveBits() const
Compute the number of active bits in the value.
Definition: APInt.h:1463
APInt trunc(unsigned width) const
Truncate to new width.
Definition: APInt.cpp:906
static APInt getMaxValue(unsigned numBits)
Gets maximum unsigned value of APInt for specific bit width.
Definition: APInt.h:184
APInt abs() const
Get the absolute value.
Definition: APInt.h:1737
bool ugt(const APInt &RHS) const
Unsigned greater than comparison.
Definition: APInt.h:1160
bool isZero() const
Determine if this value is zero, i.e. all bits are clear.
Definition: APInt.h:358
bool isSignMask() const
Check if the APInt's value is returned by getSignMask.
Definition: APInt.h:444
APInt urem(const APInt &RHS) const
Unsigned remainder operation.
Definition: APInt.cpp:1672
unsigned getBitWidth() const
Return the number of bits in the APInt.
Definition: APInt.h:1439
bool ult(const APInt &RHS) const
Unsigned less than comparison.
Definition: APInt.h:1089
static APInt getSignedMaxValue(unsigned numBits)
Gets maximum signed value of APInt for a specific bit width.
Definition: APInt.h:187
static APInt getMinValue(unsigned numBits)
Gets minimum unsigned value of APInt for a specific bit width.
Definition: APInt.h:194
bool isNegative() const
Determine sign of this APInt.
Definition: APInt.h:307
bool sle(const APInt &RHS) const
Signed less or equal comparison.
Definition: APInt.h:1144
static APInt getSignedMinValue(unsigned numBits)
Gets minimum signed value of APInt for a specific bit width.
Definition: APInt.h:197
unsigned countTrailingZeros() const
Definition: APInt.h:1597
bool isStrictlyPositive() const
Determine if this APInt Value is positive.
Definition: APInt.h:334
APInt ashr(unsigned ShiftAmt) const
Arithmetic right-shift function.
Definition: APInt.h:805
bool ule(const APInt &RHS) const
Unsigned less or equal comparison.
Definition: APInt.h:1128
APInt sext(unsigned width) const
Sign extend to a new width.
Definition: APInt.cpp:954
APInt shl(unsigned shiftAmt) const
Left-shift function.
Definition: APInt.h:851
static APInt getLowBitsSet(unsigned numBits, unsigned loBitsSet)
Constructs an APInt value that has the bottom loBitsSet bits set.
Definition: APInt.h:284
bool isSignBitSet() const
Determine if sign bit of this APInt is set.
Definition: APInt.h:319
bool slt(const APInt &RHS) const
Signed less than comparison.
Definition: APInt.h:1108
static APInt getZero(unsigned numBits)
Get the '0' value for the specified bit-width.
Definition: APInt.h:178
bool isIntN(unsigned N) const
Check if this APInt has an N-bits unsigned integer value.
Definition: APInt.h:410
static APInt getOneBitSet(unsigned numBits, unsigned BitNo)
Return an APInt with exactly one bit set in the result.
Definition: APInt.h:217
bool uge(const APInt &RHS) const
Unsigned greater or equal comparison.
Definition: APInt.h:1199
This templated class represents "all analyses that operate over <a particular IR unit>" (e....
Definition: Analysis.h:47
API to communicate dependencies between analyses during invalidation.
Definition: PassManager.h:387
bool invalidate(IRUnitT &IR, const PreservedAnalyses &PA)
Trigger the invalidation of some other analysis pass if not already handled and return whether it was...
Definition: PassManager.h:405
A container for analyses that lazily runs them and caches their results.
Definition: PassManager.h:348
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Definition: PassManager.h:500
Represent the analysis usage information of a pass.
void setPreservesAll()
Set by analyses that do not transform their input at all.
AnalysisUsage & addRequiredTransitive()
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition: ArrayRef.h:41
ArrayRef< T > take_front(size_t N=1) const
Return a copy of *this with only the first N elements.
Definition: ArrayRef.h:228
iterator end() const
Definition: ArrayRef.h:154
size_t size() const
size - Get the array size.
Definition: ArrayRef.h:165
iterator begin() const
Definition: ArrayRef.h:153
A function analysis which provides an AssumptionCache.
An immutable pass that tracks lazily created AssumptionCache objects.
A cache of @llvm.assume calls within a function.
MutableArrayRef< ResultElem > assumptions()
Access the list of assumption handles currently tracked for this function.
bool isSingleEdge() const
Check if this is the only edge between Start and End.
Definition: Dominators.cpp:51
LLVM Basic Block Representation.
Definition: BasicBlock.h:60
iterator begin()
Instruction iterator methods.
Definition: BasicBlock.h:429
const Instruction & front() const
Definition: BasicBlock.h:452
const BasicBlock * getSinglePredecessor() const
Return the predecessor of this block if it has a single predecessor block.
Definition: BasicBlock.cpp:439
const Function * getParent() const
Return the enclosing method, or null if none.
Definition: BasicBlock.h:205
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition: BasicBlock.h:220
Value * getRHS() const
unsigned getNoWrapKind() const
Returns one of OBO::NoSignedWrap or OBO::NoUnsignedWrap.
Instruction::BinaryOps getBinaryOp() const
Returns the binary operation underlying the intrinsic.
Value * getLHS() const
BinaryOps getOpcode() const
Definition: InstrTypes.h:491
Conditional or Unconditional Branch instruction.
bool isConditional() const
BasicBlock * getSuccessor(unsigned i) const
bool isUnconditional() const
Value * getCondition() const
LLVM_ATTRIBUTE_RETURNS_NONNULL void * Allocate(size_t Size, Align Alignment)
Allocate space at the specified alignment.
Definition: Allocator.h:148
This class represents a function call, abstracting a target machine's calling convention.
Value handle with callbacks on RAUW and destruction.
Definition: ValueHandle.h:383
void setValPtr(Value *P)
Definition: ValueHandle.h:390
bool isFalseWhenEqual() const
This is just a convenience.
Definition: InstrTypes.h:1281
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition: InstrTypes.h:965
@ ICMP_SLT
signed less than
Definition: InstrTypes.h:994
@ ICMP_SLE
signed less or equal
Definition: InstrTypes.h:995
@ ICMP_UGE
unsigned greater or equal
Definition: InstrTypes.h:989
@ ICMP_UGT
unsigned greater than
Definition: InstrTypes.h:988
@ ICMP_SGT
signed greater than
Definition: InstrTypes.h:992
@ ICMP_ULT
unsigned less than
Definition: InstrTypes.h:990
@ ICMP_EQ
equal
Definition: InstrTypes.h:986
@ ICMP_NE
not equal
Definition: InstrTypes.h:987
@ ICMP_SGE
signed greater or equal
Definition: InstrTypes.h:993
@ ICMP_ULE
unsigned less or equal
Definition: InstrTypes.h:991
bool isSigned() const
Definition: InstrTypes.h:1226
Predicate getSwappedPredicate() const
For example, EQ->EQ, SLE->SGE, ULT->UGT, OEQ->OEQ, ULE->UGE, OLT->OGT, etc.
Definition: InstrTypes.h:1128
bool isTrueWhenEqual() const
This is just a convenience.
Definition: InstrTypes.h:1275
Predicate getNonStrictPredicate() const
For example, SGT -> SGE, SLT -> SLE, ULT -> ULE, UGT -> UGE.
Definition: InstrTypes.h:1172
Predicate getInversePredicate() const
For example, EQ -> NE, UGT -> ULE, SLT -> SGE, OEQ -> UNE, UGT -> OLE, OLT -> UGE,...
Definition: InstrTypes.h:1090
Predicate getPredicate() const
Return the predicate for this instruction.
Definition: InstrTypes.h:1066
Predicate getFlippedSignednessPredicate()
For example, SLT->ULT, ULT->SLT, SLE->ULE, ULE->SLE, EQ->Failed assert.
Definition: InstrTypes.h:1269
bool isUnsigned() const
Definition: InstrTypes.h:1232
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
Definition: InstrTypes.h:1222
static Constant * getNot(Constant *C)
Definition: Constants.cpp:2531
static Constant * getPtrToInt(Constant *C, Type *Ty, bool OnlyIfReduced=false)
Definition: Constants.cpp:2112
static Constant * getICmp(unsigned short pred, Constant *LHS, Constant *RHS, bool OnlyIfReduced=false)
get* - Return some common constants without having to specify the full Instruction::OPCODE identifier...
Definition: Constants.cpp:2404
static Constant * getGetElementPtr(Type *Ty, Constant *C, ArrayRef< Constant * > IdxList, bool InBounds=false, std::optional< unsigned > InRangeIndex=std::nullopt, Type *OnlyIfReducedTy=nullptr)
Getelementptr form.
Definition: Constants.h:1201
static Constant * getAdd(Constant *C1, Constant *C2, bool HasNUW=false, bool HasNSW=false)
Definition: Constants.cpp:2537
static Constant * getTrunc(Constant *C, Type *Ty, bool OnlyIfReduced=false)
Definition: Constants.cpp:2098
static Constant * getNeg(Constant *C, bool HasNUW=false, bool HasNSW=false)
Definition: Constants.cpp:2525
This is the shared class of boolean and integer constants.
Definition: Constants.h:79
bool isMinusOne() const
This function will return true iff every bit in this constant is set to true.
Definition: Constants.h:216
bool isOne() const
This is just a convenience method to make client code smaller for a common case.
Definition: Constants.h:210
bool isZero() const
This is just a convenience method to make client code smaller for a common code.
Definition: Constants.h:204
static ConstantInt * getFalse(LLVMContext &Context)
Definition: Constants.cpp:856
uint64_t getZExtValue() const
Return the constant as a 64-bit unsigned integer value after it has been zero extended as appropriate...
Definition: Constants.h:153
const APInt & getValue() const
Return the constant as an APInt value reference.
Definition: Constants.h:144
static ConstantInt * getBool(LLVMContext &Context, bool V)
Definition: Constants.cpp:863
This class represents a range of values.
Definition: ConstantRange.h:47
ConstantRange add(const ConstantRange &Other) const
Return a new range representing the possible values resulting from an addition of a value in this ran...
ConstantRange zextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
PreferredRangeType
If represented precisely, the result of some range operations may consist of multiple disjoint ranges...
bool getEquivalentICmp(CmpInst::Predicate &Pred, APInt &RHS) const
Set up Pred and RHS such that ConstantRange::makeExactICmpRegion(Pred, RHS) == *this.
const APInt & getLower() const
Return the lower value for this range.
ConstantRange truncate(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly smaller than the current typ...
bool isFullSet() const
Return true if this set contains all of the elements possible for this data-type.
bool icmp(CmpInst::Predicate Pred, const ConstantRange &Other) const
Does the predicate Pred hold between ranges this and Other? NOTE: false does not mean that inverse pr...
bool isEmptySet() const
Return true if this set contains no members.
ConstantRange zeroExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
bool isSignWrappedSet() const
Return true if this set wraps around the signed domain.
APInt getSignedMin() const
Return the smallest signed value contained in the ConstantRange.
bool isWrappedSet() const
Return true if this set wraps around the unsigned domain.
void print(raw_ostream &OS) const
Print out the bounds to a stream.
ConstantRange signExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
const APInt & getUpper() const
Return the upper value for this range.
ConstantRange unionWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the union of this range with another range.
static ConstantRange makeExactICmpRegion(CmpInst::Predicate Pred, const APInt &Other)
Produce the exact range such that all values in the returned range satisfy the given predicate with a...
bool contains(const APInt &Val) const
Return true if the specified value is in the set.
APInt getUnsignedMax() const
Return the largest unsigned value contained in the ConstantRange.
ConstantRange intersectWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the intersection of this range with another range.
APInt getSignedMax() const
Return the largest signed value contained in the ConstantRange.
static ConstantRange getNonEmpty(APInt Lower, APInt Upper)
Create non-empty constant range with the given bounds.
Definition: ConstantRange.h:84
static ConstantRange makeGuaranteedNoWrapRegion(Instruction::BinaryOps BinOp, const ConstantRange &Other, unsigned NoWrapKind)
Produce the largest range containing all X such that "X BinOp Y" is guaranteed not to wrap (overflow)...
unsigned getMinSignedBits() const
Compute the maximal number of bits needed to represent every value in this signed range.
uint32_t getBitWidth() const
Get the bit width of this ConstantRange.
ConstantRange sub(const ConstantRange &Other) const
Return a new range representing the possible values resulting from a subtraction of a value in this r...
ConstantRange sextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
static ConstantRange makeExactNoWrapRegion(Instruction::BinaryOps BinOp, const APInt &Other, unsigned NoWrapKind)
Produce the range that contains X if and only if "X BinOp Other" does not wrap.
This is an important base class in LLVM.
Definition: Constant.h:41
bool isNullValue() const
Return true if this is the value that would be returned by getNullValue.
Definition: Constants.cpp:90
This class represents an Operation in the Expression.
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:110
const StructLayout * getStructLayout(StructType *Ty) const
Returns a StructLayout object, indicating the alignment of the struct, its size, and the offsets of i...
Definition: DataLayout.cpp:720
IntegerType * getIntPtrType(LLVMContext &C, unsigned AddressSpace=0) const
Returns an integer type with size at least as big as that of a pointer in the given address space.
Definition: DataLayout.cpp:878
unsigned getIndexTypeSizeInBits(Type *Ty) const
Layout size of the index used in GEP calculation.
Definition: DataLayout.cpp:774
IntegerType * getIndexType(LLVMContext &C, unsigned AddressSpace) const
Returns the type of a GEP index in AddressSpace.
Definition: DataLayout.cpp:905
TypeSize getTypeSizeInBits(Type *Ty) const
Size examples:
Definition: DataLayout.h:672
ValueT lookup(const_arg_type_t< KeyT > Val) const
lookup - Return the entry for the specified key, or a default constructed value if no such entry exis...
Definition: DenseMap.h:202
iterator find(const_arg_type_t< KeyT > Val)
Definition: DenseMap.h:155
bool erase(const KeyT &Val)
Definition: DenseMap.h:329
DenseMapIterator< KeyT, ValueT, KeyInfoT, BucketT > iterator
Definition: DenseMap.h:71
iterator find_as(const LookupKeyT &Val)
Alternate version of find() which allows a different, and possibly less expensive,...
Definition: DenseMap.h:180
bool empty() const
Definition: DenseMap.h:98
size_type count(const_arg_type_t< KeyT > Val) const
Return 1 if the specified key is in the map, 0 otherwise.
Definition: DenseMap.h:151
iterator end()
Definition: DenseMap.h:84
bool contains(const_arg_type_t< KeyT > Val) const
Return true if the specified key is in the map, false otherwise.
Definition: DenseMap.h:145
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition: DenseMap.h:220
Analysis pass which computes a DominatorTree.
Definition: Dominators.h:279
bool properlyDominates(const DomTreeNodeBase< NodeT > *A, const DomTreeNodeBase< NodeT > *B) const
properlyDominates - Returns true iff A dominates B and A != B.
Legacy analysis pass which computes a DominatorTree.
Definition: Dominators.h:317
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition: Dominators.h:162
bool isReachableFromEntry(const Use &U) const
Provide an overload for a Use.
Definition: Dominators.cpp:321
bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
Definition: Dominators.cpp:122
EquivalenceClasses - This represents a collection of equivalence classes and supports three efficient...
member_iterator unionSets(const ElemTy &V1, const ElemTy &V2)
union - Merge the two equivalence sets for the specified values, inserting them if they do not alread...
bool isEquivalent(const ElemTy &V1, const ElemTy &V2) const
FoldingSetNodeIDRef - This class describes a reference to an interned FoldingSetNodeID,...
Definition: FoldingSet.h:290
FoldingSetNodeID - This class is used to gather all the unique data bits of a node.
Definition: FoldingSet.h:320
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:311
const BasicBlock & getEntryBlock() const
Definition: Function.h:782
static Type * getTypeAtIndex(Type *Ty, Value *Idx)
Return the type of the element at the given index of an indexable type.
Module * getParent()
Get the module that this global value is contained inside of...
Definition: GlobalValue.h:655
static bool isPrivateLinkage(LinkageTypes Linkage)
Definition: GlobalValue.h:405
static bool isInternalLinkage(LinkageTypes Linkage)
Definition: GlobalValue.h:402
This instruction compares its operands according to the predicate given to the constructor.
static bool isGE(Predicate P)
Return true if the predicate is SGE or UGE.
static bool isLT(Predicate P)
Return true if the predicate is SLT or ULT.
static bool isGT(Predicate P)
Return true if the predicate is SGT or UGT.
bool isEquality() const
Return true if this predicate is either EQ or NE.
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
static bool isLE(Predicate P)
Return true if the predicate is SLE or ULE.
bool hasNoUnsignedWrap() const LLVM_READONLY
Determine whether the no unsigned wrap flag is set.
bool hasNoSignedWrap() const LLVM_READONLY
Determine whether the no signed wrap flag is set.
const BasicBlock * getParent() const
Definition: Instruction.h:151
Class to represent integer types.
Definition: DerivedTypes.h:40
static IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
Definition: Type.cpp:278
An instruction for reading from memory.
Definition: Instructions.h:184
Analysis pass that exposes the LoopInfo for a function.
Definition: LoopInfo.h:566
bool contains(const LoopT *L) const
Return true if the specified loop is contained within in this loop.
BlockT * getHeader() const
unsigned getLoopDepth() const
Return the nesting level of this loop.
BlockT * getLoopPredecessor() const
If the given loop's header has exactly one unique predecessor outside the loop, return it.
LoopT * getParentLoop() const
Return the parent loop if it exists or nullptr for top level loops.
iterator end() const
unsigned getLoopDepth(const BlockT *BB) const
Return the loop nesting level of the specified block.
iterator begin() const
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
The legacy pass manager's analysis pass to compute loop information.
Definition: LoopInfo.h:593
Represents a single loop in the control flow graph.
Definition: LoopInfo.h:44
bool isLoopInvariant(const Value *V) const
Return true if the specified value is loop invariant.
Definition: LoopInfo.cpp:60
Metadata node.
Definition: Metadata.h:1067
A Module instance is used to store all the information related to an LLVM module.
Definition: Module.h:65
Function * getFunction(StringRef Name) const
Look up the specified function in the module symbol table.
Definition: Module.cpp:191
This is a utility class that provides an abstraction for the common functionality between Instruction...
Definition: Operator.h:31
unsigned getOpcode() const
Return the opcode for this Instruction or ConstantExpr.
Definition: Operator.h:41
Utility class for integer operators which may exhibit overflow - Add, Sub, Mul, and Shl.
Definition: Operator.h:75
bool hasNoSignedWrap() const
Test whether this operation is known to never undergo signed overflow, aka the nsw property.
Definition: Operator.h:108
bool hasNoUnsignedWrap() const
Test whether this operation is known to never undergo unsigned overflow, aka the nuw property.
Definition: Operator.h:102
iterator_range< const_block_iterator > blocks() const
Value * getIncomingValueForBlock(const BasicBlock *BB) const
BasicBlock * getIncomingBlock(unsigned i) const
Return incoming basic block number i.
Value * getIncomingValue(unsigned i) const
Return incoming value number x.
unsigned getNumIncomingValues() const
Return the number of incoming edges.
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
PointerIntPair - This class implements a pair of a pointer and small integer.
static PointerType * getUnqual(Type *ElementType)
This constructs a pointer to an object of the specified type in the default address space (address sp...
Definition: DerivedTypes.h:662
static PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
Definition: Constants.cpp:1827
An interface layer with SCEV used to manage how we see SCEV expressions for values in the context of ...
void addPredicate(const SCEVPredicate &Pred)
Adds a new predicate.
const SCEVPredicate & getPredicate() const
bool hasNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags)
Returns true if we've proved that V doesn't wrap by means of a SCEV predicate.
void setNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags)
Proves that V doesn't overflow by adding SCEV predicate.
void print(raw_ostream &OS, unsigned Depth) const
Print the SCEV mappings done by the Predicated Scalar Evolution.
bool areAddRecsEqualWithPreds(const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const
Check if AR1 and AR2 are equal, while taking into account Equal predicates in Preds.
PredicatedScalarEvolution(ScalarEvolution &SE, Loop &L)
const SCEVAddRecExpr * getAsAddRec(Value *V)
Attempts to produce an AddRecExpr for V by adding additional SCEV predicates.
const SCEV * getBackedgeTakenCount()
Get the (predicated) backedge count for the analyzed loop.
const SCEV * getSCEV(Value *V)
Returns the SCEV expression of V, in the context of the current SCEV predicate.
A set of analyses that are preserved following a run of a transformation pass.
Definition: Analysis.h:109
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: Analysis.h:115
PreservedAnalysisChecker getChecker() const
Build a checker for this PreservedAnalyses and the specified analysis type.
Definition: Analysis.h:264
constexpr bool isValid() const
Definition: Register.h:116
This node represents an addition of some number of SCEVs.
This node represents a polynomial recurrence on the trip count of the specified loop.
const SCEV * evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const
Return the value of this chain of recurrences at the specified iteration number.
const SCEV * getStepRecurrence(ScalarEvolution &SE) const
Constructs and returns the recurrence indicating how much this expression steps by.
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a recurrence without clearing any previously set flags.
bool isAffine() const
Return true if this represents an expression A + B*x where A and B are loop invariant values.
bool isQuadratic() const
Return true if this represents an expression A + B*x + C*x^2 where A, B and C are loop invariant valu...
const SCEV * getNumIterationsInRange(const ConstantRange &Range, ScalarEvolution &SE) const
Return the number of iterations of this loop that produce values in the specified constant range.
const SCEVAddRecExpr * getPostIncExpr(ScalarEvolution &SE) const
Return an expression representing the value of this expression one iteration of the loop ahead.
This is the base class for unary cast operator classes.
const SCEV * getOperand() const
SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op, Type *ty)
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a non-recurrence without clearing previously set flags.
This class represents an assumption that the expression LHS Pred RHS evaluates to true,...
SCEVComparePredicate(const FoldingSetNodeIDRef ID, const ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
bool implies(const SCEVPredicate *N) const override
Implementation of the SCEVPredicate interface.
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
This class represents a constant integer value.
ConstantInt * getValue() const
const APInt & getAPInt() const
This is the base class for unary integral cast operator classes.
SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op, Type *ty)
This node is the base class min/max selections.
static enum SCEVTypes negate(enum SCEVTypes T)
This node represents multiplication of some number of SCEVs.
This node is a base class providing common functionality for n'ary operators.
NoWrapFlags getNoWrapFlags(NoWrapFlags Mask=NoWrapMask) const
const SCEV * getOperand(unsigned i) const
const SCEV *const * Operands
ArrayRef< const SCEV * > operands() const
This class represents an assumption made using SCEV expressions which can be checked at run-time.
virtual bool implies(const SCEVPredicate *N) const =0
Returns true if this predicate implies N.
SCEVPredicate(const SCEVPredicate &)=default
virtual void print(raw_ostream &OS, unsigned Depth=0) const =0
Prints a textual representation of this predicate with an indentation of Depth.
This class represents a cast from a pointer to a pointer-sized integer value.
This visitor recursively visits a SCEV expression and re-writes it.
const SCEV * visitSignExtendExpr(const SCEVSignExtendExpr *Expr)
const SCEV * visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr)
const SCEV * visitSMinExpr(const SCEVSMinExpr *Expr)
const SCEV * visitUMinExpr(const SCEVUMinExpr *Expr)
This class represents a signed maximum selection.
This class represents a signed minimum selection.
This node is the base class for sequential/in-order min/max selections.
This class represents a sequential/in-order unsigned minimum selection.
This class represents a sign extension of a small integer value to a larger integer value.
Visit all nodes in the expression tree using worklist traversal.
void visitAll(const SCEV *Root)
This class represents a truncation of an integer value to a smaller integer value.
This class represents a binary unsigned division operation.
const SCEV * getLHS() const
const SCEV * getRHS() const
This class represents an unsigned maximum selection.
This class represents an unsigned minimum selection.
This class represents a composition of other SCEV predicates, and is the class that most clients will...
SCEVUnionPredicate(ArrayRef< const SCEVPredicate * > Preds)
Union predicates don't get cached so create a dummy set ID for it.
void print(raw_ostream &OS, unsigned Depth) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool isAlwaysTrue() const override
Implementation of the SCEVPredicate interface.
bool implies(const SCEVPredicate *N) const override
Returns true if this predicate implies N.
This means that we are dealing with an entirely unknown SCEV value, and only represent it as its LLVM...
This class represents the value of vscale, as used when defining the length of a scalable vector or r...
This class represents an assumption made on an AddRec expression.
IncrementWrapFlags
Similar to SCEV::NoWrapFlags, but with slightly different semantics for FlagNUSW.
SCEVWrapPredicate(const FoldingSetNodeIDRef ID, const SCEVAddRecExpr *AR, IncrementWrapFlags Flags)
bool implies(const SCEVPredicate *N) const override
Returns true if this predicate implies N.
static SCEVWrapPredicate::IncrementWrapFlags setFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OnFlags)
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
const SCEVAddRecExpr * getExpr() const
Implementation of the SCEVPredicate interface.
static SCEVWrapPredicate::IncrementWrapFlags clearFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OffFlags)
Convenient IncrementWrapFlags manipulation methods.
static SCEVWrapPredicate::IncrementWrapFlags getImpliedFlags(const SCEVAddRecExpr *AR, ScalarEvolution &SE)
Returns the set of SCEVWrapPredicate no wrap flags implied by a SCEVAddRecExpr.
IncrementWrapFlags getFlags() const
Returns the set assumed no overflow flags.
This class represents a zero extension of a small integer value to a larger integer value.
This class represents an analyzed expression in the program.
ArrayRef< const SCEV * > operands() const
Return operands of this SCEV expression.
unsigned short getExpressionSize() const
bool isOne() const
Return true if the expression is a constant one.
bool isZero() const
Return true if the expression is a constant zero.
void dump() const
This method is used for debugging.
bool isAllOnesValue() const
Return true if the expression is a constant all-ones value.
bool isNonConstantNegative() const
Return true if the specified scev is negated, but not a constant.
void print(raw_ostream &OS) const
Print out the internal representation of this scalar to the specified stream.
SCEVTypes getSCEVType() const
Type * getType() const
Return the LLVM type of this SCEV expression.
NoWrapFlags
NoWrapFlags are bitfield indices into SubclassData.
Analysis pass that exposes the ScalarEvolution for a function.
ScalarEvolution run(Function &F, FunctionAnalysisManager &AM)
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
void print(raw_ostream &OS, const Module *=nullptr) const override
print - Print out the internal state of the pass.
bool runOnFunction(Function &F) override
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
void releaseMemory() override
releaseMemory() - This member can be implemented by a pass if it wants to be able to release its memo...
void verifyAnalysis() const override
verifyAnalysis() - This member can be implemented by a analysis pass to check state of analysis infor...
The main scalar evolution driver.
const SCEV * getConstantMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEVConstant that is greater than or equal to (i.e.
static bool hasFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags TestFlags)
const DataLayout & getDataLayout() const
Return the DataLayout associated with the module this SCEV instance is operating on.
bool isKnownNonNegative(const SCEV *S)
Test if the given expression is known to be non-negative.
const SCEV * getNegativeSCEV(const SCEV *V, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap)
Return the SCEV object corresponding to -V.
bool isLoopBackedgeGuardedByCond(const Loop *L, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether the backedge of the loop is protected by a conditional between LHS and RHS.
const SCEV * getSMaxExpr(const SCEV *LHS, const SCEV *RHS)
const SCEV * getUDivCeilSCEV(const SCEV *N, const SCEV *D)
Compute ceil(N / D).
const SCEV * getGEPExpr(GEPOperator *GEP, const SmallVectorImpl< const SCEV * > &IndexExprs)
Returns an expression for a GEP.
Type * getWiderType(Type *Ty1, Type *Ty2) const
const SCEV * getAbsExpr(const SCEV *Op, bool IsNSW)
bool isKnownNonPositive(const SCEV *S)
Test if the given expression is known to be non-positive.
const SCEV * getURemExpr(const SCEV *LHS, const SCEV *RHS)
Represents an unsigned remainder expression based on unsigned division.
bool SimplifyICmpOperands(ICmpInst::Predicate &Pred, const SCEV *&LHS, const SCEV *&RHS, unsigned Depth=0)
Simplify LHS and RHS in a comparison with predicate Pred.
APInt getConstantMultiple(const SCEV *S)
Returns the max constant multiple of S.
bool isKnownNegative(const SCEV *S)
Test if the given expression is known to be negative.
const SCEV * removePointerBase(const SCEV *S)
Compute an expression equivalent to S - getPointerBase(S).
bool isKnownNonZero(const SCEV *S)
Test if the given expression is known to be non-zero.
const SCEV * getSCEVAtScope(const SCEV *S, const Loop *L)
Return a SCEV expression for the specified value at the specified scope in the program.
const SCEV * getSMinExpr(const SCEV *LHS, const SCEV *RHS)
const SCEV * getBackedgeTakenCount(const Loop *L, ExitCountKind Kind=Exact)
If the specified loop has a predictable backedge-taken count, return it, otherwise return a SCEVCould...
const SCEV * getUMaxExpr(const SCEV *LHS, const SCEV *RHS)
void setNoWrapFlags(SCEVAddRecExpr *AddRec, SCEV::NoWrapFlags Flags)
Update no-wrap flags of an AddRec.
const SCEV * getUMaxFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS)
Promote the operands to the wider of the types using zero-extension, and then perform a umax operatio...
const SCEV * getZero(Type *Ty)
Return a SCEV for the constant 0 of a specific type.
bool willNotOverflow(Instruction::BinaryOps BinOp, bool Signed, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI=nullptr)
Is operation BinOp between LHS and RHS provably does not have a signed/unsigned overflow (Signed)?...
ExitLimit computeExitLimitFromCond(const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit, bool AllowPredicates=false)
Compute the number of times the backedge of the specified loop will execute if its exit condition wer...
const SCEV * getZeroExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
const SCEVPredicate * getEqualPredicate(const SCEV *LHS, const SCEV *RHS)
unsigned getSmallConstantTripMultiple(const Loop *L, const SCEV *ExitCount)
Returns the largest constant divisor of the trip count as a normal unsigned value,...
uint64_t getTypeSizeInBits(Type *Ty) const
Return the size in bits of the specified type, for which isSCEVable must return true.
const SCEV * getConstant(ConstantInt *V)
const SCEV * getSCEV(Value *V)
Return a SCEV expression for the full generality of the specified expression.
ConstantRange getSignedRange(const SCEV *S)
Determine the signed range for a particular SCEV.
const SCEV * getNoopOrSignExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
unsigned getSmallConstantMaxTripCount(const Loop *L)
Returns the upper bound of the loop trip count as a normal unsigned value.
bool loopHasNoAbnormalExits(const Loop *L)
Return true if the loop has no abnormal exits.
const SCEV * getTripCountFromExitCount(const SCEV *ExitCount)
A version of getTripCountFromExitCount below which always picks an evaluation type which can not resu...
ScalarEvolution(Function &F, TargetLibraryInfo &TLI, AssumptionCache &AC, DominatorTree &DT, LoopInfo &LI)
const SCEV * getOne(Type *Ty)
Return a SCEV for the constant 1 of a specific type.
const SCEV * getTruncateOrNoop(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
const SCEV * getCastExpr(SCEVTypes Kind, const SCEV *Op, Type *Ty)
const SCEV * getSequentialMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< const SCEV * > &Operands)
const SCEV * getLosslessPtrToIntExpr(const SCEV *Op, unsigned Depth=0)
bool isKnownViaInduction(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
We'd like to check the predicate on every iteration of the most dominated loop between loops used in ...
std::optional< bool > evaluatePredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
Check whether the condition described by Pred, LHS, and RHS is true or false.
bool isKnownPredicateAt(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
const SCEV * getPtrToIntExpr(const SCEV *Op, Type *Ty)
bool isBackedgeTakenCountMaxOrZero(const Loop *L)
Return true if the backedge taken count is either the value returned by getConstantMaxBackedgeTakenCo...
void forgetLoop(const Loop *L)
This method should be called by the client when it has changed a loop in a way that may effect Scalar...
bool isLoopInvariant(const SCEV *S, const Loop *L)
Return true if the value of the given SCEV is unchanging in the specified loop.
bool isKnownPositive(const SCEV *S)
Test if the given expression is known to be positive.
APInt getUnsignedRangeMin(const SCEV *S)
Determine the min of the unsigned range for a particular SCEV.
bool isKnownPredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
const SCEV * getOffsetOfExpr(Type *IntTy, StructType *STy, unsigned FieldNo)
Return an expression for offsetof on the given field with type IntTy.
LoopDisposition getLoopDisposition(const SCEV *S, const Loop *L)
Return the "disposition" of the given SCEV with respect to the given loop.
bool containsAddRecurrence(const SCEV *S)
Return true if the SCEV is a scAddRecExpr or it contains scAddRecExpr.
const SCEV * getSignExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
const SCEV * getAddRecExpr(const SCEV *Start, const SCEV *Step, const Loop *L, SCEV::NoWrapFlags Flags)
Get an add recurrence expression for the specified loop.
bool isBasicBlockEntryGuardedByCond(const BasicBlock *BB, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the basic block is protected by a conditional between LHS and RHS.
bool isKnownOnEveryIteration(ICmpInst::Predicate Pred, const SCEVAddRecExpr *LHS, const SCEV *RHS)
Test if the condition described by Pred, LHS, RHS is known to be true on every iteration of the loop ...
bool hasOperand(const SCEV *S, const SCEV *Op) const
Test whether the given SCEV has Op as a direct or indirect operand.
const SCEV * getUDivExpr(const SCEV *LHS, const SCEV *RHS)
Get a canonical unsigned division expression, or something simpler if possible.
const SCEV * getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
bool isSCEVable(Type *Ty) const
Test if values of the given type are analyzable within the SCEV framework.
Type * getEffectiveSCEVType(Type *Ty) const
Return a type with the same bitwidth as the given type and which represents how SCEV will treat the g...
const SCEVPredicate * getComparePredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
const SCEV * getNotSCEV(const SCEV *V)
Return the SCEV object corresponding to ~V.
std::optional< LoopInvariantPredicate > getLoopInvariantPredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI=nullptr)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L, return a LoopInvaria...
bool instructionCouldExistWithOperands(const SCEV *A, const SCEV *B)
Return true if there exists a point in the program at which both A and B could be operands to the sam...
std::optional< bool > evaluatePredicateAt(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Check whether the condition described by Pred, LHS, and RHS is true or false in the given Context.
ConstantRange getUnsignedRange(const SCEV *S)
Determine the unsigned range for a particular SCEV.
uint32_t getMinTrailingZeros(const SCEV *S)
Determine the minimum number of zero bits that S is guaranteed to end in (at every loop iteration).
void print(raw_ostream &OS) const
const SCEV * getUMinExpr(const SCEV *LHS, const SCEV *RHS, bool Sequential=false)
const SCEV * getPredicatedBackedgeTakenCount(const Loop *L, SmallVector< const SCEVPredicate *, 4 > &Predicates)
Similar to getBackedgeTakenCount, except it will add a set of SCEV predicates to Predicates that are ...
static SCEV::NoWrapFlags clearFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OffFlags)
void forgetTopmostLoop(const Loop *L)
void forgetValue(Value *V)
This method should be called by the client when it has changed a value in a way that may effect its v...
APInt getSignedRangeMin(const SCEV *S)
Determine the min of the signed range for a particular SCEV.
const SCEV * getNoopOrAnyExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
void forgetBlockAndLoopDispositions(Value *V=nullptr)
Called when the client has changed the disposition of values in a loop or block.
const SCEV * getTruncateExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
const SCEV * getStoreSizeOfExpr(Type *IntTy, Type *StoreTy)
Return an expression for the store size of StoreTy that is type IntTy.
const SCEVPredicate * getWrapPredicate(const SCEVAddRecExpr *AR, SCEVWrapPredicate::IncrementWrapFlags AddedFlags)
const SCEV * getMinusSCEV(const SCEV *LHS, const SCEV *RHS, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Return LHS-RHS.
APInt getNonZeroConstantMultiple(const SCEV *S)
const SCEV * getMinusOne(Type *Ty)
Return a SCEV for the constant -1 of a specific type.
static SCEV::NoWrapFlags setFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OnFlags)
bool hasLoopInvariantBackedgeTakenCount(const Loop *L)
Return true if the specified loop has an analyzable loop-invariant backedge-taken count.
BlockDisposition getBlockDisposition(const SCEV *S, const BasicBlock *BB)
Return the "disposition" of the given SCEV with respect to the given block.
const SCEV * getNoopOrZeroExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
bool invalidate(Function &F, const PreservedAnalyses &PA, FunctionAnalysisManager::Invalidator &Inv)
const SCEV * getUMinFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS, bool Sequential=false)
Promote the operands to the wider of the types using zero-extension, and then perform a umin operatio...
bool loopIsFiniteByAssumption(const Loop *L)
Return true if this loop is finite by assumption.
const SCEV * getExistingSCEV(Value *V)
Return an existing SCEV for V if there is one, otherwise return nullptr.
LoopDisposition
An enum describing the relationship between a SCEV and a loop.
@ LoopComputable
The SCEV varies predictably with the loop.
@ LoopVariant
The SCEV is loop-variant (unknown).
@ LoopInvariant
The SCEV is loop-invariant.
const SCEV * getAnyExtendExpr(const SCEV *Op, Type *Ty)
getAnyExtendExpr - Return a SCEV for the given operand extended with unspecified bits out to the give...
const SCEVAddRecExpr * convertSCEVToAddRecWithPredicates(const SCEV *S, const Loop *L, SmallPtrSetImpl< const SCEVPredicate * > &Preds)
Tries to convert the S expression to an AddRec expression, adding additional predicates to Preds as r...
std::optional< SCEV::NoWrapFlags > getStrengthenedNoWrapFlagsFromBinOp(const OverflowingBinaryOperator *OBO)
Parse NSW/NUW flags from add/sub/mul IR binary operation Op into SCEV no-wrap flags,...
void forgetLcssaPhiWithNewPredecessor(Loop *L, PHINode *V)
Forget LCSSA phi node V of loop L to which a new predecessor was added, such that it may no longer be...
bool containsUndefs(const SCEV *S) const
Return true if the SCEV expression contains an undef value.
std::optional< MonotonicPredicateType > getMonotonicPredicateType(const SCEVAddRecExpr *LHS, ICmpInst::Predicate Pred)
If, for all loop invariant X, the predicate "LHS `Pred` X" is monotonically increasing or decreasing,...
const SCEV * getCouldNotCompute()
bool isAvailableAtLoopEntry(const SCEV *S, const Loop *L)
Determine if the SCEV can be evaluated at loop's entry.
BlockDisposition
An enum describing the relationship between a SCEV and a basic block.
@ DominatesBlock
The SCEV dominates the block.
@ ProperlyDominatesBlock
The SCEV properly dominates the block.
@ DoesNotDominateBlock
The SCEV does not dominate the block.
std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterationsImpl(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
const SCEV * getExitCount(const Loop *L, const BasicBlock *ExitingBlock, ExitCountKind Kind=Exact)
Return the number of times the backedge executes before the given exit would be taken; if not exactly...
const SCEV * getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
void getPoisonGeneratingValues(SmallPtrSetImpl< const Value * > &Result, const SCEV *S)
Return the set of Values that, if poison, will definitively result in S being poison as well.
void forgetLoopDispositions()
Called when the client has changed the disposition of values in this loop.
const SCEV * getVScale(Type *Ty)
unsigned getSmallConstantTripCount(const Loop *L)
Returns the exact trip count of the loop if we can compute it, and the result is a small constant.
bool hasComputableLoopEvolution(const SCEV *S, const Loop *L)
Return true if the given SCEV changes value in a known way in the specified loop.
const SCEV * getPointerBase(const SCEV *V)
Transitively follow the chain of pointer-type operands until reaching a SCEV that does not have a sin...
const SCEV * getMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< const SCEV * > &Operands)
bool dominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV dominate the specified basic block.
APInt getUnsignedRangeMax(const SCEV *S)
Determine the max of the unsigned range for a particular SCEV.
ExitCountKind
The terms "backedge taken count" and "exit count" are used interchangeably to refer to the number of ...
@ SymbolicMaximum
An expression which provides an upper bound on the exact trip count.
@ ConstantMaximum
A constant which provides an upper bound on the exact trip count.
@ Exact
An expression exactly describing the number of times the backedge has executed when a loop is exited.
std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterations(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L at given Context duri...
const SCEV * applyLoopGuards(const SCEV *Expr, const Loop *L)
Try to apply information from loop guards for L to Expr.
const SCEV * getMulExpr(SmallVectorImpl< const SCEV * > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical multiply expression, or something simpler if possible.
const SCEV * getElementSize(Instruction *Inst)
Return the size of an element read or written by Inst.
const SCEV * getSizeOfExpr(Type *IntTy, TypeSize Size)
Return an expression for a TypeSize.
const SCEV * getUnknown(Value *V)
std::optional< std::pair< const SCEV *, SmallVector< const SCEVPredicate *, 3 > > > createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI)
Checks if SymbolicPHI can be rewritten as an AddRecExpr under some Predicates.
const SCEV * getTruncateOrZeroExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
bool isLoopEntryGuardedByCond(const Loop *L, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the loop is protected by a conditional between LHS and RHS.
const SCEV * getElementCount(Type *Ty, ElementCount EC)
static SCEV::NoWrapFlags maskFlags(SCEV::NoWrapFlags Flags, int Mask)
Convenient NoWrapFlags manipulation that hides enum casts and is visible in the ScalarEvolution name ...
std::optional< APInt > computeConstantDifference(const SCEV *LHS, const SCEV *RHS)
Compute LHS - RHS and returns the result as an APInt if it is a constant, and std::nullopt if it isn'...
bool properlyDominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV properly dominate the specified basic block.
const SCEV * rewriteUsingPredicate(const SCEV *S, const Loop *L, const SCEVPredicate &A)
Re-writes the SCEV according to the Predicates in A.
std::pair< const SCEV *, const SCEV * > SplitIntoInitAndPostInc(const Loop *L, const SCEV *S)
Splits SCEV expression S into two SCEVs.
bool canReuseInstruction(const SCEV *S, Instruction *I, SmallVectorImpl< Instruction * > &DropPoisonGeneratingInsts)
Check whether it is poison-safe to represent the expression S using the instruction I.
const SCEV * getUDivExactExpr(const SCEV *LHS, const SCEV *RHS)
Get a canonical unsigned division expression, or something simpler if possible.
void registerUser(const SCEV *User, ArrayRef< const SCEV * > Ops)
Notify this ScalarEvolution that User directly uses SCEVs in Ops.
const SCEV * getAddExpr(SmallVectorImpl< const SCEV * > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical add expression, or something simpler if possible.
const SCEV * getTruncateOrSignExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
bool containsErasedValue(const SCEV *S) const
Return true if the SCEV expression contains a Value that has been optimised out and is now a nullptr.
const SCEV * getSymbolicMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEV that is greater than or equal to (i.e.
APInt getSignedRangeMax(const SCEV *S)
Determine the max of the signed range for a particular SCEV.
LLVMContext & getContext() const
This class represents the LLVM 'select' instruction.
size_type size() const
Definition: SmallPtrSet.h:94
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
Definition: SmallPtrSet.h:321
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
Definition: SmallPtrSet.h:342
bool contains(ConstPtrType Ptr) const
Definition: SmallPtrSet.h:366
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
Definition: SmallPtrSet.h:427
SmallSet - This maintains a set of unique values, optimizing for the case when the set is small (less...
Definition: SmallSet.h:135
std::pair< const_iterator, bool > insert(const T &V)
insert - Insert an element into the set if it isn't already there.
Definition: SmallSet.h:179
size_type size() const
Definition: SmallSet.h:161
bool empty() const
Definition: SmallVector.h:94
size_t size() const
Definition: SmallVector.h:91
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
Definition: SmallVector.h:586
reference emplace_back(ArgTypes &&... Args)
Definition: SmallVector.h:950
void reserve(size_type N)
Definition: SmallVector.h:676
iterator erase(const_iterator CI)
Definition: SmallVector.h:750
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
Definition: SmallVector.h:696
iterator insert(iterator I, T &&Elt)
Definition: SmallVector.h:818
void push_back(const T &Elt)
Definition: SmallVector.h:426
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1209
An instruction for storing to memory.
Definition: Instructions.h:317
Used to lazily calculate structure layout information for a target machine, based on the DataLayout s...
Definition: DataLayout.h:622
TypeSize getElementOffset(unsigned Idx) const
Definition: DataLayout.h:651
TypeSize getSizeInBits() const
Definition: DataLayout.h:631
Class to represent struct types.
Definition: DerivedTypes.h:216
Multiway switch.
Analysis pass providing the TargetLibraryInfo.
Provides information about what library functions are available for the current target.
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
bool isPointerTy() const
True if this is an instance of PointerType.
Definition: Type.h:255
static IntegerType * getInt1Ty(LLVMContext &C)
static IntegerType * getIntNTy(LLVMContext &C, unsigned N)
unsigned getScalarSizeInBits() const LLVM_READONLY
If this is a vector type, return the getPrimitiveSizeInBits value for the element type.
static IntegerType * getInt8Ty(LLVMContext &C)
bool isIntOrPtrTy() const
Return true if this is an integer type or a pointer type.
Definition: Type.h:243
static IntegerType * getInt32Ty(LLVMContext &C)
bool isIntegerTy() const
True if this is an instance of IntegerType.
Definition: Type.h:228
TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
A Use represents the edge between a Value definition and its users.
Definition: Use.h:43
op_range operands()
Definition: User.h:242
Use & Op()
Definition: User.h:133
Value * getOperand(unsigned i) const
Definition: User.h:169
LLVM Value Representation.
Definition: Value.h:74
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
unsigned getValueID() const
Return an ID for the concrete type of this object.
Definition: Value.h:532
void printAsOperand(raw_ostream &O, bool PrintType=true, const Module *M=nullptr) const
Print the name of this Value out to the specified raw_ostream.
Definition: AsmWriter.cpp:5043
LLVMContext & getContext() const
All values hold a context through their type.
Definition: Value.cpp:1074
iterator_range< use_iterator > uses()
Definition: Value.h:376
StringRef getName() const
Return a constant reference to the value's name.
Definition: Value.cpp:309
Represents an op.with.overflow intrinsic.
constexpr bool isScalable() const
Returns whether the quantity is scaled by a runtime quantity (vscale).
Definition: TypeSize.h:171
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition: raw_ostream.h:52
raw_ostream & indent(unsigned NumSpaces)
indent - Insert 'NumSpaces' spaces.
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
const APInt & smin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be signed.
Definition: APInt.h:2178
const APInt & smax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be signed.
Definition: APInt.h:2183
const APInt & umin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be unsigned.
Definition: APInt.h:2188
std::optional< APInt > SolveQuadraticEquationWrap(APInt A, APInt B, APInt C, unsigned RangeWidth)
Let q(n) = An^2 + Bn + C, and BW = bit width of the value range (e.g.
Definition: APInt.cpp:2814
const APInt & umax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be unsigned.
Definition: APInt.h:2193
APInt GreatestCommonDivisor(APInt A, APInt B)
Compute GCD of two unsigned APInt values.
Definition: APInt.cpp:767
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition: CallingConv.h:24
@ C
The default llvm calling convention, compatible with C.
Definition: CallingConv.h:34
StringRef getName(ID id)
Return the LLVM name for an intrinsic, such as "llvm.ppc.altivec.lvx".
Definition: Function.cpp:1017
BinaryOp_match< LHS, RHS, Instruction::AShr > m_AShr(const LHS &L, const RHS &R)
bool match(Val *V, const Pattern &P)
Definition: PatternMatch.h:49
class_match< ConstantInt > m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
Definition: PatternMatch.h:163
ThreeOps_match< Cond, LHS, RHS, Instruction::Select > m_Select(const Cond &C, const LHS &L, const RHS &R)
Matches SelectInst.
bind_ty< WithOverflowInst > m_WithOverflowInst(WithOverflowInst *&I)
Match a with overflow intrinsic, capturing it if we match.
Definition: PatternMatch.h:771
auto m_LogicalOr()
Matches L || R where L and R are arbitrary values.
brc_match< Cond_t, bind_ty< BasicBlock >, bind_ty< BasicBlock > > m_Br(const Cond_t &C, BasicBlock *&T, BasicBlock *&F)
BinaryOp_match< LHS, RHS, Instruction::SDiv > m_SDiv(const LHS &L, const RHS &R)
apint_match m_APInt(const APInt *&Res)
Match a ConstantInt or splatted ConstantVector, binding the specified pointer to the contained APInt.
Definition: PatternMatch.h:294
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
Definition: PatternMatch.h:92
BinaryOp_match< LHS, RHS, Instruction::LShr > m_LShr(const LHS &L, const RHS &R)
BinaryOp_match< LHS, RHS, Instruction::Shl > m_Shl(const LHS &L, const RHS &R)
auto m_LogicalAnd()
Matches L && R where L and R are arbitrary values.
class_match< BasicBlock > m_BasicBlock()
Match an arbitrary basic block value and ignore it.
Definition: PatternMatch.h:184
@ ReallyHidden
Definition: CommandLine.h:139
initializer< Ty > init(const Ty &Val)
Definition: CommandLine.h:450
LocationClass< Ty > location(Ty &L)
Definition: CommandLine.h:470
@ Switch
The "resume-switch" lowering, where there are separate resume and destroy functions that are shared b...
constexpr double e
Definition: MathExtras.h:31
NodeAddr< PhiNode * > Phi
Definition: RDFGraph.h:390
@ FalseVal
Definition: TGLexer.h:59
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
void visitAll(const SCEV *Root, SV &Visitor)
Use SCEVTraversal to visit all nodes in the given expression tree.
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
Definition: STLExtras.h:329
@ Offset
Definition: DWP.cpp:456
bool isKnownNonZero(const Value *V, const DataLayout &DL, unsigned Depth=0, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true)
Return true if the given value is known to be non-zero when defined.
void stable_sort(R &&Range)
Definition: STLExtras.h:2004
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1731
bool canCreatePoison(const Operator *Op, bool ConsiderFlagsAndMetadata=true)
bool mustTriggerUB(const Instruction *I, const SmallPtrSetImpl< const Value * > &KnownPoison)
Return true if the given instruction must trigger undefined behavior when I is executed with any oper...
detail::scope_exit< std::decay_t< Callable > > make_scope_exit(Callable &&F)
Definition: ScopeExit.h:59
bool canConstantFoldCallTo(const CallBase *Call, const Function *F)
canConstantFoldCallTo - Return true if its even possible to fold a call to the specified function.
bool verifyFunction(const Function &F, raw_ostream *OS=nullptr)
Check a function for errors, useful for use when debugging a pass.
Definition: Verifier.cpp:6966
auto successors(const MachineBasicBlock *BB)
void * PointerTy
Definition: GenericValue.h:21
void append_range(Container &C, Range &&R)
Wrapper function to append range R to container C.
Definition: STLExtras.h:2082
Constant * ConstantFoldCompareInstOperands(unsigned Predicate, Constant *LHS, Constant *RHS, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr, const Instruction *I=nullptr)
Attempt to constant fold a compare instruction (icmp/fcmp) with the specified operands.
unsigned short computeExpressionSize(ArrayRef< const SCEV * > Args)
bool VerifySCEV
Printable print(const GCNRegPressure &RP, const GCNSubtarget *ST=nullptr)
ConstantRange getConstantRangeFromMetadata(const MDNode &RangeMD)
Parse out a conservative ConstantRange from !range metadata.
Value * simplifyInstruction(Instruction *I, const SimplifyQuery &Q)
See if we can compute a simplified version of this instruction.
bool isOverflowIntrinsicNoWrap(const WithOverflowInst *WO, const DominatorTree &DT)
Returns true if the arithmetic part of the WO 's result is used only along the paths control dependen...
bool matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO, Value *&Start, Value *&Step)
Attempt to match a simple first order recurrence cycle of the form: iv = phi Ty [Start,...
void erase(Container &C, ValueType V)
Wrapper function to remove a value from a container:
Definition: STLExtras.h:2068
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1738
bool getObjectSize(const Value *Ptr, uint64_t &Size, const DataLayout &DL, const TargetLibraryInfo *TLI, ObjectSizeOpts Opts={})
Compute the size of the object pointed by Ptr.
void initializeScalarEvolutionWrapperPassPass(PassRegistry &)
auto reverse(ContainerTy &&C)
Definition: STLExtras.h:428
bool isMustProgress(const Loop *L)
Return true if this loop can be assumed to make progress.
Definition: LoopInfo.cpp:1118
bool impliesPoison(const Value *ValAssumedPoison, const Value *V)
Return true if V is poison given that ValAssumedPoison is already poison.
Constant * ConstantFoldInstOperands(Instruction *I, ArrayRef< Constant * > Ops, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr)
ConstantFoldInstOperands - Attempt to constant fold an instruction with the specified operands.
bool isFinite(const Loop *L)
Return true if this loop can be assumed to run for a finite number of iterations.
Definition: LoopInfo.cpp:1108
bool programUndefinedIfPoison(const Instruction *Inst)
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:163
bool isPointerTy(const Type *T)
Definition: SPIRVUtils.h:116
ConstantRange getVScaleRange(const Function *F, unsigned BitWidth)
Determine the possible constant range of vscale with the given bit width, based on the vscale_range f...
@ Mod
The access may modify the value stored in memory.
@ First
Helpers to iterate all locations in the MemoryEffectsBase class.
bool propagatesPoison(const Use &PoisonOp)
Return true if PoisonOp's user yields poison or raises UB if its operand PoisonOp is poison.
@ UMin
Unsigned integer min implemented in terms of select(cmp()).
@ Mul
Product of integers.
@ SMax
Signed integer max implemented in terms of select(cmp()).
@ SMin
Signed integer min implemented in terms of select(cmp()).
@ Add
Sum of integers.
@ UMax
Unsigned integer max implemented in terms of select(cmp()).
bool isIntN(unsigned N, int64_t x)
Checks if an signed integer fits into the given (dynamic) bit width.
Definition: MathExtras.h:233
auto count(R &&Range, const E &Element)
Wrapper function around std::count to count the number of times an element Element occurs in the give...
Definition: STLExtras.h:1923
void computeKnownBits(const Value *V, KnownBits &Known, const DataLayout &DL, unsigned Depth=0, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true)
Determine which bits of V are known to be either zero or one and return them in the KnownZero/KnownOn...
DWARFExpression::Operation Op
auto max_element(R &&Range)
Definition: STLExtras.h:1995
raw_ostream & operator<<(raw_ostream &OS, const APFixedPoint &FX)
Definition: APFixedPoint.h:293
constexpr unsigned BitWidth
Definition: BitmaskEnum.h:191
OutputIt move(R &&Range, OutputIt Out)
Provide wrappers to std::move which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1858
bool isGuaranteedToTransferExecutionToSuccessor(const Instruction *I)
Return true if this function can prove that the instruction I will always transfer execution to one o...
auto count_if(R &&Range, UnaryPredicate P)
Wrapper function around std::count_if to count the number of times an element satisfying a given pred...
Definition: STLExtras.h:1930
auto predecessors(const MachineBasicBlock *BB)
bool isAllocationFn(const Value *V, const TargetLibraryInfo *TLI)
Tests if a value is a call or invoke to a library function that allocates or reallocates memory (eith...
bool is_contained(R &&Range, const E &Element)
Returns true if Element is found in Range.
Definition: STLExtras.h:1888
unsigned ComputeNumSignBits(const Value *Op, const DataLayout &DL, unsigned Depth=0, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true)
Return the number of times the sign bit of the register is replicated into the other bits.
iterator_range< df_iterator< T > > depth_first(const T &G)
auto seq(T Begin, T End)
Iterate over an integral type from Begin up to - but not including - End.
Definition: Sequence.h:305
bool isGuaranteedNotToBePoison(const Value *V, AssumptionCache *AC=nullptr, const Instruction *CtxI=nullptr, const DominatorTree *DT=nullptr, unsigned Depth=0)
Returns true if V cannot be poison, but may be undef.
bool SCEVExprContains(const SCEV *Root, PredTy Pred)
Return true if any node in Root satisfies the predicate Pred.
Implement std::hash so that hash_code can be used in STL containers.
Definition: BitVector.h:858
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition: BitVector.h:860
#define N
#define NC
Definition: regutils.h:42
This struct is a compact representation of a valid (non-zero power of two) alignment.
Definition: Alignment.h:39
A special type used by analysis passes to provide an address that identifies that particular analysis...
Definition: Analysis.h:26
static KnownBits makeConstant(const APInt &C)
Create known bits from a known constant.
Definition: KnownBits.h:297
bool isNonNegative() const
Returns true if this value is known to be non-negative.
Definition: KnownBits.h:104
static KnownBits ashr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for ashr(LHS, RHS).
Definition: KnownBits.cpp:422
unsigned getBitWidth() const
Get the bit width of this value.
Definition: KnownBits.h:40
static KnownBits lshr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for lshr(LHS, RHS).
Definition: KnownBits.cpp:364
KnownBits zextOrTrunc(unsigned BitWidth) const
Return known bits for a zero extension or truncation of the value we're tracking.
Definition: KnownBits.h:192
APInt getMaxValue() const
Return the maximal unsigned value possible given these KnownBits.
Definition: KnownBits.h:141
APInt getMinValue() const
Return the minimal unsigned value possible given these KnownBits.
Definition: KnownBits.h:125
bool isNegative() const
Returns true if this value is known to be negative.
Definition: KnownBits.h:101
static KnownBits shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW=false, bool NSW=false, bool ShAmtNonZero=false)
Compute known bits for shl(LHS, RHS).
Definition: KnownBits.cpp:279
Various options to control the behavior of getObjectSize.
bool NullIsUnknownSize
If this is true, null pointers in address space 0 will be treated as though they can't be evaluated.
bool RoundToAlign
Whether to round the result up to the alignment of allocas, byval arguments, and global variables.
An object of this class is returned by queries that could not be answered.
static bool classof(const SCEV *S)
Methods for support type inquiry through isa, cast, and dyn_cast:
This class defines a simple visitor class that may be used for various SCEV analysis purposes.
A utility class that uses RAII to save and restore the value of a variable.
Information about the number of loop iterations for which a loop exit's branch condition evaluates to...
ExitLimit(const SCEV *E)
Construct either an exact exit limit from a constant, or an unknown one from a SCEVCouldNotCompute.
void addPredicate(const SCEVPredicate *P)