LLVM 23.0.0git
ScalarEvolution.cpp
Go to the documentation of this file.
1//===- ScalarEvolution.cpp - Scalar Evolution Analysis --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains the implementation of the scalar evolution analysis
10// engine, which is used primarily to analyze expressions involving induction
11// variables in loops.
12//
13// There are several aspects to this library. First is the representation of
14// scalar expressions, which are represented as subclasses of the SCEV class.
15// These classes are used to represent certain types of subexpressions that we
16// can handle. We only create one SCEV of a particular shape, so
17// pointer-comparisons for equality are legal.
18//
19// One important aspect of the SCEV objects is that they are never cyclic, even
20// if there is a cycle in the dataflow for an expression (ie, a PHI node). If
21// the PHI node is one of the idioms that we can represent (e.g., a polynomial
22// recurrence) then we represent it directly as a recurrence node, otherwise we
23// represent it as a SCEVUnknown node.
24//
25// In addition to being able to represent expressions of various types, we also
26// have folders that are used to build the *canonical* representation for a
27// particular expression. These folders are capable of using a variety of
28// rewrite rules to simplify the expressions.
29//
30// Once the folders are defined, we can implement the more interesting
31// higher-level code, such as the code that recognizes PHI nodes of various
32// types, computes the execution count of a loop, etc.
33//
34// TODO: We should use these routines and value representations to implement
35// dependence analysis!
36//
37//===----------------------------------------------------------------------===//
38//
39// There are several good references for the techniques used in this analysis.
40//
41// Chains of recurrences -- a method to expedite the evaluation
42// of closed-form functions
43// Olaf Bachmann, Paul S. Wang, Eugene V. Zima
44//
45// On computational properties of chains of recurrences
46// Eugene V. Zima
47//
48// Symbolic Evaluation of Chains of Recurrences for Loop Optimization
49// Robert A. van Engelen
50//
51// Efficient Symbolic Analysis for Optimizing Compilers
52// Robert A. van Engelen
53//
54// Using the chains of recurrences algebra for data dependence testing and
55// induction variable substitution
56// MS Thesis, Johnie Birch
57//
58//===----------------------------------------------------------------------===//
59
61#include "llvm/ADT/APInt.h"
62#include "llvm/ADT/ArrayRef.h"
63#include "llvm/ADT/DenseMap.h"
65#include "llvm/ADT/FoldingSet.h"
66#include "llvm/ADT/STLExtras.h"
67#include "llvm/ADT/ScopeExit.h"
68#include "llvm/ADT/Sequence.h"
71#include "llvm/ADT/Statistic.h"
73#include "llvm/ADT/StringRef.h"
83#include "llvm/Config/llvm-config.h"
84#include "llvm/IR/Argument.h"
85#include "llvm/IR/BasicBlock.h"
86#include "llvm/IR/CFG.h"
87#include "llvm/IR/Constant.h"
89#include "llvm/IR/Constants.h"
90#include "llvm/IR/DataLayout.h"
92#include "llvm/IR/Dominators.h"
93#include "llvm/IR/Function.h"
94#include "llvm/IR/GlobalAlias.h"
95#include "llvm/IR/GlobalValue.h"
97#include "llvm/IR/InstrTypes.h"
98#include "llvm/IR/Instruction.h"
101#include "llvm/IR/Intrinsics.h"
102#include "llvm/IR/LLVMContext.h"
103#include "llvm/IR/Operator.h"
104#include "llvm/IR/PatternMatch.h"
105#include "llvm/IR/Type.h"
106#include "llvm/IR/Use.h"
107#include "llvm/IR/User.h"
108#include "llvm/IR/Value.h"
109#include "llvm/IR/Verifier.h"
111#include "llvm/Pass.h"
112#include "llvm/Support/Casting.h"
115#include "llvm/Support/Debug.h"
121#include <algorithm>
122#include <cassert>
123#include <climits>
124#include <cstdint>
125#include <cstdlib>
126#include <map>
127#include <memory>
128#include <numeric>
129#include <optional>
130#include <tuple>
131#include <utility>
132#include <vector>
133
134using namespace llvm;
135using namespace PatternMatch;
136using namespace SCEVPatternMatch;
137
138#define DEBUG_TYPE "scalar-evolution"
139
140STATISTIC(NumExitCountsComputed,
141 "Number of loop exits with predictable exit counts");
142STATISTIC(NumExitCountsNotComputed,
143 "Number of loop exits without predictable exit counts");
144STATISTIC(NumBruteForceTripCountsComputed,
145 "Number of loops with trip counts computed by force");
146
147#ifdef EXPENSIVE_CHECKS
148bool llvm::VerifySCEV = true;
149#else
150bool llvm::VerifySCEV = false;
151#endif
152
154 MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
155 cl::desc("Maximum number of iterations SCEV will "
156 "symbolically execute a constant "
157 "derived loop"),
158 cl::init(100));
159
161 "verify-scev", cl::Hidden, cl::location(VerifySCEV),
162 cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"));
164 "verify-scev-strict", cl::Hidden,
165 cl::desc("Enable stricter verification with -verify-scev is passed"));
166
168 "scev-verify-ir", cl::Hidden,
169 cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"),
170 cl::init(false));
171
173 "scev-mulops-inline-threshold", cl::Hidden,
174 cl::desc("Threshold for inlining multiplication operands into a SCEV"),
175 cl::init(32));
176
178 "scev-addops-inline-threshold", cl::Hidden,
179 cl::desc("Threshold for inlining addition operands into a SCEV"),
180 cl::init(500));
181
183 "scalar-evolution-max-scev-compare-depth", cl::Hidden,
184 cl::desc("Maximum depth of recursive SCEV complexity comparisons"),
185 cl::init(32));
186
188 "scalar-evolution-max-scev-operations-implication-depth", cl::Hidden,
189 cl::desc("Maximum depth of recursive SCEV operations implication analysis"),
190 cl::init(2));
191
193 "scalar-evolution-max-value-compare-depth", cl::Hidden,
194 cl::desc("Maximum depth of recursive value complexity comparisons"),
195 cl::init(2));
196
198 MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden,
199 cl::desc("Maximum depth of recursive arithmetics"),
200 cl::init(32));
201
203 "scalar-evolution-max-constant-evolving-depth", cl::Hidden,
204 cl::desc("Maximum depth of recursive constant evolving"), cl::init(32));
205
207 MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden,
208 cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"),
209 cl::init(8));
210
212 MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden,
213 cl::desc("Max coefficients in AddRec during evolving"),
214 cl::init(8));
215
217 HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden,
218 cl::desc("Size of the expression which is considered huge"),
219 cl::init(4096));
220
222 "scev-range-iter-threshold", cl::Hidden,
223 cl::desc("Threshold for switching to iteratively computing SCEV ranges"),
224 cl::init(32));
225
227 "scalar-evolution-max-loop-guard-collection-depth", cl::Hidden,
228 cl::desc("Maximum depth for recursive loop guard collection"), cl::init(1));
229
230static cl::opt<bool>
231ClassifyExpressions("scalar-evolution-classify-expressions",
232 cl::Hidden, cl::init(true),
233 cl::desc("When printing analysis, include information on every instruction"));
234
236 "scalar-evolution-use-expensive-range-sharpening", cl::Hidden,
237 cl::init(false),
238 cl::desc("Use more powerful methods of sharpening expression ranges. May "
239 "be costly in terms of compile time"));
240
242 "scalar-evolution-max-scc-analysis-depth", cl::Hidden,
243 cl::desc("Maximum amount of nodes to process while searching SCEVUnknown "
244 "Phi strongly connected components"),
245 cl::init(8));
246
247static cl::opt<bool>
248 EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden,
249 cl::desc("Handle <= and >= in finite loops"),
250 cl::init(true));
251
253 "scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden,
254 cl::desc("Infer nuw/nsw flags using context where suitable"),
255 cl::init(true));
256
257//===----------------------------------------------------------------------===//
258// SCEV class definitions
259//===----------------------------------------------------------------------===//
260
261//===----------------------------------------------------------------------===//
262// Implementation of the SCEV class.
263//
264
265#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
267 print(dbgs());
268 dbgs() << '\n';
269}
270#endif
271
272void SCEV::print(raw_ostream &OS) const {
273 switch (getSCEVType()) {
274 case scConstant:
275 cast<SCEVConstant>(this)->getValue()->printAsOperand(OS, false);
276 return;
277 case scVScale:
278 OS << "vscale";
279 return;
280 case scPtrToAddr:
281 case scPtrToInt: {
282 const SCEVCastExpr *PtrCast = cast<SCEVCastExpr>(this);
283 const SCEV *Op = PtrCast->getOperand();
284 StringRef OpS = getSCEVType() == scPtrToAddr ? "addr" : "int";
285 OS << "(ptrto" << OpS << " " << *Op->getType() << " " << *Op << " to "
286 << *PtrCast->getType() << ")";
287 return;
288 }
289 case scTruncate: {
290 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(this);
291 const SCEV *Op = Trunc->getOperand();
292 OS << "(trunc " << *Op->getType() << " " << *Op << " to "
293 << *Trunc->getType() << ")";
294 return;
295 }
296 case scZeroExtend: {
298 const SCEV *Op = ZExt->getOperand();
299 OS << "(zext " << *Op->getType() << " " << *Op << " to "
300 << *ZExt->getType() << ")";
301 return;
302 }
303 case scSignExtend: {
305 const SCEV *Op = SExt->getOperand();
306 OS << "(sext " << *Op->getType() << " " << *Op << " to "
307 << *SExt->getType() << ")";
308 return;
309 }
310 case scAddRecExpr: {
311 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(this);
312 OS << "{" << *AR->getOperand(0);
313 for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i)
314 OS << ",+," << *AR->getOperand(i);
315 OS << "}<";
316 if (AR->hasNoUnsignedWrap())
317 OS << "nuw><";
318 if (AR->hasNoSignedWrap())
319 OS << "nsw><";
320 if (AR->hasNoSelfWrap() &&
322 OS << "nw><";
323 AR->getLoop()->getHeader()->printAsOperand(OS, /*PrintType=*/false);
324 OS << ">";
325 return;
326 }
327 case scAddExpr:
328 case scMulExpr:
329 case scUMaxExpr:
330 case scSMaxExpr:
331 case scUMinExpr:
332 case scSMinExpr:
334 const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(this);
335 const char *OpStr = nullptr;
336 switch (NAry->getSCEVType()) {
337 case scAddExpr: OpStr = " + "; break;
338 case scMulExpr: OpStr = " * "; break;
339 case scUMaxExpr: OpStr = " umax "; break;
340 case scSMaxExpr: OpStr = " smax "; break;
341 case scUMinExpr:
342 OpStr = " umin ";
343 break;
344 case scSMinExpr:
345 OpStr = " smin ";
346 break;
348 OpStr = " umin_seq ";
349 break;
350 default:
351 llvm_unreachable("There are no other nary expression types.");
352 }
353 OS << "("
355 << ")";
356 switch (NAry->getSCEVType()) {
357 case scAddExpr:
358 case scMulExpr:
359 if (NAry->hasNoUnsignedWrap())
360 OS << "<nuw>";
361 if (NAry->hasNoSignedWrap())
362 OS << "<nsw>";
363 break;
364 default:
365 // Nothing to print for other nary expressions.
366 break;
367 }
368 return;
369 }
370 case scUDivExpr: {
371 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(this);
372 OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")";
373 return;
374 }
375 case scUnknown:
376 cast<SCEVUnknown>(this)->getValue()->printAsOperand(OS, false);
377 return;
379 OS << "***COULDNOTCOMPUTE***";
380 return;
381 }
382 llvm_unreachable("Unknown SCEV kind!");
383}
384
386 switch (getSCEVType()) {
387 case scConstant:
388 return cast<SCEVConstant>(this)->getType();
389 case scVScale:
390 return cast<SCEVVScale>(this)->getType();
391 case scPtrToAddr:
392 case scPtrToInt:
393 case scTruncate:
394 case scZeroExtend:
395 case scSignExtend:
396 return cast<SCEVCastExpr>(this)->getType();
397 case scAddRecExpr:
398 return cast<SCEVAddRecExpr>(this)->getType();
399 case scMulExpr:
400 return cast<SCEVMulExpr>(this)->getType();
401 case scUMaxExpr:
402 case scSMaxExpr:
403 case scUMinExpr:
404 case scSMinExpr:
405 return cast<SCEVMinMaxExpr>(this)->getType();
407 return cast<SCEVSequentialMinMaxExpr>(this)->getType();
408 case scAddExpr:
409 return cast<SCEVAddExpr>(this)->getType();
410 case scUDivExpr:
411 return cast<SCEVUDivExpr>(this)->getType();
412 case scUnknown:
413 return cast<SCEVUnknown>(this)->getType();
415 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
416 }
417 llvm_unreachable("Unknown SCEV kind!");
418}
419
421 switch (getSCEVType()) {
422 case scConstant:
423 case scVScale:
424 case scUnknown:
425 return {};
426 case scPtrToAddr:
427 case scPtrToInt:
428 case scTruncate:
429 case scZeroExtend:
430 case scSignExtend:
431 return cast<SCEVCastExpr>(this)->operands();
432 case scAddRecExpr:
433 case scAddExpr:
434 case scMulExpr:
435 case scUMaxExpr:
436 case scSMaxExpr:
437 case scUMinExpr:
438 case scSMinExpr:
440 return cast<SCEVNAryExpr>(this)->operands();
441 case scUDivExpr:
442 return cast<SCEVUDivExpr>(this)->operands();
444 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
445 }
446 llvm_unreachable("Unknown SCEV kind!");
447}
448
449bool SCEV::isZero() const { return match(this, m_scev_Zero()); }
450
451bool SCEV::isOne() const { return match(this, m_scev_One()); }
452
453bool SCEV::isAllOnesValue() const { return match(this, m_scev_AllOnes()); }
454
457 if (!Mul) return false;
458
459 // If there is a constant factor, it will be first.
460 const SCEVConstant *SC = dyn_cast<SCEVConstant>(Mul->getOperand(0));
461 if (!SC) return false;
462
463 // Return true if the value is negative, this matches things like (-42 * V).
464 return SC->getAPInt().isNegative();
465}
466
469
471 return S->getSCEVType() == scCouldNotCompute;
472}
473
476 ID.AddInteger(scConstant);
477 ID.AddPointer(V);
478 void *IP = nullptr;
479 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
480 SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V);
481 UniqueSCEVs.InsertNode(S, IP);
482 return S;
483}
484
486 return getConstant(ConstantInt::get(getContext(), Val));
487}
488
489const SCEV *
492 // TODO: Avoid implicit trunc?
493 // See https://github.com/llvm/llvm-project/issues/112510.
494 return getConstant(
495 ConstantInt::get(ITy, V, isSigned, /*ImplicitTrunc=*/true));
496}
497
500 ID.AddInteger(scVScale);
501 ID.AddPointer(Ty);
502 void *IP = nullptr;
503 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
504 return S;
505 SCEV *S = new (SCEVAllocator) SCEVVScale(ID.Intern(SCEVAllocator), Ty);
506 UniqueSCEVs.InsertNode(S, IP);
507 return S;
508}
509
511 SCEV::NoWrapFlags Flags) {
512 const SCEV *Res = getConstant(Ty, EC.getKnownMinValue());
513 if (EC.isScalable())
514 Res = getMulExpr(Res, getVScale(Ty), Flags);
515 return Res;
516}
517
519 const SCEV *op, Type *ty)
520 : SCEV(ID, SCEVTy, computeExpressionSize(op)), Op(op), Ty(ty) {}
521
522SCEVPtrToAddrExpr::SCEVPtrToAddrExpr(const FoldingSetNodeIDRef ID,
523 const SCEV *Op, Type *ITy)
524 : SCEVCastExpr(ID, scPtrToAddr, Op, ITy) {
525 assert(getOperand()->getType()->isPointerTy() && Ty->isIntegerTy() &&
526 "Must be a non-bit-width-changing pointer-to-integer cast!");
527}
528
529SCEVPtrToIntExpr::SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op,
530 Type *ITy)
531 : SCEVCastExpr(ID, scPtrToInt, Op, ITy) {
532 assert(getOperand()->getType()->isPointerTy() && Ty->isIntegerTy() &&
533 "Must be a non-bit-width-changing pointer-to-integer cast!");
534}
535
537 SCEVTypes SCEVTy, const SCEV *op,
538 Type *ty)
539 : SCEVCastExpr(ID, SCEVTy, op, ty) {}
540
541SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op,
542 Type *ty)
544 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
545 "Cannot truncate non-integer value!");
546}
547
548SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID,
549 const SCEV *op, Type *ty)
551 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
552 "Cannot zero extend non-integer value!");
553}
554
555SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID,
556 const SCEV *op, Type *ty)
558 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
559 "Cannot sign extend non-integer value!");
560}
561
563 // Clear this SCEVUnknown from various maps.
564 SE->forgetMemoizedResults(this);
565
566 // Remove this SCEVUnknown from the uniquing map.
567 SE->UniqueSCEVs.RemoveNode(this);
568
569 // Release the value.
570 setValPtr(nullptr);
571}
572
573void SCEVUnknown::allUsesReplacedWith(Value *New) {
574 // Clear this SCEVUnknown from various maps.
575 SE->forgetMemoizedResults(this);
576
577 // Remove this SCEVUnknown from the uniquing map.
578 SE->UniqueSCEVs.RemoveNode(this);
579
580 // Replace the value pointer in case someone is still using this SCEVUnknown.
581 setValPtr(New);
582}
583
584//===----------------------------------------------------------------------===//
585// SCEV Utilities
586//===----------------------------------------------------------------------===//
587
588/// Compare the two values \p LV and \p RV in terms of their "complexity" where
589/// "complexity" is a partial (and somewhat ad-hoc) relation used to order
590/// operands in SCEV expressions.
591static int CompareValueComplexity(const LoopInfo *const LI, Value *LV,
592 Value *RV, unsigned Depth) {
594 return 0;
595
596 // Order pointer values after integer values. This helps SCEVExpander form
597 // GEPs.
598 bool LIsPointer = LV->getType()->isPointerTy(),
599 RIsPointer = RV->getType()->isPointerTy();
600 if (LIsPointer != RIsPointer)
601 return (int)LIsPointer - (int)RIsPointer;
602
603 // Compare getValueID values.
604 unsigned LID = LV->getValueID(), RID = RV->getValueID();
605 if (LID != RID)
606 return (int)LID - (int)RID;
607
608 // Sort arguments by their position.
609 if (const auto *LA = dyn_cast<Argument>(LV)) {
610 const auto *RA = cast<Argument>(RV);
611 unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo();
612 return (int)LArgNo - (int)RArgNo;
613 }
614
615 if (const auto *LGV = dyn_cast<GlobalValue>(LV)) {
616 const auto *RGV = cast<GlobalValue>(RV);
617
618 if (auto L = LGV->getLinkage() - RGV->getLinkage())
619 return L;
620
621 const auto IsGVNameSemantic = [&](const GlobalValue *GV) {
622 auto LT = GV->getLinkage();
623 return !(GlobalValue::isPrivateLinkage(LT) ||
625 };
626
627 // Use the names to distinguish the two values, but only if the
628 // names are semantically important.
629 if (IsGVNameSemantic(LGV) && IsGVNameSemantic(RGV))
630 return LGV->getName().compare(RGV->getName());
631 }
632
633 // For instructions, compare their loop depth, and their operand count. This
634 // is pretty loose.
635 if (const auto *LInst = dyn_cast<Instruction>(LV)) {
636 const auto *RInst = cast<Instruction>(RV);
637
638 // Compare loop depths.
639 const BasicBlock *LParent = LInst->getParent(),
640 *RParent = RInst->getParent();
641 if (LParent != RParent) {
642 unsigned LDepth = LI->getLoopDepth(LParent),
643 RDepth = LI->getLoopDepth(RParent);
644 if (LDepth != RDepth)
645 return (int)LDepth - (int)RDepth;
646 }
647
648 // Compare the number of operands.
649 unsigned LNumOps = LInst->getNumOperands(),
650 RNumOps = RInst->getNumOperands();
651 if (LNumOps != RNumOps)
652 return (int)LNumOps - (int)RNumOps;
653
654 for (unsigned Idx : seq(LNumOps)) {
655 int Result = CompareValueComplexity(LI, LInst->getOperand(Idx),
656 RInst->getOperand(Idx), Depth + 1);
657 if (Result != 0)
658 return Result;
659 }
660 }
661
662 return 0;
663}
664
665// Return negative, zero, or positive, if LHS is less than, equal to, or greater
666// than RHS, respectively. A three-way result allows recursive comparisons to be
667// more efficient.
668// If the max analysis depth was reached, return std::nullopt, assuming we do
669// not know if they are equivalent for sure.
670static std::optional<int>
671CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS,
672 const SCEV *RHS, DominatorTree &DT, unsigned Depth = 0) {
673 // Fast-path: SCEVs are uniqued so we can do a quick equality check.
674 if (LHS == RHS)
675 return 0;
676
677 // Primarily, sort the SCEVs by their getSCEVType().
678 SCEVTypes LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
679 if (LType != RType)
680 return (int)LType - (int)RType;
681
683 return std::nullopt;
684
685 // Aside from the getSCEVType() ordering, the particular ordering
686 // isn't very important except that it's beneficial to be consistent,
687 // so that (a + b) and (b + a) don't end up as different expressions.
688 switch (LType) {
689 case scUnknown: {
690 const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
691 const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
692
693 int X =
694 CompareValueComplexity(LI, LU->getValue(), RU->getValue(), Depth + 1);
695 return X;
696 }
697
698 case scConstant: {
701
702 // Compare constant values.
703 const APInt &LA = LC->getAPInt();
704 const APInt &RA = RC->getAPInt();
705 unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth();
706 if (LBitWidth != RBitWidth)
707 return (int)LBitWidth - (int)RBitWidth;
708 return LA.ult(RA) ? -1 : 1;
709 }
710
711 case scVScale: {
712 const auto *LTy = cast<IntegerType>(cast<SCEVVScale>(LHS)->getType());
713 const auto *RTy = cast<IntegerType>(cast<SCEVVScale>(RHS)->getType());
714 return LTy->getBitWidth() - RTy->getBitWidth();
715 }
716
717 case scAddRecExpr: {
720
721 // There is always a dominance between two recs that are used by one SCEV,
722 // so we can safely sort recs by loop header dominance. We require such
723 // order in getAddExpr.
724 const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop();
725 if (LLoop != RLoop) {
726 const BasicBlock *LHead = LLoop->getHeader(), *RHead = RLoop->getHeader();
727 assert(LHead != RHead && "Two loops share the same header?");
728 if (DT.dominates(LHead, RHead))
729 return 1;
730 assert(DT.dominates(RHead, LHead) &&
731 "No dominance between recurrences used by one SCEV?");
732 return -1;
733 }
734
735 [[fallthrough]];
736 }
737
738 case scTruncate:
739 case scZeroExtend:
740 case scSignExtend:
741 case scPtrToAddr:
742 case scPtrToInt:
743 case scAddExpr:
744 case scMulExpr:
745 case scUDivExpr:
746 case scSMaxExpr:
747 case scUMaxExpr:
748 case scSMinExpr:
749 case scUMinExpr:
751 ArrayRef<const SCEV *> LOps = LHS->operands();
752 ArrayRef<const SCEV *> ROps = RHS->operands();
753
754 // Lexicographically compare n-ary-like expressions.
755 unsigned LNumOps = LOps.size(), RNumOps = ROps.size();
756 if (LNumOps != RNumOps)
757 return (int)LNumOps - (int)RNumOps;
758
759 for (unsigned i = 0; i != LNumOps; ++i) {
760 auto X = CompareSCEVComplexity(LI, LOps[i], ROps[i], DT, Depth + 1);
761 if (X != 0)
762 return X;
763 }
764 return 0;
765 }
766
768 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
769 }
770 llvm_unreachable("Unknown SCEV kind!");
771}
772
773/// Given a list of SCEV objects, order them by their complexity, and group
774/// objects of the same complexity together by value. When this routine is
775/// finished, we know that any duplicates in the vector are consecutive and that
776/// complexity is monotonically increasing.
777///
778/// Note that we go take special precautions to ensure that we get deterministic
779/// results from this routine. In other words, we don't want the results of
780/// this to depend on where the addresses of various SCEV objects happened to
781/// land in memory.
783 LoopInfo *LI, DominatorTree &DT) {
784 if (Ops.size() < 2) return; // Noop
785
786 // Whether LHS has provably less complexity than RHS.
787 auto IsLessComplex = [&](const SCEV *LHS, const SCEV *RHS) {
788 auto Complexity = CompareSCEVComplexity(LI, LHS, RHS, DT);
789 return Complexity && *Complexity < 0;
790 };
791 if (Ops.size() == 2) {
792 // This is the common case, which also happens to be trivially simple.
793 // Special case it.
794 const SCEV *&LHS = Ops[0], *&RHS = Ops[1];
795 if (IsLessComplex(RHS, LHS))
796 std::swap(LHS, RHS);
797 return;
798 }
799
800 // Do the rough sort by complexity.
801 llvm::stable_sort(Ops, [&](const SCEV *LHS, const SCEV *RHS) {
802 return IsLessComplex(LHS, RHS);
803 });
804
805 // Now that we are sorted by complexity, group elements of the same
806 // complexity. Note that this is, at worst, N^2, but the vector is likely to
807 // be extremely short in practice. Note that we take this approach because we
808 // do not want to depend on the addresses of the objects we are grouping.
809 for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) {
810 const SCEV *S = Ops[i];
811 unsigned Complexity = S->getSCEVType();
812
813 // If there are any objects of the same complexity and same value as this
814 // one, group them.
815 for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) {
816 if (Ops[j] == S) { // Found a duplicate.
817 // Move it to immediately after i'th element.
818 std::swap(Ops[i+1], Ops[j]);
819 ++i; // no need to rescan it.
820 if (i == e-2) return; // Done!
821 }
822 }
823 }
824}
825
826/// Returns true if \p Ops contains a huge SCEV (the subtree of S contains at
827/// least HugeExprThreshold nodes).
829 return any_of(Ops, [](const SCEV *S) {
831 });
832}
833
834/// Performs a number of common optimizations on the passed \p Ops. If the
835/// whole expression reduces down to a single operand, it will be returned.
836///
837/// The following optimizations are performed:
838/// * Fold constants using the \p Fold function.
839/// * Remove identity constants satisfying \p IsIdentity.
840/// * If a constant satisfies \p IsAbsorber, return it.
841/// * Sort operands by complexity.
842template <typename FoldT, typename IsIdentityT, typename IsAbsorberT>
843static const SCEV *
846 IsIdentityT IsIdentity, IsAbsorberT IsAbsorber) {
847 const SCEVConstant *Folded = nullptr;
848 for (unsigned Idx = 0; Idx < Ops.size();) {
849 const SCEV *Op = Ops[Idx];
850 if (const auto *C = dyn_cast<SCEVConstant>(Op)) {
851 if (!Folded)
852 Folded = C;
853 else
854 Folded = cast<SCEVConstant>(
855 SE.getConstant(Fold(Folded->getAPInt(), C->getAPInt())));
856 Ops.erase(Ops.begin() + Idx);
857 continue;
858 }
859 ++Idx;
860 }
861
862 if (Ops.empty()) {
863 assert(Folded && "Must have folded value");
864 return Folded;
865 }
866
867 if (Folded && IsAbsorber(Folded->getAPInt()))
868 return Folded;
869
870 GroupByComplexity(Ops, &LI, DT);
871 if (Folded && !IsIdentity(Folded->getAPInt()))
872 Ops.insert(Ops.begin(), Folded);
873
874 return Ops.size() == 1 ? Ops[0] : nullptr;
875}
876
877//===----------------------------------------------------------------------===//
878// Simple SCEV method implementations
879//===----------------------------------------------------------------------===//
880
881/// Compute BC(It, K). The result has width W. Assume, K > 0.
882static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
883 ScalarEvolution &SE,
884 Type *ResultTy) {
885 // Handle the simplest case efficiently.
886 if (K == 1)
887 return SE.getTruncateOrZeroExtend(It, ResultTy);
888
889 // We are using the following formula for BC(It, K):
890 //
891 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
892 //
893 // Suppose, W is the bitwidth of the return value. We must be prepared for
894 // overflow. Hence, we must assure that the result of our computation is
895 // equal to the accurate one modulo 2^W. Unfortunately, division isn't
896 // safe in modular arithmetic.
897 //
898 // However, this code doesn't use exactly that formula; the formula it uses
899 // is something like the following, where T is the number of factors of 2 in
900 // K! (i.e. trailing zeros in the binary representation of K!), and ^ is
901 // exponentiation:
902 //
903 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T)
904 //
905 // This formula is trivially equivalent to the previous formula. However,
906 // this formula can be implemented much more efficiently. The trick is that
907 // K! / 2^T is odd, and exact division by an odd number *is* safe in modular
908 // arithmetic. To do exact division in modular arithmetic, all we have
909 // to do is multiply by the inverse. Therefore, this step can be done at
910 // width W.
911 //
912 // The next issue is how to safely do the division by 2^T. The way this
913 // is done is by doing the multiplication step at a width of at least W + T
914 // bits. This way, the bottom W+T bits of the product are accurate. Then,
915 // when we perform the division by 2^T (which is equivalent to a right shift
916 // by T), the bottom W bits are accurate. Extra bits are okay; they'll get
917 // truncated out after the division by 2^T.
918 //
919 // In comparison to just directly using the first formula, this technique
920 // is much more efficient; using the first formula requires W * K bits,
921 // but this formula less than W + K bits. Also, the first formula requires
922 // a division step, whereas this formula only requires multiplies and shifts.
923 //
924 // It doesn't matter whether the subtraction step is done in the calculation
925 // width or the input iteration count's width; if the subtraction overflows,
926 // the result must be zero anyway. We prefer here to do it in the width of
927 // the induction variable because it helps a lot for certain cases; CodeGen
928 // isn't smart enough to ignore the overflow, which leads to much less
929 // efficient code if the width of the subtraction is wider than the native
930 // register width.
931 //
932 // (It's possible to not widen at all by pulling out factors of 2 before
933 // the multiplication; for example, K=2 can be calculated as
934 // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires
935 // extra arithmetic, so it's not an obvious win, and it gets
936 // much more complicated for K > 3.)
937
938 // Protection from insane SCEVs; this bound is conservative,
939 // but it probably doesn't matter.
940 if (K > 1000)
941 return SE.getCouldNotCompute();
942
943 unsigned W = SE.getTypeSizeInBits(ResultTy);
944
945 // Calculate K! / 2^T and T; we divide out the factors of two before
946 // multiplying for calculating K! / 2^T to avoid overflow.
947 // Other overflow doesn't matter because we only care about the bottom
948 // W bits of the result.
949 APInt OddFactorial(W, 1);
950 unsigned T = 1;
951 for (unsigned i = 3; i <= K; ++i) {
952 unsigned TwoFactors = countr_zero(i);
953 T += TwoFactors;
954 OddFactorial *= (i >> TwoFactors);
955 }
956
957 // We need at least W + T bits for the multiplication step
958 unsigned CalculationBits = W + T;
959
960 // Calculate 2^T, at width T+W.
961 APInt DivFactor = APInt::getOneBitSet(CalculationBits, T);
962
963 // Calculate the multiplicative inverse of K! / 2^T;
964 // this multiplication factor will perform the exact division by
965 // K! / 2^T.
966 APInt MultiplyFactor = OddFactorial.multiplicativeInverse();
967
968 // Calculate the product, at width T+W
969 IntegerType *CalculationTy = IntegerType::get(SE.getContext(),
970 CalculationBits);
971 const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
972 for (unsigned i = 1; i != K; ++i) {
973 const SCEV *S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i));
974 Dividend = SE.getMulExpr(Dividend,
975 SE.getTruncateOrZeroExtend(S, CalculationTy));
976 }
977
978 // Divide by 2^T
979 const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
980
981 // Truncate the result, and divide by K! / 2^T.
982
983 return SE.getMulExpr(SE.getConstant(MultiplyFactor),
984 SE.getTruncateOrZeroExtend(DivResult, ResultTy));
985}
986
987/// Return the value of this chain of recurrences at the specified iteration
988/// number. We can evaluate this recurrence by multiplying each element in the
989/// chain by the binomial coefficient corresponding to it. In other words, we
990/// can evaluate {A,+,B,+,C,+,D} as:
991///
992/// A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
993///
994/// where BC(It, k) stands for binomial coefficient.
996 ScalarEvolution &SE) const {
997 return evaluateAtIteration(operands(), It, SE);
998}
999
1000const SCEV *
1002 const SCEV *It, ScalarEvolution &SE) {
1003 assert(Operands.size() > 0);
1004 const SCEV *Result = Operands[0];
1005 for (unsigned i = 1, e = Operands.size(); i != e; ++i) {
1006 // The computation is correct in the face of overflow provided that the
1007 // multiplication is performed _after_ the evaluation of the binomial
1008 // coefficient.
1009 const SCEV *Coeff = BinomialCoefficient(It, i, SE, Result->getType());
1010 if (isa<SCEVCouldNotCompute>(Coeff))
1011 return Coeff;
1012
1013 Result = SE.getAddExpr(Result, SE.getMulExpr(Operands[i], Coeff));
1014 }
1015 return Result;
1016}
1017
1018//===----------------------------------------------------------------------===//
1019// SCEV Expression folder implementations
1020//===----------------------------------------------------------------------===//
1021
1022/// The SCEVCastSinkingRewriter takes a scalar evolution expression,
1023/// which computes a pointer-typed value, and rewrites the whole expression
1024/// tree so that *all* the computations are done on integers, and the only
1025/// pointer-typed operands in the expression are SCEVUnknown.
1026/// The CreatePtrCast callback is invoked to create the actual conversion
1027/// (ptrtoint or ptrtoaddr) at the SCEVUnknown leaves.
1029 : public SCEVRewriteVisitor<SCEVCastSinkingRewriter> {
1031 using ConversionFn = function_ref<const SCEV *(const SCEVUnknown *)>;
1032 Type *TargetTy;
1033 ConversionFn CreatePtrCast;
1034
1035public:
1037 ConversionFn CreatePtrCast)
1038 : Base(SE), TargetTy(TargetTy), CreatePtrCast(std::move(CreatePtrCast)) {}
1039
1040 static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE,
1041 Type *TargetTy, ConversionFn CreatePtrCast) {
1042 SCEVCastSinkingRewriter Rewriter(SE, TargetTy, std::move(CreatePtrCast));
1043 return Rewriter.visit(Scev);
1044 }
1045
1046 const SCEV *visit(const SCEV *S) {
1047 Type *STy = S->getType();
1048 // If the expression is not pointer-typed, just keep it as-is.
1049 if (!STy->isPointerTy())
1050 return S;
1051 // Else, recursively sink the cast down into it.
1052 return Base::visit(S);
1053 }
1054
1055 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
1056 // Preserve wrap flags on rewritten SCEVAddExpr, which the default
1057 // implementation drops.
1059 bool Changed = false;
1060 for (const auto *Op : Expr->operands()) {
1061 Operands.push_back(visit(Op));
1062 Changed |= Op != Operands.back();
1063 }
1064 return !Changed ? Expr : SE.getAddExpr(Operands, Expr->getNoWrapFlags());
1065 }
1066
1067 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
1069 bool Changed = false;
1070 for (const auto *Op : Expr->operands()) {
1071 Operands.push_back(visit(Op));
1072 Changed |= Op != Operands.back();
1073 }
1074 return !Changed ? Expr : SE.getMulExpr(Operands, Expr->getNoWrapFlags());
1075 }
1076
1077 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
1078 assert(Expr->getType()->isPointerTy() &&
1079 "Should only reach pointer-typed SCEVUnknown's.");
1080 // Perform some basic constant folding. If the operand of the cast is a
1081 // null pointer, don't create a cast SCEV expression (that will be left
1082 // as-is), but produce a zero constant.
1084 return SE.getZero(TargetTy);
1085 return CreatePtrCast(Expr);
1086 }
1087};
1088
1090 assert(Op->getType()->isPointerTy() && "Op must be a pointer");
1091
1092 // It isn't legal for optimizations to construct new ptrtoint expressions
1093 // for non-integral pointers.
1094 if (getDataLayout().isNonIntegralPointerType(Op->getType()))
1095 return getCouldNotCompute();
1096
1097 Type *IntPtrTy = getDataLayout().getIntPtrType(Op->getType());
1098
1099 // We can only trivially model ptrtoint if SCEV's effective (integer) type
1100 // is sufficiently wide to represent all possible pointer values.
1101 // We could theoretically teach SCEV to truncate wider pointers, but
1102 // that isn't implemented for now.
1104 getDataLayout().getTypeSizeInBits(IntPtrTy))
1105 return getCouldNotCompute();
1106
1107 // Use the rewriter to sink the cast down to SCEVUnknown leaves.
1109 Op, *this, IntPtrTy, [this, IntPtrTy](const SCEVUnknown *U) {
1111 ID.AddInteger(scPtrToInt);
1112 ID.AddPointer(U);
1113 void *IP = nullptr;
1114 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1115 return S;
1116 SCEV *S = new (SCEVAllocator)
1117 SCEVPtrToIntExpr(ID.Intern(SCEVAllocator), U, IntPtrTy);
1118 UniqueSCEVs.InsertNode(S, IP);
1119 registerUser(S, U);
1120 return static_cast<const SCEV *>(S);
1121 });
1122 assert(IntOp->getType()->isIntegerTy() &&
1123 "We must have succeeded in sinking the cast, "
1124 "and ending up with an integer-typed expression!");
1125 return IntOp;
1126}
1127
1129 assert(Op->getType()->isPointerTy() && "Op must be a pointer");
1130 Type *Ty = DL.getAddressType(Op->getType());
1131
1132 // Use the rewriter to sink the cast down to SCEVUnknown leaves.
1133 // The rewriter handles null pointer constant folding.
1135 Op, *this, Ty, [this, Ty](const SCEVUnknown *U) {
1137 ID.AddInteger(scPtrToAddr);
1138 ID.AddPointer(U);
1139 ID.AddPointer(Ty);
1140 void *IP = nullptr;
1141 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1142 return S;
1143 SCEV *S = new (SCEVAllocator)
1144 SCEVPtrToAddrExpr(ID.Intern(SCEVAllocator), U, Ty);
1145 UniqueSCEVs.InsertNode(S, IP);
1146 registerUser(S, U);
1147 return static_cast<const SCEV *>(S);
1148 });
1149 assert(IntOp->getType()->isIntegerTy() &&
1150 "We must have succeeded in sinking the cast, "
1151 "and ending up with an integer-typed expression!");
1152 return IntOp;
1153}
1154
1156 assert(Ty->isIntegerTy() && "Target type must be an integer type!");
1157
1158 const SCEV *IntOp = getLosslessPtrToIntExpr(Op);
1159 if (isa<SCEVCouldNotCompute>(IntOp))
1160 return IntOp;
1161
1162 return getTruncateOrZeroExtend(IntOp, Ty);
1163}
1164
1166 unsigned Depth) {
1167 assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) &&
1168 "This is not a truncating conversion!");
1169 assert(isSCEVable(Ty) &&
1170 "This is not a conversion to a SCEVable type!");
1171 assert(!Op->getType()->isPointerTy() && "Can't truncate pointer!");
1172 Ty = getEffectiveSCEVType(Ty);
1173
1175 ID.AddInteger(scTruncate);
1176 ID.AddPointer(Op);
1177 ID.AddPointer(Ty);
1178 void *IP = nullptr;
1179 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1180
1181 // Fold if the operand is constant.
1182 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1183 return getConstant(
1184 cast<ConstantInt>(ConstantExpr::getTrunc(SC->getValue(), Ty)));
1185
1186 // trunc(trunc(x)) --> trunc(x)
1188 return getTruncateExpr(ST->getOperand(), Ty, Depth + 1);
1189
1190 // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing
1192 return getTruncateOrSignExtend(SS->getOperand(), Ty, Depth + 1);
1193
1194 // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing
1196 return getTruncateOrZeroExtend(SZ->getOperand(), Ty, Depth + 1);
1197
1198 if (Depth > MaxCastDepth) {
1199 SCEV *S =
1200 new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), Op, Ty);
1201 UniqueSCEVs.InsertNode(S, IP);
1202 registerUser(S, Op);
1203 return S;
1204 }
1205
1206 // trunc(x1 + ... + xN) --> trunc(x1) + ... + trunc(xN) and
1207 // trunc(x1 * ... * xN) --> trunc(x1) * ... * trunc(xN),
1208 // if after transforming we have at most one truncate, not counting truncates
1209 // that replace other casts.
1211 auto *CommOp = cast<SCEVCommutativeExpr>(Op);
1213 unsigned numTruncs = 0;
1214 for (unsigned i = 0, e = CommOp->getNumOperands(); i != e && numTruncs < 2;
1215 ++i) {
1216 const SCEV *S = getTruncateExpr(CommOp->getOperand(i), Ty, Depth + 1);
1217 if (!isa<SCEVIntegralCastExpr>(CommOp->getOperand(i)) &&
1219 numTruncs++;
1220 Operands.push_back(S);
1221 }
1222 if (numTruncs < 2) {
1223 if (isa<SCEVAddExpr>(Op))
1224 return getAddExpr(Operands);
1225 if (isa<SCEVMulExpr>(Op))
1226 return getMulExpr(Operands);
1227 llvm_unreachable("Unexpected SCEV type for Op.");
1228 }
1229 // Although we checked in the beginning that ID is not in the cache, it is
1230 // possible that during recursion and different modification ID was inserted
1231 // into the cache. So if we find it, just return it.
1232 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1233 return S;
1234 }
1235
1236 // If the input value is a chrec scev, truncate the chrec's operands.
1237 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
1239 for (const SCEV *Op : AddRec->operands())
1240 Operands.push_back(getTruncateExpr(Op, Ty, Depth + 1));
1241 return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap);
1242 }
1243
1244 // Return zero if truncating to known zeros.
1245 uint32_t MinTrailingZeros = getMinTrailingZeros(Op);
1246 if (MinTrailingZeros >= getTypeSizeInBits(Ty))
1247 return getZero(Ty);
1248
1249 // The cast wasn't folded; create an explicit cast node. We can reuse
1250 // the existing insert position since if we get here, we won't have
1251 // made any changes which would invalidate it.
1252 SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator),
1253 Op, Ty);
1254 UniqueSCEVs.InsertNode(S, IP);
1255 registerUser(S, Op);
1256 return S;
1257}
1258
1259// Get the limit of a recurrence such that incrementing by Step cannot cause
1260// signed overflow as long as the value of the recurrence within the
1261// loop does not exceed this limit before incrementing.
1262static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step,
1263 ICmpInst::Predicate *Pred,
1264 ScalarEvolution *SE) {
1265 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1266 if (SE->isKnownPositive(Step)) {
1267 *Pred = ICmpInst::ICMP_SLT;
1269 SE->getSignedRangeMax(Step));
1270 }
1271 if (SE->isKnownNegative(Step)) {
1272 *Pred = ICmpInst::ICMP_SGT;
1274 SE->getSignedRangeMin(Step));
1275 }
1276 return nullptr;
1277}
1278
1279// Get the limit of a recurrence such that incrementing by Step cannot cause
1280// unsigned overflow as long as the value of the recurrence within the loop does
1281// not exceed this limit before incrementing.
1283 ICmpInst::Predicate *Pred,
1284 ScalarEvolution *SE) {
1285 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1286 *Pred = ICmpInst::ICMP_ULT;
1287
1289 SE->getUnsignedRangeMax(Step));
1290}
1291
1292namespace {
1293
1294struct ExtendOpTraitsBase {
1295 typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *,
1296 unsigned);
1297};
1298
1299// Used to make code generic over signed and unsigned overflow.
1300template <typename ExtendOp> struct ExtendOpTraits {
1301 // Members present:
1302 //
1303 // static const SCEV::NoWrapFlags WrapType;
1304 //
1305 // static const ExtendOpTraitsBase::GetExtendExprTy GetExtendExpr;
1306 //
1307 // static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1308 // ICmpInst::Predicate *Pred,
1309 // ScalarEvolution *SE);
1310};
1311
1312template <>
1313struct ExtendOpTraits<SCEVSignExtendExpr> : public ExtendOpTraitsBase {
1314 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNSW;
1315
1316 static const GetExtendExprTy GetExtendExpr;
1317
1318 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1319 ICmpInst::Predicate *Pred,
1320 ScalarEvolution *SE) {
1321 return getSignedOverflowLimitForStep(Step, Pred, SE);
1322 }
1323};
1324
1325const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1327
1328template <>
1329struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase {
1330 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNUW;
1331
1332 static const GetExtendExprTy GetExtendExpr;
1333
1334 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1335 ICmpInst::Predicate *Pred,
1336 ScalarEvolution *SE) {
1337 return getUnsignedOverflowLimitForStep(Step, Pred, SE);
1338 }
1339};
1340
1341const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1343
1344} // end anonymous namespace
1345
1346// The recurrence AR has been shown to have no signed/unsigned wrap or something
1347// close to it. Typically, if we can prove NSW/NUW for AR, then we can just as
1348// easily prove NSW/NUW for its preincrement or postincrement sibling. This
1349// allows normalizing a sign/zero extended AddRec as such: {sext/zext(Step +
1350// Start),+,Step} => {(Step + sext/zext(Start),+,Step} As a result, the
1351// expression "Step + sext/zext(PreIncAR)" is congruent with
1352// "sext/zext(PostIncAR)"
1353template <typename ExtendOpTy>
1354static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty,
1355 ScalarEvolution *SE, unsigned Depth) {
1356 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1357 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1358
1359 const Loop *L = AR->getLoop();
1360 const SCEV *Start = AR->getStart();
1361 const SCEV *Step = AR->getStepRecurrence(*SE);
1362
1363 // Check for a simple looking step prior to loop entry.
1364 const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Start);
1365 if (!SA)
1366 return nullptr;
1367
1368 // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV
1369 // subtraction is expensive. For this purpose, perform a quick and dirty
1370 // difference, by checking for Step in the operand list. Note, that
1371 // SA might have repeated ops, like %a + %a + ..., so only remove one.
1373 for (auto It = DiffOps.begin(); It != DiffOps.end(); ++It)
1374 if (*It == Step) {
1375 DiffOps.erase(It);
1376 break;
1377 }
1378
1379 if (DiffOps.size() == SA->getNumOperands())
1380 return nullptr;
1381
1382 // Try to prove `WrapType` (SCEV::FlagNSW or SCEV::FlagNUW) on `PreStart` +
1383 // `Step`:
1384
1385 // 1. NSW/NUW flags on the step increment.
1386 auto PreStartFlags =
1388 const SCEV *PreStart = SE->getAddExpr(DiffOps, PreStartFlags);
1390 SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap));
1391
1392 // "{S,+,X} is <nsw>/<nuw>" and "the backedge is taken at least once" implies
1393 // "S+X does not sign/unsign-overflow".
1394 //
1395
1396 const SCEV *BECount = SE->getBackedgeTakenCount(L);
1397 if (PreAR && PreAR->getNoWrapFlags(WrapType) &&
1398 !isa<SCEVCouldNotCompute>(BECount) && SE->isKnownPositive(BECount))
1399 return PreStart;
1400
1401 // 2. Direct overflow check on the step operation's expression.
1402 unsigned BitWidth = SE->getTypeSizeInBits(AR->getType());
1403 Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2);
1404 const SCEV *OperandExtendedStart =
1405 SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy, Depth),
1406 (SE->*GetExtendExpr)(Step, WideTy, Depth));
1407 if ((SE->*GetExtendExpr)(Start, WideTy, Depth) == OperandExtendedStart) {
1408 if (PreAR && AR->getNoWrapFlags(WrapType)) {
1409 // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW
1410 // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then
1411 // `PreAR` == {`PreStart`,+,`Step`} is also `WrapType`. Cache this fact.
1412 SE->setNoWrapFlags(const_cast<SCEVAddRecExpr *>(PreAR), WrapType);
1413 }
1414 return PreStart;
1415 }
1416
1417 // 3. Loop precondition.
1419 const SCEV *OverflowLimit =
1420 ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(Step, &Pred, SE);
1421
1422 if (OverflowLimit &&
1423 SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit))
1424 return PreStart;
1425
1426 return nullptr;
1427}
1428
1429// Get the normalized zero or sign extended expression for this AddRec's Start.
1430template <typename ExtendOpTy>
1431static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty,
1432 ScalarEvolution *SE,
1433 unsigned Depth) {
1434 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1435
1436 const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE, Depth);
1437 if (!PreStart)
1438 return (SE->*GetExtendExpr)(AR->getStart(), Ty, Depth);
1439
1440 return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty,
1441 Depth),
1442 (SE->*GetExtendExpr)(PreStart, Ty, Depth));
1443}
1444
1445// Try to prove away overflow by looking at "nearby" add recurrences. A
1446// motivating example for this rule: if we know `{0,+,4}` is `ult` `-1` and it
1447// does not itself wrap then we can conclude that `{1,+,4}` is `nuw`.
1448//
1449// Formally:
1450//
1451// {S,+,X} == {S-T,+,X} + T
1452// => Ext({S,+,X}) == Ext({S-T,+,X} + T)
1453//
1454// If ({S-T,+,X} + T) does not overflow ... (1)
1455//
1456// RHS == Ext({S-T,+,X} + T) == Ext({S-T,+,X}) + Ext(T)
1457//
1458// If {S-T,+,X} does not overflow ... (2)
1459//
1460// RHS == Ext({S-T,+,X}) + Ext(T) == {Ext(S-T),+,Ext(X)} + Ext(T)
1461// == {Ext(S-T)+Ext(T),+,Ext(X)}
1462//
1463// If (S-T)+T does not overflow ... (3)
1464//
1465// RHS == {Ext(S-T)+Ext(T),+,Ext(X)} == {Ext(S-T+T),+,Ext(X)}
1466// == {Ext(S),+,Ext(X)} == LHS
1467//
1468// Thus, if (1), (2) and (3) are true for some T, then
1469// Ext({S,+,X}) == {Ext(S),+,Ext(X)}
1470//
1471// (3) is implied by (1) -- "(S-T)+T does not overflow" is simply "({S-T,+,X}+T)
1472// does not overflow" restricted to the 0th iteration. Therefore we only need
1473// to check for (1) and (2).
1474//
1475// In the current context, S is `Start`, X is `Step`, Ext is `ExtendOpTy` and T
1476// is `Delta` (defined below).
1477template <typename ExtendOpTy>
1478bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start,
1479 const SCEV *Step,
1480 const Loop *L) {
1481 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1482
1483 // We restrict `Start` to a constant to prevent SCEV from spending too much
1484 // time here. It is correct (but more expensive) to continue with a
1485 // non-constant `Start` and do a general SCEV subtraction to compute
1486 // `PreStart` below.
1487 const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start);
1488 if (!StartC)
1489 return false;
1490
1491 APInt StartAI = StartC->getAPInt();
1492
1493 for (unsigned Delta : {-2, -1, 1, 2}) {
1494 const SCEV *PreStart = getConstant(StartAI - Delta);
1495
1496 FoldingSetNodeID ID;
1497 ID.AddInteger(scAddRecExpr);
1498 ID.AddPointer(PreStart);
1499 ID.AddPointer(Step);
1500 ID.AddPointer(L);
1501 void *IP = nullptr;
1502 const auto *PreAR =
1503 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
1504
1505 // Give up if we don't already have the add recurrence we need because
1506 // actually constructing an add recurrence is relatively expensive.
1507 if (PreAR && PreAR->getNoWrapFlags(WrapType)) { // proves (2)
1508 const SCEV *DeltaS = getConstant(StartC->getType(), Delta);
1510 const SCEV *Limit = ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(
1511 DeltaS, &Pred, this);
1512 if (Limit && isKnownPredicate(Pred, PreAR, Limit)) // proves (1)
1513 return true;
1514 }
1515 }
1516
1517 return false;
1518}
1519
1520// Finds an integer D for an expression (C + x + y + ...) such that the top
1521// level addition in (D + (C - D + x + y + ...)) would not wrap (signed or
1522// unsigned) and the number of trailing zeros of (C - D + x + y + ...) is
1523// maximized, where C is the \p ConstantTerm, x, y, ... are arbitrary SCEVs, and
1524// the (C + x + y + ...) expression is \p WholeAddExpr.
1526 const SCEVConstant *ConstantTerm,
1527 const SCEVAddExpr *WholeAddExpr) {
1528 const APInt &C = ConstantTerm->getAPInt();
1529 const unsigned BitWidth = C.getBitWidth();
1530 // Find number of trailing zeros of (x + y + ...) w/o the C first:
1531 uint32_t TZ = BitWidth;
1532 for (unsigned I = 1, E = WholeAddExpr->getNumOperands(); I < E && TZ; ++I)
1533 TZ = std::min(TZ, SE.getMinTrailingZeros(WholeAddExpr->getOperand(I)));
1534 if (TZ) {
1535 // Set D to be as many least significant bits of C as possible while still
1536 // guaranteeing that adding D to (C - D + x + y + ...) won't cause a wrap:
1537 return TZ < BitWidth ? C.trunc(TZ).zext(BitWidth) : C;
1538 }
1539 return APInt(BitWidth, 0);
1540}
1541
1542// Finds an integer D for an affine AddRec expression {C,+,x} such that the top
1543// level addition in (D + {C-D,+,x}) would not wrap (signed or unsigned) and the
1544// number of trailing zeros of (C - D + x * n) is maximized, where C is the \p
1545// ConstantStart, x is an arbitrary \p Step, and n is the loop trip count.
1547 const APInt &ConstantStart,
1548 const SCEV *Step) {
1549 const unsigned BitWidth = ConstantStart.getBitWidth();
1550 const uint32_t TZ = SE.getMinTrailingZeros(Step);
1551 if (TZ)
1552 return TZ < BitWidth ? ConstantStart.trunc(TZ).zext(BitWidth)
1553 : ConstantStart;
1554 return APInt(BitWidth, 0);
1555}
1556
1558 const ScalarEvolution::FoldID &ID, const SCEV *S,
1561 &FoldCacheUser) {
1562 auto I = FoldCache.insert({ID, S});
1563 if (!I.second) {
1564 // Remove FoldCacheUser entry for ID when replacing an existing FoldCache
1565 // entry.
1566 auto &UserIDs = FoldCacheUser[I.first->second];
1567 assert(count(UserIDs, ID) == 1 && "unexpected duplicates in UserIDs");
1568 for (unsigned I = 0; I != UserIDs.size(); ++I)
1569 if (UserIDs[I] == ID) {
1570 std::swap(UserIDs[I], UserIDs.back());
1571 break;
1572 }
1573 UserIDs.pop_back();
1574 I.first->second = S;
1575 }
1576 FoldCacheUser[S].push_back(ID);
1577}
1578
1579const SCEV *
1581 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1582 "This is not an extending conversion!");
1583 assert(isSCEVable(Ty) &&
1584 "This is not a conversion to a SCEVable type!");
1585 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1586 Ty = getEffectiveSCEVType(Ty);
1587
1588 FoldID ID(scZeroExtend, Op, Ty);
1589 if (const SCEV *S = FoldCache.lookup(ID))
1590 return S;
1591
1592 const SCEV *S = getZeroExtendExprImpl(Op, Ty, Depth);
1594 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
1595 return S;
1596}
1597
1599 unsigned Depth) {
1600 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1601 "This is not an extending conversion!");
1602 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
1603 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1604
1605 // Fold if the operand is constant.
1606 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1607 return getConstant(SC->getAPInt().zext(getTypeSizeInBits(Ty)));
1608
1609 // zext(zext(x)) --> zext(x)
1611 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1612
1613 // Before doing any expensive analysis, check to see if we've already
1614 // computed a SCEV for this Op and Ty.
1616 ID.AddInteger(scZeroExtend);
1617 ID.AddPointer(Op);
1618 ID.AddPointer(Ty);
1619 void *IP = nullptr;
1620 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1621 if (Depth > MaxCastDepth) {
1622 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1623 Op, Ty);
1624 UniqueSCEVs.InsertNode(S, IP);
1625 registerUser(S, Op);
1626 return S;
1627 }
1628
1629 // zext(trunc(x)) --> zext(x) or x or trunc(x)
1631 // It's possible the bits taken off by the truncate were all zero bits. If
1632 // so, we should be able to simplify this further.
1633 const SCEV *X = ST->getOperand();
1635 unsigned TruncBits = getTypeSizeInBits(ST->getType());
1636 unsigned NewBits = getTypeSizeInBits(Ty);
1637 if (CR.truncate(TruncBits).zeroExtend(NewBits).contains(
1638 CR.zextOrTrunc(NewBits)))
1639 return getTruncateOrZeroExtend(X, Ty, Depth);
1640 }
1641
1642 // If the input value is a chrec scev, and we can prove that the value
1643 // did not overflow the old, smaller, value, we can zero extend all of the
1644 // operands (often constants). This allows analysis of something like
1645 // this: for (unsigned char X = 0; X < 100; ++X) { int Y = X; }
1647 if (AR->isAffine()) {
1648 const SCEV *Start = AR->getStart();
1649 const SCEV *Step = AR->getStepRecurrence(*this);
1650 unsigned BitWidth = getTypeSizeInBits(AR->getType());
1651 const Loop *L = AR->getLoop();
1652
1653 // If we have special knowledge that this addrec won't overflow,
1654 // we don't need to do any further analysis.
1655 if (AR->hasNoUnsignedWrap()) {
1656 Start =
1658 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1659 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1660 }
1661
1662 // Check whether the backedge-taken count is SCEVCouldNotCompute.
1663 // Note that this serves two purposes: It filters out loops that are
1664 // simply not analyzable, and it covers the case where this code is
1665 // being called from within backedge-taken count analysis, such that
1666 // attempting to ask for the backedge-taken count would likely result
1667 // in infinite recursion. In the later case, the analysis code will
1668 // cope with a conservative value, and it will take care to purge
1669 // that value once it has finished.
1670 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
1671 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
1672 // Manually compute the final value for AR, checking for overflow.
1673
1674 // Check whether the backedge-taken count can be losslessly casted to
1675 // the addrec's type. The count is always unsigned.
1676 const SCEV *CastedMaxBECount =
1677 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
1678 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
1679 CastedMaxBECount, MaxBECount->getType(), Depth);
1680 if (MaxBECount == RecastedMaxBECount) {
1681 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
1682 // Check whether Start+Step*MaxBECount has no unsigned overflow.
1683 const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step,
1685 const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul,
1687 Depth + 1),
1688 WideTy, Depth + 1);
1689 const SCEV *WideStart = getZeroExtendExpr(Start, WideTy, Depth + 1);
1690 const SCEV *WideMaxBECount =
1691 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
1692 const SCEV *OperandExtendedAdd =
1693 getAddExpr(WideStart,
1694 getMulExpr(WideMaxBECount,
1695 getZeroExtendExpr(Step, WideTy, Depth + 1),
1698 if (ZAdd == OperandExtendedAdd) {
1699 // Cache knowledge of AR NUW, which is propagated to this AddRec.
1700 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1701 // Return the expression with the addrec on the outside.
1702 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1703 Depth + 1);
1704 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1705 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1706 }
1707 // Similar to above, only this time treat the step value as signed.
1708 // This covers loops that count down.
1709 OperandExtendedAdd =
1710 getAddExpr(WideStart,
1711 getMulExpr(WideMaxBECount,
1712 getSignExtendExpr(Step, WideTy, Depth + 1),
1715 if (ZAdd == OperandExtendedAdd) {
1716 // Cache knowledge of AR NW, which is propagated to this AddRec.
1717 // Negative step causes unsigned wrap, but it still can't self-wrap.
1718 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1719 // Return the expression with the addrec on the outside.
1720 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1721 Depth + 1);
1722 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1723 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1724 }
1725 }
1726 }
1727
1728 // Normally, in the cases we can prove no-overflow via a
1729 // backedge guarding condition, we can also compute a backedge
1730 // taken count for the loop. The exceptions are assumptions and
1731 // guards present in the loop -- SCEV is not great at exploiting
1732 // these to compute max backedge taken counts, but can still use
1733 // these to prove lack of overflow. Use this fact to avoid
1734 // doing extra work that may not pay off.
1735 if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards ||
1736 !AC.assumptions().empty()) {
1737
1738 auto NewFlags = proveNoUnsignedWrapViaInduction(AR);
1739 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
1740 if (AR->hasNoUnsignedWrap()) {
1741 // Same as nuw case above - duplicated here to avoid a compile time
1742 // issue. It's not clear that the order of checks does matter, but
1743 // it's one of two issue possible causes for a change which was
1744 // reverted. Be conservative for the moment.
1745 Start =
1747 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1748 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1749 }
1750
1751 // For a negative step, we can extend the operands iff doing so only
1752 // traverses values in the range zext([0,UINT_MAX]).
1753 if (isKnownNegative(Step)) {
1755 getSignedRangeMin(Step));
1758 // Cache knowledge of AR NW, which is propagated to this
1759 // AddRec. Negative step causes unsigned wrap, but it
1760 // still can't self-wrap.
1761 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1762 // Return the expression with the addrec on the outside.
1763 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1764 Depth + 1);
1765 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1766 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1767 }
1768 }
1769 }
1770
1771 // zext({C,+,Step}) --> (zext(D) + zext({C-D,+,Step}))<nuw><nsw>
1772 // if D + (C - D + Step * n) could be proven to not unsigned wrap
1773 // where D maximizes the number of trailing zeros of (C - D + Step * n)
1774 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
1775 const APInt &C = SC->getAPInt();
1776 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
1777 if (D != 0) {
1778 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1779 const SCEV *SResidual =
1780 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
1781 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1782 return getAddExpr(SZExtD, SZExtR,
1784 Depth + 1);
1785 }
1786 }
1787
1788 if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) {
1789 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1790 Start =
1792 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1793 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1794 }
1795 }
1796
1797 // zext(A % B) --> zext(A) % zext(B)
1798 {
1799 const SCEV *LHS;
1800 const SCEV *RHS;
1801 if (match(Op, m_scev_URem(m_SCEV(LHS), m_SCEV(RHS), *this)))
1802 return getURemExpr(getZeroExtendExpr(LHS, Ty, Depth + 1),
1803 getZeroExtendExpr(RHS, Ty, Depth + 1));
1804 }
1805
1806 // zext(A / B) --> zext(A) / zext(B).
1807 if (auto *Div = dyn_cast<SCEVUDivExpr>(Op))
1808 return getUDivExpr(getZeroExtendExpr(Div->getLHS(), Ty, Depth + 1),
1809 getZeroExtendExpr(Div->getRHS(), Ty, Depth + 1));
1810
1811 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1812 // zext((A + B + ...)<nuw>) --> (zext(A) + zext(B) + ...)<nuw>
1813 if (SA->hasNoUnsignedWrap()) {
1814 // If the addition does not unsign overflow then we can, by definition,
1815 // commute the zero extension with the addition operation.
1817 for (const auto *Op : SA->operands())
1818 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1819 return getAddExpr(Ops, SCEV::FlagNUW, Depth + 1);
1820 }
1821
1822 // zext(C + x + y + ...) --> (zext(D) + zext((C - D) + x + y + ...))
1823 // if D + (C - D + x + y + ...) could be proven to not unsigned wrap
1824 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1825 //
1826 // Often address arithmetics contain expressions like
1827 // (zext (add (shl X, C1), C2)), for instance, (zext (5 + (4 * X))).
1828 // This transformation is useful while proving that such expressions are
1829 // equal or differ by a small constant amount, see LoadStoreVectorizer pass.
1830 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1831 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1832 if (D != 0) {
1833 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1834 const SCEV *SResidual =
1836 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1837 return getAddExpr(SZExtD, SZExtR,
1839 Depth + 1);
1840 }
1841 }
1842 }
1843
1844 if (auto *SM = dyn_cast<SCEVMulExpr>(Op)) {
1845 // zext((A * B * ...)<nuw>) --> (zext(A) * zext(B) * ...)<nuw>
1846 if (SM->hasNoUnsignedWrap()) {
1847 // If the multiply does not unsign overflow then we can, by definition,
1848 // commute the zero extension with the multiply operation.
1850 for (const auto *Op : SM->operands())
1851 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1852 return getMulExpr(Ops, SCEV::FlagNUW, Depth + 1);
1853 }
1854
1855 // zext(2^K * (trunc X to iN)) to iM ->
1856 // 2^K * (zext(trunc X to i{N-K}) to iM)<nuw>
1857 //
1858 // Proof:
1859 //
1860 // zext(2^K * (trunc X to iN)) to iM
1861 // = zext((trunc X to iN) << K) to iM
1862 // = zext((trunc X to i{N-K}) << K)<nuw> to iM
1863 // (because shl removes the top K bits)
1864 // = zext((2^K * (trunc X to i{N-K}))<nuw>) to iM
1865 // = (2^K * (zext(trunc X to i{N-K}) to iM))<nuw>.
1866 //
1867 const APInt *C;
1868 const SCEV *TruncRHS;
1869 if (match(SM,
1870 m_scev_Mul(m_scev_APInt(C), m_scev_Trunc(m_SCEV(TruncRHS)))) &&
1871 C->isPowerOf2()) {
1872 int NewTruncBits =
1873 getTypeSizeInBits(SM->getOperand(1)->getType()) - C->logBase2();
1874 Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits);
1875 return getMulExpr(
1876 getZeroExtendExpr(SM->getOperand(0), Ty),
1877 getZeroExtendExpr(getTruncateExpr(TruncRHS, NewTruncTy), Ty),
1878 SCEV::FlagNUW, Depth + 1);
1879 }
1880 }
1881
1882 // zext(umin(x, y)) -> umin(zext(x), zext(y))
1883 // zext(umax(x, y)) -> umax(zext(x), zext(y))
1887 for (auto *Operand : MinMax->operands())
1888 Operands.push_back(getZeroExtendExpr(Operand, Ty));
1890 return getUMinExpr(Operands);
1891 return getUMaxExpr(Operands);
1892 }
1893
1894 // zext(umin_seq(x, y)) -> umin_seq(zext(x), zext(y))
1896 assert(isa<SCEVSequentialUMinExpr>(MinMax) && "Not supported!");
1898 for (auto *Operand : MinMax->operands())
1899 Operands.push_back(getZeroExtendExpr(Operand, Ty));
1900 return getUMinExpr(Operands, /*Sequential*/ true);
1901 }
1902
1903 // The cast wasn't folded; create an explicit cast node.
1904 // Recompute the insert position, as it may have been invalidated.
1905 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1906 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1907 Op, Ty);
1908 UniqueSCEVs.InsertNode(S, IP);
1909 registerUser(S, Op);
1910 return S;
1911}
1912
1913const SCEV *
1915 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1916 "This is not an extending conversion!");
1917 assert(isSCEVable(Ty) &&
1918 "This is not a conversion to a SCEVable type!");
1919 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1920 Ty = getEffectiveSCEVType(Ty);
1921
1922 FoldID ID(scSignExtend, Op, Ty);
1923 if (const SCEV *S = FoldCache.lookup(ID))
1924 return S;
1925
1926 const SCEV *S = getSignExtendExprImpl(Op, Ty, Depth);
1928 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
1929 return S;
1930}
1931
1933 unsigned Depth) {
1934 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1935 "This is not an extending conversion!");
1936 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
1937 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1938 Ty = getEffectiveSCEVType(Ty);
1939
1940 // Fold if the operand is constant.
1941 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1942 return getConstant(SC->getAPInt().sext(getTypeSizeInBits(Ty)));
1943
1944 // sext(sext(x)) --> sext(x)
1946 return getSignExtendExpr(SS->getOperand(), Ty, Depth + 1);
1947
1948 // sext(zext(x)) --> zext(x)
1950 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1951
1952 // Before doing any expensive analysis, check to see if we've already
1953 // computed a SCEV for this Op and Ty.
1955 ID.AddInteger(scSignExtend);
1956 ID.AddPointer(Op);
1957 ID.AddPointer(Ty);
1958 void *IP = nullptr;
1959 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1960 // Limit recursion depth.
1961 if (Depth > MaxCastDepth) {
1962 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
1963 Op, Ty);
1964 UniqueSCEVs.InsertNode(S, IP);
1965 registerUser(S, Op);
1966 return S;
1967 }
1968
1969 // sext(trunc(x)) --> sext(x) or x or trunc(x)
1971 // It's possible the bits taken off by the truncate were all sign bits. If
1972 // so, we should be able to simplify this further.
1973 const SCEV *X = ST->getOperand();
1975 unsigned TruncBits = getTypeSizeInBits(ST->getType());
1976 unsigned NewBits = getTypeSizeInBits(Ty);
1977 if (CR.truncate(TruncBits).signExtend(NewBits).contains(
1978 CR.sextOrTrunc(NewBits)))
1979 return getTruncateOrSignExtend(X, Ty, Depth);
1980 }
1981
1982 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1983 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
1984 if (SA->hasNoSignedWrap()) {
1985 // If the addition does not sign overflow then we can, by definition,
1986 // commute the sign extension with the addition operation.
1988 for (const auto *Op : SA->operands())
1989 Ops.push_back(getSignExtendExpr(Op, Ty, Depth + 1));
1990 return getAddExpr(Ops, SCEV::FlagNSW, Depth + 1);
1991 }
1992
1993 // sext(C + x + y + ...) --> (sext(D) + sext((C - D) + x + y + ...))
1994 // if D + (C - D + x + y + ...) could be proven to not signed wrap
1995 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1996 //
1997 // For instance, this will bring two seemingly different expressions:
1998 // 1 + sext(5 + 20 * %x + 24 * %y) and
1999 // sext(6 + 20 * %x + 24 * %y)
2000 // to the same form:
2001 // 2 + sext(4 + 20 * %x + 24 * %y)
2002 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
2003 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
2004 if (D != 0) {
2005 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
2006 const SCEV *SResidual =
2008 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
2009 return getAddExpr(SSExtD, SSExtR,
2011 Depth + 1);
2012 }
2013 }
2014 }
2015 // If the input value is a chrec scev, and we can prove that the value
2016 // did not overflow the old, smaller, value, we can sign extend all of the
2017 // operands (often constants). This allows analysis of something like
2018 // this: for (signed char X = 0; X < 100; ++X) { int Y = X; }
2020 if (AR->isAffine()) {
2021 const SCEV *Start = AR->getStart();
2022 const SCEV *Step = AR->getStepRecurrence(*this);
2023 unsigned BitWidth = getTypeSizeInBits(AR->getType());
2024 const Loop *L = AR->getLoop();
2025
2026 // If we have special knowledge that this addrec won't overflow,
2027 // we don't need to do any further analysis.
2028 if (AR->hasNoSignedWrap()) {
2029 Start =
2031 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2032 return getAddRecExpr(Start, Step, L, SCEV::FlagNSW);
2033 }
2034
2035 // Check whether the backedge-taken count is SCEVCouldNotCompute.
2036 // Note that this serves two purposes: It filters out loops that are
2037 // simply not analyzable, and it covers the case where this code is
2038 // being called from within backedge-taken count analysis, such that
2039 // attempting to ask for the backedge-taken count would likely result
2040 // in infinite recursion. In the later case, the analysis code will
2041 // cope with a conservative value, and it will take care to purge
2042 // that value once it has finished.
2043 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
2044 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
2045 // Manually compute the final value for AR, checking for
2046 // overflow.
2047
2048 // Check whether the backedge-taken count can be losslessly casted to
2049 // the addrec's type. The count is always unsigned.
2050 const SCEV *CastedMaxBECount =
2051 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
2052 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
2053 CastedMaxBECount, MaxBECount->getType(), Depth);
2054 if (MaxBECount == RecastedMaxBECount) {
2055 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
2056 // Check whether Start+Step*MaxBECount has no signed overflow.
2057 const SCEV *SMul = getMulExpr(CastedMaxBECount, Step,
2059 const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul,
2061 Depth + 1),
2062 WideTy, Depth + 1);
2063 const SCEV *WideStart = getSignExtendExpr(Start, WideTy, Depth + 1);
2064 const SCEV *WideMaxBECount =
2065 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
2066 const SCEV *OperandExtendedAdd =
2067 getAddExpr(WideStart,
2068 getMulExpr(WideMaxBECount,
2069 getSignExtendExpr(Step, WideTy, Depth + 1),
2072 if (SAdd == OperandExtendedAdd) {
2073 // Cache knowledge of AR NSW, which is propagated to this AddRec.
2074 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2075 // Return the expression with the addrec on the outside.
2076 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2077 Depth + 1);
2078 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2079 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2080 }
2081 // Similar to above, only this time treat the step value as unsigned.
2082 // This covers loops that count up with an unsigned step.
2083 OperandExtendedAdd =
2084 getAddExpr(WideStart,
2085 getMulExpr(WideMaxBECount,
2086 getZeroExtendExpr(Step, WideTy, Depth + 1),
2089 if (SAdd == OperandExtendedAdd) {
2090 // If AR wraps around then
2091 //
2092 // abs(Step) * MaxBECount > unsigned-max(AR->getType())
2093 // => SAdd != OperandExtendedAdd
2094 //
2095 // Thus (AR is not NW => SAdd != OperandExtendedAdd) <=>
2096 // (SAdd == OperandExtendedAdd => AR is NW)
2097
2098 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
2099
2100 // Return the expression with the addrec on the outside.
2101 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2102 Depth + 1);
2103 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
2104 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2105 }
2106 }
2107 }
2108
2109 auto NewFlags = proveNoSignedWrapViaInduction(AR);
2110 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
2111 if (AR->hasNoSignedWrap()) {
2112 // Same as nsw case above - duplicated here to avoid a compile time
2113 // issue. It's not clear that the order of checks does matter, but
2114 // it's one of two issue possible causes for a change which was
2115 // reverted. Be conservative for the moment.
2116 Start =
2118 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2119 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2120 }
2121
2122 // sext({C,+,Step}) --> (sext(D) + sext({C-D,+,Step}))<nuw><nsw>
2123 // if D + (C - D + Step * n) could be proven to not signed wrap
2124 // where D maximizes the number of trailing zeros of (C - D + Step * n)
2125 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
2126 const APInt &C = SC->getAPInt();
2127 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
2128 if (D != 0) {
2129 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
2130 const SCEV *SResidual =
2131 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
2132 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
2133 return getAddExpr(SSExtD, SSExtR,
2135 Depth + 1);
2136 }
2137 }
2138
2139 if (proveNoWrapByVaryingStart<SCEVSignExtendExpr>(Start, Step, L)) {
2140 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2141 Start =
2143 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2144 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2145 }
2146 }
2147
2148 // If the input value is provably positive and we could not simplify
2149 // away the sext build a zext instead.
2151 return getZeroExtendExpr(Op, Ty, Depth + 1);
2152
2153 // sext(smin(x, y)) -> smin(sext(x), sext(y))
2154 // sext(smax(x, y)) -> smax(sext(x), sext(y))
2158 for (auto *Operand : MinMax->operands())
2159 Operands.push_back(getSignExtendExpr(Operand, Ty));
2161 return getSMinExpr(Operands);
2162 return getSMaxExpr(Operands);
2163 }
2164
2165 // The cast wasn't folded; create an explicit cast node.
2166 // Recompute the insert position, as it may have been invalidated.
2167 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2168 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
2169 Op, Ty);
2170 UniqueSCEVs.InsertNode(S, IP);
2171 registerUser(S, { Op });
2172 return S;
2173}
2174
2176 Type *Ty) {
2177 switch (Kind) {
2178 case scTruncate:
2179 return getTruncateExpr(Op, Ty);
2180 case scZeroExtend:
2181 return getZeroExtendExpr(Op, Ty);
2182 case scSignExtend:
2183 return getSignExtendExpr(Op, Ty);
2184 case scPtrToInt:
2185 return getPtrToIntExpr(Op, Ty);
2186 default:
2187 llvm_unreachable("Not a SCEV cast expression!");
2188 }
2189}
2190
2191/// getAnyExtendExpr - Return a SCEV for the given operand extended with
2192/// unspecified bits out to the given type.
2194 Type *Ty) {
2195 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
2196 "This is not an extending conversion!");
2197 assert(isSCEVable(Ty) &&
2198 "This is not a conversion to a SCEVable type!");
2199 Ty = getEffectiveSCEVType(Ty);
2200
2201 // Sign-extend negative constants.
2202 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
2203 if (SC->getAPInt().isNegative())
2204 return getSignExtendExpr(Op, Ty);
2205
2206 // Peel off a truncate cast.
2208 const SCEV *NewOp = T->getOperand();
2209 if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty))
2210 return getAnyExtendExpr(NewOp, Ty);
2211 return getTruncateOrNoop(NewOp, Ty);
2212 }
2213
2214 // Next try a zext cast. If the cast is folded, use it.
2215 const SCEV *ZExt = getZeroExtendExpr(Op, Ty);
2216 if (!isa<SCEVZeroExtendExpr>(ZExt))
2217 return ZExt;
2218
2219 // Next try a sext cast. If the cast is folded, use it.
2220 const SCEV *SExt = getSignExtendExpr(Op, Ty);
2221 if (!isa<SCEVSignExtendExpr>(SExt))
2222 return SExt;
2223
2224 // Force the cast to be folded into the operands of an addrec.
2225 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) {
2227 for (const SCEV *Op : AR->operands())
2228 Ops.push_back(getAnyExtendExpr(Op, Ty));
2229 return getAddRecExpr(Ops, AR->getLoop(), SCEV::FlagNW);
2230 }
2231
2232 // If the expression is obviously signed, use the sext cast value.
2233 if (isa<SCEVSMaxExpr>(Op))
2234 return SExt;
2235
2236 // Absent any other information, use the zext cast value.
2237 return ZExt;
2238}
2239
2240/// Process the given Ops list, which is a list of operands to be added under
2241/// the given scale, update the given map. This is a helper function for
2242/// getAddRecExpr. As an example of what it does, given a sequence of operands
2243/// that would form an add expression like this:
2244///
2245/// m + n + 13 + (A * (o + p + (B * (q + m + 29)))) + r + (-1 * r)
2246///
2247/// where A and B are constants, update the map with these values:
2248///
2249/// (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0)
2250///
2251/// and add 13 + A*B*29 to AccumulatedConstant.
2252/// This will allow getAddRecExpr to produce this:
2253///
2254/// 13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B)
2255///
2256/// This form often exposes folding opportunities that are hidden in
2257/// the original operand list.
2258///
2259/// Return true iff it appears that any interesting folding opportunities
2260/// may be exposed. This helps getAddRecExpr short-circuit extra work in
2261/// the common case where no interesting opportunities are present, and
2262/// is also used as a check to avoid infinite recursion.
2263static bool
2266 APInt &AccumulatedConstant,
2267 ArrayRef<const SCEV *> Ops, const APInt &Scale,
2268 ScalarEvolution &SE) {
2269 bool Interesting = false;
2270
2271 // Iterate over the add operands. They are sorted, with constants first.
2272 unsigned i = 0;
2273 while (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
2274 ++i;
2275 // Pull a buried constant out to the outside.
2276 if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero())
2277 Interesting = true;
2278 AccumulatedConstant += Scale * C->getAPInt();
2279 }
2280
2281 // Next comes everything else. We're especially interested in multiplies
2282 // here, but they're in the middle, so just visit the rest with one loop.
2283 for (; i != Ops.size(); ++i) {
2285 if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) {
2286 APInt NewScale =
2287 Scale * cast<SCEVConstant>(Mul->getOperand(0))->getAPInt();
2288 if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) {
2289 // A multiplication of a constant with another add; recurse.
2290 const SCEVAddExpr *Add = cast<SCEVAddExpr>(Mul->getOperand(1));
2291 Interesting |=
2292 CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2293 Add->operands(), NewScale, SE);
2294 } else {
2295 // A multiplication of a constant with some other value. Update
2296 // the map.
2297 SmallVector<const SCEV *, 4> MulOps(drop_begin(Mul->operands()));
2298 const SCEV *Key = SE.getMulExpr(MulOps);
2299 auto Pair = M.insert({Key, NewScale});
2300 if (Pair.second) {
2301 NewOps.push_back(Pair.first->first);
2302 } else {
2303 Pair.first->second += NewScale;
2304 // The map already had an entry for this value, which may indicate
2305 // a folding opportunity.
2306 Interesting = true;
2307 }
2308 }
2309 } else {
2310 // An ordinary operand. Update the map.
2311 std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair =
2312 M.insert({Ops[i], Scale});
2313 if (Pair.second) {
2314 NewOps.push_back(Pair.first->first);
2315 } else {
2316 Pair.first->second += Scale;
2317 // The map already had an entry for this value, which may indicate
2318 // a folding opportunity.
2319 Interesting = true;
2320 }
2321 }
2322 }
2323
2324 return Interesting;
2325}
2326
2328 const SCEV *LHS, const SCEV *RHS,
2329 const Instruction *CtxI) {
2330 const SCEV *(ScalarEvolution::*Operation)(const SCEV *, const SCEV *,
2331 SCEV::NoWrapFlags, unsigned);
2332 switch (BinOp) {
2333 default:
2334 llvm_unreachable("Unsupported binary op");
2335 case Instruction::Add:
2337 break;
2338 case Instruction::Sub:
2340 break;
2341 case Instruction::Mul:
2343 break;
2344 }
2345
2346 const SCEV *(ScalarEvolution::*Extension)(const SCEV *, Type *, unsigned) =
2349
2350 // Check ext(LHS op RHS) == ext(LHS) op ext(RHS)
2351 auto *NarrowTy = cast<IntegerType>(LHS->getType());
2352 auto *WideTy =
2353 IntegerType::get(NarrowTy->getContext(), NarrowTy->getBitWidth() * 2);
2354
2355 const SCEV *A = (this->*Extension)(
2356 (this->*Operation)(LHS, RHS, SCEV::FlagAnyWrap, 0), WideTy, 0);
2357 const SCEV *LHSB = (this->*Extension)(LHS, WideTy, 0);
2358 const SCEV *RHSB = (this->*Extension)(RHS, WideTy, 0);
2359 const SCEV *B = (this->*Operation)(LHSB, RHSB, SCEV::FlagAnyWrap, 0);
2360 if (A == B)
2361 return true;
2362 // Can we use context to prove the fact we need?
2363 if (!CtxI)
2364 return false;
2365 // TODO: Support mul.
2366 if (BinOp == Instruction::Mul)
2367 return false;
2368 auto *RHSC = dyn_cast<SCEVConstant>(RHS);
2369 // TODO: Lift this limitation.
2370 if (!RHSC)
2371 return false;
2372 APInt C = RHSC->getAPInt();
2373 unsigned NumBits = C.getBitWidth();
2374 bool IsSub = (BinOp == Instruction::Sub);
2375 bool IsNegativeConst = (Signed && C.isNegative());
2376 // Compute the direction and magnitude by which we need to check overflow.
2377 bool OverflowDown = IsSub ^ IsNegativeConst;
2378 APInt Magnitude = C;
2379 if (IsNegativeConst) {
2380 if (C == APInt::getSignedMinValue(NumBits))
2381 // TODO: SINT_MIN on inversion gives the same negative value, we don't
2382 // want to deal with that.
2383 return false;
2384 Magnitude = -C;
2385 }
2386
2388 if (OverflowDown) {
2389 // To avoid overflow down, we need to make sure that MIN + Magnitude <= LHS.
2390 APInt Min = Signed ? APInt::getSignedMinValue(NumBits)
2391 : APInt::getMinValue(NumBits);
2392 APInt Limit = Min + Magnitude;
2393 return isKnownPredicateAt(Pred, getConstant(Limit), LHS, CtxI);
2394 } else {
2395 // To avoid overflow up, we need to make sure that LHS <= MAX - Magnitude.
2396 APInt Max = Signed ? APInt::getSignedMaxValue(NumBits)
2397 : APInt::getMaxValue(NumBits);
2398 APInt Limit = Max - Magnitude;
2399 return isKnownPredicateAt(Pred, LHS, getConstant(Limit), CtxI);
2400 }
2401}
2402
2403std::optional<SCEV::NoWrapFlags>
2405 const OverflowingBinaryOperator *OBO) {
2406 // It cannot be done any better.
2407 if (OBO->hasNoUnsignedWrap() && OBO->hasNoSignedWrap())
2408 return std::nullopt;
2409
2411
2412 if (OBO->hasNoUnsignedWrap())
2414 if (OBO->hasNoSignedWrap())
2416
2417 bool Deduced = false;
2418
2419 if (OBO->getOpcode() != Instruction::Add &&
2420 OBO->getOpcode() != Instruction::Sub &&
2421 OBO->getOpcode() != Instruction::Mul)
2422 return std::nullopt;
2423
2424 const SCEV *LHS = getSCEV(OBO->getOperand(0));
2425 const SCEV *RHS = getSCEV(OBO->getOperand(1));
2426
2427 const Instruction *CtxI =
2429 if (!OBO->hasNoUnsignedWrap() &&
2431 /* Signed */ false, LHS, RHS, CtxI)) {
2433 Deduced = true;
2434 }
2435
2436 if (!OBO->hasNoSignedWrap() &&
2438 /* Signed */ true, LHS, RHS, CtxI)) {
2440 Deduced = true;
2441 }
2442
2443 if (Deduced)
2444 return Flags;
2445 return std::nullopt;
2446}
2447
2448// We're trying to construct a SCEV of type `Type' with `Ops' as operands and
2449// `OldFlags' as can't-wrap behavior. Infer a more aggressive set of
2450// can't-overflow flags for the operation if possible.
2454 SCEV::NoWrapFlags Flags) {
2455 using namespace std::placeholders;
2456
2457 using OBO = OverflowingBinaryOperator;
2458
2459 bool CanAnalyze =
2461 (void)CanAnalyze;
2462 assert(CanAnalyze && "don't call from other places!");
2463
2464 int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
2465 SCEV::NoWrapFlags SignOrUnsignWrap =
2466 ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2467
2468 // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
2469 auto IsKnownNonNegative = [&](const SCEV *S) {
2470 return SE->isKnownNonNegative(S);
2471 };
2472
2473 if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative))
2474 Flags =
2475 ScalarEvolution::setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
2476
2477 SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2478
2479 if (SignOrUnsignWrap != SignOrUnsignMask &&
2480 (Type == scAddExpr || Type == scMulExpr) && Ops.size() == 2 &&
2481 isa<SCEVConstant>(Ops[0])) {
2482
2483 auto Opcode = [&] {
2484 switch (Type) {
2485 case scAddExpr:
2486 return Instruction::Add;
2487 case scMulExpr:
2488 return Instruction::Mul;
2489 default:
2490 llvm_unreachable("Unexpected SCEV op.");
2491 }
2492 }();
2493
2494 const APInt &C = cast<SCEVConstant>(Ops[0])->getAPInt();
2495
2496 // (A <opcode> C) --> (A <opcode> C)<nsw> if the op doesn't sign overflow.
2497 if (!(SignOrUnsignWrap & SCEV::FlagNSW)) {
2499 Opcode, C, OBO::NoSignedWrap);
2500 if (NSWRegion.contains(SE->getSignedRange(Ops[1])))
2502 }
2503
2504 // (A <opcode> C) --> (A <opcode> C)<nuw> if the op doesn't unsign overflow.
2505 if (!(SignOrUnsignWrap & SCEV::FlagNUW)) {
2507 Opcode, C, OBO::NoUnsignedWrap);
2508 if (NUWRegion.contains(SE->getUnsignedRange(Ops[1])))
2510 }
2511 }
2512
2513 // <0,+,nonnegative><nw> is also nuw
2514 // TODO: Add corresponding nsw case
2516 !ScalarEvolution::hasFlags(Flags, SCEV::FlagNUW) && Ops.size() == 2 &&
2517 Ops[0]->isZero() && IsKnownNonNegative(Ops[1]))
2519
2520 // both (udiv X, Y) * Y and Y * (udiv X, Y) are always NUW
2522 Ops.size() == 2) {
2523 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[0]))
2524 if (UDiv->getOperand(1) == Ops[1])
2526 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[1]))
2527 if (UDiv->getOperand(1) == Ops[0])
2529 }
2530
2531 return Flags;
2532}
2533
2535 return isLoopInvariant(S, L) && properlyDominates(S, L->getHeader());
2536}
2537
2538/// Get a canonical add expression, or something simpler if possible.
2540 SCEV::NoWrapFlags OrigFlags,
2541 unsigned Depth) {
2542 assert(!(OrigFlags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) &&
2543 "only nuw or nsw allowed");
2544 assert(!Ops.empty() && "Cannot get empty add!");
2545 if (Ops.size() == 1) return Ops[0];
2546#ifndef NDEBUG
2547 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
2548 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2549 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
2550 "SCEVAddExpr operand types don't match!");
2551 unsigned NumPtrs = count_if(
2552 Ops, [](const SCEV *Op) { return Op->getType()->isPointerTy(); });
2553 assert(NumPtrs <= 1 && "add has at most one pointer operand");
2554#endif
2555
2556 const SCEV *Folded = constantFoldAndGroupOps(
2557 *this, LI, DT, Ops,
2558 [](const APInt &C1, const APInt &C2) { return C1 + C2; },
2559 [](const APInt &C) { return C.isZero(); }, // identity
2560 [](const APInt &C) { return false; }); // absorber
2561 if (Folded)
2562 return Folded;
2563
2564 unsigned Idx = isa<SCEVConstant>(Ops[0]) ? 1 : 0;
2565
2566 // Delay expensive flag strengthening until necessary.
2567 auto ComputeFlags = [this, OrigFlags](ArrayRef<const SCEV *> Ops) {
2568 return StrengthenNoWrapFlags(this, scAddExpr, Ops, OrigFlags);
2569 };
2570
2571 // Limit recursion calls depth.
2573 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2574
2575 if (SCEV *S = findExistingSCEVInCache(scAddExpr, Ops)) {
2576 // Don't strengthen flags if we have no new information.
2577 SCEVAddExpr *Add = static_cast<SCEVAddExpr *>(S);
2578 if (Add->getNoWrapFlags(OrigFlags) != OrigFlags)
2579 Add->setNoWrapFlags(ComputeFlags(Ops));
2580 return S;
2581 }
2582
2583 // Okay, check to see if the same value occurs in the operand list more than
2584 // once. If so, merge them together into an multiply expression. Since we
2585 // sorted the list, these values are required to be adjacent.
2586 Type *Ty = Ops[0]->getType();
2587 bool FoundMatch = false;
2588 for (unsigned i = 0, e = Ops.size(); i != e-1; ++i)
2589 if (Ops[i] == Ops[i+1]) { // X + Y + Y --> X + Y*2
2590 // Scan ahead to count how many equal operands there are.
2591 unsigned Count = 2;
2592 while (i+Count != e && Ops[i+Count] == Ops[i])
2593 ++Count;
2594 // Merge the values into a multiply.
2595 const SCEV *Scale = getConstant(Ty, Count);
2596 const SCEV *Mul = getMulExpr(Scale, Ops[i], SCEV::FlagAnyWrap, Depth + 1);
2597 if (Ops.size() == Count)
2598 return Mul;
2599 Ops[i] = Mul;
2600 Ops.erase(Ops.begin()+i+1, Ops.begin()+i+Count);
2601 --i; e -= Count - 1;
2602 FoundMatch = true;
2603 }
2604 if (FoundMatch)
2605 return getAddExpr(Ops, OrigFlags, Depth + 1);
2606
2607 // Check for truncates. If all the operands are truncated from the same
2608 // type, see if factoring out the truncate would permit the result to be
2609 // folded. eg., n*trunc(x) + m*trunc(y) --> trunc(trunc(m)*x + trunc(n)*y)
2610 // if the contents of the resulting outer trunc fold to something simple.
2611 auto FindTruncSrcType = [&]() -> Type * {
2612 // We're ultimately looking to fold an addrec of truncs and muls of only
2613 // constants and truncs, so if we find any other types of SCEV
2614 // as operands of the addrec then we bail and return nullptr here.
2615 // Otherwise, we return the type of the operand of a trunc that we find.
2616 if (auto *T = dyn_cast<SCEVTruncateExpr>(Ops[Idx]))
2617 return T->getOperand()->getType();
2618 if (const auto *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
2619 const auto *LastOp = Mul->getOperand(Mul->getNumOperands() - 1);
2620 if (const auto *T = dyn_cast<SCEVTruncateExpr>(LastOp))
2621 return T->getOperand()->getType();
2622 }
2623 return nullptr;
2624 };
2625 if (auto *SrcType = FindTruncSrcType()) {
2627 bool Ok = true;
2628 // Check all the operands to see if they can be represented in the
2629 // source type of the truncate.
2630 for (const SCEV *Op : Ops) {
2632 if (T->getOperand()->getType() != SrcType) {
2633 Ok = false;
2634 break;
2635 }
2636 LargeOps.push_back(T->getOperand());
2637 } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Op)) {
2638 LargeOps.push_back(getAnyExtendExpr(C, SrcType));
2639 } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Op)) {
2640 SmallVector<const SCEV *, 8> LargeMulOps;
2641 for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) {
2642 if (const SCEVTruncateExpr *T =
2643 dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) {
2644 if (T->getOperand()->getType() != SrcType) {
2645 Ok = false;
2646 break;
2647 }
2648 LargeMulOps.push_back(T->getOperand());
2649 } else if (const auto *C = dyn_cast<SCEVConstant>(M->getOperand(j))) {
2650 LargeMulOps.push_back(getAnyExtendExpr(C, SrcType));
2651 } else {
2652 Ok = false;
2653 break;
2654 }
2655 }
2656 if (Ok)
2657 LargeOps.push_back(getMulExpr(LargeMulOps, SCEV::FlagAnyWrap, Depth + 1));
2658 } else {
2659 Ok = false;
2660 break;
2661 }
2662 }
2663 if (Ok) {
2664 // Evaluate the expression in the larger type.
2665 const SCEV *Fold = getAddExpr(LargeOps, SCEV::FlagAnyWrap, Depth + 1);
2666 // If it folds to something simple, use it. Otherwise, don't.
2667 if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
2668 return getTruncateExpr(Fold, Ty);
2669 }
2670 }
2671
2672 if (Ops.size() == 2) {
2673 // Check if we have an expression of the form ((X + C1) - C2), where C1 and
2674 // C2 can be folded in a way that allows retaining wrapping flags of (X +
2675 // C1).
2676 const SCEV *A = Ops[0];
2677 const SCEV *B = Ops[1];
2678 auto *AddExpr = dyn_cast<SCEVAddExpr>(B);
2679 auto *C = dyn_cast<SCEVConstant>(A);
2680 if (AddExpr && C && isa<SCEVConstant>(AddExpr->getOperand(0))) {
2681 auto C1 = cast<SCEVConstant>(AddExpr->getOperand(0))->getAPInt();
2682 auto C2 = C->getAPInt();
2683 SCEV::NoWrapFlags PreservedFlags = SCEV::FlagAnyWrap;
2684
2685 APInt ConstAdd = C1 + C2;
2686 auto AddFlags = AddExpr->getNoWrapFlags();
2687 // Adding a smaller constant is NUW if the original AddExpr was NUW.
2689 ConstAdd.ule(C1)) {
2690 PreservedFlags =
2692 }
2693
2694 // Adding a constant with the same sign and small magnitude is NSW, if the
2695 // original AddExpr was NSW.
2697 C1.isSignBitSet() == ConstAdd.isSignBitSet() &&
2698 ConstAdd.abs().ule(C1.abs())) {
2699 PreservedFlags =
2701 }
2702
2703 if (PreservedFlags != SCEV::FlagAnyWrap) {
2704 SmallVector<const SCEV *, 4> NewOps(AddExpr->operands());
2705 NewOps[0] = getConstant(ConstAdd);
2706 return getAddExpr(NewOps, PreservedFlags);
2707 }
2708 }
2709
2710 // Try to push the constant operand into a ZExt: A + zext (-A + B) -> zext
2711 // (B), if trunc (A) + -A + B does not unsigned-wrap.
2712 const SCEVAddExpr *InnerAdd;
2713 if (match(B, m_scev_ZExt(m_scev_Add(InnerAdd)))) {
2714 const SCEV *NarrowA = getTruncateExpr(A, InnerAdd->getType());
2715 if (NarrowA == getNegativeSCEV(InnerAdd->getOperand(0)) &&
2716 getZeroExtendExpr(NarrowA, B->getType()) == A &&
2717 hasFlags(StrengthenNoWrapFlags(this, scAddExpr, {NarrowA, InnerAdd},
2719 SCEV::FlagNUW)) {
2720 return getZeroExtendExpr(getAddExpr(NarrowA, InnerAdd), B->getType());
2721 }
2722 }
2723 }
2724
2725 // Canonicalize (-1 * urem X, Y) + X --> (Y * X/Y)
2726 const SCEV *Y;
2727 if (Ops.size() == 2 &&
2728 match(Ops[0],
2730 m_scev_URem(m_scev_Specific(Ops[1]), m_SCEV(Y), *this))))
2731 return getMulExpr(Y, getUDivExpr(Ops[1], Y));
2732
2733 // Skip past any other cast SCEVs.
2734 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
2735 ++Idx;
2736
2737 // If there are add operands they would be next.
2738 if (Idx < Ops.size()) {
2739 bool DeletedAdd = false;
2740 // If the original flags and all inlined SCEVAddExprs are NUW, use the
2741 // common NUW flag for expression after inlining. Other flags cannot be
2742 // preserved, because they may depend on the original order of operations.
2743 SCEV::NoWrapFlags CommonFlags = maskFlags(OrigFlags, SCEV::FlagNUW);
2744 while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
2745 if (Ops.size() > AddOpsInlineThreshold ||
2746 Add->getNumOperands() > AddOpsInlineThreshold)
2747 break;
2748 // If we have an add, expand the add operands onto the end of the operands
2749 // list.
2750 Ops.erase(Ops.begin()+Idx);
2751 append_range(Ops, Add->operands());
2752 DeletedAdd = true;
2753 CommonFlags = maskFlags(CommonFlags, Add->getNoWrapFlags());
2754 }
2755
2756 // If we deleted at least one add, we added operands to the end of the list,
2757 // and they are not necessarily sorted. Recurse to resort and resimplify
2758 // any operands we just acquired.
2759 if (DeletedAdd)
2760 return getAddExpr(Ops, CommonFlags, Depth + 1);
2761 }
2762
2763 // Skip over the add expression until we get to a multiply.
2764 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
2765 ++Idx;
2766
2767 // Check to see if there are any folding opportunities present with
2768 // operands multiplied by constant values.
2769 if (Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx])) {
2773 APInt AccumulatedConstant(BitWidth, 0);
2774 if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2775 Ops, APInt(BitWidth, 1), *this)) {
2776 struct APIntCompare {
2777 bool operator()(const APInt &LHS, const APInt &RHS) const {
2778 return LHS.ult(RHS);
2779 }
2780 };
2781
2782 // Some interesting folding opportunity is present, so its worthwhile to
2783 // re-generate the operands list. Group the operands by constant scale,
2784 // to avoid multiplying by the same constant scale multiple times.
2785 std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare> MulOpLists;
2786 for (const SCEV *NewOp : NewOps)
2787 MulOpLists[M.find(NewOp)->second].push_back(NewOp);
2788 // Re-generate the operands list.
2789 Ops.clear();
2790 if (AccumulatedConstant != 0)
2791 Ops.push_back(getConstant(AccumulatedConstant));
2792 for (auto &MulOp : MulOpLists) {
2793 if (MulOp.first == 1) {
2794 Ops.push_back(getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1));
2795 } else if (MulOp.first != 0) {
2796 Ops.push_back(getMulExpr(
2797 getConstant(MulOp.first),
2798 getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1),
2799 SCEV::FlagAnyWrap, Depth + 1));
2800 }
2801 }
2802 if (Ops.empty())
2803 return getZero(Ty);
2804 if (Ops.size() == 1)
2805 return Ops[0];
2806 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2807 }
2808 }
2809
2810 // If we are adding something to a multiply expression, make sure the
2811 // something is not already an operand of the multiply. If so, merge it into
2812 // the multiply.
2813 for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) {
2814 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]);
2815 for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) {
2816 const SCEV *MulOpSCEV = Mul->getOperand(MulOp);
2817 if (isa<SCEVConstant>(MulOpSCEV))
2818 continue;
2819 for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
2820 if (MulOpSCEV == Ops[AddOp]) {
2821 // Fold W + X + (X * Y * Z) --> W + (X * ((Y*Z)+1))
2822 const SCEV *InnerMul = Mul->getOperand(MulOp == 0);
2823 if (Mul->getNumOperands() != 2) {
2824 // If the multiply has more than two operands, we must get the
2825 // Y*Z term.
2827 Mul->operands().take_front(MulOp));
2828 append_range(MulOps, Mul->operands().drop_front(MulOp + 1));
2829 InnerMul = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2830 }
2831 SmallVector<const SCEV *, 2> TwoOps = {getOne(Ty), InnerMul};
2832 const SCEV *AddOne = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2833 const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV,
2835 if (Ops.size() == 2) return OuterMul;
2836 if (AddOp < Idx) {
2837 Ops.erase(Ops.begin()+AddOp);
2838 Ops.erase(Ops.begin()+Idx-1);
2839 } else {
2840 Ops.erase(Ops.begin()+Idx);
2841 Ops.erase(Ops.begin()+AddOp-1);
2842 }
2843 Ops.push_back(OuterMul);
2844 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2845 }
2846
2847 // Check this multiply against other multiplies being added together.
2848 for (unsigned OtherMulIdx = Idx+1;
2849 OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]);
2850 ++OtherMulIdx) {
2851 const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]);
2852 // If MulOp occurs in OtherMul, we can fold the two multiplies
2853 // together.
2854 for (unsigned OMulOp = 0, e = OtherMul->getNumOperands();
2855 OMulOp != e; ++OMulOp)
2856 if (OtherMul->getOperand(OMulOp) == MulOpSCEV) {
2857 // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E))
2858 const SCEV *InnerMul1 = Mul->getOperand(MulOp == 0);
2859 if (Mul->getNumOperands() != 2) {
2861 Mul->operands().take_front(MulOp));
2862 append_range(MulOps, Mul->operands().drop_front(MulOp+1));
2863 InnerMul1 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2864 }
2865 const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0);
2866 if (OtherMul->getNumOperands() != 2) {
2868 OtherMul->operands().take_front(OMulOp));
2869 append_range(MulOps, OtherMul->operands().drop_front(OMulOp+1));
2870 InnerMul2 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2871 }
2872 SmallVector<const SCEV *, 2> TwoOps = {InnerMul1, InnerMul2};
2873 const SCEV *InnerMulSum =
2874 getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2875 const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum,
2877 if (Ops.size() == 2) return OuterMul;
2878 Ops.erase(Ops.begin()+Idx);
2879 Ops.erase(Ops.begin()+OtherMulIdx-1);
2880 Ops.push_back(OuterMul);
2881 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2882 }
2883 }
2884 }
2885 }
2886
2887 // If there are any add recurrences in the operands list, see if any other
2888 // added values are loop invariant. If so, we can fold them into the
2889 // recurrence.
2890 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
2891 ++Idx;
2892
2893 // Scan over all recurrences, trying to fold loop invariants into them.
2894 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
2895 // Scan all of the other operands to this add and add them to the vector if
2896 // they are loop invariant w.r.t. the recurrence.
2898 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
2899 const Loop *AddRecLoop = AddRec->getLoop();
2900 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2901 if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) {
2902 LIOps.push_back(Ops[i]);
2903 Ops.erase(Ops.begin()+i);
2904 --i; --e;
2905 }
2906
2907 // If we found some loop invariants, fold them into the recurrence.
2908 if (!LIOps.empty()) {
2909 // Compute nowrap flags for the addition of the loop-invariant ops and
2910 // the addrec. Temporarily push it as an operand for that purpose. These
2911 // flags are valid in the scope of the addrec only.
2912 LIOps.push_back(AddRec);
2913 SCEV::NoWrapFlags Flags = ComputeFlags(LIOps);
2914 LIOps.pop_back();
2915
2916 // NLI + LI + {Start,+,Step} --> NLI + {LI+Start,+,Step}
2917 LIOps.push_back(AddRec->getStart());
2918
2919 SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2920
2921 // It is not in general safe to propagate flags valid on an add within
2922 // the addrec scope to one outside it. We must prove that the inner
2923 // scope is guaranteed to execute if the outer one does to be able to
2924 // safely propagate. We know the program is undefined if poison is
2925 // produced on the inner scoped addrec. We also know that *for this use*
2926 // the outer scoped add can't overflow (because of the flags we just
2927 // computed for the inner scoped add) without the program being undefined.
2928 // Proving that entry to the outer scope neccesitates entry to the inner
2929 // scope, thus proves the program undefined if the flags would be violated
2930 // in the outer scope.
2931 SCEV::NoWrapFlags AddFlags = Flags;
2932 if (AddFlags != SCEV::FlagAnyWrap) {
2933 auto *DefI = getDefiningScopeBound(LIOps);
2934 auto *ReachI = &*AddRecLoop->getHeader()->begin();
2935 if (!isGuaranteedToTransferExecutionTo(DefI, ReachI))
2936 AddFlags = SCEV::FlagAnyWrap;
2937 }
2938 AddRecOps[0] = getAddExpr(LIOps, AddFlags, Depth + 1);
2939
2940 // Build the new addrec. Propagate the NUW and NSW flags if both the
2941 // outer add and the inner addrec are guaranteed to have no overflow.
2942 // Always propagate NW.
2943 Flags = AddRec->getNoWrapFlags(setFlags(Flags, SCEV::FlagNW));
2944 const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags);
2945
2946 // If all of the other operands were loop invariant, we are done.
2947 if (Ops.size() == 1) return NewRec;
2948
2949 // Otherwise, add the folded AddRec by the non-invariant parts.
2950 for (unsigned i = 0;; ++i)
2951 if (Ops[i] == AddRec) {
2952 Ops[i] = NewRec;
2953 break;
2954 }
2955 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2956 }
2957
2958 // Okay, if there weren't any loop invariants to be folded, check to see if
2959 // there are multiple AddRec's with the same loop induction variable being
2960 // added together. If so, we can fold them.
2961 for (unsigned OtherIdx = Idx+1;
2962 OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2963 ++OtherIdx) {
2964 // We expect the AddRecExpr's to be sorted in reverse dominance order,
2965 // so that the 1st found AddRecExpr is dominated by all others.
2966 assert(DT.dominates(
2967 cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()->getHeader(),
2968 AddRec->getLoop()->getHeader()) &&
2969 "AddRecExprs are not sorted in reverse dominance order?");
2970 if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) {
2971 // Other + {A,+,B}<L> + {C,+,D}<L> --> Other + {A+C,+,B+D}<L>
2972 SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2973 for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2974 ++OtherIdx) {
2975 const auto *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
2976 if (OtherAddRec->getLoop() == AddRecLoop) {
2977 for (unsigned i = 0, e = OtherAddRec->getNumOperands();
2978 i != e; ++i) {
2979 if (i >= AddRecOps.size()) {
2980 append_range(AddRecOps, OtherAddRec->operands().drop_front(i));
2981 break;
2982 }
2984 AddRecOps[i], OtherAddRec->getOperand(i)};
2985 AddRecOps[i] = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2986 }
2987 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
2988 }
2989 }
2990 // Step size has changed, so we cannot guarantee no self-wraparound.
2991 Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap);
2992 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2993 }
2994 }
2995
2996 // Otherwise couldn't fold anything into this recurrence. Move onto the
2997 // next one.
2998 }
2999
3000 // Okay, it looks like we really DO need an add expr. Check to see if we
3001 // already have one, otherwise create a new one.
3002 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
3003}
3004
3005const SCEV *
3006ScalarEvolution::getOrCreateAddExpr(ArrayRef<const SCEV *> Ops,
3007 SCEV::NoWrapFlags Flags) {
3009 ID.AddInteger(scAddExpr);
3010 for (const SCEV *Op : Ops)
3011 ID.AddPointer(Op);
3012 void *IP = nullptr;
3013 SCEVAddExpr *S =
3014 static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3015 if (!S) {
3016 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3018 S = new (SCEVAllocator)
3019 SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size());
3020 UniqueSCEVs.InsertNode(S, IP);
3021 registerUser(S, Ops);
3022 }
3023 S->setNoWrapFlags(Flags);
3024 return S;
3025}
3026
3027const SCEV *
3028ScalarEvolution::getOrCreateAddRecExpr(ArrayRef<const SCEV *> Ops,
3029 const Loop *L, SCEV::NoWrapFlags Flags) {
3030 FoldingSetNodeID ID;
3031 ID.AddInteger(scAddRecExpr);
3032 for (const SCEV *Op : Ops)
3033 ID.AddPointer(Op);
3034 ID.AddPointer(L);
3035 void *IP = nullptr;
3036 SCEVAddRecExpr *S =
3037 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3038 if (!S) {
3039 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3041 S = new (SCEVAllocator)
3042 SCEVAddRecExpr(ID.Intern(SCEVAllocator), O, Ops.size(), L);
3043 UniqueSCEVs.InsertNode(S, IP);
3044 LoopUsers[L].push_back(S);
3045 registerUser(S, Ops);
3046 }
3047 setNoWrapFlags(S, Flags);
3048 return S;
3049}
3050
3051const SCEV *
3052ScalarEvolution::getOrCreateMulExpr(ArrayRef<const SCEV *> Ops,
3053 SCEV::NoWrapFlags Flags) {
3054 FoldingSetNodeID ID;
3055 ID.AddInteger(scMulExpr);
3056 for (const SCEV *Op : Ops)
3057 ID.AddPointer(Op);
3058 void *IP = nullptr;
3059 SCEVMulExpr *S =
3060 static_cast<SCEVMulExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3061 if (!S) {
3062 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3064 S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator),
3065 O, Ops.size());
3066 UniqueSCEVs.InsertNode(S, IP);
3067 registerUser(S, Ops);
3068 }
3069 S->setNoWrapFlags(Flags);
3070 return S;
3071}
3072
3073static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) {
3074 uint64_t k = i*j;
3075 if (j > 1 && k / j != i) Overflow = true;
3076 return k;
3077}
3078
3079/// Compute the result of "n choose k", the binomial coefficient. If an
3080/// intermediate computation overflows, Overflow will be set and the return will
3081/// be garbage. Overflow is not cleared on absence of overflow.
3082static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) {
3083 // We use the multiplicative formula:
3084 // n(n-1)(n-2)...(n-(k-1)) / k(k-1)(k-2)...1 .
3085 // At each iteration, we take the n-th term of the numeral and divide by the
3086 // (k-n)th term of the denominator. This division will always produce an
3087 // integral result, and helps reduce the chance of overflow in the
3088 // intermediate computations. However, we can still overflow even when the
3089 // final result would fit.
3090
3091 if (n == 0 || n == k) return 1;
3092 if (k > n) return 0;
3093
3094 if (k > n/2)
3095 k = n-k;
3096
3097 uint64_t r = 1;
3098 for (uint64_t i = 1; i <= k; ++i) {
3099 r = umul_ov(r, n-(i-1), Overflow);
3100 r /= i;
3101 }
3102 return r;
3103}
3104
3105/// Determine if any of the operands in this SCEV are a constant or if
3106/// any of the add or multiply expressions in this SCEV contain a constant.
3107static bool containsConstantInAddMulChain(const SCEV *StartExpr) {
3108 struct FindConstantInAddMulChain {
3109 bool FoundConstant = false;
3110
3111 bool follow(const SCEV *S) {
3112 FoundConstant |= isa<SCEVConstant>(S);
3113 return isa<SCEVAddExpr>(S) || isa<SCEVMulExpr>(S);
3114 }
3115
3116 bool isDone() const {
3117 return FoundConstant;
3118 }
3119 };
3120
3121 FindConstantInAddMulChain F;
3123 ST.visitAll(StartExpr);
3124 return F.FoundConstant;
3125}
3126
3127/// Get a canonical multiply expression, or something simpler if possible.
3129 SCEV::NoWrapFlags OrigFlags,
3130 unsigned Depth) {
3131 assert(OrigFlags == maskFlags(OrigFlags, SCEV::FlagNUW | SCEV::FlagNSW) &&
3132 "only nuw or nsw allowed");
3133 assert(!Ops.empty() && "Cannot get empty mul!");
3134 if (Ops.size() == 1) return Ops[0];
3135#ifndef NDEBUG
3136 Type *ETy = Ops[0]->getType();
3137 assert(!ETy->isPointerTy());
3138 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
3139 assert(Ops[i]->getType() == ETy &&
3140 "SCEVMulExpr operand types don't match!");
3141#endif
3142
3143 const SCEV *Folded = constantFoldAndGroupOps(
3144 *this, LI, DT, Ops,
3145 [](const APInt &C1, const APInt &C2) { return C1 * C2; },
3146 [](const APInt &C) { return C.isOne(); }, // identity
3147 [](const APInt &C) { return C.isZero(); }); // absorber
3148 if (Folded)
3149 return Folded;
3150
3151 // Delay expensive flag strengthening until necessary.
3152 auto ComputeFlags = [this, OrigFlags](ArrayRef<const SCEV *> Ops) {
3153 return StrengthenNoWrapFlags(this, scMulExpr, Ops, OrigFlags);
3154 };
3155
3156 // Limit recursion calls depth.
3158 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3159
3160 if (SCEV *S = findExistingSCEVInCache(scMulExpr, Ops)) {
3161 // Don't strengthen flags if we have no new information.
3162 SCEVMulExpr *Mul = static_cast<SCEVMulExpr *>(S);
3163 if (Mul->getNoWrapFlags(OrigFlags) != OrigFlags)
3164 Mul->setNoWrapFlags(ComputeFlags(Ops));
3165 return S;
3166 }
3167
3168 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3169 if (Ops.size() == 2) {
3170 // C1*(C2+V) -> C1*C2 + C1*V
3171 // If any of Add's ops are Adds or Muls with a constant, apply this
3172 // transformation as well.
3173 //
3174 // TODO: There are some cases where this transformation is not
3175 // profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of
3176 // this transformation should be narrowed down.
3177 const SCEV *Op0, *Op1;
3178 if (match(Ops[1], m_scev_Add(m_SCEV(Op0), m_SCEV(Op1))) &&
3180 const SCEV *LHS = getMulExpr(LHSC, Op0, SCEV::FlagAnyWrap, Depth + 1);
3181 const SCEV *RHS = getMulExpr(LHSC, Op1, SCEV::FlagAnyWrap, Depth + 1);
3182 return getAddExpr(LHS, RHS, SCEV::FlagAnyWrap, Depth + 1);
3183 }
3184
3185 if (Ops[0]->isAllOnesValue()) {
3186 // If we have a mul by -1 of an add, try distributing the -1 among the
3187 // add operands.
3188 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) {
3190 bool AnyFolded = false;
3191 for (const SCEV *AddOp : Add->operands()) {
3192 const SCEV *Mul = getMulExpr(Ops[0], AddOp, SCEV::FlagAnyWrap,
3193 Depth + 1);
3194 if (!isa<SCEVMulExpr>(Mul)) AnyFolded = true;
3195 NewOps.push_back(Mul);
3196 }
3197 if (AnyFolded)
3198 return getAddExpr(NewOps, SCEV::FlagAnyWrap, Depth + 1);
3199 } else if (const auto *AddRec = dyn_cast<SCEVAddRecExpr>(Ops[1])) {
3200 // Negation preserves a recurrence's no self-wrap property.
3202 for (const SCEV *AddRecOp : AddRec->operands())
3203 Operands.push_back(getMulExpr(Ops[0], AddRecOp, SCEV::FlagAnyWrap,
3204 Depth + 1));
3205 // Let M be the minimum representable signed value. AddRec with nsw
3206 // multiplied by -1 can have signed overflow if and only if it takes a
3207 // value of M: M * (-1) would stay M and (M + 1) * (-1) would be the
3208 // maximum signed value. In all other cases signed overflow is
3209 // impossible.
3210 auto FlagsMask = SCEV::FlagNW;
3211 if (hasFlags(AddRec->getNoWrapFlags(), SCEV::FlagNSW)) {
3212 auto MinInt =
3213 APInt::getSignedMinValue(getTypeSizeInBits(AddRec->getType()));
3214 if (getSignedRangeMin(AddRec) != MinInt)
3215 FlagsMask = setFlags(FlagsMask, SCEV::FlagNSW);
3216 }
3217 return getAddRecExpr(Operands, AddRec->getLoop(),
3218 AddRec->getNoWrapFlags(FlagsMask));
3219 }
3220 }
3221
3222 // Try to push the constant operand into a ZExt: C * zext (A + B) ->
3223 // zext (C*A + C*B) if trunc (C) * (A + B) does not unsigned-wrap.
3224 const SCEVAddExpr *InnerAdd;
3225 if (match(Ops[1], m_scev_ZExt(m_scev_Add(InnerAdd)))) {
3226 const SCEV *NarrowC = getTruncateExpr(LHSC, InnerAdd->getType());
3227 if (isa<SCEVConstant>(InnerAdd->getOperand(0)) &&
3228 getZeroExtendExpr(NarrowC, Ops[1]->getType()) == LHSC &&
3229 hasFlags(StrengthenNoWrapFlags(this, scMulExpr, {NarrowC, InnerAdd},
3231 SCEV::FlagNUW)) {
3232 auto *Res = getMulExpr(NarrowC, InnerAdd, SCEV::FlagNUW, Depth + 1);
3233 return getZeroExtendExpr(Res, Ops[1]->getType(), Depth + 1);
3234 };
3235 }
3236
3237 // Try to fold (C1 * D /u C2) -> C1/C2 * D, if C1 and C2 are powers-of-2,
3238 // D is a multiple of C2, and C1 is a multiple of C2. If C2 is a multiple
3239 // of C1, fold to (D /u (C2 /u C1)).
3240 const SCEV *D;
3241 APInt C1V = LHSC->getAPInt();
3242 // (C1 * D /u C2) == -1 * -C1 * D /u C2 when C1 != INT_MIN. Don't treat -1
3243 // as -1 * 1, as it won't enable additional folds.
3244 if (C1V.isNegative() && !C1V.isMinSignedValue() && !C1V.isAllOnes())
3245 C1V = C1V.abs();
3246 const SCEVConstant *C2;
3247 if (C1V.isPowerOf2() &&
3249 C2->getAPInt().isPowerOf2() &&
3250 C1V.logBase2() <= getMinTrailingZeros(D)) {
3251 const SCEV *NewMul = nullptr;
3252 if (C1V.uge(C2->getAPInt())) {
3253 NewMul = getMulExpr(getUDivExpr(getConstant(C1V), C2), D);
3254 } else if (C2->getAPInt().logBase2() <= getMinTrailingZeros(D)) {
3255 assert(C1V.ugt(1) && "C1 <= 1 should have been folded earlier");
3256 NewMul = getUDivExpr(D, getUDivExpr(C2, getConstant(C1V)));
3257 }
3258 if (NewMul)
3259 return C1V == LHSC->getAPInt() ? NewMul : getNegativeSCEV(NewMul);
3260 }
3261 }
3262 }
3263
3264 // Skip over the add expression until we get to a multiply.
3265 unsigned Idx = 0;
3266 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
3267 ++Idx;
3268
3269 // If there are mul operands inline them all into this expression.
3270 if (Idx < Ops.size()) {
3271 bool DeletedMul = false;
3272 while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
3273 if (Ops.size() > MulOpsInlineThreshold)
3274 break;
3275 // If we have an mul, expand the mul operands onto the end of the
3276 // operands list.
3277 Ops.erase(Ops.begin()+Idx);
3278 append_range(Ops, Mul->operands());
3279 DeletedMul = true;
3280 }
3281
3282 // If we deleted at least one mul, we added operands to the end of the
3283 // list, and they are not necessarily sorted. Recurse to resort and
3284 // resimplify any operands we just acquired.
3285 if (DeletedMul)
3286 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3287 }
3288
3289 // If there are any add recurrences in the operands list, see if any other
3290 // added values are loop invariant. If so, we can fold them into the
3291 // recurrence.
3292 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
3293 ++Idx;
3294
3295 // Scan over all recurrences, trying to fold loop invariants into them.
3296 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
3297 // Scan all of the other operands to this mul and add them to the vector
3298 // if they are loop invariant w.r.t. the recurrence.
3300 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
3301 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3302 if (isAvailableAtLoopEntry(Ops[i], AddRec->getLoop())) {
3303 LIOps.push_back(Ops[i]);
3304 Ops.erase(Ops.begin()+i);
3305 --i; --e;
3306 }
3307
3308 // If we found some loop invariants, fold them into the recurrence.
3309 if (!LIOps.empty()) {
3310 // NLI * LI * {Start,+,Step} --> NLI * {LI*Start,+,LI*Step}
3312 NewOps.reserve(AddRec->getNumOperands());
3313 const SCEV *Scale = getMulExpr(LIOps, SCEV::FlagAnyWrap, Depth + 1);
3314
3315 // If both the mul and addrec are nuw, we can preserve nuw.
3316 // If both the mul and addrec are nsw, we can only preserve nsw if either
3317 // a) they are also nuw, or
3318 // b) all multiplications of addrec operands with scale are nsw.
3319 SCEV::NoWrapFlags Flags =
3320 AddRec->getNoWrapFlags(ComputeFlags({Scale, AddRec}));
3321
3322 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
3323 NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i),
3324 SCEV::FlagAnyWrap, Depth + 1));
3325
3326 if (hasFlags(Flags, SCEV::FlagNSW) && !hasFlags(Flags, SCEV::FlagNUW)) {
3328 Instruction::Mul, getSignedRange(Scale),
3330 if (!NSWRegion.contains(getSignedRange(AddRec->getOperand(i))))
3331 Flags = clearFlags(Flags, SCEV::FlagNSW);
3332 }
3333 }
3334
3335 const SCEV *NewRec = getAddRecExpr(NewOps, AddRec->getLoop(), Flags);
3336
3337 // If all of the other operands were loop invariant, we are done.
3338 if (Ops.size() == 1) return NewRec;
3339
3340 // Otherwise, multiply the folded AddRec by the non-invariant parts.
3341 for (unsigned i = 0;; ++i)
3342 if (Ops[i] == AddRec) {
3343 Ops[i] = NewRec;
3344 break;
3345 }
3346 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3347 }
3348
3349 // Okay, if there weren't any loop invariants to be folded, check to see
3350 // if there are multiple AddRec's with the same loop induction variable
3351 // being multiplied together. If so, we can fold them.
3352
3353 // {A1,+,A2,+,...,+,An}<L> * {B1,+,B2,+,...,+,Bn}<L>
3354 // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [
3355 // choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z
3356 // ]]],+,...up to x=2n}.
3357 // Note that the arguments to choose() are always integers with values
3358 // known at compile time, never SCEV objects.
3359 //
3360 // The implementation avoids pointless extra computations when the two
3361 // addrec's are of different length (mathematically, it's equivalent to
3362 // an infinite stream of zeros on the right).
3363 bool OpsModified = false;
3364 for (unsigned OtherIdx = Idx+1;
3365 OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
3366 ++OtherIdx) {
3367 const SCEVAddRecExpr *OtherAddRec =
3368 dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]);
3369 if (!OtherAddRec || OtherAddRec->getLoop() != AddRec->getLoop())
3370 continue;
3371
3372 // Limit max number of arguments to avoid creation of unreasonably big
3373 // SCEVAddRecs with very complex operands.
3374 if (AddRec->getNumOperands() + OtherAddRec->getNumOperands() - 1 >
3375 MaxAddRecSize || hasHugeExpression({AddRec, OtherAddRec}))
3376 continue;
3377
3378 bool Overflow = false;
3379 Type *Ty = AddRec->getType();
3380 bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64;
3382 for (int x = 0, xe = AddRec->getNumOperands() +
3383 OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) {
3385 for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) {
3386 uint64_t Coeff1 = Choose(x, 2*x - y, Overflow);
3387 for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1),
3388 ze = std::min(x+1, (int)OtherAddRec->getNumOperands());
3389 z < ze && !Overflow; ++z) {
3390 uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow);
3391 uint64_t Coeff;
3392 if (LargerThan64Bits)
3393 Coeff = umul_ov(Coeff1, Coeff2, Overflow);
3394 else
3395 Coeff = Coeff1*Coeff2;
3396 const SCEV *CoeffTerm = getConstant(Ty, Coeff);
3397 const SCEV *Term1 = AddRec->getOperand(y-z);
3398 const SCEV *Term2 = OtherAddRec->getOperand(z);
3399 SumOps.push_back(getMulExpr(CoeffTerm, Term1, Term2,
3400 SCEV::FlagAnyWrap, Depth + 1));
3401 }
3402 }
3403 if (SumOps.empty())
3404 SumOps.push_back(getZero(Ty));
3405 AddRecOps.push_back(getAddExpr(SumOps, SCEV::FlagAnyWrap, Depth + 1));
3406 }
3407 if (!Overflow) {
3408 const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRec->getLoop(),
3410 if (Ops.size() == 2) return NewAddRec;
3411 Ops[Idx] = NewAddRec;
3412 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
3413 OpsModified = true;
3414 AddRec = dyn_cast<SCEVAddRecExpr>(NewAddRec);
3415 if (!AddRec)
3416 break;
3417 }
3418 }
3419 if (OpsModified)
3420 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3421
3422 // Otherwise couldn't fold anything into this recurrence. Move onto the
3423 // next one.
3424 }
3425
3426 // Okay, it looks like we really DO need an mul expr. Check to see if we
3427 // already have one, otherwise create a new one.
3428 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3429}
3430
3431/// Represents an unsigned remainder expression based on unsigned division.
3433 const SCEV *RHS) {
3434 assert(getEffectiveSCEVType(LHS->getType()) ==
3435 getEffectiveSCEVType(RHS->getType()) &&
3436 "SCEVURemExpr operand types don't match!");
3437
3438 // Short-circuit easy cases
3439 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3440 // If constant is one, the result is trivial
3441 if (RHSC->getValue()->isOne())
3442 return getZero(LHS->getType()); // X urem 1 --> 0
3443
3444 // If constant is a power of two, fold into a zext(trunc(LHS)).
3445 if (RHSC->getAPInt().isPowerOf2()) {
3446 Type *FullTy = LHS->getType();
3447 Type *TruncTy =
3448 IntegerType::get(getContext(), RHSC->getAPInt().logBase2());
3449 return getZeroExtendExpr(getTruncateExpr(LHS, TruncTy), FullTy);
3450 }
3451 }
3452
3453 // Fallback to %a == %x urem %y == %x -<nuw> ((%x udiv %y) *<nuw> %y)
3454 const SCEV *UDiv = getUDivExpr(LHS, RHS);
3455 const SCEV *Mult = getMulExpr(UDiv, RHS, SCEV::FlagNUW);
3456 return getMinusSCEV(LHS, Mult, SCEV::FlagNUW);
3457}
3458
3459/// Get a canonical unsigned division expression, or something simpler if
3460/// possible.
3462 const SCEV *RHS) {
3463 assert(!LHS->getType()->isPointerTy() &&
3464 "SCEVUDivExpr operand can't be pointer!");
3465 assert(LHS->getType() == RHS->getType() &&
3466 "SCEVUDivExpr operand types don't match!");
3467
3469 ID.AddInteger(scUDivExpr);
3470 ID.AddPointer(LHS);
3471 ID.AddPointer(RHS);
3472 void *IP = nullptr;
3473 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3474 return S;
3475
3476 // 0 udiv Y == 0
3477 if (match(LHS, m_scev_Zero()))
3478 return LHS;
3479
3480 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3481 if (RHSC->getValue()->isOne())
3482 return LHS; // X udiv 1 --> x
3483 // If the denominator is zero, the result of the udiv is undefined. Don't
3484 // try to analyze it, because the resolution chosen here may differ from
3485 // the resolution chosen in other parts of the compiler.
3486 if (!RHSC->getValue()->isZero()) {
3487 // Determine if the division can be folded into the operands of
3488 // its operands.
3489 // TODO: Generalize this to non-constants by using known-bits information.
3490 Type *Ty = LHS->getType();
3491 unsigned LZ = RHSC->getAPInt().countl_zero();
3492 unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1;
3493 // For non-power-of-two values, effectively round the value up to the
3494 // nearest power of two.
3495 if (!RHSC->getAPInt().isPowerOf2())
3496 ++MaxShiftAmt;
3497 IntegerType *ExtTy =
3498 IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt);
3499 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
3500 if (const SCEVConstant *Step =
3501 dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this))) {
3502 // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded.
3503 const APInt &StepInt = Step->getAPInt();
3504 const APInt &DivInt = RHSC->getAPInt();
3505 if (!StepInt.urem(DivInt) &&
3506 getZeroExtendExpr(AR, ExtTy) ==
3507 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3508 getZeroExtendExpr(Step, ExtTy),
3509 AR->getLoop(), SCEV::FlagAnyWrap)) {
3511 for (const SCEV *Op : AR->operands())
3512 Operands.push_back(getUDivExpr(Op, RHS));
3513 return getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagNW);
3514 }
3515 /// Get a canonical UDivExpr for a recurrence.
3516 /// {X,+,N}/C => {Y,+,N}/C where Y=X-(X%N). Safe when C%N=0.
3517 const APInt *StartRem;
3518 if (!DivInt.urem(StepInt) && match(getURemExpr(AR->getStart(), Step),
3519 m_scev_APInt(StartRem))) {
3520 bool NoWrap =
3521 getZeroExtendExpr(AR, ExtTy) ==
3522 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3523 getZeroExtendExpr(Step, ExtTy), AR->getLoop(),
3525
3526 // With N <= C and both N, C as powers-of-2, the transformation
3527 // {X,+,N}/C => {(X - X%N),+,N}/C preserves division results even
3528 // if wrapping occurs, as the division results remain equivalent for
3529 // all offsets in [[(X - X%N), X).
3530 bool CanFoldWithWrap = StepInt.ule(DivInt) && // N <= C
3531 StepInt.isPowerOf2() && DivInt.isPowerOf2();
3532 // Only fold if the subtraction can be folded in the start
3533 // expression.
3534 const SCEV *NewStart =
3535 getMinusSCEV(AR->getStart(), getConstant(*StartRem));
3536 if (*StartRem != 0 && (NoWrap || CanFoldWithWrap) &&
3537 !isa<SCEVAddExpr>(NewStart)) {
3538 const SCEV *NewLHS =
3539 getAddRecExpr(NewStart, Step, AR->getLoop(),
3540 NoWrap ? SCEV::FlagNW : SCEV::FlagAnyWrap);
3541 if (LHS != NewLHS) {
3542 LHS = NewLHS;
3543
3544 // Reset the ID to include the new LHS, and check if it is
3545 // already cached.
3546 ID.clear();
3547 ID.AddInteger(scUDivExpr);
3548 ID.AddPointer(LHS);
3549 ID.AddPointer(RHS);
3550 IP = nullptr;
3551 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3552 return S;
3553 }
3554 }
3555 }
3556 }
3557 // (A*B)/C --> A*(B/C) if safe and B/C can be folded.
3558 if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) {
3560 for (const SCEV *Op : M->operands())
3561 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3562 if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands))
3563 // Find an operand that's safely divisible.
3564 for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) {
3565 const SCEV *Op = M->getOperand(i);
3566 const SCEV *Div = getUDivExpr(Op, RHSC);
3567 if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) {
3568 Operands = SmallVector<const SCEV *, 4>(M->operands());
3569 Operands[i] = Div;
3570 return getMulExpr(Operands);
3571 }
3572 }
3573 }
3574
3575 // (A/B)/C --> A/(B*C) if safe and B*C can be folded.
3576 if (const SCEVUDivExpr *OtherDiv = dyn_cast<SCEVUDivExpr>(LHS)) {
3577 if (auto *DivisorConstant =
3578 dyn_cast<SCEVConstant>(OtherDiv->getRHS())) {
3579 bool Overflow = false;
3580 APInt NewRHS =
3581 DivisorConstant->getAPInt().umul_ov(RHSC->getAPInt(), Overflow);
3582 if (Overflow) {
3583 return getConstant(RHSC->getType(), 0, false);
3584 }
3585 return getUDivExpr(OtherDiv->getLHS(), getConstant(NewRHS));
3586 }
3587 }
3588
3589 // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded.
3590 if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(LHS)) {
3592 for (const SCEV *Op : A->operands())
3593 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3594 if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) {
3595 Operands.clear();
3596 for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) {
3597 const SCEV *Op = getUDivExpr(A->getOperand(i), RHS);
3598 if (isa<SCEVUDivExpr>(Op) ||
3599 getMulExpr(Op, RHS) != A->getOperand(i))
3600 break;
3601 Operands.push_back(Op);
3602 }
3603 if (Operands.size() == A->getNumOperands())
3604 return getAddExpr(Operands);
3605 }
3606 }
3607
3608 // Fold if both operands are constant.
3609 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS))
3610 return getConstant(LHSC->getAPInt().udiv(RHSC->getAPInt()));
3611 }
3612 }
3613
3614 // ((-C + (C smax %x)) /u %x) evaluates to zero, for any positive constant C.
3615 const APInt *NegC, *C;
3616 if (match(LHS,
3619 NegC->isNegative() && !NegC->isMinSignedValue() && *C == -*NegC)
3620 return getZero(LHS->getType());
3621
3622 // TODO: Generalize to handle any common factors.
3623 // udiv (mul nuw a, vscale), (mul nuw b, vscale) --> udiv a, b
3624 const SCEV *NewLHS, *NewRHS;
3625 if (match(LHS, m_scev_c_NUWMul(m_SCEV(NewLHS), m_SCEVVScale())) &&
3626 match(RHS, m_scev_c_NUWMul(m_SCEV(NewRHS), m_SCEVVScale())))
3627 return getUDivExpr(NewLHS, NewRHS);
3628
3629 // The Insertion Point (IP) might be invalid by now (due to UniqueSCEVs
3630 // changes). Make sure we get a new one.
3631 IP = nullptr;
3632 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
3633 SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator),
3634 LHS, RHS);
3635 UniqueSCEVs.InsertNode(S, IP);
3636 registerUser(S, {LHS, RHS});
3637 return S;
3638}
3639
3640APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) {
3641 APInt A = C1->getAPInt().abs();
3642 APInt B = C2->getAPInt().abs();
3643 uint32_t ABW = A.getBitWidth();
3644 uint32_t BBW = B.getBitWidth();
3645
3646 if (ABW > BBW)
3647 B = B.zext(ABW);
3648 else if (ABW < BBW)
3649 A = A.zext(BBW);
3650
3651 return APIntOps::GreatestCommonDivisor(std::move(A), std::move(B));
3652}
3653
3654/// Get a canonical unsigned division expression, or something simpler if
3655/// possible. There is no representation for an exact udiv in SCEV IR, but we
3656/// can attempt to remove factors from the LHS and RHS. We can't do this when
3657/// it's not exact because the udiv may be clearing bits.
3659 const SCEV *RHS) {
3660 // TODO: we could try to find factors in all sorts of things, but for now we
3661 // just deal with u/exact (multiply, constant). See SCEVDivision towards the
3662 // end of this file for inspiration.
3663
3665 if (!Mul || !Mul->hasNoUnsignedWrap())
3666 return getUDivExpr(LHS, RHS);
3667
3668 if (const SCEVConstant *RHSCst = dyn_cast<SCEVConstant>(RHS)) {
3669 // If the mulexpr multiplies by a constant, then that constant must be the
3670 // first element of the mulexpr.
3671 if (const auto *LHSCst = dyn_cast<SCEVConstant>(Mul->getOperand(0))) {
3672 if (LHSCst == RHSCst) {
3673 SmallVector<const SCEV *, 2> Operands(drop_begin(Mul->operands()));
3674 return getMulExpr(Operands);
3675 }
3676
3677 // We can't just assume that LHSCst divides RHSCst cleanly, it could be
3678 // that there's a factor provided by one of the other terms. We need to
3679 // check.
3680 APInt Factor = gcd(LHSCst, RHSCst);
3681 if (!Factor.isIntN(1)) {
3682 LHSCst =
3683 cast<SCEVConstant>(getConstant(LHSCst->getAPInt().udiv(Factor)));
3684 RHSCst =
3685 cast<SCEVConstant>(getConstant(RHSCst->getAPInt().udiv(Factor)));
3687 Operands.push_back(LHSCst);
3688 append_range(Operands, Mul->operands().drop_front());
3689 LHS = getMulExpr(Operands);
3690 RHS = RHSCst;
3692 if (!Mul)
3693 return getUDivExactExpr(LHS, RHS);
3694 }
3695 }
3696 }
3697
3698 for (int i = 0, e = Mul->getNumOperands(); i != e; ++i) {
3699 if (Mul->getOperand(i) == RHS) {
3701 append_range(Operands, Mul->operands().take_front(i));
3702 append_range(Operands, Mul->operands().drop_front(i + 1));
3703 return getMulExpr(Operands);
3704 }
3705 }
3706
3707 return getUDivExpr(LHS, RHS);
3708}
3709
3710/// Get an add recurrence expression for the specified loop. Simplify the
3711/// expression as much as possible.
3712const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, const SCEV *Step,
3713 const Loop *L,
3714 SCEV::NoWrapFlags Flags) {
3716 Operands.push_back(Start);
3717 if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step))
3718 if (StepChrec->getLoop() == L) {
3719 append_range(Operands, StepChrec->operands());
3720 return getAddRecExpr(Operands, L, maskFlags(Flags, SCEV::FlagNW));
3721 }
3722
3723 Operands.push_back(Step);
3724 return getAddRecExpr(Operands, L, Flags);
3725}
3726
3727/// Get an add recurrence expression for the specified loop. Simplify the
3728/// expression as much as possible.
3729const SCEV *
3731 const Loop *L, SCEV::NoWrapFlags Flags) {
3732 if (Operands.size() == 1) return Operands[0];
3733#ifndef NDEBUG
3734 Type *ETy = getEffectiveSCEVType(Operands[0]->getType());
3735 for (const SCEV *Op : llvm::drop_begin(Operands)) {
3736 assert(getEffectiveSCEVType(Op->getType()) == ETy &&
3737 "SCEVAddRecExpr operand types don't match!");
3738 assert(!Op->getType()->isPointerTy() && "Step must be integer");
3739 }
3740 for (const SCEV *Op : Operands)
3742 "SCEVAddRecExpr operand is not available at loop entry!");
3743#endif
3744
3745 if (Operands.back()->isZero()) {
3746 Operands.pop_back();
3747 return getAddRecExpr(Operands, L, SCEV::FlagAnyWrap); // {X,+,0} --> X
3748 }
3749
3750 // It's tempting to want to call getConstantMaxBackedgeTakenCount count here and
3751 // use that information to infer NUW and NSW flags. However, computing a
3752 // BE count requires calling getAddRecExpr, so we may not yet have a
3753 // meaningful BE count at this point (and if we don't, we'd be stuck
3754 // with a SCEVCouldNotCompute as the cached BE count).
3755
3756 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
3757
3758 // Canonicalize nested AddRecs in by nesting them in order of loop depth.
3759 if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
3760 const Loop *NestedLoop = NestedAR->getLoop();
3761 if (L->contains(NestedLoop)
3762 ? (L->getLoopDepth() < NestedLoop->getLoopDepth())
3763 : (!NestedLoop->contains(L) &&
3764 DT.dominates(L->getHeader(), NestedLoop->getHeader()))) {
3765 SmallVector<const SCEV *, 4> NestedOperands(NestedAR->operands());
3766 Operands[0] = NestedAR->getStart();
3767 // AddRecs require their operands be loop-invariant with respect to their
3768 // loops. Don't perform this transformation if it would break this
3769 // requirement.
3770 bool AllInvariant = all_of(
3771 Operands, [&](const SCEV *Op) { return isLoopInvariant(Op, L); });
3772
3773 if (AllInvariant) {
3774 // Create a recurrence for the outer loop with the same step size.
3775 //
3776 // The outer recurrence keeps its NW flag but only keeps NUW/NSW if the
3777 // inner recurrence has the same property.
3778 SCEV::NoWrapFlags OuterFlags =
3779 maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags());
3780
3781 NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags);
3782 AllInvariant = all_of(NestedOperands, [&](const SCEV *Op) {
3783 return isLoopInvariant(Op, NestedLoop);
3784 });
3785
3786 if (AllInvariant) {
3787 // Ok, both add recurrences are valid after the transformation.
3788 //
3789 // The inner recurrence keeps its NW flag but only keeps NUW/NSW if
3790 // the outer recurrence has the same property.
3791 SCEV::NoWrapFlags InnerFlags =
3792 maskFlags(NestedAR->getNoWrapFlags(), SCEV::FlagNW | Flags);
3793 return getAddRecExpr(NestedOperands, NestedLoop, InnerFlags);
3794 }
3795 }
3796 // Reset Operands to its original state.
3797 Operands[0] = NestedAR;
3798 }
3799 }
3800
3801 // Okay, it looks like we really DO need an addrec expr. Check to see if we
3802 // already have one, otherwise create a new one.
3803 return getOrCreateAddRecExpr(Operands, L, Flags);
3804}
3805
3807 ArrayRef<const SCEV *> IndexExprs) {
3808 const SCEV *BaseExpr = getSCEV(GEP->getPointerOperand());
3809 // getSCEV(Base)->getType() has the same address space as Base->getType()
3810 // because SCEV::getType() preserves the address space.
3811 GEPNoWrapFlags NW = GEP->getNoWrapFlags();
3812 if (NW != GEPNoWrapFlags::none()) {
3813 // We'd like to propagate flags from the IR to the corresponding SCEV nodes,
3814 // but to do that, we have to ensure that said flag is valid in the entire
3815 // defined scope of the SCEV.
3816 // TODO: non-instructions have global scope. We might be able to prove
3817 // some global scope cases
3818 auto *GEPI = dyn_cast<Instruction>(GEP);
3819 if (!GEPI || !isSCEVExprNeverPoison(GEPI))
3820 NW = GEPNoWrapFlags::none();
3821 }
3822
3823 return getGEPExpr(BaseExpr, IndexExprs, GEP->getSourceElementType(), NW);
3824}
3825
3827 ArrayRef<const SCEV *> IndexExprs,
3828 Type *SrcElementTy, GEPNoWrapFlags NW) {
3830 if (NW.hasNoUnsignedSignedWrap())
3831 OffsetWrap = setFlags(OffsetWrap, SCEV::FlagNSW);
3832 if (NW.hasNoUnsignedWrap())
3833 OffsetWrap = setFlags(OffsetWrap, SCEV::FlagNUW);
3834
3835 Type *CurTy = BaseExpr->getType();
3836 Type *IntIdxTy = getEffectiveSCEVType(BaseExpr->getType());
3837 bool FirstIter = true;
3839 for (const SCEV *IndexExpr : IndexExprs) {
3840 // Compute the (potentially symbolic) offset in bytes for this index.
3841 if (StructType *STy = dyn_cast<StructType>(CurTy)) {
3842 // For a struct, add the member offset.
3843 ConstantInt *Index = cast<SCEVConstant>(IndexExpr)->getValue();
3844 unsigned FieldNo = Index->getZExtValue();
3845 const SCEV *FieldOffset = getOffsetOfExpr(IntIdxTy, STy, FieldNo);
3846 Offsets.push_back(FieldOffset);
3847
3848 // Update CurTy to the type of the field at Index.
3849 CurTy = STy->getTypeAtIndex(Index);
3850 } else {
3851 // Update CurTy to its element type.
3852 if (FirstIter) {
3853 assert(isa<PointerType>(CurTy) &&
3854 "The first index of a GEP indexes a pointer");
3855 CurTy = SrcElementTy;
3856 FirstIter = false;
3857 } else {
3859 }
3860 // For an array, add the element offset, explicitly scaled.
3861 const SCEV *ElementSize = getSizeOfExpr(IntIdxTy, CurTy);
3862 // Getelementptr indices are signed.
3863 IndexExpr = getTruncateOrSignExtend(IndexExpr, IntIdxTy);
3864
3865 // Multiply the index by the element size to compute the element offset.
3866 const SCEV *LocalOffset = getMulExpr(IndexExpr, ElementSize, OffsetWrap);
3867 Offsets.push_back(LocalOffset);
3868 }
3869 }
3870
3871 // Handle degenerate case of GEP without offsets.
3872 if (Offsets.empty())
3873 return BaseExpr;
3874
3875 // Add the offsets together, assuming nsw if inbounds.
3876 const SCEV *Offset = getAddExpr(Offsets, OffsetWrap);
3877 // Add the base address and the offset. We cannot use the nsw flag, as the
3878 // base address is unsigned. However, if we know that the offset is
3879 // non-negative, we can use nuw.
3880 bool NUW = NW.hasNoUnsignedWrap() ||
3883 auto *GEPExpr = getAddExpr(BaseExpr, Offset, BaseWrap);
3884 assert(BaseExpr->getType() == GEPExpr->getType() &&
3885 "GEP should not change type mid-flight.");
3886 return GEPExpr;
3887}
3888
3889SCEV *ScalarEvolution::findExistingSCEVInCache(SCEVTypes SCEVType,
3892 ID.AddInteger(SCEVType);
3893 for (const SCEV *Op : Ops)
3894 ID.AddPointer(Op);
3895 void *IP = nullptr;
3896 return UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3897}
3898
3899const SCEV *ScalarEvolution::getAbsExpr(const SCEV *Op, bool IsNSW) {
3901 return getSMaxExpr(Op, getNegativeSCEV(Op, Flags));
3902}
3903
3906 assert(SCEVMinMaxExpr::isMinMaxType(Kind) && "Not a SCEVMinMaxExpr!");
3907 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
3908 if (Ops.size() == 1) return Ops[0];
3909#ifndef NDEBUG
3910 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
3911 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
3912 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
3913 "Operand types don't match!");
3914 assert(Ops[0]->getType()->isPointerTy() ==
3915 Ops[i]->getType()->isPointerTy() &&
3916 "min/max should be consistently pointerish");
3917 }
3918#endif
3919
3920 bool IsSigned = Kind == scSMaxExpr || Kind == scSMinExpr;
3921 bool IsMax = Kind == scSMaxExpr || Kind == scUMaxExpr;
3922
3923 const SCEV *Folded = constantFoldAndGroupOps(
3924 *this, LI, DT, Ops,
3925 [&](const APInt &C1, const APInt &C2) {
3926 switch (Kind) {
3927 case scSMaxExpr:
3928 return APIntOps::smax(C1, C2);
3929 case scSMinExpr:
3930 return APIntOps::smin(C1, C2);
3931 case scUMaxExpr:
3932 return APIntOps::umax(C1, C2);
3933 case scUMinExpr:
3934 return APIntOps::umin(C1, C2);
3935 default:
3936 llvm_unreachable("Unknown SCEV min/max opcode");
3937 }
3938 },
3939 [&](const APInt &C) {
3940 // identity
3941 if (IsMax)
3942 return IsSigned ? C.isMinSignedValue() : C.isMinValue();
3943 else
3944 return IsSigned ? C.isMaxSignedValue() : C.isMaxValue();
3945 },
3946 [&](const APInt &C) {
3947 // absorber
3948 if (IsMax)
3949 return IsSigned ? C.isMaxSignedValue() : C.isMaxValue();
3950 else
3951 return IsSigned ? C.isMinSignedValue() : C.isMinValue();
3952 });
3953 if (Folded)
3954 return Folded;
3955
3956 // Check if we have created the same expression before.
3957 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops)) {
3958 return S;
3959 }
3960
3961 // Find the first operation of the same kind
3962 unsigned Idx = 0;
3963 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < Kind)
3964 ++Idx;
3965
3966 // Check to see if one of the operands is of the same kind. If so, expand its
3967 // operands onto our operand list, and recurse to simplify.
3968 if (Idx < Ops.size()) {
3969 bool DeletedAny = false;
3970 while (Ops[Idx]->getSCEVType() == Kind) {
3971 const SCEVMinMaxExpr *SMME = cast<SCEVMinMaxExpr>(Ops[Idx]);
3972 Ops.erase(Ops.begin()+Idx);
3973 append_range(Ops, SMME->operands());
3974 DeletedAny = true;
3975 }
3976
3977 if (DeletedAny)
3978 return getMinMaxExpr(Kind, Ops);
3979 }
3980
3981 // Okay, check to see if the same value occurs in the operand list twice. If
3982 // so, delete one. Since we sorted the list, these values are required to
3983 // be adjacent.
3988 llvm::CmpInst::Predicate FirstPred = IsMax ? GEPred : LEPred;
3989 llvm::CmpInst::Predicate SecondPred = IsMax ? LEPred : GEPred;
3990 for (unsigned i = 0, e = Ops.size() - 1; i != e; ++i) {
3991 if (Ops[i] == Ops[i + 1] ||
3992 isKnownViaNonRecursiveReasoning(FirstPred, Ops[i], Ops[i + 1])) {
3993 // X op Y op Y --> X op Y
3994 // X op Y --> X, if we know X, Y are ordered appropriately
3995 Ops.erase(Ops.begin() + i + 1, Ops.begin() + i + 2);
3996 --i;
3997 --e;
3998 } else if (isKnownViaNonRecursiveReasoning(SecondPred, Ops[i],
3999 Ops[i + 1])) {
4000 // X op Y --> Y, if we know X, Y are ordered appropriately
4001 Ops.erase(Ops.begin() + i, Ops.begin() + i + 1);
4002 --i;
4003 --e;
4004 }
4005 }
4006
4007 if (Ops.size() == 1) return Ops[0];
4008
4009 assert(!Ops.empty() && "Reduced smax down to nothing!");
4010
4011 // Okay, it looks like we really DO need an expr. Check to see if we
4012 // already have one, otherwise create a new one.
4014 ID.AddInteger(Kind);
4015 for (const SCEV *Op : Ops)
4016 ID.AddPointer(Op);
4017 void *IP = nullptr;
4018 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
4019 if (ExistingSCEV)
4020 return ExistingSCEV;
4021 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
4023 SCEV *S = new (SCEVAllocator)
4024 SCEVMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
4025
4026 UniqueSCEVs.InsertNode(S, IP);
4027 registerUser(S, Ops);
4028 return S;
4029}
4030
4031namespace {
4032
4033class SCEVSequentialMinMaxDeduplicatingVisitor final
4034 : public SCEVVisitor<SCEVSequentialMinMaxDeduplicatingVisitor,
4035 std::optional<const SCEV *>> {
4036 using RetVal = std::optional<const SCEV *>;
4038
4039 ScalarEvolution &SE;
4040 const SCEVTypes RootKind; // Must be a sequential min/max expression.
4041 const SCEVTypes NonSequentialRootKind; // Non-sequential variant of RootKind.
4043
4044 bool canRecurseInto(SCEVTypes Kind) const {
4045 // We can only recurse into the SCEV expression of the same effective type
4046 // as the type of our root SCEV expression.
4047 return RootKind == Kind || NonSequentialRootKind == Kind;
4048 };
4049
4050 RetVal visitAnyMinMaxExpr(const SCEV *S) {
4052 "Only for min/max expressions.");
4053 SCEVTypes Kind = S->getSCEVType();
4054
4055 if (!canRecurseInto(Kind))
4056 return S;
4057
4058 auto *NAry = cast<SCEVNAryExpr>(S);
4060 bool Changed = visit(Kind, NAry->operands(), NewOps);
4061
4062 if (!Changed)
4063 return S;
4064 if (NewOps.empty())
4065 return std::nullopt;
4066
4068 ? SE.getSequentialMinMaxExpr(Kind, NewOps)
4069 : SE.getMinMaxExpr(Kind, NewOps);
4070 }
4071
4072 RetVal visit(const SCEV *S) {
4073 // Has the whole operand been seen already?
4074 if (!SeenOps.insert(S).second)
4075 return std::nullopt;
4076 return Base::visit(S);
4077 }
4078
4079public:
4080 SCEVSequentialMinMaxDeduplicatingVisitor(ScalarEvolution &SE,
4081 SCEVTypes RootKind)
4082 : SE(SE), RootKind(RootKind),
4083 NonSequentialRootKind(
4084 SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType(
4085 RootKind)) {}
4086
4087 bool /*Changed*/ visit(SCEVTypes Kind, ArrayRef<const SCEV *> OrigOps,
4088 SmallVectorImpl<const SCEV *> &NewOps) {
4089 bool Changed = false;
4091 Ops.reserve(OrigOps.size());
4092
4093 for (const SCEV *Op : OrigOps) {
4094 RetVal NewOp = visit(Op);
4095 if (NewOp != Op)
4096 Changed = true;
4097 if (NewOp)
4098 Ops.emplace_back(*NewOp);
4099 }
4100
4101 if (Changed)
4102 NewOps = std::move(Ops);
4103 return Changed;
4104 }
4105
4106 RetVal visitConstant(const SCEVConstant *Constant) { return Constant; }
4107
4108 RetVal visitVScale(const SCEVVScale *VScale) { return VScale; }
4109
4110 RetVal visitPtrToAddrExpr(const SCEVPtrToAddrExpr *Expr) { return Expr; }
4111
4112 RetVal visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { return Expr; }
4113
4114 RetVal visitTruncateExpr(const SCEVTruncateExpr *Expr) { return Expr; }
4115
4116 RetVal visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { return Expr; }
4117
4118 RetVal visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { return Expr; }
4119
4120 RetVal visitAddExpr(const SCEVAddExpr *Expr) { return Expr; }
4121
4122 RetVal visitMulExpr(const SCEVMulExpr *Expr) { return Expr; }
4123
4124 RetVal visitUDivExpr(const SCEVUDivExpr *Expr) { return Expr; }
4125
4126 RetVal visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
4127
4128 RetVal visitSMaxExpr(const SCEVSMaxExpr *Expr) {
4129 return visitAnyMinMaxExpr(Expr);
4130 }
4131
4132 RetVal visitUMaxExpr(const SCEVUMaxExpr *Expr) {
4133 return visitAnyMinMaxExpr(Expr);
4134 }
4135
4136 RetVal visitSMinExpr(const SCEVSMinExpr *Expr) {
4137 return visitAnyMinMaxExpr(Expr);
4138 }
4139
4140 RetVal visitUMinExpr(const SCEVUMinExpr *Expr) {
4141 return visitAnyMinMaxExpr(Expr);
4142 }
4143
4144 RetVal visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
4145 return visitAnyMinMaxExpr(Expr);
4146 }
4147
4148 RetVal visitUnknown(const SCEVUnknown *Expr) { return Expr; }
4149
4150 RetVal visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { return Expr; }
4151};
4152
4153} // namespace
4154
4156 switch (Kind) {
4157 case scConstant:
4158 case scVScale:
4159 case scTruncate:
4160 case scZeroExtend:
4161 case scSignExtend:
4162 case scPtrToAddr:
4163 case scPtrToInt:
4164 case scAddExpr:
4165 case scMulExpr:
4166 case scUDivExpr:
4167 case scAddRecExpr:
4168 case scUMaxExpr:
4169 case scSMaxExpr:
4170 case scUMinExpr:
4171 case scSMinExpr:
4172 case scUnknown:
4173 // If any operand is poison, the whole expression is poison.
4174 return true;
4176 // FIXME: if the *first* operand is poison, the whole expression is poison.
4177 return false; // Pessimistically, say that it does not propagate poison.
4178 case scCouldNotCompute:
4179 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
4180 }
4181 llvm_unreachable("Unknown SCEV kind!");
4182}
4183
4184namespace {
4185// The only way poison may be introduced in a SCEV expression is from a
4186// poison SCEVUnknown (ConstantExprs are also represented as SCEVUnknown,
4187// not SCEVConstant). Notably, nowrap flags in SCEV nodes can *not*
4188// introduce poison -- they encode guaranteed, non-speculated knowledge.
4189//
4190// Additionally, all SCEV nodes propagate poison from inputs to outputs,
4191// with the notable exception of umin_seq, where only poison from the first
4192// operand is (unconditionally) propagated.
4193struct SCEVPoisonCollector {
4194 bool LookThroughMaybePoisonBlocking;
4195 SmallPtrSet<const SCEVUnknown *, 4> MaybePoison;
4196 SCEVPoisonCollector(bool LookThroughMaybePoisonBlocking)
4197 : LookThroughMaybePoisonBlocking(LookThroughMaybePoisonBlocking) {}
4198
4199 bool follow(const SCEV *S) {
4200 if (!LookThroughMaybePoisonBlocking &&
4202 return false;
4203
4204 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
4205 if (!isGuaranteedNotToBePoison(SU->getValue()))
4206 MaybePoison.insert(SU);
4207 }
4208 return true;
4209 }
4210 bool isDone() const { return false; }
4211};
4212} // namespace
4213
4214/// Return true if V is poison given that AssumedPoison is already poison.
4215static bool impliesPoison(const SCEV *AssumedPoison, const SCEV *S) {
4216 // First collect all SCEVs that might result in AssumedPoison to be poison.
4217 // We need to look through potentially poison-blocking operations here,
4218 // because we want to find all SCEVs that *might* result in poison, not only
4219 // those that are *required* to.
4220 SCEVPoisonCollector PC1(/* LookThroughMaybePoisonBlocking */ true);
4221 visitAll(AssumedPoison, PC1);
4222
4223 // AssumedPoison is never poison. As the assumption is false, the implication
4224 // is true. Don't bother walking the other SCEV in this case.
4225 if (PC1.MaybePoison.empty())
4226 return true;
4227
4228 // Collect all SCEVs in S that, if poison, *will* result in S being poison
4229 // as well. We cannot look through potentially poison-blocking operations
4230 // here, as their arguments only *may* make the result poison.
4231 SCEVPoisonCollector PC2(/* LookThroughMaybePoisonBlocking */ false);
4232 visitAll(S, PC2);
4233
4234 // Make sure that no matter which SCEV in PC1.MaybePoison is actually poison,
4235 // it will also make S poison by being part of PC2.MaybePoison.
4236 return llvm::set_is_subset(PC1.MaybePoison, PC2.MaybePoison);
4237}
4238
4240 SmallPtrSetImpl<const Value *> &Result, const SCEV *S) {
4241 SCEVPoisonCollector PC(/* LookThroughMaybePoisonBlocking */ false);
4242 visitAll(S, PC);
4243 for (const SCEVUnknown *SU : PC.MaybePoison)
4244 Result.insert(SU->getValue());
4245}
4246
4248 const SCEV *S, Instruction *I,
4249 SmallVectorImpl<Instruction *> &DropPoisonGeneratingInsts) {
4250 // If the instruction cannot be poison, it's always safe to reuse.
4252 return true;
4253
4254 // Otherwise, it is possible that I is more poisonous that S. Collect the
4255 // poison-contributors of S, and then check whether I has any additional
4256 // poison-contributors. Poison that is contributed through poison-generating
4257 // flags is handled by dropping those flags instead.
4259 getPoisonGeneratingValues(PoisonVals, S);
4260
4261 SmallVector<Value *> Worklist;
4263 Worklist.push_back(I);
4264 while (!Worklist.empty()) {
4265 Value *V = Worklist.pop_back_val();
4266 if (!Visited.insert(V).second)
4267 continue;
4268
4269 // Avoid walking large instruction graphs.
4270 if (Visited.size() > 16)
4271 return false;
4272
4273 // Either the value can't be poison, or the S would also be poison if it
4274 // is.
4275 if (PoisonVals.contains(V) || ::isGuaranteedNotToBePoison(V))
4276 continue;
4277
4278 auto *I = dyn_cast<Instruction>(V);
4279 if (!I)
4280 return false;
4281
4282 // Disjoint or instructions are interpreted as adds by SCEV. However, we
4283 // can't replace an arbitrary add with disjoint or, even if we drop the
4284 // flag. We would need to convert the or into an add.
4285 if (auto *PDI = dyn_cast<PossiblyDisjointInst>(I))
4286 if (PDI->isDisjoint())
4287 return false;
4288
4289 // FIXME: Ignore vscale, even though it technically could be poison. Do this
4290 // because SCEV currently assumes it can't be poison. Remove this special
4291 // case once we proper model when vscale can be poison.
4292 if (auto *II = dyn_cast<IntrinsicInst>(I);
4293 II && II->getIntrinsicID() == Intrinsic::vscale)
4294 continue;
4295
4296 if (canCreatePoison(cast<Operator>(I), /*ConsiderFlagsAndMetadata*/ false))
4297 return false;
4298
4299 // If the instruction can't create poison, we can recurse to its operands.
4300 if (I->hasPoisonGeneratingAnnotations())
4301 DropPoisonGeneratingInsts.push_back(I);
4302
4303 llvm::append_range(Worklist, I->operands());
4304 }
4305 return true;
4306}
4307
4308const SCEV *
4311 assert(SCEVSequentialMinMaxExpr::isSequentialMinMaxType(Kind) &&
4312 "Not a SCEVSequentialMinMaxExpr!");
4313 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
4314 if (Ops.size() == 1)
4315 return Ops[0];
4316#ifndef NDEBUG
4317 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
4318 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4319 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
4320 "Operand types don't match!");
4321 assert(Ops[0]->getType()->isPointerTy() ==
4322 Ops[i]->getType()->isPointerTy() &&
4323 "min/max should be consistently pointerish");
4324 }
4325#endif
4326
4327 // Note that SCEVSequentialMinMaxExpr is *NOT* commutative,
4328 // so we can *NOT* do any kind of sorting of the expressions!
4329
4330 // Check if we have created the same expression before.
4331 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops))
4332 return S;
4333
4334 // FIXME: there are *some* simplifications that we can do here.
4335
4336 // Keep only the first instance of an operand.
4337 {
4338 SCEVSequentialMinMaxDeduplicatingVisitor Deduplicator(*this, Kind);
4339 bool Changed = Deduplicator.visit(Kind, Ops, Ops);
4340 if (Changed)
4341 return getSequentialMinMaxExpr(Kind, Ops);
4342 }
4343
4344 // Check to see if one of the operands is of the same kind. If so, expand its
4345 // operands onto our operand list, and recurse to simplify.
4346 {
4347 unsigned Idx = 0;
4348 bool DeletedAny = false;
4349 while (Idx < Ops.size()) {
4350 if (Ops[Idx]->getSCEVType() != Kind) {
4351 ++Idx;
4352 continue;
4353 }
4354 const auto *SMME = cast<SCEVSequentialMinMaxExpr>(Ops[Idx]);
4355 Ops.erase(Ops.begin() + Idx);
4356 Ops.insert(Ops.begin() + Idx, SMME->operands().begin(),
4357 SMME->operands().end());
4358 DeletedAny = true;
4359 }
4360
4361 if (DeletedAny)
4362 return getSequentialMinMaxExpr(Kind, Ops);
4363 }
4364
4365 const SCEV *SaturationPoint;
4367 switch (Kind) {
4369 SaturationPoint = getZero(Ops[0]->getType());
4370 Pred = ICmpInst::ICMP_ULE;
4371 break;
4372 default:
4373 llvm_unreachable("Not a sequential min/max type.");
4374 }
4375
4376 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4377 if (!isGuaranteedNotToCauseUB(Ops[i]))
4378 continue;
4379 // We can replace %x umin_seq %y with %x umin %y if either:
4380 // * %y being poison implies %x is also poison.
4381 // * %x cannot be the saturating value (e.g. zero for umin).
4382 if (::impliesPoison(Ops[i], Ops[i - 1]) ||
4383 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_NE, Ops[i - 1],
4384 SaturationPoint)) {
4385 SmallVector<const SCEV *> SeqOps = {Ops[i - 1], Ops[i]};
4386 Ops[i - 1] = getMinMaxExpr(
4388 SeqOps);
4389 Ops.erase(Ops.begin() + i);
4390 return getSequentialMinMaxExpr(Kind, Ops);
4391 }
4392 // Fold %x umin_seq %y to %x if %x ule %y.
4393 // TODO: We might be able to prove the predicate for a later operand.
4394 if (isKnownViaNonRecursiveReasoning(Pred, Ops[i - 1], Ops[i])) {
4395 Ops.erase(Ops.begin() + i);
4396 return getSequentialMinMaxExpr(Kind, Ops);
4397 }
4398 }
4399
4400 // Okay, it looks like we really DO need an expr. Check to see if we
4401 // already have one, otherwise create a new one.
4403 ID.AddInteger(Kind);
4404 for (const SCEV *Op : Ops)
4405 ID.AddPointer(Op);
4406 void *IP = nullptr;
4407 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
4408 if (ExistingSCEV)
4409 return ExistingSCEV;
4410
4411 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
4413 SCEV *S = new (SCEVAllocator)
4414 SCEVSequentialMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
4415
4416 UniqueSCEVs.InsertNode(S, IP);
4417 registerUser(S, Ops);
4418 return S;
4419}
4420
4421const SCEV *ScalarEvolution::getSMaxExpr(const SCEV *LHS, const SCEV *RHS) {
4422 SmallVector<const SCEV *, 2> Ops = {LHS, RHS};
4423 return getSMaxExpr(Ops);
4424}
4425
4429
4430const SCEV *ScalarEvolution::getUMaxExpr(const SCEV *LHS, const SCEV *RHS) {
4431 SmallVector<const SCEV *, 2> Ops = {LHS, RHS};
4432 return getUMaxExpr(Ops);
4433}
4434
4438
4440 const SCEV *RHS) {
4441 SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
4442 return getSMinExpr(Ops);
4443}
4444
4448
4449const SCEV *ScalarEvolution::getUMinExpr(const SCEV *LHS, const SCEV *RHS,
4450 bool Sequential) {
4451 SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
4452 return getUMinExpr(Ops, Sequential);
4453}
4454
4460
4461const SCEV *
4463 const SCEV *Res = getConstant(IntTy, Size.getKnownMinValue());
4464 if (Size.isScalable())
4465 Res = getMulExpr(Res, getVScale(IntTy));
4466 return Res;
4467}
4468
4470 return getSizeOfExpr(IntTy, getDataLayout().getTypeAllocSize(AllocTy));
4471}
4472
4474 return getSizeOfExpr(IntTy, getDataLayout().getTypeStoreSize(StoreTy));
4475}
4476
4478 StructType *STy,
4479 unsigned FieldNo) {
4480 // We can bypass creating a target-independent constant expression and then
4481 // folding it back into a ConstantInt. This is just a compile-time
4482 // optimization.
4483 const StructLayout *SL = getDataLayout().getStructLayout(STy);
4484 assert(!SL->getSizeInBits().isScalable() &&
4485 "Cannot get offset for structure containing scalable vector types");
4486 return getConstant(IntTy, SL->getElementOffset(FieldNo));
4487}
4488
4490 // Don't attempt to do anything other than create a SCEVUnknown object
4491 // here. createSCEV only calls getUnknown after checking for all other
4492 // interesting possibilities, and any other code that calls getUnknown
4493 // is doing so in order to hide a value from SCEV canonicalization.
4494
4496 ID.AddInteger(scUnknown);
4497 ID.AddPointer(V);
4498 void *IP = nullptr;
4499 if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) {
4500 assert(cast<SCEVUnknown>(S)->getValue() == V &&
4501 "Stale SCEVUnknown in uniquing map!");
4502 return S;
4503 }
4504 SCEV *S = new (SCEVAllocator) SCEVUnknown(ID.Intern(SCEVAllocator), V, this,
4505 FirstUnknown);
4506 FirstUnknown = cast<SCEVUnknown>(S);
4507 UniqueSCEVs.InsertNode(S, IP);
4508 return S;
4509}
4510
4511//===----------------------------------------------------------------------===//
4512// Basic SCEV Analysis and PHI Idiom Recognition Code
4513//
4514
4515/// Test if values of the given type are analyzable within the SCEV
4516/// framework. This primarily includes integer types, and it can optionally
4517/// include pointer types if the ScalarEvolution class has access to
4518/// target-specific information.
4520 // Integers and pointers are always SCEVable.
4521 return Ty->isIntOrPtrTy();
4522}
4523
4524/// Return the size in bits of the specified type, for which isSCEVable must
4525/// return true.
4527 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4528 if (Ty->isPointerTy())
4530 return getDataLayout().getTypeSizeInBits(Ty);
4531}
4532
4533/// Return a type with the same bitwidth as the given type and which represents
4534/// how SCEV will treat the given type, for which isSCEVable must return
4535/// true. For pointer types, this is the pointer index sized integer type.
4537 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4538
4539 if (Ty->isIntegerTy())
4540 return Ty;
4541
4542 // The only other support type is pointer.
4543 assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!");
4544 return getDataLayout().getIndexType(Ty);
4545}
4546
4548 return getTypeSizeInBits(T1) >= getTypeSizeInBits(T2) ? T1 : T2;
4549}
4550
4552 const SCEV *B) {
4553 /// For a valid use point to exist, the defining scope of one operand
4554 /// must dominate the other.
4555 bool PreciseA, PreciseB;
4556 auto *ScopeA = getDefiningScopeBound({A}, PreciseA);
4557 auto *ScopeB = getDefiningScopeBound({B}, PreciseB);
4558 if (!PreciseA || !PreciseB)
4559 // Can't tell.
4560 return false;
4561 return (ScopeA == ScopeB) || DT.dominates(ScopeA, ScopeB) ||
4562 DT.dominates(ScopeB, ScopeA);
4563}
4564
4566 return CouldNotCompute.get();
4567}
4568
4569bool ScalarEvolution::checkValidity(const SCEV *S) const {
4570 bool ContainsNulls = SCEVExprContains(S, [](const SCEV *S) {
4571 auto *SU = dyn_cast<SCEVUnknown>(S);
4572 return SU && SU->getValue() == nullptr;
4573 });
4574
4575 return !ContainsNulls;
4576}
4577
4579 HasRecMapType::iterator I = HasRecMap.find(S);
4580 if (I != HasRecMap.end())
4581 return I->second;
4582
4583 bool FoundAddRec =
4584 SCEVExprContains(S, [](const SCEV *S) { return isa<SCEVAddRecExpr>(S); });
4585 HasRecMap.insert({S, FoundAddRec});
4586 return FoundAddRec;
4587}
4588
4589/// Return the ValueOffsetPair set for \p S. \p S can be represented
4590/// by the value and offset from any ValueOffsetPair in the set.
4591ArrayRef<Value *> ScalarEvolution::getSCEVValues(const SCEV *S) {
4592 ExprValueMapType::iterator SI = ExprValueMap.find_as(S);
4593 if (SI == ExprValueMap.end())
4594 return {};
4595 return SI->second.getArrayRef();
4596}
4597
4598/// Erase Value from ValueExprMap and ExprValueMap. ValueExprMap.erase(V)
4599/// cannot be used separately. eraseValueFromMap should be used to remove
4600/// V from ValueExprMap and ExprValueMap at the same time.
4601void ScalarEvolution::eraseValueFromMap(Value *V) {
4602 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4603 if (I != ValueExprMap.end()) {
4604 auto EVIt = ExprValueMap.find(I->second);
4605 bool Removed = EVIt->second.remove(V);
4606 (void) Removed;
4607 assert(Removed && "Value not in ExprValueMap?");
4608 ValueExprMap.erase(I);
4609 }
4610}
4611
4612void ScalarEvolution::insertValueToMap(Value *V, const SCEV *S) {
4613 // A recursive query may have already computed the SCEV. It should be
4614 // equivalent, but may not necessarily be exactly the same, e.g. due to lazily
4615 // inferred nowrap flags.
4616 auto It = ValueExprMap.find_as(V);
4617 if (It == ValueExprMap.end()) {
4618 ValueExprMap.insert({SCEVCallbackVH(V, this), S});
4619 ExprValueMap[S].insert(V);
4620 }
4621}
4622
4623/// Return an existing SCEV if it exists, otherwise analyze the expression and
4624/// create a new one.
4626 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4627
4628 if (const SCEV *S = getExistingSCEV(V))
4629 return S;
4630 return createSCEVIter(V);
4631}
4632
4634 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4635
4636 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4637 if (I != ValueExprMap.end()) {
4638 const SCEV *S = I->second;
4639 assert(checkValidity(S) &&
4640 "existing SCEV has not been properly invalidated");
4641 return S;
4642 }
4643 return nullptr;
4644}
4645
4646/// Return a SCEV corresponding to -V = -1*V
4648 SCEV::NoWrapFlags Flags) {
4649 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4650 return getConstant(
4651 cast<ConstantInt>(ConstantExpr::getNeg(VC->getValue())));
4652
4653 Type *Ty = V->getType();
4654 Ty = getEffectiveSCEVType(Ty);
4655 return getMulExpr(V, getMinusOne(Ty), Flags);
4656}
4657
4658/// If Expr computes ~A, return A else return nullptr
4659static const SCEV *MatchNotExpr(const SCEV *Expr) {
4660 const SCEV *MulOp;
4661 if (match(Expr, m_scev_Add(m_scev_AllOnes(),
4662 m_scev_Mul(m_scev_AllOnes(), m_SCEV(MulOp)))))
4663 return MulOp;
4664 return nullptr;
4665}
4666
4667/// Return a SCEV corresponding to ~V = -1-V
4669 assert(!V->getType()->isPointerTy() && "Can't negate pointer");
4670
4671 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4672 return getConstant(
4673 cast<ConstantInt>(ConstantExpr::getNot(VC->getValue())));
4674
4675 // Fold ~(u|s)(min|max)(~x, ~y) to (u|s)(max|min)(x, y)
4676 if (const SCEVMinMaxExpr *MME = dyn_cast<SCEVMinMaxExpr>(V)) {
4677 auto MatchMinMaxNegation = [&](const SCEVMinMaxExpr *MME) {
4678 SmallVector<const SCEV *, 2> MatchedOperands;
4679 for (const SCEV *Operand : MME->operands()) {
4680 const SCEV *Matched = MatchNotExpr(Operand);
4681 if (!Matched)
4682 return (const SCEV *)nullptr;
4683 MatchedOperands.push_back(Matched);
4684 }
4685 return getMinMaxExpr(SCEVMinMaxExpr::negate(MME->getSCEVType()),
4686 MatchedOperands);
4687 };
4688 if (const SCEV *Replaced = MatchMinMaxNegation(MME))
4689 return Replaced;
4690 }
4691
4692 Type *Ty = V->getType();
4693 Ty = getEffectiveSCEVType(Ty);
4694 return getMinusSCEV(getMinusOne(Ty), V);
4695}
4696
4698 assert(P->getType()->isPointerTy());
4699
4700 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(P)) {
4701 // The base of an AddRec is the first operand.
4702 SmallVector<const SCEV *> Ops{AddRec->operands()};
4703 Ops[0] = removePointerBase(Ops[0]);
4704 // Don't try to transfer nowrap flags for now. We could in some cases
4705 // (for example, if pointer operand of the AddRec is a SCEVUnknown).
4706 return getAddRecExpr(Ops, AddRec->getLoop(), SCEV::FlagAnyWrap);
4707 }
4708 if (auto *Add = dyn_cast<SCEVAddExpr>(P)) {
4709 // The base of an Add is the pointer operand.
4710 SmallVector<const SCEV *> Ops{Add->operands()};
4711 const SCEV **PtrOp = nullptr;
4712 for (const SCEV *&AddOp : Ops) {
4713 if (AddOp->getType()->isPointerTy()) {
4714 assert(!PtrOp && "Cannot have multiple pointer ops");
4715 PtrOp = &AddOp;
4716 }
4717 }
4718 *PtrOp = removePointerBase(*PtrOp);
4719 // Don't try to transfer nowrap flags for now. We could in some cases
4720 // (for example, if the pointer operand of the Add is a SCEVUnknown).
4721 return getAddExpr(Ops);
4722 }
4723 // Any other expression must be a pointer base.
4724 return getZero(P->getType());
4725}
4726
4727const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS,
4728 SCEV::NoWrapFlags Flags,
4729 unsigned Depth) {
4730 // Fast path: X - X --> 0.
4731 if (LHS == RHS)
4732 return getZero(LHS->getType());
4733
4734 // If we subtract two pointers with different pointer bases, bail.
4735 // Eventually, we're going to add an assertion to getMulExpr that we
4736 // can't multiply by a pointer.
4737 if (RHS->getType()->isPointerTy()) {
4738 if (!LHS->getType()->isPointerTy() ||
4739 getPointerBase(LHS) != getPointerBase(RHS))
4740 return getCouldNotCompute();
4741 LHS = removePointerBase(LHS);
4742 RHS = removePointerBase(RHS);
4743 }
4744
4745 // We represent LHS - RHS as LHS + (-1)*RHS. This transformation
4746 // makes it so that we cannot make much use of NUW.
4747 auto AddFlags = SCEV::FlagAnyWrap;
4748 const bool RHSIsNotMinSigned =
4750 if (hasFlags(Flags, SCEV::FlagNSW)) {
4751 // Let M be the minimum representable signed value. Then (-1)*RHS
4752 // signed-wraps if and only if RHS is M. That can happen even for
4753 // a NSW subtraction because e.g. (-1)*M signed-wraps even though
4754 // -1 - M does not. So to transfer NSW from LHS - RHS to LHS +
4755 // (-1)*RHS, we need to prove that RHS != M.
4756 //
4757 // If LHS is non-negative and we know that LHS - RHS does not
4758 // signed-wrap, then RHS cannot be M. So we can rule out signed-wrap
4759 // either by proving that RHS > M or that LHS >= 0.
4760 if (RHSIsNotMinSigned || isKnownNonNegative(LHS)) {
4761 AddFlags = SCEV::FlagNSW;
4762 }
4763 }
4764
4765 // FIXME: Find a correct way to transfer NSW to (-1)*M when LHS -
4766 // RHS is NSW and LHS >= 0.
4767 //
4768 // The difficulty here is that the NSW flag may have been proven
4769 // relative to a loop that is to be found in a recurrence in LHS and
4770 // not in RHS. Applying NSW to (-1)*M may then let the NSW have a
4771 // larger scope than intended.
4772 auto NegFlags = RHSIsNotMinSigned ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
4773
4774 return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags, Depth);
4775}
4776
4778 unsigned Depth) {
4779 Type *SrcTy = V->getType();
4780 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4781 "Cannot truncate or zero extend with non-integer arguments!");
4782 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4783 return V; // No conversion
4784 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4785 return getTruncateExpr(V, Ty, Depth);
4786 return getZeroExtendExpr(V, Ty, Depth);
4787}
4788
4790 unsigned Depth) {
4791 Type *SrcTy = V->getType();
4792 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4793 "Cannot truncate or zero extend with non-integer arguments!");
4794 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4795 return V; // No conversion
4796 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4797 return getTruncateExpr(V, Ty, Depth);
4798 return getSignExtendExpr(V, Ty, Depth);
4799}
4800
4801const SCEV *
4803 Type *SrcTy = V->getType();
4804 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4805 "Cannot noop or zero extend with non-integer arguments!");
4807 "getNoopOrZeroExtend cannot truncate!");
4808 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4809 return V; // No conversion
4810 return getZeroExtendExpr(V, Ty);
4811}
4812
4813const SCEV *
4815 Type *SrcTy = V->getType();
4816 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4817 "Cannot noop or sign extend with non-integer arguments!");
4819 "getNoopOrSignExtend cannot truncate!");
4820 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4821 return V; // No conversion
4822 return getSignExtendExpr(V, Ty);
4823}
4824
4825const SCEV *
4827 Type *SrcTy = V->getType();
4828 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4829 "Cannot noop or any extend with non-integer arguments!");
4831 "getNoopOrAnyExtend cannot truncate!");
4832 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4833 return V; // No conversion
4834 return getAnyExtendExpr(V, Ty);
4835}
4836
4837const SCEV *
4839 Type *SrcTy = V->getType();
4840 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4841 "Cannot truncate or noop with non-integer arguments!");
4843 "getTruncateOrNoop cannot extend!");
4844 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4845 return V; // No conversion
4846 return getTruncateExpr(V, Ty);
4847}
4848
4850 const SCEV *RHS) {
4851 const SCEV *PromotedLHS = LHS;
4852 const SCEV *PromotedRHS = RHS;
4853
4854 if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType()))
4855 PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
4856 else
4857 PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
4858
4859 return getUMaxExpr(PromotedLHS, PromotedRHS);
4860}
4861
4863 const SCEV *RHS,
4864 bool Sequential) {
4865 SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
4866 return getUMinFromMismatchedTypes(Ops, Sequential);
4867}
4868
4869const SCEV *
4871 bool Sequential) {
4872 assert(!Ops.empty() && "At least one operand must be!");
4873 // Trivial case.
4874 if (Ops.size() == 1)
4875 return Ops[0];
4876
4877 // Find the max type first.
4878 Type *MaxType = nullptr;
4879 for (const auto *S : Ops)
4880 if (MaxType)
4881 MaxType = getWiderType(MaxType, S->getType());
4882 else
4883 MaxType = S->getType();
4884 assert(MaxType && "Failed to find maximum type!");
4885
4886 // Extend all ops to max type.
4887 SmallVector<const SCEV *, 2> PromotedOps;
4888 for (const auto *S : Ops)
4889 PromotedOps.push_back(getNoopOrZeroExtend(S, MaxType));
4890
4891 // Generate umin.
4892 return getUMinExpr(PromotedOps, Sequential);
4893}
4894
4896 // A pointer operand may evaluate to a nonpointer expression, such as null.
4897 if (!V->getType()->isPointerTy())
4898 return V;
4899
4900 while (true) {
4901 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(V)) {
4902 V = AddRec->getStart();
4903 } else if (auto *Add = dyn_cast<SCEVAddExpr>(V)) {
4904 const SCEV *PtrOp = nullptr;
4905 for (const SCEV *AddOp : Add->operands()) {
4906 if (AddOp->getType()->isPointerTy()) {
4907 assert(!PtrOp && "Cannot have multiple pointer ops");
4908 PtrOp = AddOp;
4909 }
4910 }
4911 assert(PtrOp && "Must have pointer op");
4912 V = PtrOp;
4913 } else // Not something we can look further into.
4914 return V;
4915 }
4916}
4917
4918/// Push users of the given Instruction onto the given Worklist.
4922 // Push the def-use children onto the Worklist stack.
4923 for (User *U : I->users()) {
4924 auto *UserInsn = cast<Instruction>(U);
4925 if (Visited.insert(UserInsn).second)
4926 Worklist.push_back(UserInsn);
4927 }
4928}
4929
4930namespace {
4931
4932/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its start
4933/// expression in case its Loop is L. If it is not L then
4934/// if IgnoreOtherLoops is true then use AddRec itself
4935/// otherwise rewrite cannot be done.
4936/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
4937class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> {
4938public:
4939 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
4940 bool IgnoreOtherLoops = true) {
4941 SCEVInitRewriter Rewriter(L, SE);
4942 const SCEV *Result = Rewriter.visit(S);
4943 if (Rewriter.hasSeenLoopVariantSCEVUnknown())
4944 return SE.getCouldNotCompute();
4945 return Rewriter.hasSeenOtherLoops() && !IgnoreOtherLoops
4946 ? SE.getCouldNotCompute()
4947 : Result;
4948 }
4949
4950 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4951 if (!SE.isLoopInvariant(Expr, L))
4952 SeenLoopVariantSCEVUnknown = true;
4953 return Expr;
4954 }
4955
4956 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4957 // Only re-write AddRecExprs for this loop.
4958 if (Expr->getLoop() == L)
4959 return Expr->getStart();
4960 SeenOtherLoops = true;
4961 return Expr;
4962 }
4963
4964 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
4965
4966 bool hasSeenOtherLoops() { return SeenOtherLoops; }
4967
4968private:
4969 explicit SCEVInitRewriter(const Loop *L, ScalarEvolution &SE)
4970 : SCEVRewriteVisitor(SE), L(L) {}
4971
4972 const Loop *L;
4973 bool SeenLoopVariantSCEVUnknown = false;
4974 bool SeenOtherLoops = false;
4975};
4976
4977/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its post
4978/// increment expression in case its Loop is L. If it is not L then
4979/// use AddRec itself.
4980/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
4981class SCEVPostIncRewriter : public SCEVRewriteVisitor<SCEVPostIncRewriter> {
4982public:
4983 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE) {
4984 SCEVPostIncRewriter Rewriter(L, SE);
4985 const SCEV *Result = Rewriter.visit(S);
4986 return Rewriter.hasSeenLoopVariantSCEVUnknown()
4987 ? SE.getCouldNotCompute()
4988 : Result;
4989 }
4990
4991 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4992 if (!SE.isLoopInvariant(Expr, L))
4993 SeenLoopVariantSCEVUnknown = true;
4994 return Expr;
4995 }
4996
4997 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4998 // Only re-write AddRecExprs for this loop.
4999 if (Expr->getLoop() == L)
5000 return Expr->getPostIncExpr(SE);
5001 SeenOtherLoops = true;
5002 return Expr;
5003 }
5004
5005 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
5006
5007 bool hasSeenOtherLoops() { return SeenOtherLoops; }
5008
5009private:
5010 explicit SCEVPostIncRewriter(const Loop *L, ScalarEvolution &SE)
5011 : SCEVRewriteVisitor(SE), L(L) {}
5012
5013 const Loop *L;
5014 bool SeenLoopVariantSCEVUnknown = false;
5015 bool SeenOtherLoops = false;
5016};
5017
5018/// This class evaluates the compare condition by matching it against the
5019/// condition of loop latch. If there is a match we assume a true value
5020/// for the condition while building SCEV nodes.
5021class SCEVBackedgeConditionFolder
5022 : public SCEVRewriteVisitor<SCEVBackedgeConditionFolder> {
5023public:
5024 static const SCEV *rewrite(const SCEV *S, const Loop *L,
5025 ScalarEvolution &SE) {
5026 bool IsPosBECond = false;
5027 Value *BECond = nullptr;
5028 if (BasicBlock *Latch = L->getLoopLatch()) {
5029 BranchInst *BI = dyn_cast<BranchInst>(Latch->getTerminator());
5030 if (BI && BI->isConditional()) {
5031 assert(BI->getSuccessor(0) != BI->getSuccessor(1) &&
5032 "Both outgoing branches should not target same header!");
5033 BECond = BI->getCondition();
5034 IsPosBECond = BI->getSuccessor(0) == L->getHeader();
5035 } else {
5036 return S;
5037 }
5038 }
5039 SCEVBackedgeConditionFolder Rewriter(L, BECond, IsPosBECond, SE);
5040 return Rewriter.visit(S);
5041 }
5042
5043 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5044 const SCEV *Result = Expr;
5045 bool InvariantF = SE.isLoopInvariant(Expr, L);
5046
5047 if (!InvariantF) {
5049 switch (I->getOpcode()) {
5050 case Instruction::Select: {
5051 SelectInst *SI = cast<SelectInst>(I);
5052 std::optional<const SCEV *> Res =
5053 compareWithBackedgeCondition(SI->getCondition());
5054 if (Res) {
5055 bool IsOne = cast<SCEVConstant>(*Res)->getValue()->isOne();
5056 Result = SE.getSCEV(IsOne ? SI->getTrueValue() : SI->getFalseValue());
5057 }
5058 break;
5059 }
5060 default: {
5061 std::optional<const SCEV *> Res = compareWithBackedgeCondition(I);
5062 if (Res)
5063 Result = *Res;
5064 break;
5065 }
5066 }
5067 }
5068 return Result;
5069 }
5070
5071private:
5072 explicit SCEVBackedgeConditionFolder(const Loop *L, Value *BECond,
5073 bool IsPosBECond, ScalarEvolution &SE)
5074 : SCEVRewriteVisitor(SE), L(L), BackedgeCond(BECond),
5075 IsPositiveBECond(IsPosBECond) {}
5076
5077 std::optional<const SCEV *> compareWithBackedgeCondition(Value *IC);
5078
5079 const Loop *L;
5080 /// Loop back condition.
5081 Value *BackedgeCond = nullptr;
5082 /// Set to true if loop back is on positive branch condition.
5083 bool IsPositiveBECond;
5084};
5085
5086std::optional<const SCEV *>
5087SCEVBackedgeConditionFolder::compareWithBackedgeCondition(Value *IC) {
5088
5089 // If value matches the backedge condition for loop latch,
5090 // then return a constant evolution node based on loopback
5091 // branch taken.
5092 if (BackedgeCond == IC)
5093 return IsPositiveBECond ? SE.getOne(Type::getInt1Ty(SE.getContext()))
5095 return std::nullopt;
5096}
5097
5098class SCEVShiftRewriter : public SCEVRewriteVisitor<SCEVShiftRewriter> {
5099public:
5100 static const SCEV *rewrite(const SCEV *S, const Loop *L,
5101 ScalarEvolution &SE) {
5102 SCEVShiftRewriter Rewriter(L, SE);
5103 const SCEV *Result = Rewriter.visit(S);
5104 return Rewriter.isValid() ? Result : SE.getCouldNotCompute();
5105 }
5106
5107 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5108 // Only allow AddRecExprs for this loop.
5109 if (!SE.isLoopInvariant(Expr, L))
5110 Valid = false;
5111 return Expr;
5112 }
5113
5114 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
5115 if (Expr->getLoop() == L && Expr->isAffine())
5116 return SE.getMinusSCEV(Expr, Expr->getStepRecurrence(SE));
5117 Valid = false;
5118 return Expr;
5119 }
5120
5121 bool isValid() { return Valid; }
5122
5123private:
5124 explicit SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE)
5125 : SCEVRewriteVisitor(SE), L(L) {}
5126
5127 const Loop *L;
5128 bool Valid = true;
5129};
5130
5131} // end anonymous namespace
5132
5134ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) {
5135 if (!AR->isAffine())
5136 return SCEV::FlagAnyWrap;
5137
5138 using OBO = OverflowingBinaryOperator;
5139
5141
5142 if (!AR->hasNoSelfWrap()) {
5143 const SCEV *BECount = getConstantMaxBackedgeTakenCount(AR->getLoop());
5144 if (const SCEVConstant *BECountMax = dyn_cast<SCEVConstant>(BECount)) {
5145 ConstantRange StepCR = getSignedRange(AR->getStepRecurrence(*this));
5146 const APInt &BECountAP = BECountMax->getAPInt();
5147 unsigned NoOverflowBitWidth =
5148 BECountAP.getActiveBits() + StepCR.getMinSignedBits();
5149 if (NoOverflowBitWidth <= getTypeSizeInBits(AR->getType()))
5151 }
5152 }
5153
5154 if (!AR->hasNoSignedWrap()) {
5155 ConstantRange AddRecRange = getSignedRange(AR);
5156 ConstantRange IncRange = getSignedRange(AR->getStepRecurrence(*this));
5157
5159 Instruction::Add, IncRange, OBO::NoSignedWrap);
5160 if (NSWRegion.contains(AddRecRange))
5162 }
5163
5164 if (!AR->hasNoUnsignedWrap()) {
5165 ConstantRange AddRecRange = getUnsignedRange(AR);
5166 ConstantRange IncRange = getUnsignedRange(AR->getStepRecurrence(*this));
5167
5169 Instruction::Add, IncRange, OBO::NoUnsignedWrap);
5170 if (NUWRegion.contains(AddRecRange))
5172 }
5173
5174 return Result;
5175}
5176
5178ScalarEvolution::proveNoSignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5180
5181 if (AR->hasNoSignedWrap())
5182 return Result;
5183
5184 if (!AR->isAffine())
5185 return Result;
5186
5187 // This function can be expensive, only try to prove NSW once per AddRec.
5188 if (!SignedWrapViaInductionTried.insert(AR).second)
5189 return Result;
5190
5191 const SCEV *Step = AR->getStepRecurrence(*this);
5192 const Loop *L = AR->getLoop();
5193
5194 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5195 // Note that this serves two purposes: It filters out loops that are
5196 // simply not analyzable, and it covers the case where this code is
5197 // being called from within backedge-taken count analysis, such that
5198 // attempting to ask for the backedge-taken count would likely result
5199 // in infinite recursion. In the later case, the analysis code will
5200 // cope with a conservative value, and it will take care to purge
5201 // that value once it has finished.
5202 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5203
5204 // Normally, in the cases we can prove no-overflow via a
5205 // backedge guarding condition, we can also compute a backedge
5206 // taken count for the loop. The exceptions are assumptions and
5207 // guards present in the loop -- SCEV is not great at exploiting
5208 // these to compute max backedge taken counts, but can still use
5209 // these to prove lack of overflow. Use this fact to avoid
5210 // doing extra work that may not pay off.
5211
5212 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5213 AC.assumptions().empty())
5214 return Result;
5215
5216 // If the backedge is guarded by a comparison with the pre-inc value the
5217 // addrec is safe. Also, if the entry is guarded by a comparison with the
5218 // start value and the backedge is guarded by a comparison with the post-inc
5219 // value, the addrec is safe.
5221 const SCEV *OverflowLimit =
5222 getSignedOverflowLimitForStep(Step, &Pred, this);
5223 if (OverflowLimit &&
5224 (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) ||
5225 isKnownOnEveryIteration(Pred, AR, OverflowLimit))) {
5226 Result = setFlags(Result, SCEV::FlagNSW);
5227 }
5228 return Result;
5229}
5231ScalarEvolution::proveNoUnsignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5233
5234 if (AR->hasNoUnsignedWrap())
5235 return Result;
5236
5237 if (!AR->isAffine())
5238 return Result;
5239
5240 // This function can be expensive, only try to prove NUW once per AddRec.
5241 if (!UnsignedWrapViaInductionTried.insert(AR).second)
5242 return Result;
5243
5244 const SCEV *Step = AR->getStepRecurrence(*this);
5245 unsigned BitWidth = getTypeSizeInBits(AR->getType());
5246 const Loop *L = AR->getLoop();
5247
5248 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5249 // Note that this serves two purposes: It filters out loops that are
5250 // simply not analyzable, and it covers the case where this code is
5251 // being called from within backedge-taken count analysis, such that
5252 // attempting to ask for the backedge-taken count would likely result
5253 // in infinite recursion. In the later case, the analysis code will
5254 // cope with a conservative value, and it will take care to purge
5255 // that value once it has finished.
5256 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5257
5258 // Normally, in the cases we can prove no-overflow via a
5259 // backedge guarding condition, we can also compute a backedge
5260 // taken count for the loop. The exceptions are assumptions and
5261 // guards present in the loop -- SCEV is not great at exploiting
5262 // these to compute max backedge taken counts, but can still use
5263 // these to prove lack of overflow. Use this fact to avoid
5264 // doing extra work that may not pay off.
5265
5266 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5267 AC.assumptions().empty())
5268 return Result;
5269
5270 // If the backedge is guarded by a comparison with the pre-inc value the
5271 // addrec is safe. Also, if the entry is guarded by a comparison with the
5272 // start value and the backedge is guarded by a comparison with the post-inc
5273 // value, the addrec is safe.
5274 if (isKnownPositive(Step)) {
5275 const SCEV *N = getConstant(APInt::getMinValue(BitWidth) -
5276 getUnsignedRangeMax(Step));
5279 Result = setFlags(Result, SCEV::FlagNUW);
5280 }
5281 }
5282
5283 return Result;
5284}
5285
5286namespace {
5287
5288/// Represents an abstract binary operation. This may exist as a
5289/// normal instruction or constant expression, or may have been
5290/// derived from an expression tree.
5291struct BinaryOp {
5292 unsigned Opcode;
5293 Value *LHS;
5294 Value *RHS;
5295 bool IsNSW = false;
5296 bool IsNUW = false;
5297
5298 /// Op is set if this BinaryOp corresponds to a concrete LLVM instruction or
5299 /// constant expression.
5300 Operator *Op = nullptr;
5301
5302 explicit BinaryOp(Operator *Op)
5303 : Opcode(Op->getOpcode()), LHS(Op->getOperand(0)), RHS(Op->getOperand(1)),
5304 Op(Op) {
5305 if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Op)) {
5306 IsNSW = OBO->hasNoSignedWrap();
5307 IsNUW = OBO->hasNoUnsignedWrap();
5308 }
5309 }
5310
5311 explicit BinaryOp(unsigned Opcode, Value *LHS, Value *RHS, bool IsNSW = false,
5312 bool IsNUW = false)
5313 : Opcode(Opcode), LHS(LHS), RHS(RHS), IsNSW(IsNSW), IsNUW(IsNUW) {}
5314};
5315
5316} // end anonymous namespace
5317
5318/// Try to map \p V into a BinaryOp, and return \c std::nullopt on failure.
5319static std::optional<BinaryOp> MatchBinaryOp(Value *V, const DataLayout &DL,
5320 AssumptionCache &AC,
5321 const DominatorTree &DT,
5322 const Instruction *CxtI) {
5323 auto *Op = dyn_cast<Operator>(V);
5324 if (!Op)
5325 return std::nullopt;
5326
5327 // Implementation detail: all the cleverness here should happen without
5328 // creating new SCEV expressions -- our caller knowns tricks to avoid creating
5329 // SCEV expressions when possible, and we should not break that.
5330
5331 switch (Op->getOpcode()) {
5332 case Instruction::Add:
5333 case Instruction::Sub:
5334 case Instruction::Mul:
5335 case Instruction::UDiv:
5336 case Instruction::URem:
5337 case Instruction::And:
5338 case Instruction::AShr:
5339 case Instruction::Shl:
5340 return BinaryOp(Op);
5341
5342 case Instruction::Or: {
5343 // Convert or disjoint into add nuw nsw.
5344 if (cast<PossiblyDisjointInst>(Op)->isDisjoint())
5345 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1),
5346 /*IsNSW=*/true, /*IsNUW=*/true);
5347 return BinaryOp(Op);
5348 }
5349
5350 case Instruction::Xor:
5351 if (auto *RHSC = dyn_cast<ConstantInt>(Op->getOperand(1)))
5352 // If the RHS of the xor is a signmask, then this is just an add.
5353 // Instcombine turns add of signmask into xor as a strength reduction step.
5354 if (RHSC->getValue().isSignMask())
5355 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5356 // Binary `xor` is a bit-wise `add`.
5357 if (V->getType()->isIntegerTy(1))
5358 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5359 return BinaryOp(Op);
5360
5361 case Instruction::LShr:
5362 // Turn logical shift right of a constant into a unsigned divide.
5363 if (ConstantInt *SA = dyn_cast<ConstantInt>(Op->getOperand(1))) {
5364 uint32_t BitWidth = cast<IntegerType>(Op->getType())->getBitWidth();
5365
5366 // If the shift count is not less than the bitwidth, the result of
5367 // the shift is undefined. Don't try to analyze it, because the
5368 // resolution chosen here may differ from the resolution chosen in
5369 // other parts of the compiler.
5370 if (SA->getValue().ult(BitWidth)) {
5371 Constant *X =
5372 ConstantInt::get(SA->getContext(),
5373 APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
5374 return BinaryOp(Instruction::UDiv, Op->getOperand(0), X);
5375 }
5376 }
5377 return BinaryOp(Op);
5378
5379 case Instruction::ExtractValue: {
5380 auto *EVI = cast<ExtractValueInst>(Op);
5381 if (EVI->getNumIndices() != 1 || EVI->getIndices()[0] != 0)
5382 break;
5383
5384 auto *WO = dyn_cast<WithOverflowInst>(EVI->getAggregateOperand());
5385 if (!WO)
5386 break;
5387
5388 Instruction::BinaryOps BinOp = WO->getBinaryOp();
5389 bool Signed = WO->isSigned();
5390 // TODO: Should add nuw/nsw flags for mul as well.
5391 if (BinOp == Instruction::Mul || !isOverflowIntrinsicNoWrap(WO, DT))
5392 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS());
5393
5394 // Now that we know that all uses of the arithmetic-result component of
5395 // CI are guarded by the overflow check, we can go ahead and pretend
5396 // that the arithmetic is non-overflowing.
5397 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS(),
5398 /* IsNSW = */ Signed, /* IsNUW = */ !Signed);
5399 }
5400
5401 default:
5402 break;
5403 }
5404
5405 // Recognise intrinsic loop.decrement.reg, and as this has exactly the same
5406 // semantics as a Sub, return a binary sub expression.
5407 if (auto *II = dyn_cast<IntrinsicInst>(V))
5408 if (II->getIntrinsicID() == Intrinsic::loop_decrement_reg)
5409 return BinaryOp(Instruction::Sub, II->getOperand(0), II->getOperand(1));
5410
5411 return std::nullopt;
5412}
5413
5414/// Helper function to createAddRecFromPHIWithCasts. We have a phi
5415/// node whose symbolic (unknown) SCEV is \p SymbolicPHI, which is updated via
5416/// the loop backedge by a SCEVAddExpr, possibly also with a few casts on the
5417/// way. This function checks if \p Op, an operand of this SCEVAddExpr,
5418/// follows one of the following patterns:
5419/// Op == (SExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5420/// Op == (ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5421/// If the SCEV expression of \p Op conforms with one of the expected patterns
5422/// we return the type of the truncation operation, and indicate whether the
5423/// truncated type should be treated as signed/unsigned by setting
5424/// \p Signed to true/false, respectively.
5425static Type *isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI,
5426 bool &Signed, ScalarEvolution &SE) {
5427 // The case where Op == SymbolicPHI (that is, with no type conversions on
5428 // the way) is handled by the regular add recurrence creating logic and
5429 // would have already been triggered in createAddRecForPHI. Reaching it here
5430 // means that createAddRecFromPHI had failed for this PHI before (e.g.,
5431 // because one of the other operands of the SCEVAddExpr updating this PHI is
5432 // not invariant).
5433 //
5434 // Here we look for the case where Op = (ext(trunc(SymbolicPHI))), and in
5435 // this case predicates that allow us to prove that Op == SymbolicPHI will
5436 // be added.
5437 if (Op == SymbolicPHI)
5438 return nullptr;
5439
5440 unsigned SourceBits = SE.getTypeSizeInBits(SymbolicPHI->getType());
5441 unsigned NewBits = SE.getTypeSizeInBits(Op->getType());
5442 if (SourceBits != NewBits)
5443 return nullptr;
5444
5445 if (match(Op, m_scev_SExt(m_scev_Trunc(m_scev_Specific(SymbolicPHI))))) {
5446 Signed = true;
5447 return cast<SCEVCastExpr>(Op)->getOperand()->getType();
5448 }
5449 if (match(Op, m_scev_ZExt(m_scev_Trunc(m_scev_Specific(SymbolicPHI))))) {
5450 Signed = false;
5451 return cast<SCEVCastExpr>(Op)->getOperand()->getType();
5452 }
5453 return nullptr;
5454}
5455
5456static const Loop *isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI) {
5457 if (!PN->getType()->isIntegerTy())
5458 return nullptr;
5459 const Loop *L = LI.getLoopFor(PN->getParent());
5460 if (!L || L->getHeader() != PN->getParent())
5461 return nullptr;
5462 return L;
5463}
5464
5465// Analyze \p SymbolicPHI, a SCEV expression of a phi node, and check if the
5466// computation that updates the phi follows the following pattern:
5467// (SExt/ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy) + InvariantAccum
5468// which correspond to a phi->trunc->sext/zext->add->phi update chain.
5469// If so, try to see if it can be rewritten as an AddRecExpr under some
5470// Predicates. If successful, return them as a pair. Also cache the results
5471// of the analysis.
5472//
5473// Example usage scenario:
5474// Say the Rewriter is called for the following SCEV:
5475// 8 * ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5476// where:
5477// %X = phi i64 (%Start, %BEValue)
5478// It will visitMul->visitAdd->visitSExt->visitTrunc->visitUnknown(%X),
5479// and call this function with %SymbolicPHI = %X.
5480//
5481// The analysis will find that the value coming around the backedge has
5482// the following SCEV:
5483// BEValue = ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5484// Upon concluding that this matches the desired pattern, the function
5485// will return the pair {NewAddRec, SmallPredsVec} where:
5486// NewAddRec = {%Start,+,%Step}
5487// SmallPredsVec = {P1, P2, P3} as follows:
5488// P1(WrapPred): AR: {trunc(%Start),+,(trunc %Step)}<nsw> Flags: <nssw>
5489// P2(EqualPred): %Start == (sext i32 (trunc i64 %Start to i32) to i64)
5490// P3(EqualPred): %Step == (sext i32 (trunc i64 %Step to i32) to i64)
5491// The returned pair means that SymbolicPHI can be rewritten into NewAddRec
5492// under the predicates {P1,P2,P3}.
5493// This predicated rewrite will be cached in PredicatedSCEVRewrites:
5494// PredicatedSCEVRewrites[{%X,L}] = {NewAddRec, {P1,P2,P3)}
5495//
5496// TODO's:
5497//
5498// 1) Extend the Induction descriptor to also support inductions that involve
5499// casts: When needed (namely, when we are called in the context of the
5500// vectorizer induction analysis), a Set of cast instructions will be
5501// populated by this method, and provided back to isInductionPHI. This is
5502// needed to allow the vectorizer to properly record them to be ignored by
5503// the cost model and to avoid vectorizing them (otherwise these casts,
5504// which are redundant under the runtime overflow checks, will be
5505// vectorized, which can be costly).
5506//
5507// 2) Support additional induction/PHISCEV patterns: We also want to support
5508// inductions where the sext-trunc / zext-trunc operations (partly) occur
5509// after the induction update operation (the induction increment):
5510//
5511// (Trunc iy (SExt/ZExt ix (%SymbolicPHI + InvariantAccum) to iy) to ix)
5512// which correspond to a phi->add->trunc->sext/zext->phi update chain.
5513//
5514// (Trunc iy ((SExt/ZExt ix (%SymbolicPhi) to iy) + InvariantAccum) to ix)
5515// which correspond to a phi->trunc->add->sext/zext->phi update chain.
5516//
5517// 3) Outline common code with createAddRecFromPHI to avoid duplication.
5518std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5519ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI) {
5521
5522 // *** Part1: Analyze if we have a phi-with-cast pattern for which we can
5523 // return an AddRec expression under some predicate.
5524
5525 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5526 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5527 assert(L && "Expecting an integer loop header phi");
5528
5529 // The loop may have multiple entrances or multiple exits; we can analyze
5530 // this phi as an addrec if it has a unique entry value and a unique
5531 // backedge value.
5532 Value *BEValueV = nullptr, *StartValueV = nullptr;
5533 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5534 Value *V = PN->getIncomingValue(i);
5535 if (L->contains(PN->getIncomingBlock(i))) {
5536 if (!BEValueV) {
5537 BEValueV = V;
5538 } else if (BEValueV != V) {
5539 BEValueV = nullptr;
5540 break;
5541 }
5542 } else if (!StartValueV) {
5543 StartValueV = V;
5544 } else if (StartValueV != V) {
5545 StartValueV = nullptr;
5546 break;
5547 }
5548 }
5549 if (!BEValueV || !StartValueV)
5550 return std::nullopt;
5551
5552 const SCEV *BEValue = getSCEV(BEValueV);
5553
5554 // If the value coming around the backedge is an add with the symbolic
5555 // value we just inserted, possibly with casts that we can ignore under
5556 // an appropriate runtime guard, then we found a simple induction variable!
5557 const auto *Add = dyn_cast<SCEVAddExpr>(BEValue);
5558 if (!Add)
5559 return std::nullopt;
5560
5561 // If there is a single occurrence of the symbolic value, possibly
5562 // casted, replace it with a recurrence.
5563 unsigned FoundIndex = Add->getNumOperands();
5564 Type *TruncTy = nullptr;
5565 bool Signed;
5566 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5567 if ((TruncTy =
5568 isSimpleCastedPHI(Add->getOperand(i), SymbolicPHI, Signed, *this)))
5569 if (FoundIndex == e) {
5570 FoundIndex = i;
5571 break;
5572 }
5573
5574 if (FoundIndex == Add->getNumOperands())
5575 return std::nullopt;
5576
5577 // Create an add with everything but the specified operand.
5579 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5580 if (i != FoundIndex)
5581 Ops.push_back(Add->getOperand(i));
5582 const SCEV *Accum = getAddExpr(Ops);
5583
5584 // The runtime checks will not be valid if the step amount is
5585 // varying inside the loop.
5586 if (!isLoopInvariant(Accum, L))
5587 return std::nullopt;
5588
5589 // *** Part2: Create the predicates
5590
5591 // Analysis was successful: we have a phi-with-cast pattern for which we
5592 // can return an AddRec expression under the following predicates:
5593 //
5594 // P1: A Wrap predicate that guarantees that Trunc(Start) + i*Trunc(Accum)
5595 // fits within the truncated type (does not overflow) for i = 0 to n-1.
5596 // P2: An Equal predicate that guarantees that
5597 // Start = (Ext ix (Trunc iy (Start) to ix) to iy)
5598 // P3: An Equal predicate that guarantees that
5599 // Accum = (Ext ix (Trunc iy (Accum) to ix) to iy)
5600 //
5601 // As we next prove, the above predicates guarantee that:
5602 // Start + i*Accum = (Ext ix (Trunc iy ( Start + i*Accum ) to ix) to iy)
5603 //
5604 //
5605 // More formally, we want to prove that:
5606 // Expr(i+1) = Start + (i+1) * Accum
5607 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5608 //
5609 // Given that:
5610 // 1) Expr(0) = Start
5611 // 2) Expr(1) = Start + Accum
5612 // = (Ext ix (Trunc iy (Start) to ix) to iy) + Accum :: from P2
5613 // 3) Induction hypothesis (step i):
5614 // Expr(i) = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum
5615 //
5616 // Proof:
5617 // Expr(i+1) =
5618 // = Start + (i+1)*Accum
5619 // = (Start + i*Accum) + Accum
5620 // = Expr(i) + Accum
5621 // = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum + Accum
5622 // :: from step i
5623 //
5624 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy) + Accum + Accum
5625 //
5626 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy)
5627 // + (Ext ix (Trunc iy (Accum) to ix) to iy)
5628 // + Accum :: from P3
5629 //
5630 // = (Ext ix (Trunc iy ((Start + (i-1)*Accum) + Accum) to ix) to iy)
5631 // + Accum :: from P1: Ext(x)+Ext(y)=>Ext(x+y)
5632 //
5633 // = (Ext ix (Trunc iy (Start + i*Accum) to ix) to iy) + Accum
5634 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5635 //
5636 // By induction, the same applies to all iterations 1<=i<n:
5637 //
5638
5639 // Create a truncated addrec for which we will add a no overflow check (P1).
5640 const SCEV *StartVal = getSCEV(StartValueV);
5641 const SCEV *PHISCEV =
5642 getAddRecExpr(getTruncateExpr(StartVal, TruncTy),
5643 getTruncateExpr(Accum, TruncTy), L, SCEV::FlagAnyWrap);
5644
5645 // PHISCEV can be either a SCEVConstant or a SCEVAddRecExpr.
5646 // ex: If truncated Accum is 0 and StartVal is a constant, then PHISCEV
5647 // will be constant.
5648 //
5649 // If PHISCEV is a constant, then P1 degenerates into P2 or P3, so we don't
5650 // add P1.
5651 if (const auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5655 const SCEVPredicate *AddRecPred = getWrapPredicate(AR, AddedFlags);
5656 Predicates.push_back(AddRecPred);
5657 }
5658
5659 // Create the Equal Predicates P2,P3:
5660
5661 // It is possible that the predicates P2 and/or P3 are computable at
5662 // compile time due to StartVal and/or Accum being constants.
5663 // If either one is, then we can check that now and escape if either P2
5664 // or P3 is false.
5665
5666 // Construct the extended SCEV: (Ext ix (Trunc iy (Expr) to ix) to iy)
5667 // for each of StartVal and Accum
5668 auto getExtendedExpr = [&](const SCEV *Expr,
5669 bool CreateSignExtend) -> const SCEV * {
5670 assert(isLoopInvariant(Expr, L) && "Expr is expected to be invariant");
5671 const SCEV *TruncatedExpr = getTruncateExpr(Expr, TruncTy);
5672 const SCEV *ExtendedExpr =
5673 CreateSignExtend ? getSignExtendExpr(TruncatedExpr, Expr->getType())
5674 : getZeroExtendExpr(TruncatedExpr, Expr->getType());
5675 return ExtendedExpr;
5676 };
5677
5678 // Given:
5679 // ExtendedExpr = (Ext ix (Trunc iy (Expr) to ix) to iy
5680 // = getExtendedExpr(Expr)
5681 // Determine whether the predicate P: Expr == ExtendedExpr
5682 // is known to be false at compile time
5683 auto PredIsKnownFalse = [&](const SCEV *Expr,
5684 const SCEV *ExtendedExpr) -> bool {
5685 return Expr != ExtendedExpr &&
5686 isKnownPredicate(ICmpInst::ICMP_NE, Expr, ExtendedExpr);
5687 };
5688
5689 const SCEV *StartExtended = getExtendedExpr(StartVal, Signed);
5690 if (PredIsKnownFalse(StartVal, StartExtended)) {
5691 LLVM_DEBUG(dbgs() << "P2 is compile-time false\n";);
5692 return std::nullopt;
5693 }
5694
5695 // The Step is always Signed (because the overflow checks are either
5696 // NSSW or NUSW)
5697 const SCEV *AccumExtended = getExtendedExpr(Accum, /*CreateSignExtend=*/true);
5698 if (PredIsKnownFalse(Accum, AccumExtended)) {
5699 LLVM_DEBUG(dbgs() << "P3 is compile-time false\n";);
5700 return std::nullopt;
5701 }
5702
5703 auto AppendPredicate = [&](const SCEV *Expr,
5704 const SCEV *ExtendedExpr) -> void {
5705 if (Expr != ExtendedExpr &&
5706 !isKnownPredicate(ICmpInst::ICMP_EQ, Expr, ExtendedExpr)) {
5707 const SCEVPredicate *Pred = getEqualPredicate(Expr, ExtendedExpr);
5708 LLVM_DEBUG(dbgs() << "Added Predicate: " << *Pred);
5709 Predicates.push_back(Pred);
5710 }
5711 };
5712
5713 AppendPredicate(StartVal, StartExtended);
5714 AppendPredicate(Accum, AccumExtended);
5715
5716 // *** Part3: Predicates are ready. Now go ahead and create the new addrec in
5717 // which the casts had been folded away. The caller can rewrite SymbolicPHI
5718 // into NewAR if it will also add the runtime overflow checks specified in
5719 // Predicates.
5720 auto *NewAR = getAddRecExpr(StartVal, Accum, L, SCEV::FlagAnyWrap);
5721
5722 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> PredRewrite =
5723 std::make_pair(NewAR, Predicates);
5724 // Remember the result of the analysis for this SCEV at this locayyytion.
5725 PredicatedSCEVRewrites[{SymbolicPHI, L}] = PredRewrite;
5726 return PredRewrite;
5727}
5728
5729std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5731 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5732 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5733 if (!L)
5734 return std::nullopt;
5735
5736 // Check to see if we already analyzed this PHI.
5737 auto I = PredicatedSCEVRewrites.find({SymbolicPHI, L});
5738 if (I != PredicatedSCEVRewrites.end()) {
5739 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> Rewrite =
5740 I->second;
5741 // Analysis was done before and failed to create an AddRec:
5742 if (Rewrite.first == SymbolicPHI)
5743 return std::nullopt;
5744 // Analysis was done before and succeeded to create an AddRec under
5745 // a predicate:
5746 assert(isa<SCEVAddRecExpr>(Rewrite.first) && "Expected an AddRec");
5747 assert(!(Rewrite.second).empty() && "Expected to find Predicates");
5748 return Rewrite;
5749 }
5750
5751 std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5752 Rewrite = createAddRecFromPHIWithCastsImpl(SymbolicPHI);
5753
5754 // Record in the cache that the analysis failed
5755 if (!Rewrite) {
5757 PredicatedSCEVRewrites[{SymbolicPHI, L}] = {SymbolicPHI, Predicates};
5758 return std::nullopt;
5759 }
5760
5761 return Rewrite;
5762}
5763
5764// FIXME: This utility is currently required because the Rewriter currently
5765// does not rewrite this expression:
5766// {0, +, (sext ix (trunc iy to ix) to iy)}
5767// into {0, +, %step},
5768// even when the following Equal predicate exists:
5769// "%step == (sext ix (trunc iy to ix) to iy)".
5771 const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const {
5772 if (AR1 == AR2)
5773 return true;
5774
5775 auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool {
5776 if (Expr1 != Expr2 &&
5777 !Preds->implies(SE.getEqualPredicate(Expr1, Expr2), SE) &&
5778 !Preds->implies(SE.getEqualPredicate(Expr2, Expr1), SE))
5779 return false;
5780 return true;
5781 };
5782
5783 if (!areExprsEqual(AR1->getStart(), AR2->getStart()) ||
5784 !areExprsEqual(AR1->getStepRecurrence(SE), AR2->getStepRecurrence(SE)))
5785 return false;
5786 return true;
5787}
5788
5789/// A helper function for createAddRecFromPHI to handle simple cases.
5790///
5791/// This function tries to find an AddRec expression for the simplest (yet most
5792/// common) cases: PN = PHI(Start, OP(Self, LoopInvariant)).
5793/// If it fails, createAddRecFromPHI will use a more general, but slow,
5794/// technique for finding the AddRec expression.
5795const SCEV *ScalarEvolution::createSimpleAffineAddRec(PHINode *PN,
5796 Value *BEValueV,
5797 Value *StartValueV) {
5798 const Loop *L = LI.getLoopFor(PN->getParent());
5799 assert(L && L->getHeader() == PN->getParent());
5800 assert(BEValueV && StartValueV);
5801
5802 auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN);
5803 if (!BO)
5804 return nullptr;
5805
5806 if (BO->Opcode != Instruction::Add)
5807 return nullptr;
5808
5809 const SCEV *Accum = nullptr;
5810 if (BO->LHS == PN && L->isLoopInvariant(BO->RHS))
5811 Accum = getSCEV(BO->RHS);
5812 else if (BO->RHS == PN && L->isLoopInvariant(BO->LHS))
5813 Accum = getSCEV(BO->LHS);
5814
5815 if (!Accum)
5816 return nullptr;
5817
5819 if (BO->IsNUW)
5820 Flags = setFlags(Flags, SCEV::FlagNUW);
5821 if (BO->IsNSW)
5822 Flags = setFlags(Flags, SCEV::FlagNSW);
5823
5824 const SCEV *StartVal = getSCEV(StartValueV);
5825 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5826 insertValueToMap(PN, PHISCEV);
5827
5828 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5829 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR),
5831 proveNoWrapViaConstantRanges(AR)));
5832 }
5833
5834 // We can add Flags to the post-inc expression only if we
5835 // know that it is *undefined behavior* for BEValueV to
5836 // overflow.
5837 if (auto *BEInst = dyn_cast<Instruction>(BEValueV)) {
5838 assert(isLoopInvariant(Accum, L) &&
5839 "Accum is defined outside L, but is not invariant?");
5840 if (isAddRecNeverPoison(BEInst, L))
5841 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5842 }
5843
5844 return PHISCEV;
5845}
5846
5847const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
5848 const Loop *L = LI.getLoopFor(PN->getParent());
5849 if (!L || L->getHeader() != PN->getParent())
5850 return nullptr;
5851
5852 // The loop may have multiple entrances or multiple exits; we can analyze
5853 // this phi as an addrec if it has a unique entry value and a unique
5854 // backedge value.
5855 Value *BEValueV = nullptr, *StartValueV = nullptr;
5856 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5857 Value *V = PN->getIncomingValue(i);
5858 if (L->contains(PN->getIncomingBlock(i))) {
5859 if (!BEValueV) {
5860 BEValueV = V;
5861 } else if (BEValueV != V) {
5862 BEValueV = nullptr;
5863 break;
5864 }
5865 } else if (!StartValueV) {
5866 StartValueV = V;
5867 } else if (StartValueV != V) {
5868 StartValueV = nullptr;
5869 break;
5870 }
5871 }
5872 if (!BEValueV || !StartValueV)
5873 return nullptr;
5874
5875 assert(ValueExprMap.find_as(PN) == ValueExprMap.end() &&
5876 "PHI node already processed?");
5877
5878 // First, try to find AddRec expression without creating a fictituos symbolic
5879 // value for PN.
5880 if (auto *S = createSimpleAffineAddRec(PN, BEValueV, StartValueV))
5881 return S;
5882
5883 // Handle PHI node value symbolically.
5884 const SCEV *SymbolicName = getUnknown(PN);
5885 insertValueToMap(PN, SymbolicName);
5886
5887 // Using this symbolic name for the PHI, analyze the value coming around
5888 // the back-edge.
5889 const SCEV *BEValue = getSCEV(BEValueV);
5890
5891 // NOTE: If BEValue is loop invariant, we know that the PHI node just
5892 // has a special value for the first iteration of the loop.
5893
5894 // If the value coming around the backedge is an add with the symbolic
5895 // value we just inserted, then we found a simple induction variable!
5896 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) {
5897 // If there is a single occurrence of the symbolic value, replace it
5898 // with a recurrence.
5899 unsigned FoundIndex = Add->getNumOperands();
5900 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5901 if (Add->getOperand(i) == SymbolicName)
5902 if (FoundIndex == e) {
5903 FoundIndex = i;
5904 break;
5905 }
5906
5907 if (FoundIndex != Add->getNumOperands()) {
5908 // Create an add with everything but the specified operand.
5910 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5911 if (i != FoundIndex)
5912 Ops.push_back(SCEVBackedgeConditionFolder::rewrite(Add->getOperand(i),
5913 L, *this));
5914 const SCEV *Accum = getAddExpr(Ops);
5915
5916 // This is not a valid addrec if the step amount is varying each
5917 // loop iteration, but is not itself an addrec in this loop.
5918 if (isLoopInvariant(Accum, L) ||
5919 (isa<SCEVAddRecExpr>(Accum) &&
5920 cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) {
5922
5923 if (auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN)) {
5924 if (BO->Opcode == Instruction::Add && BO->LHS == PN) {
5925 if (BO->IsNUW)
5926 Flags = setFlags(Flags, SCEV::FlagNUW);
5927 if (BO->IsNSW)
5928 Flags = setFlags(Flags, SCEV::FlagNSW);
5929 }
5930 } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(BEValueV)) {
5931 if (GEP->getOperand(0) == PN) {
5932 GEPNoWrapFlags NW = GEP->getNoWrapFlags();
5933 // If the increment has any nowrap flags, then we know the address
5934 // space cannot be wrapped around.
5935 if (NW != GEPNoWrapFlags::none())
5936 Flags = setFlags(Flags, SCEV::FlagNW);
5937 // If the GEP is nuw or nusw with non-negative offset, we know that
5938 // no unsigned wrap occurs. We cannot set the nsw flag as only the
5939 // offset is treated as signed, while the base is unsigned.
5940 if (NW.hasNoUnsignedWrap() ||
5942 Flags = setFlags(Flags, SCEV::FlagNUW);
5943 }
5944
5945 // We cannot transfer nuw and nsw flags from subtraction
5946 // operations -- sub nuw X, Y is not the same as add nuw X, -Y
5947 // for instance.
5948 }
5949
5950 const SCEV *StartVal = getSCEV(StartValueV);
5951 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5952
5953 // Okay, for the entire analysis of this edge we assumed the PHI
5954 // to be symbolic. We now need to go back and purge all of the
5955 // entries for the scalars that use the symbolic expression.
5956 forgetMemoizedResults(SymbolicName);
5957 insertValueToMap(PN, PHISCEV);
5958
5959 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5960 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR),
5962 proveNoWrapViaConstantRanges(AR)));
5963 }
5964
5965 // We can add Flags to the post-inc expression only if we
5966 // know that it is *undefined behavior* for BEValueV to
5967 // overflow.
5968 if (auto *BEInst = dyn_cast<Instruction>(BEValueV))
5969 if (isLoopInvariant(Accum, L) && isAddRecNeverPoison(BEInst, L))
5970 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5971
5972 return PHISCEV;
5973 }
5974 }
5975 } else {
5976 // Otherwise, this could be a loop like this:
5977 // i = 0; for (j = 1; ..; ++j) { .... i = j; }
5978 // In this case, j = {1,+,1} and BEValue is j.
5979 // Because the other in-value of i (0) fits the evolution of BEValue
5980 // i really is an addrec evolution.
5981 //
5982 // We can generalize this saying that i is the shifted value of BEValue
5983 // by one iteration:
5984 // PHI(f(0), f({1,+,1})) --> f({0,+,1})
5985
5986 // Do not allow refinement in rewriting of BEValue.
5987 const SCEV *Shifted = SCEVShiftRewriter::rewrite(BEValue, L, *this);
5988 const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this, false);
5989 if (Shifted != getCouldNotCompute() && Start != getCouldNotCompute() &&
5990 isGuaranteedNotToCauseUB(Shifted) && ::impliesPoison(Shifted, Start)) {
5991 const SCEV *StartVal = getSCEV(StartValueV);
5992 if (Start == StartVal) {
5993 // Okay, for the entire analysis of this edge we assumed the PHI
5994 // to be symbolic. We now need to go back and purge all of the
5995 // entries for the scalars that use the symbolic expression.
5996 forgetMemoizedResults(SymbolicName);
5997 insertValueToMap(PN, Shifted);
5998 return Shifted;
5999 }
6000 }
6001 }
6002
6003 // Remove the temporary PHI node SCEV that has been inserted while intending
6004 // to create an AddRecExpr for this PHI node. We can not keep this temporary
6005 // as it will prevent later (possibly simpler) SCEV expressions to be added
6006 // to the ValueExprMap.
6007 eraseValueFromMap(PN);
6008
6009 return nullptr;
6010}
6011
6012// Try to match a control flow sequence that branches out at BI and merges back
6013// at Merge into a "C ? LHS : RHS" select pattern. Return true on a successful
6014// match.
6016 Value *&C, Value *&LHS, Value *&RHS) {
6017 C = BI->getCondition();
6018
6019 BasicBlockEdge LeftEdge(BI->getParent(), BI->getSuccessor(0));
6020 BasicBlockEdge RightEdge(BI->getParent(), BI->getSuccessor(1));
6021
6022 if (!LeftEdge.isSingleEdge())
6023 return false;
6024
6025 assert(RightEdge.isSingleEdge() && "Follows from LeftEdge.isSingleEdge()");
6026
6027 Use &LeftUse = Merge->getOperandUse(0);
6028 Use &RightUse = Merge->getOperandUse(1);
6029
6030 if (DT.dominates(LeftEdge, LeftUse) && DT.dominates(RightEdge, RightUse)) {
6031 LHS = LeftUse;
6032 RHS = RightUse;
6033 return true;
6034 }
6035
6036 if (DT.dominates(LeftEdge, RightUse) && DT.dominates(RightEdge, LeftUse)) {
6037 LHS = RightUse;
6038 RHS = LeftUse;
6039 return true;
6040 }
6041
6042 return false;
6043}
6044
6045const SCEV *ScalarEvolution::createNodeFromSelectLikePHI(PHINode *PN) {
6046 auto IsReachable =
6047 [&](BasicBlock *BB) { return DT.isReachableFromEntry(BB); };
6048 if (PN->getNumIncomingValues() == 2 && all_of(PN->blocks(), IsReachable)) {
6049 // Try to match
6050 //
6051 // br %cond, label %left, label %right
6052 // left:
6053 // br label %merge
6054 // right:
6055 // br label %merge
6056 // merge:
6057 // V = phi [ %x, %left ], [ %y, %right ]
6058 //
6059 // as "select %cond, %x, %y"
6060
6061 BasicBlock *IDom = DT[PN->getParent()]->getIDom()->getBlock();
6062 assert(IDom && "At least the entry block should dominate PN");
6063
6064 auto *BI = dyn_cast<BranchInst>(IDom->getTerminator());
6065 Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr;
6066
6067 if (BI && BI->isConditional() &&
6068 BrPHIToSelect(DT, BI, PN, Cond, LHS, RHS) &&
6071 return createNodeForSelectOrPHI(PN, Cond, LHS, RHS);
6072 }
6073
6074 return nullptr;
6075}
6076
6077/// Returns SCEV for the first operand of a phi if all phi operands have
6078/// identical opcodes and operands
6079/// eg.
6080/// a: %add = %a + %b
6081/// br %c
6082/// b: %add1 = %a + %b
6083/// br %c
6084/// c: %phi = phi [%add, a], [%add1, b]
6085/// scev(%phi) => scev(%add)
6086const SCEV *
6087ScalarEvolution::createNodeForPHIWithIdenticalOperands(PHINode *PN) {
6088 BinaryOperator *CommonInst = nullptr;
6089 // Check if instructions are identical.
6090 for (Value *Incoming : PN->incoming_values()) {
6091 auto *IncomingInst = dyn_cast<BinaryOperator>(Incoming);
6092 if (!IncomingInst)
6093 return nullptr;
6094 if (CommonInst) {
6095 if (!CommonInst->isIdenticalToWhenDefined(IncomingInst))
6096 return nullptr; // Not identical, give up
6097 } else {
6098 // Remember binary operator
6099 CommonInst = IncomingInst;
6100 }
6101 }
6102 if (!CommonInst)
6103 return nullptr;
6104
6105 // Check if SCEV exprs for instructions are identical.
6106 const SCEV *CommonSCEV = getSCEV(CommonInst);
6107 bool SCEVExprsIdentical =
6109 [this, CommonSCEV](Value *V) { return CommonSCEV == getSCEV(V); });
6110 return SCEVExprsIdentical ? CommonSCEV : nullptr;
6111}
6112
6113const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
6114 if (const SCEV *S = createAddRecFromPHI(PN))
6115 return S;
6116
6117 // We do not allow simplifying phi (undef, X) to X here, to avoid reusing the
6118 // phi node for X.
6119 if (Value *V = simplifyInstruction(
6120 PN, {getDataLayout(), &TLI, &DT, &AC, /*CtxI=*/nullptr,
6121 /*UseInstrInfo=*/true, /*CanUseUndef=*/false}))
6122 return getSCEV(V);
6123
6124 if (const SCEV *S = createNodeForPHIWithIdenticalOperands(PN))
6125 return S;
6126
6127 if (const SCEV *S = createNodeFromSelectLikePHI(PN))
6128 return S;
6129
6130 // If it's not a loop phi, we can't handle it yet.
6131 return getUnknown(PN);
6132}
6133
6134bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind,
6135 SCEVTypes RootKind) {
6136 struct FindClosure {
6137 const SCEV *OperandToFind;
6138 const SCEVTypes RootKind; // Must be a sequential min/max expression.
6139 const SCEVTypes NonSequentialRootKind; // Non-seq variant of RootKind.
6140
6141 bool Found = false;
6142
6143 bool canRecurseInto(SCEVTypes Kind) const {
6144 // We can only recurse into the SCEV expression of the same effective type
6145 // as the type of our root SCEV expression, and into zero-extensions.
6146 return RootKind == Kind || NonSequentialRootKind == Kind ||
6147 scZeroExtend == Kind;
6148 };
6149
6150 FindClosure(const SCEV *OperandToFind, SCEVTypes RootKind)
6151 : OperandToFind(OperandToFind), RootKind(RootKind),
6152 NonSequentialRootKind(
6154 RootKind)) {}
6155
6156 bool follow(const SCEV *S) {
6157 Found = S == OperandToFind;
6158
6159 return !isDone() && canRecurseInto(S->getSCEVType());
6160 }
6161
6162 bool isDone() const { return Found; }
6163 };
6164
6165 FindClosure FC(OperandToFind, RootKind);
6166 visitAll(Root, FC);
6167 return FC.Found;
6168}
6169
6170std::optional<const SCEV *>
6171ScalarEvolution::createNodeForSelectOrPHIInstWithICmpInstCond(Type *Ty,
6172 ICmpInst *Cond,
6173 Value *TrueVal,
6174 Value *FalseVal) {
6175 // Try to match some simple smax or umax patterns.
6176 auto *ICI = Cond;
6177
6178 Value *LHS = ICI->getOperand(0);
6179 Value *RHS = ICI->getOperand(1);
6180
6181 switch (ICI->getPredicate()) {
6182 case ICmpInst::ICMP_SLT:
6183 case ICmpInst::ICMP_SLE:
6184 case ICmpInst::ICMP_ULT:
6185 case ICmpInst::ICMP_ULE:
6186 std::swap(LHS, RHS);
6187 [[fallthrough]];
6188 case ICmpInst::ICMP_SGT:
6189 case ICmpInst::ICMP_SGE:
6190 case ICmpInst::ICMP_UGT:
6191 case ICmpInst::ICMP_UGE:
6192 // a > b ? a+x : b+x -> max(a, b)+x
6193 // a > b ? b+x : a+x -> min(a, b)+x
6195 bool Signed = ICI->isSigned();
6196 const SCEV *LA = getSCEV(TrueVal);
6197 const SCEV *RA = getSCEV(FalseVal);
6198 const SCEV *LS = getSCEV(LHS);
6199 const SCEV *RS = getSCEV(RHS);
6200 if (LA->getType()->isPointerTy()) {
6201 // FIXME: Handle cases where LS/RS are pointers not equal to LA/RA.
6202 // Need to make sure we can't produce weird expressions involving
6203 // negated pointers.
6204 if (LA == LS && RA == RS)
6205 return Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS);
6206 if (LA == RS && RA == LS)
6207 return Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS);
6208 }
6209 auto CoerceOperand = [&](const SCEV *Op) -> const SCEV * {
6210 if (Op->getType()->isPointerTy()) {
6213 return Op;
6214 }
6215 if (Signed)
6216 Op = getNoopOrSignExtend(Op, Ty);
6217 else
6218 Op = getNoopOrZeroExtend(Op, Ty);
6219 return Op;
6220 };
6221 LS = CoerceOperand(LS);
6222 RS = CoerceOperand(RS);
6224 break;
6225 const SCEV *LDiff = getMinusSCEV(LA, LS);
6226 const SCEV *RDiff = getMinusSCEV(RA, RS);
6227 if (LDiff == RDiff)
6228 return getAddExpr(Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS),
6229 LDiff);
6230 LDiff = getMinusSCEV(LA, RS);
6231 RDiff = getMinusSCEV(RA, LS);
6232 if (LDiff == RDiff)
6233 return getAddExpr(Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS),
6234 LDiff);
6235 }
6236 break;
6237 case ICmpInst::ICMP_NE:
6238 // x != 0 ? x+y : C+y -> x == 0 ? C+y : x+y
6239 std::swap(TrueVal, FalseVal);
6240 [[fallthrough]];
6241 case ICmpInst::ICMP_EQ:
6242 // x == 0 ? C+y : x+y -> umax(x, C)+y iff C u<= 1
6245 const SCEV *X = getNoopOrZeroExtend(getSCEV(LHS), Ty);
6246 const SCEV *TrueValExpr = getSCEV(TrueVal); // C+y
6247 const SCEV *FalseValExpr = getSCEV(FalseVal); // x+y
6248 const SCEV *Y = getMinusSCEV(FalseValExpr, X); // y = (x+y)-x
6249 const SCEV *C = getMinusSCEV(TrueValExpr, Y); // C = (C+y)-y
6250 if (isa<SCEVConstant>(C) && cast<SCEVConstant>(C)->getAPInt().ule(1))
6251 return getAddExpr(getUMaxExpr(X, C), Y);
6252 }
6253 // x == 0 ? 0 : umin (..., x, ...) -> umin_seq(x, umin (...))
6254 // x == 0 ? 0 : umin_seq(..., x, ...) -> umin_seq(x, umin_seq(...))
6255 // x == 0 ? 0 : umin (..., umin_seq(..., x, ...), ...)
6256 // -> umin_seq(x, umin (..., umin_seq(...), ...))
6258 isa<ConstantInt>(TrueVal) && cast<ConstantInt>(TrueVal)->isZero()) {
6259 const SCEV *X = getSCEV(LHS);
6260 while (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(X))
6261 X = ZExt->getOperand();
6262 if (getTypeSizeInBits(X->getType()) <= getTypeSizeInBits(Ty)) {
6263 const SCEV *FalseValExpr = getSCEV(FalseVal);
6264 if (SCEVMinMaxExprContains(FalseValExpr, X, scSequentialUMinExpr))
6265 return getUMinExpr(getNoopOrZeroExtend(X, Ty), FalseValExpr,
6266 /*Sequential=*/true);
6267 }
6268 }
6269 break;
6270 default:
6271 break;
6272 }
6273
6274 return std::nullopt;
6275}
6276
6277static std::optional<const SCEV *>
6279 const SCEV *TrueExpr, const SCEV *FalseExpr) {
6280 assert(CondExpr->getType()->isIntegerTy(1) &&
6281 TrueExpr->getType() == FalseExpr->getType() &&
6282 TrueExpr->getType()->isIntegerTy(1) &&
6283 "Unexpected operands of a select.");
6284
6285 // i1 cond ? i1 x : i1 C --> C + (i1 cond ? (i1 x - i1 C) : i1 0)
6286 // --> C + (umin_seq cond, x - C)
6287 //
6288 // i1 cond ? i1 C : i1 x --> C + (i1 cond ? i1 0 : (i1 x - i1 C))
6289 // --> C + (i1 ~cond ? (i1 x - i1 C) : i1 0)
6290 // --> C + (umin_seq ~cond, x - C)
6291
6292 // FIXME: while we can't legally model the case where both of the hands
6293 // are fully variable, we only require that the *difference* is constant.
6294 if (!isa<SCEVConstant>(TrueExpr) && !isa<SCEVConstant>(FalseExpr))
6295 return std::nullopt;
6296
6297 const SCEV *X, *C;
6298 if (isa<SCEVConstant>(TrueExpr)) {
6299 CondExpr = SE->getNotSCEV(CondExpr);
6300 X = FalseExpr;
6301 C = TrueExpr;
6302 } else {
6303 X = TrueExpr;
6304 C = FalseExpr;
6305 }
6306 return SE->getAddExpr(C, SE->getUMinExpr(CondExpr, SE->getMinusSCEV(X, C),
6307 /*Sequential=*/true));
6308}
6309
6310static std::optional<const SCEV *>
6312 Value *FalseVal) {
6313 if (!isa<ConstantInt>(TrueVal) && !isa<ConstantInt>(FalseVal))
6314 return std::nullopt;
6315
6316 const auto *SECond = SE->getSCEV(Cond);
6317 const auto *SETrue = SE->getSCEV(TrueVal);
6318 const auto *SEFalse = SE->getSCEV(FalseVal);
6319 return createNodeForSelectViaUMinSeq(SE, SECond, SETrue, SEFalse);
6320}
6321
6322const SCEV *ScalarEvolution::createNodeForSelectOrPHIViaUMinSeq(
6323 Value *V, Value *Cond, Value *TrueVal, Value *FalseVal) {
6324 assert(Cond->getType()->isIntegerTy(1) && "Select condition is not an i1?");
6325 assert(TrueVal->getType() == FalseVal->getType() &&
6326 V->getType() == TrueVal->getType() &&
6327 "Types of select hands and of the result must match.");
6328
6329 // For now, only deal with i1-typed `select`s.
6330 if (!V->getType()->isIntegerTy(1))
6331 return getUnknown(V);
6332
6333 if (std::optional<const SCEV *> S =
6334 createNodeForSelectViaUMinSeq(this, Cond, TrueVal, FalseVal))
6335 return *S;
6336
6337 return getUnknown(V);
6338}
6339
6340const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Value *V, Value *Cond,
6341 Value *TrueVal,
6342 Value *FalseVal) {
6343 // Handle "constant" branch or select. This can occur for instance when a
6344 // loop pass transforms an inner loop and moves on to process the outer loop.
6345 if (auto *CI = dyn_cast<ConstantInt>(Cond))
6346 return getSCEV(CI->isOne() ? TrueVal : FalseVal);
6347
6348 if (auto *I = dyn_cast<Instruction>(V)) {
6349 if (auto *ICI = dyn_cast<ICmpInst>(Cond)) {
6350 if (std::optional<const SCEV *> S =
6351 createNodeForSelectOrPHIInstWithICmpInstCond(I->getType(), ICI,
6352 TrueVal, FalseVal))
6353 return *S;
6354 }
6355 }
6356
6357 return createNodeForSelectOrPHIViaUMinSeq(V, Cond, TrueVal, FalseVal);
6358}
6359
6360/// Expand GEP instructions into add and multiply operations. This allows them
6361/// to be analyzed by regular SCEV code.
6362const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
6363 assert(GEP->getSourceElementType()->isSized() &&
6364 "GEP source element type must be sized");
6365
6367 for (Value *Index : GEP->indices())
6368 IndexExprs.push_back(getSCEV(Index));
6369 return getGEPExpr(GEP, IndexExprs);
6370}
6371
6372APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S,
6373 const Instruction *CtxI) {
6374 uint64_t BitWidth = getTypeSizeInBits(S->getType());
6375 auto GetShiftedByZeros = [BitWidth](uint32_t TrailingZeros) {
6376 return TrailingZeros >= BitWidth
6378 : APInt::getOneBitSet(BitWidth, TrailingZeros);
6379 };
6380 auto GetGCDMultiple = [this, CtxI](const SCEVNAryExpr *N) {
6381 // The result is GCD of all operands results.
6382 APInt Res = getConstantMultiple(N->getOperand(0), CtxI);
6383 for (unsigned I = 1, E = N->getNumOperands(); I < E && Res != 1; ++I)
6385 Res, getConstantMultiple(N->getOperand(I), CtxI));
6386 return Res;
6387 };
6388
6389 switch (S->getSCEVType()) {
6390 case scConstant:
6391 return cast<SCEVConstant>(S)->getAPInt();
6392 case scPtrToAddr:
6393 case scPtrToInt:
6394 return getConstantMultiple(cast<SCEVCastExpr>(S)->getOperand());
6395 case scUDivExpr:
6396 case scVScale:
6397 return APInt(BitWidth, 1);
6398 case scTruncate: {
6399 // Only multiples that are a power of 2 will hold after truncation.
6400 const SCEVTruncateExpr *T = cast<SCEVTruncateExpr>(S);
6401 uint32_t TZ = getMinTrailingZeros(T->getOperand(), CtxI);
6402 return GetShiftedByZeros(TZ);
6403 }
6404 case scZeroExtend: {
6405 const SCEVZeroExtendExpr *Z = cast<SCEVZeroExtendExpr>(S);
6406 return getConstantMultiple(Z->getOperand(), CtxI).zext(BitWidth);
6407 }
6408 case scSignExtend: {
6409 // Only multiples that are a power of 2 will hold after sext.
6410 const SCEVSignExtendExpr *E = cast<SCEVSignExtendExpr>(S);
6411 uint32_t TZ = getMinTrailingZeros(E->getOperand(), CtxI);
6412 return GetShiftedByZeros(TZ);
6413 }
6414 case scMulExpr: {
6415 const SCEVMulExpr *M = cast<SCEVMulExpr>(S);
6416 if (M->hasNoUnsignedWrap()) {
6417 // The result is the product of all operand results.
6418 APInt Res = getConstantMultiple(M->getOperand(0), CtxI);
6419 for (const SCEV *Operand : M->operands().drop_front())
6420 Res = Res * getConstantMultiple(Operand, CtxI);
6421 return Res;
6422 }
6423
6424 // If there are no wrap guarentees, find the trailing zeros, which is the
6425 // sum of trailing zeros for all its operands.
6426 uint32_t TZ = 0;
6427 for (const SCEV *Operand : M->operands())
6428 TZ += getMinTrailingZeros(Operand, CtxI);
6429 return GetShiftedByZeros(TZ);
6430 }
6431 case scAddExpr:
6432 case scAddRecExpr: {
6433 const SCEVNAryExpr *N = cast<SCEVNAryExpr>(S);
6434 if (N->hasNoUnsignedWrap())
6435 return GetGCDMultiple(N);
6436 // Find the trailing bits, which is the minimum of its operands.
6437 uint32_t TZ = getMinTrailingZeros(N->getOperand(0), CtxI);
6438 for (const SCEV *Operand : N->operands().drop_front())
6439 TZ = std::min(TZ, getMinTrailingZeros(Operand, CtxI));
6440 return GetShiftedByZeros(TZ);
6441 }
6442 case scUMaxExpr:
6443 case scSMaxExpr:
6444 case scUMinExpr:
6445 case scSMinExpr:
6447 return GetGCDMultiple(cast<SCEVNAryExpr>(S));
6448 case scUnknown: {
6449 // Ask ValueTracking for known bits. SCEVUnknown only become available at
6450 // the point their underlying IR instruction has been defined. If CtxI was
6451 // not provided, use:
6452 // * the first instruction in the entry block if it is an argument
6453 // * the instruction itself otherwise.
6454 const SCEVUnknown *U = cast<SCEVUnknown>(S);
6455 if (!CtxI) {
6456 if (isa<Argument>(U->getValue()))
6457 CtxI = &*F.getEntryBlock().begin();
6458 else if (auto *I = dyn_cast<Instruction>(U->getValue()))
6459 CtxI = I;
6460 }
6461 unsigned Known =
6462 computeKnownBits(U->getValue(), getDataLayout(), &AC, CtxI, &DT)
6463 .countMinTrailingZeros();
6464 return GetShiftedByZeros(Known);
6465 }
6466 case scCouldNotCompute:
6467 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6468 }
6469 llvm_unreachable("Unknown SCEV kind!");
6470}
6471
6473 const Instruction *CtxI) {
6474 // Skip looking up and updating the cache if there is a context instruction,
6475 // as the result will only be valid in the specified context.
6476 if (CtxI)
6477 return getConstantMultipleImpl(S, CtxI);
6478
6479 auto I = ConstantMultipleCache.find(S);
6480 if (I != ConstantMultipleCache.end())
6481 return I->second;
6482
6483 APInt Result = getConstantMultipleImpl(S, CtxI);
6484 auto InsertPair = ConstantMultipleCache.insert({S, Result});
6485 assert(InsertPair.second && "Should insert a new key");
6486 return InsertPair.first->second;
6487}
6488
6490 APInt Multiple = getConstantMultiple(S);
6491 return Multiple == 0 ? APInt(Multiple.getBitWidth(), 1) : Multiple;
6492}
6493
6495 const Instruction *CtxI) {
6496 return std::min(getConstantMultiple(S, CtxI).countTrailingZeros(),
6497 (unsigned)getTypeSizeInBits(S->getType()));
6498}
6499
6500/// Helper method to assign a range to V from metadata present in the IR.
6501static std::optional<ConstantRange> GetRangeFromMetadata(Value *V) {
6503 if (MDNode *MD = I->getMetadata(LLVMContext::MD_range))
6504 return getConstantRangeFromMetadata(*MD);
6505 if (const auto *CB = dyn_cast<CallBase>(V))
6506 if (std::optional<ConstantRange> Range = CB->getRange())
6507 return Range;
6508 }
6509 if (auto *A = dyn_cast<Argument>(V))
6510 if (std::optional<ConstantRange> Range = A->getRange())
6511 return Range;
6512
6513 return std::nullopt;
6514}
6515
6517 SCEV::NoWrapFlags Flags) {
6518 if (AddRec->getNoWrapFlags(Flags) != Flags) {
6519 AddRec->setNoWrapFlags(Flags);
6520 UnsignedRanges.erase(AddRec);
6521 SignedRanges.erase(AddRec);
6522 ConstantMultipleCache.erase(AddRec);
6523 }
6524}
6525
6526ConstantRange ScalarEvolution::
6527getRangeForUnknownRecurrence(const SCEVUnknown *U) {
6528 const DataLayout &DL = getDataLayout();
6529
6530 unsigned BitWidth = getTypeSizeInBits(U->getType());
6531 const ConstantRange FullSet(BitWidth, /*isFullSet=*/true);
6532
6533 // Match a simple recurrence of the form: <start, ShiftOp, Step>, and then
6534 // use information about the trip count to improve our available range. Note
6535 // that the trip count independent cases are already handled by known bits.
6536 // WARNING: The definition of recurrence used here is subtly different than
6537 // the one used by AddRec (and thus most of this file). Step is allowed to
6538 // be arbitrarily loop varying here, where AddRec allows only loop invariant
6539 // and other addrecs in the same loop (for non-affine addrecs). The code
6540 // below intentionally handles the case where step is not loop invariant.
6541 auto *P = dyn_cast<PHINode>(U->getValue());
6542 if (!P)
6543 return FullSet;
6544
6545 // Make sure that no Phi input comes from an unreachable block. Otherwise,
6546 // even the values that are not available in these blocks may come from them,
6547 // and this leads to false-positive recurrence test.
6548 for (auto *Pred : predecessors(P->getParent()))
6549 if (!DT.isReachableFromEntry(Pred))
6550 return FullSet;
6551
6552 BinaryOperator *BO;
6553 Value *Start, *Step;
6554 if (!matchSimpleRecurrence(P, BO, Start, Step))
6555 return FullSet;
6556
6557 // If we found a recurrence in reachable code, we must be in a loop. Note
6558 // that BO might be in some subloop of L, and that's completely okay.
6559 auto *L = LI.getLoopFor(P->getParent());
6560 assert(L && L->getHeader() == P->getParent());
6561 if (!L->contains(BO->getParent()))
6562 // NOTE: This bailout should be an assert instead. However, asserting
6563 // the condition here exposes a case where LoopFusion is querying SCEV
6564 // with malformed loop information during the midst of the transform.
6565 // There doesn't appear to be an obvious fix, so for the moment bailout
6566 // until the caller issue can be fixed. PR49566 tracks the bug.
6567 return FullSet;
6568
6569 // TODO: Extend to other opcodes such as mul, and div
6570 switch (BO->getOpcode()) {
6571 default:
6572 return FullSet;
6573 case Instruction::AShr:
6574 case Instruction::LShr:
6575 case Instruction::Shl:
6576 break;
6577 };
6578
6579 if (BO->getOperand(0) != P)
6580 // TODO: Handle the power function forms some day.
6581 return FullSet;
6582
6583 unsigned TC = getSmallConstantMaxTripCount(L);
6584 if (!TC || TC >= BitWidth)
6585 return FullSet;
6586
6587 auto KnownStart = computeKnownBits(Start, DL, &AC, nullptr, &DT);
6588 auto KnownStep = computeKnownBits(Step, DL, &AC, nullptr, &DT);
6589 assert(KnownStart.getBitWidth() == BitWidth &&
6590 KnownStep.getBitWidth() == BitWidth);
6591
6592 // Compute total shift amount, being careful of overflow and bitwidths.
6593 auto MaxShiftAmt = KnownStep.getMaxValue();
6594 APInt TCAP(BitWidth, TC-1);
6595 bool Overflow = false;
6596 auto TotalShift = MaxShiftAmt.umul_ov(TCAP, Overflow);
6597 if (Overflow)
6598 return FullSet;
6599
6600 switch (BO->getOpcode()) {
6601 default:
6602 llvm_unreachable("filtered out above");
6603 case Instruction::AShr: {
6604 // For each ashr, three cases:
6605 // shift = 0 => unchanged value
6606 // saturation => 0 or -1
6607 // other => a value closer to zero (of the same sign)
6608 // Thus, the end value is closer to zero than the start.
6609 auto KnownEnd = KnownBits::ashr(KnownStart,
6610 KnownBits::makeConstant(TotalShift));
6611 if (KnownStart.isNonNegative())
6612 // Analogous to lshr (simply not yet canonicalized)
6613 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6614 KnownStart.getMaxValue() + 1);
6615 if (KnownStart.isNegative())
6616 // End >=u Start && End <=s Start
6617 return ConstantRange::getNonEmpty(KnownStart.getMinValue(),
6618 KnownEnd.getMaxValue() + 1);
6619 break;
6620 }
6621 case Instruction::LShr: {
6622 // For each lshr, three cases:
6623 // shift = 0 => unchanged value
6624 // saturation => 0
6625 // other => a smaller positive number
6626 // Thus, the low end of the unsigned range is the last value produced.
6627 auto KnownEnd = KnownBits::lshr(KnownStart,
6628 KnownBits::makeConstant(TotalShift));
6629 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6630 KnownStart.getMaxValue() + 1);
6631 }
6632 case Instruction::Shl: {
6633 // Iff no bits are shifted out, value increases on every shift.
6634 auto KnownEnd = KnownBits::shl(KnownStart,
6635 KnownBits::makeConstant(TotalShift));
6636 if (TotalShift.ult(KnownStart.countMinLeadingZeros()))
6637 return ConstantRange(KnownStart.getMinValue(),
6638 KnownEnd.getMaxValue() + 1);
6639 break;
6640 }
6641 };
6642 return FullSet;
6643}
6644
6645const ConstantRange &
6646ScalarEvolution::getRangeRefIter(const SCEV *S,
6647 ScalarEvolution::RangeSignHint SignHint) {
6648 DenseMap<const SCEV *, ConstantRange> &Cache =
6649 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6650 : SignedRanges;
6652 SmallPtrSet<const SCEV *, 8> Seen;
6653
6654 // Add Expr to the worklist, if Expr is either an N-ary expression or a
6655 // SCEVUnknown PHI node.
6656 auto AddToWorklist = [&WorkList, &Seen, &Cache](const SCEV *Expr) {
6657 if (!Seen.insert(Expr).second)
6658 return;
6659 if (Cache.contains(Expr))
6660 return;
6661 switch (Expr->getSCEVType()) {
6662 case scUnknown:
6663 if (!isa<PHINode>(cast<SCEVUnknown>(Expr)->getValue()))
6664 break;
6665 [[fallthrough]];
6666 case scConstant:
6667 case scVScale:
6668 case scTruncate:
6669 case scZeroExtend:
6670 case scSignExtend:
6671 case scPtrToAddr:
6672 case scPtrToInt:
6673 case scAddExpr:
6674 case scMulExpr:
6675 case scUDivExpr:
6676 case scAddRecExpr:
6677 case scUMaxExpr:
6678 case scSMaxExpr:
6679 case scUMinExpr:
6680 case scSMinExpr:
6682 WorkList.push_back(Expr);
6683 break;
6684 case scCouldNotCompute:
6685 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6686 }
6687 };
6688 AddToWorklist(S);
6689
6690 // Build worklist by queuing operands of N-ary expressions and phi nodes.
6691 for (unsigned I = 0; I != WorkList.size(); ++I) {
6692 const SCEV *P = WorkList[I];
6693 auto *UnknownS = dyn_cast<SCEVUnknown>(P);
6694 // If it is not a `SCEVUnknown`, just recurse into operands.
6695 if (!UnknownS) {
6696 for (const SCEV *Op : P->operands())
6697 AddToWorklist(Op);
6698 continue;
6699 }
6700 // `SCEVUnknown`'s require special treatment.
6701 if (const PHINode *P = dyn_cast<PHINode>(UnknownS->getValue())) {
6702 if (!PendingPhiRangesIter.insert(P).second)
6703 continue;
6704 for (auto &Op : reverse(P->operands()))
6705 AddToWorklist(getSCEV(Op));
6706 }
6707 }
6708
6709 if (!WorkList.empty()) {
6710 // Use getRangeRef to compute ranges for items in the worklist in reverse
6711 // order. This will force ranges for earlier operands to be computed before
6712 // their users in most cases.
6713 for (const SCEV *P : reverse(drop_begin(WorkList))) {
6714 getRangeRef(P, SignHint);
6715
6716 if (auto *UnknownS = dyn_cast<SCEVUnknown>(P))
6717 if (const PHINode *P = dyn_cast<PHINode>(UnknownS->getValue()))
6718 PendingPhiRangesIter.erase(P);
6719 }
6720 }
6721
6722 return getRangeRef(S, SignHint, 0);
6723}
6724
6725/// Determine the range for a particular SCEV. If SignHint is
6726/// HINT_RANGE_UNSIGNED (resp. HINT_RANGE_SIGNED) then getRange prefers ranges
6727/// with a "cleaner" unsigned (resp. signed) representation.
6728const ConstantRange &ScalarEvolution::getRangeRef(
6729 const SCEV *S, ScalarEvolution::RangeSignHint SignHint, unsigned Depth) {
6730 DenseMap<const SCEV *, ConstantRange> &Cache =
6731 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6732 : SignedRanges;
6734 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? ConstantRange::Unsigned
6736
6737 // See if we've computed this range already.
6739 if (I != Cache.end())
6740 return I->second;
6741
6742 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
6743 return setRange(C, SignHint, ConstantRange(C->getAPInt()));
6744
6745 // Switch to iteratively computing the range for S, if it is part of a deeply
6746 // nested expression.
6748 return getRangeRefIter(S, SignHint);
6749
6750 unsigned BitWidth = getTypeSizeInBits(S->getType());
6751 ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true);
6752 using OBO = OverflowingBinaryOperator;
6753
6754 // If the value has known zeros, the maximum value will have those known zeros
6755 // as well.
6756 if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) {
6757 APInt Multiple = getNonZeroConstantMultiple(S);
6758 APInt Remainder = APInt::getMaxValue(BitWidth).urem(Multiple);
6759 if (!Remainder.isZero())
6760 ConservativeResult =
6761 ConstantRange(APInt::getMinValue(BitWidth),
6762 APInt::getMaxValue(BitWidth) - Remainder + 1);
6763 }
6764 else {
6765 uint32_t TZ = getMinTrailingZeros(S);
6766 if (TZ != 0) {
6767 ConservativeResult = ConstantRange(
6769 APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1);
6770 }
6771 }
6772
6773 switch (S->getSCEVType()) {
6774 case scConstant:
6775 llvm_unreachable("Already handled above.");
6776 case scVScale:
6777 return setRange(S, SignHint, getVScaleRange(&F, BitWidth));
6778 case scTruncate: {
6779 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(S);
6780 ConstantRange X = getRangeRef(Trunc->getOperand(), SignHint, Depth + 1);
6781 return setRange(
6782 Trunc, SignHint,
6783 ConservativeResult.intersectWith(X.truncate(BitWidth), RangeType));
6784 }
6785 case scZeroExtend: {
6786 const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(S);
6787 ConstantRange X = getRangeRef(ZExt->getOperand(), SignHint, Depth + 1);
6788 return setRange(
6789 ZExt, SignHint,
6790 ConservativeResult.intersectWith(X.zeroExtend(BitWidth), RangeType));
6791 }
6792 case scSignExtend: {
6793 const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(S);
6794 ConstantRange X = getRangeRef(SExt->getOperand(), SignHint, Depth + 1);
6795 return setRange(
6796 SExt, SignHint,
6797 ConservativeResult.intersectWith(X.signExtend(BitWidth), RangeType));
6798 }
6799 case scPtrToAddr:
6800 case scPtrToInt: {
6801 const SCEVCastExpr *Cast = cast<SCEVCastExpr>(S);
6802 ConstantRange X = getRangeRef(Cast->getOperand(), SignHint, Depth + 1);
6803 return setRange(Cast, SignHint, X);
6804 }
6805 case scAddExpr: {
6806 const SCEVAddExpr *Add = cast<SCEVAddExpr>(S);
6807 // Check if this is a URem pattern: A - (A / B) * B, which is always < B.
6808 const SCEV *URemLHS = nullptr, *URemRHS = nullptr;
6809 if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED &&
6810 match(S, m_scev_URem(m_SCEV(URemLHS), m_SCEV(URemRHS), *this))) {
6811 ConstantRange LHSRange = getRangeRef(URemLHS, SignHint, Depth + 1);
6812 ConstantRange RHSRange = getRangeRef(URemRHS, SignHint, Depth + 1);
6813 ConservativeResult =
6814 ConservativeResult.intersectWith(LHSRange.urem(RHSRange), RangeType);
6815 }
6816 ConstantRange X = getRangeRef(Add->getOperand(0), SignHint, Depth + 1);
6817 unsigned WrapType = OBO::AnyWrap;
6818 if (Add->hasNoSignedWrap())
6819 WrapType |= OBO::NoSignedWrap;
6820 if (Add->hasNoUnsignedWrap())
6821 WrapType |= OBO::NoUnsignedWrap;
6822 for (const SCEV *Op : drop_begin(Add->operands()))
6823 X = X.addWithNoWrap(getRangeRef(Op, SignHint, Depth + 1), WrapType,
6824 RangeType);
6825 return setRange(Add, SignHint,
6826 ConservativeResult.intersectWith(X, RangeType));
6827 }
6828 case scMulExpr: {
6829 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(S);
6830 ConstantRange X = getRangeRef(Mul->getOperand(0), SignHint, Depth + 1);
6831 for (const SCEV *Op : drop_begin(Mul->operands()))
6832 X = X.multiply(getRangeRef(Op, SignHint, Depth + 1));
6833 return setRange(Mul, SignHint,
6834 ConservativeResult.intersectWith(X, RangeType));
6835 }
6836 case scUDivExpr: {
6837 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
6838 ConstantRange X = getRangeRef(UDiv->getLHS(), SignHint, Depth + 1);
6839 ConstantRange Y = getRangeRef(UDiv->getRHS(), SignHint, Depth + 1);
6840 return setRange(UDiv, SignHint,
6841 ConservativeResult.intersectWith(X.udiv(Y), RangeType));
6842 }
6843 case scAddRecExpr: {
6844 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(S);
6845 // If there's no unsigned wrap, the value will never be less than its
6846 // initial value.
6847 if (AddRec->hasNoUnsignedWrap()) {
6848 APInt UnsignedMinValue = getUnsignedRangeMin(AddRec->getStart());
6849 if (!UnsignedMinValue.isZero())
6850 ConservativeResult = ConservativeResult.intersectWith(
6851 ConstantRange(UnsignedMinValue, APInt(BitWidth, 0)), RangeType);
6852 }
6853
6854 // If there's no signed wrap, and all the operands except initial value have
6855 // the same sign or zero, the value won't ever be:
6856 // 1: smaller than initial value if operands are non negative,
6857 // 2: bigger than initial value if operands are non positive.
6858 // For both cases, value can not cross signed min/max boundary.
6859 if (AddRec->hasNoSignedWrap()) {
6860 bool AllNonNeg = true;
6861 bool AllNonPos = true;
6862 for (unsigned i = 1, e = AddRec->getNumOperands(); i != e; ++i) {
6863 if (!isKnownNonNegative(AddRec->getOperand(i)))
6864 AllNonNeg = false;
6865 if (!isKnownNonPositive(AddRec->getOperand(i)))
6866 AllNonPos = false;
6867 }
6868 if (AllNonNeg)
6869 ConservativeResult = ConservativeResult.intersectWith(
6872 RangeType);
6873 else if (AllNonPos)
6874 ConservativeResult = ConservativeResult.intersectWith(
6876 getSignedRangeMax(AddRec->getStart()) +
6877 1),
6878 RangeType);
6879 }
6880
6881 // TODO: non-affine addrec
6882 if (AddRec->isAffine()) {
6883 const SCEV *MaxBEScev =
6885 if (!isa<SCEVCouldNotCompute>(MaxBEScev)) {
6886 APInt MaxBECount = cast<SCEVConstant>(MaxBEScev)->getAPInt();
6887
6888 // Adjust MaxBECount to the same bitwidth as AddRec. We can truncate if
6889 // MaxBECount's active bits are all <= AddRec's bit width.
6890 if (MaxBECount.getBitWidth() > BitWidth &&
6891 MaxBECount.getActiveBits() <= BitWidth)
6892 MaxBECount = MaxBECount.trunc(BitWidth);
6893 else if (MaxBECount.getBitWidth() < BitWidth)
6894 MaxBECount = MaxBECount.zext(BitWidth);
6895
6896 if (MaxBECount.getBitWidth() == BitWidth) {
6897 auto RangeFromAffine = getRangeForAffineAR(
6898 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
6899 ConservativeResult =
6900 ConservativeResult.intersectWith(RangeFromAffine, RangeType);
6901
6902 auto RangeFromFactoring = getRangeViaFactoring(
6903 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
6904 ConservativeResult =
6905 ConservativeResult.intersectWith(RangeFromFactoring, RangeType);
6906 }
6907 }
6908
6909 // Now try symbolic BE count and more powerful methods.
6911 const SCEV *SymbolicMaxBECount =
6913 if (!isa<SCEVCouldNotCompute>(SymbolicMaxBECount) &&
6914 getTypeSizeInBits(MaxBEScev->getType()) <= BitWidth &&
6915 AddRec->hasNoSelfWrap()) {
6916 auto RangeFromAffineNew = getRangeForAffineNoSelfWrappingAR(
6917 AddRec, SymbolicMaxBECount, BitWidth, SignHint);
6918 ConservativeResult =
6919 ConservativeResult.intersectWith(RangeFromAffineNew, RangeType);
6920 }
6921 }
6922 }
6923
6924 return setRange(AddRec, SignHint, std::move(ConservativeResult));
6925 }
6926 case scUMaxExpr:
6927 case scSMaxExpr:
6928 case scUMinExpr:
6929 case scSMinExpr:
6930 case scSequentialUMinExpr: {
6932 switch (S->getSCEVType()) {
6933 case scUMaxExpr:
6934 ID = Intrinsic::umax;
6935 break;
6936 case scSMaxExpr:
6937 ID = Intrinsic::smax;
6938 break;
6939 case scUMinExpr:
6941 ID = Intrinsic::umin;
6942 break;
6943 case scSMinExpr:
6944 ID = Intrinsic::smin;
6945 break;
6946 default:
6947 llvm_unreachable("Unknown SCEVMinMaxExpr/SCEVSequentialMinMaxExpr.");
6948 }
6949
6950 const auto *NAry = cast<SCEVNAryExpr>(S);
6951 ConstantRange X = getRangeRef(NAry->getOperand(0), SignHint, Depth + 1);
6952 for (unsigned i = 1, e = NAry->getNumOperands(); i != e; ++i)
6953 X = X.intrinsic(
6954 ID, {X, getRangeRef(NAry->getOperand(i), SignHint, Depth + 1)});
6955 return setRange(S, SignHint,
6956 ConservativeResult.intersectWith(X, RangeType));
6957 }
6958 case scUnknown: {
6959 const SCEVUnknown *U = cast<SCEVUnknown>(S);
6960 Value *V = U->getValue();
6961
6962 // Check if the IR explicitly contains !range metadata.
6963 std::optional<ConstantRange> MDRange = GetRangeFromMetadata(V);
6964 if (MDRange)
6965 ConservativeResult =
6966 ConservativeResult.intersectWith(*MDRange, RangeType);
6967
6968 // Use facts about recurrences in the underlying IR. Note that add
6969 // recurrences are AddRecExprs and thus don't hit this path. This
6970 // primarily handles shift recurrences.
6971 auto CR = getRangeForUnknownRecurrence(U);
6972 ConservativeResult = ConservativeResult.intersectWith(CR);
6973
6974 // See if ValueTracking can give us a useful range.
6975 const DataLayout &DL = getDataLayout();
6976 KnownBits Known = computeKnownBits(V, DL, &AC, nullptr, &DT);
6977 if (Known.getBitWidth() != BitWidth)
6978 Known = Known.zextOrTrunc(BitWidth);
6979
6980 // ValueTracking may be able to compute a tighter result for the number of
6981 // sign bits than for the value of those sign bits.
6982 unsigned NS = ComputeNumSignBits(V, DL, &AC, nullptr, &DT);
6983 if (U->getType()->isPointerTy()) {
6984 // If the pointer size is larger than the index size type, this can cause
6985 // NS to be larger than BitWidth. So compensate for this.
6986 unsigned ptrSize = DL.getPointerTypeSizeInBits(U->getType());
6987 int ptrIdxDiff = ptrSize - BitWidth;
6988 if (ptrIdxDiff > 0 && ptrSize > BitWidth && NS > (unsigned)ptrIdxDiff)
6989 NS -= ptrIdxDiff;
6990 }
6991
6992 if (NS > 1) {
6993 // If we know any of the sign bits, we know all of the sign bits.
6994 if (!Known.Zero.getHiBits(NS).isZero())
6995 Known.Zero.setHighBits(NS);
6996 if (!Known.One.getHiBits(NS).isZero())
6997 Known.One.setHighBits(NS);
6998 }
6999
7000 if (Known.getMinValue() != Known.getMaxValue() + 1)
7001 ConservativeResult = ConservativeResult.intersectWith(
7002 ConstantRange(Known.getMinValue(), Known.getMaxValue() + 1),
7003 RangeType);
7004 if (NS > 1)
7005 ConservativeResult = ConservativeResult.intersectWith(
7006 ConstantRange(APInt::getSignedMinValue(BitWidth).ashr(NS - 1),
7007 APInt::getSignedMaxValue(BitWidth).ashr(NS - 1) + 1),
7008 RangeType);
7009
7010 if (U->getType()->isPointerTy() && SignHint == HINT_RANGE_UNSIGNED) {
7011 // Strengthen the range if the underlying IR value is a
7012 // global/alloca/heap allocation using the size of the object.
7013 bool CanBeNull, CanBeFreed;
7014 uint64_t DerefBytes =
7015 V->getPointerDereferenceableBytes(DL, CanBeNull, CanBeFreed);
7016 if (DerefBytes > 1 && isUIntN(BitWidth, DerefBytes)) {
7017 // The highest address the object can start is DerefBytes bytes before
7018 // the end (unsigned max value). If this value is not a multiple of the
7019 // alignment, the last possible start value is the next lowest multiple
7020 // of the alignment. Note: The computations below cannot overflow,
7021 // because if they would there's no possible start address for the
7022 // object.
7023 APInt MaxVal =
7024 APInt::getMaxValue(BitWidth) - APInt(BitWidth, DerefBytes);
7025 uint64_t Align = U->getValue()->getPointerAlignment(DL).value();
7026 uint64_t Rem = MaxVal.urem(Align);
7027 MaxVal -= APInt(BitWidth, Rem);
7028 APInt MinVal = APInt::getZero(BitWidth);
7029 if (llvm::isKnownNonZero(V, DL))
7030 MinVal = Align;
7031 ConservativeResult = ConservativeResult.intersectWith(
7032 ConstantRange::getNonEmpty(MinVal, MaxVal + 1), RangeType);
7033 }
7034 }
7035
7036 // A range of Phi is a subset of union of all ranges of its input.
7037 if (PHINode *Phi = dyn_cast<PHINode>(V)) {
7038 // Make sure that we do not run over cycled Phis.
7039 if (PendingPhiRanges.insert(Phi).second) {
7040 ConstantRange RangeFromOps(BitWidth, /*isFullSet=*/false);
7041
7042 for (const auto &Op : Phi->operands()) {
7043 auto OpRange = getRangeRef(getSCEV(Op), SignHint, Depth + 1);
7044 RangeFromOps = RangeFromOps.unionWith(OpRange);
7045 // No point to continue if we already have a full set.
7046 if (RangeFromOps.isFullSet())
7047 break;
7048 }
7049 ConservativeResult =
7050 ConservativeResult.intersectWith(RangeFromOps, RangeType);
7051 bool Erased = PendingPhiRanges.erase(Phi);
7052 assert(Erased && "Failed to erase Phi properly?");
7053 (void)Erased;
7054 }
7055 }
7056
7057 // vscale can't be equal to zero
7058 if (const auto *II = dyn_cast<IntrinsicInst>(V))
7059 if (II->getIntrinsicID() == Intrinsic::vscale) {
7060 ConstantRange Disallowed = APInt::getZero(BitWidth);
7061 ConservativeResult = ConservativeResult.difference(Disallowed);
7062 }
7063
7064 return setRange(U, SignHint, std::move(ConservativeResult));
7065 }
7066 case scCouldNotCompute:
7067 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
7068 }
7069
7070 return setRange(S, SignHint, std::move(ConservativeResult));
7071}
7072
7073// Given a StartRange, Step and MaxBECount for an expression compute a range of
7074// values that the expression can take. Initially, the expression has a value
7075// from StartRange and then is changed by Step up to MaxBECount times. Signed
7076// argument defines if we treat Step as signed or unsigned.
7078 const ConstantRange &StartRange,
7079 const APInt &MaxBECount,
7080 bool Signed) {
7081 unsigned BitWidth = Step.getBitWidth();
7082 assert(BitWidth == StartRange.getBitWidth() &&
7083 BitWidth == MaxBECount.getBitWidth() && "mismatched bit widths");
7084 // If either Step or MaxBECount is 0, then the expression won't change, and we
7085 // just need to return the initial range.
7086 if (Step == 0 || MaxBECount == 0)
7087 return StartRange;
7088
7089 // If we don't know anything about the initial value (i.e. StartRange is
7090 // FullRange), then we don't know anything about the final range either.
7091 // Return FullRange.
7092 if (StartRange.isFullSet())
7093 return ConstantRange::getFull(BitWidth);
7094
7095 // If Step is signed and negative, then we use its absolute value, but we also
7096 // note that we're moving in the opposite direction.
7097 bool Descending = Signed && Step.isNegative();
7098
7099 if (Signed)
7100 // This is correct even for INT_SMIN. Let's look at i8 to illustrate this:
7101 // abs(INT_SMIN) = abs(-128) = abs(0x80) = -0x80 = 0x80 = 128.
7102 // This equations hold true due to the well-defined wrap-around behavior of
7103 // APInt.
7104 Step = Step.abs();
7105
7106 // Check if Offset is more than full span of BitWidth. If it is, the
7107 // expression is guaranteed to overflow.
7108 if (APInt::getMaxValue(StartRange.getBitWidth()).udiv(Step).ult(MaxBECount))
7109 return ConstantRange::getFull(BitWidth);
7110
7111 // Offset is by how much the expression can change. Checks above guarantee no
7112 // overflow here.
7113 APInt Offset = Step * MaxBECount;
7114
7115 // Minimum value of the final range will match the minimal value of StartRange
7116 // if the expression is increasing and will be decreased by Offset otherwise.
7117 // Maximum value of the final range will match the maximal value of StartRange
7118 // if the expression is decreasing and will be increased by Offset otherwise.
7119 APInt StartLower = StartRange.getLower();
7120 APInt StartUpper = StartRange.getUpper() - 1;
7121 APInt MovedBoundary = Descending ? (StartLower - std::move(Offset))
7122 : (StartUpper + std::move(Offset));
7123
7124 // It's possible that the new minimum/maximum value will fall into the initial
7125 // range (due to wrap around). This means that the expression can take any
7126 // value in this bitwidth, and we have to return full range.
7127 if (StartRange.contains(MovedBoundary))
7128 return ConstantRange::getFull(BitWidth);
7129
7130 APInt NewLower =
7131 Descending ? std::move(MovedBoundary) : std::move(StartLower);
7132 APInt NewUpper =
7133 Descending ? std::move(StartUpper) : std::move(MovedBoundary);
7134 NewUpper += 1;
7135
7136 // No overflow detected, return [StartLower, StartUpper + Offset + 1) range.
7137 return ConstantRange::getNonEmpty(std::move(NewLower), std::move(NewUpper));
7138}
7139
7140ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start,
7141 const SCEV *Step,
7142 const APInt &MaxBECount) {
7143 assert(getTypeSizeInBits(Start->getType()) ==
7144 getTypeSizeInBits(Step->getType()) &&
7145 getTypeSizeInBits(Start->getType()) == MaxBECount.getBitWidth() &&
7146 "mismatched bit widths");
7147
7148 // First, consider step signed.
7149 ConstantRange StartSRange = getSignedRange(Start);
7150 ConstantRange StepSRange = getSignedRange(Step);
7151
7152 // If Step can be both positive and negative, we need to find ranges for the
7153 // maximum absolute step values in both directions and union them.
7154 ConstantRange SR = getRangeForAffineARHelper(
7155 StepSRange.getSignedMin(), StartSRange, MaxBECount, /* Signed = */ true);
7157 StartSRange, MaxBECount,
7158 /* Signed = */ true));
7159
7160 // Next, consider step unsigned.
7161 ConstantRange UR = getRangeForAffineARHelper(
7162 getUnsignedRangeMax(Step), getUnsignedRange(Start), MaxBECount,
7163 /* Signed = */ false);
7164
7165 // Finally, intersect signed and unsigned ranges.
7167}
7168
7169ConstantRange ScalarEvolution::getRangeForAffineNoSelfWrappingAR(
7170 const SCEVAddRecExpr *AddRec, const SCEV *MaxBECount, unsigned BitWidth,
7171 ScalarEvolution::RangeSignHint SignHint) {
7172 assert(AddRec->isAffine() && "Non-affine AddRecs are not suppored!\n");
7173 assert(AddRec->hasNoSelfWrap() &&
7174 "This only works for non-self-wrapping AddRecs!");
7175 const bool IsSigned = SignHint == HINT_RANGE_SIGNED;
7176 const SCEV *Step = AddRec->getStepRecurrence(*this);
7177 // Only deal with constant step to save compile time.
7178 if (!isa<SCEVConstant>(Step))
7179 return ConstantRange::getFull(BitWidth);
7180 // Let's make sure that we can prove that we do not self-wrap during
7181 // MaxBECount iterations. We need this because MaxBECount is a maximum
7182 // iteration count estimate, and we might infer nw from some exit for which we
7183 // do not know max exit count (or any other side reasoning).
7184 // TODO: Turn into assert at some point.
7185 if (getTypeSizeInBits(MaxBECount->getType()) >
7186 getTypeSizeInBits(AddRec->getType()))
7187 return ConstantRange::getFull(BitWidth);
7188 MaxBECount = getNoopOrZeroExtend(MaxBECount, AddRec->getType());
7189 const SCEV *RangeWidth = getMinusOne(AddRec->getType());
7190 const SCEV *StepAbs = getUMinExpr(Step, getNegativeSCEV(Step));
7191 const SCEV *MaxItersWithoutWrap = getUDivExpr(RangeWidth, StepAbs);
7192 if (!isKnownPredicateViaConstantRanges(ICmpInst::ICMP_ULE, MaxBECount,
7193 MaxItersWithoutWrap))
7194 return ConstantRange::getFull(BitWidth);
7195
7196 ICmpInst::Predicate LEPred =
7198 ICmpInst::Predicate GEPred =
7200 const SCEV *End = AddRec->evaluateAtIteration(MaxBECount, *this);
7201
7202 // We know that there is no self-wrap. Let's take Start and End values and
7203 // look at all intermediate values V1, V2, ..., Vn that IndVar takes during
7204 // the iteration. They either lie inside the range [Min(Start, End),
7205 // Max(Start, End)] or outside it:
7206 //
7207 // Case 1: RangeMin ... Start V1 ... VN End ... RangeMax;
7208 // Case 2: RangeMin Vk ... V1 Start ... End Vn ... Vk + 1 RangeMax;
7209 //
7210 // No self wrap flag guarantees that the intermediate values cannot be BOTH
7211 // outside and inside the range [Min(Start, End), Max(Start, End)]. Using that
7212 // knowledge, let's try to prove that we are dealing with Case 1. It is so if
7213 // Start <= End and step is positive, or Start >= End and step is negative.
7214 const SCEV *Start = applyLoopGuards(AddRec->getStart(), AddRec->getLoop());
7215 ConstantRange StartRange = getRangeRef(Start, SignHint);
7216 ConstantRange EndRange = getRangeRef(End, SignHint);
7217 ConstantRange RangeBetween = StartRange.unionWith(EndRange);
7218 // If they already cover full iteration space, we will know nothing useful
7219 // even if we prove what we want to prove.
7220 if (RangeBetween.isFullSet())
7221 return RangeBetween;
7222 // Only deal with ranges that do not wrap (i.e. RangeMin < RangeMax).
7223 bool IsWrappedSet = IsSigned ? RangeBetween.isSignWrappedSet()
7224 : RangeBetween.isWrappedSet();
7225 if (IsWrappedSet)
7226 return ConstantRange::getFull(BitWidth);
7227
7228 if (isKnownPositive(Step) &&
7229 isKnownPredicateViaConstantRanges(LEPred, Start, End))
7230 return RangeBetween;
7231 if (isKnownNegative(Step) &&
7232 isKnownPredicateViaConstantRanges(GEPred, Start, End))
7233 return RangeBetween;
7234 return ConstantRange::getFull(BitWidth);
7235}
7236
7237ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start,
7238 const SCEV *Step,
7239 const APInt &MaxBECount) {
7240 // RangeOf({C?A:B,+,C?P:Q}) == RangeOf(C?{A,+,P}:{B,+,Q})
7241 // == RangeOf({A,+,P}) union RangeOf({B,+,Q})
7242
7243 unsigned BitWidth = MaxBECount.getBitWidth();
7244 assert(getTypeSizeInBits(Start->getType()) == BitWidth &&
7245 getTypeSizeInBits(Step->getType()) == BitWidth &&
7246 "mismatched bit widths");
7247
7248 struct SelectPattern {
7249 Value *Condition = nullptr;
7250 APInt TrueValue;
7251 APInt FalseValue;
7252
7253 explicit SelectPattern(ScalarEvolution &SE, unsigned BitWidth,
7254 const SCEV *S) {
7255 std::optional<unsigned> CastOp;
7256 APInt Offset(BitWidth, 0);
7257
7259 "Should be!");
7260
7261 // Peel off a constant offset. In the future we could consider being
7262 // smarter here and handle {Start+Step,+,Step} too.
7263 const APInt *Off;
7264 if (match(S, m_scev_Add(m_scev_APInt(Off), m_SCEV(S))))
7265 Offset = *Off;
7266
7267 // Peel off a cast operation
7268 if (auto *SCast = dyn_cast<SCEVIntegralCastExpr>(S)) {
7269 CastOp = SCast->getSCEVType();
7270 S = SCast->getOperand();
7271 }
7272
7273 using namespace llvm::PatternMatch;
7274
7275 auto *SU = dyn_cast<SCEVUnknown>(S);
7276 const APInt *TrueVal, *FalseVal;
7277 if (!SU ||
7278 !match(SU->getValue(), m_Select(m_Value(Condition), m_APInt(TrueVal),
7279 m_APInt(FalseVal)))) {
7280 Condition = nullptr;
7281 return;
7282 }
7283
7284 TrueValue = *TrueVal;
7285 FalseValue = *FalseVal;
7286
7287 // Re-apply the cast we peeled off earlier
7288 if (CastOp)
7289 switch (*CastOp) {
7290 default:
7291 llvm_unreachable("Unknown SCEV cast type!");
7292
7293 case scTruncate:
7294 TrueValue = TrueValue.trunc(BitWidth);
7295 FalseValue = FalseValue.trunc(BitWidth);
7296 break;
7297 case scZeroExtend:
7298 TrueValue = TrueValue.zext(BitWidth);
7299 FalseValue = FalseValue.zext(BitWidth);
7300 break;
7301 case scSignExtend:
7302 TrueValue = TrueValue.sext(BitWidth);
7303 FalseValue = FalseValue.sext(BitWidth);
7304 break;
7305 }
7306
7307 // Re-apply the constant offset we peeled off earlier
7308 TrueValue += Offset;
7309 FalseValue += Offset;
7310 }
7311
7312 bool isRecognized() { return Condition != nullptr; }
7313 };
7314
7315 SelectPattern StartPattern(*this, BitWidth, Start);
7316 if (!StartPattern.isRecognized())
7317 return ConstantRange::getFull(BitWidth);
7318
7319 SelectPattern StepPattern(*this, BitWidth, Step);
7320 if (!StepPattern.isRecognized())
7321 return ConstantRange::getFull(BitWidth);
7322
7323 if (StartPattern.Condition != StepPattern.Condition) {
7324 // We don't handle this case today; but we could, by considering four
7325 // possibilities below instead of two. I'm not sure if there are cases where
7326 // that will help over what getRange already does, though.
7327 return ConstantRange::getFull(BitWidth);
7328 }
7329
7330 // NB! Calling ScalarEvolution::getConstant is fine, but we should not try to
7331 // construct arbitrary general SCEV expressions here. This function is called
7332 // from deep in the call stack, and calling getSCEV (on a sext instruction,
7333 // say) can end up caching a suboptimal value.
7334
7335 // FIXME: without the explicit `this` receiver below, MSVC errors out with
7336 // C2352 and C2512 (otherwise it isn't needed).
7337
7338 const SCEV *TrueStart = this->getConstant(StartPattern.TrueValue);
7339 const SCEV *TrueStep = this->getConstant(StepPattern.TrueValue);
7340 const SCEV *FalseStart = this->getConstant(StartPattern.FalseValue);
7341 const SCEV *FalseStep = this->getConstant(StepPattern.FalseValue);
7342
7343 ConstantRange TrueRange =
7344 this->getRangeForAffineAR(TrueStart, TrueStep, MaxBECount);
7345 ConstantRange FalseRange =
7346 this->getRangeForAffineAR(FalseStart, FalseStep, MaxBECount);
7347
7348 return TrueRange.unionWith(FalseRange);
7349}
7350
7351SCEV::NoWrapFlags ScalarEvolution::getNoWrapFlagsFromUB(const Value *V) {
7352 if (isa<ConstantExpr>(V)) return SCEV::FlagAnyWrap;
7353 const BinaryOperator *BinOp = cast<BinaryOperator>(V);
7354
7355 // Return early if there are no flags to propagate to the SCEV.
7357 if (BinOp->hasNoUnsignedWrap())
7359 if (BinOp->hasNoSignedWrap())
7361 if (Flags == SCEV::FlagAnyWrap)
7362 return SCEV::FlagAnyWrap;
7363
7364 return isSCEVExprNeverPoison(BinOp) ? Flags : SCEV::FlagAnyWrap;
7365}
7366
7367const Instruction *
7368ScalarEvolution::getNonTrivialDefiningScopeBound(const SCEV *S) {
7369 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S))
7370 return &*AddRec->getLoop()->getHeader()->begin();
7371 if (auto *U = dyn_cast<SCEVUnknown>(S))
7372 if (auto *I = dyn_cast<Instruction>(U->getValue()))
7373 return I;
7374 return nullptr;
7375}
7376
7377const Instruction *
7378ScalarEvolution::getDefiningScopeBound(ArrayRef<const SCEV *> Ops,
7379 bool &Precise) {
7380 Precise = true;
7381 // Do a bounded search of the def relation of the requested SCEVs.
7382 SmallPtrSet<const SCEV *, 16> Visited;
7384 auto pushOp = [&](const SCEV *S) {
7385 if (!Visited.insert(S).second)
7386 return;
7387 // Threshold of 30 here is arbitrary.
7388 if (Visited.size() > 30) {
7389 Precise = false;
7390 return;
7391 }
7392 Worklist.push_back(S);
7393 };
7394
7395 for (const auto *S : Ops)
7396 pushOp(S);
7397
7398 const Instruction *Bound = nullptr;
7399 while (!Worklist.empty()) {
7400 auto *S = Worklist.pop_back_val();
7401 if (auto *DefI = getNonTrivialDefiningScopeBound(S)) {
7402 if (!Bound || DT.dominates(Bound, DefI))
7403 Bound = DefI;
7404 } else {
7405 for (const auto *Op : S->operands())
7406 pushOp(Op);
7407 }
7408 }
7409 return Bound ? Bound : &*F.getEntryBlock().begin();
7410}
7411
7412const Instruction *
7413ScalarEvolution::getDefiningScopeBound(ArrayRef<const SCEV *> Ops) {
7414 bool Discard;
7415 return getDefiningScopeBound(Ops, Discard);
7416}
7417
7418bool ScalarEvolution::isGuaranteedToTransferExecutionTo(const Instruction *A,
7419 const Instruction *B) {
7420 if (A->getParent() == B->getParent() &&
7422 B->getIterator()))
7423 return true;
7424
7425 auto *BLoop = LI.getLoopFor(B->getParent());
7426 if (BLoop && BLoop->getHeader() == B->getParent() &&
7427 BLoop->getLoopPreheader() == A->getParent() &&
7429 A->getParent()->end()) &&
7430 isGuaranteedToTransferExecutionToSuccessor(B->getParent()->begin(),
7431 B->getIterator()))
7432 return true;
7433 return false;
7434}
7435
7436bool ScalarEvolution::isGuaranteedNotToBePoison(const SCEV *Op) {
7437 SCEVPoisonCollector PC(/* LookThroughMaybePoisonBlocking */ true);
7438 visitAll(Op, PC);
7439 return PC.MaybePoison.empty();
7440}
7441
7442bool ScalarEvolution::isGuaranteedNotToCauseUB(const SCEV *Op) {
7443 return !SCEVExprContains(Op, [this](const SCEV *S) {
7444 const SCEV *Op1;
7445 bool M = match(S, m_scev_UDiv(m_SCEV(), m_SCEV(Op1)));
7446 // The UDiv may be UB if the divisor is poison or zero. Unless the divisor
7447 // is a non-zero constant, we have to assume the UDiv may be UB.
7448 return M && (!isKnownNonZero(Op1) || !isGuaranteedNotToBePoison(Op1));
7449 });
7450}
7451
7452bool ScalarEvolution::isSCEVExprNeverPoison(const Instruction *I) {
7453 // Only proceed if we can prove that I does not yield poison.
7455 return false;
7456
7457 // At this point we know that if I is executed, then it does not wrap
7458 // according to at least one of NSW or NUW. If I is not executed, then we do
7459 // not know if the calculation that I represents would wrap. Multiple
7460 // instructions can map to the same SCEV. If we apply NSW or NUW from I to
7461 // the SCEV, we must guarantee no wrapping for that SCEV also when it is
7462 // derived from other instructions that map to the same SCEV. We cannot make
7463 // that guarantee for cases where I is not executed. So we need to find a
7464 // upper bound on the defining scope for the SCEV, and prove that I is
7465 // executed every time we enter that scope. When the bounding scope is a
7466 // loop (the common case), this is equivalent to proving I executes on every
7467 // iteration of that loop.
7469 for (const Use &Op : I->operands()) {
7470 // I could be an extractvalue from a call to an overflow intrinsic.
7471 // TODO: We can do better here in some cases.
7472 if (isSCEVable(Op->getType()))
7473 SCEVOps.push_back(getSCEV(Op));
7474 }
7475 auto *DefI = getDefiningScopeBound(SCEVOps);
7476 return isGuaranteedToTransferExecutionTo(DefI, I);
7477}
7478
7479bool ScalarEvolution::isAddRecNeverPoison(const Instruction *I, const Loop *L) {
7480 // If we know that \c I can never be poison period, then that's enough.
7481 if (isSCEVExprNeverPoison(I))
7482 return true;
7483
7484 // If the loop only has one exit, then we know that, if the loop is entered,
7485 // any instruction dominating that exit will be executed. If any such
7486 // instruction would result in UB, the addrec cannot be poison.
7487 //
7488 // This is basically the same reasoning as in isSCEVExprNeverPoison(), but
7489 // also handles uses outside the loop header (they just need to dominate the
7490 // single exit).
7491
7492 auto *ExitingBB = L->getExitingBlock();
7493 if (!ExitingBB || !loopHasNoAbnormalExits(L))
7494 return false;
7495
7496 SmallPtrSet<const Value *, 16> KnownPoison;
7498
7499 // We start by assuming \c I, the post-inc add recurrence, is poison. Only
7500 // things that are known to be poison under that assumption go on the
7501 // Worklist.
7502 KnownPoison.insert(I);
7503 Worklist.push_back(I);
7504
7505 while (!Worklist.empty()) {
7506 const Instruction *Poison = Worklist.pop_back_val();
7507
7508 for (const Use &U : Poison->uses()) {
7509 const Instruction *PoisonUser = cast<Instruction>(U.getUser());
7510 if (mustTriggerUB(PoisonUser, KnownPoison) &&
7511 DT.dominates(PoisonUser->getParent(), ExitingBB))
7512 return true;
7513
7514 if (propagatesPoison(U) && L->contains(PoisonUser))
7515 if (KnownPoison.insert(PoisonUser).second)
7516 Worklist.push_back(PoisonUser);
7517 }
7518 }
7519
7520 return false;
7521}
7522
7523ScalarEvolution::LoopProperties
7524ScalarEvolution::getLoopProperties(const Loop *L) {
7525 using LoopProperties = ScalarEvolution::LoopProperties;
7526
7527 auto Itr = LoopPropertiesCache.find(L);
7528 if (Itr == LoopPropertiesCache.end()) {
7529 auto HasSideEffects = [](Instruction *I) {
7530 if (auto *SI = dyn_cast<StoreInst>(I))
7531 return !SI->isSimple();
7532
7533 if (I->mayThrow())
7534 return true;
7535
7536 // Non-volatile memset / memcpy do not count as side-effect for forward
7537 // progress.
7538 if (isa<MemIntrinsic>(I) && !I->isVolatile())
7539 return false;
7540
7541 return I->mayWriteToMemory();
7542 };
7543
7544 LoopProperties LP = {/* HasNoAbnormalExits */ true,
7545 /*HasNoSideEffects*/ true};
7546
7547 for (auto *BB : L->getBlocks())
7548 for (auto &I : *BB) {
7550 LP.HasNoAbnormalExits = false;
7551 if (HasSideEffects(&I))
7552 LP.HasNoSideEffects = false;
7553 if (!LP.HasNoAbnormalExits && !LP.HasNoSideEffects)
7554 break; // We're already as pessimistic as we can get.
7555 }
7556
7557 auto InsertPair = LoopPropertiesCache.insert({L, LP});
7558 assert(InsertPair.second && "We just checked!");
7559 Itr = InsertPair.first;
7560 }
7561
7562 return Itr->second;
7563}
7564
7566 // A mustprogress loop without side effects must be finite.
7567 // TODO: The check used here is very conservative. It's only *specific*
7568 // side effects which are well defined in infinite loops.
7569 return isFinite(L) || (isMustProgress(L) && loopHasNoSideEffects(L));
7570}
7571
7572const SCEV *ScalarEvolution::createSCEVIter(Value *V) {
7573 // Worklist item with a Value and a bool indicating whether all operands have
7574 // been visited already.
7577
7578 Stack.emplace_back(V, true);
7579 Stack.emplace_back(V, false);
7580 while (!Stack.empty()) {
7581 auto E = Stack.pop_back_val();
7582 Value *CurV = E.getPointer();
7583
7584 if (getExistingSCEV(CurV))
7585 continue;
7586
7588 const SCEV *CreatedSCEV = nullptr;
7589 // If all operands have been visited already, create the SCEV.
7590 if (E.getInt()) {
7591 CreatedSCEV = createSCEV(CurV);
7592 } else {
7593 // Otherwise get the operands we need to create SCEV's for before creating
7594 // the SCEV for CurV. If the SCEV for CurV can be constructed trivially,
7595 // just use it.
7596 CreatedSCEV = getOperandsToCreate(CurV, Ops);
7597 }
7598
7599 if (CreatedSCEV) {
7600 insertValueToMap(CurV, CreatedSCEV);
7601 } else {
7602 // Queue CurV for SCEV creation, followed by its's operands which need to
7603 // be constructed first.
7604 Stack.emplace_back(CurV, true);
7605 for (Value *Op : Ops)
7606 Stack.emplace_back(Op, false);
7607 }
7608 }
7609
7610 return getExistingSCEV(V);
7611}
7612
7613const SCEV *
7614ScalarEvolution::getOperandsToCreate(Value *V, SmallVectorImpl<Value *> &Ops) {
7615 if (!isSCEVable(V->getType()))
7616 return getUnknown(V);
7617
7618 if (Instruction *I = dyn_cast<Instruction>(V)) {
7619 // Don't attempt to analyze instructions in blocks that aren't
7620 // reachable. Such instructions don't matter, and they aren't required
7621 // to obey basic rules for definitions dominating uses which this
7622 // analysis depends on.
7623 if (!DT.isReachableFromEntry(I->getParent()))
7624 return getUnknown(PoisonValue::get(V->getType()));
7625 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7626 return getConstant(CI);
7627 else if (isa<GlobalAlias>(V))
7628 return getUnknown(V);
7629 else if (!isa<ConstantExpr>(V))
7630 return getUnknown(V);
7631
7633 if (auto BO =
7635 bool IsConstArg = isa<ConstantInt>(BO->RHS);
7636 switch (BO->Opcode) {
7637 case Instruction::Add:
7638 case Instruction::Mul: {
7639 // For additions and multiplications, traverse add/mul chains for which we
7640 // can potentially create a single SCEV, to reduce the number of
7641 // get{Add,Mul}Expr calls.
7642 do {
7643 if (BO->Op) {
7644 if (BO->Op != V && getExistingSCEV(BO->Op)) {
7645 Ops.push_back(BO->Op);
7646 break;
7647 }
7648 }
7649 Ops.push_back(BO->RHS);
7650 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7652 if (!NewBO ||
7653 (BO->Opcode == Instruction::Add &&
7654 (NewBO->Opcode != Instruction::Add &&
7655 NewBO->Opcode != Instruction::Sub)) ||
7656 (BO->Opcode == Instruction::Mul &&
7657 NewBO->Opcode != Instruction::Mul)) {
7658 Ops.push_back(BO->LHS);
7659 break;
7660 }
7661 // CreateSCEV calls getNoWrapFlagsFromUB, which under certain conditions
7662 // requires a SCEV for the LHS.
7663 if (BO->Op && (BO->IsNSW || BO->IsNUW)) {
7664 auto *I = dyn_cast<Instruction>(BO->Op);
7665 if (I && programUndefinedIfPoison(I)) {
7666 Ops.push_back(BO->LHS);
7667 break;
7668 }
7669 }
7670 BO = NewBO;
7671 } while (true);
7672 return nullptr;
7673 }
7674 case Instruction::Sub:
7675 case Instruction::UDiv:
7676 case Instruction::URem:
7677 break;
7678 case Instruction::AShr:
7679 case Instruction::Shl:
7680 case Instruction::Xor:
7681 if (!IsConstArg)
7682 return nullptr;
7683 break;
7684 case Instruction::And:
7685 case Instruction::Or:
7686 if (!IsConstArg && !BO->LHS->getType()->isIntegerTy(1))
7687 return nullptr;
7688 break;
7689 case Instruction::LShr:
7690 return getUnknown(V);
7691 default:
7692 llvm_unreachable("Unhandled binop");
7693 break;
7694 }
7695
7696 Ops.push_back(BO->LHS);
7697 Ops.push_back(BO->RHS);
7698 return nullptr;
7699 }
7700
7701 switch (U->getOpcode()) {
7702 case Instruction::Trunc:
7703 case Instruction::ZExt:
7704 case Instruction::SExt:
7705 case Instruction::PtrToAddr:
7706 case Instruction::PtrToInt:
7707 Ops.push_back(U->getOperand(0));
7708 return nullptr;
7709
7710 case Instruction::BitCast:
7711 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType())) {
7712 Ops.push_back(U->getOperand(0));
7713 return nullptr;
7714 }
7715 return getUnknown(V);
7716
7717 case Instruction::SDiv:
7718 case Instruction::SRem:
7719 Ops.push_back(U->getOperand(0));
7720 Ops.push_back(U->getOperand(1));
7721 return nullptr;
7722
7723 case Instruction::GetElementPtr:
7724 assert(cast<GEPOperator>(U)->getSourceElementType()->isSized() &&
7725 "GEP source element type must be sized");
7726 llvm::append_range(Ops, U->operands());
7727 return nullptr;
7728
7729 case Instruction::IntToPtr:
7730 return getUnknown(V);
7731
7732 case Instruction::PHI:
7733 // Keep constructing SCEVs' for phis recursively for now.
7734 return nullptr;
7735
7736 case Instruction::Select: {
7737 // Check if U is a select that can be simplified to a SCEVUnknown.
7738 auto CanSimplifyToUnknown = [this, U]() {
7739 if (U->getType()->isIntegerTy(1) || isa<ConstantInt>(U->getOperand(0)))
7740 return false;
7741
7742 auto *ICI = dyn_cast<ICmpInst>(U->getOperand(0));
7743 if (!ICI)
7744 return false;
7745 Value *LHS = ICI->getOperand(0);
7746 Value *RHS = ICI->getOperand(1);
7747 if (ICI->getPredicate() == CmpInst::ICMP_EQ ||
7748 ICI->getPredicate() == CmpInst::ICMP_NE) {
7750 return true;
7751 } else if (getTypeSizeInBits(LHS->getType()) >
7752 getTypeSizeInBits(U->getType()))
7753 return true;
7754 return false;
7755 };
7756 if (CanSimplifyToUnknown())
7757 return getUnknown(U);
7758
7759 llvm::append_range(Ops, U->operands());
7760 return nullptr;
7761 break;
7762 }
7763 case Instruction::Call:
7764 case Instruction::Invoke:
7765 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand()) {
7766 Ops.push_back(RV);
7767 return nullptr;
7768 }
7769
7770 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
7771 switch (II->getIntrinsicID()) {
7772 case Intrinsic::abs:
7773 Ops.push_back(II->getArgOperand(0));
7774 return nullptr;
7775 case Intrinsic::umax:
7776 case Intrinsic::umin:
7777 case Intrinsic::smax:
7778 case Intrinsic::smin:
7779 case Intrinsic::usub_sat:
7780 case Intrinsic::uadd_sat:
7781 Ops.push_back(II->getArgOperand(0));
7782 Ops.push_back(II->getArgOperand(1));
7783 return nullptr;
7784 case Intrinsic::start_loop_iterations:
7785 case Intrinsic::annotation:
7786 case Intrinsic::ptr_annotation:
7787 Ops.push_back(II->getArgOperand(0));
7788 return nullptr;
7789 default:
7790 break;
7791 }
7792 }
7793 break;
7794 }
7795
7796 return nullptr;
7797}
7798
7799const SCEV *ScalarEvolution::createSCEV(Value *V) {
7800 if (!isSCEVable(V->getType()))
7801 return getUnknown(V);
7802
7803 if (Instruction *I = dyn_cast<Instruction>(V)) {
7804 // Don't attempt to analyze instructions in blocks that aren't
7805 // reachable. Such instructions don't matter, and they aren't required
7806 // to obey basic rules for definitions dominating uses which this
7807 // analysis depends on.
7808 if (!DT.isReachableFromEntry(I->getParent()))
7809 return getUnknown(PoisonValue::get(V->getType()));
7810 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7811 return getConstant(CI);
7812 else if (isa<GlobalAlias>(V))
7813 return getUnknown(V);
7814 else if (!isa<ConstantExpr>(V))
7815 return getUnknown(V);
7816
7817 const SCEV *LHS;
7818 const SCEV *RHS;
7819
7821 if (auto BO =
7823 switch (BO->Opcode) {
7824 case Instruction::Add: {
7825 // The simple thing to do would be to just call getSCEV on both operands
7826 // and call getAddExpr with the result. However if we're looking at a
7827 // bunch of things all added together, this can be quite inefficient,
7828 // because it leads to N-1 getAddExpr calls for N ultimate operands.
7829 // Instead, gather up all the operands and make a single getAddExpr call.
7830 // LLVM IR canonical form means we need only traverse the left operands.
7832 do {
7833 if (BO->Op) {
7834 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
7835 AddOps.push_back(OpSCEV);
7836 break;
7837 }
7838
7839 // If a NUW or NSW flag can be applied to the SCEV for this
7840 // addition, then compute the SCEV for this addition by itself
7841 // with a separate call to getAddExpr. We need to do that
7842 // instead of pushing the operands of the addition onto AddOps,
7843 // since the flags are only known to apply to this particular
7844 // addition - they may not apply to other additions that can be
7845 // formed with operands from AddOps.
7846 const SCEV *RHS = getSCEV(BO->RHS);
7847 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
7848 if (Flags != SCEV::FlagAnyWrap) {
7849 const SCEV *LHS = getSCEV(BO->LHS);
7850 if (BO->Opcode == Instruction::Sub)
7851 AddOps.push_back(getMinusSCEV(LHS, RHS, Flags));
7852 else
7853 AddOps.push_back(getAddExpr(LHS, RHS, Flags));
7854 break;
7855 }
7856 }
7857
7858 if (BO->Opcode == Instruction::Sub)
7859 AddOps.push_back(getNegativeSCEV(getSCEV(BO->RHS)));
7860 else
7861 AddOps.push_back(getSCEV(BO->RHS));
7862
7863 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7865 if (!NewBO || (NewBO->Opcode != Instruction::Add &&
7866 NewBO->Opcode != Instruction::Sub)) {
7867 AddOps.push_back(getSCEV(BO->LHS));
7868 break;
7869 }
7870 BO = NewBO;
7871 } while (true);
7872
7873 return getAddExpr(AddOps);
7874 }
7875
7876 case Instruction::Mul: {
7878 do {
7879 if (BO->Op) {
7880 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
7881 MulOps.push_back(OpSCEV);
7882 break;
7883 }
7884
7885 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
7886 if (Flags != SCEV::FlagAnyWrap) {
7887 LHS = getSCEV(BO->LHS);
7888 RHS = getSCEV(BO->RHS);
7889 MulOps.push_back(getMulExpr(LHS, RHS, Flags));
7890 break;
7891 }
7892 }
7893
7894 MulOps.push_back(getSCEV(BO->RHS));
7895 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7897 if (!NewBO || NewBO->Opcode != Instruction::Mul) {
7898 MulOps.push_back(getSCEV(BO->LHS));
7899 break;
7900 }
7901 BO = NewBO;
7902 } while (true);
7903
7904 return getMulExpr(MulOps);
7905 }
7906 case Instruction::UDiv:
7907 LHS = getSCEV(BO->LHS);
7908 RHS = getSCEV(BO->RHS);
7909 return getUDivExpr(LHS, RHS);
7910 case Instruction::URem:
7911 LHS = getSCEV(BO->LHS);
7912 RHS = getSCEV(BO->RHS);
7913 return getURemExpr(LHS, RHS);
7914 case Instruction::Sub: {
7916 if (BO->Op)
7917 Flags = getNoWrapFlagsFromUB(BO->Op);
7918 LHS = getSCEV(BO->LHS);
7919 RHS = getSCEV(BO->RHS);
7920 return getMinusSCEV(LHS, RHS, Flags);
7921 }
7922 case Instruction::And:
7923 // For an expression like x&255 that merely masks off the high bits,
7924 // use zext(trunc(x)) as the SCEV expression.
7925 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
7926 if (CI->isZero())
7927 return getSCEV(BO->RHS);
7928 if (CI->isMinusOne())
7929 return getSCEV(BO->LHS);
7930 const APInt &A = CI->getValue();
7931
7932 // Instcombine's ShrinkDemandedConstant may strip bits out of
7933 // constants, obscuring what would otherwise be a low-bits mask.
7934 // Use computeKnownBits to compute what ShrinkDemandedConstant
7935 // knew about to reconstruct a low-bits mask value.
7936 unsigned LZ = A.countl_zero();
7937 unsigned TZ = A.countr_zero();
7938 unsigned BitWidth = A.getBitWidth();
7939 KnownBits Known(BitWidth);
7940 computeKnownBits(BO->LHS, Known, getDataLayout(), &AC, nullptr, &DT);
7941
7942 APInt EffectiveMask =
7943 APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ);
7944 if ((LZ != 0 || TZ != 0) && !((~A & ~Known.Zero) & EffectiveMask)) {
7945 const SCEV *MulCount = getConstant(APInt::getOneBitSet(BitWidth, TZ));
7946 const SCEV *LHS = getSCEV(BO->LHS);
7947 const SCEV *ShiftedLHS = nullptr;
7948 if (auto *LHSMul = dyn_cast<SCEVMulExpr>(LHS)) {
7949 if (auto *OpC = dyn_cast<SCEVConstant>(LHSMul->getOperand(0))) {
7950 // For an expression like (x * 8) & 8, simplify the multiply.
7951 unsigned MulZeros = OpC->getAPInt().countr_zero();
7952 unsigned GCD = std::min(MulZeros, TZ);
7953 APInt DivAmt = APInt::getOneBitSet(BitWidth, TZ - GCD);
7955 MulOps.push_back(getConstant(OpC->getAPInt().ashr(GCD)));
7956 append_range(MulOps, LHSMul->operands().drop_front());
7957 auto *NewMul = getMulExpr(MulOps, LHSMul->getNoWrapFlags());
7958 ShiftedLHS = getUDivExpr(NewMul, getConstant(DivAmt));
7959 }
7960 }
7961 if (!ShiftedLHS)
7962 ShiftedLHS = getUDivExpr(LHS, MulCount);
7963 return getMulExpr(
7965 getTruncateExpr(ShiftedLHS,
7966 IntegerType::get(getContext(), BitWidth - LZ - TZ)),
7967 BO->LHS->getType()),
7968 MulCount);
7969 }
7970 }
7971 // Binary `and` is a bit-wise `umin`.
7972 if (BO->LHS->getType()->isIntegerTy(1)) {
7973 LHS = getSCEV(BO->LHS);
7974 RHS = getSCEV(BO->RHS);
7975 return getUMinExpr(LHS, RHS);
7976 }
7977 break;
7978
7979 case Instruction::Or:
7980 // Binary `or` is a bit-wise `umax`.
7981 if (BO->LHS->getType()->isIntegerTy(1)) {
7982 LHS = getSCEV(BO->LHS);
7983 RHS = getSCEV(BO->RHS);
7984 return getUMaxExpr(LHS, RHS);
7985 }
7986 break;
7987
7988 case Instruction::Xor:
7989 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
7990 // If the RHS of xor is -1, then this is a not operation.
7991 if (CI->isMinusOne())
7992 return getNotSCEV(getSCEV(BO->LHS));
7993
7994 // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask.
7995 // This is a variant of the check for xor with -1, and it handles
7996 // the case where instcombine has trimmed non-demanded bits out
7997 // of an xor with -1.
7998 if (auto *LBO = dyn_cast<BinaryOperator>(BO->LHS))
7999 if (ConstantInt *LCI = dyn_cast<ConstantInt>(LBO->getOperand(1)))
8000 if (LBO->getOpcode() == Instruction::And &&
8001 LCI->getValue() == CI->getValue())
8002 if (const SCEVZeroExtendExpr *Z =
8004 Type *UTy = BO->LHS->getType();
8005 const SCEV *Z0 = Z->getOperand();
8006 Type *Z0Ty = Z0->getType();
8007 unsigned Z0TySize = getTypeSizeInBits(Z0Ty);
8008
8009 // If C is a low-bits mask, the zero extend is serving to
8010 // mask off the high bits. Complement the operand and
8011 // re-apply the zext.
8012 if (CI->getValue().isMask(Z0TySize))
8013 return getZeroExtendExpr(getNotSCEV(Z0), UTy);
8014
8015 // If C is a single bit, it may be in the sign-bit position
8016 // before the zero-extend. In this case, represent the xor
8017 // using an add, which is equivalent, and re-apply the zext.
8018 APInt Trunc = CI->getValue().trunc(Z0TySize);
8019 if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() &&
8020 Trunc.isSignMask())
8021 return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)),
8022 UTy);
8023 }
8024 }
8025 break;
8026
8027 case Instruction::Shl:
8028 // Turn shift left of a constant amount into a multiply.
8029 if (ConstantInt *SA = dyn_cast<ConstantInt>(BO->RHS)) {
8030 uint32_t BitWidth = cast<IntegerType>(SA->getType())->getBitWidth();
8031
8032 // If the shift count is not less than the bitwidth, the result of
8033 // the shift is undefined. Don't try to analyze it, because the
8034 // resolution chosen here may differ from the resolution chosen in
8035 // other parts of the compiler.
8036 if (SA->getValue().uge(BitWidth))
8037 break;
8038
8039 // We can safely preserve the nuw flag in all cases. It's also safe to
8040 // turn a nuw nsw shl into a nuw nsw mul. However, nsw in isolation
8041 // requires special handling. It can be preserved as long as we're not
8042 // left shifting by bitwidth - 1.
8043 auto Flags = SCEV::FlagAnyWrap;
8044 if (BO->Op) {
8045 auto MulFlags = getNoWrapFlagsFromUB(BO->Op);
8046 if ((MulFlags & SCEV::FlagNSW) &&
8047 ((MulFlags & SCEV::FlagNUW) || SA->getValue().ult(BitWidth - 1)))
8049 if (MulFlags & SCEV::FlagNUW)
8051 }
8052
8053 ConstantInt *X = ConstantInt::get(
8054 getContext(), APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
8055 return getMulExpr(getSCEV(BO->LHS), getConstant(X), Flags);
8056 }
8057 break;
8058
8059 case Instruction::AShr:
8060 // AShr X, C, where C is a constant.
8061 ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS);
8062 if (!CI)
8063 break;
8064
8065 Type *OuterTy = BO->LHS->getType();
8066 uint64_t BitWidth = getTypeSizeInBits(OuterTy);
8067 // If the shift count is not less than the bitwidth, the result of
8068 // the shift is undefined. Don't try to analyze it, because the
8069 // resolution chosen here may differ from the resolution chosen in
8070 // other parts of the compiler.
8071 if (CI->getValue().uge(BitWidth))
8072 break;
8073
8074 if (CI->isZero())
8075 return getSCEV(BO->LHS); // shift by zero --> noop
8076
8077 uint64_t AShrAmt = CI->getZExtValue();
8078 Type *TruncTy = IntegerType::get(getContext(), BitWidth - AShrAmt);
8079
8080 Operator *L = dyn_cast<Operator>(BO->LHS);
8081 const SCEV *AddTruncateExpr = nullptr;
8082 ConstantInt *ShlAmtCI = nullptr;
8083 const SCEV *AddConstant = nullptr;
8084
8085 if (L && L->getOpcode() == Instruction::Add) {
8086 // X = Shl A, n
8087 // Y = Add X, c
8088 // Z = AShr Y, m
8089 // n, c and m are constants.
8090
8091 Operator *LShift = dyn_cast<Operator>(L->getOperand(0));
8092 ConstantInt *AddOperandCI = dyn_cast<ConstantInt>(L->getOperand(1));
8093 if (LShift && LShift->getOpcode() == Instruction::Shl) {
8094 if (AddOperandCI) {
8095 const SCEV *ShlOp0SCEV = getSCEV(LShift->getOperand(0));
8096 ShlAmtCI = dyn_cast<ConstantInt>(LShift->getOperand(1));
8097 // since we truncate to TruncTy, the AddConstant should be of the
8098 // same type, so create a new Constant with type same as TruncTy.
8099 // Also, the Add constant should be shifted right by AShr amount.
8100 APInt AddOperand = AddOperandCI->getValue().ashr(AShrAmt);
8101 AddConstant = getConstant(AddOperand.trunc(BitWidth - AShrAmt));
8102 // we model the expression as sext(add(trunc(A), c << n)), since the
8103 // sext(trunc) part is already handled below, we create a
8104 // AddExpr(TruncExp) which will be used later.
8105 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
8106 }
8107 }
8108 } else if (L && L->getOpcode() == Instruction::Shl) {
8109 // X = Shl A, n
8110 // Y = AShr X, m
8111 // Both n and m are constant.
8112
8113 const SCEV *ShlOp0SCEV = getSCEV(L->getOperand(0));
8114 ShlAmtCI = dyn_cast<ConstantInt>(L->getOperand(1));
8115 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
8116 }
8117
8118 if (AddTruncateExpr && ShlAmtCI) {
8119 // We can merge the two given cases into a single SCEV statement,
8120 // incase n = m, the mul expression will be 2^0, so it gets resolved to
8121 // a simpler case. The following code handles the two cases:
8122 //
8123 // 1) For a two-shift sext-inreg, i.e. n = m,
8124 // use sext(trunc(x)) as the SCEV expression.
8125 //
8126 // 2) When n > m, use sext(mul(trunc(x), 2^(n-m)))) as the SCEV
8127 // expression. We already checked that ShlAmt < BitWidth, so
8128 // the multiplier, 1 << (ShlAmt - AShrAmt), fits into TruncTy as
8129 // ShlAmt - AShrAmt < Amt.
8130 const APInt &ShlAmt = ShlAmtCI->getValue();
8131 if (ShlAmt.ult(BitWidth) && ShlAmt.uge(AShrAmt)) {
8132 APInt Mul = APInt::getOneBitSet(BitWidth - AShrAmt,
8133 ShlAmtCI->getZExtValue() - AShrAmt);
8134 const SCEV *CompositeExpr =
8135 getMulExpr(AddTruncateExpr, getConstant(Mul));
8136 if (L->getOpcode() != Instruction::Shl)
8137 CompositeExpr = getAddExpr(CompositeExpr, AddConstant);
8138
8139 return getSignExtendExpr(CompositeExpr, OuterTy);
8140 }
8141 }
8142 break;
8143 }
8144 }
8145
8146 switch (U->getOpcode()) {
8147 case Instruction::Trunc:
8148 return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType());
8149
8150 case Instruction::ZExt:
8151 return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType());
8152
8153 case Instruction::SExt:
8154 if (auto BO = MatchBinaryOp(U->getOperand(0), getDataLayout(), AC, DT,
8156 // The NSW flag of a subtract does not always survive the conversion to
8157 // A + (-1)*B. By pushing sign extension onto its operands we are much
8158 // more likely to preserve NSW and allow later AddRec optimisations.
8159 //
8160 // NOTE: This is effectively duplicating this logic from getSignExtend:
8161 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
8162 // but by that point the NSW information has potentially been lost.
8163 if (BO->Opcode == Instruction::Sub && BO->IsNSW) {
8164 Type *Ty = U->getType();
8165 auto *V1 = getSignExtendExpr(getSCEV(BO->LHS), Ty);
8166 auto *V2 = getSignExtendExpr(getSCEV(BO->RHS), Ty);
8167 return getMinusSCEV(V1, V2, SCEV::FlagNSW);
8168 }
8169 }
8170 return getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType());
8171
8172 case Instruction::BitCast:
8173 // BitCasts are no-op casts so we just eliminate the cast.
8174 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType()))
8175 return getSCEV(U->getOperand(0));
8176 break;
8177
8178 case Instruction::PtrToAddr:
8179 return getPtrToAddrExpr(getSCEV(U->getOperand(0)));
8180
8181 case Instruction::PtrToInt: {
8182 // Pointer to integer cast is straight-forward, so do model it.
8183 const SCEV *Op = getSCEV(U->getOperand(0));
8184 Type *DstIntTy = U->getType();
8185 // But only if effective SCEV (integer) type is wide enough to represent
8186 // all possible pointer values.
8187 const SCEV *IntOp = getPtrToIntExpr(Op, DstIntTy);
8188 if (isa<SCEVCouldNotCompute>(IntOp))
8189 return getUnknown(V);
8190 return IntOp;
8191 }
8192 case Instruction::IntToPtr:
8193 // Just don't deal with inttoptr casts.
8194 return getUnknown(V);
8195
8196 case Instruction::SDiv:
8197 // If both operands are non-negative, this is just an udiv.
8198 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8199 isKnownNonNegative(getSCEV(U->getOperand(1))))
8200 return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8201 break;
8202
8203 case Instruction::SRem:
8204 // If both operands are non-negative, this is just an urem.
8205 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8206 isKnownNonNegative(getSCEV(U->getOperand(1))))
8207 return getURemExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8208 break;
8209
8210 case Instruction::GetElementPtr:
8211 return createNodeForGEP(cast<GEPOperator>(U));
8212
8213 case Instruction::PHI:
8214 return createNodeForPHI(cast<PHINode>(U));
8215
8216 case Instruction::Select:
8217 return createNodeForSelectOrPHI(U, U->getOperand(0), U->getOperand(1),
8218 U->getOperand(2));
8219
8220 case Instruction::Call:
8221 case Instruction::Invoke:
8222 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand())
8223 return getSCEV(RV);
8224
8225 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
8226 switch (II->getIntrinsicID()) {
8227 case Intrinsic::abs:
8228 return getAbsExpr(
8229 getSCEV(II->getArgOperand(0)),
8230 /*IsNSW=*/cast<ConstantInt>(II->getArgOperand(1))->isOne());
8231 case Intrinsic::umax:
8232 LHS = getSCEV(II->getArgOperand(0));
8233 RHS = getSCEV(II->getArgOperand(1));
8234 return getUMaxExpr(LHS, RHS);
8235 case Intrinsic::umin:
8236 LHS = getSCEV(II->getArgOperand(0));
8237 RHS = getSCEV(II->getArgOperand(1));
8238 return getUMinExpr(LHS, RHS);
8239 case Intrinsic::smax:
8240 LHS = getSCEV(II->getArgOperand(0));
8241 RHS = getSCEV(II->getArgOperand(1));
8242 return getSMaxExpr(LHS, RHS);
8243 case Intrinsic::smin:
8244 LHS = getSCEV(II->getArgOperand(0));
8245 RHS = getSCEV(II->getArgOperand(1));
8246 return getSMinExpr(LHS, RHS);
8247 case Intrinsic::usub_sat: {
8248 const SCEV *X = getSCEV(II->getArgOperand(0));
8249 const SCEV *Y = getSCEV(II->getArgOperand(1));
8250 const SCEV *ClampedY = getUMinExpr(X, Y);
8251 return getMinusSCEV(X, ClampedY, SCEV::FlagNUW);
8252 }
8253 case Intrinsic::uadd_sat: {
8254 const SCEV *X = getSCEV(II->getArgOperand(0));
8255 const SCEV *Y = getSCEV(II->getArgOperand(1));
8256 const SCEV *ClampedX = getUMinExpr(X, getNotSCEV(Y));
8257 return getAddExpr(ClampedX, Y, SCEV::FlagNUW);
8258 }
8259 case Intrinsic::start_loop_iterations:
8260 case Intrinsic::annotation:
8261 case Intrinsic::ptr_annotation:
8262 // A start_loop_iterations or llvm.annotation or llvm.prt.annotation is
8263 // just eqivalent to the first operand for SCEV purposes.
8264 return getSCEV(II->getArgOperand(0));
8265 case Intrinsic::vscale:
8266 return getVScale(II->getType());
8267 default:
8268 break;
8269 }
8270 }
8271 break;
8272 }
8273
8274 return getUnknown(V);
8275}
8276
8277//===----------------------------------------------------------------------===//
8278// Iteration Count Computation Code
8279//
8280
8282 if (isa<SCEVCouldNotCompute>(ExitCount))
8283 return getCouldNotCompute();
8284
8285 auto *ExitCountType = ExitCount->getType();
8286 assert(ExitCountType->isIntegerTy());
8287 auto *EvalTy = Type::getIntNTy(ExitCountType->getContext(),
8288 1 + ExitCountType->getScalarSizeInBits());
8289 return getTripCountFromExitCount(ExitCount, EvalTy, nullptr);
8290}
8291
8293 Type *EvalTy,
8294 const Loop *L) {
8295 if (isa<SCEVCouldNotCompute>(ExitCount))
8296 return getCouldNotCompute();
8297
8298 unsigned ExitCountSize = getTypeSizeInBits(ExitCount->getType());
8299 unsigned EvalSize = EvalTy->getPrimitiveSizeInBits();
8300
8301 auto CanAddOneWithoutOverflow = [&]() {
8302 ConstantRange ExitCountRange =
8303 getRangeRef(ExitCount, RangeSignHint::HINT_RANGE_UNSIGNED);
8304 if (!ExitCountRange.contains(APInt::getMaxValue(ExitCountSize)))
8305 return true;
8306
8307 return L && isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, ExitCount,
8308 getMinusOne(ExitCount->getType()));
8309 };
8310
8311 // If we need to zero extend the backedge count, check if we can add one to
8312 // it prior to zero extending without overflow. Provided this is safe, it
8313 // allows better simplification of the +1.
8314 if (EvalSize > ExitCountSize && CanAddOneWithoutOverflow())
8315 return getZeroExtendExpr(
8316 getAddExpr(ExitCount, getOne(ExitCount->getType())), EvalTy);
8317
8318 // Get the total trip count from the count by adding 1. This may wrap.
8319 return getAddExpr(getTruncateOrZeroExtend(ExitCount, EvalTy), getOne(EvalTy));
8320}
8321
8322static unsigned getConstantTripCount(const SCEVConstant *ExitCount) {
8323 if (!ExitCount)
8324 return 0;
8325
8326 ConstantInt *ExitConst = ExitCount->getValue();
8327
8328 // Guard against huge trip counts.
8329 if (ExitConst->getValue().getActiveBits() > 32)
8330 return 0;
8331
8332 // In case of integer overflow, this returns 0, which is correct.
8333 return ((unsigned)ExitConst->getZExtValue()) + 1;
8334}
8335
8337 auto *ExitCount = dyn_cast<SCEVConstant>(getBackedgeTakenCount(L, Exact));
8338 return getConstantTripCount(ExitCount);
8339}
8340
8341unsigned
8343 const BasicBlock *ExitingBlock) {
8344 assert(ExitingBlock && "Must pass a non-null exiting block!");
8345 assert(L->isLoopExiting(ExitingBlock) &&
8346 "Exiting block must actually branch out of the loop!");
8347 const SCEVConstant *ExitCount =
8348 dyn_cast<SCEVConstant>(getExitCount(L, ExitingBlock));
8349 return getConstantTripCount(ExitCount);
8350}
8351
8353 const Loop *L, SmallVectorImpl<const SCEVPredicate *> *Predicates) {
8354
8355 const auto *MaxExitCount =
8356 Predicates ? getPredicatedConstantMaxBackedgeTakenCount(L, *Predicates)
8358 return getConstantTripCount(dyn_cast<SCEVConstant>(MaxExitCount));
8359}
8360
8362 SmallVector<BasicBlock *, 8> ExitingBlocks;
8363 L->getExitingBlocks(ExitingBlocks);
8364
8365 std::optional<unsigned> Res;
8366 for (auto *ExitingBB : ExitingBlocks) {
8367 unsigned Multiple = getSmallConstantTripMultiple(L, ExitingBB);
8368 if (!Res)
8369 Res = Multiple;
8370 Res = std::gcd(*Res, Multiple);
8371 }
8372 return Res.value_or(1);
8373}
8374
8376 const SCEV *ExitCount) {
8377 if (isa<SCEVCouldNotCompute>(ExitCount))
8378 return 1;
8379
8380 // Get the trip count
8381 const SCEV *TCExpr = getTripCountFromExitCount(applyLoopGuards(ExitCount, L));
8382
8383 APInt Multiple = getNonZeroConstantMultiple(TCExpr);
8384 // If a trip multiple is huge (>=2^32), the trip count is still divisible by
8385 // the greatest power of 2 divisor less than 2^32.
8386 return Multiple.getActiveBits() > 32
8387 ? 1U << std::min(31U, Multiple.countTrailingZeros())
8388 : (unsigned)Multiple.getZExtValue();
8389}
8390
8391/// Returns the largest constant divisor of the trip count of this loop as a
8392/// normal unsigned value, if possible. This means that the actual trip count is
8393/// always a multiple of the returned value (don't forget the trip count could
8394/// very well be zero as well!).
8395///
8396/// Returns 1 if the trip count is unknown or not guaranteed to be the
8397/// multiple of a constant (which is also the case if the trip count is simply
8398/// constant, use getSmallConstantTripCount for that case), Will also return 1
8399/// if the trip count is very large (>= 2^32).
8400///
8401/// As explained in the comments for getSmallConstantTripCount, this assumes
8402/// that control exits the loop via ExitingBlock.
8403unsigned
8405 const BasicBlock *ExitingBlock) {
8406 assert(ExitingBlock && "Must pass a non-null exiting block!");
8407 assert(L->isLoopExiting(ExitingBlock) &&
8408 "Exiting block must actually branch out of the loop!");
8409 const SCEV *ExitCount = getExitCount(L, ExitingBlock);
8410 return getSmallConstantTripMultiple(L, ExitCount);
8411}
8412
8414 const BasicBlock *ExitingBlock,
8415 ExitCountKind Kind) {
8416 switch (Kind) {
8417 case Exact:
8418 return getBackedgeTakenInfo(L).getExact(ExitingBlock, this);
8419 case SymbolicMaximum:
8420 return getBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this);
8421 case ConstantMaximum:
8422 return getBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this);
8423 };
8424 llvm_unreachable("Invalid ExitCountKind!");
8425}
8426
8428 const Loop *L, const BasicBlock *ExitingBlock,
8430 switch (Kind) {
8431 case Exact:
8432 return getPredicatedBackedgeTakenInfo(L).getExact(ExitingBlock, this,
8433 Predicates);
8434 case SymbolicMaximum:
8435 return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this,
8436 Predicates);
8437 case ConstantMaximum:
8438 return getPredicatedBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this,
8439 Predicates);
8440 };
8441 llvm_unreachable("Invalid ExitCountKind!");
8442}
8443
8446 return getPredicatedBackedgeTakenInfo(L).getExact(L, this, &Preds);
8447}
8448
8450 ExitCountKind Kind) {
8451 switch (Kind) {
8452 case Exact:
8453 return getBackedgeTakenInfo(L).getExact(L, this);
8454 case ConstantMaximum:
8455 return getBackedgeTakenInfo(L).getConstantMax(this);
8456 case SymbolicMaximum:
8457 return getBackedgeTakenInfo(L).getSymbolicMax(L, this);
8458 };
8459 llvm_unreachable("Invalid ExitCountKind!");
8460}
8461
8464 return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(L, this, &Preds);
8465}
8466
8469 return getPredicatedBackedgeTakenInfo(L).getConstantMax(this, &Preds);
8470}
8471
8473 return getBackedgeTakenInfo(L).isConstantMaxOrZero(this);
8474}
8475
8476/// Push PHI nodes in the header of the given loop onto the given Worklist.
8477static void PushLoopPHIs(const Loop *L,
8480 BasicBlock *Header = L->getHeader();
8481
8482 // Push all Loop-header PHIs onto the Worklist stack.
8483 for (PHINode &PN : Header->phis())
8484 if (Visited.insert(&PN).second)
8485 Worklist.push_back(&PN);
8486}
8487
8488ScalarEvolution::BackedgeTakenInfo &
8489ScalarEvolution::getPredicatedBackedgeTakenInfo(const Loop *L) {
8490 auto &BTI = getBackedgeTakenInfo(L);
8491 if (BTI.hasFullInfo())
8492 return BTI;
8493
8494 auto Pair = PredicatedBackedgeTakenCounts.try_emplace(L);
8495
8496 if (!Pair.second)
8497 return Pair.first->second;
8498
8499 BackedgeTakenInfo Result =
8500 computeBackedgeTakenCount(L, /*AllowPredicates=*/true);
8501
8502 return PredicatedBackedgeTakenCounts.find(L)->second = std::move(Result);
8503}
8504
8505ScalarEvolution::BackedgeTakenInfo &
8506ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
8507 // Initially insert an invalid entry for this loop. If the insertion
8508 // succeeds, proceed to actually compute a backedge-taken count and
8509 // update the value. The temporary CouldNotCompute value tells SCEV
8510 // code elsewhere that it shouldn't attempt to request a new
8511 // backedge-taken count, which could result in infinite recursion.
8512 std::pair<DenseMap<const Loop *, BackedgeTakenInfo>::iterator, bool> Pair =
8513 BackedgeTakenCounts.try_emplace(L);
8514 if (!Pair.second)
8515 return Pair.first->second;
8516
8517 // computeBackedgeTakenCount may allocate memory for its result. Inserting it
8518 // into the BackedgeTakenCounts map transfers ownership. Otherwise, the result
8519 // must be cleared in this scope.
8520 BackedgeTakenInfo Result = computeBackedgeTakenCount(L);
8521
8522 // Now that we know more about the trip count for this loop, forget any
8523 // existing SCEV values for PHI nodes in this loop since they are only
8524 // conservative estimates made without the benefit of trip count
8525 // information. This invalidation is not necessary for correctness, and is
8526 // only done to produce more precise results.
8527 if (Result.hasAnyInfo()) {
8528 // Invalidate any expression using an addrec in this loop.
8530 auto LoopUsersIt = LoopUsers.find(L);
8531 if (LoopUsersIt != LoopUsers.end())
8532 append_range(ToForget, LoopUsersIt->second);
8533 forgetMemoizedResults(ToForget);
8534
8535 // Invalidate constant-evolved loop header phis.
8536 for (PHINode &PN : L->getHeader()->phis())
8537 ConstantEvolutionLoopExitValue.erase(&PN);
8538 }
8539
8540 // Re-lookup the insert position, since the call to
8541 // computeBackedgeTakenCount above could result in a
8542 // recusive call to getBackedgeTakenInfo (on a different
8543 // loop), which would invalidate the iterator computed
8544 // earlier.
8545 return BackedgeTakenCounts.find(L)->second = std::move(Result);
8546}
8547
8549 // This method is intended to forget all info about loops. It should
8550 // invalidate caches as if the following happened:
8551 // - The trip counts of all loops have changed arbitrarily
8552 // - Every llvm::Value has been updated in place to produce a different
8553 // result.
8554 BackedgeTakenCounts.clear();
8555 PredicatedBackedgeTakenCounts.clear();
8556 BECountUsers.clear();
8557 LoopPropertiesCache.clear();
8558 ConstantEvolutionLoopExitValue.clear();
8559 ValueExprMap.clear();
8560 ValuesAtScopes.clear();
8561 ValuesAtScopesUsers.clear();
8562 LoopDispositions.clear();
8563 BlockDispositions.clear();
8564 UnsignedRanges.clear();
8565 SignedRanges.clear();
8566 ExprValueMap.clear();
8567 HasRecMap.clear();
8568 ConstantMultipleCache.clear();
8569 PredicatedSCEVRewrites.clear();
8570 FoldCache.clear();
8571 FoldCacheUser.clear();
8572}
8573void ScalarEvolution::visitAndClearUsers(
8577 while (!Worklist.empty()) {
8578 Instruction *I = Worklist.pop_back_val();
8579 if (!isSCEVable(I->getType()) && !isa<WithOverflowInst>(I))
8580 continue;
8581
8583 ValueExprMap.find_as(static_cast<Value *>(I));
8584 if (It != ValueExprMap.end()) {
8585 eraseValueFromMap(It->first);
8586 ToForget.push_back(It->second);
8587 if (PHINode *PN = dyn_cast<PHINode>(I))
8588 ConstantEvolutionLoopExitValue.erase(PN);
8589 }
8590
8591 PushDefUseChildren(I, Worklist, Visited);
8592 }
8593}
8594
8596 SmallVector<const Loop *, 16> LoopWorklist(1, L);
8600
8601 // Iterate over all the loops and sub-loops to drop SCEV information.
8602 while (!LoopWorklist.empty()) {
8603 auto *CurrL = LoopWorklist.pop_back_val();
8604
8605 // Drop any stored trip count value.
8606 forgetBackedgeTakenCounts(CurrL, /* Predicated */ false);
8607 forgetBackedgeTakenCounts(CurrL, /* Predicated */ true);
8608
8609 // Drop information about predicated SCEV rewrites for this loop.
8610 for (auto I = PredicatedSCEVRewrites.begin();
8611 I != PredicatedSCEVRewrites.end();) {
8612 std::pair<const SCEV *, const Loop *> Entry = I->first;
8613 if (Entry.second == CurrL)
8614 PredicatedSCEVRewrites.erase(I++);
8615 else
8616 ++I;
8617 }
8618
8619 auto LoopUsersItr = LoopUsers.find(CurrL);
8620 if (LoopUsersItr != LoopUsers.end())
8621 llvm::append_range(ToForget, LoopUsersItr->second);
8622
8623 // Drop information about expressions based on loop-header PHIs.
8624 PushLoopPHIs(CurrL, Worklist, Visited);
8625 visitAndClearUsers(Worklist, Visited, ToForget);
8626
8627 LoopPropertiesCache.erase(CurrL);
8628 // Forget all contained loops too, to avoid dangling entries in the
8629 // ValuesAtScopes map.
8630 LoopWorklist.append(CurrL->begin(), CurrL->end());
8631 }
8632 forgetMemoizedResults(ToForget);
8633}
8634
8636 forgetLoop(L->getOutermostLoop());
8637}
8638
8641 if (!I) return;
8642
8643 // Drop information about expressions based on loop-header PHIs.
8647 Worklist.push_back(I);
8648 Visited.insert(I);
8649 visitAndClearUsers(Worklist, Visited, ToForget);
8650
8651 forgetMemoizedResults(ToForget);
8652}
8653
8655 if (!isSCEVable(V->getType()))
8656 return;
8657
8658 // If SCEV looked through a trivial LCSSA phi node, we might have SCEV's
8659 // directly using a SCEVUnknown/SCEVAddRec defined in the loop. After an
8660 // extra predecessor is added, this is no longer valid. Find all Unknowns and
8661 // AddRecs defined in the loop and invalidate any SCEV's making use of them.
8662 if (const SCEV *S = getExistingSCEV(V)) {
8663 struct InvalidationRootCollector {
8664 Loop *L;
8666
8667 InvalidationRootCollector(Loop *L) : L(L) {}
8668
8669 bool follow(const SCEV *S) {
8670 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
8671 if (auto *I = dyn_cast<Instruction>(SU->getValue()))
8672 if (L->contains(I))
8673 Roots.push_back(S);
8674 } else if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
8675 if (L->contains(AddRec->getLoop()))
8676 Roots.push_back(S);
8677 }
8678 return true;
8679 }
8680 bool isDone() const { return false; }
8681 };
8682
8683 InvalidationRootCollector C(L);
8684 visitAll(S, C);
8685 forgetMemoizedResults(C.Roots);
8686 }
8687
8688 // Also perform the normal invalidation.
8689 forgetValue(V);
8690}
8691
8692void ScalarEvolution::forgetLoopDispositions() { LoopDispositions.clear(); }
8693
8695 // Unless a specific value is passed to invalidation, completely clear both
8696 // caches.
8697 if (!V) {
8698 BlockDispositions.clear();
8699 LoopDispositions.clear();
8700 return;
8701 }
8702
8703 if (!isSCEVable(V->getType()))
8704 return;
8705
8706 const SCEV *S = getExistingSCEV(V);
8707 if (!S)
8708 return;
8709
8710 // Invalidate the block and loop dispositions cached for S. Dispositions of
8711 // S's users may change if S's disposition changes (i.e. a user may change to
8712 // loop-invariant, if S changes to loop invariant), so also invalidate
8713 // dispositions of S's users recursively.
8714 SmallVector<const SCEV *, 8> Worklist = {S};
8716 while (!Worklist.empty()) {
8717 const SCEV *Curr = Worklist.pop_back_val();
8718 bool LoopDispoRemoved = LoopDispositions.erase(Curr);
8719 bool BlockDispoRemoved = BlockDispositions.erase(Curr);
8720 if (!LoopDispoRemoved && !BlockDispoRemoved)
8721 continue;
8722 auto Users = SCEVUsers.find(Curr);
8723 if (Users != SCEVUsers.end())
8724 for (const auto *User : Users->second)
8725 if (Seen.insert(User).second)
8726 Worklist.push_back(User);
8727 }
8728}
8729
8730/// Get the exact loop backedge taken count considering all loop exits. A
8731/// computable result can only be returned for loops with all exiting blocks
8732/// dominating the latch. howFarToZero assumes that the limit of each loop test
8733/// is never skipped. This is a valid assumption as long as the loop exits via
8734/// that test. For precise results, it is the caller's responsibility to specify
8735/// the relevant loop exiting block using getExact(ExitingBlock, SE).
8736const SCEV *ScalarEvolution::BackedgeTakenInfo::getExact(
8737 const Loop *L, ScalarEvolution *SE,
8739 // If any exits were not computable, the loop is not computable.
8740 if (!isComplete() || ExitNotTaken.empty())
8741 return SE->getCouldNotCompute();
8742
8743 const BasicBlock *Latch = L->getLoopLatch();
8744 // All exiting blocks we have collected must dominate the only backedge.
8745 if (!Latch)
8746 return SE->getCouldNotCompute();
8747
8748 // All exiting blocks we have gathered dominate loop's latch, so exact trip
8749 // count is simply a minimum out of all these calculated exit counts.
8751 for (const auto &ENT : ExitNotTaken) {
8752 const SCEV *BECount = ENT.ExactNotTaken;
8753 assert(BECount != SE->getCouldNotCompute() && "Bad exit SCEV!");
8754 assert(SE->DT.dominates(ENT.ExitingBlock, Latch) &&
8755 "We should only have known counts for exiting blocks that dominate "
8756 "latch!");
8757
8758 Ops.push_back(BECount);
8759
8760 if (Preds)
8761 append_range(*Preds, ENT.Predicates);
8762
8763 assert((Preds || ENT.hasAlwaysTruePredicate()) &&
8764 "Predicate should be always true!");
8765 }
8766
8767 // If an earlier exit exits on the first iteration (exit count zero), then
8768 // a later poison exit count should not propagate into the result. This are
8769 // exactly the semantics provided by umin_seq.
8770 return SE->getUMinFromMismatchedTypes(Ops, /* Sequential */ true);
8771}
8772
8773const ScalarEvolution::ExitNotTakenInfo *
8774ScalarEvolution::BackedgeTakenInfo::getExitNotTaken(
8775 const BasicBlock *ExitingBlock,
8776 SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
8777 for (const auto &ENT : ExitNotTaken)
8778 if (ENT.ExitingBlock == ExitingBlock) {
8779 if (ENT.hasAlwaysTruePredicate())
8780 return &ENT;
8781 else if (Predicates) {
8782 append_range(*Predicates, ENT.Predicates);
8783 return &ENT;
8784 }
8785 }
8786
8787 return nullptr;
8788}
8789
8790/// getConstantMax - Get the constant max backedge taken count for the loop.
8791const SCEV *ScalarEvolution::BackedgeTakenInfo::getConstantMax(
8792 ScalarEvolution *SE,
8793 SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
8794 if (!getConstantMax())
8795 return SE->getCouldNotCompute();
8796
8797 for (const auto &ENT : ExitNotTaken)
8798 if (!ENT.hasAlwaysTruePredicate()) {
8799 if (!Predicates)
8800 return SE->getCouldNotCompute();
8801 append_range(*Predicates, ENT.Predicates);
8802 }
8803
8804 assert((isa<SCEVCouldNotCompute>(getConstantMax()) ||
8805 isa<SCEVConstant>(getConstantMax())) &&
8806 "No point in having a non-constant max backedge taken count!");
8807 return getConstantMax();
8808}
8809
8810const SCEV *ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(
8811 const Loop *L, ScalarEvolution *SE,
8812 SmallVectorImpl<const SCEVPredicate *> *Predicates) {
8813 if (!SymbolicMax) {
8814 // Form an expression for the maximum exit count possible for this loop. We
8815 // merge the max and exact information to approximate a version of
8816 // getConstantMaxBackedgeTakenCount which isn't restricted to just
8817 // constants.
8819
8820 for (const auto &ENT : ExitNotTaken) {
8821 const SCEV *ExitCount = ENT.SymbolicMaxNotTaken;
8822 if (!isa<SCEVCouldNotCompute>(ExitCount)) {
8823 assert(SE->DT.dominates(ENT.ExitingBlock, L->getLoopLatch()) &&
8824 "We should only have known counts for exiting blocks that "
8825 "dominate latch!");
8826 ExitCounts.push_back(ExitCount);
8827 if (Predicates)
8828 append_range(*Predicates, ENT.Predicates);
8829
8830 assert((Predicates || ENT.hasAlwaysTruePredicate()) &&
8831 "Predicate should be always true!");
8832 }
8833 }
8834 if (ExitCounts.empty())
8835 SymbolicMax = SE->getCouldNotCompute();
8836 else
8837 SymbolicMax =
8838 SE->getUMinFromMismatchedTypes(ExitCounts, /*Sequential*/ true);
8839 }
8840 return SymbolicMax;
8841}
8842
8843bool ScalarEvolution::BackedgeTakenInfo::isConstantMaxOrZero(
8844 ScalarEvolution *SE) const {
8845 auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) {
8846 return !ENT.hasAlwaysTruePredicate();
8847 };
8848 return MaxOrZero && !any_of(ExitNotTaken, PredicateNotAlwaysTrue);
8849}
8850
8853
8855 const SCEV *E, const SCEV *ConstantMaxNotTaken,
8856 const SCEV *SymbolicMaxNotTaken, bool MaxOrZero,
8860 // If we prove the max count is zero, so is the symbolic bound. This happens
8861 // in practice due to differences in a) how context sensitive we've chosen
8862 // to be and b) how we reason about bounds implied by UB.
8863 if (ConstantMaxNotTaken->isZero()) {
8864 this->ExactNotTaken = E = ConstantMaxNotTaken;
8865 this->SymbolicMaxNotTaken = SymbolicMaxNotTaken = ConstantMaxNotTaken;
8866 }
8867
8870 "Exact is not allowed to be less precise than Constant Max");
8873 "Exact is not allowed to be less precise than Symbolic Max");
8876 "Symbolic Max is not allowed to be less precise than Constant Max");
8879 "No point in having a non-constant max backedge taken count!");
8881 for (const auto PredList : PredLists)
8882 for (const auto *P : PredList) {
8883 if (SeenPreds.contains(P))
8884 continue;
8885 assert(!isa<SCEVUnionPredicate>(P) && "Only add leaf predicates here!");
8886 SeenPreds.insert(P);
8887 Predicates.push_back(P);
8888 }
8889 assert((isa<SCEVCouldNotCompute>(E) || !E->getType()->isPointerTy()) &&
8890 "Backedge count should be int");
8892 !ConstantMaxNotTaken->getType()->isPointerTy()) &&
8893 "Max backedge count should be int");
8894}
8895
8903
8904/// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each
8905/// computable exit into a persistent ExitNotTakenInfo array.
8906ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(
8908 bool IsComplete, const SCEV *ConstantMax, bool MaxOrZero)
8909 : ConstantMax(ConstantMax), IsComplete(IsComplete), MaxOrZero(MaxOrZero) {
8910 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
8911
8912 ExitNotTaken.reserve(ExitCounts.size());
8913 std::transform(ExitCounts.begin(), ExitCounts.end(),
8914 std::back_inserter(ExitNotTaken),
8915 [&](const EdgeExitInfo &EEI) {
8916 BasicBlock *ExitBB = EEI.first;
8917 const ExitLimit &EL = EEI.second;
8918 return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken,
8919 EL.ConstantMaxNotTaken, EL.SymbolicMaxNotTaken,
8920 EL.Predicates);
8921 });
8922 assert((isa<SCEVCouldNotCompute>(ConstantMax) ||
8923 isa<SCEVConstant>(ConstantMax)) &&
8924 "No point in having a non-constant max backedge taken count!");
8925}
8926
8927/// Compute the number of times the backedge of the specified loop will execute.
8928ScalarEvolution::BackedgeTakenInfo
8929ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
8930 bool AllowPredicates) {
8931 SmallVector<BasicBlock *, 8> ExitingBlocks;
8932 L->getExitingBlocks(ExitingBlocks);
8933
8934 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
8935
8937 bool CouldComputeBECount = true;
8938 BasicBlock *Latch = L->getLoopLatch(); // may be NULL.
8939 const SCEV *MustExitMaxBECount = nullptr;
8940 const SCEV *MayExitMaxBECount = nullptr;
8941 bool MustExitMaxOrZero = false;
8942 bool IsOnlyExit = ExitingBlocks.size() == 1;
8943
8944 // Compute the ExitLimit for each loop exit. Use this to populate ExitCounts
8945 // and compute maxBECount.
8946 // Do a union of all the predicates here.
8947 for (BasicBlock *ExitBB : ExitingBlocks) {
8948 // We canonicalize untaken exits to br (constant), ignore them so that
8949 // proving an exit untaken doesn't negatively impact our ability to reason
8950 // about the loop as whole.
8951 if (auto *BI = dyn_cast<BranchInst>(ExitBB->getTerminator()))
8952 if (auto *CI = dyn_cast<ConstantInt>(BI->getCondition())) {
8953 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
8954 if (ExitIfTrue == CI->isZero())
8955 continue;
8956 }
8957
8958 ExitLimit EL = computeExitLimit(L, ExitBB, IsOnlyExit, AllowPredicates);
8959
8960 assert((AllowPredicates || EL.Predicates.empty()) &&
8961 "Predicated exit limit when predicates are not allowed!");
8962
8963 // 1. For each exit that can be computed, add an entry to ExitCounts.
8964 // CouldComputeBECount is true only if all exits can be computed.
8965 if (EL.ExactNotTaken != getCouldNotCompute())
8966 ++NumExitCountsComputed;
8967 else
8968 // We couldn't compute an exact value for this exit, so
8969 // we won't be able to compute an exact value for the loop.
8970 CouldComputeBECount = false;
8971 // Remember exit count if either exact or symbolic is known. Because
8972 // Exact always implies symbolic, only check symbolic.
8973 if (EL.SymbolicMaxNotTaken != getCouldNotCompute())
8974 ExitCounts.emplace_back(ExitBB, EL);
8975 else {
8976 assert(EL.ExactNotTaken == getCouldNotCompute() &&
8977 "Exact is known but symbolic isn't?");
8978 ++NumExitCountsNotComputed;
8979 }
8980
8981 // 2. Derive the loop's MaxBECount from each exit's max number of
8982 // non-exiting iterations. Partition the loop exits into two kinds:
8983 // LoopMustExits and LoopMayExits.
8984 //
8985 // If the exit dominates the loop latch, it is a LoopMustExit otherwise it
8986 // is a LoopMayExit. If any computable LoopMustExit is found, then
8987 // MaxBECount is the minimum EL.ConstantMaxNotTaken of computable
8988 // LoopMustExits. Otherwise, MaxBECount is conservatively the maximum
8989 // EL.ConstantMaxNotTaken, where CouldNotCompute is considered greater than
8990 // any
8991 // computable EL.ConstantMaxNotTaken.
8992 if (EL.ConstantMaxNotTaken != getCouldNotCompute() && Latch &&
8993 DT.dominates(ExitBB, Latch)) {
8994 if (!MustExitMaxBECount) {
8995 MustExitMaxBECount = EL.ConstantMaxNotTaken;
8996 MustExitMaxOrZero = EL.MaxOrZero;
8997 } else {
8998 MustExitMaxBECount = getUMinFromMismatchedTypes(MustExitMaxBECount,
8999 EL.ConstantMaxNotTaken);
9000 }
9001 } else if (MayExitMaxBECount != getCouldNotCompute()) {
9002 if (!MayExitMaxBECount || EL.ConstantMaxNotTaken == getCouldNotCompute())
9003 MayExitMaxBECount = EL.ConstantMaxNotTaken;
9004 else {
9005 MayExitMaxBECount = getUMaxFromMismatchedTypes(MayExitMaxBECount,
9006 EL.ConstantMaxNotTaken);
9007 }
9008 }
9009 }
9010 const SCEV *MaxBECount = MustExitMaxBECount ? MustExitMaxBECount :
9011 (MayExitMaxBECount ? MayExitMaxBECount : getCouldNotCompute());
9012 // The loop backedge will be taken the maximum or zero times if there's
9013 // a single exit that must be taken the maximum or zero times.
9014 bool MaxOrZero = (MustExitMaxOrZero && ExitingBlocks.size() == 1);
9015
9016 // Remember which SCEVs are used in exit limits for invalidation purposes.
9017 // We only care about non-constant SCEVs here, so we can ignore
9018 // EL.ConstantMaxNotTaken
9019 // and MaxBECount, which must be SCEVConstant.
9020 for (const auto &Pair : ExitCounts) {
9021 if (!isa<SCEVConstant>(Pair.second.ExactNotTaken))
9022 BECountUsers[Pair.second.ExactNotTaken].insert({L, AllowPredicates});
9023 if (!isa<SCEVConstant>(Pair.second.SymbolicMaxNotTaken))
9024 BECountUsers[Pair.second.SymbolicMaxNotTaken].insert(
9025 {L, AllowPredicates});
9026 }
9027 return BackedgeTakenInfo(std::move(ExitCounts), CouldComputeBECount,
9028 MaxBECount, MaxOrZero);
9029}
9030
9031ScalarEvolution::ExitLimit
9032ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
9033 bool IsOnlyExit, bool AllowPredicates) {
9034 assert(L->contains(ExitingBlock) && "Exit count for non-loop block?");
9035 // If our exiting block does not dominate the latch, then its connection with
9036 // loop's exit limit may be far from trivial.
9037 const BasicBlock *Latch = L->getLoopLatch();
9038 if (!Latch || !DT.dominates(ExitingBlock, Latch))
9039 return getCouldNotCompute();
9040
9041 Instruction *Term = ExitingBlock->getTerminator();
9042 if (BranchInst *BI = dyn_cast<BranchInst>(Term)) {
9043 assert(BI->isConditional() && "If unconditional, it can't be in loop!");
9044 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
9045 assert(ExitIfTrue == L->contains(BI->getSuccessor(1)) &&
9046 "It should have one successor in loop and one exit block!");
9047 // Proceed to the next level to examine the exit condition expression.
9048 return computeExitLimitFromCond(L, BI->getCondition(), ExitIfTrue,
9049 /*ControlsOnlyExit=*/IsOnlyExit,
9050 AllowPredicates);
9051 }
9052
9053 if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) {
9054 // For switch, make sure that there is a single exit from the loop.
9055 BasicBlock *Exit = nullptr;
9056 for (auto *SBB : successors(ExitingBlock))
9057 if (!L->contains(SBB)) {
9058 if (Exit) // Multiple exit successors.
9059 return getCouldNotCompute();
9060 Exit = SBB;
9061 }
9062 assert(Exit && "Exiting block must have at least one exit");
9063 return computeExitLimitFromSingleExitSwitch(
9064 L, SI, Exit, /*ControlsOnlyExit=*/IsOnlyExit);
9065 }
9066
9067 return getCouldNotCompute();
9068}
9069
9071 const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
9072 bool AllowPredicates) {
9073 ScalarEvolution::ExitLimitCacheTy Cache(L, ExitIfTrue, AllowPredicates);
9074 return computeExitLimitFromCondCached(Cache, L, ExitCond, ExitIfTrue,
9075 ControlsOnlyExit, AllowPredicates);
9076}
9077
9078std::optional<ScalarEvolution::ExitLimit>
9079ScalarEvolution::ExitLimitCache::find(const Loop *L, Value *ExitCond,
9080 bool ExitIfTrue, bool ControlsOnlyExit,
9081 bool AllowPredicates) {
9082 (void)this->L;
9083 (void)this->ExitIfTrue;
9084 (void)this->AllowPredicates;
9085
9086 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
9087 this->AllowPredicates == AllowPredicates &&
9088 "Variance in assumed invariant key components!");
9089 auto Itr = TripCountMap.find({ExitCond, ControlsOnlyExit});
9090 if (Itr == TripCountMap.end())
9091 return std::nullopt;
9092 return Itr->second;
9093}
9094
9095void ScalarEvolution::ExitLimitCache::insert(const Loop *L, Value *ExitCond,
9096 bool ExitIfTrue,
9097 bool ControlsOnlyExit,
9098 bool AllowPredicates,
9099 const ExitLimit &EL) {
9100 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
9101 this->AllowPredicates == AllowPredicates &&
9102 "Variance in assumed invariant key components!");
9103
9104 auto InsertResult = TripCountMap.insert({{ExitCond, ControlsOnlyExit}, EL});
9105 assert(InsertResult.second && "Expected successful insertion!");
9106 (void)InsertResult;
9107 (void)ExitIfTrue;
9108}
9109
9110ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached(
9111 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9112 bool ControlsOnlyExit, bool AllowPredicates) {
9113
9114 if (auto MaybeEL = Cache.find(L, ExitCond, ExitIfTrue, ControlsOnlyExit,
9115 AllowPredicates))
9116 return *MaybeEL;
9117
9118 ExitLimit EL = computeExitLimitFromCondImpl(
9119 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates);
9120 Cache.insert(L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates, EL);
9121 return EL;
9122}
9123
9124ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
9125 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9126 bool ControlsOnlyExit, bool AllowPredicates) {
9127 // Handle BinOp conditions (And, Or).
9128 if (auto LimitFromBinOp = computeExitLimitFromCondFromBinOp(
9129 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates))
9130 return *LimitFromBinOp;
9131
9132 // With an icmp, it may be feasible to compute an exact backedge-taken count.
9133 // Proceed to the next level to examine the icmp.
9134 if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond)) {
9135 ExitLimit EL =
9136 computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsOnlyExit);
9137 if (EL.hasFullInfo() || !AllowPredicates)
9138 return EL;
9139
9140 // Try again, but use SCEV predicates this time.
9141 return computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue,
9142 ControlsOnlyExit,
9143 /*AllowPredicates=*/true);
9144 }
9145
9146 // Check for a constant condition. These are normally stripped out by
9147 // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to
9148 // preserve the CFG and is temporarily leaving constant conditions
9149 // in place.
9150 if (ConstantInt *CI = dyn_cast<ConstantInt>(ExitCond)) {
9151 if (ExitIfTrue == !CI->getZExtValue())
9152 // The backedge is always taken.
9153 return getCouldNotCompute();
9154 // The backedge is never taken.
9155 return getZero(CI->getType());
9156 }
9157
9158 // If we're exiting based on the overflow flag of an x.with.overflow intrinsic
9159 // with a constant step, we can form an equivalent icmp predicate and figure
9160 // out how many iterations will be taken before we exit.
9161 const WithOverflowInst *WO;
9162 const APInt *C;
9163 if (match(ExitCond, m_ExtractValue<1>(m_WithOverflowInst(WO))) &&
9164 match(WO->getRHS(), m_APInt(C))) {
9165 ConstantRange NWR =
9167 WO->getNoWrapKind());
9168 CmpInst::Predicate Pred;
9169 APInt NewRHSC, Offset;
9170 NWR.getEquivalentICmp(Pred, NewRHSC, Offset);
9171 if (!ExitIfTrue)
9172 Pred = ICmpInst::getInversePredicate(Pred);
9173 auto *LHS = getSCEV(WO->getLHS());
9174 if (Offset != 0)
9176 auto EL = computeExitLimitFromICmp(L, Pred, LHS, getConstant(NewRHSC),
9177 ControlsOnlyExit, AllowPredicates);
9178 if (EL.hasAnyInfo())
9179 return EL;
9180 }
9181
9182 // If it's not an integer or pointer comparison then compute it the hard way.
9183 return computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
9184}
9185
9186std::optional<ScalarEvolution::ExitLimit>
9187ScalarEvolution::computeExitLimitFromCondFromBinOp(
9188 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9189 bool ControlsOnlyExit, bool AllowPredicates) {
9190 // Check if the controlling expression for this loop is an And or Or.
9191 Value *Op0, *Op1;
9192 bool IsAnd = false;
9193 if (match(ExitCond, m_LogicalAnd(m_Value(Op0), m_Value(Op1))))
9194 IsAnd = true;
9195 else if (match(ExitCond, m_LogicalOr(m_Value(Op0), m_Value(Op1))))
9196 IsAnd = false;
9197 else
9198 return std::nullopt;
9199
9200 // EitherMayExit is true in these two cases:
9201 // br (and Op0 Op1), loop, exit
9202 // br (or Op0 Op1), exit, loop
9203 bool EitherMayExit = IsAnd ^ ExitIfTrue;
9204 ExitLimit EL0 = computeExitLimitFromCondCached(
9205 Cache, L, Op0, ExitIfTrue, ControlsOnlyExit && !EitherMayExit,
9206 AllowPredicates);
9207 ExitLimit EL1 = computeExitLimitFromCondCached(
9208 Cache, L, Op1, ExitIfTrue, ControlsOnlyExit && !EitherMayExit,
9209 AllowPredicates);
9210
9211 // Be robust against unsimplified IR for the form "op i1 X, NeutralElement"
9212 const Constant *NeutralElement = ConstantInt::get(ExitCond->getType(), IsAnd);
9213 if (isa<ConstantInt>(Op1))
9214 return Op1 == NeutralElement ? EL0 : EL1;
9215 if (isa<ConstantInt>(Op0))
9216 return Op0 == NeutralElement ? EL1 : EL0;
9217
9218 const SCEV *BECount = getCouldNotCompute();
9219 const SCEV *ConstantMaxBECount = getCouldNotCompute();
9220 const SCEV *SymbolicMaxBECount = getCouldNotCompute();
9221 if (EitherMayExit) {
9222 bool UseSequentialUMin = !isa<BinaryOperator>(ExitCond);
9223 // Both conditions must be same for the loop to continue executing.
9224 // Choose the less conservative count.
9225 if (EL0.ExactNotTaken != getCouldNotCompute() &&
9226 EL1.ExactNotTaken != getCouldNotCompute()) {
9227 BECount = getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken,
9228 UseSequentialUMin);
9229 }
9230 if (EL0.ConstantMaxNotTaken == getCouldNotCompute())
9231 ConstantMaxBECount = EL1.ConstantMaxNotTaken;
9232 else if (EL1.ConstantMaxNotTaken == getCouldNotCompute())
9233 ConstantMaxBECount = EL0.ConstantMaxNotTaken;
9234 else
9235 ConstantMaxBECount = getUMinFromMismatchedTypes(EL0.ConstantMaxNotTaken,
9236 EL1.ConstantMaxNotTaken);
9237 if (EL0.SymbolicMaxNotTaken == getCouldNotCompute())
9238 SymbolicMaxBECount = EL1.SymbolicMaxNotTaken;
9239 else if (EL1.SymbolicMaxNotTaken == getCouldNotCompute())
9240 SymbolicMaxBECount = EL0.SymbolicMaxNotTaken;
9241 else
9242 SymbolicMaxBECount = getUMinFromMismatchedTypes(
9243 EL0.SymbolicMaxNotTaken, EL1.SymbolicMaxNotTaken, UseSequentialUMin);
9244 } else {
9245 // Both conditions must be same at the same time for the loop to exit.
9246 // For now, be conservative.
9247 if (EL0.ExactNotTaken == EL1.ExactNotTaken)
9248 BECount = EL0.ExactNotTaken;
9249 }
9250
9251 // There are cases (e.g. PR26207) where computeExitLimitFromCond is able
9252 // to be more aggressive when computing BECount than when computing
9253 // ConstantMaxBECount. In these cases it is possible for EL0.ExactNotTaken
9254 // and
9255 // EL1.ExactNotTaken to match, but for EL0.ConstantMaxNotTaken and
9256 // EL1.ConstantMaxNotTaken to not.
9257 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
9258 !isa<SCEVCouldNotCompute>(BECount))
9259 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
9260 if (isa<SCEVCouldNotCompute>(SymbolicMaxBECount))
9261 SymbolicMaxBECount =
9262 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
9263 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
9264 {ArrayRef(EL0.Predicates), ArrayRef(EL1.Predicates)});
9265}
9266
9267ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9268 const Loop *L, ICmpInst *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
9269 bool AllowPredicates) {
9270 // If the condition was exit on true, convert the condition to exit on false
9271 CmpPredicate Pred;
9272 if (!ExitIfTrue)
9273 Pred = ExitCond->getCmpPredicate();
9274 else
9275 Pred = ExitCond->getInverseCmpPredicate();
9276 const ICmpInst::Predicate OriginalPred = Pred;
9277
9278 const SCEV *LHS = getSCEV(ExitCond->getOperand(0));
9279 const SCEV *RHS = getSCEV(ExitCond->getOperand(1));
9280
9281 ExitLimit EL = computeExitLimitFromICmp(L, Pred, LHS, RHS, ControlsOnlyExit,
9282 AllowPredicates);
9283 if (EL.hasAnyInfo())
9284 return EL;
9285
9286 auto *ExhaustiveCount =
9287 computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
9288
9289 if (!isa<SCEVCouldNotCompute>(ExhaustiveCount))
9290 return ExhaustiveCount;
9291
9292 return computeShiftCompareExitLimit(ExitCond->getOperand(0),
9293 ExitCond->getOperand(1), L, OriginalPred);
9294}
9295ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9296 const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS,
9297 bool ControlsOnlyExit, bool AllowPredicates) {
9298
9299 // Try to evaluate any dependencies out of the loop.
9300 LHS = getSCEVAtScope(LHS, L);
9301 RHS = getSCEVAtScope(RHS, L);
9302
9303 // At this point, we would like to compute how many iterations of the
9304 // loop the predicate will return true for these inputs.
9305 if (isLoopInvariant(LHS, L) && !isLoopInvariant(RHS, L)) {
9306 // If there is a loop-invariant, force it into the RHS.
9307 std::swap(LHS, RHS);
9309 }
9310
9311 bool ControllingFiniteLoop = ControlsOnlyExit && loopHasNoAbnormalExits(L) &&
9313 // Simplify the operands before analyzing them.
9314 (void)SimplifyICmpOperands(Pred, LHS, RHS, /*Depth=*/0);
9315
9316 // If we have a comparison of a chrec against a constant, try to use value
9317 // ranges to answer this query.
9318 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS))
9319 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS))
9320 if (AddRec->getLoop() == L) {
9321 // Form the constant range.
9322 ConstantRange CompRange =
9323 ConstantRange::makeExactICmpRegion(Pred, RHSC->getAPInt());
9324
9325 const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this);
9326 if (!isa<SCEVCouldNotCompute>(Ret)) return Ret;
9327 }
9328
9329 // If this loop must exit based on this condition (or execute undefined
9330 // behaviour), see if we can improve wrap flags. This is essentially
9331 // a must execute style proof.
9332 if (ControllingFiniteLoop && isLoopInvariant(RHS, L)) {
9333 // If we can prove the test sequence produced must repeat the same values
9334 // on self-wrap of the IV, then we can infer that IV doesn't self wrap
9335 // because if it did, we'd have an infinite (undefined) loop.
9336 // TODO: We can peel off any functions which are invertible *in L*. Loop
9337 // invariant terms are effectively constants for our purposes here.
9338 auto *InnerLHS = LHS;
9339 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS))
9340 InnerLHS = ZExt->getOperand();
9341 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(InnerLHS);
9342 AR && !AR->hasNoSelfWrap() && AR->getLoop() == L && AR->isAffine() &&
9343 isKnownToBeAPowerOfTwo(AR->getStepRecurrence(*this), /*OrZero=*/true,
9344 /*OrNegative=*/true)) {
9345 auto Flags = AR->getNoWrapFlags();
9346 Flags = setFlags(Flags, SCEV::FlagNW);
9347 SmallVector<const SCEV *> Operands{AR->operands()};
9348 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
9349 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
9350 }
9351
9352 // For a slt/ult condition with a positive step, can we prove nsw/nuw?
9353 // From no-self-wrap, this follows trivially from the fact that every
9354 // (un)signed-wrapped, but not self-wrapped value must be LT than the
9355 // last value before (un)signed wrap. Since we know that last value
9356 // didn't exit, nor will any smaller one.
9357 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT) {
9358 auto WrapType = Pred == ICmpInst::ICMP_SLT ? SCEV::FlagNSW : SCEV::FlagNUW;
9359 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS);
9360 AR && AR->getLoop() == L && AR->isAffine() &&
9361 !AR->getNoWrapFlags(WrapType) && AR->hasNoSelfWrap() &&
9362 isKnownPositive(AR->getStepRecurrence(*this))) {
9363 auto Flags = AR->getNoWrapFlags();
9364 Flags = setFlags(Flags, WrapType);
9365 SmallVector<const SCEV*> Operands{AR->operands()};
9366 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
9367 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
9368 }
9369 }
9370 }
9371
9372 switch (Pred) {
9373 case ICmpInst::ICMP_NE: { // while (X != Y)
9374 // Convert to: while (X-Y != 0)
9375 if (LHS->getType()->isPointerTy()) {
9378 return LHS;
9379 }
9380 if (RHS->getType()->isPointerTy()) {
9383 return RHS;
9384 }
9385 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit,
9386 AllowPredicates);
9387 if (EL.hasAnyInfo())
9388 return EL;
9389 break;
9390 }
9391 case ICmpInst::ICMP_EQ: { // while (X == Y)
9392 // Convert to: while (X-Y == 0)
9393 if (LHS->getType()->isPointerTy()) {
9396 return LHS;
9397 }
9398 if (RHS->getType()->isPointerTy()) {
9401 return RHS;
9402 }
9403 ExitLimit EL = howFarToNonZero(getMinusSCEV(LHS, RHS), L);
9404 if (EL.hasAnyInfo()) return EL;
9405 break;
9406 }
9407 case ICmpInst::ICMP_SLE:
9408 case ICmpInst::ICMP_ULE:
9409 // Since the loop is finite, an invariant RHS cannot include the boundary
9410 // value, otherwise it would loop forever.
9411 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9412 !isLoopInvariant(RHS, L)) {
9413 // Otherwise, perform the addition in a wider type, to avoid overflow.
9414 // If the LHS is an addrec with the appropriate nowrap flag, the
9415 // extension will be sunk into it and the exit count can be analyzed.
9416 auto *OldType = dyn_cast<IntegerType>(LHS->getType());
9417 if (!OldType)
9418 break;
9419 // Prefer doubling the bitwidth over adding a single bit to make it more
9420 // likely that we use a legal type.
9421 auto *NewType =
9422 Type::getIntNTy(OldType->getContext(), OldType->getBitWidth() * 2);
9423 if (ICmpInst::isSigned(Pred)) {
9424 LHS = getSignExtendExpr(LHS, NewType);
9425 RHS = getSignExtendExpr(RHS, NewType);
9426 } else {
9427 LHS = getZeroExtendExpr(LHS, NewType);
9428 RHS = getZeroExtendExpr(RHS, NewType);
9429 }
9430 }
9432 [[fallthrough]];
9433 case ICmpInst::ICMP_SLT:
9434 case ICmpInst::ICMP_ULT: { // while (X < Y)
9435 bool IsSigned = ICmpInst::isSigned(Pred);
9436 ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9437 AllowPredicates);
9438 if (EL.hasAnyInfo())
9439 return EL;
9440 break;
9441 }
9442 case ICmpInst::ICMP_SGE:
9443 case ICmpInst::ICMP_UGE:
9444 // Since the loop is finite, an invariant RHS cannot include the boundary
9445 // value, otherwise it would loop forever.
9446 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9447 !isLoopInvariant(RHS, L))
9448 break;
9450 [[fallthrough]];
9451 case ICmpInst::ICMP_SGT:
9452 case ICmpInst::ICMP_UGT: { // while (X > Y)
9453 bool IsSigned = ICmpInst::isSigned(Pred);
9454 ExitLimit EL = howManyGreaterThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9455 AllowPredicates);
9456 if (EL.hasAnyInfo())
9457 return EL;
9458 break;
9459 }
9460 default:
9461 break;
9462 }
9463
9464 return getCouldNotCompute();
9465}
9466
9467ScalarEvolution::ExitLimit
9468ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L,
9469 SwitchInst *Switch,
9470 BasicBlock *ExitingBlock,
9471 bool ControlsOnlyExit) {
9472 assert(!L->contains(ExitingBlock) && "Not an exiting block!");
9473
9474 // Give up if the exit is the default dest of a switch.
9475 if (Switch->getDefaultDest() == ExitingBlock)
9476 return getCouldNotCompute();
9477
9478 assert(L->contains(Switch->getDefaultDest()) &&
9479 "Default case must not exit the loop!");
9480 const SCEV *LHS = getSCEVAtScope(Switch->getCondition(), L);
9481 const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock));
9482
9483 // while (X != Y) --> while (X-Y != 0)
9484 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit);
9485 if (EL.hasAnyInfo())
9486 return EL;
9487
9488 return getCouldNotCompute();
9489}
9490
9491static ConstantInt *
9493 ScalarEvolution &SE) {
9494 const SCEV *InVal = SE.getConstant(C);
9495 const SCEV *Val = AddRec->evaluateAtIteration(InVal, SE);
9497 "Evaluation of SCEV at constant didn't fold correctly?");
9498 return cast<SCEVConstant>(Val)->getValue();
9499}
9500
9501ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit(
9502 Value *LHS, Value *RHSV, const Loop *L, ICmpInst::Predicate Pred) {
9503 ConstantInt *RHS = dyn_cast<ConstantInt>(RHSV);
9504 if (!RHS)
9505 return getCouldNotCompute();
9506
9507 const BasicBlock *Latch = L->getLoopLatch();
9508 if (!Latch)
9509 return getCouldNotCompute();
9510
9511 const BasicBlock *Predecessor = L->getLoopPredecessor();
9512 if (!Predecessor)
9513 return getCouldNotCompute();
9514
9515 // Return true if V is of the form "LHS `shift_op` <positive constant>".
9516 // Return LHS in OutLHS and shift_opt in OutOpCode.
9517 auto MatchPositiveShift =
9518 [](Value *V, Value *&OutLHS, Instruction::BinaryOps &OutOpCode) {
9519
9520 using namespace PatternMatch;
9521
9522 ConstantInt *ShiftAmt;
9523 if (match(V, m_LShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9524 OutOpCode = Instruction::LShr;
9525 else if (match(V, m_AShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9526 OutOpCode = Instruction::AShr;
9527 else if (match(V, m_Shl(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9528 OutOpCode = Instruction::Shl;
9529 else
9530 return false;
9531
9532 return ShiftAmt->getValue().isStrictlyPositive();
9533 };
9534
9535 // Recognize a "shift recurrence" either of the form %iv or of %iv.shifted in
9536 //
9537 // loop:
9538 // %iv = phi i32 [ %iv.shifted, %loop ], [ %val, %preheader ]
9539 // %iv.shifted = lshr i32 %iv, <positive constant>
9540 //
9541 // Return true on a successful match. Return the corresponding PHI node (%iv
9542 // above) in PNOut and the opcode of the shift operation in OpCodeOut.
9543 auto MatchShiftRecurrence =
9544 [&](Value *V, PHINode *&PNOut, Instruction::BinaryOps &OpCodeOut) {
9545 std::optional<Instruction::BinaryOps> PostShiftOpCode;
9546
9547 {
9549 Value *V;
9550
9551 // If we encounter a shift instruction, "peel off" the shift operation,
9552 // and remember that we did so. Later when we inspect %iv's backedge
9553 // value, we will make sure that the backedge value uses the same
9554 // operation.
9555 //
9556 // Note: the peeled shift operation does not have to be the same
9557 // instruction as the one feeding into the PHI's backedge value. We only
9558 // really care about it being the same *kind* of shift instruction --
9559 // that's all that is required for our later inferences to hold.
9560 if (MatchPositiveShift(LHS, V, OpC)) {
9561 PostShiftOpCode = OpC;
9562 LHS = V;
9563 }
9564 }
9565
9566 PNOut = dyn_cast<PHINode>(LHS);
9567 if (!PNOut || PNOut->getParent() != L->getHeader())
9568 return false;
9569
9570 Value *BEValue = PNOut->getIncomingValueForBlock(Latch);
9571 Value *OpLHS;
9572
9573 return
9574 // The backedge value for the PHI node must be a shift by a positive
9575 // amount
9576 MatchPositiveShift(BEValue, OpLHS, OpCodeOut) &&
9577
9578 // of the PHI node itself
9579 OpLHS == PNOut &&
9580
9581 // and the kind of shift should be match the kind of shift we peeled
9582 // off, if any.
9583 (!PostShiftOpCode || *PostShiftOpCode == OpCodeOut);
9584 };
9585
9586 PHINode *PN;
9588 if (!MatchShiftRecurrence(LHS, PN, OpCode))
9589 return getCouldNotCompute();
9590
9591 const DataLayout &DL = getDataLayout();
9592
9593 // The key rationale for this optimization is that for some kinds of shift
9594 // recurrences, the value of the recurrence "stabilizes" to either 0 or -1
9595 // within a finite number of iterations. If the condition guarding the
9596 // backedge (in the sense that the backedge is taken if the condition is true)
9597 // is false for the value the shift recurrence stabilizes to, then we know
9598 // that the backedge is taken only a finite number of times.
9599
9600 ConstantInt *StableValue = nullptr;
9601 switch (OpCode) {
9602 default:
9603 llvm_unreachable("Impossible case!");
9604
9605 case Instruction::AShr: {
9606 // {K,ashr,<positive-constant>} stabilizes to signum(K) in at most
9607 // bitwidth(K) iterations.
9608 Value *FirstValue = PN->getIncomingValueForBlock(Predecessor);
9609 KnownBits Known = computeKnownBits(FirstValue, DL, &AC,
9610 Predecessor->getTerminator(), &DT);
9611 auto *Ty = cast<IntegerType>(RHS->getType());
9612 if (Known.isNonNegative())
9613 StableValue = ConstantInt::get(Ty, 0);
9614 else if (Known.isNegative())
9615 StableValue = ConstantInt::get(Ty, -1, true);
9616 else
9617 return getCouldNotCompute();
9618
9619 break;
9620 }
9621 case Instruction::LShr:
9622 case Instruction::Shl:
9623 // Both {K,lshr,<positive-constant>} and {K,shl,<positive-constant>}
9624 // stabilize to 0 in at most bitwidth(K) iterations.
9625 StableValue = ConstantInt::get(cast<IntegerType>(RHS->getType()), 0);
9626 break;
9627 }
9628
9629 auto *Result =
9630 ConstantFoldCompareInstOperands(Pred, StableValue, RHS, DL, &TLI);
9631 assert(Result->getType()->isIntegerTy(1) &&
9632 "Otherwise cannot be an operand to a branch instruction");
9633
9634 if (Result->isZeroValue()) {
9635 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
9636 const SCEV *UpperBound =
9638 return ExitLimit(getCouldNotCompute(), UpperBound, UpperBound, false);
9639 }
9640
9641 return getCouldNotCompute();
9642}
9643
9644/// Return true if we can constant fold an instruction of the specified type,
9645/// assuming that all operands were constants.
9646static bool CanConstantFold(const Instruction *I) {
9650 return true;
9651
9652 if (const CallInst *CI = dyn_cast<CallInst>(I))
9653 if (const Function *F = CI->getCalledFunction())
9654 return canConstantFoldCallTo(CI, F);
9655 return false;
9656}
9657
9658/// Determine whether this instruction can constant evolve within this loop
9659/// assuming its operands can all constant evolve.
9660static bool canConstantEvolve(Instruction *I, const Loop *L) {
9661 // An instruction outside of the loop can't be derived from a loop PHI.
9662 if (!L->contains(I)) return false;
9663
9664 if (isa<PHINode>(I)) {
9665 // We don't currently keep track of the control flow needed to evaluate
9666 // PHIs, so we cannot handle PHIs inside of loops.
9667 return L->getHeader() == I->getParent();
9668 }
9669
9670 // If we won't be able to constant fold this expression even if the operands
9671 // are constants, bail early.
9672 return CanConstantFold(I);
9673}
9674
9675/// getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by
9676/// recursing through each instruction operand until reaching a loop header phi.
9677static PHINode *
9680 unsigned Depth) {
9682 return nullptr;
9683
9684 // Otherwise, we can evaluate this instruction if all of its operands are
9685 // constant or derived from a PHI node themselves.
9686 PHINode *PHI = nullptr;
9687 for (Value *Op : UseInst->operands()) {
9688 if (isa<Constant>(Op)) continue;
9689
9691 if (!OpInst || !canConstantEvolve(OpInst, L)) return nullptr;
9692
9693 PHINode *P = dyn_cast<PHINode>(OpInst);
9694 if (!P)
9695 // If this operand is already visited, reuse the prior result.
9696 // We may have P != PHI if this is the deepest point at which the
9697 // inconsistent paths meet.
9698 P = PHIMap.lookup(OpInst);
9699 if (!P) {
9700 // Recurse and memoize the results, whether a phi is found or not.
9701 // This recursive call invalidates pointers into PHIMap.
9702 P = getConstantEvolvingPHIOperands(OpInst, L, PHIMap, Depth + 1);
9703 PHIMap[OpInst] = P;
9704 }
9705 if (!P)
9706 return nullptr; // Not evolving from PHI
9707 if (PHI && PHI != P)
9708 return nullptr; // Evolving from multiple different PHIs.
9709 PHI = P;
9710 }
9711 // This is a expression evolving from a constant PHI!
9712 return PHI;
9713}
9714
9715/// getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node
9716/// in the loop that V is derived from. We allow arbitrary operations along the
9717/// way, but the operands of an operation must either be constants or a value
9718/// derived from a constant PHI. If this expression does not fit with these
9719/// constraints, return null.
9722 if (!I || !canConstantEvolve(I, L)) return nullptr;
9723
9724 if (PHINode *PN = dyn_cast<PHINode>(I))
9725 return PN;
9726
9727 // Record non-constant instructions contained by the loop.
9729 return getConstantEvolvingPHIOperands(I, L, PHIMap, 0);
9730}
9731
9732/// EvaluateExpression - Given an expression that passes the
9733/// getConstantEvolvingPHI predicate, evaluate its value assuming the PHI node
9734/// in the loop has the value PHIVal. If we can't fold this expression for some
9735/// reason, return null.
9738 const DataLayout &DL,
9739 const TargetLibraryInfo *TLI) {
9740 // Convenient constant check, but redundant for recursive calls.
9741 if (Constant *C = dyn_cast<Constant>(V)) return C;
9743 if (!I) return nullptr;
9744
9745 if (Constant *C = Vals.lookup(I)) return C;
9746
9747 // An instruction inside the loop depends on a value outside the loop that we
9748 // weren't given a mapping for, or a value such as a call inside the loop.
9749 if (!canConstantEvolve(I, L)) return nullptr;
9750
9751 // An unmapped PHI can be due to a branch or another loop inside this loop,
9752 // or due to this not being the initial iteration through a loop where we
9753 // couldn't compute the evolution of this particular PHI last time.
9754 if (isa<PHINode>(I)) return nullptr;
9755
9756 std::vector<Constant*> Operands(I->getNumOperands());
9757
9758 for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
9759 Instruction *Operand = dyn_cast<Instruction>(I->getOperand(i));
9760 if (!Operand) {
9761 Operands[i] = dyn_cast<Constant>(I->getOperand(i));
9762 if (!Operands[i]) return nullptr;
9763 continue;
9764 }
9765 Constant *C = EvaluateExpression(Operand, L, Vals, DL, TLI);
9766 Vals[Operand] = C;
9767 if (!C) return nullptr;
9768 Operands[i] = C;
9769 }
9770
9771 return ConstantFoldInstOperands(I, Operands, DL, TLI,
9772 /*AllowNonDeterministic=*/false);
9773}
9774
9775
9776// If every incoming value to PN except the one for BB is a specific Constant,
9777// return that, else return nullptr.
9779 Constant *IncomingVal = nullptr;
9780
9781 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
9782 if (PN->getIncomingBlock(i) == BB)
9783 continue;
9784
9785 auto *CurrentVal = dyn_cast<Constant>(PN->getIncomingValue(i));
9786 if (!CurrentVal)
9787 return nullptr;
9788
9789 if (IncomingVal != CurrentVal) {
9790 if (IncomingVal)
9791 return nullptr;
9792 IncomingVal = CurrentVal;
9793 }
9794 }
9795
9796 return IncomingVal;
9797}
9798
9799/// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
9800/// in the header of its containing loop, we know the loop executes a
9801/// constant number of times, and the PHI node is just a recurrence
9802/// involving constants, fold it.
9803Constant *
9804ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN,
9805 const APInt &BEs,
9806 const Loop *L) {
9807 auto [I, Inserted] = ConstantEvolutionLoopExitValue.try_emplace(PN);
9808 if (!Inserted)
9809 return I->second;
9810
9812 return nullptr; // Not going to evaluate it.
9813
9814 Constant *&RetVal = I->second;
9815
9816 DenseMap<Instruction *, Constant *> CurrentIterVals;
9817 BasicBlock *Header = L->getHeader();
9818 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
9819
9820 BasicBlock *Latch = L->getLoopLatch();
9821 if (!Latch)
9822 return nullptr;
9823
9824 for (PHINode &PHI : Header->phis()) {
9825 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
9826 CurrentIterVals[&PHI] = StartCST;
9827 }
9828 if (!CurrentIterVals.count(PN))
9829 return RetVal = nullptr;
9830
9831 Value *BEValue = PN->getIncomingValueForBlock(Latch);
9832
9833 // Execute the loop symbolically to determine the exit value.
9834 assert(BEs.getActiveBits() < CHAR_BIT * sizeof(unsigned) &&
9835 "BEs is <= MaxBruteForceIterations which is an 'unsigned'!");
9836
9837 unsigned NumIterations = BEs.getZExtValue(); // must be in range
9838 unsigned IterationNum = 0;
9839 const DataLayout &DL = getDataLayout();
9840 for (; ; ++IterationNum) {
9841 if (IterationNum == NumIterations)
9842 return RetVal = CurrentIterVals[PN]; // Got exit value!
9843
9844 // Compute the value of the PHIs for the next iteration.
9845 // EvaluateExpression adds non-phi values to the CurrentIterVals map.
9846 DenseMap<Instruction *, Constant *> NextIterVals;
9847 Constant *NextPHI =
9848 EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9849 if (!NextPHI)
9850 return nullptr; // Couldn't evaluate!
9851 NextIterVals[PN] = NextPHI;
9852
9853 bool StoppedEvolving = NextPHI == CurrentIterVals[PN];
9854
9855 // Also evaluate the other PHI nodes. However, we don't get to stop if we
9856 // cease to be able to evaluate one of them or if they stop evolving,
9857 // because that doesn't necessarily prevent us from computing PN.
9859 for (const auto &I : CurrentIterVals) {
9860 PHINode *PHI = dyn_cast<PHINode>(I.first);
9861 if (!PHI || PHI == PN || PHI->getParent() != Header) continue;
9862 PHIsToCompute.emplace_back(PHI, I.second);
9863 }
9864 // We use two distinct loops because EvaluateExpression may invalidate any
9865 // iterators into CurrentIterVals.
9866 for (const auto &I : PHIsToCompute) {
9867 PHINode *PHI = I.first;
9868 Constant *&NextPHI = NextIterVals[PHI];
9869 if (!NextPHI) { // Not already computed.
9870 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
9871 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9872 }
9873 if (NextPHI != I.second)
9874 StoppedEvolving = false;
9875 }
9876
9877 // If all entries in CurrentIterVals == NextIterVals then we can stop
9878 // iterating, the loop can't continue to change.
9879 if (StoppedEvolving)
9880 return RetVal = CurrentIterVals[PN];
9881
9882 CurrentIterVals.swap(NextIterVals);
9883 }
9884}
9885
9886const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L,
9887 Value *Cond,
9888 bool ExitWhen) {
9889 PHINode *PN = getConstantEvolvingPHI(Cond, L);
9890 if (!PN) return getCouldNotCompute();
9891
9892 // If the loop is canonicalized, the PHI will have exactly two entries.
9893 // That's the only form we support here.
9894 if (PN->getNumIncomingValues() != 2) return getCouldNotCompute();
9895
9896 DenseMap<Instruction *, Constant *> CurrentIterVals;
9897 BasicBlock *Header = L->getHeader();
9898 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
9899
9900 BasicBlock *Latch = L->getLoopLatch();
9901 assert(Latch && "Should follow from NumIncomingValues == 2!");
9902
9903 for (PHINode &PHI : Header->phis()) {
9904 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
9905 CurrentIterVals[&PHI] = StartCST;
9906 }
9907 if (!CurrentIterVals.count(PN))
9908 return getCouldNotCompute();
9909
9910 // Okay, we find a PHI node that defines the trip count of this loop. Execute
9911 // the loop symbolically to determine when the condition gets a value of
9912 // "ExitWhen".
9913 unsigned MaxIterations = MaxBruteForceIterations; // Limit analysis.
9914 const DataLayout &DL = getDataLayout();
9915 for (unsigned IterationNum = 0; IterationNum != MaxIterations;++IterationNum){
9916 auto *CondVal = dyn_cast_or_null<ConstantInt>(
9917 EvaluateExpression(Cond, L, CurrentIterVals, DL, &TLI));
9918
9919 // Couldn't symbolically evaluate.
9920 if (!CondVal) return getCouldNotCompute();
9921
9922 if (CondVal->getValue() == uint64_t(ExitWhen)) {
9923 ++NumBruteForceTripCountsComputed;
9924 return getConstant(Type::getInt32Ty(getContext()), IterationNum);
9925 }
9926
9927 // Update all the PHI nodes for the next iteration.
9928 DenseMap<Instruction *, Constant *> NextIterVals;
9929
9930 // Create a list of which PHIs we need to compute. We want to do this before
9931 // calling EvaluateExpression on them because that may invalidate iterators
9932 // into CurrentIterVals.
9933 SmallVector<PHINode *, 8> PHIsToCompute;
9934 for (const auto &I : CurrentIterVals) {
9935 PHINode *PHI = dyn_cast<PHINode>(I.first);
9936 if (!PHI || PHI->getParent() != Header) continue;
9937 PHIsToCompute.push_back(PHI);
9938 }
9939 for (PHINode *PHI : PHIsToCompute) {
9940 Constant *&NextPHI = NextIterVals[PHI];
9941 if (NextPHI) continue; // Already computed!
9942
9943 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
9944 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9945 }
9946 CurrentIterVals.swap(NextIterVals);
9947 }
9948
9949 // Too many iterations were needed to evaluate.
9950 return getCouldNotCompute();
9951}
9952
9953const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
9955 ValuesAtScopes[V];
9956 // Check to see if we've folded this expression at this loop before.
9957 for (auto &LS : Values)
9958 if (LS.first == L)
9959 return LS.second ? LS.second : V;
9960
9961 Values.emplace_back(L, nullptr);
9962
9963 // Otherwise compute it.
9964 const SCEV *C = computeSCEVAtScope(V, L);
9965 for (auto &LS : reverse(ValuesAtScopes[V]))
9966 if (LS.first == L) {
9967 LS.second = C;
9968 if (!isa<SCEVConstant>(C))
9969 ValuesAtScopesUsers[C].push_back({L, V});
9970 break;
9971 }
9972 return C;
9973}
9974
9975/// This builds up a Constant using the ConstantExpr interface. That way, we
9976/// will return Constants for objects which aren't represented by a
9977/// SCEVConstant, because SCEVConstant is restricted to ConstantInt.
9978/// Returns NULL if the SCEV isn't representable as a Constant.
9980 switch (V->getSCEVType()) {
9981 case scCouldNotCompute:
9982 case scAddRecExpr:
9983 case scVScale:
9984 return nullptr;
9985 case scConstant:
9986 return cast<SCEVConstant>(V)->getValue();
9987 case scUnknown:
9988 return dyn_cast<Constant>(cast<SCEVUnknown>(V)->getValue());
9989 case scPtrToAddr: {
9991 if (Constant *CastOp = BuildConstantFromSCEV(P2I->getOperand()))
9992 return ConstantExpr::getPtrToAddr(CastOp, P2I->getType());
9993
9994 return nullptr;
9995 }
9996 case scPtrToInt: {
9998 if (Constant *CastOp = BuildConstantFromSCEV(P2I->getOperand()))
9999 return ConstantExpr::getPtrToInt(CastOp, P2I->getType());
10000
10001 return nullptr;
10002 }
10003 case scTruncate: {
10005 if (Constant *CastOp = BuildConstantFromSCEV(ST->getOperand()))
10006 return ConstantExpr::getTrunc(CastOp, ST->getType());
10007 return nullptr;
10008 }
10009 case scAddExpr: {
10010 const SCEVAddExpr *SA = cast<SCEVAddExpr>(V);
10011 Constant *C = nullptr;
10012 for (const SCEV *Op : SA->operands()) {
10014 if (!OpC)
10015 return nullptr;
10016 if (!C) {
10017 C = OpC;
10018 continue;
10019 }
10020 assert(!C->getType()->isPointerTy() &&
10021 "Can only have one pointer, and it must be last");
10022 if (OpC->getType()->isPointerTy()) {
10023 // The offsets have been converted to bytes. We can add bytes using
10024 // an i8 GEP.
10026 OpC, C);
10027 } else {
10028 C = ConstantExpr::getAdd(C, OpC);
10029 }
10030 }
10031 return C;
10032 }
10033 case scMulExpr:
10034 case scSignExtend:
10035 case scZeroExtend:
10036 case scUDivExpr:
10037 case scSMaxExpr:
10038 case scUMaxExpr:
10039 case scSMinExpr:
10040 case scUMinExpr:
10042 return nullptr;
10043 }
10044 llvm_unreachable("Unknown SCEV kind!");
10045}
10046
10047const SCEV *
10048ScalarEvolution::getWithOperands(const SCEV *S,
10049 SmallVectorImpl<const SCEV *> &NewOps) {
10050 switch (S->getSCEVType()) {
10051 case scTruncate:
10052 case scZeroExtend:
10053 case scSignExtend:
10054 case scPtrToAddr:
10055 case scPtrToInt:
10056 return getCastExpr(S->getSCEVType(), NewOps[0], S->getType());
10057 case scAddRecExpr: {
10058 auto *AddRec = cast<SCEVAddRecExpr>(S);
10059 return getAddRecExpr(NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags());
10060 }
10061 case scAddExpr:
10062 return getAddExpr(NewOps, cast<SCEVAddExpr>(S)->getNoWrapFlags());
10063 case scMulExpr:
10064 return getMulExpr(NewOps, cast<SCEVMulExpr>(S)->getNoWrapFlags());
10065 case scUDivExpr:
10066 return getUDivExpr(NewOps[0], NewOps[1]);
10067 case scUMaxExpr:
10068 case scSMaxExpr:
10069 case scUMinExpr:
10070 case scSMinExpr:
10071 return getMinMaxExpr(S->getSCEVType(), NewOps);
10073 return getSequentialMinMaxExpr(S->getSCEVType(), NewOps);
10074 case scConstant:
10075 case scVScale:
10076 case scUnknown:
10077 return S;
10078 case scCouldNotCompute:
10079 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
10080 }
10081 llvm_unreachable("Unknown SCEV kind!");
10082}
10083
10084const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
10085 switch (V->getSCEVType()) {
10086 case scConstant:
10087 case scVScale:
10088 return V;
10089 case scAddRecExpr: {
10090 // If this is a loop recurrence for a loop that does not contain L, then we
10091 // are dealing with the final value computed by the loop.
10092 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(V);
10093 // First, attempt to evaluate each operand.
10094 // Avoid performing the look-up in the common case where the specified
10095 // expression has no loop-variant portions.
10096 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
10097 const SCEV *OpAtScope = getSCEVAtScope(AddRec->getOperand(i), L);
10098 if (OpAtScope == AddRec->getOperand(i))
10099 continue;
10100
10101 // Okay, at least one of these operands is loop variant but might be
10102 // foldable. Build a new instance of the folded commutative expression.
10104 NewOps.reserve(AddRec->getNumOperands());
10105 append_range(NewOps, AddRec->operands().take_front(i));
10106 NewOps.push_back(OpAtScope);
10107 for (++i; i != e; ++i)
10108 NewOps.push_back(getSCEVAtScope(AddRec->getOperand(i), L));
10109
10110 const SCEV *FoldedRec = getAddRecExpr(
10111 NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags(SCEV::FlagNW));
10112 AddRec = dyn_cast<SCEVAddRecExpr>(FoldedRec);
10113 // The addrec may be folded to a nonrecurrence, for example, if the
10114 // induction variable is multiplied by zero after constant folding. Go
10115 // ahead and return the folded value.
10116 if (!AddRec)
10117 return FoldedRec;
10118 break;
10119 }
10120
10121 // If the scope is outside the addrec's loop, evaluate it by using the
10122 // loop exit value of the addrec.
10123 if (!AddRec->getLoop()->contains(L)) {
10124 // To evaluate this recurrence, we need to know how many times the AddRec
10125 // loop iterates. Compute this now.
10126 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop());
10127 if (BackedgeTakenCount == getCouldNotCompute())
10128 return AddRec;
10129
10130 // Then, evaluate the AddRec.
10131 return AddRec->evaluateAtIteration(BackedgeTakenCount, *this);
10132 }
10133
10134 return AddRec;
10135 }
10136 case scTruncate:
10137 case scZeroExtend:
10138 case scSignExtend:
10139 case scPtrToAddr:
10140 case scPtrToInt:
10141 case scAddExpr:
10142 case scMulExpr:
10143 case scUDivExpr:
10144 case scUMaxExpr:
10145 case scSMaxExpr:
10146 case scUMinExpr:
10147 case scSMinExpr:
10148 case scSequentialUMinExpr: {
10149 ArrayRef<const SCEV *> Ops = V->operands();
10150 // Avoid performing the look-up in the common case where the specified
10151 // expression has no loop-variant portions.
10152 for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
10153 const SCEV *OpAtScope = getSCEVAtScope(Ops[i], L);
10154 if (OpAtScope != Ops[i]) {
10155 // Okay, at least one of these operands is loop variant but might be
10156 // foldable. Build a new instance of the folded commutative expression.
10158 NewOps.reserve(Ops.size());
10159 append_range(NewOps, Ops.take_front(i));
10160 NewOps.push_back(OpAtScope);
10161
10162 for (++i; i != e; ++i) {
10163 OpAtScope = getSCEVAtScope(Ops[i], L);
10164 NewOps.push_back(OpAtScope);
10165 }
10166
10167 return getWithOperands(V, NewOps);
10168 }
10169 }
10170 // If we got here, all operands are loop invariant.
10171 return V;
10172 }
10173 case scUnknown: {
10174 // If this instruction is evolved from a constant-evolving PHI, compute the
10175 // exit value from the loop without using SCEVs.
10176 const SCEVUnknown *SU = cast<SCEVUnknown>(V);
10178 if (!I)
10179 return V; // This is some other type of SCEVUnknown, just return it.
10180
10181 if (PHINode *PN = dyn_cast<PHINode>(I)) {
10182 const Loop *CurrLoop = this->LI[I->getParent()];
10183 // Looking for loop exit value.
10184 if (CurrLoop && CurrLoop->getParentLoop() == L &&
10185 PN->getParent() == CurrLoop->getHeader()) {
10186 // Okay, there is no closed form solution for the PHI node. Check
10187 // to see if the loop that contains it has a known backedge-taken
10188 // count. If so, we may be able to force computation of the exit
10189 // value.
10190 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(CurrLoop);
10191 // This trivial case can show up in some degenerate cases where
10192 // the incoming IR has not yet been fully simplified.
10193 if (BackedgeTakenCount->isZero()) {
10194 Value *InitValue = nullptr;
10195 bool MultipleInitValues = false;
10196 for (unsigned i = 0; i < PN->getNumIncomingValues(); i++) {
10197 if (!CurrLoop->contains(PN->getIncomingBlock(i))) {
10198 if (!InitValue)
10199 InitValue = PN->getIncomingValue(i);
10200 else if (InitValue != PN->getIncomingValue(i)) {
10201 MultipleInitValues = true;
10202 break;
10203 }
10204 }
10205 }
10206 if (!MultipleInitValues && InitValue)
10207 return getSCEV(InitValue);
10208 }
10209 // Do we have a loop invariant value flowing around the backedge
10210 // for a loop which must execute the backedge?
10211 if (!isa<SCEVCouldNotCompute>(BackedgeTakenCount) &&
10212 isKnownNonZero(BackedgeTakenCount) &&
10213 PN->getNumIncomingValues() == 2) {
10214
10215 unsigned InLoopPred =
10216 CurrLoop->contains(PN->getIncomingBlock(0)) ? 0 : 1;
10217 Value *BackedgeVal = PN->getIncomingValue(InLoopPred);
10218 if (CurrLoop->isLoopInvariant(BackedgeVal))
10219 return getSCEV(BackedgeVal);
10220 }
10221 if (auto *BTCC = dyn_cast<SCEVConstant>(BackedgeTakenCount)) {
10222 // Okay, we know how many times the containing loop executes. If
10223 // this is a constant evolving PHI node, get the final value at
10224 // the specified iteration number.
10225 Constant *RV =
10226 getConstantEvolutionLoopExitValue(PN, BTCC->getAPInt(), CurrLoop);
10227 if (RV)
10228 return getSCEV(RV);
10229 }
10230 }
10231 }
10232
10233 // Okay, this is an expression that we cannot symbolically evaluate
10234 // into a SCEV. Check to see if it's possible to symbolically evaluate
10235 // the arguments into constants, and if so, try to constant propagate the
10236 // result. This is particularly useful for computing loop exit values.
10237 if (!CanConstantFold(I))
10238 return V; // This is some other type of SCEVUnknown, just return it.
10239
10240 SmallVector<Constant *, 4> Operands;
10241 Operands.reserve(I->getNumOperands());
10242 bool MadeImprovement = false;
10243 for (Value *Op : I->operands()) {
10244 if (Constant *C = dyn_cast<Constant>(Op)) {
10245 Operands.push_back(C);
10246 continue;
10247 }
10248
10249 // If any of the operands is non-constant and if they are
10250 // non-integer and non-pointer, don't even try to analyze them
10251 // with scev techniques.
10252 if (!isSCEVable(Op->getType()))
10253 return V;
10254
10255 const SCEV *OrigV = getSCEV(Op);
10256 const SCEV *OpV = getSCEVAtScope(OrigV, L);
10257 MadeImprovement |= OrigV != OpV;
10258
10260 if (!C)
10261 return V;
10262 assert(C->getType() == Op->getType() && "Type mismatch");
10263 Operands.push_back(C);
10264 }
10265
10266 // Check to see if getSCEVAtScope actually made an improvement.
10267 if (!MadeImprovement)
10268 return V; // This is some other type of SCEVUnknown, just return it.
10269
10270 Constant *C = nullptr;
10271 const DataLayout &DL = getDataLayout();
10272 C = ConstantFoldInstOperands(I, Operands, DL, &TLI,
10273 /*AllowNonDeterministic=*/false);
10274 if (!C)
10275 return V;
10276 return getSCEV(C);
10277 }
10278 case scCouldNotCompute:
10279 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
10280 }
10281 llvm_unreachable("Unknown SCEV type!");
10282}
10283
10285 return getSCEVAtScope(getSCEV(V), L);
10286}
10287
10288const SCEV *ScalarEvolution::stripInjectiveFunctions(const SCEV *S) const {
10290 return stripInjectiveFunctions(ZExt->getOperand());
10292 return stripInjectiveFunctions(SExt->getOperand());
10293 return S;
10294}
10295
10296/// Finds the minimum unsigned root of the following equation:
10297///
10298/// A * X = B (mod N)
10299///
10300/// where N = 2^BW and BW is the common bit width of A and B. The signedness of
10301/// A and B isn't important.
10302///
10303/// If the equation does not have a solution, SCEVCouldNotCompute is returned.
10304static const SCEV *
10307 ScalarEvolution &SE, const Loop *L) {
10308 uint32_t BW = A.getBitWidth();
10309 assert(BW == SE.getTypeSizeInBits(B->getType()));
10310 assert(A != 0 && "A must be non-zero.");
10311
10312 // 1. D = gcd(A, N)
10313 //
10314 // The gcd of A and N may have only one prime factor: 2. The number of
10315 // trailing zeros in A is its multiplicity
10316 uint32_t Mult2 = A.countr_zero();
10317 // D = 2^Mult2
10318
10319 // 2. Check if B is divisible by D.
10320 //
10321 // B is divisible by D if and only if the multiplicity of prime factor 2 for B
10322 // is not less than multiplicity of this prime factor for D.
10323 unsigned MinTZ = SE.getMinTrailingZeros(B);
10324 // Try again with the terminator of the loop predecessor for context-specific
10325 // result, if MinTZ s too small.
10326 if (MinTZ < Mult2 && L->getLoopPredecessor())
10327 MinTZ = SE.getMinTrailingZeros(B, L->getLoopPredecessor()->getTerminator());
10328 if (MinTZ < Mult2) {
10329 // Check if we can prove there's no remainder using URem.
10330 const SCEV *URem =
10331 SE.getURemExpr(B, SE.getConstant(APInt::getOneBitSet(BW, Mult2)));
10332 const SCEV *Zero = SE.getZero(B->getType());
10333 if (!SE.isKnownPredicate(CmpInst::ICMP_EQ, URem, Zero)) {
10334 // Try to add a predicate ensuring B is a multiple of 1 << Mult2.
10335 if (!Predicates)
10336 return SE.getCouldNotCompute();
10337
10338 // Avoid adding a predicate that is known to be false.
10339 if (SE.isKnownPredicate(CmpInst::ICMP_NE, URem, Zero))
10340 return SE.getCouldNotCompute();
10341 Predicates->push_back(SE.getEqualPredicate(URem, Zero));
10342 }
10343 }
10344
10345 // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic
10346 // modulo (N / D).
10347 //
10348 // If D == 1, (N / D) == N == 2^BW, so we need one extra bit to represent
10349 // (N / D) in general. The inverse itself always fits into BW bits, though,
10350 // so we immediately truncate it.
10351 APInt AD = A.lshr(Mult2).trunc(BW - Mult2); // AD = A / D
10352 APInt I = AD.multiplicativeInverse().zext(BW);
10353
10354 // 4. Compute the minimum unsigned root of the equation:
10355 // I * (B / D) mod (N / D)
10356 // To simplify the computation, we factor out the divide by D:
10357 // (I * B mod N) / D
10358 const SCEV *D = SE.getConstant(APInt::getOneBitSet(BW, Mult2));
10359 return SE.getUDivExactExpr(SE.getMulExpr(B, SE.getConstant(I)), D);
10360}
10361
10362/// For a given quadratic addrec, generate coefficients of the corresponding
10363/// quadratic equation, multiplied by a common value to ensure that they are
10364/// integers.
10365/// The returned value is a tuple { A, B, C, M, BitWidth }, where
10366/// Ax^2 + Bx + C is the quadratic function, M is the value that A, B and C
10367/// were multiplied by, and BitWidth is the bit width of the original addrec
10368/// coefficients.
10369/// This function returns std::nullopt if the addrec coefficients are not
10370/// compile- time constants.
10371static std::optional<std::tuple<APInt, APInt, APInt, APInt, unsigned>>
10373 assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!");
10374 const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0));
10375 const SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1));
10376 const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2));
10377 LLVM_DEBUG(dbgs() << __func__ << ": analyzing quadratic addrec: "
10378 << *AddRec << '\n');
10379
10380 // We currently can only solve this if the coefficients are constants.
10381 if (!LC || !MC || !NC) {
10382 LLVM_DEBUG(dbgs() << __func__ << ": coefficients are not constant\n");
10383 return std::nullopt;
10384 }
10385
10386 APInt L = LC->getAPInt();
10387 APInt M = MC->getAPInt();
10388 APInt N = NC->getAPInt();
10389 assert(!N.isZero() && "This is not a quadratic addrec");
10390
10391 unsigned BitWidth = LC->getAPInt().getBitWidth();
10392 unsigned NewWidth = BitWidth + 1;
10393 LLVM_DEBUG(dbgs() << __func__ << ": addrec coeff bw: "
10394 << BitWidth << '\n');
10395 // The sign-extension (as opposed to a zero-extension) here matches the
10396 // extension used in SolveQuadraticEquationWrap (with the same motivation).
10397 N = N.sext(NewWidth);
10398 M = M.sext(NewWidth);
10399 L = L.sext(NewWidth);
10400
10401 // The increments are M, M+N, M+2N, ..., so the accumulated values are
10402 // L+M, (L+M)+(M+N), (L+M)+(M+N)+(M+2N), ..., that is,
10403 // L+M, L+2M+N, L+3M+3N, ...
10404 // After n iterations the accumulated value Acc is L + nM + n(n-1)/2 N.
10405 //
10406 // The equation Acc = 0 is then
10407 // L + nM + n(n-1)/2 N = 0, or 2L + 2M n + n(n-1) N = 0.
10408 // In a quadratic form it becomes:
10409 // N n^2 + (2M-N) n + 2L = 0.
10410
10411 APInt A = N;
10412 APInt B = 2 * M - A;
10413 APInt C = 2 * L;
10414 APInt T = APInt(NewWidth, 2);
10415 LLVM_DEBUG(dbgs() << __func__ << ": equation " << A << "x^2 + " << B
10416 << "x + " << C << ", coeff bw: " << NewWidth
10417 << ", multiplied by " << T << '\n');
10418 return std::make_tuple(A, B, C, T, BitWidth);
10419}
10420
10421/// Helper function to compare optional APInts:
10422/// (a) if X and Y both exist, return min(X, Y),
10423/// (b) if neither X nor Y exist, return std::nullopt,
10424/// (c) if exactly one of X and Y exists, return that value.
10425static std::optional<APInt> MinOptional(std::optional<APInt> X,
10426 std::optional<APInt> Y) {
10427 if (X && Y) {
10428 unsigned W = std::max(X->getBitWidth(), Y->getBitWidth());
10429 APInt XW = X->sext(W);
10430 APInt YW = Y->sext(W);
10431 return XW.slt(YW) ? *X : *Y;
10432 }
10433 if (!X && !Y)
10434 return std::nullopt;
10435 return X ? *X : *Y;
10436}
10437
10438/// Helper function to truncate an optional APInt to a given BitWidth.
10439/// When solving addrec-related equations, it is preferable to return a value
10440/// that has the same bit width as the original addrec's coefficients. If the
10441/// solution fits in the original bit width, truncate it (except for i1).
10442/// Returning a value of a different bit width may inhibit some optimizations.
10443///
10444/// In general, a solution to a quadratic equation generated from an addrec
10445/// may require BW+1 bits, where BW is the bit width of the addrec's
10446/// coefficients. The reason is that the coefficients of the quadratic
10447/// equation are BW+1 bits wide (to avoid truncation when converting from
10448/// the addrec to the equation).
10449static std::optional<APInt> TruncIfPossible(std::optional<APInt> X,
10450 unsigned BitWidth) {
10451 if (!X)
10452 return std::nullopt;
10453 unsigned W = X->getBitWidth();
10455 return X->trunc(BitWidth);
10456 return X;
10457}
10458
10459/// Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n
10460/// iterations. The values L, M, N are assumed to be signed, and they
10461/// should all have the same bit widths.
10462/// Find the least n >= 0 such that c(n) = 0 in the arithmetic modulo 2^BW,
10463/// where BW is the bit width of the addrec's coefficients.
10464/// If the calculated value is a BW-bit integer (for BW > 1), it will be
10465/// returned as such, otherwise the bit width of the returned value may
10466/// be greater than BW.
10467///
10468/// This function returns std::nullopt if
10469/// (a) the addrec coefficients are not constant, or
10470/// (b) SolveQuadraticEquationWrap was unable to find a solution. For cases
10471/// like x^2 = 5, no integer solutions exist, in other cases an integer
10472/// solution may exist, but SolveQuadraticEquationWrap may fail to find it.
10473static std::optional<APInt>
10475 APInt A, B, C, M;
10476 unsigned BitWidth;
10477 auto T = GetQuadraticEquation(AddRec);
10478 if (!T)
10479 return std::nullopt;
10480
10481 std::tie(A, B, C, M, BitWidth) = *T;
10482 LLVM_DEBUG(dbgs() << __func__ << ": solving for unsigned overflow\n");
10483 std::optional<APInt> X =
10485 if (!X)
10486 return std::nullopt;
10487
10488 ConstantInt *CX = ConstantInt::get(SE.getContext(), *X);
10489 ConstantInt *V = EvaluateConstantChrecAtConstant(AddRec, CX, SE);
10490 if (!V->isZero())
10491 return std::nullopt;
10492
10493 return TruncIfPossible(X, BitWidth);
10494}
10495
10496/// Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n
10497/// iterations. The values M, N are assumed to be signed, and they
10498/// should all have the same bit widths.
10499/// Find the least n such that c(n) does not belong to the given range,
10500/// while c(n-1) does.
10501///
10502/// This function returns std::nullopt if
10503/// (a) the addrec coefficients are not constant, or
10504/// (b) SolveQuadraticEquationWrap was unable to find a solution for the
10505/// bounds of the range.
10506static std::optional<APInt>
10508 const ConstantRange &Range, ScalarEvolution &SE) {
10509 assert(AddRec->getOperand(0)->isZero() &&
10510 "Starting value of addrec should be 0");
10511 LLVM_DEBUG(dbgs() << __func__ << ": solving boundary crossing for range "
10512 << Range << ", addrec " << *AddRec << '\n');
10513 // This case is handled in getNumIterationsInRange. Here we can assume that
10514 // we start in the range.
10515 assert(Range.contains(APInt(SE.getTypeSizeInBits(AddRec->getType()), 0)) &&
10516 "Addrec's initial value should be in range");
10517
10518 APInt A, B, C, M;
10519 unsigned BitWidth;
10520 auto T = GetQuadraticEquation(AddRec);
10521 if (!T)
10522 return std::nullopt;
10523
10524 // Be careful about the return value: there can be two reasons for not
10525 // returning an actual number. First, if no solutions to the equations
10526 // were found, and second, if the solutions don't leave the given range.
10527 // The first case means that the actual solution is "unknown", the second
10528 // means that it's known, but not valid. If the solution is unknown, we
10529 // cannot make any conclusions.
10530 // Return a pair: the optional solution and a flag indicating if the
10531 // solution was found.
10532 auto SolveForBoundary =
10533 [&](APInt Bound) -> std::pair<std::optional<APInt>, bool> {
10534 // Solve for signed overflow and unsigned overflow, pick the lower
10535 // solution.
10536 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: checking boundary "
10537 << Bound << " (before multiplying by " << M << ")\n");
10538 Bound *= M; // The quadratic equation multiplier.
10539
10540 std::optional<APInt> SO;
10541 if (BitWidth > 1) {
10542 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10543 "signed overflow\n");
10545 }
10546 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10547 "unsigned overflow\n");
10548 std::optional<APInt> UO =
10550
10551 auto LeavesRange = [&] (const APInt &X) {
10552 ConstantInt *C0 = ConstantInt::get(SE.getContext(), X);
10553 ConstantInt *V0 = EvaluateConstantChrecAtConstant(AddRec, C0, SE);
10554 if (Range.contains(V0->getValue()))
10555 return false;
10556 // X should be at least 1, so X-1 is non-negative.
10557 ConstantInt *C1 = ConstantInt::get(SE.getContext(), X-1);
10558 ConstantInt *V1 = EvaluateConstantChrecAtConstant(AddRec, C1, SE);
10559 if (Range.contains(V1->getValue()))
10560 return true;
10561 return false;
10562 };
10563
10564 // If SolveQuadraticEquationWrap returns std::nullopt, it means that there
10565 // can be a solution, but the function failed to find it. We cannot treat it
10566 // as "no solution".
10567 if (!SO || !UO)
10568 return {std::nullopt, false};
10569
10570 // Check the smaller value first to see if it leaves the range.
10571 // At this point, both SO and UO must have values.
10572 std::optional<APInt> Min = MinOptional(SO, UO);
10573 if (LeavesRange(*Min))
10574 return { Min, true };
10575 std::optional<APInt> Max = Min == SO ? UO : SO;
10576 if (LeavesRange(*Max))
10577 return { Max, true };
10578
10579 // Solutions were found, but were eliminated, hence the "true".
10580 return {std::nullopt, true};
10581 };
10582
10583 std::tie(A, B, C, M, BitWidth) = *T;
10584 // Lower bound is inclusive, subtract 1 to represent the exiting value.
10585 APInt Lower = Range.getLower().sext(A.getBitWidth()) - 1;
10586 APInt Upper = Range.getUpper().sext(A.getBitWidth());
10587 auto SL = SolveForBoundary(Lower);
10588 auto SU = SolveForBoundary(Upper);
10589 // If any of the solutions was unknown, no meaninigful conclusions can
10590 // be made.
10591 if (!SL.second || !SU.second)
10592 return std::nullopt;
10593
10594 // Claim: The correct solution is not some value between Min and Max.
10595 //
10596 // Justification: Assuming that Min and Max are different values, one of
10597 // them is when the first signed overflow happens, the other is when the
10598 // first unsigned overflow happens. Crossing the range boundary is only
10599 // possible via an overflow (treating 0 as a special case of it, modeling
10600 // an overflow as crossing k*2^W for some k).
10601 //
10602 // The interesting case here is when Min was eliminated as an invalid
10603 // solution, but Max was not. The argument is that if there was another
10604 // overflow between Min and Max, it would also have been eliminated if
10605 // it was considered.
10606 //
10607 // For a given boundary, it is possible to have two overflows of the same
10608 // type (signed/unsigned) without having the other type in between: this
10609 // can happen when the vertex of the parabola is between the iterations
10610 // corresponding to the overflows. This is only possible when the two
10611 // overflows cross k*2^W for the same k. In such case, if the second one
10612 // left the range (and was the first one to do so), the first overflow
10613 // would have to enter the range, which would mean that either we had left
10614 // the range before or that we started outside of it. Both of these cases
10615 // are contradictions.
10616 //
10617 // Claim: In the case where SolveForBoundary returns std::nullopt, the correct
10618 // solution is not some value between the Max for this boundary and the
10619 // Min of the other boundary.
10620 //
10621 // Justification: Assume that we had such Max_A and Min_B corresponding
10622 // to range boundaries A and B and such that Max_A < Min_B. If there was
10623 // a solution between Max_A and Min_B, it would have to be caused by an
10624 // overflow corresponding to either A or B. It cannot correspond to B,
10625 // since Min_B is the first occurrence of such an overflow. If it
10626 // corresponded to A, it would have to be either a signed or an unsigned
10627 // overflow that is larger than both eliminated overflows for A. But
10628 // between the eliminated overflows and this overflow, the values would
10629 // cover the entire value space, thus crossing the other boundary, which
10630 // is a contradiction.
10631
10632 return TruncIfPossible(MinOptional(SL.first, SU.first), BitWidth);
10633}
10634
10635ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
10636 const Loop *L,
10637 bool ControlsOnlyExit,
10638 bool AllowPredicates) {
10639
10640 // This is only used for loops with a "x != y" exit test. The exit condition
10641 // is now expressed as a single expression, V = x-y. So the exit test is
10642 // effectively V != 0. We know and take advantage of the fact that this
10643 // expression only being used in a comparison by zero context.
10644
10646 // If the value is a constant
10647 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10648 // If the value is already zero, the branch will execute zero times.
10649 if (C->getValue()->isZero()) return C;
10650 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10651 }
10652
10653 const SCEVAddRecExpr *AddRec =
10654 dyn_cast<SCEVAddRecExpr>(stripInjectiveFunctions(V));
10655
10656 if (!AddRec && AllowPredicates)
10657 // Try to make this an AddRec using runtime tests, in the first X
10658 // iterations of this loop, where X is the SCEV expression found by the
10659 // algorithm below.
10660 AddRec = convertSCEVToAddRecWithPredicates(V, L, Predicates);
10661
10662 if (!AddRec || AddRec->getLoop() != L)
10663 return getCouldNotCompute();
10664
10665 // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of
10666 // the quadratic equation to solve it.
10667 if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) {
10668 // We can only use this value if the chrec ends up with an exact zero
10669 // value at this index. When solving for "X*X != 5", for example, we
10670 // should not accept a root of 2.
10671 if (auto S = SolveQuadraticAddRecExact(AddRec, *this)) {
10672 const auto *R = cast<SCEVConstant>(getConstant(*S));
10673 return ExitLimit(R, R, R, false, Predicates);
10674 }
10675 return getCouldNotCompute();
10676 }
10677
10678 // Otherwise we can only handle this if it is affine.
10679 if (!AddRec->isAffine())
10680 return getCouldNotCompute();
10681
10682 // If this is an affine expression, the execution count of this branch is
10683 // the minimum unsigned root of the following equation:
10684 //
10685 // Start + Step*N = 0 (mod 2^BW)
10686 //
10687 // equivalent to:
10688 //
10689 // Step*N = -Start (mod 2^BW)
10690 //
10691 // where BW is the common bit width of Start and Step.
10692
10693 // Get the initial value for the loop.
10694 const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop());
10695 const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop());
10696
10697 if (!isLoopInvariant(Step, L))
10698 return getCouldNotCompute();
10699
10700 LoopGuards Guards = LoopGuards::collect(L, *this);
10701 // Specialize step for this loop so we get context sensitive facts below.
10702 const SCEV *StepWLG = applyLoopGuards(Step, Guards);
10703
10704 // For positive steps (counting up until unsigned overflow):
10705 // N = -Start/Step (as unsigned)
10706 // For negative steps (counting down to zero):
10707 // N = Start/-Step
10708 // First compute the unsigned distance from zero in the direction of Step.
10709 bool CountDown = isKnownNegative(StepWLG);
10710 if (!CountDown && !isKnownNonNegative(StepWLG))
10711 return getCouldNotCompute();
10712
10713 const SCEV *Distance = CountDown ? Start : getNegativeSCEV(Start);
10714 // Handle unitary steps, which cannot wraparound.
10715 // 1*N = -Start; -1*N = Start (mod 2^BW), so:
10716 // N = Distance (as unsigned)
10717
10718 if (match(Step, m_CombineOr(m_scev_One(), m_scev_AllOnes()))) {
10719 APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, Guards));
10720 MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance));
10721
10722 // When a loop like "for (int i = 0; i != n; ++i) { /* body */ }" is rotated,
10723 // we end up with a loop whose backedge-taken count is n - 1. Detect this
10724 // case, and see if we can improve the bound.
10725 //
10726 // Explicitly handling this here is necessary because getUnsignedRange
10727 // isn't context-sensitive; it doesn't know that we only care about the
10728 // range inside the loop.
10729 const SCEV *Zero = getZero(Distance->getType());
10730 const SCEV *One = getOne(Distance->getType());
10731 const SCEV *DistancePlusOne = getAddExpr(Distance, One);
10732 if (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, DistancePlusOne, Zero)) {
10733 // If Distance + 1 doesn't overflow, we can compute the maximum distance
10734 // as "unsigned_max(Distance + 1) - 1".
10735 ConstantRange CR = getUnsignedRange(DistancePlusOne);
10736 MaxBECount = APIntOps::umin(MaxBECount, CR.getUnsignedMax() - 1);
10737 }
10738 return ExitLimit(Distance, getConstant(MaxBECount), Distance, false,
10739 Predicates);
10740 }
10741
10742 // If the condition controls loop exit (the loop exits only if the expression
10743 // is true) and the addition is no-wrap we can use unsigned divide to
10744 // compute the backedge count. In this case, the step may not divide the
10745 // distance, but we don't care because if the condition is "missed" the loop
10746 // will have undefined behavior due to wrapping.
10747 if (ControlsOnlyExit && AddRec->hasNoSelfWrap() &&
10748 loopHasNoAbnormalExits(AddRec->getLoop())) {
10749
10750 // If the stride is zero and the start is non-zero, the loop must be
10751 // infinite. In C++, most loops are finite by assumption, in which case the
10752 // step being zero implies UB must execute if the loop is entered.
10753 if (!(loopIsFiniteByAssumption(L) && isKnownNonZero(Start)) &&
10754 !isKnownNonZero(StepWLG))
10755 return getCouldNotCompute();
10756
10757 const SCEV *Exact =
10758 getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
10759 const SCEV *ConstantMax = getCouldNotCompute();
10760 if (Exact != getCouldNotCompute()) {
10761 APInt MaxInt = getUnsignedRangeMax(applyLoopGuards(Exact, Guards));
10762 ConstantMax =
10764 }
10765 const SCEV *SymbolicMax =
10766 isa<SCEVCouldNotCompute>(Exact) ? ConstantMax : Exact;
10767 return ExitLimit(Exact, ConstantMax, SymbolicMax, false, Predicates);
10768 }
10769
10770 // Solve the general equation.
10771 const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
10772 if (!StepC || StepC->getValue()->isZero())
10773 return getCouldNotCompute();
10774 const SCEV *E = SolveLinEquationWithOverflow(
10775 StepC->getAPInt(), getNegativeSCEV(Start),
10776 AllowPredicates ? &Predicates : nullptr, *this, L);
10777
10778 const SCEV *M = E;
10779 if (E != getCouldNotCompute()) {
10780 APInt MaxWithGuards = getUnsignedRangeMax(applyLoopGuards(E, Guards));
10781 M = getConstant(APIntOps::umin(MaxWithGuards, getUnsignedRangeMax(E)));
10782 }
10783 auto *S = isa<SCEVCouldNotCompute>(E) ? M : E;
10784 return ExitLimit(E, M, S, false, Predicates);
10785}
10786
10787ScalarEvolution::ExitLimit
10788ScalarEvolution::howFarToNonZero(const SCEV *V, const Loop *L) {
10789 // Loops that look like: while (X == 0) are very strange indeed. We don't
10790 // handle them yet except for the trivial case. This could be expanded in the
10791 // future as needed.
10792
10793 // If the value is a constant, check to see if it is known to be non-zero
10794 // already. If so, the backedge will execute zero times.
10795 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10796 if (!C->getValue()->isZero())
10797 return getZero(C->getType());
10798 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10799 }
10800
10801 // We could implement others, but I really doubt anyone writes loops like
10802 // this, and if they did, they would already be constant folded.
10803 return getCouldNotCompute();
10804}
10805
10806std::pair<const BasicBlock *, const BasicBlock *>
10807ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(const BasicBlock *BB)
10808 const {
10809 // If the block has a unique predecessor, then there is no path from the
10810 // predecessor to the block that does not go through the direct edge
10811 // from the predecessor to the block.
10812 if (const BasicBlock *Pred = BB->getSinglePredecessor())
10813 return {Pred, BB};
10814
10815 // A loop's header is defined to be a block that dominates the loop.
10816 // If the header has a unique predecessor outside the loop, it must be
10817 // a block that has exactly one successor that can reach the loop.
10818 if (const Loop *L = LI.getLoopFor(BB))
10819 return {L->getLoopPredecessor(), L->getHeader()};
10820
10821 return {nullptr, BB};
10822}
10823
10824/// SCEV structural equivalence is usually sufficient for testing whether two
10825/// expressions are equal, however for the purposes of looking for a condition
10826/// guarding a loop, it can be useful to be a little more general, since a
10827/// front-end may have replicated the controlling expression.
10828static bool HasSameValue(const SCEV *A, const SCEV *B) {
10829 // Quick check to see if they are the same SCEV.
10830 if (A == B) return true;
10831
10832 auto ComputesEqualValues = [](const Instruction *A, const Instruction *B) {
10833 // Not all instructions that are "identical" compute the same value. For
10834 // instance, two distinct alloca instructions allocating the same type are
10835 // identical and do not read memory; but compute distinct values.
10836 return A->isIdenticalTo(B) && (isa<BinaryOperator>(A) || isa<GetElementPtrInst>(A));
10837 };
10838
10839 // Otherwise, if they're both SCEVUnknown, it's possible that they hold
10840 // two different instructions with the same value. Check for this case.
10841 if (const SCEVUnknown *AU = dyn_cast<SCEVUnknown>(A))
10842 if (const SCEVUnknown *BU = dyn_cast<SCEVUnknown>(B))
10843 if (const Instruction *AI = dyn_cast<Instruction>(AU->getValue()))
10844 if (const Instruction *BI = dyn_cast<Instruction>(BU->getValue()))
10845 if (ComputesEqualValues(AI, BI))
10846 return true;
10847
10848 // Otherwise assume they may have a different value.
10849 return false;
10850}
10851
10852static bool MatchBinarySub(const SCEV *S, const SCEV *&LHS, const SCEV *&RHS) {
10853 const SCEV *Op0, *Op1;
10854 if (!match(S, m_scev_Add(m_SCEV(Op0), m_SCEV(Op1))))
10855 return false;
10856 if (match(Op0, m_scev_Mul(m_scev_AllOnes(), m_SCEV(RHS)))) {
10857 LHS = Op1;
10858 return true;
10859 }
10860 if (match(Op1, m_scev_Mul(m_scev_AllOnes(), m_SCEV(RHS)))) {
10861 LHS = Op0;
10862 return true;
10863 }
10864 return false;
10865}
10866
10868 const SCEV *&RHS, unsigned Depth) {
10869 bool Changed = false;
10870 // Simplifies ICMP to trivial true or false by turning it into '0 == 0' or
10871 // '0 != 0'.
10872 auto TrivialCase = [&](bool TriviallyTrue) {
10874 Pred = TriviallyTrue ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE;
10875 return true;
10876 };
10877 // If we hit the max recursion limit bail out.
10878 if (Depth >= 3)
10879 return false;
10880
10881 const SCEV *NewLHS, *NewRHS;
10882 if (match(LHS, m_scev_c_Mul(m_SCEV(NewLHS), m_SCEVVScale())) &&
10883 match(RHS, m_scev_c_Mul(m_SCEV(NewRHS), m_SCEVVScale()))) {
10884 const SCEVMulExpr *LMul = cast<SCEVMulExpr>(LHS);
10885 const SCEVMulExpr *RMul = cast<SCEVMulExpr>(RHS);
10886
10887 // (X * vscale) pred (Y * vscale) ==> X pred Y
10888 // when both multiples are NSW.
10889 // (X * vscale) uicmp/eq/ne (Y * vscale) ==> X uicmp/eq/ne Y
10890 // when both multiples are NUW.
10891 if ((LMul->hasNoSignedWrap() && RMul->hasNoSignedWrap()) ||
10892 (LMul->hasNoUnsignedWrap() && RMul->hasNoUnsignedWrap() &&
10893 !ICmpInst::isSigned(Pred))) {
10894 LHS = NewLHS;
10895 RHS = NewRHS;
10896 Changed = true;
10897 }
10898 }
10899
10900 // Canonicalize a constant to the right side.
10901 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
10902 // Check for both operands constant.
10903 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
10904 if (!ICmpInst::compare(LHSC->getAPInt(), RHSC->getAPInt(), Pred))
10905 return TrivialCase(false);
10906 return TrivialCase(true);
10907 }
10908 // Otherwise swap the operands to put the constant on the right.
10909 std::swap(LHS, RHS);
10911 Changed = true;
10912 }
10913
10914 // If we're comparing an addrec with a value which is loop-invariant in the
10915 // addrec's loop, put the addrec on the left. Also make a dominance check,
10916 // as both operands could be addrecs loop-invariant in each other's loop.
10917 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(RHS)) {
10918 const Loop *L = AR->getLoop();
10919 if (isLoopInvariant(LHS, L) && properlyDominates(LHS, L->getHeader())) {
10920 std::swap(LHS, RHS);
10922 Changed = true;
10923 }
10924 }
10925
10926 // If there's a constant operand, canonicalize comparisons with boundary
10927 // cases, and canonicalize *-or-equal comparisons to regular comparisons.
10928 if (const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS)) {
10929 const APInt &RA = RC->getAPInt();
10930
10931 bool SimplifiedByConstantRange = false;
10932
10933 if (!ICmpInst::isEquality(Pred)) {
10935 if (ExactCR.isFullSet())
10936 return TrivialCase(true);
10937 if (ExactCR.isEmptySet())
10938 return TrivialCase(false);
10939
10940 APInt NewRHS;
10941 CmpInst::Predicate NewPred;
10942 if (ExactCR.getEquivalentICmp(NewPred, NewRHS) &&
10943 ICmpInst::isEquality(NewPred)) {
10944 // We were able to convert an inequality to an equality.
10945 Pred = NewPred;
10946 RHS = getConstant(NewRHS);
10947 Changed = SimplifiedByConstantRange = true;
10948 }
10949 }
10950
10951 if (!SimplifiedByConstantRange) {
10952 switch (Pred) {
10953 default:
10954 break;
10955 case ICmpInst::ICMP_EQ:
10956 case ICmpInst::ICMP_NE:
10957 // Fold ((-1) * %a) + %b == 0 (equivalent to %b-%a == 0) into %a == %b.
10958 if (RA.isZero() && MatchBinarySub(LHS, LHS, RHS))
10959 Changed = true;
10960 break;
10961
10962 // The "Should have been caught earlier!" messages refer to the fact
10963 // that the ExactCR.isFullSet() or ExactCR.isEmptySet() check above
10964 // should have fired on the corresponding cases, and canonicalized the
10965 // check to trivial case.
10966
10967 case ICmpInst::ICMP_UGE:
10968 assert(!RA.isMinValue() && "Should have been caught earlier!");
10969 Pred = ICmpInst::ICMP_UGT;
10970 RHS = getConstant(RA - 1);
10971 Changed = true;
10972 break;
10973 case ICmpInst::ICMP_ULE:
10974 assert(!RA.isMaxValue() && "Should have been caught earlier!");
10975 Pred = ICmpInst::ICMP_ULT;
10976 RHS = getConstant(RA + 1);
10977 Changed = true;
10978 break;
10979 case ICmpInst::ICMP_SGE:
10980 assert(!RA.isMinSignedValue() && "Should have been caught earlier!");
10981 Pred = ICmpInst::ICMP_SGT;
10982 RHS = getConstant(RA - 1);
10983 Changed = true;
10984 break;
10985 case ICmpInst::ICMP_SLE:
10986 assert(!RA.isMaxSignedValue() && "Should have been caught earlier!");
10987 Pred = ICmpInst::ICMP_SLT;
10988 RHS = getConstant(RA + 1);
10989 Changed = true;
10990 break;
10991 }
10992 }
10993 }
10994
10995 // Check for obvious equality.
10996 if (HasSameValue(LHS, RHS)) {
10997 if (ICmpInst::isTrueWhenEqual(Pred))
10998 return TrivialCase(true);
11000 return TrivialCase(false);
11001 }
11002
11003 // If possible, canonicalize GE/LE comparisons to GT/LT comparisons, by
11004 // adding or subtracting 1 from one of the operands.
11005 switch (Pred) {
11006 case ICmpInst::ICMP_SLE:
11007 if (!getSignedRangeMax(RHS).isMaxSignedValue()) {
11008 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
11010 Pred = ICmpInst::ICMP_SLT;
11011 Changed = true;
11012 } else if (!getSignedRangeMin(LHS).isMinSignedValue()) {
11013 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS,
11015 Pred = ICmpInst::ICMP_SLT;
11016 Changed = true;
11017 }
11018 break;
11019 case ICmpInst::ICMP_SGE:
11020 if (!getSignedRangeMin(RHS).isMinSignedValue()) {
11021 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS,
11023 Pred = ICmpInst::ICMP_SGT;
11024 Changed = true;
11025 } else if (!getSignedRangeMax(LHS).isMaxSignedValue()) {
11026 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
11028 Pred = ICmpInst::ICMP_SGT;
11029 Changed = true;
11030 }
11031 break;
11032 case ICmpInst::ICMP_ULE:
11033 if (!getUnsignedRangeMax(RHS).isMaxValue()) {
11034 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
11036 Pred = ICmpInst::ICMP_ULT;
11037 Changed = true;
11038 } else if (!getUnsignedRangeMin(LHS).isMinValue()) {
11039 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS);
11040 Pred = ICmpInst::ICMP_ULT;
11041 Changed = true;
11042 }
11043 break;
11044 case ICmpInst::ICMP_UGE:
11045 // If RHS is an op we can fold the -1, try that first.
11046 // Otherwise prefer LHS to preserve the nuw flag.
11047 if ((isa<SCEVConstant>(RHS) ||
11049 isa<SCEVConstant>(cast<SCEVNAryExpr>(RHS)->getOperand(0)))) &&
11050 !getUnsignedRangeMin(RHS).isMinValue()) {
11051 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
11052 Pred = ICmpInst::ICMP_UGT;
11053 Changed = true;
11054 } else if (!getUnsignedRangeMax(LHS).isMaxValue()) {
11055 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
11057 Pred = ICmpInst::ICMP_UGT;
11058 Changed = true;
11059 } else if (!getUnsignedRangeMin(RHS).isMinValue()) {
11060 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
11061 Pred = ICmpInst::ICMP_UGT;
11062 Changed = true;
11063 }
11064 break;
11065 default:
11066 break;
11067 }
11068
11069 // TODO: More simplifications are possible here.
11070
11071 // Recursively simplify until we either hit a recursion limit or nothing
11072 // changes.
11073 if (Changed)
11074 (void)SimplifyICmpOperands(Pred, LHS, RHS, Depth + 1);
11075
11076 return Changed;
11077}
11078
11080 return getSignedRangeMax(S).isNegative();
11081}
11082
11086
11088 return !getSignedRangeMin(S).isNegative();
11089}
11090
11094
11096 // Query push down for cases where the unsigned range is
11097 // less than sufficient.
11098 if (const auto *SExt = dyn_cast<SCEVSignExtendExpr>(S))
11099 return isKnownNonZero(SExt->getOperand(0));
11100 return getUnsignedRangeMin(S) != 0;
11101}
11102
11104 bool OrNegative) {
11105 auto NonRecursive = [this, OrNegative](const SCEV *S) {
11106 if (auto *C = dyn_cast<SCEVConstant>(S))
11107 return C->getAPInt().isPowerOf2() ||
11108 (OrNegative && C->getAPInt().isNegatedPowerOf2());
11109
11110 // The vscale_range indicates vscale is a power-of-two.
11111 return isa<SCEVVScale>(S) && F.hasFnAttribute(Attribute::VScaleRange);
11112 };
11113
11114 if (NonRecursive(S))
11115 return true;
11116
11117 auto *Mul = dyn_cast<SCEVMulExpr>(S);
11118 if (!Mul)
11119 return false;
11120 return all_of(Mul->operands(), NonRecursive) && (OrZero || isKnownNonZero(S));
11121}
11122
11124 const SCEV *S, uint64_t M,
11126 if (M == 0)
11127 return false;
11128 if (M == 1)
11129 return true;
11130
11131 // Recursively check AddRec operands. An AddRecExpr S is a multiple of M if S
11132 // starts with a multiple of M and at every iteration step S only adds
11133 // multiples of M.
11134 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S))
11135 return isKnownMultipleOf(AddRec->getStart(), M, Assumptions) &&
11136 isKnownMultipleOf(AddRec->getStepRecurrence(*this), M, Assumptions);
11137
11138 // For a constant, check that "S % M == 0".
11139 if (auto *Cst = dyn_cast<SCEVConstant>(S)) {
11140 APInt C = Cst->getAPInt();
11141 return C.urem(M) == 0;
11142 }
11143
11144 // TODO: Also check other SCEV expressions, i.e., SCEVAddRecExpr, etc.
11145
11146 // Basic tests have failed.
11147 // Check "S % M == 0" at compile time and record runtime Assumptions.
11148 auto *STy = dyn_cast<IntegerType>(S->getType());
11149 const SCEV *SmodM =
11150 getURemExpr(S, getConstant(ConstantInt::get(STy, M, false)));
11151 const SCEV *Zero = getZero(STy);
11152
11153 // Check whether "S % M == 0" is known at compile time.
11154 if (isKnownPredicate(ICmpInst::ICMP_EQ, SmodM, Zero))
11155 return true;
11156
11157 // Check whether "S % M != 0" is known at compile time.
11158 if (isKnownPredicate(ICmpInst::ICMP_NE, SmodM, Zero))
11159 return false;
11160
11162
11163 // Detect redundant predicates.
11164 for (auto *A : Assumptions)
11165 if (A->implies(P, *this))
11166 return true;
11167
11168 // Only record non-redundant predicates.
11169 Assumptions.push_back(P);
11170 return true;
11171}
11172
11174 return ((isKnownNonNegative(S1) && isKnownNonNegative(S2)) ||
11176}
11177
11178std::pair<const SCEV *, const SCEV *>
11180 // Compute SCEV on entry of loop L.
11181 const SCEV *Start = SCEVInitRewriter::rewrite(S, L, *this);
11182 if (Start == getCouldNotCompute())
11183 return { Start, Start };
11184 // Compute post increment SCEV for loop L.
11185 const SCEV *PostInc = SCEVPostIncRewriter::rewrite(S, L, *this);
11186 assert(PostInc != getCouldNotCompute() && "Unexpected could not compute");
11187 return { Start, PostInc };
11188}
11189
11191 const SCEV *RHS) {
11192 // First collect all loops.
11194 getUsedLoops(LHS, LoopsUsed);
11195 getUsedLoops(RHS, LoopsUsed);
11196
11197 if (LoopsUsed.empty())
11198 return false;
11199
11200 // Domination relationship must be a linear order on collected loops.
11201#ifndef NDEBUG
11202 for (const auto *L1 : LoopsUsed)
11203 for (const auto *L2 : LoopsUsed)
11204 assert((DT.dominates(L1->getHeader(), L2->getHeader()) ||
11205 DT.dominates(L2->getHeader(), L1->getHeader())) &&
11206 "Domination relationship is not a linear order");
11207#endif
11208
11209 const Loop *MDL =
11210 *llvm::max_element(LoopsUsed, [&](const Loop *L1, const Loop *L2) {
11211 return DT.properlyDominates(L1->getHeader(), L2->getHeader());
11212 });
11213
11214 // Get init and post increment value for LHS.
11215 auto SplitLHS = SplitIntoInitAndPostInc(MDL, LHS);
11216 // if LHS contains unknown non-invariant SCEV then bail out.
11217 if (SplitLHS.first == getCouldNotCompute())
11218 return false;
11219 assert (SplitLHS.second != getCouldNotCompute() && "Unexpected CNC");
11220 // Get init and post increment value for RHS.
11221 auto SplitRHS = SplitIntoInitAndPostInc(MDL, RHS);
11222 // if RHS contains unknown non-invariant SCEV then bail out.
11223 if (SplitRHS.first == getCouldNotCompute())
11224 return false;
11225 assert (SplitRHS.second != getCouldNotCompute() && "Unexpected CNC");
11226 // It is possible that init SCEV contains an invariant load but it does
11227 // not dominate MDL and is not available at MDL loop entry, so we should
11228 // check it here.
11229 if (!isAvailableAtLoopEntry(SplitLHS.first, MDL) ||
11230 !isAvailableAtLoopEntry(SplitRHS.first, MDL))
11231 return false;
11232
11233 // It seems backedge guard check is faster than entry one so in some cases
11234 // it can speed up whole estimation by short circuit
11235 return isLoopBackedgeGuardedByCond(MDL, Pred, SplitLHS.second,
11236 SplitRHS.second) &&
11237 isLoopEntryGuardedByCond(MDL, Pred, SplitLHS.first, SplitRHS.first);
11238}
11239
11241 const SCEV *RHS) {
11242 // Canonicalize the inputs first.
11243 (void)SimplifyICmpOperands(Pred, LHS, RHS);
11244
11245 if (isKnownViaInduction(Pred, LHS, RHS))
11246 return true;
11247
11248 if (isKnownPredicateViaSplitting(Pred, LHS, RHS))
11249 return true;
11250
11251 // Otherwise see what can be done with some simple reasoning.
11252 return isKnownViaNonRecursiveReasoning(Pred, LHS, RHS);
11253}
11254
11256 const SCEV *LHS,
11257 const SCEV *RHS) {
11258 if (isKnownPredicate(Pred, LHS, RHS))
11259 return true;
11261 return false;
11262 return std::nullopt;
11263}
11264
11266 const SCEV *RHS,
11267 const Instruction *CtxI) {
11268 // TODO: Analyze guards and assumes from Context's block.
11269 return isKnownPredicate(Pred, LHS, RHS) ||
11270 isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS);
11271}
11272
11273std::optional<bool>
11275 const SCEV *RHS, const Instruction *CtxI) {
11276 std::optional<bool> KnownWithoutContext = evaluatePredicate(Pred, LHS, RHS);
11277 if (KnownWithoutContext)
11278 return KnownWithoutContext;
11279
11280 if (isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS))
11281 return true;
11283 CtxI->getParent(), ICmpInst::getInverseCmpPredicate(Pred), LHS, RHS))
11284 return false;
11285 return std::nullopt;
11286}
11287
11289 const SCEVAddRecExpr *LHS,
11290 const SCEV *RHS) {
11291 const Loop *L = LHS->getLoop();
11292 return isLoopEntryGuardedByCond(L, Pred, LHS->getStart(), RHS) &&
11293 isLoopBackedgeGuardedByCond(L, Pred, LHS->getPostIncExpr(*this), RHS);
11294}
11295
11296std::optional<ScalarEvolution::MonotonicPredicateType>
11298 ICmpInst::Predicate Pred) {
11299 auto Result = getMonotonicPredicateTypeImpl(LHS, Pred);
11300
11301#ifndef NDEBUG
11302 // Verify an invariant: inverting the predicate should turn a monotonically
11303 // increasing change to a monotonically decreasing one, and vice versa.
11304 if (Result) {
11305 auto ResultSwapped =
11306 getMonotonicPredicateTypeImpl(LHS, ICmpInst::getSwappedPredicate(Pred));
11307
11308 assert(*ResultSwapped != *Result &&
11309 "monotonicity should flip as we flip the predicate");
11310 }
11311#endif
11312
11313 return Result;
11314}
11315
11316std::optional<ScalarEvolution::MonotonicPredicateType>
11317ScalarEvolution::getMonotonicPredicateTypeImpl(const SCEVAddRecExpr *LHS,
11318 ICmpInst::Predicate Pred) {
11319 // A zero step value for LHS means the induction variable is essentially a
11320 // loop invariant value. We don't really depend on the predicate actually
11321 // flipping from false to true (for increasing predicates, and the other way
11322 // around for decreasing predicates), all we care about is that *if* the
11323 // predicate changes then it only changes from false to true.
11324 //
11325 // A zero step value in itself is not very useful, but there may be places
11326 // where SCEV can prove X >= 0 but not prove X > 0, so it is helpful to be
11327 // as general as possible.
11328
11329 // Only handle LE/LT/GE/GT predicates.
11330 if (!ICmpInst::isRelational(Pred))
11331 return std::nullopt;
11332
11333 bool IsGreater = ICmpInst::isGE(Pred) || ICmpInst::isGT(Pred);
11334 assert((IsGreater || ICmpInst::isLE(Pred) || ICmpInst::isLT(Pred)) &&
11335 "Should be greater or less!");
11336
11337 // Check that AR does not wrap.
11338 if (ICmpInst::isUnsigned(Pred)) {
11339 if (!LHS->hasNoUnsignedWrap())
11340 return std::nullopt;
11342 }
11343 assert(ICmpInst::isSigned(Pred) &&
11344 "Relational predicate is either signed or unsigned!");
11345 if (!LHS->hasNoSignedWrap())
11346 return std::nullopt;
11347
11348 const SCEV *Step = LHS->getStepRecurrence(*this);
11349
11350 if (isKnownNonNegative(Step))
11352
11353 if (isKnownNonPositive(Step))
11355
11356 return std::nullopt;
11357}
11358
11359std::optional<ScalarEvolution::LoopInvariantPredicate>
11361 const SCEV *RHS, const Loop *L,
11362 const Instruction *CtxI) {
11363 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
11364 if (!isLoopInvariant(RHS, L)) {
11365 if (!isLoopInvariant(LHS, L))
11366 return std::nullopt;
11367
11368 std::swap(LHS, RHS);
11370 }
11371
11372 const SCEVAddRecExpr *ArLHS = dyn_cast<SCEVAddRecExpr>(LHS);
11373 if (!ArLHS || ArLHS->getLoop() != L)
11374 return std::nullopt;
11375
11376 auto MonotonicType = getMonotonicPredicateType(ArLHS, Pred);
11377 if (!MonotonicType)
11378 return std::nullopt;
11379 // If the predicate "ArLHS `Pred` RHS" monotonically increases from false to
11380 // true as the loop iterates, and the backedge is control dependent on
11381 // "ArLHS `Pred` RHS" == true then we can reason as follows:
11382 //
11383 // * if the predicate was false in the first iteration then the predicate
11384 // is never evaluated again, since the loop exits without taking the
11385 // backedge.
11386 // * if the predicate was true in the first iteration then it will
11387 // continue to be true for all future iterations since it is
11388 // monotonically increasing.
11389 //
11390 // For both the above possibilities, we can replace the loop varying
11391 // predicate with its value on the first iteration of the loop (which is
11392 // loop invariant).
11393 //
11394 // A similar reasoning applies for a monotonically decreasing predicate, by
11395 // replacing true with false and false with true in the above two bullets.
11397 auto P = Increasing ? Pred : ICmpInst::getInverseCmpPredicate(Pred);
11398
11399 if (isLoopBackedgeGuardedByCond(L, P, LHS, RHS))
11401 RHS);
11402
11403 if (!CtxI)
11404 return std::nullopt;
11405 // Try to prove via context.
11406 // TODO: Support other cases.
11407 switch (Pred) {
11408 default:
11409 break;
11410 case ICmpInst::ICMP_ULE:
11411 case ICmpInst::ICMP_ULT: {
11412 assert(ArLHS->hasNoUnsignedWrap() && "Is a requirement of monotonicity!");
11413 // Given preconditions
11414 // (1) ArLHS does not cross the border of positive and negative parts of
11415 // range because of:
11416 // - Positive step; (TODO: lift this limitation)
11417 // - nuw - does not cross zero boundary;
11418 // - nsw - does not cross SINT_MAX boundary;
11419 // (2) ArLHS <s RHS
11420 // (3) RHS >=s 0
11421 // we can replace the loop variant ArLHS <u RHS condition with loop
11422 // invariant Start(ArLHS) <u RHS.
11423 //
11424 // Because of (1) there are two options:
11425 // - ArLHS is always negative. It means that ArLHS <u RHS is always false;
11426 // - ArLHS is always non-negative. Because of (3) RHS is also non-negative.
11427 // It means that ArLHS <s RHS <=> ArLHS <u RHS.
11428 // Because of (2) ArLHS <u RHS is trivially true.
11429 // All together it means that ArLHS <u RHS <=> Start(ArLHS) >=s 0.
11430 // We can strengthen this to Start(ArLHS) <u RHS.
11431 auto SignFlippedPred = ICmpInst::getFlippedSignednessPredicate(Pred);
11432 if (ArLHS->hasNoSignedWrap() && ArLHS->isAffine() &&
11433 isKnownPositive(ArLHS->getStepRecurrence(*this)) &&
11434 isKnownNonNegative(RHS) &&
11435 isKnownPredicateAt(SignFlippedPred, ArLHS, RHS, CtxI))
11437 RHS);
11438 }
11439 }
11440
11441 return std::nullopt;
11442}
11443
11444std::optional<ScalarEvolution::LoopInvariantPredicate>
11446 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11447 const Instruction *CtxI, const SCEV *MaxIter) {
11449 Pred, LHS, RHS, L, CtxI, MaxIter))
11450 return LIP;
11451 if (auto *UMin = dyn_cast<SCEVUMinExpr>(MaxIter))
11452 // Number of iterations expressed as UMIN isn't always great for expressing
11453 // the value on the last iteration. If the straightforward approach didn't
11454 // work, try the following trick: if the a predicate is invariant for X, it
11455 // is also invariant for umin(X, ...). So try to find something that works
11456 // among subexpressions of MaxIter expressed as umin.
11457 for (auto *Op : UMin->operands())
11459 Pred, LHS, RHS, L, CtxI, Op))
11460 return LIP;
11461 return std::nullopt;
11462}
11463
11464std::optional<ScalarEvolution::LoopInvariantPredicate>
11466 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11467 const Instruction *CtxI, const SCEV *MaxIter) {
11468 // Try to prove the following set of facts:
11469 // - The predicate is monotonic in the iteration space.
11470 // - If the check does not fail on the 1st iteration:
11471 // - No overflow will happen during first MaxIter iterations;
11472 // - It will not fail on the MaxIter'th iteration.
11473 // If the check does fail on the 1st iteration, we leave the loop and no
11474 // other checks matter.
11475
11476 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
11477 if (!isLoopInvariant(RHS, L)) {
11478 if (!isLoopInvariant(LHS, L))
11479 return std::nullopt;
11480
11481 std::swap(LHS, RHS);
11483 }
11484
11485 auto *AR = dyn_cast<SCEVAddRecExpr>(LHS);
11486 if (!AR || AR->getLoop() != L)
11487 return std::nullopt;
11488
11489 // The predicate must be relational (i.e. <, <=, >=, >).
11490 if (!ICmpInst::isRelational(Pred))
11491 return std::nullopt;
11492
11493 // TODO: Support steps other than +/- 1.
11494 const SCEV *Step = AR->getStepRecurrence(*this);
11495 auto *One = getOne(Step->getType());
11496 auto *MinusOne = getNegativeSCEV(One);
11497 if (Step != One && Step != MinusOne)
11498 return std::nullopt;
11499
11500 // Type mismatch here means that MaxIter is potentially larger than max
11501 // unsigned value in start type, which mean we cannot prove no wrap for the
11502 // indvar.
11503 if (AR->getType() != MaxIter->getType())
11504 return std::nullopt;
11505
11506 // Value of IV on suggested last iteration.
11507 const SCEV *Last = AR->evaluateAtIteration(MaxIter, *this);
11508 // Does it still meet the requirement?
11509 if (!isLoopBackedgeGuardedByCond(L, Pred, Last, RHS))
11510 return std::nullopt;
11511 // Because step is +/- 1 and MaxIter has same type as Start (i.e. it does
11512 // not exceed max unsigned value of this type), this effectively proves
11513 // that there is no wrap during the iteration. To prove that there is no
11514 // signed/unsigned wrap, we need to check that
11515 // Start <= Last for step = 1 or Start >= Last for step = -1.
11516 ICmpInst::Predicate NoOverflowPred =
11518 if (Step == MinusOne)
11519 NoOverflowPred = ICmpInst::getSwappedCmpPredicate(NoOverflowPred);
11520 const SCEV *Start = AR->getStart();
11521 if (!isKnownPredicateAt(NoOverflowPred, Start, Last, CtxI))
11522 return std::nullopt;
11523
11524 // Everything is fine.
11525 return ScalarEvolution::LoopInvariantPredicate(Pred, Start, RHS);
11526}
11527
11528bool ScalarEvolution::isKnownPredicateViaConstantRanges(CmpPredicate Pred,
11529 const SCEV *LHS,
11530 const SCEV *RHS) {
11531 if (HasSameValue(LHS, RHS))
11532 return ICmpInst::isTrueWhenEqual(Pred);
11533
11534 auto CheckRange = [&](bool IsSigned) {
11535 auto RangeLHS = IsSigned ? getSignedRange(LHS) : getUnsignedRange(LHS);
11536 auto RangeRHS = IsSigned ? getSignedRange(RHS) : getUnsignedRange(RHS);
11537 return RangeLHS.icmp(Pred, RangeRHS);
11538 };
11539
11540 // The check at the top of the function catches the case where the values are
11541 // known to be equal.
11542 if (Pred == CmpInst::ICMP_EQ)
11543 return false;
11544
11545 if (Pred == CmpInst::ICMP_NE) {
11546 if (CheckRange(true) || CheckRange(false))
11547 return true;
11548 auto *Diff = getMinusSCEV(LHS, RHS);
11549 return !isa<SCEVCouldNotCompute>(Diff) && isKnownNonZero(Diff);
11550 }
11551
11552 return CheckRange(CmpInst::isSigned(Pred));
11553}
11554
11555bool ScalarEvolution::isKnownPredicateViaNoOverflow(CmpPredicate Pred,
11556 const SCEV *LHS,
11557 const SCEV *RHS) {
11558 // Match X to (A + C1)<ExpectedFlags> and Y to (A + C2)<ExpectedFlags>, where
11559 // C1 and C2 are constant integers. If either X or Y are not add expressions,
11560 // consider them as X + 0 and Y + 0 respectively. C1 and C2 are returned via
11561 // OutC1 and OutC2.
11562 auto MatchBinaryAddToConst = [this](const SCEV *X, const SCEV *Y,
11563 APInt &OutC1, APInt &OutC2,
11564 SCEV::NoWrapFlags ExpectedFlags) {
11565 const SCEV *XNonConstOp, *XConstOp;
11566 const SCEV *YNonConstOp, *YConstOp;
11567 SCEV::NoWrapFlags XFlagsPresent;
11568 SCEV::NoWrapFlags YFlagsPresent;
11569
11570 if (!splitBinaryAdd(X, XConstOp, XNonConstOp, XFlagsPresent)) {
11571 XConstOp = getZero(X->getType());
11572 XNonConstOp = X;
11573 XFlagsPresent = ExpectedFlags;
11574 }
11575 if (!isa<SCEVConstant>(XConstOp))
11576 return false;
11577
11578 if (!splitBinaryAdd(Y, YConstOp, YNonConstOp, YFlagsPresent)) {
11579 YConstOp = getZero(Y->getType());
11580 YNonConstOp = Y;
11581 YFlagsPresent = ExpectedFlags;
11582 }
11583
11584 if (YNonConstOp != XNonConstOp)
11585 return false;
11586
11587 if (!isa<SCEVConstant>(YConstOp))
11588 return false;
11589
11590 // When matching ADDs with NUW flags (and unsigned predicates), only the
11591 // second ADD (with the larger constant) requires NUW.
11592 if ((YFlagsPresent & ExpectedFlags) != ExpectedFlags)
11593 return false;
11594 if (ExpectedFlags != SCEV::FlagNUW &&
11595 (XFlagsPresent & ExpectedFlags) != ExpectedFlags) {
11596 return false;
11597 }
11598
11599 OutC1 = cast<SCEVConstant>(XConstOp)->getAPInt();
11600 OutC2 = cast<SCEVConstant>(YConstOp)->getAPInt();
11601
11602 return true;
11603 };
11604
11605 APInt C1;
11606 APInt C2;
11607
11608 switch (Pred) {
11609 default:
11610 break;
11611
11612 case ICmpInst::ICMP_SGE:
11613 std::swap(LHS, RHS);
11614 [[fallthrough]];
11615 case ICmpInst::ICMP_SLE:
11616 // (X + C1)<nsw> s<= (X + C2)<nsw> if C1 s<= C2.
11617 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.sle(C2))
11618 return true;
11619
11620 break;
11621
11622 case ICmpInst::ICMP_SGT:
11623 std::swap(LHS, RHS);
11624 [[fallthrough]];
11625 case ICmpInst::ICMP_SLT:
11626 // (X + C1)<nsw> s< (X + C2)<nsw> if C1 s< C2.
11627 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.slt(C2))
11628 return true;
11629
11630 break;
11631
11632 case ICmpInst::ICMP_UGE:
11633 std::swap(LHS, RHS);
11634 [[fallthrough]];
11635 case ICmpInst::ICMP_ULE:
11636 // (X + C1) u<= (X + C2)<nuw> for C1 u<= C2.
11637 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNUW) && C1.ule(C2))
11638 return true;
11639
11640 break;
11641
11642 case ICmpInst::ICMP_UGT:
11643 std::swap(LHS, RHS);
11644 [[fallthrough]];
11645 case ICmpInst::ICMP_ULT:
11646 // (X + C1) u< (X + C2)<nuw> if C1 u< C2.
11647 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNUW) && C1.ult(C2))
11648 return true;
11649 break;
11650 }
11651
11652 return false;
11653}
11654
11655bool ScalarEvolution::isKnownPredicateViaSplitting(CmpPredicate Pred,
11656 const SCEV *LHS,
11657 const SCEV *RHS) {
11658 if (Pred != ICmpInst::ICMP_ULT || ProvingSplitPredicate)
11659 return false;
11660
11661 // Allowing arbitrary number of activations of isKnownPredicateViaSplitting on
11662 // the stack can result in exponential time complexity.
11663 SaveAndRestore Restore(ProvingSplitPredicate, true);
11664
11665 // If L >= 0 then I `ult` L <=> I >= 0 && I `slt` L
11666 //
11667 // To prove L >= 0 we use isKnownNonNegative whereas to prove I >= 0 we use
11668 // isKnownPredicate. isKnownPredicate is more powerful, but also more
11669 // expensive; and using isKnownNonNegative(RHS) is sufficient for most of the
11670 // interesting cases seen in practice. We can consider "upgrading" L >= 0 to
11671 // use isKnownPredicate later if needed.
11672 return isKnownNonNegative(RHS) &&
11675}
11676
11677bool ScalarEvolution::isImpliedViaGuard(const BasicBlock *BB, CmpPredicate Pred,
11678 const SCEV *LHS, const SCEV *RHS) {
11679 // No need to even try if we know the module has no guards.
11680 if (!HasGuards)
11681 return false;
11682
11683 return any_of(*BB, [&](const Instruction &I) {
11684 using namespace llvm::PatternMatch;
11685
11686 Value *Condition;
11688 m_Value(Condition))) &&
11689 isImpliedCond(Pred, LHS, RHS, Condition, false);
11690 });
11691}
11692
11693/// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is
11694/// protected by a conditional between LHS and RHS. This is used to
11695/// to eliminate casts.
11697 CmpPredicate Pred,
11698 const SCEV *LHS,
11699 const SCEV *RHS) {
11700 // Interpret a null as meaning no loop, where there is obviously no guard
11701 // (interprocedural conditions notwithstanding). Do not bother about
11702 // unreachable loops.
11703 if (!L || !DT.isReachableFromEntry(L->getHeader()))
11704 return true;
11705
11706 if (VerifyIR)
11707 assert(!verifyFunction(*L->getHeader()->getParent(), &dbgs()) &&
11708 "This cannot be done on broken IR!");
11709
11710
11711 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
11712 return true;
11713
11714 BasicBlock *Latch = L->getLoopLatch();
11715 if (!Latch)
11716 return false;
11717
11718 BranchInst *LoopContinuePredicate =
11720 if (LoopContinuePredicate && LoopContinuePredicate->isConditional() &&
11721 isImpliedCond(Pred, LHS, RHS,
11722 LoopContinuePredicate->getCondition(),
11723 LoopContinuePredicate->getSuccessor(0) != L->getHeader()))
11724 return true;
11725
11726 // We don't want more than one activation of the following loops on the stack
11727 // -- that can lead to O(n!) time complexity.
11728 if (WalkingBEDominatingConds)
11729 return false;
11730
11731 SaveAndRestore ClearOnExit(WalkingBEDominatingConds, true);
11732
11733 // See if we can exploit a trip count to prove the predicate.
11734 const auto &BETakenInfo = getBackedgeTakenInfo(L);
11735 const SCEV *LatchBECount = BETakenInfo.getExact(Latch, this);
11736 if (LatchBECount != getCouldNotCompute()) {
11737 // We know that Latch branches back to the loop header exactly
11738 // LatchBECount times. This means the backdege condition at Latch is
11739 // equivalent to "{0,+,1} u< LatchBECount".
11740 Type *Ty = LatchBECount->getType();
11741 auto NoWrapFlags = SCEV::NoWrapFlags(SCEV::FlagNUW | SCEV::FlagNW);
11742 const SCEV *LoopCounter =
11743 getAddRecExpr(getZero(Ty), getOne(Ty), L, NoWrapFlags);
11744 if (isImpliedCond(Pred, LHS, RHS, ICmpInst::ICMP_ULT, LoopCounter,
11745 LatchBECount))
11746 return true;
11747 }
11748
11749 // Check conditions due to any @llvm.assume intrinsics.
11750 for (auto &AssumeVH : AC.assumptions()) {
11751 if (!AssumeVH)
11752 continue;
11753 auto *CI = cast<CallInst>(AssumeVH);
11754 if (!DT.dominates(CI, Latch->getTerminator()))
11755 continue;
11756
11757 if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false))
11758 return true;
11759 }
11760
11761 if (isImpliedViaGuard(Latch, Pred, LHS, RHS))
11762 return true;
11763
11764 for (DomTreeNode *DTN = DT[Latch], *HeaderDTN = DT[L->getHeader()];
11765 DTN != HeaderDTN; DTN = DTN->getIDom()) {
11766 assert(DTN && "should reach the loop header before reaching the root!");
11767
11768 BasicBlock *BB = DTN->getBlock();
11769 if (isImpliedViaGuard(BB, Pred, LHS, RHS))
11770 return true;
11771
11772 BasicBlock *PBB = BB->getSinglePredecessor();
11773 if (!PBB)
11774 continue;
11775
11776 BranchInst *ContinuePredicate = dyn_cast<BranchInst>(PBB->getTerminator());
11777 if (!ContinuePredicate || !ContinuePredicate->isConditional())
11778 continue;
11779
11780 Value *Condition = ContinuePredicate->getCondition();
11781
11782 // If we have an edge `E` within the loop body that dominates the only
11783 // latch, the condition guarding `E` also guards the backedge. This
11784 // reasoning works only for loops with a single latch.
11785
11786 BasicBlockEdge DominatingEdge(PBB, BB);
11787 if (DominatingEdge.isSingleEdge()) {
11788 // We're constructively (and conservatively) enumerating edges within the
11789 // loop body that dominate the latch. The dominator tree better agree
11790 // with us on this:
11791 assert(DT.dominates(DominatingEdge, Latch) && "should be!");
11792
11793 if (isImpliedCond(Pred, LHS, RHS, Condition,
11794 BB != ContinuePredicate->getSuccessor(0)))
11795 return true;
11796 }
11797 }
11798
11799 return false;
11800}
11801
11803 CmpPredicate Pred,
11804 const SCEV *LHS,
11805 const SCEV *RHS) {
11806 // Do not bother proving facts for unreachable code.
11807 if (!DT.isReachableFromEntry(BB))
11808 return true;
11809 if (VerifyIR)
11810 assert(!verifyFunction(*BB->getParent(), &dbgs()) &&
11811 "This cannot be done on broken IR!");
11812
11813 // If we cannot prove strict comparison (e.g. a > b), maybe we can prove
11814 // the facts (a >= b && a != b) separately. A typical situation is when the
11815 // non-strict comparison is known from ranges and non-equality is known from
11816 // dominating predicates. If we are proving strict comparison, we always try
11817 // to prove non-equality and non-strict comparison separately.
11818 CmpPredicate NonStrictPredicate = ICmpInst::getNonStrictCmpPredicate(Pred);
11819 const bool ProvingStrictComparison =
11820 Pred != NonStrictPredicate.dropSameSign();
11821 bool ProvedNonStrictComparison = false;
11822 bool ProvedNonEquality = false;
11823
11824 auto SplitAndProve = [&](std::function<bool(CmpPredicate)> Fn) -> bool {
11825 if (!ProvedNonStrictComparison)
11826 ProvedNonStrictComparison = Fn(NonStrictPredicate);
11827 if (!ProvedNonEquality)
11828 ProvedNonEquality = Fn(ICmpInst::ICMP_NE);
11829 if (ProvedNonStrictComparison && ProvedNonEquality)
11830 return true;
11831 return false;
11832 };
11833
11834 if (ProvingStrictComparison) {
11835 auto ProofFn = [&](CmpPredicate P) {
11836 return isKnownViaNonRecursiveReasoning(P, LHS, RHS);
11837 };
11838 if (SplitAndProve(ProofFn))
11839 return true;
11840 }
11841
11842 // Try to prove (Pred, LHS, RHS) using isImpliedCond.
11843 auto ProveViaCond = [&](const Value *Condition, bool Inverse) {
11844 const Instruction *CtxI = &BB->front();
11845 if (isImpliedCond(Pred, LHS, RHS, Condition, Inverse, CtxI))
11846 return true;
11847 if (ProvingStrictComparison) {
11848 auto ProofFn = [&](CmpPredicate P) {
11849 return isImpliedCond(P, LHS, RHS, Condition, Inverse, CtxI);
11850 };
11851 if (SplitAndProve(ProofFn))
11852 return true;
11853 }
11854 return false;
11855 };
11856
11857 // Starting at the block's predecessor, climb up the predecessor chain, as long
11858 // as there are predecessors that can be found that have unique successors
11859 // leading to the original block.
11860 const Loop *ContainingLoop = LI.getLoopFor(BB);
11861 const BasicBlock *PredBB;
11862 if (ContainingLoop && ContainingLoop->getHeader() == BB)
11863 PredBB = ContainingLoop->getLoopPredecessor();
11864 else
11865 PredBB = BB->getSinglePredecessor();
11866 for (std::pair<const BasicBlock *, const BasicBlock *> Pair(PredBB, BB);
11867 Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
11868 const BranchInst *BlockEntryPredicate =
11869 dyn_cast<BranchInst>(Pair.first->getTerminator());
11870 if (!BlockEntryPredicate || BlockEntryPredicate->isUnconditional())
11871 continue;
11872
11873 if (ProveViaCond(BlockEntryPredicate->getCondition(),
11874 BlockEntryPredicate->getSuccessor(0) != Pair.second))
11875 return true;
11876 }
11877
11878 // Check conditions due to any @llvm.assume intrinsics.
11879 for (auto &AssumeVH : AC.assumptions()) {
11880 if (!AssumeVH)
11881 continue;
11882 auto *CI = cast<CallInst>(AssumeVH);
11883 if (!DT.dominates(CI, BB))
11884 continue;
11885
11886 if (ProveViaCond(CI->getArgOperand(0), false))
11887 return true;
11888 }
11889
11890 // Check conditions due to any @llvm.experimental.guard intrinsics.
11891 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
11892 F.getParent(), Intrinsic::experimental_guard);
11893 if (GuardDecl)
11894 for (const auto *GU : GuardDecl->users())
11895 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
11896 if (Guard->getFunction() == BB->getParent() && DT.dominates(Guard, BB))
11897 if (ProveViaCond(Guard->getArgOperand(0), false))
11898 return true;
11899 return false;
11900}
11901
11903 const SCEV *LHS,
11904 const SCEV *RHS) {
11905 // Interpret a null as meaning no loop, where there is obviously no guard
11906 // (interprocedural conditions notwithstanding).
11907 if (!L)
11908 return false;
11909
11910 // Both LHS and RHS must be available at loop entry.
11912 "LHS is not available at Loop Entry");
11914 "RHS is not available at Loop Entry");
11915
11916 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
11917 return true;
11918
11919 return isBasicBlockEntryGuardedByCond(L->getHeader(), Pred, LHS, RHS);
11920}
11921
11922bool ScalarEvolution::isImpliedCond(CmpPredicate Pred, const SCEV *LHS,
11923 const SCEV *RHS,
11924 const Value *FoundCondValue, bool Inverse,
11925 const Instruction *CtxI) {
11926 // False conditions implies anything. Do not bother analyzing it further.
11927 if (FoundCondValue ==
11928 ConstantInt::getBool(FoundCondValue->getContext(), Inverse))
11929 return true;
11930
11931 if (!PendingLoopPredicates.insert(FoundCondValue).second)
11932 return false;
11933
11934 llvm::scope_exit ClearOnExit(
11935 [&]() { PendingLoopPredicates.erase(FoundCondValue); });
11936
11937 // Recursively handle And and Or conditions.
11938 const Value *Op0, *Op1;
11939 if (match(FoundCondValue, m_LogicalAnd(m_Value(Op0), m_Value(Op1)))) {
11940 if (!Inverse)
11941 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
11942 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
11943 } else if (match(FoundCondValue, m_LogicalOr(m_Value(Op0), m_Value(Op1)))) {
11944 if (Inverse)
11945 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
11946 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
11947 }
11948
11949 const ICmpInst *ICI = dyn_cast<ICmpInst>(FoundCondValue);
11950 if (!ICI) return false;
11951
11952 // Now that we found a conditional branch that dominates the loop or controls
11953 // the loop latch. Check to see if it is the comparison we are looking for.
11954 CmpPredicate FoundPred;
11955 if (Inverse)
11956 FoundPred = ICI->getInverseCmpPredicate();
11957 else
11958 FoundPred = ICI->getCmpPredicate();
11959
11960 const SCEV *FoundLHS = getSCEV(ICI->getOperand(0));
11961 const SCEV *FoundRHS = getSCEV(ICI->getOperand(1));
11962
11963 return isImpliedCond(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS, CtxI);
11964}
11965
11966bool ScalarEvolution::isImpliedCond(CmpPredicate Pred, const SCEV *LHS,
11967 const SCEV *RHS, CmpPredicate FoundPred,
11968 const SCEV *FoundLHS, const SCEV *FoundRHS,
11969 const Instruction *CtxI) {
11970 // Balance the types.
11971 if (getTypeSizeInBits(LHS->getType()) <
11972 getTypeSizeInBits(FoundLHS->getType())) {
11973 // For unsigned and equality predicates, try to prove that both found
11974 // operands fit into narrow unsigned range. If so, try to prove facts in
11975 // narrow types.
11976 if (!CmpInst::isSigned(FoundPred) && !FoundLHS->getType()->isPointerTy() &&
11977 !FoundRHS->getType()->isPointerTy()) {
11978 auto *NarrowType = LHS->getType();
11979 auto *WideType = FoundLHS->getType();
11980 auto BitWidth = getTypeSizeInBits(NarrowType);
11981 const SCEV *MaxValue = getZeroExtendExpr(
11983 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundLHS,
11984 MaxValue) &&
11985 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundRHS,
11986 MaxValue)) {
11987 const SCEV *TruncFoundLHS = getTruncateExpr(FoundLHS, NarrowType);
11988 const SCEV *TruncFoundRHS = getTruncateExpr(FoundRHS, NarrowType);
11989 // We cannot preserve samesign after truncation.
11990 if (isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred.dropSameSign(),
11991 TruncFoundLHS, TruncFoundRHS, CtxI))
11992 return true;
11993 }
11994 }
11995
11996 if (LHS->getType()->isPointerTy() || RHS->getType()->isPointerTy())
11997 return false;
11998 if (CmpInst::isSigned(Pred)) {
11999 LHS = getSignExtendExpr(LHS, FoundLHS->getType());
12000 RHS = getSignExtendExpr(RHS, FoundLHS->getType());
12001 } else {
12002 LHS = getZeroExtendExpr(LHS, FoundLHS->getType());
12003 RHS = getZeroExtendExpr(RHS, FoundLHS->getType());
12004 }
12005 } else if (getTypeSizeInBits(LHS->getType()) >
12006 getTypeSizeInBits(FoundLHS->getType())) {
12007 if (FoundLHS->getType()->isPointerTy() || FoundRHS->getType()->isPointerTy())
12008 return false;
12009 if (CmpInst::isSigned(FoundPred)) {
12010 FoundLHS = getSignExtendExpr(FoundLHS, LHS->getType());
12011 FoundRHS = getSignExtendExpr(FoundRHS, LHS->getType());
12012 } else {
12013 FoundLHS = getZeroExtendExpr(FoundLHS, LHS->getType());
12014 FoundRHS = getZeroExtendExpr(FoundRHS, LHS->getType());
12015 }
12016 }
12017 return isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, FoundLHS,
12018 FoundRHS, CtxI);
12019}
12020
12021bool ScalarEvolution::isImpliedCondBalancedTypes(
12022 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, CmpPredicate FoundPred,
12023 const SCEV *FoundLHS, const SCEV *FoundRHS, const Instruction *CtxI) {
12025 getTypeSizeInBits(FoundLHS->getType()) &&
12026 "Types should be balanced!");
12027 // Canonicalize the query to match the way instcombine will have
12028 // canonicalized the comparison.
12029 if (SimplifyICmpOperands(Pred, LHS, RHS))
12030 if (LHS == RHS)
12031 return CmpInst::isTrueWhenEqual(Pred);
12032 if (SimplifyICmpOperands(FoundPred, FoundLHS, FoundRHS))
12033 if (FoundLHS == FoundRHS)
12034 return CmpInst::isFalseWhenEqual(FoundPred);
12035
12036 // Check to see if we can make the LHS or RHS match.
12037 if (LHS == FoundRHS || RHS == FoundLHS) {
12038 if (isa<SCEVConstant>(RHS)) {
12039 std::swap(FoundLHS, FoundRHS);
12040 FoundPred = ICmpInst::getSwappedCmpPredicate(FoundPred);
12041 } else {
12042 std::swap(LHS, RHS);
12044 }
12045 }
12046
12047 // Check whether the found predicate is the same as the desired predicate.
12048 if (auto P = CmpPredicate::getMatching(FoundPred, Pred))
12049 return isImpliedCondOperands(*P, LHS, RHS, FoundLHS, FoundRHS, CtxI);
12050
12051 // Check whether swapping the found predicate makes it the same as the
12052 // desired predicate.
12053 if (auto P = CmpPredicate::getMatching(
12054 ICmpInst::getSwappedCmpPredicate(FoundPred), Pred)) {
12055 // We can write the implication
12056 // 0. LHS Pred RHS <- FoundLHS SwapPred FoundRHS
12057 // using one of the following ways:
12058 // 1. LHS Pred RHS <- FoundRHS Pred FoundLHS
12059 // 2. RHS SwapPred LHS <- FoundLHS SwapPred FoundRHS
12060 // 3. LHS Pred RHS <- ~FoundLHS Pred ~FoundRHS
12061 // 4. ~LHS SwapPred ~RHS <- FoundLHS SwapPred FoundRHS
12062 // Forms 1. and 2. require swapping the operands of one condition. Don't
12063 // do this if it would break canonical constant/addrec ordering.
12065 return isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(*P), RHS,
12066 LHS, FoundLHS, FoundRHS, CtxI);
12067 if (!isa<SCEVConstant>(FoundRHS) && !isa<SCEVAddRecExpr>(FoundLHS))
12068 return isImpliedCondOperands(*P, LHS, RHS, FoundRHS, FoundLHS, CtxI);
12069
12070 // There's no clear preference between forms 3. and 4., try both. Avoid
12071 // forming getNotSCEV of pointer values as the resulting subtract is
12072 // not legal.
12073 if (!LHS->getType()->isPointerTy() && !RHS->getType()->isPointerTy() &&
12074 isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(*P),
12075 getNotSCEV(LHS), getNotSCEV(RHS), FoundLHS,
12076 FoundRHS, CtxI))
12077 return true;
12078
12079 if (!FoundLHS->getType()->isPointerTy() &&
12080 !FoundRHS->getType()->isPointerTy() &&
12081 isImpliedCondOperands(*P, LHS, RHS, getNotSCEV(FoundLHS),
12082 getNotSCEV(FoundRHS), CtxI))
12083 return true;
12084
12085 return false;
12086 }
12087
12088 auto IsSignFlippedPredicate = [](CmpInst::Predicate P1,
12090 assert(P1 != P2 && "Handled earlier!");
12091 return CmpInst::isRelational(P2) &&
12093 };
12094 if (IsSignFlippedPredicate(Pred, FoundPred)) {
12095 // Unsigned comparison is the same as signed comparison when both the
12096 // operands are non-negative or negative.
12097 if (haveSameSign(FoundLHS, FoundRHS))
12098 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI);
12099 // Create local copies that we can freely swap and canonicalize our
12100 // conditions to "le/lt".
12101 CmpPredicate CanonicalPred = Pred, CanonicalFoundPred = FoundPred;
12102 const SCEV *CanonicalLHS = LHS, *CanonicalRHS = RHS,
12103 *CanonicalFoundLHS = FoundLHS, *CanonicalFoundRHS = FoundRHS;
12104 if (ICmpInst::isGT(CanonicalPred) || ICmpInst::isGE(CanonicalPred)) {
12105 CanonicalPred = ICmpInst::getSwappedCmpPredicate(CanonicalPred);
12106 CanonicalFoundPred = ICmpInst::getSwappedCmpPredicate(CanonicalFoundPred);
12107 std::swap(CanonicalLHS, CanonicalRHS);
12108 std::swap(CanonicalFoundLHS, CanonicalFoundRHS);
12109 }
12110 assert((ICmpInst::isLT(CanonicalPred) || ICmpInst::isLE(CanonicalPred)) &&
12111 "Must be!");
12112 assert((ICmpInst::isLT(CanonicalFoundPred) ||
12113 ICmpInst::isLE(CanonicalFoundPred)) &&
12114 "Must be!");
12115 if (ICmpInst::isSigned(CanonicalPred) && isKnownNonNegative(CanonicalRHS))
12116 // Use implication:
12117 // x <u y && y >=s 0 --> x <s y.
12118 // If we can prove the left part, the right part is also proven.
12119 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
12120 CanonicalRHS, CanonicalFoundLHS,
12121 CanonicalFoundRHS);
12122 if (ICmpInst::isUnsigned(CanonicalPred) && isKnownNegative(CanonicalRHS))
12123 // Use implication:
12124 // x <s y && y <s 0 --> x <u y.
12125 // If we can prove the left part, the right part is also proven.
12126 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
12127 CanonicalRHS, CanonicalFoundLHS,
12128 CanonicalFoundRHS);
12129 }
12130
12131 // Check if we can make progress by sharpening ranges.
12132 if (FoundPred == ICmpInst::ICMP_NE &&
12133 (isa<SCEVConstant>(FoundLHS) || isa<SCEVConstant>(FoundRHS))) {
12134
12135 const SCEVConstant *C = nullptr;
12136 const SCEV *V = nullptr;
12137
12138 if (isa<SCEVConstant>(FoundLHS)) {
12139 C = cast<SCEVConstant>(FoundLHS);
12140 V = FoundRHS;
12141 } else {
12142 C = cast<SCEVConstant>(FoundRHS);
12143 V = FoundLHS;
12144 }
12145
12146 // The guarding predicate tells us that C != V. If the known range
12147 // of V is [C, t), we can sharpen the range to [C + 1, t). The
12148 // range we consider has to correspond to same signedness as the
12149 // predicate we're interested in folding.
12150
12151 APInt Min = ICmpInst::isSigned(Pred) ?
12153
12154 if (Min == C->getAPInt()) {
12155 // Given (V >= Min && V != Min) we conclude V >= (Min + 1).
12156 // This is true even if (Min + 1) wraps around -- in case of
12157 // wraparound, (Min + 1) < Min, so (V >= Min => V >= (Min + 1)).
12158
12159 APInt SharperMin = Min + 1;
12160
12161 switch (Pred) {
12162 case ICmpInst::ICMP_SGE:
12163 case ICmpInst::ICMP_UGE:
12164 // We know V `Pred` SharperMin. If this implies LHS `Pred`
12165 // RHS, we're done.
12166 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(SharperMin),
12167 CtxI))
12168 return true;
12169 [[fallthrough]];
12170
12171 case ICmpInst::ICMP_SGT:
12172 case ICmpInst::ICMP_UGT:
12173 // We know from the range information that (V `Pred` Min ||
12174 // V == Min). We know from the guarding condition that !(V
12175 // == Min). This gives us
12176 //
12177 // V `Pred` Min || V == Min && !(V == Min)
12178 // => V `Pred` Min
12179 //
12180 // If V `Pred` Min implies LHS `Pred` RHS, we're done.
12181
12182 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min), CtxI))
12183 return true;
12184 break;
12185
12186 // `LHS < RHS` and `LHS <= RHS` are handled in the same way as `RHS > LHS` and `RHS >= LHS` respectively.
12187 case ICmpInst::ICMP_SLE:
12188 case ICmpInst::ICMP_ULE:
12189 if (isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(Pred), RHS,
12190 LHS, V, getConstant(SharperMin), CtxI))
12191 return true;
12192 [[fallthrough]];
12193
12194 case ICmpInst::ICMP_SLT:
12195 case ICmpInst::ICMP_ULT:
12196 if (isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(Pred), RHS,
12197 LHS, V, getConstant(Min), CtxI))
12198 return true;
12199 break;
12200
12201 default:
12202 // No change
12203 break;
12204 }
12205 }
12206 }
12207
12208 // Check whether the actual condition is beyond sufficient.
12209 if (FoundPred == ICmpInst::ICMP_EQ)
12210 if (ICmpInst::isTrueWhenEqual(Pred))
12211 if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
12212 return true;
12213 if (Pred == ICmpInst::ICMP_NE)
12214 if (!ICmpInst::isTrueWhenEqual(FoundPred))
12215 if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
12216 return true;
12217
12218 if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS))
12219 return true;
12220
12221 // Otherwise assume the worst.
12222 return false;
12223}
12224
12225bool ScalarEvolution::splitBinaryAdd(const SCEV *Expr,
12226 const SCEV *&L, const SCEV *&R,
12227 SCEV::NoWrapFlags &Flags) {
12228 if (!match(Expr, m_scev_Add(m_SCEV(L), m_SCEV(R))))
12229 return false;
12230
12231 Flags = cast<SCEVAddExpr>(Expr)->getNoWrapFlags();
12232 return true;
12233}
12234
12235std::optional<APInt>
12237 // We avoid subtracting expressions here because this function is usually
12238 // fairly deep in the call stack (i.e. is called many times).
12239
12240 unsigned BW = getTypeSizeInBits(More->getType());
12241 APInt Diff(BW, 0);
12242 APInt DiffMul(BW, 1);
12243 // Try various simplifications to reduce the difference to a constant. Limit
12244 // the number of allowed simplifications to keep compile-time low.
12245 for (unsigned I = 0; I < 8; ++I) {
12246 if (More == Less)
12247 return Diff;
12248
12249 // Reduce addrecs with identical steps to their start value.
12251 const auto *LAR = cast<SCEVAddRecExpr>(Less);
12252 const auto *MAR = cast<SCEVAddRecExpr>(More);
12253
12254 if (LAR->getLoop() != MAR->getLoop())
12255 return std::nullopt;
12256
12257 // We look at affine expressions only; not for correctness but to keep
12258 // getStepRecurrence cheap.
12259 if (!LAR->isAffine() || !MAR->isAffine())
12260 return std::nullopt;
12261
12262 if (LAR->getStepRecurrence(*this) != MAR->getStepRecurrence(*this))
12263 return std::nullopt;
12264
12265 Less = LAR->getStart();
12266 More = MAR->getStart();
12267 continue;
12268 }
12269
12270 // Try to match a common constant multiply.
12271 auto MatchConstMul =
12272 [](const SCEV *S) -> std::optional<std::pair<const SCEV *, APInt>> {
12273 const APInt *C;
12274 const SCEV *Op;
12275 if (match(S, m_scev_Mul(m_scev_APInt(C), m_SCEV(Op))))
12276 return {{Op, *C}};
12277 return std::nullopt;
12278 };
12279 if (auto MatchedMore = MatchConstMul(More)) {
12280 if (auto MatchedLess = MatchConstMul(Less)) {
12281 if (MatchedMore->second == MatchedLess->second) {
12282 More = MatchedMore->first;
12283 Less = MatchedLess->first;
12284 DiffMul *= MatchedMore->second;
12285 continue;
12286 }
12287 }
12288 }
12289
12290 // Try to cancel out common factors in two add expressions.
12292 auto Add = [&](const SCEV *S, int Mul) {
12293 if (auto *C = dyn_cast<SCEVConstant>(S)) {
12294 if (Mul == 1) {
12295 Diff += C->getAPInt() * DiffMul;
12296 } else {
12297 assert(Mul == -1);
12298 Diff -= C->getAPInt() * DiffMul;
12299 }
12300 } else
12301 Multiplicity[S] += Mul;
12302 };
12303 auto Decompose = [&](const SCEV *S, int Mul) {
12304 if (isa<SCEVAddExpr>(S)) {
12305 for (const SCEV *Op : S->operands())
12306 Add(Op, Mul);
12307 } else
12308 Add(S, Mul);
12309 };
12310 Decompose(More, 1);
12311 Decompose(Less, -1);
12312
12313 // Check whether all the non-constants cancel out, or reduce to new
12314 // More/Less values.
12315 const SCEV *NewMore = nullptr, *NewLess = nullptr;
12316 for (const auto &[S, Mul] : Multiplicity) {
12317 if (Mul == 0)
12318 continue;
12319 if (Mul == 1) {
12320 if (NewMore)
12321 return std::nullopt;
12322 NewMore = S;
12323 } else if (Mul == -1) {
12324 if (NewLess)
12325 return std::nullopt;
12326 NewLess = S;
12327 } else
12328 return std::nullopt;
12329 }
12330
12331 // Values stayed the same, no point in trying further.
12332 if (NewMore == More || NewLess == Less)
12333 return std::nullopt;
12334
12335 More = NewMore;
12336 Less = NewLess;
12337
12338 // Reduced to constant.
12339 if (!More && !Less)
12340 return Diff;
12341
12342 // Left with variable on only one side, bail out.
12343 if (!More || !Less)
12344 return std::nullopt;
12345 }
12346
12347 // Did not reduce to constant.
12348 return std::nullopt;
12349}
12350
12351bool ScalarEvolution::isImpliedCondOperandsViaAddRecStart(
12352 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS,
12353 const SCEV *FoundRHS, const Instruction *CtxI) {
12354 // Try to recognize the following pattern:
12355 //
12356 // FoundRHS = ...
12357 // ...
12358 // loop:
12359 // FoundLHS = {Start,+,W}
12360 // context_bb: // Basic block from the same loop
12361 // known(Pred, FoundLHS, FoundRHS)
12362 //
12363 // If some predicate is known in the context of a loop, it is also known on
12364 // each iteration of this loop, including the first iteration. Therefore, in
12365 // this case, `FoundLHS Pred FoundRHS` implies `Start Pred FoundRHS`. Try to
12366 // prove the original pred using this fact.
12367 if (!CtxI)
12368 return false;
12369 const BasicBlock *ContextBB = CtxI->getParent();
12370 // Make sure AR varies in the context block.
12371 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundLHS)) {
12372 const Loop *L = AR->getLoop();
12373 // Make sure that context belongs to the loop and executes on 1st iteration
12374 // (if it ever executes at all).
12375 if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch()))
12376 return false;
12377 if (!isAvailableAtLoopEntry(FoundRHS, AR->getLoop()))
12378 return false;
12379 return isImpliedCondOperands(Pred, LHS, RHS, AR->getStart(), FoundRHS);
12380 }
12381
12382 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundRHS)) {
12383 const Loop *L = AR->getLoop();
12384 // Make sure that context belongs to the loop and executes on 1st iteration
12385 // (if it ever executes at all).
12386 if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch()))
12387 return false;
12388 if (!isAvailableAtLoopEntry(FoundLHS, AR->getLoop()))
12389 return false;
12390 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, AR->getStart());
12391 }
12392
12393 return false;
12394}
12395
12396bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow(CmpPredicate Pred,
12397 const SCEV *LHS,
12398 const SCEV *RHS,
12399 const SCEV *FoundLHS,
12400 const SCEV *FoundRHS) {
12401 if (Pred != CmpInst::ICMP_SLT && Pred != CmpInst::ICMP_ULT)
12402 return false;
12403
12404 const auto *AddRecLHS = dyn_cast<SCEVAddRecExpr>(LHS);
12405 if (!AddRecLHS)
12406 return false;
12407
12408 const auto *AddRecFoundLHS = dyn_cast<SCEVAddRecExpr>(FoundLHS);
12409 if (!AddRecFoundLHS)
12410 return false;
12411
12412 // We'd like to let SCEV reason about control dependencies, so we constrain
12413 // both the inequalities to be about add recurrences on the same loop. This
12414 // way we can use isLoopEntryGuardedByCond later.
12415
12416 const Loop *L = AddRecFoundLHS->getLoop();
12417 if (L != AddRecLHS->getLoop())
12418 return false;
12419
12420 // FoundLHS u< FoundRHS u< -C => (FoundLHS + C) u< (FoundRHS + C) ... (1)
12421 //
12422 // FoundLHS s< FoundRHS s< INT_MIN - C => (FoundLHS + C) s< (FoundRHS + C)
12423 // ... (2)
12424 //
12425 // Informal proof for (2), assuming (1) [*]:
12426 //
12427 // We'll also assume (A s< B) <=> ((A + INT_MIN) u< (B + INT_MIN)) ... (3)[**]
12428 //
12429 // Then
12430 //
12431 // FoundLHS s< FoundRHS s< INT_MIN - C
12432 // <=> (FoundLHS + INT_MIN) u< (FoundRHS + INT_MIN) u< -C [ using (3) ]
12433 // <=> (FoundLHS + INT_MIN + C) u< (FoundRHS + INT_MIN + C) [ using (1) ]
12434 // <=> (FoundLHS + INT_MIN + C + INT_MIN) s<
12435 // (FoundRHS + INT_MIN + C + INT_MIN) [ using (3) ]
12436 // <=> FoundLHS + C s< FoundRHS + C
12437 //
12438 // [*]: (1) can be proved by ruling out overflow.
12439 //
12440 // [**]: This can be proved by analyzing all the four possibilities:
12441 // (A s< 0, B s< 0), (A s< 0, B s>= 0), (A s>= 0, B s< 0) and
12442 // (A s>= 0, B s>= 0).
12443 //
12444 // Note:
12445 // Despite (2), "FoundRHS s< INT_MIN - C" does not mean that "FoundRHS + C"
12446 // will not sign underflow. For instance, say FoundLHS = (i8 -128), FoundRHS
12447 // = (i8 -127) and C = (i8 -100). Then INT_MIN - C = (i8 -28), and FoundRHS
12448 // s< (INT_MIN - C). Lack of sign overflow / underflow in "FoundRHS + C" is
12449 // neither necessary nor sufficient to prove "(FoundLHS + C) s< (FoundRHS +
12450 // C)".
12451
12452 std::optional<APInt> LDiff = computeConstantDifference(LHS, FoundLHS);
12453 if (!LDiff)
12454 return false;
12455 std::optional<APInt> RDiff = computeConstantDifference(RHS, FoundRHS);
12456 if (!RDiff || *LDiff != *RDiff)
12457 return false;
12458
12459 if (LDiff->isMinValue())
12460 return true;
12461
12462 APInt FoundRHSLimit;
12463
12464 if (Pred == CmpInst::ICMP_ULT) {
12465 FoundRHSLimit = -(*RDiff);
12466 } else {
12467 assert(Pred == CmpInst::ICMP_SLT && "Checked above!");
12468 FoundRHSLimit = APInt::getSignedMinValue(getTypeSizeInBits(RHS->getType())) - *RDiff;
12469 }
12470
12471 // Try to prove (1) or (2), as needed.
12472 return isAvailableAtLoopEntry(FoundRHS, L) &&
12473 isLoopEntryGuardedByCond(L, Pred, FoundRHS,
12474 getConstant(FoundRHSLimit));
12475}
12476
12477bool ScalarEvolution::isImpliedViaMerge(CmpPredicate Pred, const SCEV *LHS,
12478 const SCEV *RHS, const SCEV *FoundLHS,
12479 const SCEV *FoundRHS, unsigned Depth) {
12480 const PHINode *LPhi = nullptr, *RPhi = nullptr;
12481
12482 llvm::scope_exit ClearOnExit([&]() {
12483 if (LPhi) {
12484 bool Erased = PendingMerges.erase(LPhi);
12485 assert(Erased && "Failed to erase LPhi!");
12486 (void)Erased;
12487 }
12488 if (RPhi) {
12489 bool Erased = PendingMerges.erase(RPhi);
12490 assert(Erased && "Failed to erase RPhi!");
12491 (void)Erased;
12492 }
12493 });
12494
12495 // Find respective Phis and check that they are not being pending.
12496 if (const SCEVUnknown *LU = dyn_cast<SCEVUnknown>(LHS))
12497 if (auto *Phi = dyn_cast<PHINode>(LU->getValue())) {
12498 if (!PendingMerges.insert(Phi).second)
12499 return false;
12500 LPhi = Phi;
12501 }
12502 if (const SCEVUnknown *RU = dyn_cast<SCEVUnknown>(RHS))
12503 if (auto *Phi = dyn_cast<PHINode>(RU->getValue())) {
12504 // If we detect a loop of Phi nodes being processed by this method, for
12505 // example:
12506 //
12507 // %a = phi i32 [ %some1, %preheader ], [ %b, %latch ]
12508 // %b = phi i32 [ %some2, %preheader ], [ %a, %latch ]
12509 //
12510 // we don't want to deal with a case that complex, so return conservative
12511 // answer false.
12512 if (!PendingMerges.insert(Phi).second)
12513 return false;
12514 RPhi = Phi;
12515 }
12516
12517 // If none of LHS, RHS is a Phi, nothing to do here.
12518 if (!LPhi && !RPhi)
12519 return false;
12520
12521 // If there is a SCEVUnknown Phi we are interested in, make it left.
12522 if (!LPhi) {
12523 std::swap(LHS, RHS);
12524 std::swap(FoundLHS, FoundRHS);
12525 std::swap(LPhi, RPhi);
12527 }
12528
12529 assert(LPhi && "LPhi should definitely be a SCEVUnknown Phi!");
12530 const BasicBlock *LBB = LPhi->getParent();
12531 const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
12532
12533 auto ProvedEasily = [&](const SCEV *S1, const SCEV *S2) {
12534 return isKnownViaNonRecursiveReasoning(Pred, S1, S2) ||
12535 isImpliedCondOperandsViaRanges(Pred, S1, S2, Pred, FoundLHS, FoundRHS) ||
12536 isImpliedViaOperations(Pred, S1, S2, FoundLHS, FoundRHS, Depth);
12537 };
12538
12539 if (RPhi && RPhi->getParent() == LBB) {
12540 // Case one: RHS is also a SCEVUnknown Phi from the same basic block.
12541 // If we compare two Phis from the same block, and for each entry block
12542 // the predicate is true for incoming values from this block, then the
12543 // predicate is also true for the Phis.
12544 for (const BasicBlock *IncBB : predecessors(LBB)) {
12545 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12546 const SCEV *R = getSCEV(RPhi->getIncomingValueForBlock(IncBB));
12547 if (!ProvedEasily(L, R))
12548 return false;
12549 }
12550 } else if (RAR && RAR->getLoop()->getHeader() == LBB) {
12551 // Case two: RHS is also a Phi from the same basic block, and it is an
12552 // AddRec. It means that there is a loop which has both AddRec and Unknown
12553 // PHIs, for it we can compare incoming values of AddRec from above the loop
12554 // and latch with their respective incoming values of LPhi.
12555 // TODO: Generalize to handle loops with many inputs in a header.
12556 if (LPhi->getNumIncomingValues() != 2) return false;
12557
12558 auto *RLoop = RAR->getLoop();
12559 auto *Predecessor = RLoop->getLoopPredecessor();
12560 assert(Predecessor && "Loop with AddRec with no predecessor?");
12561 const SCEV *L1 = getSCEV(LPhi->getIncomingValueForBlock(Predecessor));
12562 if (!ProvedEasily(L1, RAR->getStart()))
12563 return false;
12564 auto *Latch = RLoop->getLoopLatch();
12565 assert(Latch && "Loop with AddRec with no latch?");
12566 const SCEV *L2 = getSCEV(LPhi->getIncomingValueForBlock(Latch));
12567 if (!ProvedEasily(L2, RAR->getPostIncExpr(*this)))
12568 return false;
12569 } else {
12570 // In all other cases go over inputs of LHS and compare each of them to RHS,
12571 // the predicate is true for (LHS, RHS) if it is true for all such pairs.
12572 // At this point RHS is either a non-Phi, or it is a Phi from some block
12573 // different from LBB.
12574 for (const BasicBlock *IncBB : predecessors(LBB)) {
12575 // Check that RHS is available in this block.
12576 if (!dominates(RHS, IncBB))
12577 return false;
12578 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12579 // Make sure L does not refer to a value from a potentially previous
12580 // iteration of a loop.
12581 if (!properlyDominates(L, LBB))
12582 return false;
12583 // Addrecs are considered to properly dominate their loop, so are missed
12584 // by the previous check. Discard any values that have computable
12585 // evolution in this loop.
12586 if (auto *Loop = LI.getLoopFor(LBB))
12587 if (hasComputableLoopEvolution(L, Loop))
12588 return false;
12589 if (!ProvedEasily(L, RHS))
12590 return false;
12591 }
12592 }
12593 return true;
12594}
12595
12596bool ScalarEvolution::isImpliedCondOperandsViaShift(CmpPredicate Pred,
12597 const SCEV *LHS,
12598 const SCEV *RHS,
12599 const SCEV *FoundLHS,
12600 const SCEV *FoundRHS) {
12601 // We want to imply LHS < RHS from LHS < (RHS >> shiftvalue). First, make
12602 // sure that we are dealing with same LHS.
12603 if (RHS == FoundRHS) {
12604 std::swap(LHS, RHS);
12605 std::swap(FoundLHS, FoundRHS);
12607 }
12608 if (LHS != FoundLHS)
12609 return false;
12610
12611 auto *SUFoundRHS = dyn_cast<SCEVUnknown>(FoundRHS);
12612 if (!SUFoundRHS)
12613 return false;
12614
12615 Value *Shiftee, *ShiftValue;
12616
12617 using namespace PatternMatch;
12618 if (match(SUFoundRHS->getValue(),
12619 m_LShr(m_Value(Shiftee), m_Value(ShiftValue)))) {
12620 auto *ShifteeS = getSCEV(Shiftee);
12621 // Prove one of the following:
12622 // LHS <u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <u RHS
12623 // LHS <=u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <=u RHS
12624 // LHS <s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12625 // ---> LHS <s RHS
12626 // LHS <=s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12627 // ---> LHS <=s RHS
12628 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE)
12629 return isKnownPredicate(ICmpInst::ICMP_ULE, ShifteeS, RHS);
12630 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE)
12631 if (isKnownNonNegative(ShifteeS))
12632 return isKnownPredicate(ICmpInst::ICMP_SLE, ShifteeS, RHS);
12633 }
12634
12635 return false;
12636}
12637
12638bool ScalarEvolution::isImpliedCondOperands(CmpPredicate Pred, const SCEV *LHS,
12639 const SCEV *RHS,
12640 const SCEV *FoundLHS,
12641 const SCEV *FoundRHS,
12642 const Instruction *CtxI) {
12643 return isImpliedCondOperandsViaRanges(Pred, LHS, RHS, Pred, FoundLHS,
12644 FoundRHS) ||
12645 isImpliedCondOperandsViaNoOverflow(Pred, LHS, RHS, FoundLHS,
12646 FoundRHS) ||
12647 isImpliedCondOperandsViaShift(Pred, LHS, RHS, FoundLHS, FoundRHS) ||
12648 isImpliedCondOperandsViaAddRecStart(Pred, LHS, RHS, FoundLHS, FoundRHS,
12649 CtxI) ||
12650 isImpliedCondOperandsHelper(Pred, LHS, RHS, FoundLHS, FoundRHS);
12651}
12652
12653/// Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
12654template <typename MinMaxExprType>
12655static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr,
12656 const SCEV *Candidate) {
12657 const MinMaxExprType *MinMaxExpr = dyn_cast<MinMaxExprType>(MaybeMinMaxExpr);
12658 if (!MinMaxExpr)
12659 return false;
12660
12661 return is_contained(MinMaxExpr->operands(), Candidate);
12662}
12663
12665 CmpPredicate Pred, const SCEV *LHS,
12666 const SCEV *RHS) {
12667 // If both sides are affine addrecs for the same loop, with equal
12668 // steps, and we know the recurrences don't wrap, then we only
12669 // need to check the predicate on the starting values.
12670
12671 if (!ICmpInst::isRelational(Pred))
12672 return false;
12673
12674 const SCEV *LStart, *RStart, *Step;
12675 const Loop *L;
12676 if (!match(LHS,
12677 m_scev_AffineAddRec(m_SCEV(LStart), m_SCEV(Step), m_Loop(L))) ||
12679 m_SpecificLoop(L))))
12680 return false;
12685 if (!LAR->getNoWrapFlags(NW) || !RAR->getNoWrapFlags(NW))
12686 return false;
12687
12688 return SE.isKnownPredicate(Pred, LStart, RStart);
12689}
12690
12691/// Is LHS `Pred` RHS true on the virtue of LHS or RHS being a Min or Max
12692/// expression?
12694 const SCEV *LHS, const SCEV *RHS) {
12695 switch (Pred) {
12696 default:
12697 return false;
12698
12699 case ICmpInst::ICMP_SGE:
12700 std::swap(LHS, RHS);
12701 [[fallthrough]];
12702 case ICmpInst::ICMP_SLE:
12703 return
12704 // min(A, ...) <= A
12706 // A <= max(A, ...)
12708
12709 case ICmpInst::ICMP_UGE:
12710 std::swap(LHS, RHS);
12711 [[fallthrough]];
12712 case ICmpInst::ICMP_ULE:
12713 return
12714 // min(A, ...) <= A
12715 // FIXME: what about umin_seq?
12717 // A <= max(A, ...)
12719 }
12720
12721 llvm_unreachable("covered switch fell through?!");
12722}
12723
12724bool ScalarEvolution::isImpliedViaOperations(CmpPredicate Pred, const SCEV *LHS,
12725 const SCEV *RHS,
12726 const SCEV *FoundLHS,
12727 const SCEV *FoundRHS,
12728 unsigned Depth) {
12731 "LHS and RHS have different sizes?");
12732 assert(getTypeSizeInBits(FoundLHS->getType()) ==
12733 getTypeSizeInBits(FoundRHS->getType()) &&
12734 "FoundLHS and FoundRHS have different sizes?");
12735 // We want to avoid hurting the compile time with analysis of too big trees.
12737 return false;
12738
12739 // We only want to work with GT comparison so far.
12740 if (ICmpInst::isLT(Pred)) {
12742 std::swap(LHS, RHS);
12743 std::swap(FoundLHS, FoundRHS);
12744 }
12745
12747
12748 // For unsigned, try to reduce it to corresponding signed comparison.
12749 if (P == ICmpInst::ICMP_UGT)
12750 // We can replace unsigned predicate with its signed counterpart if all
12751 // involved values are non-negative.
12752 // TODO: We could have better support for unsigned.
12753 if (isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) {
12754 // Knowing that both FoundLHS and FoundRHS are non-negative, and knowing
12755 // FoundLHS >u FoundRHS, we also know that FoundLHS >s FoundRHS. Let us
12756 // use this fact to prove that LHS and RHS are non-negative.
12757 const SCEV *MinusOne = getMinusOne(LHS->getType());
12758 if (isImpliedCondOperands(ICmpInst::ICMP_SGT, LHS, MinusOne, FoundLHS,
12759 FoundRHS) &&
12760 isImpliedCondOperands(ICmpInst::ICMP_SGT, RHS, MinusOne, FoundLHS,
12761 FoundRHS))
12763 }
12764
12765 if (P != ICmpInst::ICMP_SGT)
12766 return false;
12767
12768 auto GetOpFromSExt = [&](const SCEV *S) {
12769 if (auto *Ext = dyn_cast<SCEVSignExtendExpr>(S))
12770 return Ext->getOperand();
12771 // TODO: If S is a SCEVConstant then you can cheaply "strip" the sext off
12772 // the constant in some cases.
12773 return S;
12774 };
12775
12776 // Acquire values from extensions.
12777 auto *OrigLHS = LHS;
12778 auto *OrigFoundLHS = FoundLHS;
12779 LHS = GetOpFromSExt(LHS);
12780 FoundLHS = GetOpFromSExt(FoundLHS);
12781
12782 // Is the SGT predicate can be proved trivially or using the found context.
12783 auto IsSGTViaContext = [&](const SCEV *S1, const SCEV *S2) {
12784 return isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGT, S1, S2) ||
12785 isImpliedViaOperations(ICmpInst::ICMP_SGT, S1, S2, OrigFoundLHS,
12786 FoundRHS, Depth + 1);
12787 };
12788
12789 if (auto *LHSAddExpr = dyn_cast<SCEVAddExpr>(LHS)) {
12790 // We want to avoid creation of any new non-constant SCEV. Since we are
12791 // going to compare the operands to RHS, we should be certain that we don't
12792 // need any size extensions for this. So let's decline all cases when the
12793 // sizes of types of LHS and RHS do not match.
12794 // TODO: Maybe try to get RHS from sext to catch more cases?
12796 return false;
12797
12798 // Should not overflow.
12799 if (!LHSAddExpr->hasNoSignedWrap())
12800 return false;
12801
12802 auto *LL = LHSAddExpr->getOperand(0);
12803 auto *LR = LHSAddExpr->getOperand(1);
12804 auto *MinusOne = getMinusOne(RHS->getType());
12805
12806 // Checks that S1 >= 0 && S2 > RHS, trivially or using the found context.
12807 auto IsSumGreaterThanRHS = [&](const SCEV *S1, const SCEV *S2) {
12808 return IsSGTViaContext(S1, MinusOne) && IsSGTViaContext(S2, RHS);
12809 };
12810 // Try to prove the following rule:
12811 // (LHS = LL + LR) && (LL >= 0) && (LR > RHS) => (LHS > RHS).
12812 // (LHS = LL + LR) && (LR >= 0) && (LL > RHS) => (LHS > RHS).
12813 if (IsSumGreaterThanRHS(LL, LR) || IsSumGreaterThanRHS(LR, LL))
12814 return true;
12815 } else if (auto *LHSUnknownExpr = dyn_cast<SCEVUnknown>(LHS)) {
12816 Value *LL, *LR;
12817 // FIXME: Once we have SDiv implemented, we can get rid of this matching.
12818
12819 using namespace llvm::PatternMatch;
12820
12821 if (match(LHSUnknownExpr->getValue(), m_SDiv(m_Value(LL), m_Value(LR)))) {
12822 // Rules for division.
12823 // We are going to perform some comparisons with Denominator and its
12824 // derivative expressions. In general case, creating a SCEV for it may
12825 // lead to a complex analysis of the entire graph, and in particular it
12826 // can request trip count recalculation for the same loop. This would
12827 // cache as SCEVCouldNotCompute to avoid the infinite recursion. To avoid
12828 // this, we only want to create SCEVs that are constants in this section.
12829 // So we bail if Denominator is not a constant.
12830 if (!isa<ConstantInt>(LR))
12831 return false;
12832
12833 auto *Denominator = cast<SCEVConstant>(getSCEV(LR));
12834
12835 // We want to make sure that LHS = FoundLHS / Denominator. If it is so,
12836 // then a SCEV for the numerator already exists and matches with FoundLHS.
12837 auto *Numerator = getExistingSCEV(LL);
12838 if (!Numerator || Numerator->getType() != FoundLHS->getType())
12839 return false;
12840
12841 // Make sure that the numerator matches with FoundLHS and the denominator
12842 // is positive.
12843 if (!HasSameValue(Numerator, FoundLHS) || !isKnownPositive(Denominator))
12844 return false;
12845
12846 auto *DTy = Denominator->getType();
12847 auto *FRHSTy = FoundRHS->getType();
12848 if (DTy->isPointerTy() != FRHSTy->isPointerTy())
12849 // One of types is a pointer and another one is not. We cannot extend
12850 // them properly to a wider type, so let us just reject this case.
12851 // TODO: Usage of getEffectiveSCEVType for DTy, FRHSTy etc should help
12852 // to avoid this check.
12853 return false;
12854
12855 // Given that:
12856 // FoundLHS > FoundRHS, LHS = FoundLHS / Denominator, Denominator > 0.
12857 auto *WTy = getWiderType(DTy, FRHSTy);
12858 auto *DenominatorExt = getNoopOrSignExtend(Denominator, WTy);
12859 auto *FoundRHSExt = getNoopOrSignExtend(FoundRHS, WTy);
12860
12861 // Try to prove the following rule:
12862 // (FoundRHS > Denominator - 2) && (RHS <= 0) => (LHS > RHS).
12863 // For example, given that FoundLHS > 2. It means that FoundLHS is at
12864 // least 3. If we divide it by Denominator < 4, we will have at least 1.
12865 auto *DenomMinusTwo = getMinusSCEV(DenominatorExt, getConstant(WTy, 2));
12866 if (isKnownNonPositive(RHS) &&
12867 IsSGTViaContext(FoundRHSExt, DenomMinusTwo))
12868 return true;
12869
12870 // Try to prove the following rule:
12871 // (FoundRHS > -1 - Denominator) && (RHS < 0) => (LHS > RHS).
12872 // For example, given that FoundLHS > -3. Then FoundLHS is at least -2.
12873 // If we divide it by Denominator > 2, then:
12874 // 1. If FoundLHS is negative, then the result is 0.
12875 // 2. If FoundLHS is non-negative, then the result is non-negative.
12876 // Anyways, the result is non-negative.
12877 auto *MinusOne = getMinusOne(WTy);
12878 auto *NegDenomMinusOne = getMinusSCEV(MinusOne, DenominatorExt);
12879 if (isKnownNegative(RHS) &&
12880 IsSGTViaContext(FoundRHSExt, NegDenomMinusOne))
12881 return true;
12882 }
12883 }
12884
12885 // If our expression contained SCEVUnknown Phis, and we split it down and now
12886 // need to prove something for them, try to prove the predicate for every
12887 // possible incoming values of those Phis.
12888 if (isImpliedViaMerge(Pred, OrigLHS, RHS, OrigFoundLHS, FoundRHS, Depth + 1))
12889 return true;
12890
12891 return false;
12892}
12893
12895 const SCEV *RHS) {
12896 // zext x u<= sext x, sext x s<= zext x
12897 const SCEV *Op;
12898 switch (Pred) {
12899 case ICmpInst::ICMP_SGE:
12900 std::swap(LHS, RHS);
12901 [[fallthrough]];
12902 case ICmpInst::ICMP_SLE: {
12903 // If operand >=s 0 then ZExt == SExt. If operand <s 0 then SExt <s ZExt.
12904 return match(LHS, m_scev_SExt(m_SCEV(Op))) &&
12906 }
12907 case ICmpInst::ICMP_UGE:
12908 std::swap(LHS, RHS);
12909 [[fallthrough]];
12910 case ICmpInst::ICMP_ULE: {
12911 // If operand >=u 0 then ZExt == SExt. If operand <u 0 then ZExt <u SExt.
12912 return match(LHS, m_scev_ZExt(m_SCEV(Op))) &&
12914 }
12915 default:
12916 return false;
12917 };
12918 llvm_unreachable("unhandled case");
12919}
12920
12921bool ScalarEvolution::isKnownViaNonRecursiveReasoning(CmpPredicate Pred,
12922 const SCEV *LHS,
12923 const SCEV *RHS) {
12924 return isKnownPredicateExtendIdiom(Pred, LHS, RHS) ||
12925 isKnownPredicateViaConstantRanges(Pred, LHS, RHS) ||
12926 IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) ||
12927 IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) ||
12928 isKnownPredicateViaNoOverflow(Pred, LHS, RHS);
12929}
12930
12931bool ScalarEvolution::isImpliedCondOperandsHelper(CmpPredicate Pred,
12932 const SCEV *LHS,
12933 const SCEV *RHS,
12934 const SCEV *FoundLHS,
12935 const SCEV *FoundRHS) {
12936 switch (Pred) {
12937 default:
12938 llvm_unreachable("Unexpected CmpPredicate value!");
12939 case ICmpInst::ICMP_EQ:
12940 case ICmpInst::ICMP_NE:
12941 if (HasSameValue(LHS, FoundLHS) && HasSameValue(RHS, FoundRHS))
12942 return true;
12943 break;
12944 case ICmpInst::ICMP_SLT:
12945 case ICmpInst::ICMP_SLE:
12946 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, LHS, FoundLHS) &&
12947 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, RHS, FoundRHS))
12948 return true;
12949 break;
12950 case ICmpInst::ICMP_SGT:
12951 case ICmpInst::ICMP_SGE:
12952 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, LHS, FoundLHS) &&
12953 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, RHS, FoundRHS))
12954 return true;
12955 break;
12956 case ICmpInst::ICMP_ULT:
12957 case ICmpInst::ICMP_ULE:
12958 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, LHS, FoundLHS) &&
12959 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, RHS, FoundRHS))
12960 return true;
12961 break;
12962 case ICmpInst::ICMP_UGT:
12963 case ICmpInst::ICMP_UGE:
12964 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, LHS, FoundLHS) &&
12965 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, RHS, FoundRHS))
12966 return true;
12967 break;
12968 }
12969
12970 // Maybe it can be proved via operations?
12971 if (isImpliedViaOperations(Pred, LHS, RHS, FoundLHS, FoundRHS))
12972 return true;
12973
12974 return false;
12975}
12976
12977bool ScalarEvolution::isImpliedCondOperandsViaRanges(
12978 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, CmpPredicate FoundPred,
12979 const SCEV *FoundLHS, const SCEV *FoundRHS) {
12980 if (!isa<SCEVConstant>(RHS) || !isa<SCEVConstant>(FoundRHS))
12981 // The restriction on `FoundRHS` be lifted easily -- it exists only to
12982 // reduce the compile time impact of this optimization.
12983 return false;
12984
12985 std::optional<APInt> Addend = computeConstantDifference(LHS, FoundLHS);
12986 if (!Addend)
12987 return false;
12988
12989 const APInt &ConstFoundRHS = cast<SCEVConstant>(FoundRHS)->getAPInt();
12990
12991 // `FoundLHSRange` is the range we know `FoundLHS` to be in by virtue of the
12992 // antecedent "`FoundLHS` `FoundPred` `FoundRHS`".
12993 ConstantRange FoundLHSRange =
12994 ConstantRange::makeExactICmpRegion(FoundPred, ConstFoundRHS);
12995
12996 // Since `LHS` is `FoundLHS` + `Addend`, we can compute a range for `LHS`:
12997 ConstantRange LHSRange = FoundLHSRange.add(ConstantRange(*Addend));
12998
12999 // We can also compute the range of values for `LHS` that satisfy the
13000 // consequent, "`LHS` `Pred` `RHS`":
13001 const APInt &ConstRHS = cast<SCEVConstant>(RHS)->getAPInt();
13002 // The antecedent implies the consequent if every value of `LHS` that
13003 // satisfies the antecedent also satisfies the consequent.
13004 return LHSRange.icmp(Pred, ConstRHS);
13005}
13006
13007bool ScalarEvolution::canIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride,
13008 bool IsSigned) {
13009 assert(isKnownPositive(Stride) && "Positive stride expected!");
13010
13011 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
13012 const SCEV *One = getOne(Stride->getType());
13013
13014 if (IsSigned) {
13015 APInt MaxRHS = getSignedRangeMax(RHS);
13016 APInt MaxValue = APInt::getSignedMaxValue(BitWidth);
13017 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
13018
13019 // SMaxRHS + SMaxStrideMinusOne > SMaxValue => overflow!
13020 return (std::move(MaxValue) - MaxStrideMinusOne).slt(MaxRHS);
13021 }
13022
13023 APInt MaxRHS = getUnsignedRangeMax(RHS);
13024 APInt MaxValue = APInt::getMaxValue(BitWidth);
13025 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
13026
13027 // UMaxRHS + UMaxStrideMinusOne > UMaxValue => overflow!
13028 return (std::move(MaxValue) - MaxStrideMinusOne).ult(MaxRHS);
13029}
13030
13031bool ScalarEvolution::canIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride,
13032 bool IsSigned) {
13033
13034 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
13035 const SCEV *One = getOne(Stride->getType());
13036
13037 if (IsSigned) {
13038 APInt MinRHS = getSignedRangeMin(RHS);
13039 APInt MinValue = APInt::getSignedMinValue(BitWidth);
13040 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
13041
13042 // SMinRHS - SMaxStrideMinusOne < SMinValue => overflow!
13043 return (std::move(MinValue) + MaxStrideMinusOne).sgt(MinRHS);
13044 }
13045
13046 APInt MinRHS = getUnsignedRangeMin(RHS);
13047 APInt MinValue = APInt::getMinValue(BitWidth);
13048 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
13049
13050 // UMinRHS - UMaxStrideMinusOne < UMinValue => overflow!
13051 return (std::move(MinValue) + MaxStrideMinusOne).ugt(MinRHS);
13052}
13053
13055 // umin(N, 1) + floor((N - umin(N, 1)) / D)
13056 // This is equivalent to "1 + floor((N - 1) / D)" for N != 0. The umin
13057 // expression fixes the case of N=0.
13058 const SCEV *MinNOne = getUMinExpr(N, getOne(N->getType()));
13059 const SCEV *NMinusOne = getMinusSCEV(N, MinNOne);
13060 return getAddExpr(MinNOne, getUDivExpr(NMinusOne, D));
13061}
13062
13063const SCEV *ScalarEvolution::computeMaxBECountForLT(const SCEV *Start,
13064 const SCEV *Stride,
13065 const SCEV *End,
13066 unsigned BitWidth,
13067 bool IsSigned) {
13068 // The logic in this function assumes we can represent a positive stride.
13069 // If we can't, the backedge-taken count must be zero.
13070 if (IsSigned && BitWidth == 1)
13071 return getZero(Stride->getType());
13072
13073 // This code below only been closely audited for negative strides in the
13074 // unsigned comparison case, it may be correct for signed comparison, but
13075 // that needs to be established.
13076 if (IsSigned && isKnownNegative(Stride))
13077 return getCouldNotCompute();
13078
13079 // Calculate the maximum backedge count based on the range of values
13080 // permitted by Start, End, and Stride.
13081 APInt MinStart =
13082 IsSigned ? getSignedRangeMin(Start) : getUnsignedRangeMin(Start);
13083
13084 APInt MinStride =
13085 IsSigned ? getSignedRangeMin(Stride) : getUnsignedRangeMin(Stride);
13086
13087 // We assume either the stride is positive, or the backedge-taken count
13088 // is zero. So force StrideForMaxBECount to be at least one.
13089 APInt One(BitWidth, 1);
13090 APInt StrideForMaxBECount = IsSigned ? APIntOps::smax(One, MinStride)
13091 : APIntOps::umax(One, MinStride);
13092
13093 APInt MaxValue = IsSigned ? APInt::getSignedMaxValue(BitWidth)
13094 : APInt::getMaxValue(BitWidth);
13095 APInt Limit = MaxValue - (StrideForMaxBECount - 1);
13096
13097 // Although End can be a MAX expression we estimate MaxEnd considering only
13098 // the case End = RHS of the loop termination condition. This is safe because
13099 // in the other case (End - Start) is zero, leading to a zero maximum backedge
13100 // taken count.
13101 APInt MaxEnd = IsSigned ? APIntOps::smin(getSignedRangeMax(End), Limit)
13102 : APIntOps::umin(getUnsignedRangeMax(End), Limit);
13103
13104 // MaxBECount = ceil((max(MaxEnd, MinStart) - MinStart) / Stride)
13105 MaxEnd = IsSigned ? APIntOps::smax(MaxEnd, MinStart)
13106 : APIntOps::umax(MaxEnd, MinStart);
13107
13108 return getUDivCeilSCEV(getConstant(MaxEnd - MinStart) /* Delta */,
13109 getConstant(StrideForMaxBECount) /* Step */);
13110}
13111
13113ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
13114 const Loop *L, bool IsSigned,
13115 bool ControlsOnlyExit, bool AllowPredicates) {
13117
13119 bool PredicatedIV = false;
13120 if (!IV) {
13121 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS)) {
13122 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(ZExt->getOperand());
13123 if (AR && AR->getLoop() == L && AR->isAffine()) {
13124 auto canProveNUW = [&]() {
13125 // We can use the comparison to infer no-wrap flags only if it fully
13126 // controls the loop exit.
13127 if (!ControlsOnlyExit)
13128 return false;
13129
13130 if (!isLoopInvariant(RHS, L))
13131 return false;
13132
13133 if (!isKnownNonZero(AR->getStepRecurrence(*this)))
13134 // We need the sequence defined by AR to strictly increase in the
13135 // unsigned integer domain for the logic below to hold.
13136 return false;
13137
13138 const unsigned InnerBitWidth = getTypeSizeInBits(AR->getType());
13139 const unsigned OuterBitWidth = getTypeSizeInBits(RHS->getType());
13140 // If RHS <=u Limit, then there must exist a value V in the sequence
13141 // defined by AR (e.g. {Start,+,Step}) such that V >u RHS, and
13142 // V <=u UINT_MAX. Thus, we must exit the loop before unsigned
13143 // overflow occurs. This limit also implies that a signed comparison
13144 // (in the wide bitwidth) is equivalent to an unsigned comparison as
13145 // the high bits on both sides must be zero.
13146 APInt StrideMax = getUnsignedRangeMax(AR->getStepRecurrence(*this));
13147 APInt Limit = APInt::getMaxValue(InnerBitWidth) - (StrideMax - 1);
13148 Limit = Limit.zext(OuterBitWidth);
13149 return getUnsignedRangeMax(applyLoopGuards(RHS, L)).ule(Limit);
13150 };
13151 auto Flags = AR->getNoWrapFlags();
13152 if (!hasFlags(Flags, SCEV::FlagNUW) && canProveNUW())
13153 Flags = setFlags(Flags, SCEV::FlagNUW);
13154
13155 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
13156 if (AR->hasNoUnsignedWrap()) {
13157 // Emulate what getZeroExtendExpr would have done during construction
13158 // if we'd been able to infer the fact just above at that time.
13159 const SCEV *Step = AR->getStepRecurrence(*this);
13160 Type *Ty = ZExt->getType();
13161 auto *S = getAddRecExpr(
13163 getZeroExtendExpr(Step, Ty, 0), L, AR->getNoWrapFlags());
13165 }
13166 }
13167 }
13168 }
13169
13170
13171 if (!IV && AllowPredicates) {
13172 // Try to make this an AddRec using runtime tests, in the first X
13173 // iterations of this loop, where X is the SCEV expression found by the
13174 // algorithm below.
13175 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
13176 PredicatedIV = true;
13177 }
13178
13179 // Avoid weird loops
13180 if (!IV || IV->getLoop() != L || !IV->isAffine())
13181 return getCouldNotCompute();
13182
13183 // A precondition of this method is that the condition being analyzed
13184 // reaches an exiting branch which dominates the latch. Given that, we can
13185 // assume that an increment which violates the nowrap specification and
13186 // produces poison must cause undefined behavior when the resulting poison
13187 // value is branched upon and thus we can conclude that the backedge is
13188 // taken no more often than would be required to produce that poison value.
13189 // Note that a well defined loop can exit on the iteration which violates
13190 // the nowrap specification if there is another exit (either explicit or
13191 // implicit/exceptional) which causes the loop to execute before the
13192 // exiting instruction we're analyzing would trigger UB.
13193 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
13194 bool NoWrap = ControlsOnlyExit && IV->getNoWrapFlags(WrapType);
13196
13197 const SCEV *Stride = IV->getStepRecurrence(*this);
13198
13199 bool PositiveStride = isKnownPositive(Stride);
13200
13201 // Avoid negative or zero stride values.
13202 if (!PositiveStride) {
13203 // We can compute the correct backedge taken count for loops with unknown
13204 // strides if we can prove that the loop is not an infinite loop with side
13205 // effects. Here's the loop structure we are trying to handle -
13206 //
13207 // i = start
13208 // do {
13209 // A[i] = i;
13210 // i += s;
13211 // } while (i < end);
13212 //
13213 // The backedge taken count for such loops is evaluated as -
13214 // (max(end, start + stride) - start - 1) /u stride
13215 //
13216 // The additional preconditions that we need to check to prove correctness
13217 // of the above formula is as follows -
13218 //
13219 // a) IV is either nuw or nsw depending upon signedness (indicated by the
13220 // NoWrap flag).
13221 // b) the loop is guaranteed to be finite (e.g. is mustprogress and has
13222 // no side effects within the loop)
13223 // c) loop has a single static exit (with no abnormal exits)
13224 //
13225 // Precondition a) implies that if the stride is negative, this is a single
13226 // trip loop. The backedge taken count formula reduces to zero in this case.
13227 //
13228 // Precondition b) and c) combine to imply that if rhs is invariant in L,
13229 // then a zero stride means the backedge can't be taken without executing
13230 // undefined behavior.
13231 //
13232 // The positive stride case is the same as isKnownPositive(Stride) returning
13233 // true (original behavior of the function).
13234 //
13235 if (PredicatedIV || !NoWrap || !loopIsFiniteByAssumption(L) ||
13237 return getCouldNotCompute();
13238
13239 if (!isKnownNonZero(Stride)) {
13240 // If we have a step of zero, and RHS isn't invariant in L, we don't know
13241 // if it might eventually be greater than start and if so, on which
13242 // iteration. We can't even produce a useful upper bound.
13243 if (!isLoopInvariant(RHS, L))
13244 return getCouldNotCompute();
13245
13246 // We allow a potentially zero stride, but we need to divide by stride
13247 // below. Since the loop can't be infinite and this check must control
13248 // the sole exit, we can infer the exit must be taken on the first
13249 // iteration (e.g. backedge count = 0) if the stride is zero. Given that,
13250 // we know the numerator in the divides below must be zero, so we can
13251 // pick an arbitrary non-zero value for the denominator (e.g. stride)
13252 // and produce the right result.
13253 // FIXME: Handle the case where Stride is poison?
13254 auto wouldZeroStrideBeUB = [&]() {
13255 // Proof by contradiction. Suppose the stride were zero. If we can
13256 // prove that the backedge *is* taken on the first iteration, then since
13257 // we know this condition controls the sole exit, we must have an
13258 // infinite loop. We can't have a (well defined) infinite loop per
13259 // check just above.
13260 // Note: The (Start - Stride) term is used to get the start' term from
13261 // (start' + stride,+,stride). Remember that we only care about the
13262 // result of this expression when stride == 0 at runtime.
13263 auto *StartIfZero = getMinusSCEV(IV->getStart(), Stride);
13264 return isLoopEntryGuardedByCond(L, Cond, StartIfZero, RHS);
13265 };
13266 if (!wouldZeroStrideBeUB()) {
13267 Stride = getUMaxExpr(Stride, getOne(Stride->getType()));
13268 }
13269 }
13270 } else if (!NoWrap) {
13271 // Avoid proven overflow cases: this will ensure that the backedge taken
13272 // count will not generate any unsigned overflow.
13273 if (canIVOverflowOnLT(RHS, Stride, IsSigned))
13274 return getCouldNotCompute();
13275 }
13276
13277 // On all paths just preceeding, we established the following invariant:
13278 // IV can be assumed not to overflow up to and including the exiting
13279 // iteration. We proved this in one of two ways:
13280 // 1) We can show overflow doesn't occur before the exiting iteration
13281 // 1a) canIVOverflowOnLT, and b) step of one
13282 // 2) We can show that if overflow occurs, the loop must execute UB
13283 // before any possible exit.
13284 // Note that we have not yet proved RHS invariant (in general).
13285
13286 const SCEV *Start = IV->getStart();
13287
13288 // Preserve pointer-typed Start/RHS to pass to isLoopEntryGuardedByCond.
13289 // If we convert to integers, isLoopEntryGuardedByCond will miss some cases.
13290 // Use integer-typed versions for actual computation; we can't subtract
13291 // pointers in general.
13292 const SCEV *OrigStart = Start;
13293 const SCEV *OrigRHS = RHS;
13294 if (Start->getType()->isPointerTy()) {
13296 if (isa<SCEVCouldNotCompute>(Start))
13297 return Start;
13298 }
13299 if (RHS->getType()->isPointerTy()) {
13302 return RHS;
13303 }
13304
13305 const SCEV *End = nullptr, *BECount = nullptr,
13306 *BECountIfBackedgeTaken = nullptr;
13307 if (!isLoopInvariant(RHS, L)) {
13308 const auto *RHSAddRec = dyn_cast<SCEVAddRecExpr>(RHS);
13309 if (PositiveStride && RHSAddRec != nullptr && RHSAddRec->getLoop() == L &&
13310 RHSAddRec->getNoWrapFlags()) {
13311 // The structure of loop we are trying to calculate backedge count of:
13312 //
13313 // left = left_start
13314 // right = right_start
13315 //
13316 // while(left < right){
13317 // ... do something here ...
13318 // left += s1; // stride of left is s1 (s1 > 0)
13319 // right += s2; // stride of right is s2 (s2 < 0)
13320 // }
13321 //
13322
13323 const SCEV *RHSStart = RHSAddRec->getStart();
13324 const SCEV *RHSStride = RHSAddRec->getStepRecurrence(*this);
13325
13326 // If Stride - RHSStride is positive and does not overflow, we can write
13327 // backedge count as ->
13328 // ceil((End - Start) /u (Stride - RHSStride))
13329 // Where, End = max(RHSStart, Start)
13330
13331 // Check if RHSStride < 0 and Stride - RHSStride will not overflow.
13332 if (isKnownNegative(RHSStride) &&
13333 willNotOverflow(Instruction::Sub, /*Signed=*/true, Stride,
13334 RHSStride)) {
13335
13336 const SCEV *Denominator = getMinusSCEV(Stride, RHSStride);
13337 if (isKnownPositive(Denominator)) {
13338 End = IsSigned ? getSMaxExpr(RHSStart, Start)
13339 : getUMaxExpr(RHSStart, Start);
13340
13341 // We can do this because End >= Start, as End = max(RHSStart, Start)
13342 const SCEV *Delta = getMinusSCEV(End, Start);
13343
13344 BECount = getUDivCeilSCEV(Delta, Denominator);
13345 BECountIfBackedgeTaken =
13346 getUDivCeilSCEV(getMinusSCEV(RHSStart, Start), Denominator);
13347 }
13348 }
13349 }
13350 if (BECount == nullptr) {
13351 // If we cannot calculate ExactBECount, we can calculate the MaxBECount,
13352 // given the start, stride and max value for the end bound of the
13353 // loop (RHS), and the fact that IV does not overflow (which is
13354 // checked above).
13355 const SCEV *MaxBECount = computeMaxBECountForLT(
13356 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13357 return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,
13358 MaxBECount, false /*MaxOrZero*/, Predicates);
13359 }
13360 } else {
13361 // We use the expression (max(End,Start)-Start)/Stride to describe the
13362 // backedge count, as if the backedge is taken at least once
13363 // max(End,Start) is End and so the result is as above, and if not
13364 // max(End,Start) is Start so we get a backedge count of zero.
13365 auto *OrigStartMinusStride = getMinusSCEV(OrigStart, Stride);
13366 assert(isAvailableAtLoopEntry(OrigStartMinusStride, L) && "Must be!");
13367 assert(isAvailableAtLoopEntry(OrigStart, L) && "Must be!");
13368 assert(isAvailableAtLoopEntry(OrigRHS, L) && "Must be!");
13369 // Can we prove (max(RHS,Start) > Start - Stride?
13370 if (isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigStart) &&
13371 isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigRHS)) {
13372 // In this case, we can use a refined formula for computing backedge
13373 // taken count. The general formula remains:
13374 // "End-Start /uceiling Stride" where "End = max(RHS,Start)"
13375 // We want to use the alternate formula:
13376 // "((End - 1) - (Start - Stride)) /u Stride"
13377 // Let's do a quick case analysis to show these are equivalent under
13378 // our precondition that max(RHS,Start) > Start - Stride.
13379 // * For RHS <= Start, the backedge-taken count must be zero.
13380 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
13381 // "((Start - 1) - (Start - Stride)) /u Stride" which simplies to
13382 // "Stride - 1 /u Stride" which is indeed zero for all non-zero values
13383 // of Stride. For 0 stride, we've use umin(1,Stride) above,
13384 // reducing this to the stride of 1 case.
13385 // * For RHS >= Start, the backedge count must be "RHS-Start /uceil
13386 // Stride".
13387 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
13388 // "((RHS - 1) - (Start - Stride)) /u Stride" reassociates to
13389 // "((RHS - (Start - Stride) - 1) /u Stride".
13390 // Our preconditions trivially imply no overflow in that form.
13391 const SCEV *MinusOne = getMinusOne(Stride->getType());
13392 const SCEV *Numerator =
13393 getMinusSCEV(getAddExpr(RHS, MinusOne), getMinusSCEV(Start, Stride));
13394 BECount = getUDivExpr(Numerator, Stride);
13395 }
13396
13397 if (!BECount) {
13398 auto canProveRHSGreaterThanEqualStart = [&]() {
13399 auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
13400 const SCEV *GuardedRHS = applyLoopGuards(OrigRHS, L);
13401 const SCEV *GuardedStart = applyLoopGuards(OrigStart, L);
13402
13403 if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart) ||
13404 isKnownPredicate(CondGE, GuardedRHS, GuardedStart))
13405 return true;
13406
13407 // (RHS > Start - 1) implies RHS >= Start.
13408 // * "RHS >= Start" is trivially equivalent to "RHS > Start - 1" if
13409 // "Start - 1" doesn't overflow.
13410 // * For signed comparison, if Start - 1 does overflow, it's equal
13411 // to INT_MAX, and "RHS >s INT_MAX" is trivially false.
13412 // * For unsigned comparison, if Start - 1 does overflow, it's equal
13413 // to UINT_MAX, and "RHS >u UINT_MAX" is trivially false.
13414 //
13415 // FIXME: Should isLoopEntryGuardedByCond do this for us?
13416 auto CondGT = IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT;
13417 auto *StartMinusOne =
13418 getAddExpr(OrigStart, getMinusOne(OrigStart->getType()));
13419 return isLoopEntryGuardedByCond(L, CondGT, OrigRHS, StartMinusOne);
13420 };
13421
13422 // If we know that RHS >= Start in the context of loop, then we know
13423 // that max(RHS, Start) = RHS at this point.
13424 if (canProveRHSGreaterThanEqualStart()) {
13425 End = RHS;
13426 } else {
13427 // If RHS < Start, the backedge will be taken zero times. So in
13428 // general, we can write the backedge-taken count as:
13429 //
13430 // RHS >= Start ? ceil(RHS - Start) / Stride : 0
13431 //
13432 // We convert it to the following to make it more convenient for SCEV:
13433 //
13434 // ceil(max(RHS, Start) - Start) / Stride
13435 End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start);
13436
13437 // See what would happen if we assume the backedge is taken. This is
13438 // used to compute MaxBECount.
13439 BECountIfBackedgeTaken =
13440 getUDivCeilSCEV(getMinusSCEV(RHS, Start), Stride);
13441 }
13442
13443 // At this point, we know:
13444 //
13445 // 1. If IsSigned, Start <=s End; otherwise, Start <=u End
13446 // 2. The index variable doesn't overflow.
13447 //
13448 // Therefore, we know N exists such that
13449 // (Start + Stride * N) >= End, and computing "(Start + Stride * N)"
13450 // doesn't overflow.
13451 //
13452 // Using this information, try to prove whether the addition in
13453 // "(Start - End) + (Stride - 1)" has unsigned overflow.
13454 const SCEV *One = getOne(Stride->getType());
13455 bool MayAddOverflow = [&] {
13456 if (isKnownToBeAPowerOfTwo(Stride)) {
13457 // Suppose Stride is a power of two, and Start/End are unsigned
13458 // integers. Let UMAX be the largest representable unsigned
13459 // integer.
13460 //
13461 // By the preconditions of this function, we know
13462 // "(Start + Stride * N) >= End", and this doesn't overflow.
13463 // As a formula:
13464 //
13465 // End <= (Start + Stride * N) <= UMAX
13466 //
13467 // Subtracting Start from all the terms:
13468 //
13469 // End - Start <= Stride * N <= UMAX - Start
13470 //
13471 // Since Start is unsigned, UMAX - Start <= UMAX. Therefore:
13472 //
13473 // End - Start <= Stride * N <= UMAX
13474 //
13475 // Stride * N is a multiple of Stride. Therefore,
13476 //
13477 // End - Start <= Stride * N <= UMAX - (UMAX mod Stride)
13478 //
13479 // Since Stride is a power of two, UMAX + 1 is divisible by
13480 // Stride. Therefore, UMAX mod Stride == Stride - 1. So we can
13481 // write:
13482 //
13483 // End - Start <= Stride * N <= UMAX - Stride - 1
13484 //
13485 // Dropping the middle term:
13486 //
13487 // End - Start <= UMAX - Stride - 1
13488 //
13489 // Adding Stride - 1 to both sides:
13490 //
13491 // (End - Start) + (Stride - 1) <= UMAX
13492 //
13493 // In other words, the addition doesn't have unsigned overflow.
13494 //
13495 // A similar proof works if we treat Start/End as signed values.
13496 // Just rewrite steps before "End - Start <= Stride * N <= UMAX"
13497 // to use signed max instead of unsigned max. Note that we're
13498 // trying to prove a lack of unsigned overflow in either case.
13499 return false;
13500 }
13501 if (Start == Stride || Start == getMinusSCEV(Stride, One)) {
13502 // If Start is equal to Stride, (End - Start) + (Stride - 1) == End
13503 // - 1. If !IsSigned, 0 <u Stride == Start <=u End; so 0 <u End - 1
13504 // <u End. If IsSigned, 0 <s Stride == Start <=s End; so 0 <s End -
13505 // 1 <s End.
13506 //
13507 // If Start is equal to Stride - 1, (End - Start) + Stride - 1 ==
13508 // End.
13509 return false;
13510 }
13511 return true;
13512 }();
13513
13514 const SCEV *Delta = getMinusSCEV(End, Start);
13515 if (!MayAddOverflow) {
13516 // floor((D + (S - 1)) / S)
13517 // We prefer this formulation if it's legal because it's fewer
13518 // operations.
13519 BECount =
13520 getUDivExpr(getAddExpr(Delta, getMinusSCEV(Stride, One)), Stride);
13521 } else {
13522 BECount = getUDivCeilSCEV(Delta, Stride);
13523 }
13524 }
13525 }
13526
13527 const SCEV *ConstantMaxBECount;
13528 bool MaxOrZero = false;
13529 if (isa<SCEVConstant>(BECount)) {
13530 ConstantMaxBECount = BECount;
13531 } else if (BECountIfBackedgeTaken &&
13532 isa<SCEVConstant>(BECountIfBackedgeTaken)) {
13533 // If we know exactly how many times the backedge will be taken if it's
13534 // taken at least once, then the backedge count will either be that or
13535 // zero.
13536 ConstantMaxBECount = BECountIfBackedgeTaken;
13537 MaxOrZero = true;
13538 } else {
13539 ConstantMaxBECount = computeMaxBECountForLT(
13540 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13541 }
13542
13543 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
13544 !isa<SCEVCouldNotCompute>(BECount))
13545 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
13546
13547 const SCEV *SymbolicMaxBECount =
13548 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13549 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, MaxOrZero,
13550 Predicates);
13551}
13552
13553ScalarEvolution::ExitLimit ScalarEvolution::howManyGreaterThans(
13554 const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned,
13555 bool ControlsOnlyExit, bool AllowPredicates) {
13557 // We handle only IV > Invariant
13558 if (!isLoopInvariant(RHS, L))
13559 return getCouldNotCompute();
13560
13561 const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
13562 if (!IV && AllowPredicates)
13563 // Try to make this an AddRec using runtime tests, in the first X
13564 // iterations of this loop, where X is the SCEV expression found by the
13565 // algorithm below.
13566 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
13567
13568 // Avoid weird loops
13569 if (!IV || IV->getLoop() != L || !IV->isAffine())
13570 return getCouldNotCompute();
13571
13572 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
13573 bool NoWrap = ControlsOnlyExit && IV->getNoWrapFlags(WrapType);
13575
13576 const SCEV *Stride = getNegativeSCEV(IV->getStepRecurrence(*this));
13577
13578 // Avoid negative or zero stride values
13579 if (!isKnownPositive(Stride))
13580 return getCouldNotCompute();
13581
13582 // Avoid proven overflow cases: this will ensure that the backedge taken count
13583 // will not generate any unsigned overflow. Relaxed no-overflow conditions
13584 // exploit NoWrapFlags, allowing to optimize in presence of undefined
13585 // behaviors like the case of C language.
13586 if (!Stride->isOne() && !NoWrap)
13587 if (canIVOverflowOnGT(RHS, Stride, IsSigned))
13588 return getCouldNotCompute();
13589
13590 const SCEV *Start = IV->getStart();
13591 const SCEV *End = RHS;
13592 if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) {
13593 // If we know that Start >= RHS in the context of loop, then we know that
13594 // min(RHS, Start) = RHS at this point.
13596 L, IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE, Start, RHS))
13597 End = RHS;
13598 else
13599 End = IsSigned ? getSMinExpr(RHS, Start) : getUMinExpr(RHS, Start);
13600 }
13601
13602 if (Start->getType()->isPointerTy()) {
13604 if (isa<SCEVCouldNotCompute>(Start))
13605 return Start;
13606 }
13607 if (End->getType()->isPointerTy()) {
13608 End = getLosslessPtrToIntExpr(End);
13609 if (isa<SCEVCouldNotCompute>(End))
13610 return End;
13611 }
13612
13613 // Compute ((Start - End) + (Stride - 1)) / Stride.
13614 // FIXME: This can overflow. Holding off on fixing this for now;
13615 // howManyGreaterThans will hopefully be gone soon.
13616 const SCEV *One = getOne(Stride->getType());
13617 const SCEV *BECount = getUDivExpr(
13618 getAddExpr(getMinusSCEV(Start, End), getMinusSCEV(Stride, One)), Stride);
13619
13620 APInt MaxStart = IsSigned ? getSignedRangeMax(Start)
13622
13623 APInt MinStride = IsSigned ? getSignedRangeMin(Stride)
13624 : getUnsignedRangeMin(Stride);
13625
13626 unsigned BitWidth = getTypeSizeInBits(LHS->getType());
13627 APInt Limit = IsSigned ? APInt::getSignedMinValue(BitWidth) + (MinStride - 1)
13628 : APInt::getMinValue(BitWidth) + (MinStride - 1);
13629
13630 // Although End can be a MIN expression we estimate MinEnd considering only
13631 // the case End = RHS. This is safe because in the other case (Start - End)
13632 // is zero, leading to a zero maximum backedge taken count.
13633 APInt MinEnd =
13634 IsSigned ? APIntOps::smax(getSignedRangeMin(RHS), Limit)
13635 : APIntOps::umax(getUnsignedRangeMin(RHS), Limit);
13636
13637 const SCEV *ConstantMaxBECount =
13638 isa<SCEVConstant>(BECount)
13639 ? BECount
13640 : getUDivCeilSCEV(getConstant(MaxStart - MinEnd),
13641 getConstant(MinStride));
13642
13643 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount))
13644 ConstantMaxBECount = BECount;
13645 const SCEV *SymbolicMaxBECount =
13646 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13647
13648 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
13649 Predicates);
13650}
13651
13653 ScalarEvolution &SE) const {
13654 if (Range.isFullSet()) // Infinite loop.
13655 return SE.getCouldNotCompute();
13656
13657 // If the start is a non-zero constant, shift the range to simplify things.
13658 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart()))
13659 if (!SC->getValue()->isZero()) {
13661 Operands[0] = SE.getZero(SC->getType());
13662 const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop(),
13664 if (const auto *ShiftedAddRec = dyn_cast<SCEVAddRecExpr>(Shifted))
13665 return ShiftedAddRec->getNumIterationsInRange(
13666 Range.subtract(SC->getAPInt()), SE);
13667 // This is strange and shouldn't happen.
13668 return SE.getCouldNotCompute();
13669 }
13670
13671 // The only time we can solve this is when we have all constant indices.
13672 // Otherwise, we cannot determine the overflow conditions.
13673 if (any_of(operands(), [](const SCEV *Op) { return !isa<SCEVConstant>(Op); }))
13674 return SE.getCouldNotCompute();
13675
13676 // Okay at this point we know that all elements of the chrec are constants and
13677 // that the start element is zero.
13678
13679 // First check to see if the range contains zero. If not, the first
13680 // iteration exits.
13681 unsigned BitWidth = SE.getTypeSizeInBits(getType());
13682 if (!Range.contains(APInt(BitWidth, 0)))
13683 return SE.getZero(getType());
13684
13685 if (isAffine()) {
13686 // If this is an affine expression then we have this situation:
13687 // Solve {0,+,A} in Range === Ax in Range
13688
13689 // We know that zero is in the range. If A is positive then we know that
13690 // the upper value of the range must be the first possible exit value.
13691 // If A is negative then the lower of the range is the last possible loop
13692 // value. Also note that we already checked for a full range.
13693 APInt A = cast<SCEVConstant>(getOperand(1))->getAPInt();
13694 APInt End = A.sge(1) ? (Range.getUpper() - 1) : Range.getLower();
13695
13696 // The exit value should be (End+A)/A.
13697 APInt ExitVal = (End + A).udiv(A);
13698 ConstantInt *ExitValue = ConstantInt::get(SE.getContext(), ExitVal);
13699
13700 // Evaluate at the exit value. If we really did fall out of the valid
13701 // range, then we computed our trip count, otherwise wrap around or other
13702 // things must have happened.
13703 ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue, SE);
13704 if (Range.contains(Val->getValue()))
13705 return SE.getCouldNotCompute(); // Something strange happened
13706
13707 // Ensure that the previous value is in the range.
13708 assert(Range.contains(
13710 ConstantInt::get(SE.getContext(), ExitVal - 1), SE)->getValue()) &&
13711 "Linear scev computation is off in a bad way!");
13712 return SE.getConstant(ExitValue);
13713 }
13714
13715 if (isQuadratic()) {
13716 if (auto S = SolveQuadraticAddRecRange(this, Range, SE))
13717 return SE.getConstant(*S);
13718 }
13719
13720 return SE.getCouldNotCompute();
13721}
13722
13723const SCEVAddRecExpr *
13725 assert(getNumOperands() > 1 && "AddRec with zero step?");
13726 // There is a temptation to just call getAddExpr(this, getStepRecurrence(SE)),
13727 // but in this case we cannot guarantee that the value returned will be an
13728 // AddRec because SCEV does not have a fixed point where it stops
13729 // simplification: it is legal to return ({rec1} + {rec2}). For example, it
13730 // may happen if we reach arithmetic depth limit while simplifying. So we
13731 // construct the returned value explicitly.
13733 // If this is {A,+,B,+,C,...,+,N}, then its step is {B,+,C,+,...,+,N}, and
13734 // (this + Step) is {A+B,+,B+C,+...,+,N}.
13735 for (unsigned i = 0, e = getNumOperands() - 1; i < e; ++i)
13736 Ops.push_back(SE.getAddExpr(getOperand(i), getOperand(i + 1)));
13737 // We know that the last operand is not a constant zero (otherwise it would
13738 // have been popped out earlier). This guarantees us that if the result has
13739 // the same last operand, then it will also not be popped out, meaning that
13740 // the returned value will be an AddRec.
13741 const SCEV *Last = getOperand(getNumOperands() - 1);
13742 assert(!Last->isZero() && "Recurrency with zero step?");
13743 Ops.push_back(Last);
13746}
13747
13748// Return true when S contains at least an undef value.
13750 return SCEVExprContains(
13751 S, [](const SCEV *S) { return match(S, m_scev_UndefOrPoison()); });
13752}
13753
13754// Return true when S contains a value that is a nullptr.
13756 return SCEVExprContains(S, [](const SCEV *S) {
13757 if (const auto *SU = dyn_cast<SCEVUnknown>(S))
13758 return SU->getValue() == nullptr;
13759 return false;
13760 });
13761}
13762
13763/// Return the size of an element read or written by Inst.
13765 Type *Ty;
13766 if (StoreInst *Store = dyn_cast<StoreInst>(Inst))
13767 Ty = Store->getValueOperand()->getType();
13768 else if (LoadInst *Load = dyn_cast<LoadInst>(Inst))
13769 Ty = Load->getType();
13770 else
13771 return nullptr;
13772
13774 return getSizeOfExpr(ETy, Ty);
13775}
13776
13777//===----------------------------------------------------------------------===//
13778// SCEVCallbackVH Class Implementation
13779//===----------------------------------------------------------------------===//
13780
13782 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13783 if (PHINode *PN = dyn_cast<PHINode>(getValPtr()))
13784 SE->ConstantEvolutionLoopExitValue.erase(PN);
13785 SE->eraseValueFromMap(getValPtr());
13786 // this now dangles!
13787}
13788
13789void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) {
13790 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13791
13792 // Forget all the expressions associated with users of the old value,
13793 // so that future queries will recompute the expressions using the new
13794 // value.
13795 SE->forgetValue(getValPtr());
13796 // this now dangles!
13797}
13798
13799ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se)
13800 : CallbackVH(V), SE(se) {}
13801
13802//===----------------------------------------------------------------------===//
13803// ScalarEvolution Class Implementation
13804//===----------------------------------------------------------------------===//
13805
13808 LoopInfo &LI)
13809 : F(F), DL(F.getDataLayout()), TLI(TLI), AC(AC), DT(DT), LI(LI),
13810 CouldNotCompute(new SCEVCouldNotCompute()), ValuesAtScopes(64),
13811 LoopDispositions(64), BlockDispositions(64) {
13812 // To use guards for proving predicates, we need to scan every instruction in
13813 // relevant basic blocks, and not just terminators. Doing this is a waste of
13814 // time if the IR does not actually contain any calls to
13815 // @llvm.experimental.guard, so do a quick check and remember this beforehand.
13816 //
13817 // This pessimizes the case where a pass that preserves ScalarEvolution wants
13818 // to _add_ guards to the module when there weren't any before, and wants
13819 // ScalarEvolution to optimize based on those guards. For now we prefer to be
13820 // efficient in lieu of being smart in that rather obscure case.
13821
13822 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
13823 F.getParent(), Intrinsic::experimental_guard);
13824 HasGuards = GuardDecl && !GuardDecl->use_empty();
13825}
13826
13828 : F(Arg.F), DL(Arg.DL), HasGuards(Arg.HasGuards), TLI(Arg.TLI), AC(Arg.AC),
13829 DT(Arg.DT), LI(Arg.LI), CouldNotCompute(std::move(Arg.CouldNotCompute)),
13830 ValueExprMap(std::move(Arg.ValueExprMap)),
13831 PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)),
13832 PendingPhiRanges(std::move(Arg.PendingPhiRanges)),
13833 PendingMerges(std::move(Arg.PendingMerges)),
13834 ConstantMultipleCache(std::move(Arg.ConstantMultipleCache)),
13835 BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)),
13836 PredicatedBackedgeTakenCounts(
13837 std::move(Arg.PredicatedBackedgeTakenCounts)),
13838 BECountUsers(std::move(Arg.BECountUsers)),
13839 ConstantEvolutionLoopExitValue(
13840 std::move(Arg.ConstantEvolutionLoopExitValue)),
13841 ValuesAtScopes(std::move(Arg.ValuesAtScopes)),
13842 ValuesAtScopesUsers(std::move(Arg.ValuesAtScopesUsers)),
13843 LoopDispositions(std::move(Arg.LoopDispositions)),
13844 LoopPropertiesCache(std::move(Arg.LoopPropertiesCache)),
13845 BlockDispositions(std::move(Arg.BlockDispositions)),
13846 SCEVUsers(std::move(Arg.SCEVUsers)),
13847 UnsignedRanges(std::move(Arg.UnsignedRanges)),
13848 SignedRanges(std::move(Arg.SignedRanges)),
13849 UniqueSCEVs(std::move(Arg.UniqueSCEVs)),
13850 UniquePreds(std::move(Arg.UniquePreds)),
13851 SCEVAllocator(std::move(Arg.SCEVAllocator)),
13852 LoopUsers(std::move(Arg.LoopUsers)),
13853 PredicatedSCEVRewrites(std::move(Arg.PredicatedSCEVRewrites)),
13854 FirstUnknown(Arg.FirstUnknown) {
13855 Arg.FirstUnknown = nullptr;
13856}
13857
13859 // Iterate through all the SCEVUnknown instances and call their
13860 // destructors, so that they release their references to their values.
13861 for (SCEVUnknown *U = FirstUnknown; U;) {
13862 SCEVUnknown *Tmp = U;
13863 U = U->Next;
13864 Tmp->~SCEVUnknown();
13865 }
13866 FirstUnknown = nullptr;
13867
13868 ExprValueMap.clear();
13869 ValueExprMap.clear();
13870 HasRecMap.clear();
13871 BackedgeTakenCounts.clear();
13872 PredicatedBackedgeTakenCounts.clear();
13873
13874 assert(PendingLoopPredicates.empty() && "isImpliedCond garbage");
13875 assert(PendingPhiRanges.empty() && "getRangeRef garbage");
13876 assert(PendingMerges.empty() && "isImpliedViaMerge garbage");
13877 assert(!WalkingBEDominatingConds && "isLoopBackedgeGuardedByCond garbage!");
13878 assert(!ProvingSplitPredicate && "ProvingSplitPredicate garbage!");
13879}
13880
13884
13885/// When printing a top-level SCEV for trip counts, it's helpful to include
13886/// a type for constants which are otherwise hard to disambiguate.
13887static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV* S) {
13888 if (isa<SCEVConstant>(S))
13889 OS << *S->getType() << " ";
13890 OS << *S;
13891}
13892
13894 const Loop *L) {
13895 // Print all inner loops first
13896 for (Loop *I : *L)
13897 PrintLoopInfo(OS, SE, I);
13898
13899 OS << "Loop ";
13900 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13901 OS << ": ";
13902
13903 SmallVector<BasicBlock *, 8> ExitingBlocks;
13904 L->getExitingBlocks(ExitingBlocks);
13905 if (ExitingBlocks.size() != 1)
13906 OS << "<multiple exits> ";
13907
13908 auto *BTC = SE->getBackedgeTakenCount(L);
13909 if (!isa<SCEVCouldNotCompute>(BTC)) {
13910 OS << "backedge-taken count is ";
13911 PrintSCEVWithTypeHint(OS, BTC);
13912 } else
13913 OS << "Unpredictable backedge-taken count.";
13914 OS << "\n";
13915
13916 if (ExitingBlocks.size() > 1)
13917 for (BasicBlock *ExitingBlock : ExitingBlocks) {
13918 OS << " exit count for " << ExitingBlock->getName() << ": ";
13919 const SCEV *EC = SE->getExitCount(L, ExitingBlock);
13920 PrintSCEVWithTypeHint(OS, EC);
13921 if (isa<SCEVCouldNotCompute>(EC)) {
13922 // Retry with predicates.
13924 EC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates);
13925 if (!isa<SCEVCouldNotCompute>(EC)) {
13926 OS << "\n predicated exit count for " << ExitingBlock->getName()
13927 << ": ";
13928 PrintSCEVWithTypeHint(OS, EC);
13929 OS << "\n Predicates:\n";
13930 for (const auto *P : Predicates)
13931 P->print(OS, 4);
13932 }
13933 }
13934 OS << "\n";
13935 }
13936
13937 OS << "Loop ";
13938 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13939 OS << ": ";
13940
13941 auto *ConstantBTC = SE->getConstantMaxBackedgeTakenCount(L);
13942 if (!isa<SCEVCouldNotCompute>(ConstantBTC)) {
13943 OS << "constant max backedge-taken count is ";
13944 PrintSCEVWithTypeHint(OS, ConstantBTC);
13946 OS << ", actual taken count either this or zero.";
13947 } else {
13948 OS << "Unpredictable constant max backedge-taken count. ";
13949 }
13950
13951 OS << "\n"
13952 "Loop ";
13953 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13954 OS << ": ";
13955
13956 auto *SymbolicBTC = SE->getSymbolicMaxBackedgeTakenCount(L);
13957 if (!isa<SCEVCouldNotCompute>(SymbolicBTC)) {
13958 OS << "symbolic max backedge-taken count is ";
13959 PrintSCEVWithTypeHint(OS, SymbolicBTC);
13961 OS << ", actual taken count either this or zero.";
13962 } else {
13963 OS << "Unpredictable symbolic max backedge-taken count. ";
13964 }
13965 OS << "\n";
13966
13967 if (ExitingBlocks.size() > 1)
13968 for (BasicBlock *ExitingBlock : ExitingBlocks) {
13969 OS << " symbolic max exit count for " << ExitingBlock->getName() << ": ";
13970 auto *ExitBTC = SE->getExitCount(L, ExitingBlock,
13972 PrintSCEVWithTypeHint(OS, ExitBTC);
13973 if (isa<SCEVCouldNotCompute>(ExitBTC)) {
13974 // Retry with predicates.
13976 ExitBTC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates,
13978 if (!isa<SCEVCouldNotCompute>(ExitBTC)) {
13979 OS << "\n predicated symbolic max exit count for "
13980 << ExitingBlock->getName() << ": ";
13981 PrintSCEVWithTypeHint(OS, ExitBTC);
13982 OS << "\n Predicates:\n";
13983 for (const auto *P : Predicates)
13984 P->print(OS, 4);
13985 }
13986 }
13987 OS << "\n";
13988 }
13989
13991 auto *PBT = SE->getPredicatedBackedgeTakenCount(L, Preds);
13992 if (PBT != BTC) {
13993 assert(!Preds.empty() && "Different predicated BTC, but no predicates");
13994 OS << "Loop ";
13995 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13996 OS << ": ";
13997 if (!isa<SCEVCouldNotCompute>(PBT)) {
13998 OS << "Predicated backedge-taken count is ";
13999 PrintSCEVWithTypeHint(OS, PBT);
14000 } else
14001 OS << "Unpredictable predicated backedge-taken count.";
14002 OS << "\n";
14003 OS << " Predicates:\n";
14004 for (const auto *P : Preds)
14005 P->print(OS, 4);
14006 }
14007 Preds.clear();
14008
14009 auto *PredConstantMax =
14011 if (PredConstantMax != ConstantBTC) {
14012 assert(!Preds.empty() &&
14013 "different predicated constant max BTC but no predicates");
14014 OS << "Loop ";
14015 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14016 OS << ": ";
14017 if (!isa<SCEVCouldNotCompute>(PredConstantMax)) {
14018 OS << "Predicated constant max backedge-taken count is ";
14019 PrintSCEVWithTypeHint(OS, PredConstantMax);
14020 } else
14021 OS << "Unpredictable predicated constant max backedge-taken count.";
14022 OS << "\n";
14023 OS << " Predicates:\n";
14024 for (const auto *P : Preds)
14025 P->print(OS, 4);
14026 }
14027 Preds.clear();
14028
14029 auto *PredSymbolicMax =
14031 if (SymbolicBTC != PredSymbolicMax) {
14032 assert(!Preds.empty() &&
14033 "Different predicated symbolic max BTC, but no predicates");
14034 OS << "Loop ";
14035 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14036 OS << ": ";
14037 if (!isa<SCEVCouldNotCompute>(PredSymbolicMax)) {
14038 OS << "Predicated symbolic max backedge-taken count is ";
14039 PrintSCEVWithTypeHint(OS, PredSymbolicMax);
14040 } else
14041 OS << "Unpredictable predicated symbolic max backedge-taken count.";
14042 OS << "\n";
14043 OS << " Predicates:\n";
14044 for (const auto *P : Preds)
14045 P->print(OS, 4);
14046 }
14047
14049 OS << "Loop ";
14050 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14051 OS << ": ";
14052 OS << "Trip multiple is " << SE->getSmallConstantTripMultiple(L) << "\n";
14053 }
14054}
14055
14056namespace llvm {
14057// Note: these overloaded operators need to be in the llvm namespace for them
14058// to be resolved correctly. If we put them outside the llvm namespace, the
14059//
14060// OS << ": " << SE.getLoopDisposition(SV, InnerL);
14061//
14062// code below "breaks" and start printing raw enum values as opposed to the
14063// string values.
14066 switch (LD) {
14068 OS << "Variant";
14069 break;
14071 OS << "Invariant";
14072 break;
14074 OS << "Computable";
14075 break;
14076 }
14077 return OS;
14078}
14079
14082 switch (BD) {
14084 OS << "DoesNotDominate";
14085 break;
14087 OS << "Dominates";
14088 break;
14090 OS << "ProperlyDominates";
14091 break;
14092 }
14093 return OS;
14094}
14095} // namespace llvm
14096
14098 // ScalarEvolution's implementation of the print method is to print
14099 // out SCEV values of all instructions that are interesting. Doing
14100 // this potentially causes it to create new SCEV objects though,
14101 // which technically conflicts with the const qualifier. This isn't
14102 // observable from outside the class though, so casting away the
14103 // const isn't dangerous.
14104 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
14105
14106 if (ClassifyExpressions) {
14107 OS << "Classifying expressions for: ";
14108 F.printAsOperand(OS, /*PrintType=*/false);
14109 OS << "\n";
14110 for (Instruction &I : instructions(F))
14111 if (isSCEVable(I.getType()) && !isa<CmpInst>(I)) {
14112 OS << I << '\n';
14113 OS << " --> ";
14114 const SCEV *SV = SE.getSCEV(&I);
14115 SV->print(OS);
14116 if (!isa<SCEVCouldNotCompute>(SV)) {
14117 OS << " U: ";
14118 SE.getUnsignedRange(SV).print(OS);
14119 OS << " S: ";
14120 SE.getSignedRange(SV).print(OS);
14121 }
14122
14123 const Loop *L = LI.getLoopFor(I.getParent());
14124
14125 const SCEV *AtUse = SE.getSCEVAtScope(SV, L);
14126 if (AtUse != SV) {
14127 OS << " --> ";
14128 AtUse->print(OS);
14129 if (!isa<SCEVCouldNotCompute>(AtUse)) {
14130 OS << " U: ";
14131 SE.getUnsignedRange(AtUse).print(OS);
14132 OS << " S: ";
14133 SE.getSignedRange(AtUse).print(OS);
14134 }
14135 }
14136
14137 if (L) {
14138 OS << "\t\t" "Exits: ";
14139 const SCEV *ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop());
14140 if (!SE.isLoopInvariant(ExitValue, L)) {
14141 OS << "<<Unknown>>";
14142 } else {
14143 OS << *ExitValue;
14144 }
14145
14146 ListSeparator LS(", ", "\t\tLoopDispositions: { ");
14147 for (const auto *Iter = L; Iter; Iter = Iter->getParentLoop()) {
14148 OS << LS;
14149 Iter->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14150 OS << ": " << SE.getLoopDisposition(SV, Iter);
14151 }
14152
14153 for (const auto *InnerL : depth_first(L)) {
14154 if (InnerL == L)
14155 continue;
14156 OS << LS;
14157 InnerL->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14158 OS << ": " << SE.getLoopDisposition(SV, InnerL);
14159 }
14160
14161 OS << " }";
14162 }
14163
14164 OS << "\n";
14165 }
14166 }
14167
14168 OS << "Determining loop execution counts for: ";
14169 F.printAsOperand(OS, /*PrintType=*/false);
14170 OS << "\n";
14171 for (Loop *I : LI)
14172 PrintLoopInfo(OS, &SE, I);
14173}
14174
14177 auto &Values = LoopDispositions[S];
14178 for (auto &V : Values) {
14179 if (V.getPointer() == L)
14180 return V.getInt();
14181 }
14182 Values.emplace_back(L, LoopVariant);
14183 LoopDisposition D = computeLoopDisposition(S, L);
14184 auto &Values2 = LoopDispositions[S];
14185 for (auto &V : llvm::reverse(Values2)) {
14186 if (V.getPointer() == L) {
14187 V.setInt(D);
14188 break;
14189 }
14190 }
14191 return D;
14192}
14193
14195ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) {
14196 switch (S->getSCEVType()) {
14197 case scConstant:
14198 case scVScale:
14199 return LoopInvariant;
14200 case scAddRecExpr: {
14201 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
14202
14203 // If L is the addrec's loop, it's computable.
14204 if (AR->getLoop() == L)
14205 return LoopComputable;
14206
14207 // Add recurrences are never invariant in the function-body (null loop).
14208 if (!L)
14209 return LoopVariant;
14210
14211 // Everything that is not defined at loop entry is variant.
14212 if (DT.dominates(L->getHeader(), AR->getLoop()->getHeader()))
14213 return LoopVariant;
14214 assert(!L->contains(AR->getLoop()) && "Containing loop's header does not"
14215 " dominate the contained loop's header?");
14216
14217 // This recurrence is invariant w.r.t. L if AR's loop contains L.
14218 if (AR->getLoop()->contains(L))
14219 return LoopInvariant;
14220
14221 // This recurrence is variant w.r.t. L if any of its operands
14222 // are variant.
14223 for (const auto *Op : AR->operands())
14224 if (!isLoopInvariant(Op, L))
14225 return LoopVariant;
14226
14227 // Otherwise it's loop-invariant.
14228 return LoopInvariant;
14229 }
14230 case scTruncate:
14231 case scZeroExtend:
14232 case scSignExtend:
14233 case scPtrToAddr:
14234 case scPtrToInt:
14235 case scAddExpr:
14236 case scMulExpr:
14237 case scUDivExpr:
14238 case scUMaxExpr:
14239 case scSMaxExpr:
14240 case scUMinExpr:
14241 case scSMinExpr:
14242 case scSequentialUMinExpr: {
14243 bool HasVarying = false;
14244 for (const auto *Op : S->operands()) {
14246 if (D == LoopVariant)
14247 return LoopVariant;
14248 if (D == LoopComputable)
14249 HasVarying = true;
14250 }
14251 return HasVarying ? LoopComputable : LoopInvariant;
14252 }
14253 case scUnknown:
14254 // All non-instruction values are loop invariant. All instructions are loop
14255 // invariant if they are not contained in the specified loop.
14256 // Instructions are never considered invariant in the function body
14257 // (null loop) because they are defined within the "loop".
14258 if (auto *I = dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue()))
14259 return (L && !L->contains(I)) ? LoopInvariant : LoopVariant;
14260 return LoopInvariant;
14261 case scCouldNotCompute:
14262 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
14263 }
14264 llvm_unreachable("Unknown SCEV kind!");
14265}
14266
14268 return getLoopDisposition(S, L) == LoopInvariant;
14269}
14270
14272 return getLoopDisposition(S, L) == LoopComputable;
14273}
14274
14277 auto &Values = BlockDispositions[S];
14278 for (auto &V : Values) {
14279 if (V.getPointer() == BB)
14280 return V.getInt();
14281 }
14282 Values.emplace_back(BB, DoesNotDominateBlock);
14283 BlockDisposition D = computeBlockDisposition(S, BB);
14284 auto &Values2 = BlockDispositions[S];
14285 for (auto &V : llvm::reverse(Values2)) {
14286 if (V.getPointer() == BB) {
14287 V.setInt(D);
14288 break;
14289 }
14290 }
14291 return D;
14292}
14293
14295ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) {
14296 switch (S->getSCEVType()) {
14297 case scConstant:
14298 case scVScale:
14300 case scAddRecExpr: {
14301 // This uses a "dominates" query instead of "properly dominates" query
14302 // to test for proper dominance too, because the instruction which
14303 // produces the addrec's value is a PHI, and a PHI effectively properly
14304 // dominates its entire containing block.
14305 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
14306 if (!DT.dominates(AR->getLoop()->getHeader(), BB))
14307 return DoesNotDominateBlock;
14308
14309 // Fall through into SCEVNAryExpr handling.
14310 [[fallthrough]];
14311 }
14312 case scTruncate:
14313 case scZeroExtend:
14314 case scSignExtend:
14315 case scPtrToAddr:
14316 case scPtrToInt:
14317 case scAddExpr:
14318 case scMulExpr:
14319 case scUDivExpr:
14320 case scUMaxExpr:
14321 case scSMaxExpr:
14322 case scUMinExpr:
14323 case scSMinExpr:
14324 case scSequentialUMinExpr: {
14325 bool Proper = true;
14326 for (const SCEV *NAryOp : S->operands()) {
14328 if (D == DoesNotDominateBlock)
14329 return DoesNotDominateBlock;
14330 if (D == DominatesBlock)
14331 Proper = false;
14332 }
14333 return Proper ? ProperlyDominatesBlock : DominatesBlock;
14334 }
14335 case scUnknown:
14336 if (Instruction *I =
14337 dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue())) {
14338 if (I->getParent() == BB)
14339 return DominatesBlock;
14340 if (DT.properlyDominates(I->getParent(), BB))
14342 return DoesNotDominateBlock;
14343 }
14345 case scCouldNotCompute:
14346 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
14347 }
14348 llvm_unreachable("Unknown SCEV kind!");
14349}
14350
14351bool ScalarEvolution::dominates(const SCEV *S, const BasicBlock *BB) {
14352 return getBlockDisposition(S, BB) >= DominatesBlock;
14353}
14354
14357}
14358
14359bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const {
14360 return SCEVExprContains(S, [&](const SCEV *Expr) { return Expr == Op; });
14361}
14362
14363void ScalarEvolution::forgetBackedgeTakenCounts(const Loop *L,
14364 bool Predicated) {
14365 auto &BECounts =
14366 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
14367 auto It = BECounts.find(L);
14368 if (It != BECounts.end()) {
14369 for (const ExitNotTakenInfo &ENT : It->second.ExitNotTaken) {
14370 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
14371 if (!isa<SCEVConstant>(S)) {
14372 auto UserIt = BECountUsers.find(S);
14373 assert(UserIt != BECountUsers.end());
14374 UserIt->second.erase({L, Predicated});
14375 }
14376 }
14377 }
14378 BECounts.erase(It);
14379 }
14380}
14381
14382void ScalarEvolution::forgetMemoizedResults(ArrayRef<const SCEV *> SCEVs) {
14383 SmallPtrSet<const SCEV *, 8> ToForget(llvm::from_range, SCEVs);
14384 SmallVector<const SCEV *, 8> Worklist(ToForget.begin(), ToForget.end());
14385
14386 while (!Worklist.empty()) {
14387 const SCEV *Curr = Worklist.pop_back_val();
14388 auto Users = SCEVUsers.find(Curr);
14389 if (Users != SCEVUsers.end())
14390 for (const auto *User : Users->second)
14391 if (ToForget.insert(User).second)
14392 Worklist.push_back(User);
14393 }
14394
14395 for (const auto *S : ToForget)
14396 forgetMemoizedResultsImpl(S);
14397
14398 for (auto I = PredicatedSCEVRewrites.begin();
14399 I != PredicatedSCEVRewrites.end();) {
14400 std::pair<const SCEV *, const Loop *> Entry = I->first;
14401 if (ToForget.count(Entry.first))
14402 PredicatedSCEVRewrites.erase(I++);
14403 else
14404 ++I;
14405 }
14406}
14407
14408void ScalarEvolution::forgetMemoizedResultsImpl(const SCEV *S) {
14409 LoopDispositions.erase(S);
14410 BlockDispositions.erase(S);
14411 UnsignedRanges.erase(S);
14412 SignedRanges.erase(S);
14413 HasRecMap.erase(S);
14414 ConstantMultipleCache.erase(S);
14415
14416 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S)) {
14417 UnsignedWrapViaInductionTried.erase(AR);
14418 SignedWrapViaInductionTried.erase(AR);
14419 }
14420
14421 auto ExprIt = ExprValueMap.find(S);
14422 if (ExprIt != ExprValueMap.end()) {
14423 for (Value *V : ExprIt->second) {
14424 auto ValueIt = ValueExprMap.find_as(V);
14425 if (ValueIt != ValueExprMap.end())
14426 ValueExprMap.erase(ValueIt);
14427 }
14428 ExprValueMap.erase(ExprIt);
14429 }
14430
14431 auto ScopeIt = ValuesAtScopes.find(S);
14432 if (ScopeIt != ValuesAtScopes.end()) {
14433 for (const auto &Pair : ScopeIt->second)
14434 if (!isa_and_nonnull<SCEVConstant>(Pair.second))
14435 llvm::erase(ValuesAtScopesUsers[Pair.second],
14436 std::make_pair(Pair.first, S));
14437 ValuesAtScopes.erase(ScopeIt);
14438 }
14439
14440 auto ScopeUserIt = ValuesAtScopesUsers.find(S);
14441 if (ScopeUserIt != ValuesAtScopesUsers.end()) {
14442 for (const auto &Pair : ScopeUserIt->second)
14443 llvm::erase(ValuesAtScopes[Pair.second], std::make_pair(Pair.first, S));
14444 ValuesAtScopesUsers.erase(ScopeUserIt);
14445 }
14446
14447 auto BEUsersIt = BECountUsers.find(S);
14448 if (BEUsersIt != BECountUsers.end()) {
14449 // Work on a copy, as forgetBackedgeTakenCounts() will modify the original.
14450 auto Copy = BEUsersIt->second;
14451 for (const auto &Pair : Copy)
14452 forgetBackedgeTakenCounts(Pair.getPointer(), Pair.getInt());
14453 BECountUsers.erase(BEUsersIt);
14454 }
14455
14456 auto FoldUser = FoldCacheUser.find(S);
14457 if (FoldUser != FoldCacheUser.end())
14458 for (auto &KV : FoldUser->second)
14459 FoldCache.erase(KV);
14460 FoldCacheUser.erase(S);
14461}
14462
14463void
14464ScalarEvolution::getUsedLoops(const SCEV *S,
14465 SmallPtrSetImpl<const Loop *> &LoopsUsed) {
14466 struct FindUsedLoops {
14467 FindUsedLoops(SmallPtrSetImpl<const Loop *> &LoopsUsed)
14468 : LoopsUsed(LoopsUsed) {}
14469 SmallPtrSetImpl<const Loop *> &LoopsUsed;
14470 bool follow(const SCEV *S) {
14471 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S))
14472 LoopsUsed.insert(AR->getLoop());
14473 return true;
14474 }
14475
14476 bool isDone() const { return false; }
14477 };
14478
14479 FindUsedLoops F(LoopsUsed);
14480 SCEVTraversal<FindUsedLoops>(F).visitAll(S);
14481}
14482
14483void ScalarEvolution::getReachableBlocks(
14486 Worklist.push_back(&F.getEntryBlock());
14487 while (!Worklist.empty()) {
14488 BasicBlock *BB = Worklist.pop_back_val();
14489 if (!Reachable.insert(BB).second)
14490 continue;
14491
14492 Value *Cond;
14493 BasicBlock *TrueBB, *FalseBB;
14494 if (match(BB->getTerminator(), m_Br(m_Value(Cond), m_BasicBlock(TrueBB),
14495 m_BasicBlock(FalseBB)))) {
14496 if (auto *C = dyn_cast<ConstantInt>(Cond)) {
14497 Worklist.push_back(C->isOne() ? TrueBB : FalseBB);
14498 continue;
14499 }
14500
14501 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
14502 const SCEV *L = getSCEV(Cmp->getOperand(0));
14503 const SCEV *R = getSCEV(Cmp->getOperand(1));
14504 if (isKnownPredicateViaConstantRanges(Cmp->getCmpPredicate(), L, R)) {
14505 Worklist.push_back(TrueBB);
14506 continue;
14507 }
14508 if (isKnownPredicateViaConstantRanges(Cmp->getInverseCmpPredicate(), L,
14509 R)) {
14510 Worklist.push_back(FalseBB);
14511 continue;
14512 }
14513 }
14514 }
14515
14516 append_range(Worklist, successors(BB));
14517 }
14518}
14519
14521 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
14522 ScalarEvolution SE2(F, TLI, AC, DT, LI);
14523
14524 SmallVector<Loop *, 8> LoopStack(LI.begin(), LI.end());
14525
14526 // Map's SCEV expressions from one ScalarEvolution "universe" to another.
14527 struct SCEVMapper : public SCEVRewriteVisitor<SCEVMapper> {
14528 SCEVMapper(ScalarEvolution &SE) : SCEVRewriteVisitor<SCEVMapper>(SE) {}
14529
14530 const SCEV *visitConstant(const SCEVConstant *Constant) {
14531 return SE.getConstant(Constant->getAPInt());
14532 }
14533
14534 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14535 return SE.getUnknown(Expr->getValue());
14536 }
14537
14538 const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
14539 return SE.getCouldNotCompute();
14540 }
14541 };
14542
14543 SCEVMapper SCM(SE2);
14544 SmallPtrSet<BasicBlock *, 16> ReachableBlocks;
14545 SE2.getReachableBlocks(ReachableBlocks, F);
14546
14547 auto GetDelta = [&](const SCEV *Old, const SCEV *New) -> const SCEV * {
14548 if (containsUndefs(Old) || containsUndefs(New)) {
14549 // SCEV treats "undef" as an unknown but consistent value (i.e. it does
14550 // not propagate undef aggressively). This means we can (and do) fail
14551 // verification in cases where a transform makes a value go from "undef"
14552 // to "undef+1" (say). The transform is fine, since in both cases the
14553 // result is "undef", but SCEV thinks the value increased by 1.
14554 return nullptr;
14555 }
14556
14557 // Unless VerifySCEVStrict is set, we only compare constant deltas.
14558 const SCEV *Delta = SE2.getMinusSCEV(Old, New);
14559 if (!VerifySCEVStrict && !isa<SCEVConstant>(Delta))
14560 return nullptr;
14561
14562 return Delta;
14563 };
14564
14565 while (!LoopStack.empty()) {
14566 auto *L = LoopStack.pop_back_val();
14567 llvm::append_range(LoopStack, *L);
14568
14569 // Only verify BECounts in reachable loops. For an unreachable loop,
14570 // any BECount is legal.
14571 if (!ReachableBlocks.contains(L->getHeader()))
14572 continue;
14573
14574 // Only verify cached BECounts. Computing new BECounts may change the
14575 // results of subsequent SCEV uses.
14576 auto It = BackedgeTakenCounts.find(L);
14577 if (It == BackedgeTakenCounts.end())
14578 continue;
14579
14580 auto *CurBECount =
14581 SCM.visit(It->second.getExact(L, const_cast<ScalarEvolution *>(this)));
14582 auto *NewBECount = SE2.getBackedgeTakenCount(L);
14583
14584 if (CurBECount == SE2.getCouldNotCompute() ||
14585 NewBECount == SE2.getCouldNotCompute()) {
14586 // NB! This situation is legal, but is very suspicious -- whatever pass
14587 // change the loop to make a trip count go from could not compute to
14588 // computable or vice-versa *should have* invalidated SCEV. However, we
14589 // choose not to assert here (for now) since we don't want false
14590 // positives.
14591 continue;
14592 }
14593
14594 if (SE.getTypeSizeInBits(CurBECount->getType()) >
14595 SE.getTypeSizeInBits(NewBECount->getType()))
14596 NewBECount = SE2.getZeroExtendExpr(NewBECount, CurBECount->getType());
14597 else if (SE.getTypeSizeInBits(CurBECount->getType()) <
14598 SE.getTypeSizeInBits(NewBECount->getType()))
14599 CurBECount = SE2.getZeroExtendExpr(CurBECount, NewBECount->getType());
14600
14601 const SCEV *Delta = GetDelta(CurBECount, NewBECount);
14602 if (Delta && !Delta->isZero()) {
14603 dbgs() << "Trip Count for " << *L << " Changed!\n";
14604 dbgs() << "Old: " << *CurBECount << "\n";
14605 dbgs() << "New: " << *NewBECount << "\n";
14606 dbgs() << "Delta: " << *Delta << "\n";
14607 std::abort();
14608 }
14609 }
14610
14611 // Collect all valid loops currently in LoopInfo.
14612 SmallPtrSet<Loop *, 32> ValidLoops;
14613 SmallVector<Loop *, 32> Worklist(LI.begin(), LI.end());
14614 while (!Worklist.empty()) {
14615 Loop *L = Worklist.pop_back_val();
14616 if (ValidLoops.insert(L).second)
14617 Worklist.append(L->begin(), L->end());
14618 }
14619 for (const auto &KV : ValueExprMap) {
14620#ifndef NDEBUG
14621 // Check for SCEV expressions referencing invalid/deleted loops.
14622 if (auto *AR = dyn_cast<SCEVAddRecExpr>(KV.second)) {
14623 assert(ValidLoops.contains(AR->getLoop()) &&
14624 "AddRec references invalid loop");
14625 }
14626#endif
14627
14628 // Check that the value is also part of the reverse map.
14629 auto It = ExprValueMap.find(KV.second);
14630 if (It == ExprValueMap.end() || !It->second.contains(KV.first)) {
14631 dbgs() << "Value " << *KV.first
14632 << " is in ValueExprMap but not in ExprValueMap\n";
14633 std::abort();
14634 }
14635
14636 if (auto *I = dyn_cast<Instruction>(&*KV.first)) {
14637 if (!ReachableBlocks.contains(I->getParent()))
14638 continue;
14639 const SCEV *OldSCEV = SCM.visit(KV.second);
14640 const SCEV *NewSCEV = SE2.getSCEV(I);
14641 const SCEV *Delta = GetDelta(OldSCEV, NewSCEV);
14642 if (Delta && !Delta->isZero()) {
14643 dbgs() << "SCEV for value " << *I << " changed!\n"
14644 << "Old: " << *OldSCEV << "\n"
14645 << "New: " << *NewSCEV << "\n"
14646 << "Delta: " << *Delta << "\n";
14647 std::abort();
14648 }
14649 }
14650 }
14651
14652 for (const auto &KV : ExprValueMap) {
14653 for (Value *V : KV.second) {
14654 const SCEV *S = ValueExprMap.lookup(V);
14655 if (!S) {
14656 dbgs() << "Value " << *V
14657 << " is in ExprValueMap but not in ValueExprMap\n";
14658 std::abort();
14659 }
14660 if (S != KV.first) {
14661 dbgs() << "Value " << *V << " mapped to " << *S << " rather than "
14662 << *KV.first << "\n";
14663 std::abort();
14664 }
14665 }
14666 }
14667
14668 // Verify integrity of SCEV users.
14669 for (const auto &S : UniqueSCEVs) {
14670 for (const auto *Op : S.operands()) {
14671 // We do not store dependencies of constants.
14672 if (isa<SCEVConstant>(Op))
14673 continue;
14674 auto It = SCEVUsers.find(Op);
14675 if (It != SCEVUsers.end() && It->second.count(&S))
14676 continue;
14677 dbgs() << "Use of operand " << *Op << " by user " << S
14678 << " is not being tracked!\n";
14679 std::abort();
14680 }
14681 }
14682
14683 // Verify integrity of ValuesAtScopes users.
14684 for (const auto &ValueAndVec : ValuesAtScopes) {
14685 const SCEV *Value = ValueAndVec.first;
14686 for (const auto &LoopAndValueAtScope : ValueAndVec.second) {
14687 const Loop *L = LoopAndValueAtScope.first;
14688 const SCEV *ValueAtScope = LoopAndValueAtScope.second;
14689 if (!isa<SCEVConstant>(ValueAtScope)) {
14690 auto It = ValuesAtScopesUsers.find(ValueAtScope);
14691 if (It != ValuesAtScopesUsers.end() &&
14692 is_contained(It->second, std::make_pair(L, Value)))
14693 continue;
14694 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14695 << *ValueAtScope << " missing in ValuesAtScopesUsers\n";
14696 std::abort();
14697 }
14698 }
14699 }
14700
14701 for (const auto &ValueAtScopeAndVec : ValuesAtScopesUsers) {
14702 const SCEV *ValueAtScope = ValueAtScopeAndVec.first;
14703 for (const auto &LoopAndValue : ValueAtScopeAndVec.second) {
14704 const Loop *L = LoopAndValue.first;
14705 const SCEV *Value = LoopAndValue.second;
14707 auto It = ValuesAtScopes.find(Value);
14708 if (It != ValuesAtScopes.end() &&
14709 is_contained(It->second, std::make_pair(L, ValueAtScope)))
14710 continue;
14711 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14712 << *ValueAtScope << " missing in ValuesAtScopes\n";
14713 std::abort();
14714 }
14715 }
14716
14717 // Verify integrity of BECountUsers.
14718 auto VerifyBECountUsers = [&](bool Predicated) {
14719 auto &BECounts =
14720 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
14721 for (const auto &LoopAndBEInfo : BECounts) {
14722 for (const ExitNotTakenInfo &ENT : LoopAndBEInfo.second.ExitNotTaken) {
14723 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
14724 if (!isa<SCEVConstant>(S)) {
14725 auto UserIt = BECountUsers.find(S);
14726 if (UserIt != BECountUsers.end() &&
14727 UserIt->second.contains({ LoopAndBEInfo.first, Predicated }))
14728 continue;
14729 dbgs() << "Value " << *S << " for loop " << *LoopAndBEInfo.first
14730 << " missing from BECountUsers\n";
14731 std::abort();
14732 }
14733 }
14734 }
14735 }
14736 };
14737 VerifyBECountUsers(/* Predicated */ false);
14738 VerifyBECountUsers(/* Predicated */ true);
14739
14740 // Verify intergity of loop disposition cache.
14741 for (auto &[S, Values] : LoopDispositions) {
14742 for (auto [Loop, CachedDisposition] : Values) {
14743 const auto RecomputedDisposition = SE2.getLoopDisposition(S, Loop);
14744 if (CachedDisposition != RecomputedDisposition) {
14745 dbgs() << "Cached disposition of " << *S << " for loop " << *Loop
14746 << " is incorrect: cached " << CachedDisposition << ", actual "
14747 << RecomputedDisposition << "\n";
14748 std::abort();
14749 }
14750 }
14751 }
14752
14753 // Verify integrity of the block disposition cache.
14754 for (auto &[S, Values] : BlockDispositions) {
14755 for (auto [BB, CachedDisposition] : Values) {
14756 const auto RecomputedDisposition = SE2.getBlockDisposition(S, BB);
14757 if (CachedDisposition != RecomputedDisposition) {
14758 dbgs() << "Cached disposition of " << *S << " for block %"
14759 << BB->getName() << " is incorrect: cached " << CachedDisposition
14760 << ", actual " << RecomputedDisposition << "\n";
14761 std::abort();
14762 }
14763 }
14764 }
14765
14766 // Verify FoldCache/FoldCacheUser caches.
14767 for (auto [FoldID, Expr] : FoldCache) {
14768 auto I = FoldCacheUser.find(Expr);
14769 if (I == FoldCacheUser.end()) {
14770 dbgs() << "Missing entry in FoldCacheUser for cached expression " << *Expr
14771 << "!\n";
14772 std::abort();
14773 }
14774 if (!is_contained(I->second, FoldID)) {
14775 dbgs() << "Missing FoldID in cached users of " << *Expr << "!\n";
14776 std::abort();
14777 }
14778 }
14779 for (auto [Expr, IDs] : FoldCacheUser) {
14780 for (auto &FoldID : IDs) {
14781 const SCEV *S = FoldCache.lookup(FoldID);
14782 if (!S) {
14783 dbgs() << "Missing entry in FoldCache for expression " << *Expr
14784 << "!\n";
14785 std::abort();
14786 }
14787 if (S != Expr) {
14788 dbgs() << "Entry in FoldCache doesn't match FoldCacheUser: " << *S
14789 << " != " << *Expr << "!\n";
14790 std::abort();
14791 }
14792 }
14793 }
14794
14795 // Verify that ConstantMultipleCache computations are correct. We check that
14796 // cached multiples and recomputed multiples are multiples of each other to
14797 // verify correctness. It is possible that a recomputed multiple is different
14798 // from the cached multiple due to strengthened no wrap flags or changes in
14799 // KnownBits computations.
14800 for (auto [S, Multiple] : ConstantMultipleCache) {
14801 APInt RecomputedMultiple = SE2.getConstantMultiple(S);
14802 if ((Multiple != 0 && RecomputedMultiple != 0 &&
14803 Multiple.urem(RecomputedMultiple) != 0 &&
14804 RecomputedMultiple.urem(Multiple) != 0)) {
14805 dbgs() << "Incorrect cached computation in ConstantMultipleCache for "
14806 << *S << " : Computed " << RecomputedMultiple
14807 << " but cache contains " << Multiple << "!\n";
14808 std::abort();
14809 }
14810 }
14811}
14812
14814 Function &F, const PreservedAnalyses &PA,
14815 FunctionAnalysisManager::Invalidator &Inv) {
14816 // Invalidate the ScalarEvolution object whenever it isn't preserved or one
14817 // of its dependencies is invalidated.
14818 auto PAC = PA.getChecker<ScalarEvolutionAnalysis>();
14819 return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>()) ||
14820 Inv.invalidate<AssumptionAnalysis>(F, PA) ||
14821 Inv.invalidate<DominatorTreeAnalysis>(F, PA) ||
14822 Inv.invalidate<LoopAnalysis>(F, PA);
14823}
14824
14825AnalysisKey ScalarEvolutionAnalysis::Key;
14826
14829 auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
14830 auto &AC = AM.getResult<AssumptionAnalysis>(F);
14831 auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
14832 auto &LI = AM.getResult<LoopAnalysis>(F);
14833 return ScalarEvolution(F, TLI, AC, DT, LI);
14834}
14835
14841
14844 // For compatibility with opt's -analyze feature under legacy pass manager
14845 // which was not ported to NPM. This keeps tests using
14846 // update_analyze_test_checks.py working.
14847 OS << "Printing analysis 'Scalar Evolution Analysis' for function '"
14848 << F.getName() << "':\n";
14850 return PreservedAnalyses::all();
14851}
14852
14854 "Scalar Evolution Analysis", false, true)
14860 "Scalar Evolution Analysis", false, true)
14861
14863
14865
14867 SE.reset(new ScalarEvolution(
14869 getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F),
14871 getAnalysis<LoopInfoWrapperPass>().getLoopInfo()));
14872 return false;
14873}
14874
14876
14878 SE->print(OS);
14879}
14880
14882 if (!VerifySCEV)
14883 return;
14884
14885 SE->verify();
14886}
14887
14895
14897 const SCEV *RHS) {
14898 return getComparePredicate(ICmpInst::ICMP_EQ, LHS, RHS);
14899}
14900
14901const SCEVPredicate *
14903 const SCEV *LHS, const SCEV *RHS) {
14905 assert(LHS->getType() == RHS->getType() &&
14906 "Type mismatch between LHS and RHS");
14907 // Unique this node based on the arguments
14908 ID.AddInteger(SCEVPredicate::P_Compare);
14909 ID.AddInteger(Pred);
14910 ID.AddPointer(LHS);
14911 ID.AddPointer(RHS);
14912 void *IP = nullptr;
14913 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
14914 return S;
14915 SCEVComparePredicate *Eq = new (SCEVAllocator)
14916 SCEVComparePredicate(ID.Intern(SCEVAllocator), Pred, LHS, RHS);
14917 UniquePreds.InsertNode(Eq, IP);
14918 return Eq;
14919}
14920
14922 const SCEVAddRecExpr *AR,
14925 // Unique this node based on the arguments
14926 ID.AddInteger(SCEVPredicate::P_Wrap);
14927 ID.AddPointer(AR);
14928 ID.AddInteger(AddedFlags);
14929 void *IP = nullptr;
14930 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
14931 return S;
14932 auto *OF = new (SCEVAllocator)
14933 SCEVWrapPredicate(ID.Intern(SCEVAllocator), AR, AddedFlags);
14934 UniquePreds.InsertNode(OF, IP);
14935 return OF;
14936}
14937
14938namespace {
14939
14940class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
14941public:
14942
14943 /// Rewrites \p S in the context of a loop L and the SCEV predication
14944 /// infrastructure.
14945 ///
14946 /// If \p Pred is non-null, the SCEV expression is rewritten to respect the
14947 /// equivalences present in \p Pred.
14948 ///
14949 /// If \p NewPreds is non-null, rewrite is free to add further predicates to
14950 /// \p NewPreds such that the result will be an AddRecExpr.
14951 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
14953 const SCEVPredicate *Pred) {
14954 SCEVPredicateRewriter Rewriter(L, SE, NewPreds, Pred);
14955 return Rewriter.visit(S);
14956 }
14957
14958 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14959 if (Pred) {
14960 if (auto *U = dyn_cast<SCEVUnionPredicate>(Pred)) {
14961 for (const auto *Pred : U->getPredicates())
14962 if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred))
14963 if (IPred->getLHS() == Expr &&
14964 IPred->getPredicate() == ICmpInst::ICMP_EQ)
14965 return IPred->getRHS();
14966 } else if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred)) {
14967 if (IPred->getLHS() == Expr &&
14968 IPred->getPredicate() == ICmpInst::ICMP_EQ)
14969 return IPred->getRHS();
14970 }
14971 }
14972 return convertToAddRecWithPreds(Expr);
14973 }
14974
14975 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
14976 const SCEV *Operand = visit(Expr->getOperand());
14977 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
14978 if (AR && AR->getLoop() == L && AR->isAffine()) {
14979 // This couldn't be folded because the operand didn't have the nuw
14980 // flag. Add the nusw flag as an assumption that we could make.
14981 const SCEV *Step = AR->getStepRecurrence(SE);
14982 Type *Ty = Expr->getType();
14983 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNUSW))
14984 return SE.getAddRecExpr(SE.getZeroExtendExpr(AR->getStart(), Ty),
14985 SE.getSignExtendExpr(Step, Ty), L,
14986 AR->getNoWrapFlags());
14987 }
14988 return SE.getZeroExtendExpr(Operand, Expr->getType());
14989 }
14990
14991 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
14992 const SCEV *Operand = visit(Expr->getOperand());
14993 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
14994 if (AR && AR->getLoop() == L && AR->isAffine()) {
14995 // This couldn't be folded because the operand didn't have the nsw
14996 // flag. Add the nssw flag as an assumption that we could make.
14997 const SCEV *Step = AR->getStepRecurrence(SE);
14998 Type *Ty = Expr->getType();
14999 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNSSW))
15000 return SE.getAddRecExpr(SE.getSignExtendExpr(AR->getStart(), Ty),
15001 SE.getSignExtendExpr(Step, Ty), L,
15002 AR->getNoWrapFlags());
15003 }
15004 return SE.getSignExtendExpr(Operand, Expr->getType());
15005 }
15006
15007private:
15008 explicit SCEVPredicateRewriter(
15009 const Loop *L, ScalarEvolution &SE,
15010 SmallVectorImpl<const SCEVPredicate *> *NewPreds,
15011 const SCEVPredicate *Pred)
15012 : SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {}
15013
15014 bool addOverflowAssumption(const SCEVPredicate *P) {
15015 if (!NewPreds) {
15016 // Check if we've already made this assumption.
15017 return Pred && Pred->implies(P, SE);
15018 }
15019 NewPreds->push_back(P);
15020 return true;
15021 }
15022
15023 bool addOverflowAssumption(const SCEVAddRecExpr *AR,
15025 auto *A = SE.getWrapPredicate(AR, AddedFlags);
15026 return addOverflowAssumption(A);
15027 }
15028
15029 // If \p Expr represents a PHINode, we try to see if it can be represented
15030 // as an AddRec, possibly under a predicate (PHISCEVPred). If it is possible
15031 // to add this predicate as a runtime overflow check, we return the AddRec.
15032 // If \p Expr does not meet these conditions (is not a PHI node, or we
15033 // couldn't create an AddRec for it, or couldn't add the predicate), we just
15034 // return \p Expr.
15035 const SCEV *convertToAddRecWithPreds(const SCEVUnknown *Expr) {
15036 if (!isa<PHINode>(Expr->getValue()))
15037 return Expr;
15038 std::optional<
15039 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
15040 PredicatedRewrite = SE.createAddRecFromPHIWithCasts(Expr);
15041 if (!PredicatedRewrite)
15042 return Expr;
15043 for (const auto *P : PredicatedRewrite->second){
15044 // Wrap predicates from outer loops are not supported.
15045 if (auto *WP = dyn_cast<const SCEVWrapPredicate>(P)) {
15046 if (L != WP->getExpr()->getLoop())
15047 return Expr;
15048 }
15049 if (!addOverflowAssumption(P))
15050 return Expr;
15051 }
15052 return PredicatedRewrite->first;
15053 }
15054
15055 SmallVectorImpl<const SCEVPredicate *> *NewPreds;
15056 const SCEVPredicate *Pred;
15057 const Loop *L;
15058};
15059
15060} // end anonymous namespace
15061
15062const SCEV *
15064 const SCEVPredicate &Preds) {
15065 return SCEVPredicateRewriter::rewrite(S, L, *this, nullptr, &Preds);
15066}
15067
15069 const SCEV *S, const Loop *L,
15072 S = SCEVPredicateRewriter::rewrite(S, L, *this, &TransformPreds, nullptr);
15073 auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
15074
15075 if (!AddRec)
15076 return nullptr;
15077
15078 // Check if any of the transformed predicates is known to be false. In that
15079 // case, it doesn't make sense to convert to a predicated AddRec, as the
15080 // versioned loop will never execute.
15081 for (const SCEVPredicate *Pred : TransformPreds) {
15082 auto *WrapPred = dyn_cast<SCEVWrapPredicate>(Pred);
15083 if (!WrapPred || WrapPred->getFlags() != SCEVWrapPredicate::IncrementNSSW)
15084 continue;
15085
15086 const SCEVAddRecExpr *AddRecToCheck = WrapPred->getExpr();
15087 const SCEV *ExitCount = getBackedgeTakenCount(AddRecToCheck->getLoop());
15088 if (isa<SCEVCouldNotCompute>(ExitCount))
15089 continue;
15090
15091 const SCEV *Step = AddRecToCheck->getStepRecurrence(*this);
15092 if (!Step->isOne())
15093 continue;
15094
15095 ExitCount = getTruncateOrSignExtend(ExitCount, Step->getType());
15096 const SCEV *Add = getAddExpr(AddRecToCheck->getStart(), ExitCount);
15097 if (isKnownPredicate(CmpInst::ICMP_SLT, Add, AddRecToCheck->getStart()))
15098 return nullptr;
15099 }
15100
15101 // Since the transformation was successful, we can now transfer the SCEV
15102 // predicates.
15103 Preds.append(TransformPreds.begin(), TransformPreds.end());
15104
15105 return AddRec;
15106}
15107
15108/// SCEV predicates
15112
15114 const ICmpInst::Predicate Pred,
15115 const SCEV *LHS, const SCEV *RHS)
15116 : SCEVPredicate(ID, P_Compare), Pred(Pred), LHS(LHS), RHS(RHS) {
15117 assert(LHS->getType() == RHS->getType() && "LHS and RHS types don't match");
15118 assert(LHS != RHS && "LHS and RHS are the same SCEV");
15119}
15120
15122 ScalarEvolution &SE) const {
15123 const auto *Op = dyn_cast<SCEVComparePredicate>(N);
15124
15125 if (!Op)
15126 return false;
15127
15128 if (Pred != ICmpInst::ICMP_EQ)
15129 return false;
15130
15131 return Op->LHS == LHS && Op->RHS == RHS;
15132}
15133
15134bool SCEVComparePredicate::isAlwaysTrue() const { return false; }
15135
15137 if (Pred == ICmpInst::ICMP_EQ)
15138 OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n";
15139 else
15140 OS.indent(Depth) << "Compare predicate: " << *LHS << " " << Pred << ") "
15141 << *RHS << "\n";
15142
15143}
15144
15146 const SCEVAddRecExpr *AR,
15147 IncrementWrapFlags Flags)
15148 : SCEVPredicate(ID, P_Wrap), AR(AR), Flags(Flags) {}
15149
15150const SCEVAddRecExpr *SCEVWrapPredicate::getExpr() const { return AR; }
15151
15153 ScalarEvolution &SE) const {
15154 const auto *Op = dyn_cast<SCEVWrapPredicate>(N);
15155 if (!Op || setFlags(Flags, Op->Flags) != Flags)
15156 return false;
15157
15158 if (Op->AR == AR)
15159 return true;
15160
15161 if (Flags != SCEVWrapPredicate::IncrementNSSW &&
15163 return false;
15164
15165 const SCEV *Start = AR->getStart();
15166 const SCEV *OpStart = Op->AR->getStart();
15167 if (Start->getType()->isPointerTy() != OpStart->getType()->isPointerTy())
15168 return false;
15169
15170 // Reject pointers to different address spaces.
15171 if (Start->getType()->isPointerTy() && Start->getType() != OpStart->getType())
15172 return false;
15173
15174 const SCEV *Step = AR->getStepRecurrence(SE);
15175 const SCEV *OpStep = Op->AR->getStepRecurrence(SE);
15176 if (!SE.isKnownPositive(Step) || !SE.isKnownPositive(OpStep))
15177 return false;
15178
15179 // If both steps are positive, this implies N, if N's start and step are
15180 // ULE/SLE (for NSUW/NSSW) than this'.
15181 Type *WiderTy = SE.getWiderType(Step->getType(), OpStep->getType());
15182 Step = SE.getNoopOrZeroExtend(Step, WiderTy);
15183 OpStep = SE.getNoopOrZeroExtend(OpStep, WiderTy);
15184
15185 bool IsNUW = Flags == SCEVWrapPredicate::IncrementNUSW;
15186 OpStart = IsNUW ? SE.getNoopOrZeroExtend(OpStart, WiderTy)
15187 : SE.getNoopOrSignExtend(OpStart, WiderTy);
15188 Start = IsNUW ? SE.getNoopOrZeroExtend(Start, WiderTy)
15189 : SE.getNoopOrSignExtend(Start, WiderTy);
15191 return SE.isKnownPredicate(Pred, OpStep, Step) &&
15192 SE.isKnownPredicate(Pred, OpStart, Start);
15193}
15194
15196 SCEV::NoWrapFlags ScevFlags = AR->getNoWrapFlags();
15197 IncrementWrapFlags IFlags = Flags;
15198
15199 if (ScalarEvolution::setFlags(ScevFlags, SCEV::FlagNSW) == ScevFlags)
15200 IFlags = clearFlags(IFlags, IncrementNSSW);
15201
15202 return IFlags == IncrementAnyWrap;
15203}
15204
15205void SCEVWrapPredicate::print(raw_ostream &OS, unsigned Depth) const {
15206 OS.indent(Depth) << *getExpr() << " Added Flags: ";
15208 OS << "<nusw>";
15210 OS << "<nssw>";
15211 OS << "\n";
15212}
15213
15216 ScalarEvolution &SE) {
15217 IncrementWrapFlags ImpliedFlags = IncrementAnyWrap;
15218 SCEV::NoWrapFlags StaticFlags = AR->getNoWrapFlags();
15219
15220 // We can safely transfer the NSW flag as NSSW.
15221 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNSW) == StaticFlags)
15222 ImpliedFlags = IncrementNSSW;
15223
15224 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNUW) == StaticFlags) {
15225 // If the increment is positive, the SCEV NUW flag will also imply the
15226 // WrapPredicate NUSW flag.
15227 if (const auto *Step = dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE)))
15228 if (Step->getValue()->getValue().isNonNegative())
15229 ImpliedFlags = setFlags(ImpliedFlags, IncrementNUSW);
15230 }
15231
15232 return ImpliedFlags;
15233}
15234
15235/// Union predicates don't get cached so create a dummy set ID for it.
15237 ScalarEvolution &SE)
15238 : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
15239 for (const auto *P : Preds)
15240 add(P, SE);
15241}
15242
15244 return all_of(Preds,
15245 [](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
15246}
15247
15249 ScalarEvolution &SE) const {
15250 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N))
15251 return all_of(Set->Preds, [this, &SE](const SCEVPredicate *I) {
15252 return this->implies(I, SE);
15253 });
15254
15255 return any_of(Preds,
15256 [N, &SE](const SCEVPredicate *I) { return I->implies(N, SE); });
15257}
15258
15260 for (const auto *Pred : Preds)
15261 Pred->print(OS, Depth);
15262}
15263
15264void SCEVUnionPredicate::add(const SCEVPredicate *N, ScalarEvolution &SE) {
15265 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) {
15266 for (const auto *Pred : Set->Preds)
15267 add(Pred, SE);
15268 return;
15269 }
15270
15271 // Implication checks are quadratic in the number of predicates. Stop doing
15272 // them if there are many predicates, as they should be too expensive to use
15273 // anyway at that point.
15274 bool CheckImplies = Preds.size() < 16;
15275
15276 // Only add predicate if it is not already implied by this union predicate.
15277 if (CheckImplies && implies(N, SE))
15278 return;
15279
15280 // Build a new vector containing the current predicates, except the ones that
15281 // are implied by the new predicate N.
15283 for (auto *P : Preds) {
15284 if (CheckImplies && N->implies(P, SE))
15285 continue;
15286 PrunedPreds.push_back(P);
15287 }
15288 Preds = std::move(PrunedPreds);
15289 Preds.push_back(N);
15290}
15291
15293 Loop &L)
15294 : SE(SE), L(L) {
15296 Preds = std::make_unique<SCEVUnionPredicate>(Empty, SE);
15297}
15298
15301 for (const auto *Op : Ops)
15302 // We do not expect that forgetting cached data for SCEVConstants will ever
15303 // open any prospects for sharpening or introduce any correctness issues,
15304 // so we don't bother storing their dependencies.
15305 if (!isa<SCEVConstant>(Op))
15306 SCEVUsers[Op].insert(User);
15307}
15308
15310 const SCEV *Expr = SE.getSCEV(V);
15311 return getPredicatedSCEV(Expr);
15312}
15313
15315 RewriteEntry &Entry = RewriteMap[Expr];
15316
15317 // If we already have an entry and the version matches, return it.
15318 if (Entry.second && Generation == Entry.first)
15319 return Entry.second;
15320
15321 // We found an entry but it's stale. Rewrite the stale entry
15322 // according to the current predicate.
15323 if (Entry.second)
15324 Expr = Entry.second;
15325
15326 const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, &L, *Preds);
15327 Entry = {Generation, NewSCEV};
15328
15329 return NewSCEV;
15330}
15331
15333 if (!BackedgeCount) {
15335 BackedgeCount = SE.getPredicatedBackedgeTakenCount(&L, Preds);
15336 for (const auto *P : Preds)
15337 addPredicate(*P);
15338 }
15339 return BackedgeCount;
15340}
15341
15343 if (!SymbolicMaxBackedgeCount) {
15345 SymbolicMaxBackedgeCount =
15346 SE.getPredicatedSymbolicMaxBackedgeTakenCount(&L, Preds);
15347 for (const auto *P : Preds)
15348 addPredicate(*P);
15349 }
15350 return SymbolicMaxBackedgeCount;
15351}
15352
15354 if (!SmallConstantMaxTripCount) {
15356 SmallConstantMaxTripCount = SE.getSmallConstantMaxTripCount(&L, &Preds);
15357 for (const auto *P : Preds)
15358 addPredicate(*P);
15359 }
15360 return *SmallConstantMaxTripCount;
15361}
15362
15364 if (Preds->implies(&Pred, SE))
15365 return;
15366
15367 SmallVector<const SCEVPredicate *, 4> NewPreds(Preds->getPredicates());
15368 NewPreds.push_back(&Pred);
15369 Preds = std::make_unique<SCEVUnionPredicate>(NewPreds, SE);
15370 updateGeneration();
15371}
15372
15374 return *Preds;
15375}
15376
15377void PredicatedScalarEvolution::updateGeneration() {
15378 // If the generation number wrapped recompute everything.
15379 if (++Generation == 0) {
15380 for (auto &II : RewriteMap) {
15381 const SCEV *Rewritten = II.second.second;
15382 II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, &L, *Preds)};
15383 }
15384 }
15385}
15386
15389 const SCEV *Expr = getSCEV(V);
15390 const auto *AR = cast<SCEVAddRecExpr>(Expr);
15391
15392 auto ImpliedFlags = SCEVWrapPredicate::getImpliedFlags(AR, SE);
15393
15394 // Clear the statically implied flags.
15395 Flags = SCEVWrapPredicate::clearFlags(Flags, ImpliedFlags);
15396 addPredicate(*SE.getWrapPredicate(AR, Flags));
15397
15398 auto II = FlagsMap.insert({V, Flags});
15399 if (!II.second)
15400 II.first->second = SCEVWrapPredicate::setFlags(Flags, II.first->second);
15401}
15402
15405 const SCEV *Expr = getSCEV(V);
15406 const auto *AR = cast<SCEVAddRecExpr>(Expr);
15407
15409 Flags, SCEVWrapPredicate::getImpliedFlags(AR, SE));
15410
15411 auto II = FlagsMap.find(V);
15412
15413 if (II != FlagsMap.end())
15414 Flags = SCEVWrapPredicate::clearFlags(Flags, II->second);
15415
15417}
15418
15420 const SCEV *Expr = this->getSCEV(V);
15422 auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, NewPreds);
15423
15424 if (!New)
15425 return nullptr;
15426
15427 for (const auto *P : NewPreds)
15428 addPredicate(*P);
15429
15430 RewriteMap[SE.getSCEV(V)] = {Generation, New};
15431 return New;
15432}
15433
15436 : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
15437 Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates(),
15438 SE)),
15439 Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
15440 for (auto I : Init.FlagsMap)
15441 FlagsMap.insert(I);
15442}
15443
15445 // For each block.
15446 for (auto *BB : L.getBlocks())
15447 for (auto &I : *BB) {
15448 if (!SE.isSCEVable(I.getType()))
15449 continue;
15450
15451 auto *Expr = SE.getSCEV(&I);
15452 auto II = RewriteMap.find(Expr);
15453
15454 if (II == RewriteMap.end())
15455 continue;
15456
15457 // Don't print things that are not interesting.
15458 if (II->second.second == Expr)
15459 continue;
15460
15461 OS.indent(Depth) << "[PSE]" << I << ":\n";
15462 OS.indent(Depth + 2) << *Expr << "\n";
15463 OS.indent(Depth + 2) << "--> " << *II->second.second << "\n";
15464 }
15465}
15466
15469 BasicBlock *Header = L->getHeader();
15470 BasicBlock *Pred = L->getLoopPredecessor();
15471 LoopGuards Guards(SE);
15472 if (!Pred)
15473 return Guards;
15475 collectFromBlock(SE, Guards, Header, Pred, VisitedBlocks);
15476 return Guards;
15477}
15478
15479void ScalarEvolution::LoopGuards::collectFromPHI(
15483 unsigned Depth) {
15484 if (!SE.isSCEVable(Phi.getType()))
15485 return;
15486
15487 using MinMaxPattern = std::pair<const SCEVConstant *, SCEVTypes>;
15488 auto GetMinMaxConst = [&](unsigned IncomingIdx) -> MinMaxPattern {
15489 const BasicBlock *InBlock = Phi.getIncomingBlock(IncomingIdx);
15490 if (!VisitedBlocks.insert(InBlock).second)
15491 return {nullptr, scCouldNotCompute};
15492
15493 // Avoid analyzing unreachable blocks so that we don't get trapped
15494 // traversing cycles with ill-formed dominance or infinite cycles
15495 if (!SE.DT.isReachableFromEntry(InBlock))
15496 return {nullptr, scCouldNotCompute};
15497
15498 auto [G, Inserted] = IncomingGuards.try_emplace(InBlock, LoopGuards(SE));
15499 if (Inserted)
15500 collectFromBlock(SE, G->second, Phi.getParent(), InBlock, VisitedBlocks,
15501 Depth + 1);
15502 auto &RewriteMap = G->second.RewriteMap;
15503 if (RewriteMap.empty())
15504 return {nullptr, scCouldNotCompute};
15505 auto S = RewriteMap.find(SE.getSCEV(Phi.getIncomingValue(IncomingIdx)));
15506 if (S == RewriteMap.end())
15507 return {nullptr, scCouldNotCompute};
15508 auto *SM = dyn_cast_if_present<SCEVMinMaxExpr>(S->second);
15509 if (!SM)
15510 return {nullptr, scCouldNotCompute};
15511 if (const SCEVConstant *C0 = dyn_cast<SCEVConstant>(SM->getOperand(0)))
15512 return {C0, SM->getSCEVType()};
15513 return {nullptr, scCouldNotCompute};
15514 };
15515 auto MergeMinMaxConst = [](MinMaxPattern P1,
15516 MinMaxPattern P2) -> MinMaxPattern {
15517 auto [C1, T1] = P1;
15518 auto [C2, T2] = P2;
15519 if (!C1 || !C2 || T1 != T2)
15520 return {nullptr, scCouldNotCompute};
15521 switch (T1) {
15522 case scUMaxExpr:
15523 return {C1->getAPInt().ult(C2->getAPInt()) ? C1 : C2, T1};
15524 case scSMaxExpr:
15525 return {C1->getAPInt().slt(C2->getAPInt()) ? C1 : C2, T1};
15526 case scUMinExpr:
15527 return {C1->getAPInt().ugt(C2->getAPInt()) ? C1 : C2, T1};
15528 case scSMinExpr:
15529 return {C1->getAPInt().sgt(C2->getAPInt()) ? C1 : C2, T1};
15530 default:
15531 llvm_unreachable("Trying to merge non-MinMaxExpr SCEVs.");
15532 }
15533 };
15534 auto P = GetMinMaxConst(0);
15535 for (unsigned int In = 1; In < Phi.getNumIncomingValues(); In++) {
15536 if (!P.first)
15537 break;
15538 P = MergeMinMaxConst(P, GetMinMaxConst(In));
15539 }
15540 if (P.first) {
15541 const SCEV *LHS = SE.getSCEV(const_cast<PHINode *>(&Phi));
15543 const SCEV *RHS = SE.getMinMaxExpr(P.second, Ops);
15544 Guards.RewriteMap.insert({LHS, RHS});
15545 }
15546}
15547
15548// Return a new SCEV that modifies \p Expr to the closest number divides by
15549// \p Divisor and less or equal than Expr. For now, only handle constant
15550// Expr.
15552 const APInt &DivisorVal,
15553 ScalarEvolution &SE) {
15554 const APInt *ExprVal;
15555 if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() ||
15556 DivisorVal.isNonPositive())
15557 return Expr;
15558 APInt Rem = ExprVal->urem(DivisorVal);
15559 // return the SCEV: Expr - Expr % Divisor
15560 return SE.getConstant(*ExprVal - Rem);
15561}
15562
15563// Return a new SCEV that modifies \p Expr to the closest number divides by
15564// \p Divisor and greater or equal than Expr. For now, only handle constant
15565// Expr.
15566static const SCEV *getNextSCEVDivisibleByDivisor(const SCEV *Expr,
15567 const APInt &DivisorVal,
15568 ScalarEvolution &SE) {
15569 const APInt *ExprVal;
15570 if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() ||
15571 DivisorVal.isNonPositive())
15572 return Expr;
15573 APInt Rem = ExprVal->urem(DivisorVal);
15574 if (Rem.isZero())
15575 return Expr;
15576 // return the SCEV: Expr + Divisor - Expr % Divisor
15577 return SE.getConstant(*ExprVal + DivisorVal - Rem);
15578}
15579
15581 ICmpInst::Predicate Predicate, const SCEV *LHS, const SCEV *RHS,
15584 // If we have LHS == 0, check if LHS is computing a property of some unknown
15585 // SCEV %v which we can rewrite %v to express explicitly.
15587 return false;
15588 // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
15589 // explicitly express that.
15590 const SCEVUnknown *URemLHS = nullptr;
15591 const SCEV *URemRHS = nullptr;
15592 if (!match(LHS, m_scev_URem(m_SCEVUnknown(URemLHS), m_SCEV(URemRHS), SE)))
15593 return false;
15594
15595 const SCEV *Multiple =
15596 SE.getMulExpr(SE.getUDivExpr(URemLHS, URemRHS), URemRHS);
15597 DivInfo[URemLHS] = Multiple;
15598 if (auto *C = dyn_cast<SCEVConstant>(URemRHS))
15599 Multiples[URemLHS] = C->getAPInt();
15600 return true;
15601}
15602
15603// Check if the condition is a divisibility guard (A % B == 0).
15604static bool isDivisibilityGuard(const SCEV *LHS, const SCEV *RHS,
15605 ScalarEvolution &SE) {
15606 const SCEV *X, *Y;
15607 return match(LHS, m_scev_URem(m_SCEV(X), m_SCEV(Y), SE)) && RHS->isZero();
15608}
15609
15610// Apply divisibility by \p Divisor on MinMaxExpr with constant values,
15611// recursively. This is done by aligning up/down the constant value to the
15612// Divisor.
15613static const SCEV *applyDivisibilityOnMinMaxExpr(const SCEV *MinMaxExpr,
15614 APInt Divisor,
15615 ScalarEvolution &SE) {
15616 // Return true if \p Expr is a MinMax SCEV expression with a non-negative
15617 // constant operand. If so, return in \p SCTy the SCEV type and in \p RHS
15618 // the non-constant operand and in \p LHS the constant operand.
15619 auto IsMinMaxSCEVWithNonNegativeConstant =
15620 [&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS,
15621 const SCEV *&RHS) {
15622 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
15623 if (MinMax->getNumOperands() != 2)
15624 return false;
15625 if (auto *C = dyn_cast<SCEVConstant>(MinMax->getOperand(0))) {
15626 if (C->getAPInt().isNegative())
15627 return false;
15628 SCTy = MinMax->getSCEVType();
15629 LHS = MinMax->getOperand(0);
15630 RHS = MinMax->getOperand(1);
15631 return true;
15632 }
15633 }
15634 return false;
15635 };
15636
15637 const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr;
15638 SCEVTypes SCTy;
15639 if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS,
15640 MinMaxRHS))
15641 return MinMaxExpr;
15642 auto IsMin = isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr);
15643 assert(SE.isKnownNonNegative(MinMaxLHS) && "Expected non-negative operand!");
15644 auto *DivisibleExpr =
15645 IsMin ? getPreviousSCEVDivisibleByDivisor(MinMaxLHS, Divisor, SE)
15646 : getNextSCEVDivisibleByDivisor(MinMaxLHS, Divisor, SE);
15648 applyDivisibilityOnMinMaxExpr(MinMaxRHS, Divisor, SE), DivisibleExpr};
15649 return SE.getMinMaxExpr(SCTy, Ops);
15650}
15651
15652void ScalarEvolution::LoopGuards::collectFromBlock(
15653 ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
15654 const BasicBlock *Block, const BasicBlock *Pred,
15655 SmallPtrSetImpl<const BasicBlock *> &VisitedBlocks, unsigned Depth) {
15656
15658
15659 SmallVector<const SCEV *> ExprsToRewrite;
15660 auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
15661 const SCEV *RHS,
15662 DenseMap<const SCEV *, const SCEV *> &RewriteMap,
15663 const LoopGuards &DivGuards) {
15664 // WARNING: It is generally unsound to apply any wrap flags to the proposed
15665 // replacement SCEV which isn't directly implied by the structure of that
15666 // SCEV. In particular, using contextual facts to imply flags is *NOT*
15667 // legal. See the scoping rules for flags in the header to understand why.
15668
15669 // Check for a condition of the form (-C1 + X < C2). InstCombine will
15670 // create this form when combining two checks of the form (X u< C2 + C1) and
15671 // (X >=u C1).
15672 auto MatchRangeCheckIdiom = [&SE, Predicate, LHS, RHS, &RewriteMap,
15673 &ExprsToRewrite]() {
15674 const SCEVConstant *C1;
15675 const SCEVUnknown *LHSUnknown;
15676 auto *C2 = dyn_cast<SCEVConstant>(RHS);
15677 if (!match(LHS,
15678 m_scev_Add(m_SCEVConstant(C1), m_SCEVUnknown(LHSUnknown))) ||
15679 !C2)
15680 return false;
15681
15682 auto ExactRegion =
15683 ConstantRange::makeExactICmpRegion(Predicate, C2->getAPInt())
15684 .sub(C1->getAPInt());
15685
15686 // Bail out, unless we have a non-wrapping, monotonic range.
15687 if (ExactRegion.isWrappedSet() || ExactRegion.isFullSet())
15688 return false;
15689 auto [I, Inserted] = RewriteMap.try_emplace(LHSUnknown);
15690 const SCEV *RewrittenLHS = Inserted ? LHSUnknown : I->second;
15691 I->second = SE.getUMaxExpr(
15692 SE.getConstant(ExactRegion.getUnsignedMin()),
15693 SE.getUMinExpr(RewrittenLHS,
15694 SE.getConstant(ExactRegion.getUnsignedMax())));
15695 ExprsToRewrite.push_back(LHSUnknown);
15696 return true;
15697 };
15698 if (MatchRangeCheckIdiom())
15699 return;
15700
15701 // Do not apply information for constants or if RHS contains an AddRec.
15703 return;
15704
15705 // If RHS is SCEVUnknown, make sure the information is applied to it.
15707 std::swap(LHS, RHS);
15709 }
15710
15711 // Puts rewrite rule \p From -> \p To into the rewrite map. Also if \p From
15712 // and \p FromRewritten are the same (i.e. there has been no rewrite
15713 // registered for \p From), then puts this value in the list of rewritten
15714 // expressions.
15715 auto AddRewrite = [&](const SCEV *From, const SCEV *FromRewritten,
15716 const SCEV *To) {
15717 if (From == FromRewritten)
15718 ExprsToRewrite.push_back(From);
15719 RewriteMap[From] = To;
15720 };
15721
15722 // Checks whether \p S has already been rewritten. In that case returns the
15723 // existing rewrite because we want to chain further rewrites onto the
15724 // already rewritten value. Otherwise returns \p S.
15725 auto GetMaybeRewritten = [&](const SCEV *S) {
15726 return RewriteMap.lookup_or(S, S);
15727 };
15728
15729 const SCEV *RewrittenLHS = GetMaybeRewritten(LHS);
15730 // Apply divisibility information when computing the constant multiple.
15731 const APInt &DividesBy =
15732 SE.getConstantMultiple(DivGuards.rewrite(RewrittenLHS));
15733
15734 // Collect rewrites for LHS and its transitive operands based on the
15735 // condition.
15736 // For min/max expressions, also apply the guard to its operands:
15737 // 'min(a, b) >= c' -> '(a >= c) and (b >= c)',
15738 // 'min(a, b) > c' -> '(a > c) and (b > c)',
15739 // 'max(a, b) <= c' -> '(a <= c) and (b <= c)',
15740 // 'max(a, b) < c' -> '(a < c) and (b < c)'.
15741
15742 // We cannot express strict predicates in SCEV, so instead we replace them
15743 // with non-strict ones against plus or minus one of RHS depending on the
15744 // predicate.
15745 const SCEV *One = SE.getOne(RHS->getType());
15746 switch (Predicate) {
15747 case CmpInst::ICMP_ULT:
15748 if (RHS->getType()->isPointerTy())
15749 return;
15750 RHS = SE.getUMaxExpr(RHS, One);
15751 [[fallthrough]];
15752 case CmpInst::ICMP_SLT: {
15753 RHS = SE.getMinusSCEV(RHS, One);
15754 RHS = getPreviousSCEVDivisibleByDivisor(RHS, DividesBy, SE);
15755 break;
15756 }
15757 case CmpInst::ICMP_UGT:
15758 case CmpInst::ICMP_SGT:
15759 RHS = SE.getAddExpr(RHS, One);
15760 RHS = getNextSCEVDivisibleByDivisor(RHS, DividesBy, SE);
15761 break;
15762 case CmpInst::ICMP_ULE:
15763 case CmpInst::ICMP_SLE:
15764 RHS = getPreviousSCEVDivisibleByDivisor(RHS, DividesBy, SE);
15765 break;
15766 case CmpInst::ICMP_UGE:
15767 case CmpInst::ICMP_SGE:
15768 RHS = getNextSCEVDivisibleByDivisor(RHS, DividesBy, SE);
15769 break;
15770 default:
15771 break;
15772 }
15773
15775 SmallPtrSet<const SCEV *, 16> Visited;
15776
15777 auto EnqueueOperands = [&Worklist](const SCEVNAryExpr *S) {
15778 append_range(Worklist, S->operands());
15779 };
15780
15781 while (!Worklist.empty()) {
15782 const SCEV *From = Worklist.pop_back_val();
15783 if (isa<SCEVConstant>(From))
15784 continue;
15785 if (!Visited.insert(From).second)
15786 continue;
15787 const SCEV *FromRewritten = GetMaybeRewritten(From);
15788 const SCEV *To = nullptr;
15789
15790 switch (Predicate) {
15791 case CmpInst::ICMP_ULT:
15792 case CmpInst::ICMP_ULE:
15793 To = SE.getUMinExpr(FromRewritten, RHS);
15794 if (auto *UMax = dyn_cast<SCEVUMaxExpr>(FromRewritten))
15795 EnqueueOperands(UMax);
15796 break;
15797 case CmpInst::ICMP_SLT:
15798 case CmpInst::ICMP_SLE:
15799 To = SE.getSMinExpr(FromRewritten, RHS);
15800 if (auto *SMax = dyn_cast<SCEVSMaxExpr>(FromRewritten))
15801 EnqueueOperands(SMax);
15802 break;
15803 case CmpInst::ICMP_UGT:
15804 case CmpInst::ICMP_UGE:
15805 To = SE.getUMaxExpr(FromRewritten, RHS);
15806 if (auto *UMin = dyn_cast<SCEVUMinExpr>(FromRewritten))
15807 EnqueueOperands(UMin);
15808 break;
15809 case CmpInst::ICMP_SGT:
15810 case CmpInst::ICMP_SGE:
15811 To = SE.getSMaxExpr(FromRewritten, RHS);
15812 if (auto *SMin = dyn_cast<SCEVSMinExpr>(FromRewritten))
15813 EnqueueOperands(SMin);
15814 break;
15815 case CmpInst::ICMP_EQ:
15817 To = RHS;
15818 break;
15819 case CmpInst::ICMP_NE:
15820 if (match(RHS, m_scev_Zero())) {
15821 const SCEV *OneAlignedUp =
15822 getNextSCEVDivisibleByDivisor(One, DividesBy, SE);
15823 To = SE.getUMaxExpr(FromRewritten, OneAlignedUp);
15824 } else {
15825 // LHS != RHS can be rewritten as (LHS - RHS) = UMax(1, LHS - RHS),
15826 // but creating the subtraction eagerly is expensive. Track the
15827 // inequalities in a separate map, and materialize the rewrite lazily
15828 // when encountering a suitable subtraction while re-writing.
15829 if (LHS->getType()->isPointerTy()) {
15833 break;
15834 }
15835 const SCEVConstant *C;
15836 const SCEV *A, *B;
15839 RHS = A;
15840 LHS = B;
15841 }
15842 if (LHS > RHS)
15843 std::swap(LHS, RHS);
15844 Guards.NotEqual.insert({LHS, RHS});
15845 continue;
15846 }
15847 break;
15848 default:
15849 break;
15850 }
15851
15852 if (To)
15853 AddRewrite(From, FromRewritten, To);
15854 }
15855 };
15856
15858 // First, collect information from assumptions dominating the loop.
15859 for (auto &AssumeVH : SE.AC.assumptions()) {
15860 if (!AssumeVH)
15861 continue;
15862 auto *AssumeI = cast<CallInst>(AssumeVH);
15863 if (!SE.DT.dominates(AssumeI, Block))
15864 continue;
15865 Terms.emplace_back(AssumeI->getOperand(0), true);
15866 }
15867
15868 // Second, collect information from llvm.experimental.guards dominating the loop.
15869 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
15870 SE.F.getParent(), Intrinsic::experimental_guard);
15871 if (GuardDecl)
15872 for (const auto *GU : GuardDecl->users())
15873 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
15874 if (Guard->getFunction() == Block->getParent() &&
15875 SE.DT.dominates(Guard, Block))
15876 Terms.emplace_back(Guard->getArgOperand(0), true);
15877
15878 // Third, collect conditions from dominating branches. Starting at the loop
15879 // predecessor, climb up the predecessor chain, as long as there are
15880 // predecessors that can be found that have unique successors leading to the
15881 // original header.
15882 // TODO: share this logic with isLoopEntryGuardedByCond.
15883 unsigned NumCollectedConditions = 0;
15885 std::pair<const BasicBlock *, const BasicBlock *> Pair(Pred, Block);
15886 for (; Pair.first;
15887 Pair = SE.getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
15888 VisitedBlocks.insert(Pair.second);
15889 const BranchInst *LoopEntryPredicate =
15890 dyn_cast<BranchInst>(Pair.first->getTerminator());
15891 if (!LoopEntryPredicate || LoopEntryPredicate->isUnconditional())
15892 continue;
15893
15894 Terms.emplace_back(LoopEntryPredicate->getCondition(),
15895 LoopEntryPredicate->getSuccessor(0) == Pair.second);
15896 NumCollectedConditions++;
15897
15898 // If we are recursively collecting guards stop after 2
15899 // conditions to limit compile-time impact for now.
15900 if (Depth > 0 && NumCollectedConditions == 2)
15901 break;
15902 }
15903 // Finally, if we stopped climbing the predecessor chain because
15904 // there wasn't a unique one to continue, try to collect conditions
15905 // for PHINodes by recursively following all of their incoming
15906 // blocks and try to merge the found conditions to build a new one
15907 // for the Phi.
15908 if (Pair.second->hasNPredecessorsOrMore(2) &&
15910 SmallDenseMap<const BasicBlock *, LoopGuards> IncomingGuards;
15911 for (auto &Phi : Pair.second->phis())
15912 collectFromPHI(SE, Guards, Phi, VisitedBlocks, IncomingGuards, Depth);
15913 }
15914
15915 // Now apply the information from the collected conditions to
15916 // Guards.RewriteMap. Conditions are processed in reverse order, so the
15917 // earliest conditions is processed first, except guards with divisibility
15918 // information, which are moved to the back. This ensures the SCEVs with the
15919 // shortest dependency chains are constructed first.
15921 GuardsToProcess;
15922 for (auto [Term, EnterIfTrue] : reverse(Terms)) {
15923 SmallVector<Value *, 8> Worklist;
15924 SmallPtrSet<Value *, 8> Visited;
15925 Worklist.push_back(Term);
15926 while (!Worklist.empty()) {
15927 Value *Cond = Worklist.pop_back_val();
15928 if (!Visited.insert(Cond).second)
15929 continue;
15930
15931 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
15932 auto Predicate =
15933 EnterIfTrue ? Cmp->getPredicate() : Cmp->getInversePredicate();
15934 const auto *LHS = SE.getSCEV(Cmp->getOperand(0));
15935 const auto *RHS = SE.getSCEV(Cmp->getOperand(1));
15936 // If LHS is a constant, apply information to the other expression.
15937 // TODO: If LHS is not a constant, check if using CompareSCEVComplexity
15938 // can improve results.
15939 if (isa<SCEVConstant>(LHS)) {
15940 std::swap(LHS, RHS);
15942 }
15943 GuardsToProcess.emplace_back(Predicate, LHS, RHS);
15944 continue;
15945 }
15946
15947 Value *L, *R;
15948 if (EnterIfTrue ? match(Cond, m_LogicalAnd(m_Value(L), m_Value(R)))
15949 : match(Cond, m_LogicalOr(m_Value(L), m_Value(R)))) {
15950 Worklist.push_back(L);
15951 Worklist.push_back(R);
15952 }
15953 }
15954 }
15955
15956 // Process divisibility guards in reverse order to populate DivGuards early.
15957 DenseMap<const SCEV *, APInt> Multiples;
15958 LoopGuards DivGuards(SE);
15959 for (const auto &[Predicate, LHS, RHS] : GuardsToProcess) {
15960 if (!isDivisibilityGuard(LHS, RHS, SE))
15961 continue;
15962 collectDivisibilityInformation(Predicate, LHS, RHS, DivGuards.RewriteMap,
15963 Multiples, SE);
15964 }
15965
15966 for (const auto &[Predicate, LHS, RHS] : GuardsToProcess)
15967 CollectCondition(Predicate, LHS, RHS, Guards.RewriteMap, DivGuards);
15968
15969 // Apply divisibility information last. This ensures it is applied to the
15970 // outermost expression after other rewrites for the given value.
15971 for (const auto &[K, Divisor] : Multiples) {
15972 const SCEV *DivisorSCEV = SE.getConstant(Divisor);
15973 Guards.RewriteMap[K] =
15975 Guards.rewrite(K), Divisor, SE),
15976 DivisorSCEV),
15977 DivisorSCEV);
15978 ExprsToRewrite.push_back(K);
15979 }
15980
15981 // Let the rewriter preserve NUW/NSW flags if the unsigned/signed ranges of
15982 // the replacement expressions are contained in the ranges of the replaced
15983 // expressions.
15984 Guards.PreserveNUW = true;
15985 Guards.PreserveNSW = true;
15986 for (const SCEV *Expr : ExprsToRewrite) {
15987 const SCEV *RewriteTo = Guards.RewriteMap[Expr];
15988 Guards.PreserveNUW &=
15989 SE.getUnsignedRange(Expr).contains(SE.getUnsignedRange(RewriteTo));
15990 Guards.PreserveNSW &=
15991 SE.getSignedRange(Expr).contains(SE.getSignedRange(RewriteTo));
15992 }
15993
15994 // Now that all rewrite information is collect, rewrite the collected
15995 // expressions with the information in the map. This applies information to
15996 // sub-expressions.
15997 if (ExprsToRewrite.size() > 1) {
15998 for (const SCEV *Expr : ExprsToRewrite) {
15999 const SCEV *RewriteTo = Guards.RewriteMap[Expr];
16000 Guards.RewriteMap.erase(Expr);
16001 Guards.RewriteMap.insert({Expr, Guards.rewrite(RewriteTo)});
16002 }
16003 }
16004}
16005
16007 /// A rewriter to replace SCEV expressions in Map with the corresponding entry
16008 /// in the map. It skips AddRecExpr because we cannot guarantee that the
16009 /// replacement is loop invariant in the loop of the AddRec.
16010 class SCEVLoopGuardRewriter
16011 : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
16014
16016
16017 public:
16018 SCEVLoopGuardRewriter(ScalarEvolution &SE,
16019 const ScalarEvolution::LoopGuards &Guards)
16020 : SCEVRewriteVisitor(SE), Map(Guards.RewriteMap),
16021 NotEqual(Guards.NotEqual) {
16022 if (Guards.PreserveNUW)
16023 FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNUW);
16024 if (Guards.PreserveNSW)
16025 FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNSW);
16026 }
16027
16028 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
16029
16030 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
16031 return Map.lookup_or(Expr, Expr);
16032 }
16033
16034 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
16035 if (const SCEV *S = Map.lookup(Expr))
16036 return S;
16037
16038 // If we didn't find the extact ZExt expr in the map, check if there's
16039 // an entry for a smaller ZExt we can use instead.
16040 Type *Ty = Expr->getType();
16041 const SCEV *Op = Expr->getOperand(0);
16042 unsigned Bitwidth = Ty->getScalarSizeInBits() / 2;
16043 while (Bitwidth % 8 == 0 && Bitwidth >= 8 &&
16044 Bitwidth > Op->getType()->getScalarSizeInBits()) {
16045 Type *NarrowTy = IntegerType::get(SE.getContext(), Bitwidth);
16046 auto *NarrowExt = SE.getZeroExtendExpr(Op, NarrowTy);
16047 if (const SCEV *S = Map.lookup(NarrowExt))
16048 return SE.getZeroExtendExpr(S, Ty);
16049 Bitwidth = Bitwidth / 2;
16050 }
16051
16053 Expr);
16054 }
16055
16056 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
16057 if (const SCEV *S = Map.lookup(Expr))
16058 return S;
16060 Expr);
16061 }
16062
16063 const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) {
16064 if (const SCEV *S = Map.lookup(Expr))
16065 return S;
16067 }
16068
16069 const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) {
16070 if (const SCEV *S = Map.lookup(Expr))
16071 return S;
16073 }
16074
16075 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
16076 // Helper to check if S is a subtraction (A - B) where A != B, and if so,
16077 // return UMax(S, 1).
16078 auto RewriteSubtraction = [&](const SCEV *S) -> const SCEV * {
16079 const SCEV *LHS, *RHS;
16080 if (MatchBinarySub(S, LHS, RHS)) {
16081 if (LHS > RHS)
16082 std::swap(LHS, RHS);
16083 if (NotEqual.contains({LHS, RHS})) {
16084 const SCEV *OneAlignedUp = getNextSCEVDivisibleByDivisor(
16085 SE.getOne(S->getType()), SE.getConstantMultiple(S), SE);
16086 return SE.getUMaxExpr(OneAlignedUp, S);
16087 }
16088 }
16089 return nullptr;
16090 };
16091
16092 // Check if Expr itself is a subtraction pattern with guard info.
16093 if (const SCEV *Rewritten = RewriteSubtraction(Expr))
16094 return Rewritten;
16095
16096 // Trip count expressions sometimes consist of adding 3 operands, i.e.
16097 // (Const + A + B). There may be guard info for A + B, and if so, apply
16098 // it.
16099 // TODO: Could more generally apply guards to Add sub-expressions.
16100 if (isa<SCEVConstant>(Expr->getOperand(0)) &&
16101 Expr->getNumOperands() == 3) {
16102 const SCEV *Add =
16103 SE.getAddExpr(Expr->getOperand(1), Expr->getOperand(2));
16104 if (const SCEV *Rewritten = RewriteSubtraction(Add))
16105 return SE.getAddExpr(
16106 Expr->getOperand(0), Rewritten,
16107 ScalarEvolution::maskFlags(Expr->getNoWrapFlags(), FlagMask));
16108 if (const SCEV *S = Map.lookup(Add))
16109 return SE.getAddExpr(Expr->getOperand(0), S);
16110 }
16112 bool Changed = false;
16113 for (const auto *Op : Expr->operands()) {
16114 Operands.push_back(
16116 Changed |= Op != Operands.back();
16117 }
16118 // We are only replacing operands with equivalent values, so transfer the
16119 // flags from the original expression.
16120 return !Changed ? Expr
16121 : SE.getAddExpr(Operands,
16123 Expr->getNoWrapFlags(), FlagMask));
16124 }
16125
16126 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
16128 bool Changed = false;
16129 for (const auto *Op : Expr->operands()) {
16130 Operands.push_back(
16132 Changed |= Op != Operands.back();
16133 }
16134 // We are only replacing operands with equivalent values, so transfer the
16135 // flags from the original expression.
16136 return !Changed ? Expr
16137 : SE.getMulExpr(Operands,
16139 Expr->getNoWrapFlags(), FlagMask));
16140 }
16141 };
16142
16143 if (RewriteMap.empty() && NotEqual.empty())
16144 return Expr;
16145
16146 SCEVLoopGuardRewriter Rewriter(SE, *this);
16147 return Rewriter.visit(Expr);
16148}
16149
16150const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
16151 return applyLoopGuards(Expr, LoopGuards::collect(L, *this));
16152}
16153
16155 const LoopGuards &Guards) {
16156 return Guards.rewrite(Expr);
16157}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
constexpr LLT S1
Rewrite undef for PHI
This file implements a class to represent arbitrary precision integral constant values and operations...
@ PostInc
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Expand Atomic instructions
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
#define LLVM_DUMP_METHOD
Mark debug helper function definitions like dump() that should not be stripped from debug builds.
Definition Compiler.h:661
This file contains the declarations for the subclasses of Constant, which represent the different fla...
SmallPtrSet< const BasicBlock *, 8 > VisitedBlocks
This file defines the DenseMap class.
This file builds on the ADT/GraphTraits.h file to build generic depth first graph iterator.
This file defines a hash set that can be used to remove duplication of nodes in a graph.
#define op(i)
Hexagon Common GEP
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
This defines the Use class.
iv Induction Variable Users
Definition IVUsers.cpp:48
const AbstractManglingParser< Derived, Alloc >::OperatorInfo AbstractManglingParser< Derived, Alloc >::Ops[]
static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, AssumptionCache *AC)
Definition Lint.cpp:539
#define F(x, y, z)
Definition MD5.cpp:54
#define I(x, y, z)
Definition MD5.cpp:57
#define G(x, y, z)
Definition MD5.cpp:55
#define T
#define T1
ConstantRange Range(APInt(BitWidth, Low), APInt(BitWidth, High))
uint64_t IntrinsicInst * II
#define P(N)
ppc ctr loops verify
PowerPC Reduce CR logical Operation
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition PassSupport.h:42
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition PassSupport.h:44
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition PassSupport.h:39
R600 Clause Merge
const SmallVectorImpl< MachineOperand > & Cond
static bool isValid(const char C)
Returns true if C is a valid mangled character: <0-9a-zA-Z_>.
SI optimize exec mask operations pre RA
void visit(MachineFunction &MF, MachineBasicBlock &Start, std::function< void(MachineBasicBlock *)> op)
This file contains some templates that are useful if you are working with the STL at all.
This file provides utility classes that use RAII to save and restore values.
bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind, SCEVTypes RootKind)
static cl::opt< unsigned > MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden, cl::desc("Max coefficients in AddRec during evolving"), cl::init(8))
static cl::opt< unsigned > RangeIterThreshold("scev-range-iter-threshold", cl::Hidden, cl::desc("Threshold for switching to iteratively computing SCEV ranges"), cl::init(32))
static const Loop * isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI)
static unsigned getConstantTripCount(const SCEVConstant *ExitCount)
static int CompareValueComplexity(const LoopInfo *const LI, Value *LV, Value *RV, unsigned Depth)
Compare the two values LV and RV in terms of their "complexity" where "complexity" is a partial (and ...
static const SCEV * getNextSCEVDivisibleByDivisor(const SCEV *Expr, const APInt &DivisorVal, ScalarEvolution &SE)
static void PushLoopPHIs(const Loop *L, SmallVectorImpl< Instruction * > &Worklist, SmallPtrSetImpl< Instruction * > &Visited)
Push PHI nodes in the header of the given loop onto the given Worklist.
static void insertFoldCacheEntry(const ScalarEvolution::FoldID &ID, const SCEV *S, DenseMap< ScalarEvolution::FoldID, const SCEV * > &FoldCache, DenseMap< const SCEV *, SmallVector< ScalarEvolution::FoldID, 2 > > &FoldCacheUser)
static cl::opt< bool > ClassifyExpressions("scalar-evolution-classify-expressions", cl::Hidden, cl::init(true), cl::desc("When printing analysis, include information on every instruction"))
static bool CanConstantFold(const Instruction *I)
Return true if we can constant fold an instruction of the specified type, assuming that all operands ...
static cl::opt< unsigned > AddOpsInlineThreshold("scev-addops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining addition operands into a SCEV"), cl::init(500))
static cl::opt< unsigned > MaxLoopGuardCollectionDepth("scalar-evolution-max-loop-guard-collection-depth", cl::Hidden, cl::desc("Maximum depth for recursive loop guard collection"), cl::init(1))
static cl::opt< bool > VerifyIR("scev-verify-ir", cl::Hidden, cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"), cl::init(false))
static bool BrPHIToSelect(DominatorTree &DT, BranchInst *BI, PHINode *Merge, Value *&C, Value *&LHS, Value *&RHS)
static const SCEV * getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static std::optional< APInt > MinOptional(std::optional< APInt > X, std::optional< APInt > Y)
Helper function to compare optional APInts: (a) if X and Y both exist, return min(X,...
static cl::opt< unsigned > MulOpsInlineThreshold("scev-mulops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining multiplication operands into a SCEV"), cl::init(32))
static void GroupByComplexity(SmallVectorImpl< const SCEV * > &Ops, LoopInfo *LI, DominatorTree &DT)
Given a list of SCEV objects, order them by their complexity, and group objects of the same complexit...
static const SCEV * constantFoldAndGroupOps(ScalarEvolution &SE, LoopInfo &LI, DominatorTree &DT, SmallVectorImpl< const SCEV * > &Ops, FoldT Fold, IsIdentityT IsIdentity, IsAbsorberT IsAbsorber)
Performs a number of common optimizations on the passed Ops.
static bool isDivisibilityGuard(const SCEV *LHS, const SCEV *RHS, ScalarEvolution &SE)
static std::optional< const SCEV * > createNodeForSelectViaUMinSeq(ScalarEvolution *SE, const SCEV *CondExpr, const SCEV *TrueExpr, const SCEV *FalseExpr)
static Constant * BuildConstantFromSCEV(const SCEV *V)
This builds up a Constant using the ConstantExpr interface.
static ConstantInt * EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C, ScalarEvolution &SE)
static const SCEV * BinomialCoefficient(const SCEV *It, unsigned K, ScalarEvolution &SE, Type *ResultTy)
Compute BC(It, K). The result has width W. Assume, K > 0.
static cl::opt< unsigned > MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden, cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"), cl::init(8))
static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr, const SCEV *Candidate)
Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
static PHINode * getConstantEvolvingPHI(Value *V, const Loop *L)
getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node in the loop that V is deri...
static const SCEV * SolveLinEquationWithOverflow(const APInt &A, const SCEV *B, SmallVectorImpl< const SCEVPredicate * > *Predicates, ScalarEvolution &SE, const Loop *L)
Finds the minimum unsigned root of the following equation:
static cl::opt< unsigned > MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden, cl::desc("Maximum number of iterations SCEV will " "symbolically execute a constant " "derived loop"), cl::init(100))
static bool MatchBinarySub(const SCEV *S, const SCEV *&LHS, const SCEV *&RHS)
static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow)
static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV *S)
When printing a top-level SCEV for trip counts, it's helpful to include a type for constants which ar...
static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, const Loop *L)
static bool containsConstantInAddMulChain(const SCEV *StartExpr)
Determine if any of the operands in this SCEV are a constant or if any of the add or multiply express...
static const SCEV * getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static bool hasHugeExpression(ArrayRef< const SCEV * > Ops)
Returns true if Ops contains a huge SCEV (the subtree of S contains at least HugeExprThreshold nodes)...
static cl::opt< unsigned > MaxPhiSCCAnalysisSize("scalar-evolution-max-scc-analysis-depth", cl::Hidden, cl::desc("Maximum amount of nodes to process while searching SCEVUnknown " "Phi strongly connected components"), cl::init(8))
static bool IsKnownPredicateViaAddRecStart(ScalarEvolution &SE, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
static bool collectDivisibilityInformation(ICmpInst::Predicate Predicate, const SCEV *LHS, const SCEV *RHS, DenseMap< const SCEV *, const SCEV * > &DivInfo, DenseMap< const SCEV *, APInt > &Multiples, ScalarEvolution &SE)
static cl::opt< unsigned > MaxSCEVOperationsImplicationDepth("scalar-evolution-max-scev-operations-implication-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV operations implication analysis"), cl::init(2))
static void PushDefUseChildren(Instruction *I, SmallVectorImpl< Instruction * > &Worklist, SmallPtrSetImpl< Instruction * > &Visited)
Push users of the given Instruction onto the given Worklist.
static std::optional< APInt > SolveQuadraticAddRecRange(const SCEVAddRecExpr *AddRec, const ConstantRange &Range, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n iterations.
static cl::opt< bool > UseContextForNoWrapFlagInference("scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden, cl::desc("Infer nuw/nsw flags using context where suitable"), cl::init(true))
static cl::opt< bool > EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden, cl::desc("Handle <= and >= in finite loops"), cl::init(true))
static std::optional< std::tuple< APInt, APInt, APInt, APInt, unsigned > > GetQuadraticEquation(const SCEVAddRecExpr *AddRec)
For a given quadratic addrec, generate coefficients of the corresponding quadratic equation,...
static bool isKnownPredicateExtendIdiom(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
static std::optional< BinaryOp > MatchBinaryOp(Value *V, const DataLayout &DL, AssumptionCache &AC, const DominatorTree &DT, const Instruction *CxtI)
Try to map V into a BinaryOp, and return std::nullopt on failure.
static std::optional< APInt > SolveQuadraticAddRecExact(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n iterations.
static std::optional< APInt > TruncIfPossible(std::optional< APInt > X, unsigned BitWidth)
Helper function to truncate an optional APInt to a given BitWidth.
static cl::opt< unsigned > MaxSCEVCompareDepth("scalar-evolution-max-scev-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV complexity comparisons"), cl::init(32))
static APInt extractConstantWithoutWrapping(ScalarEvolution &SE, const SCEVConstant *ConstantTerm, const SCEVAddExpr *WholeAddExpr)
static cl::opt< unsigned > MaxConstantEvolvingDepth("scalar-evolution-max-constant-evolving-depth", cl::Hidden, cl::desc("Maximum depth of recursive constant evolving"), cl::init(32))
static ConstantRange getRangeForAffineARHelper(APInt Step, const ConstantRange &StartRange, const APInt &MaxBECount, bool Signed)
static std::optional< ConstantRange > GetRangeFromMetadata(Value *V)
Helper method to assign a range to V from metadata present in the IR.
static cl::opt< unsigned > HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden, cl::desc("Size of the expression which is considered huge"), cl::init(4096))
static SCEV::NoWrapFlags StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, ArrayRef< const SCEV * > Ops, SCEV::NoWrapFlags Flags)
static Type * isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI, bool &Signed, ScalarEvolution &SE)
Helper function to createAddRecFromPHIWithCasts.
static Constant * EvaluateExpression(Value *V, const Loop *L, DenseMap< Instruction *, Constant * > &Vals, const DataLayout &DL, const TargetLibraryInfo *TLI)
EvaluateExpression - Given an expression that passes the getConstantEvolvingPHI predicate,...
static const SCEV * getPreviousSCEVDivisibleByDivisor(const SCEV *Expr, const APInt &DivisorVal, ScalarEvolution &SE)
static const SCEV * MatchNotExpr(const SCEV *Expr)
If Expr computes ~A, return A else return nullptr.
static cl::opt< unsigned > MaxValueCompareDepth("scalar-evolution-max-value-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive value complexity comparisons"), cl::init(2))
static const SCEV * applyDivisibilityOnMinMaxExpr(const SCEV *MinMaxExpr, APInt Divisor, ScalarEvolution &SE)
static cl::opt< bool, true > VerifySCEVOpt("verify-scev", cl::Hidden, cl::location(VerifySCEV), cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"))
static const SCEV * getSignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
static cl::opt< unsigned > MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden, cl::desc("Maximum depth of recursive arithmetics"), cl::init(32))
static bool HasSameValue(const SCEV *A, const SCEV *B)
SCEV structural equivalence is usually sufficient for testing whether two expressions are equal,...
static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow)
Compute the result of "n choose k", the binomial coefficient.
static std::optional< int > CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS, const SCEV *RHS, DominatorTree &DT, unsigned Depth=0)
static bool CollectAddOperandsWithScales(SmallDenseMap< const SCEV *, APInt, 16 > &M, SmallVectorImpl< const SCEV * > &NewOps, APInt &AccumulatedConstant, ArrayRef< const SCEV * > Ops, const APInt &Scale, ScalarEvolution &SE)
Process the given Ops list, which is a list of operands to be added under the given scale,...
static bool canConstantEvolve(Instruction *I, const Loop *L)
Determine whether this instruction can constant evolve within this loop assuming its operands can all...
static PHINode * getConstantEvolvingPHIOperands(Instruction *UseInst, const Loop *L, DenseMap< Instruction *, PHINode * > &PHIMap, unsigned Depth)
getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by recursing through each instructi...
static bool scevUnconditionallyPropagatesPoisonFromOperands(SCEVTypes Kind)
static cl::opt< bool > VerifySCEVStrict("verify-scev-strict", cl::Hidden, cl::desc("Enable stricter verification with -verify-scev is passed"))
static Constant * getOtherIncomingValue(PHINode *PN, BasicBlock *BB)
static cl::opt< bool > UseExpensiveRangeSharpening("scalar-evolution-use-expensive-range-sharpening", cl::Hidden, cl::init(false), cl::desc("Use more powerful methods of sharpening expression ranges. May " "be costly in terms of compile time"))
static const SCEV * getUnsignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Is LHS Pred RHS true on the virtue of LHS or RHS being a Min or Max expression?
This file defines the make_scope_exit function, which executes user-defined cleanup logic at scope ex...
static bool InBlock(const Value *V, const BasicBlock *BB)
Provides some synthesis utilities to produce sequences of values.
This file defines the SmallPtrSet class.
This file defines the SmallVector class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition Statistic.h:171
This file contains some functions that are useful when dealing with strings.
#define LLVM_DEBUG(...)
Definition Debug.h:114
static TableGen::Emitter::Opt Y("gen-skeleton-entry", EmitSkeleton, "Generate example skeleton entry")
static TableGen::Emitter::OptClass< SkeletonEmitter > X("gen-skeleton-class", "Generate example skeleton class")
static SymbolRef::Type getType(const Symbol *Sym)
Definition TapiFile.cpp:39
LocallyHashedType DenseMapInfo< LocallyHashedType >::Empty
static std::optional< unsigned > getOpcode(ArrayRef< VPValue * > Values)
Returns the opcode of Values or ~0 if they do not all agree.
Definition VPlanSLP.cpp:247
static std::optional< bool > isImpliedCondOperands(CmpInst::Predicate Pred, const Value *ALHS, const Value *ARHS, const Value *BLHS, const Value *BRHS)
Return true if "icmp Pred BLHS BRHS" is true whenever "icmp PredALHS ARHS" is true.
Virtual Register Rewriter
Value * RHS
Value * LHS
BinaryOperator * Mul
static const uint32_t IV[8]
Definition blake3_impl.h:83
SCEVCastSinkingRewriter(ScalarEvolution &SE, Type *TargetTy, ConversionFn CreatePtrCast)
static const SCEV * rewrite(const SCEV *Scev, ScalarEvolution &SE, Type *TargetTy, ConversionFn CreatePtrCast)
const SCEV * visitUnknown(const SCEVUnknown *Expr)
const SCEV * visitMulExpr(const SCEVMulExpr *Expr)
const SCEV * visitAddExpr(const SCEVAddExpr *Expr)
const SCEV * visit(const SCEV *S)
Class for arbitrary precision integers.
Definition APInt.h:78
LLVM_ABI APInt umul_ov(const APInt &RHS, bool &Overflow) const
Definition APInt.cpp:1982
LLVM_ABI APInt zext(unsigned width) const
Zero extend to a new width.
Definition APInt.cpp:1023
bool isMinSignedValue() const
Determine if this is the smallest signed value.
Definition APInt.h:424
uint64_t getZExtValue() const
Get zero extended value.
Definition APInt.h:1549
void setHighBits(unsigned hiBits)
Set the top hiBits bits.
Definition APInt.h:1400
LLVM_ABI APInt getHiBits(unsigned numBits) const
Compute an APInt containing numBits highbits from this APInt.
Definition APInt.cpp:639
unsigned getActiveBits() const
Compute the number of active bits in the value.
Definition APInt.h:1521
LLVM_ABI APInt trunc(unsigned width) const
Truncate to new width.
Definition APInt.cpp:936
static APInt getMaxValue(unsigned numBits)
Gets maximum unsigned value of APInt for specific bit width.
Definition APInt.h:207
APInt abs() const
Get the absolute value.
Definition APInt.h:1804
bool sgt(const APInt &RHS) const
Signed greater than comparison.
Definition APInt.h:1202
bool isAllOnes() const
Determine if all bits are set. This is true for zero-width values.
Definition APInt.h:372
bool ugt(const APInt &RHS) const
Unsigned greater than comparison.
Definition APInt.h:1183
bool isZero() const
Determine if this value is zero, i.e. all bits are clear.
Definition APInt.h:381
bool isSignMask() const
Check if the APInt's value is returned by getSignMask.
Definition APInt.h:467
LLVM_ABI APInt urem(const APInt &RHS) const
Unsigned remainder operation.
Definition APInt.cpp:1677
unsigned getBitWidth() const
Return the number of bits in the APInt.
Definition APInt.h:1497
bool ult(const APInt &RHS) const
Unsigned less than comparison.
Definition APInt.h:1112
static APInt getSignedMaxValue(unsigned numBits)
Gets maximum signed value of APInt for a specific bit width.
Definition APInt.h:210
static APInt getMinValue(unsigned numBits)
Gets minimum unsigned value of APInt for a specific bit width.
Definition APInt.h:217
bool isNegative() const
Determine sign of this APInt.
Definition APInt.h:330
bool sle(const APInt &RHS) const
Signed less or equal comparison.
Definition APInt.h:1167
static APInt getSignedMinValue(unsigned numBits)
Gets minimum signed value of APInt for a specific bit width.
Definition APInt.h:220
bool isNonPositive() const
Determine if this APInt Value is non-positive (<= 0).
Definition APInt.h:362
unsigned countTrailingZeros() const
Definition APInt.h:1656
bool isStrictlyPositive() const
Determine if this APInt Value is positive.
Definition APInt.h:357
unsigned logBase2() const
Definition APInt.h:1770
APInt ashr(unsigned ShiftAmt) const
Arithmetic right-shift function.
Definition APInt.h:828
LLVM_ABI APInt multiplicativeInverse() const
Definition APInt.cpp:1285
bool ule(const APInt &RHS) const
Unsigned less or equal comparison.
Definition APInt.h:1151
LLVM_ABI APInt sext(unsigned width) const
Sign extend to a new width.
Definition APInt.cpp:996
APInt shl(unsigned shiftAmt) const
Left-shift function.
Definition APInt.h:874
bool isPowerOf2() const
Check if this APInt's value is a power of two greater than zero.
Definition APInt.h:441
static APInt getLowBitsSet(unsigned numBits, unsigned loBitsSet)
Constructs an APInt value that has the bottom loBitsSet bits set.
Definition APInt.h:307
bool isSignBitSet() const
Determine if sign bit of this APInt is set.
Definition APInt.h:342
bool slt(const APInt &RHS) const
Signed less than comparison.
Definition APInt.h:1131
static APInt getZero(unsigned numBits)
Get the '0' value for the specified bit-width.
Definition APInt.h:201
bool isIntN(unsigned N) const
Check if this APInt has an N-bits unsigned integer value.
Definition APInt.h:433
static APInt getOneBitSet(unsigned numBits, unsigned BitNo)
Return an APInt with exactly one bit set in the result.
Definition APInt.h:240
bool uge(const APInt &RHS) const
Unsigned greater or equal comparison.
Definition APInt.h:1222
This templated class represents "all analyses that operate over <aparticular IR unit>" (e....
Definition Analysis.h:50
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Represent the analysis usage information of a pass.
void setPreservesAll()
Set by analyses that do not transform their input at all.
AnalysisUsage & addRequiredTransitive()
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition ArrayRef.h:40
iterator end() const
Definition ArrayRef.h:131
size_t size() const
size - Get the array size.
Definition ArrayRef.h:142
iterator begin() const
Definition ArrayRef.h:130
A function analysis which provides an AssumptionCache.
An immutable pass that tracks lazily created AssumptionCache objects.
A cache of @llvm.assume calls within a function.
MutableArrayRef< WeakVH > assumptions()
Access the list of assumption handles currently tracked for this function.
LLVM_ABI bool isSingleEdge() const
Check if this is the only edge between Start and End.
LLVM Basic Block Representation.
Definition BasicBlock.h:62
iterator begin()
Instruction iterator methods.
Definition BasicBlock.h:470
const Function * getParent() const
Return the enclosing method, or null if none.
Definition BasicBlock.h:213
LLVM_ABI const BasicBlock * getSinglePredecessor() const
Return the predecessor of this block if it has a single predecessor block.
const Instruction & front() const
Definition BasicBlock.h:493
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition BasicBlock.h:233
LLVM_ABI unsigned getNoWrapKind() const
Returns one of OBO::NoSignedWrap or OBO::NoUnsignedWrap.
LLVM_ABI Instruction::BinaryOps getBinaryOp() const
Returns the binary operation underlying the intrinsic.
BinaryOps getOpcode() const
Definition InstrTypes.h:374
Conditional or Unconditional Branch instruction.
bool isConditional() const
BasicBlock * getSuccessor(unsigned i) const
bool isUnconditional() const
Value * getCondition() const
This class represents a function call, abstracting a target machine's calling convention.
virtual void deleted()
Callback for Value destruction.
void setValPtr(Value *P)
bool isFalseWhenEqual() const
This is just a convenience.
Definition InstrTypes.h:948
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition InstrTypes.h:676
@ ICMP_SLT
signed less than
Definition InstrTypes.h:705
@ ICMP_SLE
signed less or equal
Definition InstrTypes.h:706
@ ICMP_UGE
unsigned greater or equal
Definition InstrTypes.h:700
@ ICMP_UGT
unsigned greater than
Definition InstrTypes.h:699
@ ICMP_SGT
signed greater than
Definition InstrTypes.h:703
@ ICMP_ULT
unsigned less than
Definition InstrTypes.h:701
@ ICMP_NE
not equal
Definition InstrTypes.h:698
@ ICMP_SGE
signed greater or equal
Definition InstrTypes.h:704
@ ICMP_ULE
unsigned less or equal
Definition InstrTypes.h:702
bool isSigned() const
Definition InstrTypes.h:930
Predicate getSwappedPredicate() const
For example, EQ->EQ, SLE->SGE, ULT->UGT, OEQ->OEQ, ULE->UGE, OLT->OGT, etc.
Definition InstrTypes.h:827
bool isTrueWhenEqual() const
This is just a convenience.
Definition InstrTypes.h:942
Predicate getInversePredicate() const
For example, EQ -> NE, UGT -> ULE, SLT -> SGE, OEQ -> UNE, UGT -> OLE, OLT -> UGE,...
Definition InstrTypes.h:789
bool isUnsigned() const
Definition InstrTypes.h:936
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
Definition InstrTypes.h:926
An abstraction over a floating-point predicate, and a pack of an integer predicate with samesign info...
static LLVM_ABI std::optional< CmpPredicate > getMatching(CmpPredicate A, CmpPredicate B)
Compares two CmpPredicates taking samesign into account and returns the canonicalized CmpPredicate if...
LLVM_ABI CmpInst::Predicate getPreferredSignedPredicate() const
Attempts to return a signed CmpInst::Predicate from the CmpPredicate.
CmpInst::Predicate dropSameSign() const
Drops samesign information.
static LLVM_ABI Constant * getNot(Constant *C)
static LLVM_ABI Constant * getPtrToInt(Constant *C, Type *Ty, bool OnlyIfReduced=false)
static LLVM_ABI Constant * getPtrToAddr(Constant *C, Type *Ty, bool OnlyIfReduced=false)
static Constant * getGetElementPtr(Type *Ty, Constant *C, ArrayRef< Constant * > IdxList, GEPNoWrapFlags NW=GEPNoWrapFlags::none(), std::optional< ConstantRange > InRange=std::nullopt, Type *OnlyIfReducedTy=nullptr)
Getelementptr form.
Definition Constants.h:1284
static LLVM_ABI Constant * getAdd(Constant *C1, Constant *C2, bool HasNUW=false, bool HasNSW=false)
static LLVM_ABI Constant * getNeg(Constant *C, bool HasNSW=false)
static LLVM_ABI Constant * getTrunc(Constant *C, Type *Ty, bool OnlyIfReduced=false)
This is the shared class of boolean and integer constants.
Definition Constants.h:87
bool isZero() const
This is just a convenience method to make client code smaller for a common code.
Definition Constants.h:219
static LLVM_ABI ConstantInt * getFalse(LLVMContext &Context)
uint64_t getZExtValue() const
Return the constant as a 64-bit unsigned integer value after it has been zero extended as appropriate...
Definition Constants.h:168
const APInt & getValue() const
Return the constant as an APInt value reference.
Definition Constants.h:159
static LLVM_ABI ConstantInt * getBool(LLVMContext &Context, bool V)
This class represents a range of values.
LLVM_ABI ConstantRange add(const ConstantRange &Other) const
Return a new range representing the possible values resulting from an addition of a value in this ran...
LLVM_ABI ConstantRange zextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
PreferredRangeType
If represented precisely, the result of some range operations may consist of multiple disjoint ranges...
LLVM_ABI bool getEquivalentICmp(CmpInst::Predicate &Pred, APInt &RHS) const
Set up Pred and RHS such that ConstantRange::makeExactICmpRegion(Pred, RHS) == *this.
const APInt & getLower() const
Return the lower value for this range.
LLVM_ABI ConstantRange urem(const ConstantRange &Other) const
Return a new range representing the possible values resulting from an unsigned remainder operation of...
LLVM_ABI bool isFullSet() const
Return true if this set contains all of the elements possible for this data-type.
LLVM_ABI bool icmp(CmpInst::Predicate Pred, const ConstantRange &Other) const
Does the predicate Pred hold between ranges this and Other?
LLVM_ABI bool isEmptySet() const
Return true if this set contains no members.
LLVM_ABI ConstantRange zeroExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
LLVM_ABI bool isSignWrappedSet() const
Return true if this set wraps around the signed domain.
LLVM_ABI APInt getSignedMin() const
Return the smallest signed value contained in the ConstantRange.
LLVM_ABI bool isWrappedSet() const
Return true if this set wraps around the unsigned domain.
LLVM_ABI void print(raw_ostream &OS) const
Print out the bounds to a stream.
LLVM_ABI ConstantRange truncate(uint32_t BitWidth, unsigned NoWrapKind=0) const
Return a new range in the specified integer type, which must be strictly smaller than the current typ...
LLVM_ABI ConstantRange signExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
const APInt & getUpper() const
Return the upper value for this range.
LLVM_ABI ConstantRange unionWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the union of this range with another range.
static LLVM_ABI ConstantRange makeExactICmpRegion(CmpInst::Predicate Pred, const APInt &Other)
Produce the exact range such that all values in the returned range satisfy the given predicate with a...
LLVM_ABI bool contains(const APInt &Val) const
Return true if the specified value is in the set.
LLVM_ABI APInt getUnsignedMax() const
Return the largest unsigned value contained in the ConstantRange.
LLVM_ABI ConstantRange intersectWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the intersection of this range with another range.
LLVM_ABI APInt getSignedMax() const
Return the largest signed value contained in the ConstantRange.
static ConstantRange getNonEmpty(APInt Lower, APInt Upper)
Create non-empty constant range with the given bounds.
static LLVM_ABI ConstantRange makeGuaranteedNoWrapRegion(Instruction::BinaryOps BinOp, const ConstantRange &Other, unsigned NoWrapKind)
Produce the largest range containing all X such that "X BinOp Y" is guaranteed not to wrap (overflow)...
LLVM_ABI unsigned getMinSignedBits() const
Compute the maximal number of bits needed to represent every value in this signed range.
uint32_t getBitWidth() const
Get the bit width of this ConstantRange.
LLVM_ABI ConstantRange sub(const ConstantRange &Other) const
Return a new range representing the possible values resulting from a subtraction of a value in this r...
LLVM_ABI ConstantRange sextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
static LLVM_ABI ConstantRange makeExactNoWrapRegion(Instruction::BinaryOps BinOp, const APInt &Other, unsigned NoWrapKind)
Produce the range that contains X if and only if "X BinOp Other" does not wrap.
This is an important base class in LLVM.
Definition Constant.h:43
A parsed version of the target data layout string in and methods for querying it.
Definition DataLayout.h:64
LLVM_ABI const StructLayout * getStructLayout(StructType *Ty) const
Returns a StructLayout object, indicating the alignment of the struct, its size, and the offsets of i...
LLVM_ABI IntegerType * getIntPtrType(LLVMContext &C, unsigned AddressSpace=0) const
Returns an integer type with size at least as big as that of a pointer in the given address space.
LLVM_ABI unsigned getIndexTypeSizeInBits(Type *Ty) const
The size in bits of the index used in GEP calculation for this type.
LLVM_ABI IntegerType * getIndexType(LLVMContext &C, unsigned AddressSpace) const
Returns the type of a GEP index in AddressSpace.
TypeSize getTypeSizeInBits(Type *Ty) const
Size examples:
Definition DataLayout.h:771
ValueT lookup(const_arg_type_t< KeyT > Val) const
lookup - Return the entry for the specified key, or a default constructed value if no such entry exis...
Definition DenseMap.h:205
iterator find(const_arg_type_t< KeyT > Val)
Definition DenseMap.h:178
std::pair< iterator, bool > try_emplace(KeyT &&Key, Ts &&...Args)
Definition DenseMap.h:256
DenseMapIterator< KeyT, ValueT, KeyInfoT, BucketT > iterator
Definition DenseMap.h:74
iterator find_as(const LookupKeyT &Val)
Alternate version of find() which allows a different, and possibly less expensive,...
Definition DenseMap.h:191
size_type count(const_arg_type_t< KeyT > Val) const
Return 1 if the specified key is in the map, 0 otherwise.
Definition DenseMap.h:174
iterator end()
Definition DenseMap.h:81
bool contains(const_arg_type_t< KeyT > Val) const
Return true if the specified key is in the map, false otherwise.
Definition DenseMap.h:169
void swap(DerivedT &RHS)
Definition DenseMap.h:371
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition DenseMap.h:241
Analysis pass which computes a DominatorTree.
Definition Dominators.h:283
Legacy analysis pass which computes a DominatorTree.
Definition Dominators.h:321
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition Dominators.h:164
LLVM_ABI bool isReachableFromEntry(const Use &U) const
Provide an overload for a Use.
LLVM_ABI bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
FoldingSetNodeIDRef - This class describes a reference to an interned FoldingSetNodeID,...
Definition FoldingSet.h:172
FoldingSetNodeID - This class is used to gather all the unique data bits of a node.
Definition FoldingSet.h:209
FunctionPass(char &pid)
Definition Pass.h:316
Represents flags for the getelementptr instruction/expression.
bool hasNoUnsignedSignedWrap() const
bool hasNoUnsignedWrap() const
static GEPNoWrapFlags none()
static LLVM_ABI Type * getTypeAtIndex(Type *Ty, Value *Idx)
Return the type of the element at the given index of an indexable type.
Module * getParent()
Get the module that this global value is contained inside of...
static bool isPrivateLinkage(LinkageTypes Linkage)
static bool isInternalLinkage(LinkageTypes Linkage)
This instruction compares its operands according to the predicate given to the constructor.
CmpPredicate getCmpPredicate() const
static bool isGE(Predicate P)
Return true if the predicate is SGE or UGE.
CmpPredicate getSwappedCmpPredicate() const
static LLVM_ABI bool compare(const APInt &LHS, const APInt &RHS, ICmpInst::Predicate Pred)
Return result of LHS Pred RHS comparison.
static bool isLT(Predicate P)
Return true if the predicate is SLT or ULT.
CmpPredicate getInverseCmpPredicate() const
Predicate getNonStrictCmpPredicate() const
For example, SGT -> SGE, SLT -> SLE, ULT -> ULE, UGT -> UGE.
static bool isGT(Predicate P)
Return true if the predicate is SGT or UGT.
Predicate getFlippedSignednessPredicate() const
For example, SLT->ULT, ULT->SLT, SLE->ULE, ULE->SLE, EQ->EQ.
static CmpPredicate getInverseCmpPredicate(CmpPredicate Pred)
static bool isEquality(Predicate P)
Return true if this predicate is either EQ or NE.
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
static bool isLE(Predicate P)
Return true if the predicate is SLE or ULE.
LLVM_ABI bool hasNoUnsignedWrap() const LLVM_READONLY
Determine whether the no unsigned wrap flag is set.
LLVM_ABI bool hasNoSignedWrap() const LLVM_READONLY
Determine whether the no signed wrap flag is set.
LLVM_ABI bool isIdenticalToWhenDefined(const Instruction *I, bool IntersectAttrs=false) const LLVM_READONLY
This is like isIdenticalTo, except that it ignores the SubclassOptionalData flags,...
Class to represent integer types.
static LLVM_ABI IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
Definition Type.cpp:318
A helper class to return the specified delimiter string after the first invocation of operator String...
An instruction for reading from memory.
Analysis pass that exposes the LoopInfo for a function.
Definition LoopInfo.h:569
bool contains(const LoopT *L) const
Return true if the specified loop is contained within in this loop.
BlockT * getHeader() const
unsigned getLoopDepth() const
Return the nesting level of this loop.
BlockT * getLoopPredecessor() const
If the given loop's header has exactly one unique predecessor outside the loop, return it.
LoopT * getParentLoop() const
Return the parent loop if it exists or nullptr for top level loops.
unsigned getLoopDepth(const BlockT *BB) const
Return the loop nesting level of the specified block.
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
The legacy pass manager's analysis pass to compute loop information.
Definition LoopInfo.h:596
Represents a single loop in the control flow graph.
Definition LoopInfo.h:40
bool isLoopInvariant(const Value *V) const
Return true if the specified value is loop invariant.
Definition LoopInfo.cpp:61
Metadata node.
Definition Metadata.h:1080
A Module instance is used to store all the information related to an LLVM module.
Definition Module.h:67
unsigned getOpcode() const
Return the opcode for this Instruction or ConstantExpr.
Definition Operator.h:43
Utility class for integer operators which may exhibit overflow - Add, Sub, Mul, and Shl.
Definition Operator.h:78
bool hasNoSignedWrap() const
Test whether this operation is known to never undergo signed overflow, aka the nsw property.
Definition Operator.h:111
bool hasNoUnsignedWrap() const
Test whether this operation is known to never undergo unsigned overflow, aka the nuw property.
Definition Operator.h:105
iterator_range< const_block_iterator > blocks() const
op_range incoming_values()
Value * getIncomingValueForBlock(const BasicBlock *BB) const
BasicBlock * getIncomingBlock(unsigned i) const
Return incoming basic block number i.
Value * getIncomingValue(unsigned i) const
Return incoming value number x.
unsigned getNumIncomingValues() const
Return the number of incoming edges.
AnalysisType & getAnalysis() const
getAnalysis<AnalysisType>() - This function is used by subclasses to get to the analysis information ...
PointerIntPair - This class implements a pair of a pointer and small integer.
static PointerType * getUnqual(Type *ElementType)
This constructs a pointer to an object of the specified type in the default address space (address sp...
static LLVM_ABI PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
LLVM_ABI void addPredicate(const SCEVPredicate &Pred)
Adds a new predicate.
LLVM_ABI const SCEVPredicate & getPredicate() const
LLVM_ABI const SCEV * getPredicatedSCEV(const SCEV *Expr)
Returns the rewritten SCEV for Expr in the context of the current SCEV predicate.
LLVM_ABI bool hasNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags)
Returns true if we've proved that V doesn't wrap by means of a SCEV predicate.
LLVM_ABI void setNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags)
Proves that V doesn't overflow by adding SCEV predicate.
LLVM_ABI void print(raw_ostream &OS, unsigned Depth) const
Print the SCEV mappings done by the Predicated Scalar Evolution.
LLVM_ABI bool areAddRecsEqualWithPreds(const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const
Check if AR1 and AR2 are equal, while taking into account Equal predicates in Preds.
LLVM_ABI PredicatedScalarEvolution(ScalarEvolution &SE, Loop &L)
LLVM_ABI const SCEVAddRecExpr * getAsAddRec(Value *V)
Attempts to produce an AddRecExpr for V by adding additional SCEV predicates.
LLVM_ABI unsigned getSmallConstantMaxTripCount()
Returns the upper bound of the loop trip count as a normal unsigned value, or 0 if the trip count is ...
LLVM_ABI const SCEV * getBackedgeTakenCount()
Get the (predicated) backedge count for the analyzed loop.
LLVM_ABI const SCEV * getSymbolicMaxBackedgeTakenCount()
Get the (predicated) symbolic max backedge count for the analyzed loop.
LLVM_ABI const SCEV * getSCEV(Value *V)
Returns the SCEV expression of V, in the context of the current SCEV predicate.
A set of analyses that are preserved following a run of a transformation pass.
Definition Analysis.h:112
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition Analysis.h:118
PreservedAnalysisChecker getChecker() const
Build a checker for this PreservedAnalyses and the specified analysis type.
Definition Analysis.h:275
constexpr bool isValid() const
Definition Register.h:112
This node represents an addition of some number of SCEVs.
This node represents a polynomial recurrence on the trip count of the specified loop.
LLVM_ABI const SCEV * evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const
Return the value of this chain of recurrences at the specified iteration number.
const SCEV * getStepRecurrence(ScalarEvolution &SE) const
Constructs and returns the recurrence indicating how much this expression steps by.
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a recurrence without clearing any previously set flags.
bool isAffine() const
Return true if this represents an expression A + B*x where A and B are loop invariant values.
bool isQuadratic() const
Return true if this represents an expression A + B*x + C*x^2 where A, B and C are loop invariant valu...
LLVM_ABI const SCEV * getNumIterationsInRange(const ConstantRange &Range, ScalarEvolution &SE) const
Return the number of iterations of this loop that produce values in the specified constant range.
LLVM_ABI const SCEVAddRecExpr * getPostIncExpr(ScalarEvolution &SE) const
Return an expression representing the value of this expression one iteration of the loop ahead.
This is the base class for unary cast operator classes.
const SCEV * getOperand() const
LLVM_ABI SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op, Type *ty)
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a non-recurrence without clearing previously set flags.
This class represents an assumption that the expression LHS Pred RHS evaluates to true,...
SCEVComparePredicate(const FoldingSetNodeIDRef ID, const ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Implementation of the SCEVPredicate interface.
This class represents a constant integer value.
ConstantInt * getValue() const
const APInt & getAPInt() const
This is the base class for unary integral cast operator classes.
LLVM_ABI SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op, Type *ty)
This node is the base class min/max selections.
static enum SCEVTypes negate(enum SCEVTypes T)
This node represents multiplication of some number of SCEVs.
This node is a base class providing common functionality for n'ary operators.
NoWrapFlags getNoWrapFlags(NoWrapFlags Mask=NoWrapMask) const
const SCEV * getOperand(unsigned i) const
ArrayRef< const SCEV * > operands() const
This class represents an assumption made using SCEV expressions which can be checked at run-time.
SCEVPredicate(const SCEVPredicate &)=default
virtual bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const =0
Returns true if this predicate implies N.
SCEVPredicateKind Kind
This class represents a cast from a pointer to a pointer-sized integer value, without capturing the p...
This class represents a cast from a pointer to a pointer-sized integer value.
This visitor recursively visits a SCEV expression and re-writes it.
const SCEV * visitSignExtendExpr(const SCEVSignExtendExpr *Expr)
const SCEV * visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr)
const SCEV * visitSMinExpr(const SCEVSMinExpr *Expr)
const SCEV * visitUMinExpr(const SCEVUMinExpr *Expr)
This class represents a signed minimum selection.
This node is the base class for sequential/in-order min/max selections.
static SCEVTypes getEquivalentNonSequentialSCEVType(SCEVTypes Ty)
This class represents a sign extension of a small integer value to a larger integer value.
Visit all nodes in the expression tree using worklist traversal.
This class represents a truncation of an integer value to a smaller integer value.
This class represents a binary unsigned division operation.
This class represents an unsigned minimum selection.
This class represents a composition of other SCEV predicates, and is the class that most clients will...
void print(raw_ostream &OS, unsigned Depth) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Returns true if this predicate implies N.
SCEVUnionPredicate(ArrayRef< const SCEVPredicate * > Preds, ScalarEvolution &SE)
Union predicates don't get cached so create a dummy set ID for it.
bool isAlwaysTrue() const override
Implementation of the SCEVPredicate interface.
This means that we are dealing with an entirely unknown SCEV value, and only represent it as its LLVM...
This class represents the value of vscale, as used when defining the length of a scalable vector or r...
This class represents an assumption made on an AddRec expression.
IncrementWrapFlags
Similar to SCEV::NoWrapFlags, but with slightly different semantics for FlagNUSW.
SCEVWrapPredicate(const FoldingSetNodeIDRef ID, const SCEVAddRecExpr *AR, IncrementWrapFlags Flags)
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Returns true if this predicate implies N.
static SCEVWrapPredicate::IncrementWrapFlags setFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OnFlags)
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
const SCEVAddRecExpr * getExpr() const
Implementation of the SCEVPredicate interface.
static SCEVWrapPredicate::IncrementWrapFlags clearFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OffFlags)
Convenient IncrementWrapFlags manipulation methods.
static SCEVWrapPredicate::IncrementWrapFlags getImpliedFlags(const SCEVAddRecExpr *AR, ScalarEvolution &SE)
Returns the set of SCEVWrapPredicate no wrap flags implied by a SCEVAddRecExpr.
IncrementWrapFlags getFlags() const
Returns the set assumed no overflow flags.
This class represents a zero extension of a small integer value to a larger integer value.
This class represents an analyzed expression in the program.
LLVM_ABI ArrayRef< const SCEV * > operands() const
Return operands of this SCEV expression.
unsigned short getExpressionSize() const
LLVM_ABI bool isOne() const
Return true if the expression is a constant one.
LLVM_ABI bool isZero() const
Return true if the expression is a constant zero.
LLVM_ABI void dump() const
This method is used for debugging.
LLVM_ABI bool isAllOnesValue() const
Return true if the expression is a constant all-ones value.
LLVM_ABI bool isNonConstantNegative() const
Return true if the specified scev is negated, but not a constant.
LLVM_ABI void print(raw_ostream &OS) const
Print out the internal representation of this scalar to the specified stream.
SCEV(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, unsigned short ExpressionSize)
SCEVTypes getSCEVType() const
LLVM_ABI Type * getType() const
Return the LLVM type of this SCEV expression.
NoWrapFlags
NoWrapFlags are bitfield indices into SubclassData.
Analysis pass that exposes the ScalarEvolution for a function.
LLVM_ABI ScalarEvolution run(Function &F, FunctionAnalysisManager &AM)
LLVM_ABI PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
LLVM_ABI PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
void print(raw_ostream &OS, const Module *=nullptr) const override
print - Print out the internal state of the pass.
bool runOnFunction(Function &F) override
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
void releaseMemory() override
releaseMemory() - This member can be implemented by a pass if it wants to be able to release its memo...
void verifyAnalysis() const override
verifyAnalysis() - This member can be implemented by a analysis pass to check state of analysis infor...
static LLVM_ABI LoopGuards collect(const Loop *L, ScalarEvolution &SE)
Collect rewrite map for loop guards for loop L, together with flags indicating if NUW and NSW can be ...
LLVM_ABI const SCEV * rewrite(const SCEV *Expr) const
Try to apply the collected loop guards to Expr.
The main scalar evolution driver.
const SCEV * getConstantMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEVConstant that is greater than or equal to (i.e.
static bool hasFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags TestFlags)
const DataLayout & getDataLayout() const
Return the DataLayout associated with the module this SCEV instance is operating on.
LLVM_ABI bool isKnownNonNegative(const SCEV *S)
Test if the given expression is known to be non-negative.
LLVM_ABI bool isKnownOnEveryIteration(CmpPredicate Pred, const SCEVAddRecExpr *LHS, const SCEV *RHS)
Test if the condition described by Pred, LHS, RHS is known to be true on every iteration of the loop ...
LLVM_ABI const SCEV * getNegativeSCEV(const SCEV *V, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap)
Return the SCEV object corresponding to -V.
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterationsImpl(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
LLVM_ABI const SCEV * getSMaxExpr(const SCEV *LHS, const SCEV *RHS)
LLVM_ABI const SCEV * getUDivCeilSCEV(const SCEV *N, const SCEV *D)
Compute ceil(N / D).
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterations(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L at given Context duri...
LLVM_ABI Type * getWiderType(Type *Ty1, Type *Ty2) const
LLVM_ABI const SCEV * getAbsExpr(const SCEV *Op, bool IsNSW)
LLVM_ABI bool isKnownNonPositive(const SCEV *S)
Test if the given expression is known to be non-positive.
LLVM_ABI const SCEV * getURemExpr(const SCEV *LHS, const SCEV *RHS)
Represents an unsigned remainder expression based on unsigned division.
LLVM_ABI bool isKnownNegative(const SCEV *S)
Test if the given expression is known to be negative.
LLVM_ABI const SCEV * getPredicatedConstantMaxBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getConstantMaxBackedgeTakenCount, except it will add a set of SCEV predicates to Predicate...
LLVM_ABI const SCEV * removePointerBase(const SCEV *S)
Compute an expression equivalent to S - getPointerBase(S).
LLVM_ABI bool isLoopEntryGuardedByCond(const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the loop is protected by a conditional between LHS and RHS.
LLVM_ABI bool isKnownNonZero(const SCEV *S)
Test if the given expression is known to be non-zero.
LLVM_ABI const SCEV * getSCEVAtScope(const SCEV *S, const Loop *L)
Return a SCEV expression for the specified value at the specified scope in the program.
LLVM_ABI const SCEV * getSMinExpr(const SCEV *LHS, const SCEV *RHS)
LLVM_ABI const SCEV * getBackedgeTakenCount(const Loop *L, ExitCountKind Kind=Exact)
If the specified loop has a predictable backedge-taken count, return it, otherwise return a SCEVCould...
LLVM_ABI const SCEV * getUMaxExpr(const SCEV *LHS, const SCEV *RHS)
LLVM_ABI void setNoWrapFlags(SCEVAddRecExpr *AddRec, SCEV::NoWrapFlags Flags)
Update no-wrap flags of an AddRec.
LLVM_ABI const SCEV * getUMaxFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS)
Promote the operands to the wider of the types using zero-extension, and then perform a umax operatio...
const SCEV * getZero(Type *Ty)
Return a SCEV for the constant 0 of a specific type.
LLVM_ABI bool willNotOverflow(Instruction::BinaryOps BinOp, bool Signed, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI=nullptr)
Is operation BinOp between LHS and RHS provably does not have a signed/unsigned overflow (Signed)?
LLVM_ABI ExitLimit computeExitLimitFromCond(const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit, bool AllowPredicates=false)
Compute the number of times the backedge of the specified loop will execute if its exit condition wer...
LLVM_ABI const SCEV * getZeroExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI const SCEVPredicate * getEqualPredicate(const SCEV *LHS, const SCEV *RHS)
LLVM_ABI unsigned getSmallConstantTripMultiple(const Loop *L, const SCEV *ExitCount)
Returns the largest constant divisor of the trip count as a normal unsigned value,...
LLVM_ABI uint64_t getTypeSizeInBits(Type *Ty) const
Return the size in bits of the specified type, for which isSCEVable must return true.
LLVM_ABI const SCEV * getConstant(ConstantInt *V)
LLVM_ABI const SCEV * getPredicatedBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getBackedgeTakenCount, except it will add a set of SCEV predicates to Predicates that are ...
LLVM_ABI const SCEV * getSCEV(Value *V)
Return a SCEV expression for the full generality of the specified expression.
ConstantRange getSignedRange(const SCEV *S)
Determine the signed range for a particular SCEV.
LLVM_ABI const SCEV * getNoopOrSignExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
bool loopHasNoAbnormalExits(const Loop *L)
Return true if the loop has no abnormal exits.
LLVM_ABI const SCEV * getTripCountFromExitCount(const SCEV *ExitCount)
A version of getTripCountFromExitCount below which always picks an evaluation type which can not resu...
LLVM_ABI ScalarEvolution(Function &F, TargetLibraryInfo &TLI, AssumptionCache &AC, DominatorTree &DT, LoopInfo &LI)
const SCEV * getOne(Type *Ty)
Return a SCEV for the constant 1 of a specific type.
LLVM_ABI const SCEV * getTruncateOrNoop(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI const SCEV * getLosslessPtrToIntExpr(const SCEV *Op)
LLVM_ABI const SCEV * getCastExpr(SCEVTypes Kind, const SCEV *Op, Type *Ty)
LLVM_ABI const SCEV * getSequentialMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< const SCEV * > &Operands)
LLVM_ABI std::optional< bool > evaluatePredicateAt(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Check whether the condition described by Pred, LHS, and RHS is true or false in the given Context.
LLVM_ABI unsigned getSmallConstantMaxTripCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > *Predicates=nullptr)
Returns the upper bound of the loop trip count as a normal unsigned value.
LLVM_ABI const SCEV * getPtrToIntExpr(const SCEV *Op, Type *Ty)
LLVM_ABI bool isBackedgeTakenCountMaxOrZero(const Loop *L)
Return true if the backedge taken count is either the value returned by getConstantMaxBackedgeTakenCo...
LLVM_ABI void forgetLoop(const Loop *L)
This method should be called by the client when it has changed a loop in a way that may effect Scalar...
LLVM_ABI bool isLoopInvariant(const SCEV *S, const Loop *L)
Return true if the value of the given SCEV is unchanging in the specified loop.
LLVM_ABI bool isKnownPositive(const SCEV *S)
Test if the given expression is known to be positive.
APInt getUnsignedRangeMin(const SCEV *S)
Determine the min of the unsigned range for a particular SCEV.
LLVM_ABI bool SimplifyICmpOperands(CmpPredicate &Pred, const SCEV *&LHS, const SCEV *&RHS, unsigned Depth=0)
Simplify LHS and RHS in a comparison with predicate Pred.
LLVM_ABI const SCEV * getOffsetOfExpr(Type *IntTy, StructType *STy, unsigned FieldNo)
Return an expression for offsetof on the given field with type IntTy.
LLVM_ABI LoopDisposition getLoopDisposition(const SCEV *S, const Loop *L)
Return the "disposition" of the given SCEV with respect to the given loop.
LLVM_ABI bool containsAddRecurrence(const SCEV *S)
Return true if the SCEV is a scAddRecExpr or it contains scAddRecExpr.
LLVM_ABI const SCEV * getSignExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI const SCEV * getAddRecExpr(const SCEV *Start, const SCEV *Step, const Loop *L, SCEV::NoWrapFlags Flags)
Get an add recurrence expression for the specified loop.
LLVM_ABI bool hasOperand(const SCEV *S, const SCEV *Op) const
Test whether the given SCEV has Op as a direct or indirect operand.
LLVM_ABI const SCEV * getUDivExpr(const SCEV *LHS, const SCEV *RHS)
Get a canonical unsigned division expression, or something simpler if possible.
LLVM_ABI const SCEV * getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI bool isSCEVable(Type *Ty) const
Test if values of the given type are analyzable within the SCEV framework.
LLVM_ABI Type * getEffectiveSCEVType(Type *Ty) const
Return a type with the same bitwidth as the given type and which represents how SCEV will treat the g...
LLVM_ABI const SCEVPredicate * getComparePredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
LLVM_ABI bool haveSameSign(const SCEV *S1, const SCEV *S2)
Return true if we know that S1 and S2 must have the same sign.
LLVM_ABI const SCEV * getNotSCEV(const SCEV *V)
Return the SCEV object corresponding to ~V.
LLVM_ABI const SCEV * getElementCount(Type *Ty, ElementCount EC, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap)
LLVM_ABI bool instructionCouldExistWithOperands(const SCEV *A, const SCEV *B)
Return true if there exists a point in the program at which both A and B could be operands to the sam...
ConstantRange getUnsignedRange(const SCEV *S)
Determine the unsigned range for a particular SCEV.
LLVM_ABI void print(raw_ostream &OS) const
LLVM_ABI const SCEV * getUMinExpr(const SCEV *LHS, const SCEV *RHS, bool Sequential=false)
LLVM_ABI const SCEV * getPredicatedExitCount(const Loop *L, const BasicBlock *ExitingBlock, SmallVectorImpl< const SCEVPredicate * > *Predicates, ExitCountKind Kind=Exact)
Same as above except this uses the predicated backedge taken info and may require predicates.
static SCEV::NoWrapFlags clearFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OffFlags)
LLVM_ABI void forgetTopmostLoop(const Loop *L)
LLVM_ABI void forgetValue(Value *V)
This method should be called by the client when it has changed a value in a way that may effect its v...
APInt getSignedRangeMin(const SCEV *S)
Determine the min of the signed range for a particular SCEV.
LLVM_ABI const SCEV * getNoopOrAnyExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI void forgetBlockAndLoopDispositions(Value *V=nullptr)
Called when the client has changed the disposition of values in a loop or block.
LLVM_ABI const SCEV * getTruncateExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantPredicate(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI=nullptr)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L, return a LoopInvaria...
LLVM_ABI const SCEV * getStoreSizeOfExpr(Type *IntTy, Type *StoreTy)
Return an expression for the store size of StoreTy that is type IntTy.
LLVM_ABI const SCEVPredicate * getWrapPredicate(const SCEVAddRecExpr *AR, SCEVWrapPredicate::IncrementWrapFlags AddedFlags)
LLVM_ABI bool isLoopBackedgeGuardedByCond(const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether the backedge of the loop is protected by a conditional between LHS and RHS.
LLVM_ABI const SCEV * getMinusSCEV(const SCEV *LHS, const SCEV *RHS, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Return LHS-RHS.
LLVM_ABI APInt getNonZeroConstantMultiple(const SCEV *S)
const SCEV * getMinusOne(Type *Ty)
Return a SCEV for the constant -1 of a specific type.
static SCEV::NoWrapFlags setFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OnFlags)
LLVM_ABI bool hasLoopInvariantBackedgeTakenCount(const Loop *L)
Return true if the specified loop has an analyzable loop-invariant backedge-taken count.
LLVM_ABI BlockDisposition getBlockDisposition(const SCEV *S, const BasicBlock *BB)
Return the "disposition" of the given SCEV with respect to the given block.
LLVM_ABI const SCEV * getNoopOrZeroExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI bool invalidate(Function &F, const PreservedAnalyses &PA, FunctionAnalysisManager::Invalidator &Inv)
LLVM_ABI const SCEV * getUMinFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS, bool Sequential=false)
Promote the operands to the wider of the types using zero-extension, and then perform a umin operatio...
LLVM_ABI bool loopIsFiniteByAssumption(const Loop *L)
Return true if this loop is finite by assumption.
LLVM_ABI const SCEV * getExistingSCEV(Value *V)
Return an existing SCEV for V if there is one, otherwise return nullptr.
LLVM_ABI APInt getConstantMultiple(const SCEV *S, const Instruction *CtxI=nullptr)
Returns the max constant multiple of S.
LoopDisposition
An enum describing the relationship between a SCEV and a loop.
@ LoopComputable
The SCEV varies predictably with the loop.
@ LoopVariant
The SCEV is loop-variant (unknown).
@ LoopInvariant
The SCEV is loop-invariant.
LLVM_ABI bool isKnownMultipleOf(const SCEV *S, uint64_t M, SmallVectorImpl< const SCEVPredicate * > &Assumptions)
Check that S is a multiple of M.
LLVM_ABI const SCEV * getAnyExtendExpr(const SCEV *Op, Type *Ty)
getAnyExtendExpr - Return a SCEV for the given operand extended with unspecified bits out to the give...
LLVM_ABI bool isKnownToBeAPowerOfTwo(const SCEV *S, bool OrZero=false, bool OrNegative=false)
Test if the given expression is known to be a power of 2.
LLVM_ABI std::optional< SCEV::NoWrapFlags > getStrengthenedNoWrapFlagsFromBinOp(const OverflowingBinaryOperator *OBO)
Parse NSW/NUW flags from add/sub/mul IR binary operation Op into SCEV no-wrap flags,...
LLVM_ABI void forgetLcssaPhiWithNewPredecessor(Loop *L, PHINode *V)
Forget LCSSA phi node V of loop L to which a new predecessor was added, such that it may no longer be...
LLVM_ABI bool containsUndefs(const SCEV *S) const
Return true if the SCEV expression contains an undef value.
LLVM_ABI std::optional< MonotonicPredicateType > getMonotonicPredicateType(const SCEVAddRecExpr *LHS, ICmpInst::Predicate Pred)
If, for all loop invariant X, the predicate "LHS `Pred` X" is monotonically increasing or decreasing,...
LLVM_ABI const SCEV * getCouldNotCompute()
LLVM_ABI bool isAvailableAtLoopEntry(const SCEV *S, const Loop *L)
Determine if the SCEV can be evaluated at loop's entry.
LLVM_ABI uint32_t getMinTrailingZeros(const SCEV *S, const Instruction *CtxI=nullptr)
Determine the minimum number of zero bits that S is guaranteed to end in (at every loop iteration).
BlockDisposition
An enum describing the relationship between a SCEV and a basic block.
@ DominatesBlock
The SCEV dominates the block.
@ ProperlyDominatesBlock
The SCEV properly dominates the block.
@ DoesNotDominateBlock
The SCEV does not dominate the block.
LLVM_ABI const SCEV * getGEPExpr(GEPOperator *GEP, ArrayRef< const SCEV * > IndexExprs)
Returns an expression for a GEP.
LLVM_ABI const SCEV * getExitCount(const Loop *L, const BasicBlock *ExitingBlock, ExitCountKind Kind=Exact)
Return the number of times the backedge executes before the given exit would be taken; if not exactly...
LLVM_ABI const SCEV * getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI void getPoisonGeneratingValues(SmallPtrSetImpl< const Value * > &Result, const SCEV *S)
Return the set of Values that, if poison, will definitively result in S being poison as well.
LLVM_ABI void forgetLoopDispositions()
Called when the client has changed the disposition of values in this loop.
LLVM_ABI const SCEV * getVScale(Type *Ty)
LLVM_ABI unsigned getSmallConstantTripCount(const Loop *L)
Returns the exact trip count of the loop if we can compute it, and the result is a small constant.
LLVM_ABI bool hasComputableLoopEvolution(const SCEV *S, const Loop *L)
Return true if the given SCEV changes value in a known way in the specified loop.
LLVM_ABI const SCEV * getPointerBase(const SCEV *V)
Transitively follow the chain of pointer-type operands until reaching a SCEV that does not have a sin...
LLVM_ABI const SCEV * getMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< const SCEV * > &Operands)
LLVM_ABI void forgetAllLoops()
LLVM_ABI bool dominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV dominate the specified basic block.
APInt getUnsignedRangeMax(const SCEV *S)
Determine the max of the unsigned range for a particular SCEV.
ExitCountKind
The terms "backedge taken count" and "exit count" are used interchangeably to refer to the number of ...
@ SymbolicMaximum
An expression which provides an upper bound on the exact trip count.
@ ConstantMaximum
A constant which provides an upper bound on the exact trip count.
@ Exact
An expression exactly describing the number of times the backedge has executed when a loop is exited.
LLVM_ABI const SCEV * applyLoopGuards(const SCEV *Expr, const Loop *L)
Try to apply information from loop guards for L to Expr.
LLVM_ABI const SCEV * getMulExpr(SmallVectorImpl< const SCEV * > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical multiply expression, or something simpler if possible.
LLVM_ABI const SCEV * getPtrToAddrExpr(const SCEV *Op)
LLVM_ABI const SCEVAddRecExpr * convertSCEVToAddRecWithPredicates(const SCEV *S, const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Preds)
Tries to convert the S expression to an AddRec expression, adding additional predicates to Preds as r...
LLVM_ABI const SCEV * getElementSize(Instruction *Inst)
Return the size of an element read or written by Inst.
LLVM_ABI const SCEV * getSizeOfExpr(Type *IntTy, TypeSize Size)
Return an expression for a TypeSize.
LLVM_ABI std::optional< bool > evaluatePredicate(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Check whether the condition described by Pred, LHS, and RHS is true or false.
LLVM_ABI const SCEV * getUnknown(Value *V)
LLVM_ABI std::optional< std::pair< const SCEV *, SmallVector< const SCEVPredicate *, 3 > > > createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI)
Checks if SymbolicPHI can be rewritten as an AddRecExpr under some Predicates.
LLVM_ABI const SCEV * getTruncateOrZeroExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
static SCEV::NoWrapFlags maskFlags(SCEV::NoWrapFlags Flags, int Mask)
Convenient NoWrapFlags manipulation that hides enum casts and is visible in the ScalarEvolution name ...
LLVM_ABI std::optional< APInt > computeConstantDifference(const SCEV *LHS, const SCEV *RHS)
Compute LHS - RHS and returns the result as an APInt if it is a constant, and std::nullopt if it isn'...
LLVM_ABI bool properlyDominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV properly dominate the specified basic block.
LLVM_ABI const SCEV * rewriteUsingPredicate(const SCEV *S, const Loop *L, const SCEVPredicate &A)
Re-writes the SCEV according to the Predicates in A.
LLVM_ABI std::pair< const SCEV *, const SCEV * > SplitIntoInitAndPostInc(const Loop *L, const SCEV *S)
Splits SCEV expression S into two SCEVs.
LLVM_ABI bool canReuseInstruction(const SCEV *S, Instruction *I, SmallVectorImpl< Instruction * > &DropPoisonGeneratingInsts)
Check whether it is poison-safe to represent the expression S using the instruction I.
LLVM_ABI bool isKnownPredicateAt(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
LLVM_ABI const SCEV * getPredicatedSymbolicMaxBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getSymbolicMaxBackedgeTakenCount, except it will add a set of SCEV predicates to Predicate...
LLVM_ABI const SCEV * getUDivExactExpr(const SCEV *LHS, const SCEV *RHS)
Get a canonical unsigned division expression, or something simpler if possible.
LLVM_ABI void registerUser(const SCEV *User, ArrayRef< const SCEV * > Ops)
Notify this ScalarEvolution that User directly uses SCEVs in Ops.
LLVM_ABI const SCEV * getAddExpr(SmallVectorImpl< const SCEV * > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical add expression, or something simpler if possible.
LLVM_ABI bool isBasicBlockEntryGuardedByCond(const BasicBlock *BB, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the basic block is protected by a conditional between LHS and RHS.
LLVM_ABI const SCEV * getTruncateOrSignExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI bool containsErasedValue(const SCEV *S) const
Return true if the SCEV expression contains a Value that has been optimised out and is now a nullptr.
LLVM_ABI bool isKnownPredicate(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
LLVM_ABI bool isKnownViaInduction(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
We'd like to check the predicate on every iteration of the most dominated loop between loops used in ...
const SCEV * getSymbolicMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEV that is greater than or equal to (i.e.
APInt getSignedRangeMax(const SCEV *S)
Determine the max of the signed range for a particular SCEV.
LLVM_ABI void verify() const
LLVMContext & getContext() const
Implements a dense probed hash-table based set with some number of buckets stored inline.
Definition DenseSet.h:291
size_type size() const
Definition SmallPtrSet.h:99
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
bool contains(ConstPtrType Ptr) const
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
reference emplace_back(ArgTypes &&... Args)
void reserve(size_type N)
iterator erase(const_iterator CI)
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
iterator insert(iterator I, T &&Elt)
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
An instruction for storing to memory.
StringRef - Represent a constant reference to a string, i.e.
Definition StringRef.h:55
Used to lazily calculate structure layout information for a target machine, based on the DataLayout s...
Definition DataLayout.h:723
TypeSize getElementOffset(unsigned Idx) const
Definition DataLayout.h:754
TypeSize getSizeInBits() const
Definition DataLayout.h:734
Class to represent struct types.
Analysis pass providing the TargetLibraryInfo.
Provides information about what library functions are available for the current target.
The instances of the Type class are immutable: once they are created, they are never changed.
Definition Type.h:45
static LLVM_ABI IntegerType * getInt32Ty(LLVMContext &C)
Definition Type.cpp:296
bool isPointerTy() const
True if this is an instance of PointerType.
Definition Type.h:267
static LLVM_ABI IntegerType * getInt8Ty(LLVMContext &C)
Definition Type.cpp:294
LLVM_ABI TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
Definition Type.cpp:197
static LLVM_ABI IntegerType * getInt1Ty(LLVMContext &C)
Definition Type.cpp:293
bool isIntOrPtrTy() const
Return true if this is an integer type or a pointer type.
Definition Type.h:255
bool isIntegerTy() const
True if this is an instance of IntegerType.
Definition Type.h:240
static LLVM_ABI IntegerType * getIntNTy(LLVMContext &C, unsigned N)
Definition Type.cpp:300
A Use represents the edge between a Value definition and its users.
Definition Use.h:35
op_range operands()
Definition User.h:267
Use & Op()
Definition User.h:171
Value * getOperand(unsigned i) const
Definition User.h:207
LLVM Value Representation.
Definition Value.h:75
Type * getType() const
All values are typed, get the type of this value.
Definition Value.h:256
LLVMContext & getContext() const
All values hold a context through their type.
Definition Value.h:259
unsigned getValueID() const
Return an ID for the concrete type of this object.
Definition Value.h:543
LLVM_ABI void printAsOperand(raw_ostream &O, bool PrintType=true, const Module *M=nullptr) const
Print the name of this Value out to the specified raw_ostream.
LLVM_ABI StringRef getName() const
Return a constant reference to the value's name.
Definition Value.cpp:322
constexpr bool isScalable() const
Returns whether the quantity is scaled by a runtime quantity (vscale).
Definition TypeSize.h:168
An efficient, type-erasing, non-owning reference to a callable.
const ParentTy * getParent() const
Definition ilist_node.h:34
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition raw_ostream.h:53
raw_ostream & indent(unsigned NumSpaces)
indent - Insert 'NumSpaces' spaces.
Changed
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr char Align[]
Key for Kernel::Arg::Metadata::mAlign.
const APInt & smin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be signed.
Definition APInt.h:2257
const APInt & smax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be signed.
Definition APInt.h:2262
const APInt & umin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be unsigned.
Definition APInt.h:2267
LLVM_ABI std::optional< APInt > SolveQuadraticEquationWrap(APInt A, APInt B, APInt C, unsigned RangeWidth)
Let q(n) = An^2 + Bn + C, and BW = bit width of the value range (e.g.
Definition APInt.cpp:2823
const APInt & umax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be unsigned.
Definition APInt.h:2272
LLVM_ABI APInt GreatestCommonDivisor(APInt A, APInt B)
Compute GCD of two unsigned APInt values.
Definition APInt.cpp:798
@ Entry
Definition COFF.h:862
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition CallingConv.h:24
@ C
The default llvm calling convention, compatible with C.
Definition CallingConv.h:34
int getMinValue(MCInstrInfo const &MCII, MCInst const &MCI)
Return the minimum value of an extendable operand.
@ BasicBlock
Various leaf nodes.
Definition ISDOpcodes.h:81
LLVM_ABI Function * getDeclarationIfExists(const Module *M, ID id)
Look up the Function declaration of the intrinsic id in the Module M and return it if it exists.
Predicate
Predicate - These are "(BI << 5) | BO" for various predicates.
BinaryOp_match< LHS, RHS, Instruction::AShr > m_AShr(const LHS &L, const RHS &R)
ap_match< APInt > m_APInt(const APInt *&Res)
Match a ConstantInt or splatted ConstantVector, binding the specified pointer to the contained APInt.
bool match(Val *V, const Pattern &P)
class_match< ConstantInt > m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
IntrinsicID_match m_Intrinsic()
Match intrinsic calls like this: m_Intrinsic<Intrinsic::fabs>(m_Value(X))
ThreeOps_match< Cond, LHS, RHS, Instruction::Select > m_Select(const Cond &C, const LHS &L, const RHS &R)
Matches SelectInst.
ExtractValue_match< Ind, Val_t > m_ExtractValue(const Val_t &V)
Match a single index ExtractValue instruction.
bind_ty< WithOverflowInst > m_WithOverflowInst(WithOverflowInst *&I)
Match a with overflow intrinsic, capturing it if we match.
auto m_LogicalOr()
Matches L || R where L and R are arbitrary values.
brc_match< Cond_t, bind_ty< BasicBlock >, bind_ty< BasicBlock > > m_Br(const Cond_t &C, BasicBlock *&T, BasicBlock *&F)
BinaryOp_match< LHS, RHS, Instruction::SDiv > m_SDiv(const LHS &L, const RHS &R)
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
BinaryOp_match< LHS, RHS, Instruction::LShr > m_LShr(const LHS &L, const RHS &R)
BinaryOp_match< LHS, RHS, Instruction::Shl > m_Shl(const LHS &L, const RHS &R)
auto m_LogicalAnd()
Matches L && R where L and R are arbitrary values.
class_match< BasicBlock > m_BasicBlock()
Match an arbitrary basic block value and ignore it.
match_combine_or< LTy, RTy > m_CombineOr(const LTy &L, const RTy &R)
Combine two pattern matchers matching L || R.
class_match< const SCEVVScale > m_SCEVVScale()
bind_cst_ty m_scev_APInt(const APInt *&C)
Match an SCEV constant and bind it to an APInt.
cst_pred_ty< is_all_ones > m_scev_AllOnes()
Match an integer with all bits set.
SCEVUnaryExpr_match< SCEVZeroExtendExpr, Op0_t > m_scev_ZExt(const Op0_t &Op0)
is_undef_or_poison m_scev_UndefOrPoison()
Match an SCEVUnknown wrapping undef or poison.
class_match< const SCEVConstant > m_SCEVConstant()
cst_pred_ty< is_one > m_scev_One()
Match an integer 1.
specificloop_ty m_SpecificLoop(const Loop *L)
SCEVAffineAddRec_match< Op0_t, Op1_t, class_match< const Loop > > m_scev_AffineAddRec(const Op0_t &Op0, const Op1_t &Op1)
bind_ty< const SCEVMulExpr > m_scev_Mul(const SCEVMulExpr *&V)
SCEVUnaryExpr_match< SCEVSignExtendExpr, Op0_t > m_scev_SExt(const Op0_t &Op0)
cst_pred_ty< is_zero > m_scev_Zero()
Match an integer 0.
SCEVUnaryExpr_match< SCEVTruncateExpr, Op0_t > m_scev_Trunc(const Op0_t &Op0)
bool match(const SCEV *S, const Pattern &P)
SCEVBinaryExpr_match< SCEVUDivExpr, Op0_t, Op1_t > m_scev_UDiv(const Op0_t &Op0, const Op1_t &Op1)
specificscev_ty m_scev_Specific(const SCEV *S)
Match if we have a specific specified SCEV.
SCEVBinaryExpr_match< SCEVMulExpr, Op0_t, Op1_t, SCEV::FlagNUW, true > m_scev_c_NUWMul(const Op0_t &Op0, const Op1_t &Op1)
class_match< const Loop > m_Loop()
bind_ty< const SCEVAddExpr > m_scev_Add(const SCEVAddExpr *&V)
bind_ty< const SCEVUnknown > m_SCEVUnknown(const SCEVUnknown *&V)
SCEVBinaryExpr_match< SCEVMulExpr, Op0_t, Op1_t, SCEV::FlagAnyWrap, true > m_scev_c_Mul(const Op0_t &Op0, const Op1_t &Op1)
SCEVBinaryExpr_match< SCEVSMaxExpr, Op0_t, Op1_t > m_scev_SMax(const Op0_t &Op0, const Op1_t &Op1)
SCEVURem_match< Op0_t, Op1_t > m_scev_URem(Op0_t LHS, Op1_t RHS, ScalarEvolution &SE)
Match the mathematical pattern A - (A / B) * B, where A and B can be arbitrary expressions.
class_match< const SCEV > m_SCEV()
@ Valid
The data is already valid.
initializer< Ty > init(const Ty &Val)
LocationClass< Ty > location(Ty &L)
@ Switch
The "resume-switch" lowering, where there are separate resume and destroy functions that are shared b...
Definition CoroShape.h:31
constexpr double e
NodeAddr< PhiNode * > Phi
Definition RDFGraph.h:390
friend class Instruction
Iterator for Instructions in a `BasicBlock.
Definition BasicBlock.h:73
This is an optimization pass for GlobalISel generic memory operations.
Definition Types.h:26
void visitAll(const SCEV *Root, SV &Visitor)
Use SCEVTraversal to visit all nodes in the given expression tree.
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
Definition STLExtras.h:316
@ Offset
Definition DWP.cpp:532
FunctionAddr VTableAddr Value
Definition InstrProf.h:137
LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt gcd(const DynamicAPInt &A, const DynamicAPInt &B)
void stable_sort(R &&Range)
Definition STLExtras.h:2106
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1737
SaveAndRestore(T &) -> SaveAndRestore< T >
Printable print(const GCNRegPressure &RP, const GCNSubtarget *ST=nullptr, unsigned DynamicVGPRBlockSize=0)
LLVM_ABI bool canCreatePoison(const Operator *Op, bool ConsiderFlagsAndMetadata=true)
LLVM_ABI bool mustTriggerUB(const Instruction *I, const SmallPtrSetImpl< const Value * > &KnownPoison)
Return true if the given instruction must trigger undefined behavior when I is executed with any oper...
LLVM_ABI bool canConstantFoldCallTo(const CallBase *Call, const Function *F)
canConstantFoldCallTo - Return true if its even possible to fold a call to the specified function.
InterleavedRange< Range > interleaved(const Range &R, StringRef Separator=", ", StringRef Prefix="", StringRef Suffix="")
Output range R as a sequence of interleaved elements.
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:643
LLVM_ABI bool verifyFunction(const Function &F, raw_ostream *OS=nullptr)
Check a function for errors, useful for use when debugging a pass.
auto successors(const MachineBasicBlock *BB)
scope_exit(Callable) -> scope_exit< Callable >
constexpr from_range_t from_range
auto dyn_cast_if_present(const Y &Val)
dyn_cast_if_present<X> - Functionally identical to dyn_cast, except that a null (or none in the case ...
Definition Casting.h:732
bool set_is_subset(const S1Ty &S1, const S2Ty &S2)
set_is_subset(A, B) - Return true iff A in B
void append_range(Container &C, Range &&R)
Wrapper function to append range R to container C.
Definition STLExtras.h:2198
constexpr bool isUIntN(unsigned N, uint64_t x)
Checks if an unsigned integer fits into the given (dynamic) bit width.
Definition MathExtras.h:243
LLVM_ABI Constant * ConstantFoldCompareInstOperands(unsigned Predicate, Constant *LHS, Constant *RHS, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr, const Instruction *I=nullptr)
Attempt to constant fold a compare instruction (icmp/fcmp) with the specified operands.
unsigned short computeExpressionSize(ArrayRef< const SCEV * > Args)
void * PointerTy
LLVM_ABI bool VerifySCEV
auto uninitialized_copy(R &&Src, IterTy Dst)
Definition STLExtras.h:2101
bool isa_and_nonnull(const Y &Val)
Definition Casting.h:676
LLVM_ABI ConstantRange getConstantRangeFromMetadata(const MDNode &RangeMD)
Parse out a conservative ConstantRange from !range metadata.
int countr_zero(T Val)
Count number of 0's from the least significant bit to the most stopping at the first 1.
Definition bit.h:202
LLVM_ABI Value * simplifyInstruction(Instruction *I, const SimplifyQuery &Q)
See if we can compute a simplified version of this instruction.
LLVM_ABI bool isOverflowIntrinsicNoWrap(const WithOverflowInst *WO, const DominatorTree &DT)
Returns true if the arithmetic part of the WO 's result is used only along the paths control dependen...
DomTreeNodeBase< BasicBlock > DomTreeNode
Definition Dominators.h:94
LLVM_ABI bool matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO, Value *&Start, Value *&Step)
Attempt to match a simple first order recurrence cycle of the form: iv = phi Ty [Start,...
auto dyn_cast_or_null(const Y &Val)
Definition Casting.h:753
void erase(Container &C, ValueType V)
Wrapper function to remove a value from a container:
Definition STLExtras.h:2190
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1744
iterator_range< pointee_iterator< WrappedIteratorT > > make_pointee_range(RangeT &&Range)
Definition iterator.h:341
auto reverse(ContainerTy &&C)
Definition STLExtras.h:406
LLVM_ABI bool isMustProgress(const Loop *L)
Return true if this loop can be assumed to make progress.
LLVM_ABI bool impliesPoison(const Value *ValAssumedPoison, const Value *V)
Return true if V is poison given that ValAssumedPoison is already poison.
LLVM_ABI bool isFinite(const Loop *L)
Return true if this loop can be assumed to run for a finite number of iterations.
LLVM_ABI void computeKnownBits(const Value *V, KnownBits &Known, const DataLayout &DL, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true, unsigned Depth=0)
Determine which bits of V are known to be either zero or one and return them in the KnownZero/KnownOn...
LLVM_ABI bool programUndefinedIfPoison(const Instruction *Inst)
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:207
bool isPointerTy(const Type *T)
Definition SPIRVUtils.h:361
FunctionAddr VTableAddr Count
Definition InstrProf.h:139
LLVM_ABI ConstantRange getVScaleRange(const Function *F, unsigned BitWidth)
Determine the possible constant range of vscale with the given bit width, based on the vscale_range f...
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
Definition Casting.h:547
LLVM_ATTRIBUTE_VISIBILITY_DEFAULT AnalysisKey InnerAnalysisManagerProxy< AnalysisManagerT, IRUnitT, ExtraArgTs... >::Key
LLVM_ABI bool isKnownNonZero(const Value *V, const SimplifyQuery &Q, unsigned Depth=0)
Return true if the given value is known to be non-zero when defined.
LLVM_ABI bool propagatesPoison(const Use &PoisonOp)
Return true if PoisonOp's user yields poison or raises UB if its operand PoisonOp is poison.
@ UMin
Unsigned integer min implemented in terms of select(cmp()).
@ Mul
Product of integers.
@ SMax
Signed integer max implemented in terms of select(cmp()).
@ SMin
Signed integer min implemented in terms of select(cmp()).
@ Add
Sum of integers.
@ UMax
Unsigned integer max implemented in terms of select(cmp()).
auto count(R &&Range, const E &Element)
Wrapper function around std::count to count the number of times an element Element occurs in the give...
Definition STLExtras.h:2002
DWARFExpression::Operation Op
auto max_element(R &&Range)
Provide wrappers to std::max_element which take ranges instead of having to pass begin/end explicitly...
Definition STLExtras.h:2078
raw_ostream & operator<<(raw_ostream &OS, const APFixedPoint &FX)
ArrayRef(const T &OneElt) -> ArrayRef< T >
LLVM_ABI unsigned ComputeNumSignBits(const Value *Op, const DataLayout &DL, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true, unsigned Depth=0)
Return the number of times the sign bit of the register is replicated into the other bits.
constexpr unsigned BitWidth
OutputIt move(R &&Range, OutputIt Out)
Provide wrappers to std::move which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1915
LLVM_ABI bool isGuaranteedToTransferExecutionToSuccessor(const Instruction *I)
Return true if this function can prove that the instruction I will always transfer execution to one o...
auto count_if(R &&Range, UnaryPredicate P)
Wrapper function around std::count_if to count the number of times an element satisfying a given pred...
Definition STLExtras.h:2009
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:559
constexpr bool isIntN(unsigned N, int64_t x)
Checks if an signed integer fits into the given (dynamic) bit width.
Definition MathExtras.h:248
auto predecessors(const MachineBasicBlock *BB)
bool is_contained(R &&Range, const E &Element)
Returns true if Element is found in Range.
Definition STLExtras.h:1945
iterator_range< df_iterator< T > > depth_first(const T &G)
auto seq(T Begin, T End)
Iterate over an integral type from Begin up to - but not including - End.
Definition Sequence.h:305
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.
LLVM_ABI bool isGuaranteedNotToBePoison(const Value *V, AssumptionCache *AC=nullptr, const Instruction *CtxI=nullptr, const DominatorTree *DT=nullptr, unsigned Depth=0)
Returns true if V cannot be poison, but may be undef.
LLVM_ABI Constant * ConstantFoldInstOperands(const Instruction *I, ArrayRef< Constant * > Ops, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr, bool AllowNonDeterministic=true)
ConstantFoldInstOperands - Attempt to constant fold an instruction with the specified operands.
bool SCEVExprContains(const SCEV *Root, PredTy Pred)
Return true if any node in Root satisfies the predicate Pred.
Implement std::hash so that hash_code can be used in STL containers.
Definition BitVector.h:870
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition BitVector.h:872
#define N
#define NC
Definition regutils.h:42
A special type used by analysis passes to provide an address that identifies that particular analysis...
Definition Analysis.h:29
static KnownBits makeConstant(const APInt &C)
Create known bits from a known constant.
Definition KnownBits.h:314
bool isNonNegative() const
Returns true if this value is known to be non-negative.
Definition KnownBits.h:108
static LLVM_ABI KnownBits ashr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for ashr(LHS, RHS).
unsigned getBitWidth() const
Get the bit width of this value.
Definition KnownBits.h:44
static LLVM_ABI KnownBits lshr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for lshr(LHS, RHS).
KnownBits zextOrTrunc(unsigned BitWidth) const
Return known bits for a zero extension or truncation of the value we're tracking.
Definition KnownBits.h:199
APInt getMaxValue() const
Return the maximal unsigned value possible given these KnownBits.
Definition KnownBits.h:148
APInt getMinValue() const
Return the minimal unsigned value possible given these KnownBits.
Definition KnownBits.h:132
bool isNegative() const
Returns true if this value is known to be negative.
Definition KnownBits.h:105
static LLVM_ABI KnownBits shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW=false, bool NSW=false, bool ShAmtNonZero=false)
Compute known bits for shl(LHS, RHS).
An object of this class is returned by queries that could not be answered.
static LLVM_ABI bool classof(const SCEV *S)
Methods for support type inquiry through isa, cast, and dyn_cast:
This class defines a simple visitor class that may be used for various SCEV analysis purposes.
A utility class that uses RAII to save and restore the value of a variable.
Information about the number of loop iterations for which a loop exit's branch condition evaluates to...
LLVM_ABI ExitLimit(const SCEV *E)
Construct either an exact exit limit from a constant, or an unknown one from a SCEVCouldNotCompute.
SmallVector< const SCEVPredicate *, 4 > Predicates
A vector of predicate guards for this ExitLimit.