LLVM 22.0.0git
ScalarEvolution.cpp
Go to the documentation of this file.
1//===- ScalarEvolution.cpp - Scalar Evolution Analysis --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains the implementation of the scalar evolution analysis
10// engine, which is used primarily to analyze expressions involving induction
11// variables in loops.
12//
13// There are several aspects to this library. First is the representation of
14// scalar expressions, which are represented as subclasses of the SCEV class.
15// These classes are used to represent certain types of subexpressions that we
16// can handle. We only create one SCEV of a particular shape, so
17// pointer-comparisons for equality are legal.
18//
19// One important aspect of the SCEV objects is that they are never cyclic, even
20// if there is a cycle in the dataflow for an expression (ie, a PHI node). If
21// the PHI node is one of the idioms that we can represent (e.g., a polynomial
22// recurrence) then we represent it directly as a recurrence node, otherwise we
23// represent it as a SCEVUnknown node.
24//
25// In addition to being able to represent expressions of various types, we also
26// have folders that are used to build the *canonical* representation for a
27// particular expression. These folders are capable of using a variety of
28// rewrite rules to simplify the expressions.
29//
30// Once the folders are defined, we can implement the more interesting
31// higher-level code, such as the code that recognizes PHI nodes of various
32// types, computes the execution count of a loop, etc.
33//
34// TODO: We should use these routines and value representations to implement
35// dependence analysis!
36//
37//===----------------------------------------------------------------------===//
38//
39// There are several good references for the techniques used in this analysis.
40//
41// Chains of recurrences -- a method to expedite the evaluation
42// of closed-form functions
43// Olaf Bachmann, Paul S. Wang, Eugene V. Zima
44//
45// On computational properties of chains of recurrences
46// Eugene V. Zima
47//
48// Symbolic Evaluation of Chains of Recurrences for Loop Optimization
49// Robert A. van Engelen
50//
51// Efficient Symbolic Analysis for Optimizing Compilers
52// Robert A. van Engelen
53//
54// Using the chains of recurrences algebra for data dependence testing and
55// induction variable substitution
56// MS Thesis, Johnie Birch
57//
58//===----------------------------------------------------------------------===//
59
61#include "llvm/ADT/APInt.h"
62#include "llvm/ADT/ArrayRef.h"
63#include "llvm/ADT/DenseMap.h"
65#include "llvm/ADT/FoldingSet.h"
66#include "llvm/ADT/STLExtras.h"
67#include "llvm/ADT/ScopeExit.h"
68#include "llvm/ADT/Sequence.h"
71#include "llvm/ADT/Statistic.h"
73#include "llvm/ADT/StringRef.h"
83#include "llvm/Config/llvm-config.h"
84#include "llvm/IR/Argument.h"
85#include "llvm/IR/BasicBlock.h"
86#include "llvm/IR/CFG.h"
87#include "llvm/IR/Constant.h"
89#include "llvm/IR/Constants.h"
90#include "llvm/IR/DataLayout.h"
92#include "llvm/IR/Dominators.h"
93#include "llvm/IR/Function.h"
94#include "llvm/IR/GlobalAlias.h"
95#include "llvm/IR/GlobalValue.h"
97#include "llvm/IR/InstrTypes.h"
98#include "llvm/IR/Instruction.h"
101#include "llvm/IR/Intrinsics.h"
102#include "llvm/IR/LLVMContext.h"
103#include "llvm/IR/Operator.h"
104#include "llvm/IR/PatternMatch.h"
105#include "llvm/IR/Type.h"
106#include "llvm/IR/Use.h"
107#include "llvm/IR/User.h"
108#include "llvm/IR/Value.h"
109#include "llvm/IR/Verifier.h"
111#include "llvm/Pass.h"
112#include "llvm/Support/Casting.h"
115#include "llvm/Support/Debug.h"
121#include <algorithm>
122#include <cassert>
123#include <climits>
124#include <cstdint>
125#include <cstdlib>
126#include <map>
127#include <memory>
128#include <numeric>
129#include <optional>
130#include <tuple>
131#include <utility>
132#include <vector>
133
134using namespace llvm;
135using namespace PatternMatch;
136using namespace SCEVPatternMatch;
137
138#define DEBUG_TYPE "scalar-evolution"
139
140STATISTIC(NumExitCountsComputed,
141 "Number of loop exits with predictable exit counts");
142STATISTIC(NumExitCountsNotComputed,
143 "Number of loop exits without predictable exit counts");
144STATISTIC(NumBruteForceTripCountsComputed,
145 "Number of loops with trip counts computed by force");
146
147#ifdef EXPENSIVE_CHECKS
148bool llvm::VerifySCEV = true;
149#else
150bool llvm::VerifySCEV = false;
151#endif
152
154 MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
155 cl::desc("Maximum number of iterations SCEV will "
156 "symbolically execute a constant "
157 "derived loop"),
158 cl::init(100));
159
161 "verify-scev", cl::Hidden, cl::location(VerifySCEV),
162 cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"));
164 "verify-scev-strict", cl::Hidden,
165 cl::desc("Enable stricter verification with -verify-scev is passed"));
166
168 "scev-verify-ir", cl::Hidden,
169 cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"),
170 cl::init(false));
171
173 "scev-mulops-inline-threshold", cl::Hidden,
174 cl::desc("Threshold for inlining multiplication operands into a SCEV"),
175 cl::init(32));
176
178 "scev-addops-inline-threshold", cl::Hidden,
179 cl::desc("Threshold for inlining addition operands into a SCEV"),
180 cl::init(500));
181
183 "scalar-evolution-max-scev-compare-depth", cl::Hidden,
184 cl::desc("Maximum depth of recursive SCEV complexity comparisons"),
185 cl::init(32));
186
188 "scalar-evolution-max-scev-operations-implication-depth", cl::Hidden,
189 cl::desc("Maximum depth of recursive SCEV operations implication analysis"),
190 cl::init(2));
191
193 "scalar-evolution-max-value-compare-depth", cl::Hidden,
194 cl::desc("Maximum depth of recursive value complexity comparisons"),
195 cl::init(2));
196
198 MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden,
199 cl::desc("Maximum depth of recursive arithmetics"),
200 cl::init(32));
201
203 "scalar-evolution-max-constant-evolving-depth", cl::Hidden,
204 cl::desc("Maximum depth of recursive constant evolving"), cl::init(32));
205
207 MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden,
208 cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"),
209 cl::init(8));
210
212 MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden,
213 cl::desc("Max coefficients in AddRec during evolving"),
214 cl::init(8));
215
217 HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden,
218 cl::desc("Size of the expression which is considered huge"),
219 cl::init(4096));
220
222 "scev-range-iter-threshold", cl::Hidden,
223 cl::desc("Threshold for switching to iteratively computing SCEV ranges"),
224 cl::init(32));
225
227 "scalar-evolution-max-loop-guard-collection-depth", cl::Hidden,
228 cl::desc("Maximum depth for recursive loop guard collection"), cl::init(1));
229
230static cl::opt<bool>
231ClassifyExpressions("scalar-evolution-classify-expressions",
232 cl::Hidden, cl::init(true),
233 cl::desc("When printing analysis, include information on every instruction"));
234
236 "scalar-evolution-use-expensive-range-sharpening", cl::Hidden,
237 cl::init(false),
238 cl::desc("Use more powerful methods of sharpening expression ranges. May "
239 "be costly in terms of compile time"));
240
242 "scalar-evolution-max-scc-analysis-depth", cl::Hidden,
243 cl::desc("Maximum amount of nodes to process while searching SCEVUnknown "
244 "Phi strongly connected components"),
245 cl::init(8));
246
247static cl::opt<bool>
248 EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden,
249 cl::desc("Handle <= and >= in finite loops"),
250 cl::init(true));
251
253 "scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden,
254 cl::desc("Infer nuw/nsw flags using context where suitable"),
255 cl::init(true));
256
257//===----------------------------------------------------------------------===//
258// SCEV class definitions
259//===----------------------------------------------------------------------===//
260
261//===----------------------------------------------------------------------===//
262// Implementation of the SCEV class.
263//
264
265#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
267 print(dbgs());
268 dbgs() << '\n';
269}
270#endif
271
272void SCEV::print(raw_ostream &OS) const {
273 switch (getSCEVType()) {
274 case scConstant:
275 cast<SCEVConstant>(this)->getValue()->printAsOperand(OS, false);
276 return;
277 case scVScale:
278 OS << "vscale";
279 return;
280 case scPtrToInt: {
281 const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(this);
282 const SCEV *Op = PtrToInt->getOperand();
283 OS << "(ptrtoint " << *Op->getType() << " " << *Op << " to "
284 << *PtrToInt->getType() << ")";
285 return;
286 }
287 case scTruncate: {
288 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(this);
289 const SCEV *Op = Trunc->getOperand();
290 OS << "(trunc " << *Op->getType() << " " << *Op << " to "
291 << *Trunc->getType() << ")";
292 return;
293 }
294 case scZeroExtend: {
296 const SCEV *Op = ZExt->getOperand();
297 OS << "(zext " << *Op->getType() << " " << *Op << " to "
298 << *ZExt->getType() << ")";
299 return;
300 }
301 case scSignExtend: {
303 const SCEV *Op = SExt->getOperand();
304 OS << "(sext " << *Op->getType() << " " << *Op << " to "
305 << *SExt->getType() << ")";
306 return;
307 }
308 case scAddRecExpr: {
309 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(this);
310 OS << "{" << *AR->getOperand(0);
311 for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i)
312 OS << ",+," << *AR->getOperand(i);
313 OS << "}<";
314 if (AR->hasNoUnsignedWrap())
315 OS << "nuw><";
316 if (AR->hasNoSignedWrap())
317 OS << "nsw><";
318 if (AR->hasNoSelfWrap() &&
320 OS << "nw><";
321 AR->getLoop()->getHeader()->printAsOperand(OS, /*PrintType=*/false);
322 OS << ">";
323 return;
324 }
325 case scAddExpr:
326 case scMulExpr:
327 case scUMaxExpr:
328 case scSMaxExpr:
329 case scUMinExpr:
330 case scSMinExpr:
332 const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(this);
333 const char *OpStr = nullptr;
334 switch (NAry->getSCEVType()) {
335 case scAddExpr: OpStr = " + "; break;
336 case scMulExpr: OpStr = " * "; break;
337 case scUMaxExpr: OpStr = " umax "; break;
338 case scSMaxExpr: OpStr = " smax "; break;
339 case scUMinExpr:
340 OpStr = " umin ";
341 break;
342 case scSMinExpr:
343 OpStr = " smin ";
344 break;
346 OpStr = " umin_seq ";
347 break;
348 default:
349 llvm_unreachable("There are no other nary expression types.");
350 }
351 OS << "("
353 << ")";
354 switch (NAry->getSCEVType()) {
355 case scAddExpr:
356 case scMulExpr:
357 if (NAry->hasNoUnsignedWrap())
358 OS << "<nuw>";
359 if (NAry->hasNoSignedWrap())
360 OS << "<nsw>";
361 break;
362 default:
363 // Nothing to print for other nary expressions.
364 break;
365 }
366 return;
367 }
368 case scUDivExpr: {
369 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(this);
370 OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")";
371 return;
372 }
373 case scUnknown:
374 cast<SCEVUnknown>(this)->getValue()->printAsOperand(OS, false);
375 return;
377 OS << "***COULDNOTCOMPUTE***";
378 return;
379 }
380 llvm_unreachable("Unknown SCEV kind!");
381}
382
384 switch (getSCEVType()) {
385 case scConstant:
386 return cast<SCEVConstant>(this)->getType();
387 case scVScale:
388 return cast<SCEVVScale>(this)->getType();
389 case scPtrToInt:
390 case scTruncate:
391 case scZeroExtend:
392 case scSignExtend:
393 return cast<SCEVCastExpr>(this)->getType();
394 case scAddRecExpr:
395 return cast<SCEVAddRecExpr>(this)->getType();
396 case scMulExpr:
397 return cast<SCEVMulExpr>(this)->getType();
398 case scUMaxExpr:
399 case scSMaxExpr:
400 case scUMinExpr:
401 case scSMinExpr:
402 return cast<SCEVMinMaxExpr>(this)->getType();
404 return cast<SCEVSequentialMinMaxExpr>(this)->getType();
405 case scAddExpr:
406 return cast<SCEVAddExpr>(this)->getType();
407 case scUDivExpr:
408 return cast<SCEVUDivExpr>(this)->getType();
409 case scUnknown:
410 return cast<SCEVUnknown>(this)->getType();
412 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
413 }
414 llvm_unreachable("Unknown SCEV kind!");
415}
416
418 switch (getSCEVType()) {
419 case scConstant:
420 case scVScale:
421 case scUnknown:
422 return {};
423 case scPtrToInt:
424 case scTruncate:
425 case scZeroExtend:
426 case scSignExtend:
427 return cast<SCEVCastExpr>(this)->operands();
428 case scAddRecExpr:
429 case scAddExpr:
430 case scMulExpr:
431 case scUMaxExpr:
432 case scSMaxExpr:
433 case scUMinExpr:
434 case scSMinExpr:
436 return cast<SCEVNAryExpr>(this)->operands();
437 case scUDivExpr:
438 return cast<SCEVUDivExpr>(this)->operands();
440 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
441 }
442 llvm_unreachable("Unknown SCEV kind!");
443}
444
445bool SCEV::isZero() const { return match(this, m_scev_Zero()); }
446
447bool SCEV::isOne() const { return match(this, m_scev_One()); }
448
449bool SCEV::isAllOnesValue() const { return match(this, m_scev_AllOnes()); }
450
453 if (!Mul) return false;
454
455 // If there is a constant factor, it will be first.
456 const SCEVConstant *SC = dyn_cast<SCEVConstant>(Mul->getOperand(0));
457 if (!SC) return false;
458
459 // Return true if the value is negative, this matches things like (-42 * V).
460 return SC->getAPInt().isNegative();
461}
462
465
467 return S->getSCEVType() == scCouldNotCompute;
468}
469
472 ID.AddInteger(scConstant);
473 ID.AddPointer(V);
474 void *IP = nullptr;
475 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
476 SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V);
477 UniqueSCEVs.InsertNode(S, IP);
478 return S;
479}
480
482 return getConstant(ConstantInt::get(getContext(), Val));
483}
484
485const SCEV *
488 // TODO: Avoid implicit trunc?
489 // See https://github.com/llvm/llvm-project/issues/112510.
490 return getConstant(
491 ConstantInt::get(ITy, V, isSigned, /*ImplicitTrunc=*/true));
492}
493
496 ID.AddInteger(scVScale);
497 ID.AddPointer(Ty);
498 void *IP = nullptr;
499 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
500 return S;
501 SCEV *S = new (SCEVAllocator) SCEVVScale(ID.Intern(SCEVAllocator), Ty);
502 UniqueSCEVs.InsertNode(S, IP);
503 return S;
504}
505
507 SCEV::NoWrapFlags Flags) {
508 const SCEV *Res = getConstant(Ty, EC.getKnownMinValue());
509 if (EC.isScalable())
510 Res = getMulExpr(Res, getVScale(Ty), Flags);
511 return Res;
512}
513
515 const SCEV *op, Type *ty)
516 : SCEV(ID, SCEVTy, computeExpressionSize(op)), Op(op), Ty(ty) {}
517
518SCEVPtrToIntExpr::SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op,
519 Type *ITy)
520 : SCEVCastExpr(ID, scPtrToInt, Op, ITy) {
521 assert(getOperand()->getType()->isPointerTy() && Ty->isIntegerTy() &&
522 "Must be a non-bit-width-changing pointer-to-integer cast!");
523}
524
526 SCEVTypes SCEVTy, const SCEV *op,
527 Type *ty)
528 : SCEVCastExpr(ID, SCEVTy, op, ty) {}
529
530SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op,
531 Type *ty)
533 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
534 "Cannot truncate non-integer value!");
535}
536
537SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID,
538 const SCEV *op, Type *ty)
540 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
541 "Cannot zero extend non-integer value!");
542}
543
544SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID,
545 const SCEV *op, Type *ty)
547 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
548 "Cannot sign extend non-integer value!");
549}
550
552 // Clear this SCEVUnknown from various maps.
553 SE->forgetMemoizedResults(this);
554
555 // Remove this SCEVUnknown from the uniquing map.
556 SE->UniqueSCEVs.RemoveNode(this);
557
558 // Release the value.
559 setValPtr(nullptr);
560}
561
562void SCEVUnknown::allUsesReplacedWith(Value *New) {
563 // Clear this SCEVUnknown from various maps.
564 SE->forgetMemoizedResults(this);
565
566 // Remove this SCEVUnknown from the uniquing map.
567 SE->UniqueSCEVs.RemoveNode(this);
568
569 // Replace the value pointer in case someone is still using this SCEVUnknown.
570 setValPtr(New);
571}
572
573//===----------------------------------------------------------------------===//
574// SCEV Utilities
575//===----------------------------------------------------------------------===//
576
577/// Compare the two values \p LV and \p RV in terms of their "complexity" where
578/// "complexity" is a partial (and somewhat ad-hoc) relation used to order
579/// operands in SCEV expressions.
580static int CompareValueComplexity(const LoopInfo *const LI, Value *LV,
581 Value *RV, unsigned Depth) {
583 return 0;
584
585 // Order pointer values after integer values. This helps SCEVExpander form
586 // GEPs.
587 bool LIsPointer = LV->getType()->isPointerTy(),
588 RIsPointer = RV->getType()->isPointerTy();
589 if (LIsPointer != RIsPointer)
590 return (int)LIsPointer - (int)RIsPointer;
591
592 // Compare getValueID values.
593 unsigned LID = LV->getValueID(), RID = RV->getValueID();
594 if (LID != RID)
595 return (int)LID - (int)RID;
596
597 // Sort arguments by their position.
598 if (const auto *LA = dyn_cast<Argument>(LV)) {
599 const auto *RA = cast<Argument>(RV);
600 unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo();
601 return (int)LArgNo - (int)RArgNo;
602 }
603
604 if (const auto *LGV = dyn_cast<GlobalValue>(LV)) {
605 const auto *RGV = cast<GlobalValue>(RV);
606
607 if (auto L = LGV->getLinkage() - RGV->getLinkage())
608 return L;
609
610 const auto IsGVNameSemantic = [&](const GlobalValue *GV) {
611 auto LT = GV->getLinkage();
612 return !(GlobalValue::isPrivateLinkage(LT) ||
614 };
615
616 // Use the names to distinguish the two values, but only if the
617 // names are semantically important.
618 if (IsGVNameSemantic(LGV) && IsGVNameSemantic(RGV))
619 return LGV->getName().compare(RGV->getName());
620 }
621
622 // For instructions, compare their loop depth, and their operand count. This
623 // is pretty loose.
624 if (const auto *LInst = dyn_cast<Instruction>(LV)) {
625 const auto *RInst = cast<Instruction>(RV);
626
627 // Compare loop depths.
628 const BasicBlock *LParent = LInst->getParent(),
629 *RParent = RInst->getParent();
630 if (LParent != RParent) {
631 unsigned LDepth = LI->getLoopDepth(LParent),
632 RDepth = LI->getLoopDepth(RParent);
633 if (LDepth != RDepth)
634 return (int)LDepth - (int)RDepth;
635 }
636
637 // Compare the number of operands.
638 unsigned LNumOps = LInst->getNumOperands(),
639 RNumOps = RInst->getNumOperands();
640 if (LNumOps != RNumOps)
641 return (int)LNumOps - (int)RNumOps;
642
643 for (unsigned Idx : seq(LNumOps)) {
644 int Result = CompareValueComplexity(LI, LInst->getOperand(Idx),
645 RInst->getOperand(Idx), Depth + 1);
646 if (Result != 0)
647 return Result;
648 }
649 }
650
651 return 0;
652}
653
654// Return negative, zero, or positive, if LHS is less than, equal to, or greater
655// than RHS, respectively. A three-way result allows recursive comparisons to be
656// more efficient.
657// If the max analysis depth was reached, return std::nullopt, assuming we do
658// not know if they are equivalent for sure.
659static std::optional<int>
660CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS,
661 const SCEV *RHS, DominatorTree &DT, unsigned Depth = 0) {
662 // Fast-path: SCEVs are uniqued so we can do a quick equality check.
663 if (LHS == RHS)
664 return 0;
665
666 // Primarily, sort the SCEVs by their getSCEVType().
667 SCEVTypes LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
668 if (LType != RType)
669 return (int)LType - (int)RType;
670
672 return std::nullopt;
673
674 // Aside from the getSCEVType() ordering, the particular ordering
675 // isn't very important except that it's beneficial to be consistent,
676 // so that (a + b) and (b + a) don't end up as different expressions.
677 switch (LType) {
678 case scUnknown: {
679 const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
680 const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
681
682 int X =
683 CompareValueComplexity(LI, LU->getValue(), RU->getValue(), Depth + 1);
684 return X;
685 }
686
687 case scConstant: {
690
691 // Compare constant values.
692 const APInt &LA = LC->getAPInt();
693 const APInt &RA = RC->getAPInt();
694 unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth();
695 if (LBitWidth != RBitWidth)
696 return (int)LBitWidth - (int)RBitWidth;
697 return LA.ult(RA) ? -1 : 1;
698 }
699
700 case scVScale: {
701 const auto *LTy = cast<IntegerType>(cast<SCEVVScale>(LHS)->getType());
702 const auto *RTy = cast<IntegerType>(cast<SCEVVScale>(RHS)->getType());
703 return LTy->getBitWidth() - RTy->getBitWidth();
704 }
705
706 case scAddRecExpr: {
709
710 // There is always a dominance between two recs that are used by one SCEV,
711 // so we can safely sort recs by loop header dominance. We require such
712 // order in getAddExpr.
713 const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop();
714 if (LLoop != RLoop) {
715 const BasicBlock *LHead = LLoop->getHeader(), *RHead = RLoop->getHeader();
716 assert(LHead != RHead && "Two loops share the same header?");
717 if (DT.dominates(LHead, RHead))
718 return 1;
719 assert(DT.dominates(RHead, LHead) &&
720 "No dominance between recurrences used by one SCEV?");
721 return -1;
722 }
723
724 [[fallthrough]];
725 }
726
727 case scTruncate:
728 case scZeroExtend:
729 case scSignExtend:
730 case scPtrToInt:
731 case scAddExpr:
732 case scMulExpr:
733 case scUDivExpr:
734 case scSMaxExpr:
735 case scUMaxExpr:
736 case scSMinExpr:
737 case scUMinExpr:
739 ArrayRef<const SCEV *> LOps = LHS->operands();
740 ArrayRef<const SCEV *> ROps = RHS->operands();
741
742 // Lexicographically compare n-ary-like expressions.
743 unsigned LNumOps = LOps.size(), RNumOps = ROps.size();
744 if (LNumOps != RNumOps)
745 return (int)LNumOps - (int)RNumOps;
746
747 for (unsigned i = 0; i != LNumOps; ++i) {
748 auto X = CompareSCEVComplexity(LI, LOps[i], ROps[i], DT, Depth + 1);
749 if (X != 0)
750 return X;
751 }
752 return 0;
753 }
754
756 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
757 }
758 llvm_unreachable("Unknown SCEV kind!");
759}
760
761/// Given a list of SCEV objects, order them by their complexity, and group
762/// objects of the same complexity together by value. When this routine is
763/// finished, we know that any duplicates in the vector are consecutive and that
764/// complexity is monotonically increasing.
765///
766/// Note that we go take special precautions to ensure that we get deterministic
767/// results from this routine. In other words, we don't want the results of
768/// this to depend on where the addresses of various SCEV objects happened to
769/// land in memory.
771 LoopInfo *LI, DominatorTree &DT) {
772 if (Ops.size() < 2) return; // Noop
773
774 // Whether LHS has provably less complexity than RHS.
775 auto IsLessComplex = [&](const SCEV *LHS, const SCEV *RHS) {
776 auto Complexity = CompareSCEVComplexity(LI, LHS, RHS, DT);
777 return Complexity && *Complexity < 0;
778 };
779 if (Ops.size() == 2) {
780 // This is the common case, which also happens to be trivially simple.
781 // Special case it.
782 const SCEV *&LHS = Ops[0], *&RHS = Ops[1];
783 if (IsLessComplex(RHS, LHS))
784 std::swap(LHS, RHS);
785 return;
786 }
787
788 // Do the rough sort by complexity.
789 llvm::stable_sort(Ops, [&](const SCEV *LHS, const SCEV *RHS) {
790 return IsLessComplex(LHS, RHS);
791 });
792
793 // Now that we are sorted by complexity, group elements of the same
794 // complexity. Note that this is, at worst, N^2, but the vector is likely to
795 // be extremely short in practice. Note that we take this approach because we
796 // do not want to depend on the addresses of the objects we are grouping.
797 for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) {
798 const SCEV *S = Ops[i];
799 unsigned Complexity = S->getSCEVType();
800
801 // If there are any objects of the same complexity and same value as this
802 // one, group them.
803 for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) {
804 if (Ops[j] == S) { // Found a duplicate.
805 // Move it to immediately after i'th element.
806 std::swap(Ops[i+1], Ops[j]);
807 ++i; // no need to rescan it.
808 if (i == e-2) return; // Done!
809 }
810 }
811 }
812}
813
814/// Returns true if \p Ops contains a huge SCEV (the subtree of S contains at
815/// least HugeExprThreshold nodes).
817 return any_of(Ops, [](const SCEV *S) {
819 });
820}
821
822/// Performs a number of common optimizations on the passed \p Ops. If the
823/// whole expression reduces down to a single operand, it will be returned.
824///
825/// The following optimizations are performed:
826/// * Fold constants using the \p Fold function.
827/// * Remove identity constants satisfying \p IsIdentity.
828/// * If a constant satisfies \p IsAbsorber, return it.
829/// * Sort operands by complexity.
830template <typename FoldT, typename IsIdentityT, typename IsAbsorberT>
831static const SCEV *
834 IsIdentityT IsIdentity, IsAbsorberT IsAbsorber) {
835 const SCEVConstant *Folded = nullptr;
836 for (unsigned Idx = 0; Idx < Ops.size();) {
837 const SCEV *Op = Ops[Idx];
838 if (const auto *C = dyn_cast<SCEVConstant>(Op)) {
839 if (!Folded)
840 Folded = C;
841 else
842 Folded = cast<SCEVConstant>(
843 SE.getConstant(Fold(Folded->getAPInt(), C->getAPInt())));
844 Ops.erase(Ops.begin() + Idx);
845 continue;
846 }
847 ++Idx;
848 }
849
850 if (Ops.empty()) {
851 assert(Folded && "Must have folded value");
852 return Folded;
853 }
854
855 if (Folded && IsAbsorber(Folded->getAPInt()))
856 return Folded;
857
858 GroupByComplexity(Ops, &LI, DT);
859 if (Folded && !IsIdentity(Folded->getAPInt()))
860 Ops.insert(Ops.begin(), Folded);
861
862 return Ops.size() == 1 ? Ops[0] : nullptr;
863}
864
865//===----------------------------------------------------------------------===//
866// Simple SCEV method implementations
867//===----------------------------------------------------------------------===//
868
869/// Compute BC(It, K). The result has width W. Assume, K > 0.
870static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
871 ScalarEvolution &SE,
872 Type *ResultTy) {
873 // Handle the simplest case efficiently.
874 if (K == 1)
875 return SE.getTruncateOrZeroExtend(It, ResultTy);
876
877 // We are using the following formula for BC(It, K):
878 //
879 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
880 //
881 // Suppose, W is the bitwidth of the return value. We must be prepared for
882 // overflow. Hence, we must assure that the result of our computation is
883 // equal to the accurate one modulo 2^W. Unfortunately, division isn't
884 // safe in modular arithmetic.
885 //
886 // However, this code doesn't use exactly that formula; the formula it uses
887 // is something like the following, where T is the number of factors of 2 in
888 // K! (i.e. trailing zeros in the binary representation of K!), and ^ is
889 // exponentiation:
890 //
891 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T)
892 //
893 // This formula is trivially equivalent to the previous formula. However,
894 // this formula can be implemented much more efficiently. The trick is that
895 // K! / 2^T is odd, and exact division by an odd number *is* safe in modular
896 // arithmetic. To do exact division in modular arithmetic, all we have
897 // to do is multiply by the inverse. Therefore, this step can be done at
898 // width W.
899 //
900 // The next issue is how to safely do the division by 2^T. The way this
901 // is done is by doing the multiplication step at a width of at least W + T
902 // bits. This way, the bottom W+T bits of the product are accurate. Then,
903 // when we perform the division by 2^T (which is equivalent to a right shift
904 // by T), the bottom W bits are accurate. Extra bits are okay; they'll get
905 // truncated out after the division by 2^T.
906 //
907 // In comparison to just directly using the first formula, this technique
908 // is much more efficient; using the first formula requires W * K bits,
909 // but this formula less than W + K bits. Also, the first formula requires
910 // a division step, whereas this formula only requires multiplies and shifts.
911 //
912 // It doesn't matter whether the subtraction step is done in the calculation
913 // width or the input iteration count's width; if the subtraction overflows,
914 // the result must be zero anyway. We prefer here to do it in the width of
915 // the induction variable because it helps a lot for certain cases; CodeGen
916 // isn't smart enough to ignore the overflow, which leads to much less
917 // efficient code if the width of the subtraction is wider than the native
918 // register width.
919 //
920 // (It's possible to not widen at all by pulling out factors of 2 before
921 // the multiplication; for example, K=2 can be calculated as
922 // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires
923 // extra arithmetic, so it's not an obvious win, and it gets
924 // much more complicated for K > 3.)
925
926 // Protection from insane SCEVs; this bound is conservative,
927 // but it probably doesn't matter.
928 if (K > 1000)
929 return SE.getCouldNotCompute();
930
931 unsigned W = SE.getTypeSizeInBits(ResultTy);
932
933 // Calculate K! / 2^T and T; we divide out the factors of two before
934 // multiplying for calculating K! / 2^T to avoid overflow.
935 // Other overflow doesn't matter because we only care about the bottom
936 // W bits of the result.
937 APInt OddFactorial(W, 1);
938 unsigned T = 1;
939 for (unsigned i = 3; i <= K; ++i) {
940 unsigned TwoFactors = countr_zero(i);
941 T += TwoFactors;
942 OddFactorial *= (i >> TwoFactors);
943 }
944
945 // We need at least W + T bits for the multiplication step
946 unsigned CalculationBits = W + T;
947
948 // Calculate 2^T, at width T+W.
949 APInt DivFactor = APInt::getOneBitSet(CalculationBits, T);
950
951 // Calculate the multiplicative inverse of K! / 2^T;
952 // this multiplication factor will perform the exact division by
953 // K! / 2^T.
954 APInt MultiplyFactor = OddFactorial.multiplicativeInverse();
955
956 // Calculate the product, at width T+W
957 IntegerType *CalculationTy = IntegerType::get(SE.getContext(),
958 CalculationBits);
959 const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
960 for (unsigned i = 1; i != K; ++i) {
961 const SCEV *S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i));
962 Dividend = SE.getMulExpr(Dividend,
963 SE.getTruncateOrZeroExtend(S, CalculationTy));
964 }
965
966 // Divide by 2^T
967 const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
968
969 // Truncate the result, and divide by K! / 2^T.
970
971 return SE.getMulExpr(SE.getConstant(MultiplyFactor),
972 SE.getTruncateOrZeroExtend(DivResult, ResultTy));
973}
974
975/// Return the value of this chain of recurrences at the specified iteration
976/// number. We can evaluate this recurrence by multiplying each element in the
977/// chain by the binomial coefficient corresponding to it. In other words, we
978/// can evaluate {A,+,B,+,C,+,D} as:
979///
980/// A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
981///
982/// where BC(It, k) stands for binomial coefficient.
984 ScalarEvolution &SE) const {
985 return evaluateAtIteration(operands(), It, SE);
986}
987
988const SCEV *
990 const SCEV *It, ScalarEvolution &SE) {
991 assert(Operands.size() > 0);
992 const SCEV *Result = Operands[0];
993 for (unsigned i = 1, e = Operands.size(); i != e; ++i) {
994 // The computation is correct in the face of overflow provided that the
995 // multiplication is performed _after_ the evaluation of the binomial
996 // coefficient.
997 const SCEV *Coeff = BinomialCoefficient(It, i, SE, Result->getType());
998 if (isa<SCEVCouldNotCompute>(Coeff))
999 return Coeff;
1000
1001 Result = SE.getAddExpr(Result, SE.getMulExpr(Operands[i], Coeff));
1002 }
1003 return Result;
1004}
1005
1006//===----------------------------------------------------------------------===//
1007// SCEV Expression folder implementations
1008//===----------------------------------------------------------------------===//
1009
1011 unsigned Depth) {
1012 assert(Depth <= 1 &&
1013 "getLosslessPtrToIntExpr() should self-recurse at most once.");
1014
1015 // We could be called with an integer-typed operands during SCEV rewrites.
1016 // Since the operand is an integer already, just perform zext/trunc/self cast.
1017 if (!Op->getType()->isPointerTy())
1018 return Op;
1019
1020 // What would be an ID for such a SCEV cast expression?
1022 ID.AddInteger(scPtrToInt);
1023 ID.AddPointer(Op);
1024
1025 void *IP = nullptr;
1026
1027 // Is there already an expression for such a cast?
1028 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1029 return S;
1030
1031 // It isn't legal for optimizations to construct new ptrtoint expressions
1032 // for non-integral pointers.
1033 if (getDataLayout().isNonIntegralPointerType(Op->getType()))
1034 return getCouldNotCompute();
1035
1036 Type *IntPtrTy = getDataLayout().getIntPtrType(Op->getType());
1037
1038 // We can only trivially model ptrtoint if SCEV's effective (integer) type
1039 // is sufficiently wide to represent all possible pointer values.
1040 // We could theoretically teach SCEV to truncate wider pointers, but
1041 // that isn't implemented for now.
1043 getDataLayout().getTypeSizeInBits(IntPtrTy))
1044 return getCouldNotCompute();
1045
1046 // If not, is this expression something we can't reduce any further?
1047 if (auto *U = dyn_cast<SCEVUnknown>(Op)) {
1048 // Perform some basic constant folding. If the operand of the ptr2int cast
1049 // is a null pointer, don't create a ptr2int SCEV expression (that will be
1050 // left as-is), but produce a zero constant.
1051 // NOTE: We could handle a more general case, but lack motivational cases.
1052 if (isa<ConstantPointerNull>(U->getValue()))
1053 return getZero(IntPtrTy);
1054
1055 // Create an explicit cast node.
1056 // We can reuse the existing insert position since if we get here,
1057 // we won't have made any changes which would invalidate it.
1058 SCEV *S = new (SCEVAllocator)
1059 SCEVPtrToIntExpr(ID.Intern(SCEVAllocator), Op, IntPtrTy);
1060 UniqueSCEVs.InsertNode(S, IP);
1061 registerUser(S, Op);
1062 return S;
1063 }
1064
1065 assert(Depth == 0 && "getLosslessPtrToIntExpr() should not self-recurse for "
1066 "non-SCEVUnknown's.");
1067
1068 // Otherwise, we've got some expression that is more complex than just a
1069 // single SCEVUnknown. But we don't want to have a SCEVPtrToIntExpr of an
1070 // arbitrary expression, we want to have SCEVPtrToIntExpr of an SCEVUnknown
1071 // only, and the expressions must otherwise be integer-typed.
1072 // So sink the cast down to the SCEVUnknown's.
1073
1074 /// The SCEVPtrToIntSinkingRewriter takes a scalar evolution expression,
1075 /// which computes a pointer-typed value, and rewrites the whole expression
1076 /// tree so that *all* the computations are done on integers, and the only
1077 /// pointer-typed operands in the expression are SCEVUnknown.
1078 class SCEVPtrToIntSinkingRewriter
1079 : public SCEVRewriteVisitor<SCEVPtrToIntSinkingRewriter> {
1081
1082 public:
1083 SCEVPtrToIntSinkingRewriter(ScalarEvolution &SE) : SCEVRewriteVisitor(SE) {}
1084
1085 static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE) {
1086 SCEVPtrToIntSinkingRewriter Rewriter(SE);
1087 return Rewriter.visit(Scev);
1088 }
1089
1090 const SCEV *visit(const SCEV *S) {
1091 Type *STy = S->getType();
1092 // If the expression is not pointer-typed, just keep it as-is.
1093 if (!STy->isPointerTy())
1094 return S;
1095 // Else, recursively sink the cast down into it.
1096 return Base::visit(S);
1097 }
1098
1099 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
1101 bool Changed = false;
1102 for (const auto *Op : Expr->operands()) {
1103 Operands.push_back(visit(Op));
1104 Changed |= Op != Operands.back();
1105 }
1106 return !Changed ? Expr : SE.getAddExpr(Operands, Expr->getNoWrapFlags());
1107 }
1108
1109 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
1111 bool Changed = false;
1112 for (const auto *Op : Expr->operands()) {
1113 Operands.push_back(visit(Op));
1114 Changed |= Op != Operands.back();
1115 }
1116 return !Changed ? Expr : SE.getMulExpr(Operands, Expr->getNoWrapFlags());
1117 }
1118
1119 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
1120 assert(Expr->getType()->isPointerTy() &&
1121 "Should only reach pointer-typed SCEVUnknown's.");
1122 return SE.getLosslessPtrToIntExpr(Expr, /*Depth=*/1);
1123 }
1124 };
1125
1126 // And actually perform the cast sinking.
1127 const SCEV *IntOp = SCEVPtrToIntSinkingRewriter::rewrite(Op, *this);
1128 assert(IntOp->getType()->isIntegerTy() &&
1129 "We must have succeeded in sinking the cast, "
1130 "and ending up with an integer-typed expression!");
1131 return IntOp;
1132}
1133
1135 assert(Ty->isIntegerTy() && "Target type must be an integer type!");
1136
1137 const SCEV *IntOp = getLosslessPtrToIntExpr(Op);
1138 if (isa<SCEVCouldNotCompute>(IntOp))
1139 return IntOp;
1140
1141 return getTruncateOrZeroExtend(IntOp, Ty);
1142}
1143
1145 unsigned Depth) {
1146 assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) &&
1147 "This is not a truncating conversion!");
1148 assert(isSCEVable(Ty) &&
1149 "This is not a conversion to a SCEVable type!");
1150 assert(!Op->getType()->isPointerTy() && "Can't truncate pointer!");
1151 Ty = getEffectiveSCEVType(Ty);
1152
1154 ID.AddInteger(scTruncate);
1155 ID.AddPointer(Op);
1156 ID.AddPointer(Ty);
1157 void *IP = nullptr;
1158 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1159
1160 // Fold if the operand is constant.
1161 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1162 return getConstant(
1163 cast<ConstantInt>(ConstantExpr::getTrunc(SC->getValue(), Ty)));
1164
1165 // trunc(trunc(x)) --> trunc(x)
1167 return getTruncateExpr(ST->getOperand(), Ty, Depth + 1);
1168
1169 // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing
1171 return getTruncateOrSignExtend(SS->getOperand(), Ty, Depth + 1);
1172
1173 // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing
1175 return getTruncateOrZeroExtend(SZ->getOperand(), Ty, Depth + 1);
1176
1177 if (Depth > MaxCastDepth) {
1178 SCEV *S =
1179 new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), Op, Ty);
1180 UniqueSCEVs.InsertNode(S, IP);
1181 registerUser(S, Op);
1182 return S;
1183 }
1184
1185 // trunc(x1 + ... + xN) --> trunc(x1) + ... + trunc(xN) and
1186 // trunc(x1 * ... * xN) --> trunc(x1) * ... * trunc(xN),
1187 // if after transforming we have at most one truncate, not counting truncates
1188 // that replace other casts.
1190 auto *CommOp = cast<SCEVCommutativeExpr>(Op);
1192 unsigned numTruncs = 0;
1193 for (unsigned i = 0, e = CommOp->getNumOperands(); i != e && numTruncs < 2;
1194 ++i) {
1195 const SCEV *S = getTruncateExpr(CommOp->getOperand(i), Ty, Depth + 1);
1196 if (!isa<SCEVIntegralCastExpr>(CommOp->getOperand(i)) &&
1198 numTruncs++;
1199 Operands.push_back(S);
1200 }
1201 if (numTruncs < 2) {
1202 if (isa<SCEVAddExpr>(Op))
1203 return getAddExpr(Operands);
1204 if (isa<SCEVMulExpr>(Op))
1205 return getMulExpr(Operands);
1206 llvm_unreachable("Unexpected SCEV type for Op.");
1207 }
1208 // Although we checked in the beginning that ID is not in the cache, it is
1209 // possible that during recursion and different modification ID was inserted
1210 // into the cache. So if we find it, just return it.
1211 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1212 return S;
1213 }
1214
1215 // If the input value is a chrec scev, truncate the chrec's operands.
1216 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
1218 for (const SCEV *Op : AddRec->operands())
1219 Operands.push_back(getTruncateExpr(Op, Ty, Depth + 1));
1220 return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap);
1221 }
1222
1223 // Return zero if truncating to known zeros.
1224 uint32_t MinTrailingZeros = getMinTrailingZeros(Op);
1225 if (MinTrailingZeros >= getTypeSizeInBits(Ty))
1226 return getZero(Ty);
1227
1228 // The cast wasn't folded; create an explicit cast node. We can reuse
1229 // the existing insert position since if we get here, we won't have
1230 // made any changes which would invalidate it.
1231 SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator),
1232 Op, Ty);
1233 UniqueSCEVs.InsertNode(S, IP);
1234 registerUser(S, Op);
1235 return S;
1236}
1237
1238// Get the limit of a recurrence such that incrementing by Step cannot cause
1239// signed overflow as long as the value of the recurrence within the
1240// loop does not exceed this limit before incrementing.
1241static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step,
1242 ICmpInst::Predicate *Pred,
1243 ScalarEvolution *SE) {
1244 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1245 if (SE->isKnownPositive(Step)) {
1246 *Pred = ICmpInst::ICMP_SLT;
1248 SE->getSignedRangeMax(Step));
1249 }
1250 if (SE->isKnownNegative(Step)) {
1251 *Pred = ICmpInst::ICMP_SGT;
1253 SE->getSignedRangeMin(Step));
1254 }
1255 return nullptr;
1256}
1257
1258// Get the limit of a recurrence such that incrementing by Step cannot cause
1259// unsigned overflow as long as the value of the recurrence within the loop does
1260// not exceed this limit before incrementing.
1262 ICmpInst::Predicate *Pred,
1263 ScalarEvolution *SE) {
1264 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1265 *Pred = ICmpInst::ICMP_ULT;
1266
1268 SE->getUnsignedRangeMax(Step));
1269}
1270
1271namespace {
1272
1273struct ExtendOpTraitsBase {
1274 typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *,
1275 unsigned);
1276};
1277
1278// Used to make code generic over signed and unsigned overflow.
1279template <typename ExtendOp> struct ExtendOpTraits {
1280 // Members present:
1281 //
1282 // static const SCEV::NoWrapFlags WrapType;
1283 //
1284 // static const ExtendOpTraitsBase::GetExtendExprTy GetExtendExpr;
1285 //
1286 // static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1287 // ICmpInst::Predicate *Pred,
1288 // ScalarEvolution *SE);
1289};
1290
1291template <>
1292struct ExtendOpTraits<SCEVSignExtendExpr> : public ExtendOpTraitsBase {
1293 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNSW;
1294
1295 static const GetExtendExprTy GetExtendExpr;
1296
1297 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1298 ICmpInst::Predicate *Pred,
1299 ScalarEvolution *SE) {
1300 return getSignedOverflowLimitForStep(Step, Pred, SE);
1301 }
1302};
1303
1304const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1306
1307template <>
1308struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase {
1309 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNUW;
1310
1311 static const GetExtendExprTy GetExtendExpr;
1312
1313 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1314 ICmpInst::Predicate *Pred,
1315 ScalarEvolution *SE) {
1316 return getUnsignedOverflowLimitForStep(Step, Pred, SE);
1317 }
1318};
1319
1320const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1322
1323} // end anonymous namespace
1324
1325// The recurrence AR has been shown to have no signed/unsigned wrap or something
1326// close to it. Typically, if we can prove NSW/NUW for AR, then we can just as
1327// easily prove NSW/NUW for its preincrement or postincrement sibling. This
1328// allows normalizing a sign/zero extended AddRec as such: {sext/zext(Step +
1329// Start),+,Step} => {(Step + sext/zext(Start),+,Step} As a result, the
1330// expression "Step + sext/zext(PreIncAR)" is congruent with
1331// "sext/zext(PostIncAR)"
1332template <typename ExtendOpTy>
1333static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty,
1334 ScalarEvolution *SE, unsigned Depth) {
1335 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1336 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1337
1338 const Loop *L = AR->getLoop();
1339 const SCEV *Start = AR->getStart();
1340 const SCEV *Step = AR->getStepRecurrence(*SE);
1341
1342 // Check for a simple looking step prior to loop entry.
1343 const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Start);
1344 if (!SA)
1345 return nullptr;
1346
1347 // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV
1348 // subtraction is expensive. For this purpose, perform a quick and dirty
1349 // difference, by checking for Step in the operand list. Note, that
1350 // SA might have repeated ops, like %a + %a + ..., so only remove one.
1352 for (auto It = DiffOps.begin(); It != DiffOps.end(); ++It)
1353 if (*It == Step) {
1354 DiffOps.erase(It);
1355 break;
1356 }
1357
1358 if (DiffOps.size() == SA->getNumOperands())
1359 return nullptr;
1360
1361 // Try to prove `WrapType` (SCEV::FlagNSW or SCEV::FlagNUW) on `PreStart` +
1362 // `Step`:
1363
1364 // 1. NSW/NUW flags on the step increment.
1365 auto PreStartFlags =
1367 const SCEV *PreStart = SE->getAddExpr(DiffOps, PreStartFlags);
1369 SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap));
1370
1371 // "{S,+,X} is <nsw>/<nuw>" and "the backedge is taken at least once" implies
1372 // "S+X does not sign/unsign-overflow".
1373 //
1374
1375 const SCEV *BECount = SE->getBackedgeTakenCount(L);
1376 if (PreAR && PreAR->getNoWrapFlags(WrapType) &&
1377 !isa<SCEVCouldNotCompute>(BECount) && SE->isKnownPositive(BECount))
1378 return PreStart;
1379
1380 // 2. Direct overflow check on the step operation's expression.
1381 unsigned BitWidth = SE->getTypeSizeInBits(AR->getType());
1382 Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2);
1383 const SCEV *OperandExtendedStart =
1384 SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy, Depth),
1385 (SE->*GetExtendExpr)(Step, WideTy, Depth));
1386 if ((SE->*GetExtendExpr)(Start, WideTy, Depth) == OperandExtendedStart) {
1387 if (PreAR && AR->getNoWrapFlags(WrapType)) {
1388 // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW
1389 // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then
1390 // `PreAR` == {`PreStart`,+,`Step`} is also `WrapType`. Cache this fact.
1391 SE->setNoWrapFlags(const_cast<SCEVAddRecExpr *>(PreAR), WrapType);
1392 }
1393 return PreStart;
1394 }
1395
1396 // 3. Loop precondition.
1398 const SCEV *OverflowLimit =
1399 ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(Step, &Pred, SE);
1400
1401 if (OverflowLimit &&
1402 SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit))
1403 return PreStart;
1404
1405 return nullptr;
1406}
1407
1408// Get the normalized zero or sign extended expression for this AddRec's Start.
1409template <typename ExtendOpTy>
1410static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty,
1411 ScalarEvolution *SE,
1412 unsigned Depth) {
1413 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1414
1415 const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE, Depth);
1416 if (!PreStart)
1417 return (SE->*GetExtendExpr)(AR->getStart(), Ty, Depth);
1418
1419 return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty,
1420 Depth),
1421 (SE->*GetExtendExpr)(PreStart, Ty, Depth));
1422}
1423
1424// Try to prove away overflow by looking at "nearby" add recurrences. A
1425// motivating example for this rule: if we know `{0,+,4}` is `ult` `-1` and it
1426// does not itself wrap then we can conclude that `{1,+,4}` is `nuw`.
1427//
1428// Formally:
1429//
1430// {S,+,X} == {S-T,+,X} + T
1431// => Ext({S,+,X}) == Ext({S-T,+,X} + T)
1432//
1433// If ({S-T,+,X} + T) does not overflow ... (1)
1434//
1435// RHS == Ext({S-T,+,X} + T) == Ext({S-T,+,X}) + Ext(T)
1436//
1437// If {S-T,+,X} does not overflow ... (2)
1438//
1439// RHS == Ext({S-T,+,X}) + Ext(T) == {Ext(S-T),+,Ext(X)} + Ext(T)
1440// == {Ext(S-T)+Ext(T),+,Ext(X)}
1441//
1442// If (S-T)+T does not overflow ... (3)
1443//
1444// RHS == {Ext(S-T)+Ext(T),+,Ext(X)} == {Ext(S-T+T),+,Ext(X)}
1445// == {Ext(S),+,Ext(X)} == LHS
1446//
1447// Thus, if (1), (2) and (3) are true for some T, then
1448// Ext({S,+,X}) == {Ext(S),+,Ext(X)}
1449//
1450// (3) is implied by (1) -- "(S-T)+T does not overflow" is simply "({S-T,+,X}+T)
1451// does not overflow" restricted to the 0th iteration. Therefore we only need
1452// to check for (1) and (2).
1453//
1454// In the current context, S is `Start`, X is `Step`, Ext is `ExtendOpTy` and T
1455// is `Delta` (defined below).
1456template <typename ExtendOpTy>
1457bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start,
1458 const SCEV *Step,
1459 const Loop *L) {
1460 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1461
1462 // We restrict `Start` to a constant to prevent SCEV from spending too much
1463 // time here. It is correct (but more expensive) to continue with a
1464 // non-constant `Start` and do a general SCEV subtraction to compute
1465 // `PreStart` below.
1466 const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start);
1467 if (!StartC)
1468 return false;
1469
1470 APInt StartAI = StartC->getAPInt();
1471
1472 for (unsigned Delta : {-2, -1, 1, 2}) {
1473 const SCEV *PreStart = getConstant(StartAI - Delta);
1474
1475 FoldingSetNodeID ID;
1476 ID.AddInteger(scAddRecExpr);
1477 ID.AddPointer(PreStart);
1478 ID.AddPointer(Step);
1479 ID.AddPointer(L);
1480 void *IP = nullptr;
1481 const auto *PreAR =
1482 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
1483
1484 // Give up if we don't already have the add recurrence we need because
1485 // actually constructing an add recurrence is relatively expensive.
1486 if (PreAR && PreAR->getNoWrapFlags(WrapType)) { // proves (2)
1487 const SCEV *DeltaS = getConstant(StartC->getType(), Delta);
1489 const SCEV *Limit = ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(
1490 DeltaS, &Pred, this);
1491 if (Limit && isKnownPredicate(Pred, PreAR, Limit)) // proves (1)
1492 return true;
1493 }
1494 }
1495
1496 return false;
1497}
1498
1499// Finds an integer D for an expression (C + x + y + ...) such that the top
1500// level addition in (D + (C - D + x + y + ...)) would not wrap (signed or
1501// unsigned) and the number of trailing zeros of (C - D + x + y + ...) is
1502// maximized, where C is the \p ConstantTerm, x, y, ... are arbitrary SCEVs, and
1503// the (C + x + y + ...) expression is \p WholeAddExpr.
1505 const SCEVConstant *ConstantTerm,
1506 const SCEVAddExpr *WholeAddExpr) {
1507 const APInt &C = ConstantTerm->getAPInt();
1508 const unsigned BitWidth = C.getBitWidth();
1509 // Find number of trailing zeros of (x + y + ...) w/o the C first:
1510 uint32_t TZ = BitWidth;
1511 for (unsigned I = 1, E = WholeAddExpr->getNumOperands(); I < E && TZ; ++I)
1512 TZ = std::min(TZ, SE.getMinTrailingZeros(WholeAddExpr->getOperand(I)));
1513 if (TZ) {
1514 // Set D to be as many least significant bits of C as possible while still
1515 // guaranteeing that adding D to (C - D + x + y + ...) won't cause a wrap:
1516 return TZ < BitWidth ? C.trunc(TZ).zext(BitWidth) : C;
1517 }
1518 return APInt(BitWidth, 0);
1519}
1520
1521// Finds an integer D for an affine AddRec expression {C,+,x} such that the top
1522// level addition in (D + {C-D,+,x}) would not wrap (signed or unsigned) and the
1523// number of trailing zeros of (C - D + x * n) is maximized, where C is the \p
1524// ConstantStart, x is an arbitrary \p Step, and n is the loop trip count.
1526 const APInt &ConstantStart,
1527 const SCEV *Step) {
1528 const unsigned BitWidth = ConstantStart.getBitWidth();
1529 const uint32_t TZ = SE.getMinTrailingZeros(Step);
1530 if (TZ)
1531 return TZ < BitWidth ? ConstantStart.trunc(TZ).zext(BitWidth)
1532 : ConstantStart;
1533 return APInt(BitWidth, 0);
1534}
1535
1537 const ScalarEvolution::FoldID &ID, const SCEV *S,
1540 &FoldCacheUser) {
1541 auto I = FoldCache.insert({ID, S});
1542 if (!I.second) {
1543 // Remove FoldCacheUser entry for ID when replacing an existing FoldCache
1544 // entry.
1545 auto &UserIDs = FoldCacheUser[I.first->second];
1546 assert(count(UserIDs, ID) == 1 && "unexpected duplicates in UserIDs");
1547 for (unsigned I = 0; I != UserIDs.size(); ++I)
1548 if (UserIDs[I] == ID) {
1549 std::swap(UserIDs[I], UserIDs.back());
1550 break;
1551 }
1552 UserIDs.pop_back();
1553 I.first->second = S;
1554 }
1555 FoldCacheUser[S].push_back(ID);
1556}
1557
1558const SCEV *
1560 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1561 "This is not an extending conversion!");
1562 assert(isSCEVable(Ty) &&
1563 "This is not a conversion to a SCEVable type!");
1564 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1565 Ty = getEffectiveSCEVType(Ty);
1566
1567 FoldID ID(scZeroExtend, Op, Ty);
1568 if (const SCEV *S = FoldCache.lookup(ID))
1569 return S;
1570
1571 const SCEV *S = getZeroExtendExprImpl(Op, Ty, Depth);
1573 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
1574 return S;
1575}
1576
1578 unsigned Depth) {
1579 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1580 "This is not an extending conversion!");
1581 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
1582 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1583
1584 // Fold if the operand is constant.
1585 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1586 return getConstant(SC->getAPInt().zext(getTypeSizeInBits(Ty)));
1587
1588 // zext(zext(x)) --> zext(x)
1590 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1591
1592 // Before doing any expensive analysis, check to see if we've already
1593 // computed a SCEV for this Op and Ty.
1595 ID.AddInteger(scZeroExtend);
1596 ID.AddPointer(Op);
1597 ID.AddPointer(Ty);
1598 void *IP = nullptr;
1599 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1600 if (Depth > MaxCastDepth) {
1601 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1602 Op, Ty);
1603 UniqueSCEVs.InsertNode(S, IP);
1604 registerUser(S, Op);
1605 return S;
1606 }
1607
1608 // zext(trunc(x)) --> zext(x) or x or trunc(x)
1610 // It's possible the bits taken off by the truncate were all zero bits. If
1611 // so, we should be able to simplify this further.
1612 const SCEV *X = ST->getOperand();
1614 unsigned TruncBits = getTypeSizeInBits(ST->getType());
1615 unsigned NewBits = getTypeSizeInBits(Ty);
1616 if (CR.truncate(TruncBits).zeroExtend(NewBits).contains(
1617 CR.zextOrTrunc(NewBits)))
1618 return getTruncateOrZeroExtend(X, Ty, Depth);
1619 }
1620
1621 // If the input value is a chrec scev, and we can prove that the value
1622 // did not overflow the old, smaller, value, we can zero extend all of the
1623 // operands (often constants). This allows analysis of something like
1624 // this: for (unsigned char X = 0; X < 100; ++X) { int Y = X; }
1626 if (AR->isAffine()) {
1627 const SCEV *Start = AR->getStart();
1628 const SCEV *Step = AR->getStepRecurrence(*this);
1629 unsigned BitWidth = getTypeSizeInBits(AR->getType());
1630 const Loop *L = AR->getLoop();
1631
1632 // If we have special knowledge that this addrec won't overflow,
1633 // we don't need to do any further analysis.
1634 if (AR->hasNoUnsignedWrap()) {
1635 Start =
1637 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1638 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1639 }
1640
1641 // Check whether the backedge-taken count is SCEVCouldNotCompute.
1642 // Note that this serves two purposes: It filters out loops that are
1643 // simply not analyzable, and it covers the case where this code is
1644 // being called from within backedge-taken count analysis, such that
1645 // attempting to ask for the backedge-taken count would likely result
1646 // in infinite recursion. In the later case, the analysis code will
1647 // cope with a conservative value, and it will take care to purge
1648 // that value once it has finished.
1649 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
1650 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
1651 // Manually compute the final value for AR, checking for overflow.
1652
1653 // Check whether the backedge-taken count can be losslessly casted to
1654 // the addrec's type. The count is always unsigned.
1655 const SCEV *CastedMaxBECount =
1656 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
1657 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
1658 CastedMaxBECount, MaxBECount->getType(), Depth);
1659 if (MaxBECount == RecastedMaxBECount) {
1660 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
1661 // Check whether Start+Step*MaxBECount has no unsigned overflow.
1662 const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step,
1664 const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul,
1666 Depth + 1),
1667 WideTy, Depth + 1);
1668 const SCEV *WideStart = getZeroExtendExpr(Start, WideTy, Depth + 1);
1669 const SCEV *WideMaxBECount =
1670 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
1671 const SCEV *OperandExtendedAdd =
1672 getAddExpr(WideStart,
1673 getMulExpr(WideMaxBECount,
1674 getZeroExtendExpr(Step, WideTy, Depth + 1),
1677 if (ZAdd == OperandExtendedAdd) {
1678 // Cache knowledge of AR NUW, which is propagated to this AddRec.
1679 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1680 // Return the expression with the addrec on the outside.
1681 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1682 Depth + 1);
1683 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1684 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1685 }
1686 // Similar to above, only this time treat the step value as signed.
1687 // This covers loops that count down.
1688 OperandExtendedAdd =
1689 getAddExpr(WideStart,
1690 getMulExpr(WideMaxBECount,
1691 getSignExtendExpr(Step, WideTy, Depth + 1),
1694 if (ZAdd == OperandExtendedAdd) {
1695 // Cache knowledge of AR NW, which is propagated to this AddRec.
1696 // Negative step causes unsigned wrap, but it still can't self-wrap.
1697 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1698 // Return the expression with the addrec on the outside.
1699 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1700 Depth + 1);
1701 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1702 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1703 }
1704 }
1705 }
1706
1707 // Normally, in the cases we can prove no-overflow via a
1708 // backedge guarding condition, we can also compute a backedge
1709 // taken count for the loop. The exceptions are assumptions and
1710 // guards present in the loop -- SCEV is not great at exploiting
1711 // these to compute max backedge taken counts, but can still use
1712 // these to prove lack of overflow. Use this fact to avoid
1713 // doing extra work that may not pay off.
1714 if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards ||
1715 !AC.assumptions().empty()) {
1716
1717 auto NewFlags = proveNoUnsignedWrapViaInduction(AR);
1718 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
1719 if (AR->hasNoUnsignedWrap()) {
1720 // Same as nuw case above - duplicated here to avoid a compile time
1721 // issue. It's not clear that the order of checks does matter, but
1722 // it's one of two issue possible causes for a change which was
1723 // reverted. Be conservative for the moment.
1724 Start =
1726 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1727 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1728 }
1729
1730 // For a negative step, we can extend the operands iff doing so only
1731 // traverses values in the range zext([0,UINT_MAX]).
1732 if (isKnownNegative(Step)) {
1734 getSignedRangeMin(Step));
1737 // Cache knowledge of AR NW, which is propagated to this
1738 // AddRec. Negative step causes unsigned wrap, but it
1739 // still can't self-wrap.
1740 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1741 // Return the expression with the addrec on the outside.
1742 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1743 Depth + 1);
1744 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1745 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1746 }
1747 }
1748 }
1749
1750 // zext({C,+,Step}) --> (zext(D) + zext({C-D,+,Step}))<nuw><nsw>
1751 // if D + (C - D + Step * n) could be proven to not unsigned wrap
1752 // where D maximizes the number of trailing zeros of (C - D + Step * n)
1753 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
1754 const APInt &C = SC->getAPInt();
1755 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
1756 if (D != 0) {
1757 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1758 const SCEV *SResidual =
1759 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
1760 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1761 return getAddExpr(SZExtD, SZExtR,
1763 Depth + 1);
1764 }
1765 }
1766
1767 if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) {
1768 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1769 Start =
1771 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1772 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1773 }
1774 }
1775
1776 // zext(A % B) --> zext(A) % zext(B)
1777 {
1778 const SCEV *LHS;
1779 const SCEV *RHS;
1780 if (match(Op, m_scev_URem(m_SCEV(LHS), m_SCEV(RHS), *this)))
1781 return getURemExpr(getZeroExtendExpr(LHS, Ty, Depth + 1),
1782 getZeroExtendExpr(RHS, Ty, Depth + 1));
1783 }
1784
1785 // zext(A / B) --> zext(A) / zext(B).
1786 if (auto *Div = dyn_cast<SCEVUDivExpr>(Op))
1787 return getUDivExpr(getZeroExtendExpr(Div->getLHS(), Ty, Depth + 1),
1788 getZeroExtendExpr(Div->getRHS(), Ty, Depth + 1));
1789
1790 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1791 // zext((A + B + ...)<nuw>) --> (zext(A) + zext(B) + ...)<nuw>
1792 if (SA->hasNoUnsignedWrap()) {
1793 // If the addition does not unsign overflow then we can, by definition,
1794 // commute the zero extension with the addition operation.
1796 for (const auto *Op : SA->operands())
1797 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1798 return getAddExpr(Ops, SCEV::FlagNUW, Depth + 1);
1799 }
1800
1801 // zext(C + x + y + ...) --> (zext(D) + zext((C - D) + x + y + ...))
1802 // if D + (C - D + x + y + ...) could be proven to not unsigned wrap
1803 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1804 //
1805 // Often address arithmetics contain expressions like
1806 // (zext (add (shl X, C1), C2)), for instance, (zext (5 + (4 * X))).
1807 // This transformation is useful while proving that such expressions are
1808 // equal or differ by a small constant amount, see LoadStoreVectorizer pass.
1809 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1810 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1811 if (D != 0) {
1812 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1813 const SCEV *SResidual =
1815 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1816 return getAddExpr(SZExtD, SZExtR,
1818 Depth + 1);
1819 }
1820 }
1821 }
1822
1823 if (auto *SM = dyn_cast<SCEVMulExpr>(Op)) {
1824 // zext((A * B * ...)<nuw>) --> (zext(A) * zext(B) * ...)<nuw>
1825 if (SM->hasNoUnsignedWrap()) {
1826 // If the multiply does not unsign overflow then we can, by definition,
1827 // commute the zero extension with the multiply operation.
1829 for (const auto *Op : SM->operands())
1830 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1831 return getMulExpr(Ops, SCEV::FlagNUW, Depth + 1);
1832 }
1833
1834 // zext(2^K * (trunc X to iN)) to iM ->
1835 // 2^K * (zext(trunc X to i{N-K}) to iM)<nuw>
1836 //
1837 // Proof:
1838 //
1839 // zext(2^K * (trunc X to iN)) to iM
1840 // = zext((trunc X to iN) << K) to iM
1841 // = zext((trunc X to i{N-K}) << K)<nuw> to iM
1842 // (because shl removes the top K bits)
1843 // = zext((2^K * (trunc X to i{N-K}))<nuw>) to iM
1844 // = (2^K * (zext(trunc X to i{N-K}) to iM))<nuw>.
1845 //
1846 const APInt *C;
1847 const SCEV *TruncRHS;
1848 if (match(SM,
1849 m_scev_Mul(m_scev_APInt(C), m_scev_Trunc(m_SCEV(TruncRHS)))) &&
1850 C->isPowerOf2()) {
1851 int NewTruncBits =
1852 getTypeSizeInBits(SM->getOperand(1)->getType()) - C->logBase2();
1853 Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits);
1854 return getMulExpr(
1855 getZeroExtendExpr(SM->getOperand(0), Ty),
1856 getZeroExtendExpr(getTruncateExpr(TruncRHS, NewTruncTy), Ty),
1857 SCEV::FlagNUW, Depth + 1);
1858 }
1859 }
1860
1861 // zext(umin(x, y)) -> umin(zext(x), zext(y))
1862 // zext(umax(x, y)) -> umax(zext(x), zext(y))
1866 for (auto *Operand : MinMax->operands())
1867 Operands.push_back(getZeroExtendExpr(Operand, Ty));
1869 return getUMinExpr(Operands);
1870 return getUMaxExpr(Operands);
1871 }
1872
1873 // zext(umin_seq(x, y)) -> umin_seq(zext(x), zext(y))
1875 assert(isa<SCEVSequentialUMinExpr>(MinMax) && "Not supported!");
1877 for (auto *Operand : MinMax->operands())
1878 Operands.push_back(getZeroExtendExpr(Operand, Ty));
1879 return getUMinExpr(Operands, /*Sequential*/ true);
1880 }
1881
1882 // The cast wasn't folded; create an explicit cast node.
1883 // Recompute the insert position, as it may have been invalidated.
1884 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1885 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1886 Op, Ty);
1887 UniqueSCEVs.InsertNode(S, IP);
1888 registerUser(S, Op);
1889 return S;
1890}
1891
1892const SCEV *
1894 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1895 "This is not an extending conversion!");
1896 assert(isSCEVable(Ty) &&
1897 "This is not a conversion to a SCEVable type!");
1898 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1899 Ty = getEffectiveSCEVType(Ty);
1900
1901 FoldID ID(scSignExtend, Op, Ty);
1902 if (const SCEV *S = FoldCache.lookup(ID))
1903 return S;
1904
1905 const SCEV *S = getSignExtendExprImpl(Op, Ty, Depth);
1907 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
1908 return S;
1909}
1910
1912 unsigned Depth) {
1913 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1914 "This is not an extending conversion!");
1915 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
1916 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1917 Ty = getEffectiveSCEVType(Ty);
1918
1919 // Fold if the operand is constant.
1920 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1921 return getConstant(SC->getAPInt().sext(getTypeSizeInBits(Ty)));
1922
1923 // sext(sext(x)) --> sext(x)
1925 return getSignExtendExpr(SS->getOperand(), Ty, Depth + 1);
1926
1927 // sext(zext(x)) --> zext(x)
1929 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1930
1931 // Before doing any expensive analysis, check to see if we've already
1932 // computed a SCEV for this Op and Ty.
1934 ID.AddInteger(scSignExtend);
1935 ID.AddPointer(Op);
1936 ID.AddPointer(Ty);
1937 void *IP = nullptr;
1938 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1939 // Limit recursion depth.
1940 if (Depth > MaxCastDepth) {
1941 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
1942 Op, Ty);
1943 UniqueSCEVs.InsertNode(S, IP);
1944 registerUser(S, Op);
1945 return S;
1946 }
1947
1948 // sext(trunc(x)) --> sext(x) or x or trunc(x)
1950 // It's possible the bits taken off by the truncate were all sign bits. If
1951 // so, we should be able to simplify this further.
1952 const SCEV *X = ST->getOperand();
1954 unsigned TruncBits = getTypeSizeInBits(ST->getType());
1955 unsigned NewBits = getTypeSizeInBits(Ty);
1956 if (CR.truncate(TruncBits).signExtend(NewBits).contains(
1957 CR.sextOrTrunc(NewBits)))
1958 return getTruncateOrSignExtend(X, Ty, Depth);
1959 }
1960
1961 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1962 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
1963 if (SA->hasNoSignedWrap()) {
1964 // If the addition does not sign overflow then we can, by definition,
1965 // commute the sign extension with the addition operation.
1967 for (const auto *Op : SA->operands())
1968 Ops.push_back(getSignExtendExpr(Op, Ty, Depth + 1));
1969 return getAddExpr(Ops, SCEV::FlagNSW, Depth + 1);
1970 }
1971
1972 // sext(C + x + y + ...) --> (sext(D) + sext((C - D) + x + y + ...))
1973 // if D + (C - D + x + y + ...) could be proven to not signed wrap
1974 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1975 //
1976 // For instance, this will bring two seemingly different expressions:
1977 // 1 + sext(5 + 20 * %x + 24 * %y) and
1978 // sext(6 + 20 * %x + 24 * %y)
1979 // to the same form:
1980 // 2 + sext(4 + 20 * %x + 24 * %y)
1981 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1982 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1983 if (D != 0) {
1984 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
1985 const SCEV *SResidual =
1987 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
1988 return getAddExpr(SSExtD, SSExtR,
1990 Depth + 1);
1991 }
1992 }
1993 }
1994 // If the input value is a chrec scev, and we can prove that the value
1995 // did not overflow the old, smaller, value, we can sign extend all of the
1996 // operands (often constants). This allows analysis of something like
1997 // this: for (signed char X = 0; X < 100; ++X) { int Y = X; }
1999 if (AR->isAffine()) {
2000 const SCEV *Start = AR->getStart();
2001 const SCEV *Step = AR->getStepRecurrence(*this);
2002 unsigned BitWidth = getTypeSizeInBits(AR->getType());
2003 const Loop *L = AR->getLoop();
2004
2005 // If we have special knowledge that this addrec won't overflow,
2006 // we don't need to do any further analysis.
2007 if (AR->hasNoSignedWrap()) {
2008 Start =
2010 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2011 return getAddRecExpr(Start, Step, L, SCEV::FlagNSW);
2012 }
2013
2014 // Check whether the backedge-taken count is SCEVCouldNotCompute.
2015 // Note that this serves two purposes: It filters out loops that are
2016 // simply not analyzable, and it covers the case where this code is
2017 // being called from within backedge-taken count analysis, such that
2018 // attempting to ask for the backedge-taken count would likely result
2019 // in infinite recursion. In the later case, the analysis code will
2020 // cope with a conservative value, and it will take care to purge
2021 // that value once it has finished.
2022 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
2023 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
2024 // Manually compute the final value for AR, checking for
2025 // overflow.
2026
2027 // Check whether the backedge-taken count can be losslessly casted to
2028 // the addrec's type. The count is always unsigned.
2029 const SCEV *CastedMaxBECount =
2030 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
2031 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
2032 CastedMaxBECount, MaxBECount->getType(), Depth);
2033 if (MaxBECount == RecastedMaxBECount) {
2034 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
2035 // Check whether Start+Step*MaxBECount has no signed overflow.
2036 const SCEV *SMul = getMulExpr(CastedMaxBECount, Step,
2038 const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul,
2040 Depth + 1),
2041 WideTy, Depth + 1);
2042 const SCEV *WideStart = getSignExtendExpr(Start, WideTy, Depth + 1);
2043 const SCEV *WideMaxBECount =
2044 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
2045 const SCEV *OperandExtendedAdd =
2046 getAddExpr(WideStart,
2047 getMulExpr(WideMaxBECount,
2048 getSignExtendExpr(Step, WideTy, Depth + 1),
2051 if (SAdd == OperandExtendedAdd) {
2052 // Cache knowledge of AR NSW, which is propagated to this AddRec.
2053 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2054 // Return the expression with the addrec on the outside.
2055 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2056 Depth + 1);
2057 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2058 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2059 }
2060 // Similar to above, only this time treat the step value as unsigned.
2061 // This covers loops that count up with an unsigned step.
2062 OperandExtendedAdd =
2063 getAddExpr(WideStart,
2064 getMulExpr(WideMaxBECount,
2065 getZeroExtendExpr(Step, WideTy, Depth + 1),
2068 if (SAdd == OperandExtendedAdd) {
2069 // If AR wraps around then
2070 //
2071 // abs(Step) * MaxBECount > unsigned-max(AR->getType())
2072 // => SAdd != OperandExtendedAdd
2073 //
2074 // Thus (AR is not NW => SAdd != OperandExtendedAdd) <=>
2075 // (SAdd == OperandExtendedAdd => AR is NW)
2076
2077 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
2078
2079 // Return the expression with the addrec on the outside.
2080 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2081 Depth + 1);
2082 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
2083 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2084 }
2085 }
2086 }
2087
2088 auto NewFlags = proveNoSignedWrapViaInduction(AR);
2089 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
2090 if (AR->hasNoSignedWrap()) {
2091 // Same as nsw case above - duplicated here to avoid a compile time
2092 // issue. It's not clear that the order of checks does matter, but
2093 // it's one of two issue possible causes for a change which was
2094 // reverted. Be conservative for the moment.
2095 Start =
2097 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2098 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2099 }
2100
2101 // sext({C,+,Step}) --> (sext(D) + sext({C-D,+,Step}))<nuw><nsw>
2102 // if D + (C - D + Step * n) could be proven to not signed wrap
2103 // where D maximizes the number of trailing zeros of (C - D + Step * n)
2104 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
2105 const APInt &C = SC->getAPInt();
2106 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
2107 if (D != 0) {
2108 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
2109 const SCEV *SResidual =
2110 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
2111 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
2112 return getAddExpr(SSExtD, SSExtR,
2114 Depth + 1);
2115 }
2116 }
2117
2118 if (proveNoWrapByVaryingStart<SCEVSignExtendExpr>(Start, Step, L)) {
2119 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2120 Start =
2122 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2123 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2124 }
2125 }
2126
2127 // If the input value is provably positive and we could not simplify
2128 // away the sext build a zext instead.
2130 return getZeroExtendExpr(Op, Ty, Depth + 1);
2131
2132 // sext(smin(x, y)) -> smin(sext(x), sext(y))
2133 // sext(smax(x, y)) -> smax(sext(x), sext(y))
2137 for (auto *Operand : MinMax->operands())
2138 Operands.push_back(getSignExtendExpr(Operand, Ty));
2140 return getSMinExpr(Operands);
2141 return getSMaxExpr(Operands);
2142 }
2143
2144 // The cast wasn't folded; create an explicit cast node.
2145 // Recompute the insert position, as it may have been invalidated.
2146 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2147 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
2148 Op, Ty);
2149 UniqueSCEVs.InsertNode(S, IP);
2150 registerUser(S, { Op });
2151 return S;
2152}
2153
2155 Type *Ty) {
2156 switch (Kind) {
2157 case scTruncate:
2158 return getTruncateExpr(Op, Ty);
2159 case scZeroExtend:
2160 return getZeroExtendExpr(Op, Ty);
2161 case scSignExtend:
2162 return getSignExtendExpr(Op, Ty);
2163 case scPtrToInt:
2164 return getPtrToIntExpr(Op, Ty);
2165 default:
2166 llvm_unreachable("Not a SCEV cast expression!");
2167 }
2168}
2169
2170/// getAnyExtendExpr - Return a SCEV for the given operand extended with
2171/// unspecified bits out to the given type.
2173 Type *Ty) {
2174 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
2175 "This is not an extending conversion!");
2176 assert(isSCEVable(Ty) &&
2177 "This is not a conversion to a SCEVable type!");
2178 Ty = getEffectiveSCEVType(Ty);
2179
2180 // Sign-extend negative constants.
2181 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
2182 if (SC->getAPInt().isNegative())
2183 return getSignExtendExpr(Op, Ty);
2184
2185 // Peel off a truncate cast.
2187 const SCEV *NewOp = T->getOperand();
2188 if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty))
2189 return getAnyExtendExpr(NewOp, Ty);
2190 return getTruncateOrNoop(NewOp, Ty);
2191 }
2192
2193 // Next try a zext cast. If the cast is folded, use it.
2194 const SCEV *ZExt = getZeroExtendExpr(Op, Ty);
2195 if (!isa<SCEVZeroExtendExpr>(ZExt))
2196 return ZExt;
2197
2198 // Next try a sext cast. If the cast is folded, use it.
2199 const SCEV *SExt = getSignExtendExpr(Op, Ty);
2200 if (!isa<SCEVSignExtendExpr>(SExt))
2201 return SExt;
2202
2203 // Force the cast to be folded into the operands of an addrec.
2204 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) {
2206 for (const SCEV *Op : AR->operands())
2207 Ops.push_back(getAnyExtendExpr(Op, Ty));
2208 return getAddRecExpr(Ops, AR->getLoop(), SCEV::FlagNW);
2209 }
2210
2211 // If the expression is obviously signed, use the sext cast value.
2212 if (isa<SCEVSMaxExpr>(Op))
2213 return SExt;
2214
2215 // Absent any other information, use the zext cast value.
2216 return ZExt;
2217}
2218
2219/// Process the given Ops list, which is a list of operands to be added under
2220/// the given scale, update the given map. This is a helper function for
2221/// getAddRecExpr. As an example of what it does, given a sequence of operands
2222/// that would form an add expression like this:
2223///
2224/// m + n + 13 + (A * (o + p + (B * (q + m + 29)))) + r + (-1 * r)
2225///
2226/// where A and B are constants, update the map with these values:
2227///
2228/// (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0)
2229///
2230/// and add 13 + A*B*29 to AccumulatedConstant.
2231/// This will allow getAddRecExpr to produce this:
2232///
2233/// 13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B)
2234///
2235/// This form often exposes folding opportunities that are hidden in
2236/// the original operand list.
2237///
2238/// Return true iff it appears that any interesting folding opportunities
2239/// may be exposed. This helps getAddRecExpr short-circuit extra work in
2240/// the common case where no interesting opportunities are present, and
2241/// is also used as a check to avoid infinite recursion.
2242static bool
2245 APInt &AccumulatedConstant,
2246 ArrayRef<const SCEV *> Ops, const APInt &Scale,
2247 ScalarEvolution &SE) {
2248 bool Interesting = false;
2249
2250 // Iterate over the add operands. They are sorted, with constants first.
2251 unsigned i = 0;
2252 while (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
2253 ++i;
2254 // Pull a buried constant out to the outside.
2255 if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero())
2256 Interesting = true;
2257 AccumulatedConstant += Scale * C->getAPInt();
2258 }
2259
2260 // Next comes everything else. We're especially interested in multiplies
2261 // here, but they're in the middle, so just visit the rest with one loop.
2262 for (; i != Ops.size(); ++i) {
2264 if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) {
2265 APInt NewScale =
2266 Scale * cast<SCEVConstant>(Mul->getOperand(0))->getAPInt();
2267 if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) {
2268 // A multiplication of a constant with another add; recurse.
2269 const SCEVAddExpr *Add = cast<SCEVAddExpr>(Mul->getOperand(1));
2270 Interesting |=
2271 CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2272 Add->operands(), NewScale, SE);
2273 } else {
2274 // A multiplication of a constant with some other value. Update
2275 // the map.
2276 SmallVector<const SCEV *, 4> MulOps(drop_begin(Mul->operands()));
2277 const SCEV *Key = SE.getMulExpr(MulOps);
2278 auto Pair = M.insert({Key, NewScale});
2279 if (Pair.second) {
2280 NewOps.push_back(Pair.first->first);
2281 } else {
2282 Pair.first->second += NewScale;
2283 // The map already had an entry for this value, which may indicate
2284 // a folding opportunity.
2285 Interesting = true;
2286 }
2287 }
2288 } else {
2289 // An ordinary operand. Update the map.
2290 std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair =
2291 M.insert({Ops[i], Scale});
2292 if (Pair.second) {
2293 NewOps.push_back(Pair.first->first);
2294 } else {
2295 Pair.first->second += Scale;
2296 // The map already had an entry for this value, which may indicate
2297 // a folding opportunity.
2298 Interesting = true;
2299 }
2300 }
2301 }
2302
2303 return Interesting;
2304}
2305
2307 const SCEV *LHS, const SCEV *RHS,
2308 const Instruction *CtxI) {
2309 const SCEV *(ScalarEvolution::*Operation)(const SCEV *, const SCEV *,
2310 SCEV::NoWrapFlags, unsigned);
2311 switch (BinOp) {
2312 default:
2313 llvm_unreachable("Unsupported binary op");
2314 case Instruction::Add:
2316 break;
2317 case Instruction::Sub:
2319 break;
2320 case Instruction::Mul:
2322 break;
2323 }
2324
2325 const SCEV *(ScalarEvolution::*Extension)(const SCEV *, Type *, unsigned) =
2328
2329 // Check ext(LHS op RHS) == ext(LHS) op ext(RHS)
2330 auto *NarrowTy = cast<IntegerType>(LHS->getType());
2331 auto *WideTy =
2332 IntegerType::get(NarrowTy->getContext(), NarrowTy->getBitWidth() * 2);
2333
2334 const SCEV *A = (this->*Extension)(
2335 (this->*Operation)(LHS, RHS, SCEV::FlagAnyWrap, 0), WideTy, 0);
2336 const SCEV *LHSB = (this->*Extension)(LHS, WideTy, 0);
2337 const SCEV *RHSB = (this->*Extension)(RHS, WideTy, 0);
2338 const SCEV *B = (this->*Operation)(LHSB, RHSB, SCEV::FlagAnyWrap, 0);
2339 if (A == B)
2340 return true;
2341 // Can we use context to prove the fact we need?
2342 if (!CtxI)
2343 return false;
2344 // TODO: Support mul.
2345 if (BinOp == Instruction::Mul)
2346 return false;
2347 auto *RHSC = dyn_cast<SCEVConstant>(RHS);
2348 // TODO: Lift this limitation.
2349 if (!RHSC)
2350 return false;
2351 APInt C = RHSC->getAPInt();
2352 unsigned NumBits = C.getBitWidth();
2353 bool IsSub = (BinOp == Instruction::Sub);
2354 bool IsNegativeConst = (Signed && C.isNegative());
2355 // Compute the direction and magnitude by which we need to check overflow.
2356 bool OverflowDown = IsSub ^ IsNegativeConst;
2357 APInt Magnitude = C;
2358 if (IsNegativeConst) {
2359 if (C == APInt::getSignedMinValue(NumBits))
2360 // TODO: SINT_MIN on inversion gives the same negative value, we don't
2361 // want to deal with that.
2362 return false;
2363 Magnitude = -C;
2364 }
2365
2367 if (OverflowDown) {
2368 // To avoid overflow down, we need to make sure that MIN + Magnitude <= LHS.
2369 APInt Min = Signed ? APInt::getSignedMinValue(NumBits)
2370 : APInt::getMinValue(NumBits);
2371 APInt Limit = Min + Magnitude;
2372 return isKnownPredicateAt(Pred, getConstant(Limit), LHS, CtxI);
2373 } else {
2374 // To avoid overflow up, we need to make sure that LHS <= MAX - Magnitude.
2375 APInt Max = Signed ? APInt::getSignedMaxValue(NumBits)
2376 : APInt::getMaxValue(NumBits);
2377 APInt Limit = Max - Magnitude;
2378 return isKnownPredicateAt(Pred, LHS, getConstant(Limit), CtxI);
2379 }
2380}
2381
2382std::optional<SCEV::NoWrapFlags>
2384 const OverflowingBinaryOperator *OBO) {
2385 // It cannot be done any better.
2386 if (OBO->hasNoUnsignedWrap() && OBO->hasNoSignedWrap())
2387 return std::nullopt;
2388
2390
2391 if (OBO->hasNoUnsignedWrap())
2393 if (OBO->hasNoSignedWrap())
2395
2396 bool Deduced = false;
2397
2398 if (OBO->getOpcode() != Instruction::Add &&
2399 OBO->getOpcode() != Instruction::Sub &&
2400 OBO->getOpcode() != Instruction::Mul)
2401 return std::nullopt;
2402
2403 const SCEV *LHS = getSCEV(OBO->getOperand(0));
2404 const SCEV *RHS = getSCEV(OBO->getOperand(1));
2405
2406 const Instruction *CtxI =
2408 if (!OBO->hasNoUnsignedWrap() &&
2410 /* Signed */ false, LHS, RHS, CtxI)) {
2412 Deduced = true;
2413 }
2414
2415 if (!OBO->hasNoSignedWrap() &&
2417 /* Signed */ true, LHS, RHS, CtxI)) {
2419 Deduced = true;
2420 }
2421
2422 if (Deduced)
2423 return Flags;
2424 return std::nullopt;
2425}
2426
2427// We're trying to construct a SCEV of type `Type' with `Ops' as operands and
2428// `OldFlags' as can't-wrap behavior. Infer a more aggressive set of
2429// can't-overflow flags for the operation if possible.
2433 SCEV::NoWrapFlags Flags) {
2434 using namespace std::placeholders;
2435
2436 using OBO = OverflowingBinaryOperator;
2437
2438 bool CanAnalyze =
2440 (void)CanAnalyze;
2441 assert(CanAnalyze && "don't call from other places!");
2442
2443 int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
2444 SCEV::NoWrapFlags SignOrUnsignWrap =
2445 ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2446
2447 // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
2448 auto IsKnownNonNegative = [&](const SCEV *S) {
2449 return SE->isKnownNonNegative(S);
2450 };
2451
2452 if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative))
2453 Flags =
2454 ScalarEvolution::setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
2455
2456 SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2457
2458 if (SignOrUnsignWrap != SignOrUnsignMask &&
2459 (Type == scAddExpr || Type == scMulExpr) && Ops.size() == 2 &&
2460 isa<SCEVConstant>(Ops[0])) {
2461
2462 auto Opcode = [&] {
2463 switch (Type) {
2464 case scAddExpr:
2465 return Instruction::Add;
2466 case scMulExpr:
2467 return Instruction::Mul;
2468 default:
2469 llvm_unreachable("Unexpected SCEV op.");
2470 }
2471 }();
2472
2473 const APInt &C = cast<SCEVConstant>(Ops[0])->getAPInt();
2474
2475 // (A <opcode> C) --> (A <opcode> C)<nsw> if the op doesn't sign overflow.
2476 if (!(SignOrUnsignWrap & SCEV::FlagNSW)) {
2478 Opcode, C, OBO::NoSignedWrap);
2479 if (NSWRegion.contains(SE->getSignedRange(Ops[1])))
2481 }
2482
2483 // (A <opcode> C) --> (A <opcode> C)<nuw> if the op doesn't unsign overflow.
2484 if (!(SignOrUnsignWrap & SCEV::FlagNUW)) {
2486 Opcode, C, OBO::NoUnsignedWrap);
2487 if (NUWRegion.contains(SE->getUnsignedRange(Ops[1])))
2489 }
2490 }
2491
2492 // <0,+,nonnegative><nw> is also nuw
2493 // TODO: Add corresponding nsw case
2495 !ScalarEvolution::hasFlags(Flags, SCEV::FlagNUW) && Ops.size() == 2 &&
2496 Ops[0]->isZero() && IsKnownNonNegative(Ops[1]))
2498
2499 // both (udiv X, Y) * Y and Y * (udiv X, Y) are always NUW
2501 Ops.size() == 2) {
2502 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[0]))
2503 if (UDiv->getOperand(1) == Ops[1])
2505 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[1]))
2506 if (UDiv->getOperand(1) == Ops[0])
2508 }
2509
2510 return Flags;
2511}
2512
2514 return isLoopInvariant(S, L) && properlyDominates(S, L->getHeader());
2515}
2516
2517/// Get a canonical add expression, or something simpler if possible.
2519 SCEV::NoWrapFlags OrigFlags,
2520 unsigned Depth) {
2521 assert(!(OrigFlags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) &&
2522 "only nuw or nsw allowed");
2523 assert(!Ops.empty() && "Cannot get empty add!");
2524 if (Ops.size() == 1) return Ops[0];
2525#ifndef NDEBUG
2526 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
2527 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2528 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
2529 "SCEVAddExpr operand types don't match!");
2530 unsigned NumPtrs = count_if(
2531 Ops, [](const SCEV *Op) { return Op->getType()->isPointerTy(); });
2532 assert(NumPtrs <= 1 && "add has at most one pointer operand");
2533#endif
2534
2535 const SCEV *Folded = constantFoldAndGroupOps(
2536 *this, LI, DT, Ops,
2537 [](const APInt &C1, const APInt &C2) { return C1 + C2; },
2538 [](const APInt &C) { return C.isZero(); }, // identity
2539 [](const APInt &C) { return false; }); // absorber
2540 if (Folded)
2541 return Folded;
2542
2543 unsigned Idx = isa<SCEVConstant>(Ops[0]) ? 1 : 0;
2544
2545 // Delay expensive flag strengthening until necessary.
2546 auto ComputeFlags = [this, OrigFlags](ArrayRef<const SCEV *> Ops) {
2547 return StrengthenNoWrapFlags(this, scAddExpr, Ops, OrigFlags);
2548 };
2549
2550 // Limit recursion calls depth.
2552 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2553
2554 if (SCEV *S = findExistingSCEVInCache(scAddExpr, Ops)) {
2555 // Don't strengthen flags if we have no new information.
2556 SCEVAddExpr *Add = static_cast<SCEVAddExpr *>(S);
2557 if (Add->getNoWrapFlags(OrigFlags) != OrigFlags)
2558 Add->setNoWrapFlags(ComputeFlags(Ops));
2559 return S;
2560 }
2561
2562 // Okay, check to see if the same value occurs in the operand list more than
2563 // once. If so, merge them together into an multiply expression. Since we
2564 // sorted the list, these values are required to be adjacent.
2565 Type *Ty = Ops[0]->getType();
2566 bool FoundMatch = false;
2567 for (unsigned i = 0, e = Ops.size(); i != e-1; ++i)
2568 if (Ops[i] == Ops[i+1]) { // X + Y + Y --> X + Y*2
2569 // Scan ahead to count how many equal operands there are.
2570 unsigned Count = 2;
2571 while (i+Count != e && Ops[i+Count] == Ops[i])
2572 ++Count;
2573 // Merge the values into a multiply.
2574 const SCEV *Scale = getConstant(Ty, Count);
2575 const SCEV *Mul = getMulExpr(Scale, Ops[i], SCEV::FlagAnyWrap, Depth + 1);
2576 if (Ops.size() == Count)
2577 return Mul;
2578 Ops[i] = Mul;
2579 Ops.erase(Ops.begin()+i+1, Ops.begin()+i+Count);
2580 --i; e -= Count - 1;
2581 FoundMatch = true;
2582 }
2583 if (FoundMatch)
2584 return getAddExpr(Ops, OrigFlags, Depth + 1);
2585
2586 // Check for truncates. If all the operands are truncated from the same
2587 // type, see if factoring out the truncate would permit the result to be
2588 // folded. eg., n*trunc(x) + m*trunc(y) --> trunc(trunc(m)*x + trunc(n)*y)
2589 // if the contents of the resulting outer trunc fold to something simple.
2590 auto FindTruncSrcType = [&]() -> Type * {
2591 // We're ultimately looking to fold an addrec of truncs and muls of only
2592 // constants and truncs, so if we find any other types of SCEV
2593 // as operands of the addrec then we bail and return nullptr here.
2594 // Otherwise, we return the type of the operand of a trunc that we find.
2595 if (auto *T = dyn_cast<SCEVTruncateExpr>(Ops[Idx]))
2596 return T->getOperand()->getType();
2597 if (const auto *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
2598 const auto *LastOp = Mul->getOperand(Mul->getNumOperands() - 1);
2599 if (const auto *T = dyn_cast<SCEVTruncateExpr>(LastOp))
2600 return T->getOperand()->getType();
2601 }
2602 return nullptr;
2603 };
2604 if (auto *SrcType = FindTruncSrcType()) {
2606 bool Ok = true;
2607 // Check all the operands to see if they can be represented in the
2608 // source type of the truncate.
2609 for (const SCEV *Op : Ops) {
2611 if (T->getOperand()->getType() != SrcType) {
2612 Ok = false;
2613 break;
2614 }
2615 LargeOps.push_back(T->getOperand());
2616 } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Op)) {
2617 LargeOps.push_back(getAnyExtendExpr(C, SrcType));
2618 } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Op)) {
2619 SmallVector<const SCEV *, 8> LargeMulOps;
2620 for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) {
2621 if (const SCEVTruncateExpr *T =
2622 dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) {
2623 if (T->getOperand()->getType() != SrcType) {
2624 Ok = false;
2625 break;
2626 }
2627 LargeMulOps.push_back(T->getOperand());
2628 } else if (const auto *C = dyn_cast<SCEVConstant>(M->getOperand(j))) {
2629 LargeMulOps.push_back(getAnyExtendExpr(C, SrcType));
2630 } else {
2631 Ok = false;
2632 break;
2633 }
2634 }
2635 if (Ok)
2636 LargeOps.push_back(getMulExpr(LargeMulOps, SCEV::FlagAnyWrap, Depth + 1));
2637 } else {
2638 Ok = false;
2639 break;
2640 }
2641 }
2642 if (Ok) {
2643 // Evaluate the expression in the larger type.
2644 const SCEV *Fold = getAddExpr(LargeOps, SCEV::FlagAnyWrap, Depth + 1);
2645 // If it folds to something simple, use it. Otherwise, don't.
2646 if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
2647 return getTruncateExpr(Fold, Ty);
2648 }
2649 }
2650
2651 if (Ops.size() == 2) {
2652 // Check if we have an expression of the form ((X + C1) - C2), where C1 and
2653 // C2 can be folded in a way that allows retaining wrapping flags of (X +
2654 // C1).
2655 const SCEV *A = Ops[0];
2656 const SCEV *B = Ops[1];
2657 auto *AddExpr = dyn_cast<SCEVAddExpr>(B);
2658 auto *C = dyn_cast<SCEVConstant>(A);
2659 if (AddExpr && C && isa<SCEVConstant>(AddExpr->getOperand(0))) {
2660 auto C1 = cast<SCEVConstant>(AddExpr->getOperand(0))->getAPInt();
2661 auto C2 = C->getAPInt();
2662 SCEV::NoWrapFlags PreservedFlags = SCEV::FlagAnyWrap;
2663
2664 APInt ConstAdd = C1 + C2;
2665 auto AddFlags = AddExpr->getNoWrapFlags();
2666 // Adding a smaller constant is NUW if the original AddExpr was NUW.
2668 ConstAdd.ule(C1)) {
2669 PreservedFlags =
2671 }
2672
2673 // Adding a constant with the same sign and small magnitude is NSW, if the
2674 // original AddExpr was NSW.
2676 C1.isSignBitSet() == ConstAdd.isSignBitSet() &&
2677 ConstAdd.abs().ule(C1.abs())) {
2678 PreservedFlags =
2680 }
2681
2682 if (PreservedFlags != SCEV::FlagAnyWrap) {
2683 SmallVector<const SCEV *, 4> NewOps(AddExpr->operands());
2684 NewOps[0] = getConstant(ConstAdd);
2685 return getAddExpr(NewOps, PreservedFlags);
2686 }
2687 }
2688
2689 // Try to push the constant operand into a ZExt: A + zext (-A + B) -> zext
2690 // (B), if trunc (A) + -A + B does not unsigned-wrap.
2691 const SCEVAddExpr *InnerAdd;
2692 if (match(B, m_scev_ZExt(m_scev_Add(InnerAdd)))) {
2693 const SCEV *NarrowA = getTruncateExpr(A, InnerAdd->getType());
2694 if (NarrowA == getNegativeSCEV(InnerAdd->getOperand(0)) &&
2695 getZeroExtendExpr(NarrowA, B->getType()) == A &&
2696 hasFlags(StrengthenNoWrapFlags(this, scAddExpr, {NarrowA, InnerAdd},
2698 SCEV::FlagNUW)) {
2699 return getZeroExtendExpr(getAddExpr(NarrowA, InnerAdd), B->getType());
2700 }
2701 }
2702 }
2703
2704 // Canonicalize (-1 * urem X, Y) + X --> (Y * X/Y)
2705 const SCEV *Y;
2706 if (Ops.size() == 2 &&
2707 match(Ops[0],
2709 m_scev_URem(m_scev_Specific(Ops[1]), m_SCEV(Y), *this))))
2710 return getMulExpr(Y, getUDivExpr(Ops[1], Y));
2711
2712 // Skip past any other cast SCEVs.
2713 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
2714 ++Idx;
2715
2716 // If there are add operands they would be next.
2717 if (Idx < Ops.size()) {
2718 bool DeletedAdd = false;
2719 // If the original flags and all inlined SCEVAddExprs are NUW, use the
2720 // common NUW flag for expression after inlining. Other flags cannot be
2721 // preserved, because they may depend on the original order of operations.
2722 SCEV::NoWrapFlags CommonFlags = maskFlags(OrigFlags, SCEV::FlagNUW);
2723 while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
2724 if (Ops.size() > AddOpsInlineThreshold ||
2725 Add->getNumOperands() > AddOpsInlineThreshold)
2726 break;
2727 // If we have an add, expand the add operands onto the end of the operands
2728 // list.
2729 Ops.erase(Ops.begin()+Idx);
2730 append_range(Ops, Add->operands());
2731 DeletedAdd = true;
2732 CommonFlags = maskFlags(CommonFlags, Add->getNoWrapFlags());
2733 }
2734
2735 // If we deleted at least one add, we added operands to the end of the list,
2736 // and they are not necessarily sorted. Recurse to resort and resimplify
2737 // any operands we just acquired.
2738 if (DeletedAdd)
2739 return getAddExpr(Ops, CommonFlags, Depth + 1);
2740 }
2741
2742 // Skip over the add expression until we get to a multiply.
2743 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
2744 ++Idx;
2745
2746 // Check to see if there are any folding opportunities present with
2747 // operands multiplied by constant values.
2748 if (Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx])) {
2752 APInt AccumulatedConstant(BitWidth, 0);
2753 if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2754 Ops, APInt(BitWidth, 1), *this)) {
2755 struct APIntCompare {
2756 bool operator()(const APInt &LHS, const APInt &RHS) const {
2757 return LHS.ult(RHS);
2758 }
2759 };
2760
2761 // Some interesting folding opportunity is present, so its worthwhile to
2762 // re-generate the operands list. Group the operands by constant scale,
2763 // to avoid multiplying by the same constant scale multiple times.
2764 std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare> MulOpLists;
2765 for (const SCEV *NewOp : NewOps)
2766 MulOpLists[M.find(NewOp)->second].push_back(NewOp);
2767 // Re-generate the operands list.
2768 Ops.clear();
2769 if (AccumulatedConstant != 0)
2770 Ops.push_back(getConstant(AccumulatedConstant));
2771 for (auto &MulOp : MulOpLists) {
2772 if (MulOp.first == 1) {
2773 Ops.push_back(getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1));
2774 } else if (MulOp.first != 0) {
2775 Ops.push_back(getMulExpr(
2776 getConstant(MulOp.first),
2777 getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1),
2778 SCEV::FlagAnyWrap, Depth + 1));
2779 }
2780 }
2781 if (Ops.empty())
2782 return getZero(Ty);
2783 if (Ops.size() == 1)
2784 return Ops[0];
2785 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2786 }
2787 }
2788
2789 // If we are adding something to a multiply expression, make sure the
2790 // something is not already an operand of the multiply. If so, merge it into
2791 // the multiply.
2792 for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) {
2793 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]);
2794 for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) {
2795 const SCEV *MulOpSCEV = Mul->getOperand(MulOp);
2796 if (isa<SCEVConstant>(MulOpSCEV))
2797 continue;
2798 for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
2799 if (MulOpSCEV == Ops[AddOp]) {
2800 // Fold W + X + (X * Y * Z) --> W + (X * ((Y*Z)+1))
2801 const SCEV *InnerMul = Mul->getOperand(MulOp == 0);
2802 if (Mul->getNumOperands() != 2) {
2803 // If the multiply has more than two operands, we must get the
2804 // Y*Z term.
2806 Mul->operands().take_front(MulOp));
2807 append_range(MulOps, Mul->operands().drop_front(MulOp + 1));
2808 InnerMul = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2809 }
2810 SmallVector<const SCEV *, 2> TwoOps = {getOne(Ty), InnerMul};
2811 const SCEV *AddOne = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2812 const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV,
2814 if (Ops.size() == 2) return OuterMul;
2815 if (AddOp < Idx) {
2816 Ops.erase(Ops.begin()+AddOp);
2817 Ops.erase(Ops.begin()+Idx-1);
2818 } else {
2819 Ops.erase(Ops.begin()+Idx);
2820 Ops.erase(Ops.begin()+AddOp-1);
2821 }
2822 Ops.push_back(OuterMul);
2823 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2824 }
2825
2826 // Check this multiply against other multiplies being added together.
2827 for (unsigned OtherMulIdx = Idx+1;
2828 OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]);
2829 ++OtherMulIdx) {
2830 const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]);
2831 // If MulOp occurs in OtherMul, we can fold the two multiplies
2832 // together.
2833 for (unsigned OMulOp = 0, e = OtherMul->getNumOperands();
2834 OMulOp != e; ++OMulOp)
2835 if (OtherMul->getOperand(OMulOp) == MulOpSCEV) {
2836 // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E))
2837 const SCEV *InnerMul1 = Mul->getOperand(MulOp == 0);
2838 if (Mul->getNumOperands() != 2) {
2840 Mul->operands().take_front(MulOp));
2841 append_range(MulOps, Mul->operands().drop_front(MulOp+1));
2842 InnerMul1 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2843 }
2844 const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0);
2845 if (OtherMul->getNumOperands() != 2) {
2847 OtherMul->operands().take_front(OMulOp));
2848 append_range(MulOps, OtherMul->operands().drop_front(OMulOp+1));
2849 InnerMul2 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2850 }
2851 SmallVector<const SCEV *, 2> TwoOps = {InnerMul1, InnerMul2};
2852 const SCEV *InnerMulSum =
2853 getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2854 const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum,
2856 if (Ops.size() == 2) return OuterMul;
2857 Ops.erase(Ops.begin()+Idx);
2858 Ops.erase(Ops.begin()+OtherMulIdx-1);
2859 Ops.push_back(OuterMul);
2860 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2861 }
2862 }
2863 }
2864 }
2865
2866 // If there are any add recurrences in the operands list, see if any other
2867 // added values are loop invariant. If so, we can fold them into the
2868 // recurrence.
2869 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
2870 ++Idx;
2871
2872 // Scan over all recurrences, trying to fold loop invariants into them.
2873 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
2874 // Scan all of the other operands to this add and add them to the vector if
2875 // they are loop invariant w.r.t. the recurrence.
2877 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
2878 const Loop *AddRecLoop = AddRec->getLoop();
2879 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2880 if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) {
2881 LIOps.push_back(Ops[i]);
2882 Ops.erase(Ops.begin()+i);
2883 --i; --e;
2884 }
2885
2886 // If we found some loop invariants, fold them into the recurrence.
2887 if (!LIOps.empty()) {
2888 // Compute nowrap flags for the addition of the loop-invariant ops and
2889 // the addrec. Temporarily push it as an operand for that purpose. These
2890 // flags are valid in the scope of the addrec only.
2891 LIOps.push_back(AddRec);
2892 SCEV::NoWrapFlags Flags = ComputeFlags(LIOps);
2893 LIOps.pop_back();
2894
2895 // NLI + LI + {Start,+,Step} --> NLI + {LI+Start,+,Step}
2896 LIOps.push_back(AddRec->getStart());
2897
2898 SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2899
2900 // It is not in general safe to propagate flags valid on an add within
2901 // the addrec scope to one outside it. We must prove that the inner
2902 // scope is guaranteed to execute if the outer one does to be able to
2903 // safely propagate. We know the program is undefined if poison is
2904 // produced on the inner scoped addrec. We also know that *for this use*
2905 // the outer scoped add can't overflow (because of the flags we just
2906 // computed for the inner scoped add) without the program being undefined.
2907 // Proving that entry to the outer scope neccesitates entry to the inner
2908 // scope, thus proves the program undefined if the flags would be violated
2909 // in the outer scope.
2910 SCEV::NoWrapFlags AddFlags = Flags;
2911 if (AddFlags != SCEV::FlagAnyWrap) {
2912 auto *DefI = getDefiningScopeBound(LIOps);
2913 auto *ReachI = &*AddRecLoop->getHeader()->begin();
2914 if (!isGuaranteedToTransferExecutionTo(DefI, ReachI))
2915 AddFlags = SCEV::FlagAnyWrap;
2916 }
2917 AddRecOps[0] = getAddExpr(LIOps, AddFlags, Depth + 1);
2918
2919 // Build the new addrec. Propagate the NUW and NSW flags if both the
2920 // outer add and the inner addrec are guaranteed to have no overflow.
2921 // Always propagate NW.
2922 Flags = AddRec->getNoWrapFlags(setFlags(Flags, SCEV::FlagNW));
2923 const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags);
2924
2925 // If all of the other operands were loop invariant, we are done.
2926 if (Ops.size() == 1) return NewRec;
2927
2928 // Otherwise, add the folded AddRec by the non-invariant parts.
2929 for (unsigned i = 0;; ++i)
2930 if (Ops[i] == AddRec) {
2931 Ops[i] = NewRec;
2932 break;
2933 }
2934 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2935 }
2936
2937 // Okay, if there weren't any loop invariants to be folded, check to see if
2938 // there are multiple AddRec's with the same loop induction variable being
2939 // added together. If so, we can fold them.
2940 for (unsigned OtherIdx = Idx+1;
2941 OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2942 ++OtherIdx) {
2943 // We expect the AddRecExpr's to be sorted in reverse dominance order,
2944 // so that the 1st found AddRecExpr is dominated by all others.
2945 assert(DT.dominates(
2946 cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()->getHeader(),
2947 AddRec->getLoop()->getHeader()) &&
2948 "AddRecExprs are not sorted in reverse dominance order?");
2949 if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) {
2950 // Other + {A,+,B}<L> + {C,+,D}<L> --> Other + {A+C,+,B+D}<L>
2951 SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2952 for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2953 ++OtherIdx) {
2954 const auto *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
2955 if (OtherAddRec->getLoop() == AddRecLoop) {
2956 for (unsigned i = 0, e = OtherAddRec->getNumOperands();
2957 i != e; ++i) {
2958 if (i >= AddRecOps.size()) {
2959 append_range(AddRecOps, OtherAddRec->operands().drop_front(i));
2960 break;
2961 }
2963 AddRecOps[i], OtherAddRec->getOperand(i)};
2964 AddRecOps[i] = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2965 }
2966 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
2967 }
2968 }
2969 // Step size has changed, so we cannot guarantee no self-wraparound.
2970 Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap);
2971 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2972 }
2973 }
2974
2975 // Otherwise couldn't fold anything into this recurrence. Move onto the
2976 // next one.
2977 }
2978
2979 // Okay, it looks like we really DO need an add expr. Check to see if we
2980 // already have one, otherwise create a new one.
2981 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2982}
2983
2984const SCEV *
2985ScalarEvolution::getOrCreateAddExpr(ArrayRef<const SCEV *> Ops,
2986 SCEV::NoWrapFlags Flags) {
2988 ID.AddInteger(scAddExpr);
2989 for (const SCEV *Op : Ops)
2990 ID.AddPointer(Op);
2991 void *IP = nullptr;
2992 SCEVAddExpr *S =
2993 static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2994 if (!S) {
2995 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2997 S = new (SCEVAllocator)
2998 SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size());
2999 UniqueSCEVs.InsertNode(S, IP);
3000 registerUser(S, Ops);
3001 }
3002 S->setNoWrapFlags(Flags);
3003 return S;
3004}
3005
3006const SCEV *
3007ScalarEvolution::getOrCreateAddRecExpr(ArrayRef<const SCEV *> Ops,
3008 const Loop *L, SCEV::NoWrapFlags Flags) {
3009 FoldingSetNodeID ID;
3010 ID.AddInteger(scAddRecExpr);
3011 for (const SCEV *Op : Ops)
3012 ID.AddPointer(Op);
3013 ID.AddPointer(L);
3014 void *IP = nullptr;
3015 SCEVAddRecExpr *S =
3016 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3017 if (!S) {
3018 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3020 S = new (SCEVAllocator)
3021 SCEVAddRecExpr(ID.Intern(SCEVAllocator), O, Ops.size(), L);
3022 UniqueSCEVs.InsertNode(S, IP);
3023 LoopUsers[L].push_back(S);
3024 registerUser(S, Ops);
3025 }
3026 setNoWrapFlags(S, Flags);
3027 return S;
3028}
3029
3030const SCEV *
3031ScalarEvolution::getOrCreateMulExpr(ArrayRef<const SCEV *> Ops,
3032 SCEV::NoWrapFlags Flags) {
3033 FoldingSetNodeID ID;
3034 ID.AddInteger(scMulExpr);
3035 for (const SCEV *Op : Ops)
3036 ID.AddPointer(Op);
3037 void *IP = nullptr;
3038 SCEVMulExpr *S =
3039 static_cast<SCEVMulExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3040 if (!S) {
3041 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3043 S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator),
3044 O, Ops.size());
3045 UniqueSCEVs.InsertNode(S, IP);
3046 registerUser(S, Ops);
3047 }
3048 S->setNoWrapFlags(Flags);
3049 return S;
3050}
3051
3052static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) {
3053 uint64_t k = i*j;
3054 if (j > 1 && k / j != i) Overflow = true;
3055 return k;
3056}
3057
3058/// Compute the result of "n choose k", the binomial coefficient. If an
3059/// intermediate computation overflows, Overflow will be set and the return will
3060/// be garbage. Overflow is not cleared on absence of overflow.
3061static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) {
3062 // We use the multiplicative formula:
3063 // n(n-1)(n-2)...(n-(k-1)) / k(k-1)(k-2)...1 .
3064 // At each iteration, we take the n-th term of the numeral and divide by the
3065 // (k-n)th term of the denominator. This division will always produce an
3066 // integral result, and helps reduce the chance of overflow in the
3067 // intermediate computations. However, we can still overflow even when the
3068 // final result would fit.
3069
3070 if (n == 0 || n == k) return 1;
3071 if (k > n) return 0;
3072
3073 if (k > n/2)
3074 k = n-k;
3075
3076 uint64_t r = 1;
3077 for (uint64_t i = 1; i <= k; ++i) {
3078 r = umul_ov(r, n-(i-1), Overflow);
3079 r /= i;
3080 }
3081 return r;
3082}
3083
3084/// Determine if any of the operands in this SCEV are a constant or if
3085/// any of the add or multiply expressions in this SCEV contain a constant.
3086static bool containsConstantInAddMulChain(const SCEV *StartExpr) {
3087 struct FindConstantInAddMulChain {
3088 bool FoundConstant = false;
3089
3090 bool follow(const SCEV *S) {
3091 FoundConstant |= isa<SCEVConstant>(S);
3092 return isa<SCEVAddExpr>(S) || isa<SCEVMulExpr>(S);
3093 }
3094
3095 bool isDone() const {
3096 return FoundConstant;
3097 }
3098 };
3099
3100 FindConstantInAddMulChain F;
3102 ST.visitAll(StartExpr);
3103 return F.FoundConstant;
3104}
3105
3106/// Get a canonical multiply expression, or something simpler if possible.
3108 SCEV::NoWrapFlags OrigFlags,
3109 unsigned Depth) {
3110 assert(OrigFlags == maskFlags(OrigFlags, SCEV::FlagNUW | SCEV::FlagNSW) &&
3111 "only nuw or nsw allowed");
3112 assert(!Ops.empty() && "Cannot get empty mul!");
3113 if (Ops.size() == 1) return Ops[0];
3114#ifndef NDEBUG
3115 Type *ETy = Ops[0]->getType();
3116 assert(!ETy->isPointerTy());
3117 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
3118 assert(Ops[i]->getType() == ETy &&
3119 "SCEVMulExpr operand types don't match!");
3120#endif
3121
3122 const SCEV *Folded = constantFoldAndGroupOps(
3123 *this, LI, DT, Ops,
3124 [](const APInt &C1, const APInt &C2) { return C1 * C2; },
3125 [](const APInt &C) { return C.isOne(); }, // identity
3126 [](const APInt &C) { return C.isZero(); }); // absorber
3127 if (Folded)
3128 return Folded;
3129
3130 // Delay expensive flag strengthening until necessary.
3131 auto ComputeFlags = [this, OrigFlags](ArrayRef<const SCEV *> Ops) {
3132 return StrengthenNoWrapFlags(this, scMulExpr, Ops, OrigFlags);
3133 };
3134
3135 // Limit recursion calls depth.
3137 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3138
3139 if (SCEV *S = findExistingSCEVInCache(scMulExpr, Ops)) {
3140 // Don't strengthen flags if we have no new information.
3141 SCEVMulExpr *Mul = static_cast<SCEVMulExpr *>(S);
3142 if (Mul->getNoWrapFlags(OrigFlags) != OrigFlags)
3143 Mul->setNoWrapFlags(ComputeFlags(Ops));
3144 return S;
3145 }
3146
3147 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3148 if (Ops.size() == 2) {
3149 // C1*(C2+V) -> C1*C2 + C1*V
3150 // If any of Add's ops are Adds or Muls with a constant, apply this
3151 // transformation as well.
3152 //
3153 // TODO: There are some cases where this transformation is not
3154 // profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of
3155 // this transformation should be narrowed down.
3156 const SCEV *Op0, *Op1;
3157 if (match(Ops[1], m_scev_Add(m_SCEV(Op0), m_SCEV(Op1))) &&
3159 const SCEV *LHS = getMulExpr(LHSC, Op0, SCEV::FlagAnyWrap, Depth + 1);
3160 const SCEV *RHS = getMulExpr(LHSC, Op1, SCEV::FlagAnyWrap, Depth + 1);
3161 return getAddExpr(LHS, RHS, SCEV::FlagAnyWrap, Depth + 1);
3162 }
3163
3164 if (Ops[0]->isAllOnesValue()) {
3165 // If we have a mul by -1 of an add, try distributing the -1 among the
3166 // add operands.
3167 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) {
3169 bool AnyFolded = false;
3170 for (const SCEV *AddOp : Add->operands()) {
3171 const SCEV *Mul = getMulExpr(Ops[0], AddOp, SCEV::FlagAnyWrap,
3172 Depth + 1);
3173 if (!isa<SCEVMulExpr>(Mul)) AnyFolded = true;
3174 NewOps.push_back(Mul);
3175 }
3176 if (AnyFolded)
3177 return getAddExpr(NewOps, SCEV::FlagAnyWrap, Depth + 1);
3178 } else if (const auto *AddRec = dyn_cast<SCEVAddRecExpr>(Ops[1])) {
3179 // Negation preserves a recurrence's no self-wrap property.
3181 for (const SCEV *AddRecOp : AddRec->operands())
3182 Operands.push_back(getMulExpr(Ops[0], AddRecOp, SCEV::FlagAnyWrap,
3183 Depth + 1));
3184 // Let M be the minimum representable signed value. AddRec with nsw
3185 // multiplied by -1 can have signed overflow if and only if it takes a
3186 // value of M: M * (-1) would stay M and (M + 1) * (-1) would be the
3187 // maximum signed value. In all other cases signed overflow is
3188 // impossible.
3189 auto FlagsMask = SCEV::FlagNW;
3190 if (hasFlags(AddRec->getNoWrapFlags(), SCEV::FlagNSW)) {
3191 auto MinInt =
3192 APInt::getSignedMinValue(getTypeSizeInBits(AddRec->getType()));
3193 if (getSignedRangeMin(AddRec) != MinInt)
3194 FlagsMask = setFlags(FlagsMask, SCEV::FlagNSW);
3195 }
3196 return getAddRecExpr(Operands, AddRec->getLoop(),
3197 AddRec->getNoWrapFlags(FlagsMask));
3198 }
3199 }
3200
3201 // Try to push the constant operand into a ZExt: C * zext (A + B) ->
3202 // zext (C*A + C*B) if trunc (C) * (A + B) does not unsigned-wrap.
3203 const SCEVAddExpr *InnerAdd;
3204 if (match(Ops[1], m_scev_ZExt(m_scev_Add(InnerAdd)))) {
3205 const SCEV *NarrowC = getTruncateExpr(LHSC, InnerAdd->getType());
3206 if (isa<SCEVConstant>(InnerAdd->getOperand(0)) &&
3207 getZeroExtendExpr(NarrowC, Ops[1]->getType()) == LHSC &&
3208 hasFlags(StrengthenNoWrapFlags(this, scMulExpr, {NarrowC, InnerAdd},
3210 SCEV::FlagNUW)) {
3211 auto *Res = getMulExpr(NarrowC, InnerAdd, SCEV::FlagNUW, Depth + 1);
3212 return getZeroExtendExpr(Res, Ops[1]->getType(), Depth + 1);
3213 };
3214 }
3215
3216 // Try to fold (C1 * D /u C2) -> C1/C2 * D, if C1 and C2 are powers-of-2,
3217 // D is a multiple of C2, and C1 is a multiple of C2. If C2 is a multiple
3218 // of C1, fold to (D /u (C2 /u C1)).
3219 const SCEV *D;
3220 APInt C1V = LHSC->getAPInt();
3221 // (C1 * D /u C2) == -1 * -C1 * D /u C2 when C1 != INT_MIN. Don't treat -1
3222 // as -1 * 1, as it won't enable additional folds.
3223 if (C1V.isNegative() && !C1V.isMinSignedValue() && !C1V.isAllOnes())
3224 C1V = C1V.abs();
3225 const SCEVConstant *C2;
3226 if (C1V.isPowerOf2() &&
3228 C2->getAPInt().isPowerOf2() &&
3229 C1V.logBase2() <= getMinTrailingZeros(D)) {
3230 const SCEV *NewMul = nullptr;
3231 if (C1V.uge(C2->getAPInt())) {
3232 NewMul = getMulExpr(getUDivExpr(getConstant(C1V), C2), D);
3233 } else if (C2->getAPInt().logBase2() <= getMinTrailingZeros(D)) {
3234 assert(C1V.ugt(1) && "C1 <= 1 should have been folded earlier");
3235 NewMul = getUDivExpr(D, getUDivExpr(C2, getConstant(C1V)));
3236 }
3237 if (NewMul)
3238 return C1V == LHSC->getAPInt() ? NewMul : getNegativeSCEV(NewMul);
3239 }
3240 }
3241 }
3242
3243 // Skip over the add expression until we get to a multiply.
3244 unsigned Idx = 0;
3245 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
3246 ++Idx;
3247
3248 // If there are mul operands inline them all into this expression.
3249 if (Idx < Ops.size()) {
3250 bool DeletedMul = false;
3251 while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
3252 if (Ops.size() > MulOpsInlineThreshold)
3253 break;
3254 // If we have an mul, expand the mul operands onto the end of the
3255 // operands list.
3256 Ops.erase(Ops.begin()+Idx);
3257 append_range(Ops, Mul->operands());
3258 DeletedMul = true;
3259 }
3260
3261 // If we deleted at least one mul, we added operands to the end of the
3262 // list, and they are not necessarily sorted. Recurse to resort and
3263 // resimplify any operands we just acquired.
3264 if (DeletedMul)
3265 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3266 }
3267
3268 // If there are any add recurrences in the operands list, see if any other
3269 // added values are loop invariant. If so, we can fold them into the
3270 // recurrence.
3271 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
3272 ++Idx;
3273
3274 // Scan over all recurrences, trying to fold loop invariants into them.
3275 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
3276 // Scan all of the other operands to this mul and add them to the vector
3277 // if they are loop invariant w.r.t. the recurrence.
3279 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
3280 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3281 if (isAvailableAtLoopEntry(Ops[i], AddRec->getLoop())) {
3282 LIOps.push_back(Ops[i]);
3283 Ops.erase(Ops.begin()+i);
3284 --i; --e;
3285 }
3286
3287 // If we found some loop invariants, fold them into the recurrence.
3288 if (!LIOps.empty()) {
3289 // NLI * LI * {Start,+,Step} --> NLI * {LI*Start,+,LI*Step}
3291 NewOps.reserve(AddRec->getNumOperands());
3292 const SCEV *Scale = getMulExpr(LIOps, SCEV::FlagAnyWrap, Depth + 1);
3293
3294 // If both the mul and addrec are nuw, we can preserve nuw.
3295 // If both the mul and addrec are nsw, we can only preserve nsw if either
3296 // a) they are also nuw, or
3297 // b) all multiplications of addrec operands with scale are nsw.
3298 SCEV::NoWrapFlags Flags =
3299 AddRec->getNoWrapFlags(ComputeFlags({Scale, AddRec}));
3300
3301 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
3302 NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i),
3303 SCEV::FlagAnyWrap, Depth + 1));
3304
3305 if (hasFlags(Flags, SCEV::FlagNSW) && !hasFlags(Flags, SCEV::FlagNUW)) {
3307 Instruction::Mul, getSignedRange(Scale),
3309 if (!NSWRegion.contains(getSignedRange(AddRec->getOperand(i))))
3310 Flags = clearFlags(Flags, SCEV::FlagNSW);
3311 }
3312 }
3313
3314 const SCEV *NewRec = getAddRecExpr(NewOps, AddRec->getLoop(), Flags);
3315
3316 // If all of the other operands were loop invariant, we are done.
3317 if (Ops.size() == 1) return NewRec;
3318
3319 // Otherwise, multiply the folded AddRec by the non-invariant parts.
3320 for (unsigned i = 0;; ++i)
3321 if (Ops[i] == AddRec) {
3322 Ops[i] = NewRec;
3323 break;
3324 }
3325 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3326 }
3327
3328 // Okay, if there weren't any loop invariants to be folded, check to see
3329 // if there are multiple AddRec's with the same loop induction variable
3330 // being multiplied together. If so, we can fold them.
3331
3332 // {A1,+,A2,+,...,+,An}<L> * {B1,+,B2,+,...,+,Bn}<L>
3333 // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [
3334 // choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z
3335 // ]]],+,...up to x=2n}.
3336 // Note that the arguments to choose() are always integers with values
3337 // known at compile time, never SCEV objects.
3338 //
3339 // The implementation avoids pointless extra computations when the two
3340 // addrec's are of different length (mathematically, it's equivalent to
3341 // an infinite stream of zeros on the right).
3342 bool OpsModified = false;
3343 for (unsigned OtherIdx = Idx+1;
3344 OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
3345 ++OtherIdx) {
3346 const SCEVAddRecExpr *OtherAddRec =
3347 dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]);
3348 if (!OtherAddRec || OtherAddRec->getLoop() != AddRec->getLoop())
3349 continue;
3350
3351 // Limit max number of arguments to avoid creation of unreasonably big
3352 // SCEVAddRecs with very complex operands.
3353 if (AddRec->getNumOperands() + OtherAddRec->getNumOperands() - 1 >
3354 MaxAddRecSize || hasHugeExpression({AddRec, OtherAddRec}))
3355 continue;
3356
3357 bool Overflow = false;
3358 Type *Ty = AddRec->getType();
3359 bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64;
3361 for (int x = 0, xe = AddRec->getNumOperands() +
3362 OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) {
3364 for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) {
3365 uint64_t Coeff1 = Choose(x, 2*x - y, Overflow);
3366 for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1),
3367 ze = std::min(x+1, (int)OtherAddRec->getNumOperands());
3368 z < ze && !Overflow; ++z) {
3369 uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow);
3370 uint64_t Coeff;
3371 if (LargerThan64Bits)
3372 Coeff = umul_ov(Coeff1, Coeff2, Overflow);
3373 else
3374 Coeff = Coeff1*Coeff2;
3375 const SCEV *CoeffTerm = getConstant(Ty, Coeff);
3376 const SCEV *Term1 = AddRec->getOperand(y-z);
3377 const SCEV *Term2 = OtherAddRec->getOperand(z);
3378 SumOps.push_back(getMulExpr(CoeffTerm, Term1, Term2,
3379 SCEV::FlagAnyWrap, Depth + 1));
3380 }
3381 }
3382 if (SumOps.empty())
3383 SumOps.push_back(getZero(Ty));
3384 AddRecOps.push_back(getAddExpr(SumOps, SCEV::FlagAnyWrap, Depth + 1));
3385 }
3386 if (!Overflow) {
3387 const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRec->getLoop(),
3389 if (Ops.size() == 2) return NewAddRec;
3390 Ops[Idx] = NewAddRec;
3391 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
3392 OpsModified = true;
3393 AddRec = dyn_cast<SCEVAddRecExpr>(NewAddRec);
3394 if (!AddRec)
3395 break;
3396 }
3397 }
3398 if (OpsModified)
3399 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3400
3401 // Otherwise couldn't fold anything into this recurrence. Move onto the
3402 // next one.
3403 }
3404
3405 // Okay, it looks like we really DO need an mul expr. Check to see if we
3406 // already have one, otherwise create a new one.
3407 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3408}
3409
3410/// Represents an unsigned remainder expression based on unsigned division.
3412 const SCEV *RHS) {
3413 assert(getEffectiveSCEVType(LHS->getType()) ==
3414 getEffectiveSCEVType(RHS->getType()) &&
3415 "SCEVURemExpr operand types don't match!");
3416
3417 // Short-circuit easy cases
3418 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3419 // If constant is one, the result is trivial
3420 if (RHSC->getValue()->isOne())
3421 return getZero(LHS->getType()); // X urem 1 --> 0
3422
3423 // If constant is a power of two, fold into a zext(trunc(LHS)).
3424 if (RHSC->getAPInt().isPowerOf2()) {
3425 Type *FullTy = LHS->getType();
3426 Type *TruncTy =
3427 IntegerType::get(getContext(), RHSC->getAPInt().logBase2());
3428 return getZeroExtendExpr(getTruncateExpr(LHS, TruncTy), FullTy);
3429 }
3430 }
3431
3432 // Fallback to %a == %x urem %y == %x -<nuw> ((%x udiv %y) *<nuw> %y)
3433 const SCEV *UDiv = getUDivExpr(LHS, RHS);
3434 const SCEV *Mult = getMulExpr(UDiv, RHS, SCEV::FlagNUW);
3435 return getMinusSCEV(LHS, Mult, SCEV::FlagNUW);
3436}
3437
3438/// Get a canonical unsigned division expression, or something simpler if
3439/// possible.
3441 const SCEV *RHS) {
3442 assert(!LHS->getType()->isPointerTy() &&
3443 "SCEVUDivExpr operand can't be pointer!");
3444 assert(LHS->getType() == RHS->getType() &&
3445 "SCEVUDivExpr operand types don't match!");
3446
3448 ID.AddInteger(scUDivExpr);
3449 ID.AddPointer(LHS);
3450 ID.AddPointer(RHS);
3451 void *IP = nullptr;
3452 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3453 return S;
3454
3455 // 0 udiv Y == 0
3456 if (match(LHS, m_scev_Zero()))
3457 return LHS;
3458
3459 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3460 if (RHSC->getValue()->isOne())
3461 return LHS; // X udiv 1 --> x
3462 // If the denominator is zero, the result of the udiv is undefined. Don't
3463 // try to analyze it, because the resolution chosen here may differ from
3464 // the resolution chosen in other parts of the compiler.
3465 if (!RHSC->getValue()->isZero()) {
3466 // Determine if the division can be folded into the operands of
3467 // its operands.
3468 // TODO: Generalize this to non-constants by using known-bits information.
3469 Type *Ty = LHS->getType();
3470 unsigned LZ = RHSC->getAPInt().countl_zero();
3471 unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1;
3472 // For non-power-of-two values, effectively round the value up to the
3473 // nearest power of two.
3474 if (!RHSC->getAPInt().isPowerOf2())
3475 ++MaxShiftAmt;
3476 IntegerType *ExtTy =
3477 IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt);
3478 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
3479 if (const SCEVConstant *Step =
3480 dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this))) {
3481 // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded.
3482 const APInt &StepInt = Step->getAPInt();
3483 const APInt &DivInt = RHSC->getAPInt();
3484 if (!StepInt.urem(DivInt) &&
3485 getZeroExtendExpr(AR, ExtTy) ==
3486 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3487 getZeroExtendExpr(Step, ExtTy),
3488 AR->getLoop(), SCEV::FlagAnyWrap)) {
3490 for (const SCEV *Op : AR->operands())
3491 Operands.push_back(getUDivExpr(Op, RHS));
3492 return getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagNW);
3493 }
3494 /// Get a canonical UDivExpr for a recurrence.
3495 /// {X,+,N}/C => {Y,+,N}/C where Y=X-(X%N). Safe when C%N=0.
3496 const APInt *StartRem;
3497 if (!DivInt.urem(StepInt) && match(getURemExpr(AR->getStart(), Step),
3498 m_scev_APInt(StartRem))) {
3499 bool NoWrap =
3500 getZeroExtendExpr(AR, ExtTy) ==
3501 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3502 getZeroExtendExpr(Step, ExtTy), AR->getLoop(),
3504
3505 // With N <= C and both N, C as powers-of-2, the transformation
3506 // {X,+,N}/C => {(X - X%N),+,N}/C preserves division results even
3507 // if wrapping occurs, as the division results remain equivalent for
3508 // all offsets in [[(X - X%N), X).
3509 bool CanFoldWithWrap = StepInt.ule(DivInt) && // N <= C
3510 StepInt.isPowerOf2() && DivInt.isPowerOf2();
3511 // Only fold if the subtraction can be folded in the start
3512 // expression.
3513 const SCEV *NewStart =
3514 getMinusSCEV(AR->getStart(), getConstant(*StartRem));
3515 if (*StartRem != 0 && (NoWrap || CanFoldWithWrap) &&
3516 !isa<SCEVAddExpr>(NewStart)) {
3517 const SCEV *NewLHS =
3518 getAddRecExpr(NewStart, Step, AR->getLoop(),
3519 NoWrap ? SCEV::FlagNW : SCEV::FlagAnyWrap);
3520 if (LHS != NewLHS) {
3521 LHS = NewLHS;
3522
3523 // Reset the ID to include the new LHS, and check if it is
3524 // already cached.
3525 ID.clear();
3526 ID.AddInteger(scUDivExpr);
3527 ID.AddPointer(LHS);
3528 ID.AddPointer(RHS);
3529 IP = nullptr;
3530 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3531 return S;
3532 }
3533 }
3534 }
3535 }
3536 // (A*B)/C --> A*(B/C) if safe and B/C can be folded.
3537 if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) {
3539 for (const SCEV *Op : M->operands())
3540 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3541 if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands))
3542 // Find an operand that's safely divisible.
3543 for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) {
3544 const SCEV *Op = M->getOperand(i);
3545 const SCEV *Div = getUDivExpr(Op, RHSC);
3546 if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) {
3547 Operands = SmallVector<const SCEV *, 4>(M->operands());
3548 Operands[i] = Div;
3549 return getMulExpr(Operands);
3550 }
3551 }
3552 }
3553
3554 // (A/B)/C --> A/(B*C) if safe and B*C can be folded.
3555 if (const SCEVUDivExpr *OtherDiv = dyn_cast<SCEVUDivExpr>(LHS)) {
3556 if (auto *DivisorConstant =
3557 dyn_cast<SCEVConstant>(OtherDiv->getRHS())) {
3558 bool Overflow = false;
3559 APInt NewRHS =
3560 DivisorConstant->getAPInt().umul_ov(RHSC->getAPInt(), Overflow);
3561 if (Overflow) {
3562 return getConstant(RHSC->getType(), 0, false);
3563 }
3564 return getUDivExpr(OtherDiv->getLHS(), getConstant(NewRHS));
3565 }
3566 }
3567
3568 // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded.
3569 if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(LHS)) {
3571 for (const SCEV *Op : A->operands())
3572 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3573 if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) {
3574 Operands.clear();
3575 for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) {
3576 const SCEV *Op = getUDivExpr(A->getOperand(i), RHS);
3577 if (isa<SCEVUDivExpr>(Op) ||
3578 getMulExpr(Op, RHS) != A->getOperand(i))
3579 break;
3580 Operands.push_back(Op);
3581 }
3582 if (Operands.size() == A->getNumOperands())
3583 return getAddExpr(Operands);
3584 }
3585 }
3586
3587 // Fold if both operands are constant.
3588 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS))
3589 return getConstant(LHSC->getAPInt().udiv(RHSC->getAPInt()));
3590 }
3591 }
3592
3593 // ((-C + (C smax %x)) /u %x) evaluates to zero, for any positive constant C.
3594 const APInt *NegC, *C;
3595 if (match(LHS,
3598 NegC->isNegative() && !NegC->isMinSignedValue() && *C == -*NegC)
3599 return getZero(LHS->getType());
3600
3601 // TODO: Generalize to handle any common factors.
3602 // udiv (mul nuw a, vscale), (mul nuw b, vscale) --> udiv a, b
3603 const SCEV *NewLHS, *NewRHS;
3604 if (match(LHS, m_scev_c_NUWMul(m_SCEV(NewLHS), m_SCEVVScale())) &&
3605 match(RHS, m_scev_c_NUWMul(m_SCEV(NewRHS), m_SCEVVScale())))
3606 return getUDivExpr(NewLHS, NewRHS);
3607
3608 // The Insertion Point (IP) might be invalid by now (due to UniqueSCEVs
3609 // changes). Make sure we get a new one.
3610 IP = nullptr;
3611 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
3612 SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator),
3613 LHS, RHS);
3614 UniqueSCEVs.InsertNode(S, IP);
3615 registerUser(S, {LHS, RHS});
3616 return S;
3617}
3618
3619APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) {
3620 APInt A = C1->getAPInt().abs();
3621 APInt B = C2->getAPInt().abs();
3622 uint32_t ABW = A.getBitWidth();
3623 uint32_t BBW = B.getBitWidth();
3624
3625 if (ABW > BBW)
3626 B = B.zext(ABW);
3627 else if (ABW < BBW)
3628 A = A.zext(BBW);
3629
3630 return APIntOps::GreatestCommonDivisor(std::move(A), std::move(B));
3631}
3632
3633/// Get a canonical unsigned division expression, or something simpler if
3634/// possible. There is no representation for an exact udiv in SCEV IR, but we
3635/// can attempt to remove factors from the LHS and RHS. We can't do this when
3636/// it's not exact because the udiv may be clearing bits.
3638 const SCEV *RHS) {
3639 // TODO: we could try to find factors in all sorts of things, but for now we
3640 // just deal with u/exact (multiply, constant). See SCEVDivision towards the
3641 // end of this file for inspiration.
3642
3644 if (!Mul || !Mul->hasNoUnsignedWrap())
3645 return getUDivExpr(LHS, RHS);
3646
3647 if (const SCEVConstant *RHSCst = dyn_cast<SCEVConstant>(RHS)) {
3648 // If the mulexpr multiplies by a constant, then that constant must be the
3649 // first element of the mulexpr.
3650 if (const auto *LHSCst = dyn_cast<SCEVConstant>(Mul->getOperand(0))) {
3651 if (LHSCst == RHSCst) {
3652 SmallVector<const SCEV *, 2> Operands(drop_begin(Mul->operands()));
3653 return getMulExpr(Operands);
3654 }
3655
3656 // We can't just assume that LHSCst divides RHSCst cleanly, it could be
3657 // that there's a factor provided by one of the other terms. We need to
3658 // check.
3659 APInt Factor = gcd(LHSCst, RHSCst);
3660 if (!Factor.isIntN(1)) {
3661 LHSCst =
3662 cast<SCEVConstant>(getConstant(LHSCst->getAPInt().udiv(Factor)));
3663 RHSCst =
3664 cast<SCEVConstant>(getConstant(RHSCst->getAPInt().udiv(Factor)));
3666 Operands.push_back(LHSCst);
3667 append_range(Operands, Mul->operands().drop_front());
3668 LHS = getMulExpr(Operands);
3669 RHS = RHSCst;
3671 if (!Mul)
3672 return getUDivExactExpr(LHS, RHS);
3673 }
3674 }
3675 }
3676
3677 for (int i = 0, e = Mul->getNumOperands(); i != e; ++i) {
3678 if (Mul->getOperand(i) == RHS) {
3680 append_range(Operands, Mul->operands().take_front(i));
3681 append_range(Operands, Mul->operands().drop_front(i + 1));
3682 return getMulExpr(Operands);
3683 }
3684 }
3685
3686 return getUDivExpr(LHS, RHS);
3687}
3688
3689/// Get an add recurrence expression for the specified loop. Simplify the
3690/// expression as much as possible.
3691const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, const SCEV *Step,
3692 const Loop *L,
3693 SCEV::NoWrapFlags Flags) {
3695 Operands.push_back(Start);
3696 if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step))
3697 if (StepChrec->getLoop() == L) {
3698 append_range(Operands, StepChrec->operands());
3699 return getAddRecExpr(Operands, L, maskFlags(Flags, SCEV::FlagNW));
3700 }
3701
3702 Operands.push_back(Step);
3703 return getAddRecExpr(Operands, L, Flags);
3704}
3705
3706/// Get an add recurrence expression for the specified loop. Simplify the
3707/// expression as much as possible.
3708const SCEV *
3710 const Loop *L, SCEV::NoWrapFlags Flags) {
3711 if (Operands.size() == 1) return Operands[0];
3712#ifndef NDEBUG
3713 Type *ETy = getEffectiveSCEVType(Operands[0]->getType());
3714 for (const SCEV *Op : llvm::drop_begin(Operands)) {
3715 assert(getEffectiveSCEVType(Op->getType()) == ETy &&
3716 "SCEVAddRecExpr operand types don't match!");
3717 assert(!Op->getType()->isPointerTy() && "Step must be integer");
3718 }
3719 for (const SCEV *Op : Operands)
3721 "SCEVAddRecExpr operand is not available at loop entry!");
3722#endif
3723
3724 if (Operands.back()->isZero()) {
3725 Operands.pop_back();
3726 return getAddRecExpr(Operands, L, SCEV::FlagAnyWrap); // {X,+,0} --> X
3727 }
3728
3729 // It's tempting to want to call getConstantMaxBackedgeTakenCount count here and
3730 // use that information to infer NUW and NSW flags. However, computing a
3731 // BE count requires calling getAddRecExpr, so we may not yet have a
3732 // meaningful BE count at this point (and if we don't, we'd be stuck
3733 // with a SCEVCouldNotCompute as the cached BE count).
3734
3735 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
3736
3737 // Canonicalize nested AddRecs in by nesting them in order of loop depth.
3738 if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
3739 const Loop *NestedLoop = NestedAR->getLoop();
3740 if (L->contains(NestedLoop)
3741 ? (L->getLoopDepth() < NestedLoop->getLoopDepth())
3742 : (!NestedLoop->contains(L) &&
3743 DT.dominates(L->getHeader(), NestedLoop->getHeader()))) {
3744 SmallVector<const SCEV *, 4> NestedOperands(NestedAR->operands());
3745 Operands[0] = NestedAR->getStart();
3746 // AddRecs require their operands be loop-invariant with respect to their
3747 // loops. Don't perform this transformation if it would break this
3748 // requirement.
3749 bool AllInvariant = all_of(
3750 Operands, [&](const SCEV *Op) { return isLoopInvariant(Op, L); });
3751
3752 if (AllInvariant) {
3753 // Create a recurrence for the outer loop with the same step size.
3754 //
3755 // The outer recurrence keeps its NW flag but only keeps NUW/NSW if the
3756 // inner recurrence has the same property.
3757 SCEV::NoWrapFlags OuterFlags =
3758 maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags());
3759
3760 NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags);
3761 AllInvariant = all_of(NestedOperands, [&](const SCEV *Op) {
3762 return isLoopInvariant(Op, NestedLoop);
3763 });
3764
3765 if (AllInvariant) {
3766 // Ok, both add recurrences are valid after the transformation.
3767 //
3768 // The inner recurrence keeps its NW flag but only keeps NUW/NSW if
3769 // the outer recurrence has the same property.
3770 SCEV::NoWrapFlags InnerFlags =
3771 maskFlags(NestedAR->getNoWrapFlags(), SCEV::FlagNW | Flags);
3772 return getAddRecExpr(NestedOperands, NestedLoop, InnerFlags);
3773 }
3774 }
3775 // Reset Operands to its original state.
3776 Operands[0] = NestedAR;
3777 }
3778 }
3779
3780 // Okay, it looks like we really DO need an addrec expr. Check to see if we
3781 // already have one, otherwise create a new one.
3782 return getOrCreateAddRecExpr(Operands, L, Flags);
3783}
3784
3786 ArrayRef<const SCEV *> IndexExprs) {
3787 const SCEV *BaseExpr = getSCEV(GEP->getPointerOperand());
3788 // getSCEV(Base)->getType() has the same address space as Base->getType()
3789 // because SCEV::getType() preserves the address space.
3790 GEPNoWrapFlags NW = GEP->getNoWrapFlags();
3791 if (NW != GEPNoWrapFlags::none()) {
3792 // We'd like to propagate flags from the IR to the corresponding SCEV nodes,
3793 // but to do that, we have to ensure that said flag is valid in the entire
3794 // defined scope of the SCEV.
3795 // TODO: non-instructions have global scope. We might be able to prove
3796 // some global scope cases
3797 auto *GEPI = dyn_cast<Instruction>(GEP);
3798 if (!GEPI || !isSCEVExprNeverPoison(GEPI))
3799 NW = GEPNoWrapFlags::none();
3800 }
3801
3802 return getGEPExpr(BaseExpr, IndexExprs, GEP->getSourceElementType(), NW);
3803}
3804
3806 ArrayRef<const SCEV *> IndexExprs,
3807 Type *SrcElementTy, GEPNoWrapFlags NW) {
3809 if (NW.hasNoUnsignedSignedWrap())
3810 OffsetWrap = setFlags(OffsetWrap, SCEV::FlagNSW);
3811 if (NW.hasNoUnsignedWrap())
3812 OffsetWrap = setFlags(OffsetWrap, SCEV::FlagNUW);
3813
3814 Type *CurTy = BaseExpr->getType();
3815 Type *IntIdxTy = getEffectiveSCEVType(BaseExpr->getType());
3816 bool FirstIter = true;
3818 for (const SCEV *IndexExpr : IndexExprs) {
3819 // Compute the (potentially symbolic) offset in bytes for this index.
3820 if (StructType *STy = dyn_cast<StructType>(CurTy)) {
3821 // For a struct, add the member offset.
3822 ConstantInt *Index = cast<SCEVConstant>(IndexExpr)->getValue();
3823 unsigned FieldNo = Index->getZExtValue();
3824 const SCEV *FieldOffset = getOffsetOfExpr(IntIdxTy, STy, FieldNo);
3825 Offsets.push_back(FieldOffset);
3826
3827 // Update CurTy to the type of the field at Index.
3828 CurTy = STy->getTypeAtIndex(Index);
3829 } else {
3830 // Update CurTy to its element type.
3831 if (FirstIter) {
3832 assert(isa<PointerType>(CurTy) &&
3833 "The first index of a GEP indexes a pointer");
3834 CurTy = SrcElementTy;
3835 FirstIter = false;
3836 } else {
3838 }
3839 // For an array, add the element offset, explicitly scaled.
3840 const SCEV *ElementSize = getSizeOfExpr(IntIdxTy, CurTy);
3841 // Getelementptr indices are signed.
3842 IndexExpr = getTruncateOrSignExtend(IndexExpr, IntIdxTy);
3843
3844 // Multiply the index by the element size to compute the element offset.
3845 const SCEV *LocalOffset = getMulExpr(IndexExpr, ElementSize, OffsetWrap);
3846 Offsets.push_back(LocalOffset);
3847 }
3848 }
3849
3850 // Handle degenerate case of GEP without offsets.
3851 if (Offsets.empty())
3852 return BaseExpr;
3853
3854 // Add the offsets together, assuming nsw if inbounds.
3855 const SCEV *Offset = getAddExpr(Offsets, OffsetWrap);
3856 // Add the base address and the offset. We cannot use the nsw flag, as the
3857 // base address is unsigned. However, if we know that the offset is
3858 // non-negative, we can use nuw.
3859 bool NUW = NW.hasNoUnsignedWrap() ||
3862 auto *GEPExpr = getAddExpr(BaseExpr, Offset, BaseWrap);
3863 assert(BaseExpr->getType() == GEPExpr->getType() &&
3864 "GEP should not change type mid-flight.");
3865 return GEPExpr;
3866}
3867
3868SCEV *ScalarEvolution::findExistingSCEVInCache(SCEVTypes SCEVType,
3871 ID.AddInteger(SCEVType);
3872 for (const SCEV *Op : Ops)
3873 ID.AddPointer(Op);
3874 void *IP = nullptr;
3875 return UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3876}
3877
3878const SCEV *ScalarEvolution::getAbsExpr(const SCEV *Op, bool IsNSW) {
3880 return getSMaxExpr(Op, getNegativeSCEV(Op, Flags));
3881}
3882
3885 assert(SCEVMinMaxExpr::isMinMaxType(Kind) && "Not a SCEVMinMaxExpr!");
3886 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
3887 if (Ops.size() == 1) return Ops[0];
3888#ifndef NDEBUG
3889 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
3890 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
3891 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
3892 "Operand types don't match!");
3893 assert(Ops[0]->getType()->isPointerTy() ==
3894 Ops[i]->getType()->isPointerTy() &&
3895 "min/max should be consistently pointerish");
3896 }
3897#endif
3898
3899 bool IsSigned = Kind == scSMaxExpr || Kind == scSMinExpr;
3900 bool IsMax = Kind == scSMaxExpr || Kind == scUMaxExpr;
3901
3902 const SCEV *Folded = constantFoldAndGroupOps(
3903 *this, LI, DT, Ops,
3904 [&](const APInt &C1, const APInt &C2) {
3905 switch (Kind) {
3906 case scSMaxExpr:
3907 return APIntOps::smax(C1, C2);
3908 case scSMinExpr:
3909 return APIntOps::smin(C1, C2);
3910 case scUMaxExpr:
3911 return APIntOps::umax(C1, C2);
3912 case scUMinExpr:
3913 return APIntOps::umin(C1, C2);
3914 default:
3915 llvm_unreachable("Unknown SCEV min/max opcode");
3916 }
3917 },
3918 [&](const APInt &C) {
3919 // identity
3920 if (IsMax)
3921 return IsSigned ? C.isMinSignedValue() : C.isMinValue();
3922 else
3923 return IsSigned ? C.isMaxSignedValue() : C.isMaxValue();
3924 },
3925 [&](const APInt &C) {
3926 // absorber
3927 if (IsMax)
3928 return IsSigned ? C.isMaxSignedValue() : C.isMaxValue();
3929 else
3930 return IsSigned ? C.isMinSignedValue() : C.isMinValue();
3931 });
3932 if (Folded)
3933 return Folded;
3934
3935 // Check if we have created the same expression before.
3936 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops)) {
3937 return S;
3938 }
3939
3940 // Find the first operation of the same kind
3941 unsigned Idx = 0;
3942 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < Kind)
3943 ++Idx;
3944
3945 // Check to see if one of the operands is of the same kind. If so, expand its
3946 // operands onto our operand list, and recurse to simplify.
3947 if (Idx < Ops.size()) {
3948 bool DeletedAny = false;
3949 while (Ops[Idx]->getSCEVType() == Kind) {
3950 const SCEVMinMaxExpr *SMME = cast<SCEVMinMaxExpr>(Ops[Idx]);
3951 Ops.erase(Ops.begin()+Idx);
3952 append_range(Ops, SMME->operands());
3953 DeletedAny = true;
3954 }
3955
3956 if (DeletedAny)
3957 return getMinMaxExpr(Kind, Ops);
3958 }
3959
3960 // Okay, check to see if the same value occurs in the operand list twice. If
3961 // so, delete one. Since we sorted the list, these values are required to
3962 // be adjacent.
3967 llvm::CmpInst::Predicate FirstPred = IsMax ? GEPred : LEPred;
3968 llvm::CmpInst::Predicate SecondPred = IsMax ? LEPred : GEPred;
3969 for (unsigned i = 0, e = Ops.size() - 1; i != e; ++i) {
3970 if (Ops[i] == Ops[i + 1] ||
3971 isKnownViaNonRecursiveReasoning(FirstPred, Ops[i], Ops[i + 1])) {
3972 // X op Y op Y --> X op Y
3973 // X op Y --> X, if we know X, Y are ordered appropriately
3974 Ops.erase(Ops.begin() + i + 1, Ops.begin() + i + 2);
3975 --i;
3976 --e;
3977 } else if (isKnownViaNonRecursiveReasoning(SecondPred, Ops[i],
3978 Ops[i + 1])) {
3979 // X op Y --> Y, if we know X, Y are ordered appropriately
3980 Ops.erase(Ops.begin() + i, Ops.begin() + i + 1);
3981 --i;
3982 --e;
3983 }
3984 }
3985
3986 if (Ops.size() == 1) return Ops[0];
3987
3988 assert(!Ops.empty() && "Reduced smax down to nothing!");
3989
3990 // Okay, it looks like we really DO need an expr. Check to see if we
3991 // already have one, otherwise create a new one.
3993 ID.AddInteger(Kind);
3994 for (const SCEV *Op : Ops)
3995 ID.AddPointer(Op);
3996 void *IP = nullptr;
3997 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3998 if (ExistingSCEV)
3999 return ExistingSCEV;
4000 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
4002 SCEV *S = new (SCEVAllocator)
4003 SCEVMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
4004
4005 UniqueSCEVs.InsertNode(S, IP);
4006 registerUser(S, Ops);
4007 return S;
4008}
4009
4010namespace {
4011
4012class SCEVSequentialMinMaxDeduplicatingVisitor final
4013 : public SCEVVisitor<SCEVSequentialMinMaxDeduplicatingVisitor,
4014 std::optional<const SCEV *>> {
4015 using RetVal = std::optional<const SCEV *>;
4017
4018 ScalarEvolution &SE;
4019 const SCEVTypes RootKind; // Must be a sequential min/max expression.
4020 const SCEVTypes NonSequentialRootKind; // Non-sequential variant of RootKind.
4022
4023 bool canRecurseInto(SCEVTypes Kind) const {
4024 // We can only recurse into the SCEV expression of the same effective type
4025 // as the type of our root SCEV expression.
4026 return RootKind == Kind || NonSequentialRootKind == Kind;
4027 };
4028
4029 RetVal visitAnyMinMaxExpr(const SCEV *S) {
4031 "Only for min/max expressions.");
4032 SCEVTypes Kind = S->getSCEVType();
4033
4034 if (!canRecurseInto(Kind))
4035 return S;
4036
4037 auto *NAry = cast<SCEVNAryExpr>(S);
4039 bool Changed = visit(Kind, NAry->operands(), NewOps);
4040
4041 if (!Changed)
4042 return S;
4043 if (NewOps.empty())
4044 return std::nullopt;
4045
4047 ? SE.getSequentialMinMaxExpr(Kind, NewOps)
4048 : SE.getMinMaxExpr(Kind, NewOps);
4049 }
4050
4051 RetVal visit(const SCEV *S) {
4052 // Has the whole operand been seen already?
4053 if (!SeenOps.insert(S).second)
4054 return std::nullopt;
4055 return Base::visit(S);
4056 }
4057
4058public:
4059 SCEVSequentialMinMaxDeduplicatingVisitor(ScalarEvolution &SE,
4060 SCEVTypes RootKind)
4061 : SE(SE), RootKind(RootKind),
4062 NonSequentialRootKind(
4063 SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType(
4064 RootKind)) {}
4065
4066 bool /*Changed*/ visit(SCEVTypes Kind, ArrayRef<const SCEV *> OrigOps,
4067 SmallVectorImpl<const SCEV *> &NewOps) {
4068 bool Changed = false;
4070 Ops.reserve(OrigOps.size());
4071
4072 for (const SCEV *Op : OrigOps) {
4073 RetVal NewOp = visit(Op);
4074 if (NewOp != Op)
4075 Changed = true;
4076 if (NewOp)
4077 Ops.emplace_back(*NewOp);
4078 }
4079
4080 if (Changed)
4081 NewOps = std::move(Ops);
4082 return Changed;
4083 }
4084
4085 RetVal visitConstant(const SCEVConstant *Constant) { return Constant; }
4086
4087 RetVal visitVScale(const SCEVVScale *VScale) { return VScale; }
4088
4089 RetVal visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { return Expr; }
4090
4091 RetVal visitTruncateExpr(const SCEVTruncateExpr *Expr) { return Expr; }
4092
4093 RetVal visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { return Expr; }
4094
4095 RetVal visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { return Expr; }
4096
4097 RetVal visitAddExpr(const SCEVAddExpr *Expr) { return Expr; }
4098
4099 RetVal visitMulExpr(const SCEVMulExpr *Expr) { return Expr; }
4100
4101 RetVal visitUDivExpr(const SCEVUDivExpr *Expr) { return Expr; }
4102
4103 RetVal visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
4104
4105 RetVal visitSMaxExpr(const SCEVSMaxExpr *Expr) {
4106 return visitAnyMinMaxExpr(Expr);
4107 }
4108
4109 RetVal visitUMaxExpr(const SCEVUMaxExpr *Expr) {
4110 return visitAnyMinMaxExpr(Expr);
4111 }
4112
4113 RetVal visitSMinExpr(const SCEVSMinExpr *Expr) {
4114 return visitAnyMinMaxExpr(Expr);
4115 }
4116
4117 RetVal visitUMinExpr(const SCEVUMinExpr *Expr) {
4118 return visitAnyMinMaxExpr(Expr);
4119 }
4120
4121 RetVal visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
4122 return visitAnyMinMaxExpr(Expr);
4123 }
4124
4125 RetVal visitUnknown(const SCEVUnknown *Expr) { return Expr; }
4126
4127 RetVal visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { return Expr; }
4128};
4129
4130} // namespace
4131
4133 switch (Kind) {
4134 case scConstant:
4135 case scVScale:
4136 case scTruncate:
4137 case scZeroExtend:
4138 case scSignExtend:
4139 case scPtrToInt:
4140 case scAddExpr:
4141 case scMulExpr:
4142 case scUDivExpr:
4143 case scAddRecExpr:
4144 case scUMaxExpr:
4145 case scSMaxExpr:
4146 case scUMinExpr:
4147 case scSMinExpr:
4148 case scUnknown:
4149 // If any operand is poison, the whole expression is poison.
4150 return true;
4152 // FIXME: if the *first* operand is poison, the whole expression is poison.
4153 return false; // Pessimistically, say that it does not propagate poison.
4154 case scCouldNotCompute:
4155 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
4156 }
4157 llvm_unreachable("Unknown SCEV kind!");
4158}
4159
4160namespace {
4161// The only way poison may be introduced in a SCEV expression is from a
4162// poison SCEVUnknown (ConstantExprs are also represented as SCEVUnknown,
4163// not SCEVConstant). Notably, nowrap flags in SCEV nodes can *not*
4164// introduce poison -- they encode guaranteed, non-speculated knowledge.
4165//
4166// Additionally, all SCEV nodes propagate poison from inputs to outputs,
4167// with the notable exception of umin_seq, where only poison from the first
4168// operand is (unconditionally) propagated.
4169struct SCEVPoisonCollector {
4170 bool LookThroughMaybePoisonBlocking;
4171 SmallPtrSet<const SCEVUnknown *, 4> MaybePoison;
4172 SCEVPoisonCollector(bool LookThroughMaybePoisonBlocking)
4173 : LookThroughMaybePoisonBlocking(LookThroughMaybePoisonBlocking) {}
4174
4175 bool follow(const SCEV *S) {
4176 if (!LookThroughMaybePoisonBlocking &&
4178 return false;
4179
4180 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
4181 if (!isGuaranteedNotToBePoison(SU->getValue()))
4182 MaybePoison.insert(SU);
4183 }
4184 return true;
4185 }
4186 bool isDone() const { return false; }
4187};
4188} // namespace
4189
4190/// Return true if V is poison given that AssumedPoison is already poison.
4191static bool impliesPoison(const SCEV *AssumedPoison, const SCEV *S) {
4192 // First collect all SCEVs that might result in AssumedPoison to be poison.
4193 // We need to look through potentially poison-blocking operations here,
4194 // because we want to find all SCEVs that *might* result in poison, not only
4195 // those that are *required* to.
4196 SCEVPoisonCollector PC1(/* LookThroughMaybePoisonBlocking */ true);
4197 visitAll(AssumedPoison, PC1);
4198
4199 // AssumedPoison is never poison. As the assumption is false, the implication
4200 // is true. Don't bother walking the other SCEV in this case.
4201 if (PC1.MaybePoison.empty())
4202 return true;
4203
4204 // Collect all SCEVs in S that, if poison, *will* result in S being poison
4205 // as well. We cannot look through potentially poison-blocking operations
4206 // here, as their arguments only *may* make the result poison.
4207 SCEVPoisonCollector PC2(/* LookThroughMaybePoisonBlocking */ false);
4208 visitAll(S, PC2);
4209
4210 // Make sure that no matter which SCEV in PC1.MaybePoison is actually poison,
4211 // it will also make S poison by being part of PC2.MaybePoison.
4212 return llvm::set_is_subset(PC1.MaybePoison, PC2.MaybePoison);
4213}
4214
4216 SmallPtrSetImpl<const Value *> &Result, const SCEV *S) {
4217 SCEVPoisonCollector PC(/* LookThroughMaybePoisonBlocking */ false);
4218 visitAll(S, PC);
4219 for (const SCEVUnknown *SU : PC.MaybePoison)
4220 Result.insert(SU->getValue());
4221}
4222
4224 const SCEV *S, Instruction *I,
4225 SmallVectorImpl<Instruction *> &DropPoisonGeneratingInsts) {
4226 // If the instruction cannot be poison, it's always safe to reuse.
4228 return true;
4229
4230 // Otherwise, it is possible that I is more poisonous that S. Collect the
4231 // poison-contributors of S, and then check whether I has any additional
4232 // poison-contributors. Poison that is contributed through poison-generating
4233 // flags is handled by dropping those flags instead.
4235 getPoisonGeneratingValues(PoisonVals, S);
4236
4237 SmallVector<Value *> Worklist;
4239 Worklist.push_back(I);
4240 while (!Worklist.empty()) {
4241 Value *V = Worklist.pop_back_val();
4242 if (!Visited.insert(V).second)
4243 continue;
4244
4245 // Avoid walking large instruction graphs.
4246 if (Visited.size() > 16)
4247 return false;
4248
4249 // Either the value can't be poison, or the S would also be poison if it
4250 // is.
4251 if (PoisonVals.contains(V) || ::isGuaranteedNotToBePoison(V))
4252 continue;
4253
4254 auto *I = dyn_cast<Instruction>(V);
4255 if (!I)
4256 return false;
4257
4258 // Disjoint or instructions are interpreted as adds by SCEV. However, we
4259 // can't replace an arbitrary add with disjoint or, even if we drop the
4260 // flag. We would need to convert the or into an add.
4261 if (auto *PDI = dyn_cast<PossiblyDisjointInst>(I))
4262 if (PDI->isDisjoint())
4263 return false;
4264
4265 // FIXME: Ignore vscale, even though it technically could be poison. Do this
4266 // because SCEV currently assumes it can't be poison. Remove this special
4267 // case once we proper model when vscale can be poison.
4268 if (auto *II = dyn_cast<IntrinsicInst>(I);
4269 II && II->getIntrinsicID() == Intrinsic::vscale)
4270 continue;
4271
4272 if (canCreatePoison(cast<Operator>(I), /*ConsiderFlagsAndMetadata*/ false))
4273 return false;
4274
4275 // If the instruction can't create poison, we can recurse to its operands.
4276 if (I->hasPoisonGeneratingAnnotations())
4277 DropPoisonGeneratingInsts.push_back(I);
4278
4279 llvm::append_range(Worklist, I->operands());
4280 }
4281 return true;
4282}
4283
4284const SCEV *
4287 assert(SCEVSequentialMinMaxExpr::isSequentialMinMaxType(Kind) &&
4288 "Not a SCEVSequentialMinMaxExpr!");
4289 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
4290 if (Ops.size() == 1)
4291 return Ops[0];
4292#ifndef NDEBUG
4293 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
4294 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4295 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
4296 "Operand types don't match!");
4297 assert(Ops[0]->getType()->isPointerTy() ==
4298 Ops[i]->getType()->isPointerTy() &&
4299 "min/max should be consistently pointerish");
4300 }
4301#endif
4302
4303 // Note that SCEVSequentialMinMaxExpr is *NOT* commutative,
4304 // so we can *NOT* do any kind of sorting of the expressions!
4305
4306 // Check if we have created the same expression before.
4307 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops))
4308 return S;
4309
4310 // FIXME: there are *some* simplifications that we can do here.
4311
4312 // Keep only the first instance of an operand.
4313 {
4314 SCEVSequentialMinMaxDeduplicatingVisitor Deduplicator(*this, Kind);
4315 bool Changed = Deduplicator.visit(Kind, Ops, Ops);
4316 if (Changed)
4317 return getSequentialMinMaxExpr(Kind, Ops);
4318 }
4319
4320 // Check to see if one of the operands is of the same kind. If so, expand its
4321 // operands onto our operand list, and recurse to simplify.
4322 {
4323 unsigned Idx = 0;
4324 bool DeletedAny = false;
4325 while (Idx < Ops.size()) {
4326 if (Ops[Idx]->getSCEVType() != Kind) {
4327 ++Idx;
4328 continue;
4329 }
4330 const auto *SMME = cast<SCEVSequentialMinMaxExpr>(Ops[Idx]);
4331 Ops.erase(Ops.begin() + Idx);
4332 Ops.insert(Ops.begin() + Idx, SMME->operands().begin(),
4333 SMME->operands().end());
4334 DeletedAny = true;
4335 }
4336
4337 if (DeletedAny)
4338 return getSequentialMinMaxExpr(Kind, Ops);
4339 }
4340
4341 const SCEV *SaturationPoint;
4343 switch (Kind) {
4345 SaturationPoint = getZero(Ops[0]->getType());
4346 Pred = ICmpInst::ICMP_ULE;
4347 break;
4348 default:
4349 llvm_unreachable("Not a sequential min/max type.");
4350 }
4351
4352 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4353 if (!isGuaranteedNotToCauseUB(Ops[i]))
4354 continue;
4355 // We can replace %x umin_seq %y with %x umin %y if either:
4356 // * %y being poison implies %x is also poison.
4357 // * %x cannot be the saturating value (e.g. zero for umin).
4358 if (::impliesPoison(Ops[i], Ops[i - 1]) ||
4359 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_NE, Ops[i - 1],
4360 SaturationPoint)) {
4361 SmallVector<const SCEV *> SeqOps = {Ops[i - 1], Ops[i]};
4362 Ops[i - 1] = getMinMaxExpr(
4364 SeqOps);
4365 Ops.erase(Ops.begin() + i);
4366 return getSequentialMinMaxExpr(Kind, Ops);
4367 }
4368 // Fold %x umin_seq %y to %x if %x ule %y.
4369 // TODO: We might be able to prove the predicate for a later operand.
4370 if (isKnownViaNonRecursiveReasoning(Pred, Ops[i - 1], Ops[i])) {
4371 Ops.erase(Ops.begin() + i);
4372 return getSequentialMinMaxExpr(Kind, Ops);
4373 }
4374 }
4375
4376 // Okay, it looks like we really DO need an expr. Check to see if we
4377 // already have one, otherwise create a new one.
4379 ID.AddInteger(Kind);
4380 for (const SCEV *Op : Ops)
4381 ID.AddPointer(Op);
4382 void *IP = nullptr;
4383 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
4384 if (ExistingSCEV)
4385 return ExistingSCEV;
4386
4387 const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
4389 SCEV *S = new (SCEVAllocator)
4390 SCEVSequentialMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
4391
4392 UniqueSCEVs.InsertNode(S, IP);
4393 registerUser(S, Ops);
4394 return S;
4395}
4396
4397const SCEV *ScalarEvolution::getSMaxExpr(const SCEV *LHS, const SCEV *RHS) {
4398 SmallVector<const SCEV *, 2> Ops = {LHS, RHS};
4399 return getSMaxExpr(Ops);
4400}
4401
4405
4406const SCEV *ScalarEvolution::getUMaxExpr(const SCEV *LHS, const SCEV *RHS) {
4407 SmallVector<const SCEV *, 2> Ops = {LHS, RHS};
4408 return getUMaxExpr(Ops);
4409}
4410
4414
4416 const SCEV *RHS) {
4417 SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
4418 return getSMinExpr(Ops);
4419}
4420
4424
4425const SCEV *ScalarEvolution::getUMinExpr(const SCEV *LHS, const SCEV *RHS,
4426 bool Sequential) {
4427 SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
4428 return getUMinExpr(Ops, Sequential);
4429}
4430
4436
4437const SCEV *
4439 const SCEV *Res = getConstant(IntTy, Size.getKnownMinValue());
4440 if (Size.isScalable())
4441 Res = getMulExpr(Res, getVScale(IntTy));
4442 return Res;
4443}
4444
4446 return getSizeOfExpr(IntTy, getDataLayout().getTypeAllocSize(AllocTy));
4447}
4448
4450 return getSizeOfExpr(IntTy, getDataLayout().getTypeStoreSize(StoreTy));
4451}
4452
4454 StructType *STy,
4455 unsigned FieldNo) {
4456 // We can bypass creating a target-independent constant expression and then
4457 // folding it back into a ConstantInt. This is just a compile-time
4458 // optimization.
4459 const StructLayout *SL = getDataLayout().getStructLayout(STy);
4460 assert(!SL->getSizeInBits().isScalable() &&
4461 "Cannot get offset for structure containing scalable vector types");
4462 return getConstant(IntTy, SL->getElementOffset(FieldNo));
4463}
4464
4466 // Don't attempt to do anything other than create a SCEVUnknown object
4467 // here. createSCEV only calls getUnknown after checking for all other
4468 // interesting possibilities, and any other code that calls getUnknown
4469 // is doing so in order to hide a value from SCEV canonicalization.
4470
4472 ID.AddInteger(scUnknown);
4473 ID.AddPointer(V);
4474 void *IP = nullptr;
4475 if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) {
4476 assert(cast<SCEVUnknown>(S)->getValue() == V &&
4477 "Stale SCEVUnknown in uniquing map!");
4478 return S;
4479 }
4480 SCEV *S = new (SCEVAllocator) SCEVUnknown(ID.Intern(SCEVAllocator), V, this,
4481 FirstUnknown);
4482 FirstUnknown = cast<SCEVUnknown>(S);
4483 UniqueSCEVs.InsertNode(S, IP);
4484 return S;
4485}
4486
4487//===----------------------------------------------------------------------===//
4488// Basic SCEV Analysis and PHI Idiom Recognition Code
4489//
4490
4491/// Test if values of the given type are analyzable within the SCEV
4492/// framework. This primarily includes integer types, and it can optionally
4493/// include pointer types if the ScalarEvolution class has access to
4494/// target-specific information.
4496 // Integers and pointers are always SCEVable.
4497 return Ty->isIntOrPtrTy();
4498}
4499
4500/// Return the size in bits of the specified type, for which isSCEVable must
4501/// return true.
4503 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4504 if (Ty->isPointerTy())
4506 return getDataLayout().getTypeSizeInBits(Ty);
4507}
4508
4509/// Return a type with the same bitwidth as the given type and which represents
4510/// how SCEV will treat the given type, for which isSCEVable must return
4511/// true. For pointer types, this is the pointer index sized integer type.
4513 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4514
4515 if (Ty->isIntegerTy())
4516 return Ty;
4517
4518 // The only other support type is pointer.
4519 assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!");
4520 return getDataLayout().getIndexType(Ty);
4521}
4522
4524 return getTypeSizeInBits(T1) >= getTypeSizeInBits(T2) ? T1 : T2;
4525}
4526
4528 const SCEV *B) {
4529 /// For a valid use point to exist, the defining scope of one operand
4530 /// must dominate the other.
4531 bool PreciseA, PreciseB;
4532 auto *ScopeA = getDefiningScopeBound({A}, PreciseA);
4533 auto *ScopeB = getDefiningScopeBound({B}, PreciseB);
4534 if (!PreciseA || !PreciseB)
4535 // Can't tell.
4536 return false;
4537 return (ScopeA == ScopeB) || DT.dominates(ScopeA, ScopeB) ||
4538 DT.dominates(ScopeB, ScopeA);
4539}
4540
4542 return CouldNotCompute.get();
4543}
4544
4545bool ScalarEvolution::checkValidity(const SCEV *S) const {
4546 bool ContainsNulls = SCEVExprContains(S, [](const SCEV *S) {
4547 auto *SU = dyn_cast<SCEVUnknown>(S);
4548 return SU && SU->getValue() == nullptr;
4549 });
4550
4551 return !ContainsNulls;
4552}
4553
4555 HasRecMapType::iterator I = HasRecMap.find(S);
4556 if (I != HasRecMap.end())
4557 return I->second;
4558
4559 bool FoundAddRec =
4560 SCEVExprContains(S, [](const SCEV *S) { return isa<SCEVAddRecExpr>(S); });
4561 HasRecMap.insert({S, FoundAddRec});
4562 return FoundAddRec;
4563}
4564
4565/// Return the ValueOffsetPair set for \p S. \p S can be represented
4566/// by the value and offset from any ValueOffsetPair in the set.
4567ArrayRef<Value *> ScalarEvolution::getSCEVValues(const SCEV *S) {
4568 ExprValueMapType::iterator SI = ExprValueMap.find_as(S);
4569 if (SI == ExprValueMap.end())
4570 return {};
4571 return SI->second.getArrayRef();
4572}
4573
4574/// Erase Value from ValueExprMap and ExprValueMap. ValueExprMap.erase(V)
4575/// cannot be used separately. eraseValueFromMap should be used to remove
4576/// V from ValueExprMap and ExprValueMap at the same time.
4577void ScalarEvolution::eraseValueFromMap(Value *V) {
4578 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4579 if (I != ValueExprMap.end()) {
4580 auto EVIt = ExprValueMap.find(I->second);
4581 bool Removed = EVIt->second.remove(V);
4582 (void) Removed;
4583 assert(Removed && "Value not in ExprValueMap?");
4584 ValueExprMap.erase(I);
4585 }
4586}
4587
4588void ScalarEvolution::insertValueToMap(Value *V, const SCEV *S) {
4589 // A recursive query may have already computed the SCEV. It should be
4590 // equivalent, but may not necessarily be exactly the same, e.g. due to lazily
4591 // inferred nowrap flags.
4592 auto It = ValueExprMap.find_as(V);
4593 if (It == ValueExprMap.end()) {
4594 ValueExprMap.insert({SCEVCallbackVH(V, this), S});
4595 ExprValueMap[S].insert(V);
4596 }
4597}
4598
4599/// Return an existing SCEV if it exists, otherwise analyze the expression and
4600/// create a new one.
4602 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4603
4604 if (const SCEV *S = getExistingSCEV(V))
4605 return S;
4606 return createSCEVIter(V);
4607}
4608
4610 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4611
4612 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4613 if (I != ValueExprMap.end()) {
4614 const SCEV *S = I->second;
4615 assert(checkValidity(S) &&
4616 "existing SCEV has not been properly invalidated");
4617 return S;
4618 }
4619 return nullptr;
4620}
4621
4622/// Return a SCEV corresponding to -V = -1*V
4624 SCEV::NoWrapFlags Flags) {
4625 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4626 return getConstant(
4627 cast<ConstantInt>(ConstantExpr::getNeg(VC->getValue())));
4628
4629 Type *Ty = V->getType();
4630 Ty = getEffectiveSCEVType(Ty);
4631 return getMulExpr(V, getMinusOne(Ty), Flags);
4632}
4633
4634/// If Expr computes ~A, return A else return nullptr
4635static const SCEV *MatchNotExpr(const SCEV *Expr) {
4636 const SCEV *MulOp;
4637 if (match(Expr, m_scev_Add(m_scev_AllOnes(),
4638 m_scev_Mul(m_scev_AllOnes(), m_SCEV(MulOp)))))
4639 return MulOp;
4640 return nullptr;
4641}
4642
4643/// Return a SCEV corresponding to ~V = -1-V
4645 assert(!V->getType()->isPointerTy() && "Can't negate pointer");
4646
4647 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4648 return getConstant(
4649 cast<ConstantInt>(ConstantExpr::getNot(VC->getValue())));
4650
4651 // Fold ~(u|s)(min|max)(~x, ~y) to (u|s)(max|min)(x, y)
4652 if (const SCEVMinMaxExpr *MME = dyn_cast<SCEVMinMaxExpr>(V)) {
4653 auto MatchMinMaxNegation = [&](const SCEVMinMaxExpr *MME) {
4654 SmallVector<const SCEV *, 2> MatchedOperands;
4655 for (const SCEV *Operand : MME->operands()) {
4656 const SCEV *Matched = MatchNotExpr(Operand);
4657 if (!Matched)
4658 return (const SCEV *)nullptr;
4659 MatchedOperands.push_back(Matched);
4660 }
4661 return getMinMaxExpr(SCEVMinMaxExpr::negate(MME->getSCEVType()),
4662 MatchedOperands);
4663 };
4664 if (const SCEV *Replaced = MatchMinMaxNegation(MME))
4665 return Replaced;
4666 }
4667
4668 Type *Ty = V->getType();
4669 Ty = getEffectiveSCEVType(Ty);
4670 return getMinusSCEV(getMinusOne(Ty), V);
4671}
4672
4674 assert(P->getType()->isPointerTy());
4675
4676 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(P)) {
4677 // The base of an AddRec is the first operand.
4678 SmallVector<const SCEV *> Ops{AddRec->operands()};
4679 Ops[0] = removePointerBase(Ops[0]);
4680 // Don't try to transfer nowrap flags for now. We could in some cases
4681 // (for example, if pointer operand of the AddRec is a SCEVUnknown).
4682 return getAddRecExpr(Ops, AddRec->getLoop(), SCEV::FlagAnyWrap);
4683 }
4684 if (auto *Add = dyn_cast<SCEVAddExpr>(P)) {
4685 // The base of an Add is the pointer operand.
4686 SmallVector<const SCEV *> Ops{Add->operands()};
4687 const SCEV **PtrOp = nullptr;
4688 for (const SCEV *&AddOp : Ops) {
4689 if (AddOp->getType()->isPointerTy()) {
4690 assert(!PtrOp && "Cannot have multiple pointer ops");
4691 PtrOp = &AddOp;
4692 }
4693 }
4694 *PtrOp = removePointerBase(*PtrOp);
4695 // Don't try to transfer nowrap flags for now. We could in some cases
4696 // (for example, if the pointer operand of the Add is a SCEVUnknown).
4697 return getAddExpr(Ops);
4698 }
4699 // Any other expression must be a pointer base.
4700 return getZero(P->getType());
4701}
4702
4703const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS,
4704 SCEV::NoWrapFlags Flags,
4705 unsigned Depth) {
4706 // Fast path: X - X --> 0.
4707 if (LHS == RHS)
4708 return getZero(LHS->getType());
4709
4710 // If we subtract two pointers with different pointer bases, bail.
4711 // Eventually, we're going to add an assertion to getMulExpr that we
4712 // can't multiply by a pointer.
4713 if (RHS->getType()->isPointerTy()) {
4714 if (!LHS->getType()->isPointerTy() ||
4715 getPointerBase(LHS) != getPointerBase(RHS))
4716 return getCouldNotCompute();
4717 LHS = removePointerBase(LHS);
4718 RHS = removePointerBase(RHS);
4719 }
4720
4721 // We represent LHS - RHS as LHS + (-1)*RHS. This transformation
4722 // makes it so that we cannot make much use of NUW.
4723 auto AddFlags = SCEV::FlagAnyWrap;
4724 const bool RHSIsNotMinSigned =
4726 if (hasFlags(Flags, SCEV::FlagNSW)) {
4727 // Let M be the minimum representable signed value. Then (-1)*RHS
4728 // signed-wraps if and only if RHS is M. That can happen even for
4729 // a NSW subtraction because e.g. (-1)*M signed-wraps even though
4730 // -1 - M does not. So to transfer NSW from LHS - RHS to LHS +
4731 // (-1)*RHS, we need to prove that RHS != M.
4732 //
4733 // If LHS is non-negative and we know that LHS - RHS does not
4734 // signed-wrap, then RHS cannot be M. So we can rule out signed-wrap
4735 // either by proving that RHS > M or that LHS >= 0.
4736 if (RHSIsNotMinSigned || isKnownNonNegative(LHS)) {
4737 AddFlags = SCEV::FlagNSW;
4738 }
4739 }
4740
4741 // FIXME: Find a correct way to transfer NSW to (-1)*M when LHS -
4742 // RHS is NSW and LHS >= 0.
4743 //
4744 // The difficulty here is that the NSW flag may have been proven
4745 // relative to a loop that is to be found in a recurrence in LHS and
4746 // not in RHS. Applying NSW to (-1)*M may then let the NSW have a
4747 // larger scope than intended.
4748 auto NegFlags = RHSIsNotMinSigned ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
4749
4750 return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags, Depth);
4751}
4752
4754 unsigned Depth) {
4755 Type *SrcTy = V->getType();
4756 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4757 "Cannot truncate or zero extend with non-integer arguments!");
4758 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4759 return V; // No conversion
4760 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4761 return getTruncateExpr(V, Ty, Depth);
4762 return getZeroExtendExpr(V, Ty, Depth);
4763}
4764
4766 unsigned Depth) {
4767 Type *SrcTy = V->getType();
4768 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4769 "Cannot truncate or zero extend with non-integer arguments!");
4770 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4771 return V; // No conversion
4772 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4773 return getTruncateExpr(V, Ty, Depth);
4774 return getSignExtendExpr(V, Ty, Depth);
4775}
4776
4777const SCEV *
4779 Type *SrcTy = V->getType();
4780 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4781 "Cannot noop or zero extend with non-integer arguments!");
4783 "getNoopOrZeroExtend cannot truncate!");
4784 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4785 return V; // No conversion
4786 return getZeroExtendExpr(V, Ty);
4787}
4788
4789const SCEV *
4791 Type *SrcTy = V->getType();
4792 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4793 "Cannot noop or sign extend with non-integer arguments!");
4795 "getNoopOrSignExtend cannot truncate!");
4796 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4797 return V; // No conversion
4798 return getSignExtendExpr(V, Ty);
4799}
4800
4801const SCEV *
4803 Type *SrcTy = V->getType();
4804 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4805 "Cannot noop or any extend with non-integer arguments!");
4807 "getNoopOrAnyExtend cannot truncate!");
4808 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4809 return V; // No conversion
4810 return getAnyExtendExpr(V, Ty);
4811}
4812
4813const SCEV *
4815 Type *SrcTy = V->getType();
4816 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4817 "Cannot truncate or noop with non-integer arguments!");
4819 "getTruncateOrNoop cannot extend!");
4820 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4821 return V; // No conversion
4822 return getTruncateExpr(V, Ty);
4823}
4824
4826 const SCEV *RHS) {
4827 const SCEV *PromotedLHS = LHS;
4828 const SCEV *PromotedRHS = RHS;
4829
4830 if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType()))
4831 PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
4832 else
4833 PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
4834
4835 return getUMaxExpr(PromotedLHS, PromotedRHS);
4836}
4837
4839 const SCEV *RHS,
4840 bool Sequential) {
4841 SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
4842 return getUMinFromMismatchedTypes(Ops, Sequential);
4843}
4844
4845const SCEV *
4847 bool Sequential) {
4848 assert(!Ops.empty() && "At least one operand must be!");
4849 // Trivial case.
4850 if (Ops.size() == 1)
4851 return Ops[0];
4852
4853 // Find the max type first.
4854 Type *MaxType = nullptr;
4855 for (const auto *S : Ops)
4856 if (MaxType)
4857 MaxType = getWiderType(MaxType, S->getType());
4858 else
4859 MaxType = S->getType();
4860 assert(MaxType && "Failed to find maximum type!");
4861
4862 // Extend all ops to max type.
4863 SmallVector<const SCEV *, 2> PromotedOps;
4864 for (const auto *S : Ops)
4865 PromotedOps.push_back(getNoopOrZeroExtend(S, MaxType));
4866
4867 // Generate umin.
4868 return getUMinExpr(PromotedOps, Sequential);
4869}
4870
4872 // A pointer operand may evaluate to a nonpointer expression, such as null.
4873 if (!V->getType()->isPointerTy())
4874 return V;
4875
4876 while (true) {
4877 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(V)) {
4878 V = AddRec->getStart();
4879 } else if (auto *Add = dyn_cast<SCEVAddExpr>(V)) {
4880 const SCEV *PtrOp = nullptr;
4881 for (const SCEV *AddOp : Add->operands()) {
4882 if (AddOp->getType()->isPointerTy()) {
4883 assert(!PtrOp && "Cannot have multiple pointer ops");
4884 PtrOp = AddOp;
4885 }
4886 }
4887 assert(PtrOp && "Must have pointer op");
4888 V = PtrOp;
4889 } else // Not something we can look further into.
4890 return V;
4891 }
4892}
4893
4894/// Push users of the given Instruction onto the given Worklist.
4898 // Push the def-use children onto the Worklist stack.
4899 for (User *U : I->users()) {
4900 auto *UserInsn = cast<Instruction>(U);
4901 if (Visited.insert(UserInsn).second)
4902 Worklist.push_back(UserInsn);
4903 }
4904}
4905
4906namespace {
4907
4908/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its start
4909/// expression in case its Loop is L. If it is not L then
4910/// if IgnoreOtherLoops is true then use AddRec itself
4911/// otherwise rewrite cannot be done.
4912/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
4913class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> {
4914public:
4915 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
4916 bool IgnoreOtherLoops = true) {
4917 SCEVInitRewriter Rewriter(L, SE);
4918 const SCEV *Result = Rewriter.visit(S);
4919 if (Rewriter.hasSeenLoopVariantSCEVUnknown())
4920 return SE.getCouldNotCompute();
4921 return Rewriter.hasSeenOtherLoops() && !IgnoreOtherLoops
4922 ? SE.getCouldNotCompute()
4923 : Result;
4924 }
4925
4926 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4927 if (!SE.isLoopInvariant(Expr, L))
4928 SeenLoopVariantSCEVUnknown = true;
4929 return Expr;
4930 }
4931
4932 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4933 // Only re-write AddRecExprs for this loop.
4934 if (Expr->getLoop() == L)
4935 return Expr->getStart();
4936 SeenOtherLoops = true;
4937 return Expr;
4938 }
4939
4940 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
4941
4942 bool hasSeenOtherLoops() { return SeenOtherLoops; }
4943
4944private:
4945 explicit SCEVInitRewriter(const Loop *L, ScalarEvolution &SE)
4946 : SCEVRewriteVisitor(SE), L(L) {}
4947
4948 const Loop *L;
4949 bool SeenLoopVariantSCEVUnknown = false;
4950 bool SeenOtherLoops = false;
4951};
4952
4953/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its post
4954/// increment expression in case its Loop is L. If it is not L then
4955/// use AddRec itself.
4956/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
4957class SCEVPostIncRewriter : public SCEVRewriteVisitor<SCEVPostIncRewriter> {
4958public:
4959 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE) {
4960 SCEVPostIncRewriter Rewriter(L, SE);
4961 const SCEV *Result = Rewriter.visit(S);
4962 return Rewriter.hasSeenLoopVariantSCEVUnknown()
4963 ? SE.getCouldNotCompute()
4964 : Result;
4965 }
4966
4967 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4968 if (!SE.isLoopInvariant(Expr, L))
4969 SeenLoopVariantSCEVUnknown = true;
4970 return Expr;
4971 }
4972
4973 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4974 // Only re-write AddRecExprs for this loop.
4975 if (Expr->getLoop() == L)
4976 return Expr->getPostIncExpr(SE);
4977 SeenOtherLoops = true;
4978 return Expr;
4979 }
4980
4981 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
4982
4983 bool hasSeenOtherLoops() { return SeenOtherLoops; }
4984
4985private:
4986 explicit SCEVPostIncRewriter(const Loop *L, ScalarEvolution &SE)
4987 : SCEVRewriteVisitor(SE), L(L) {}
4988
4989 const Loop *L;
4990 bool SeenLoopVariantSCEVUnknown = false;
4991 bool SeenOtherLoops = false;
4992};
4993
4994/// This class evaluates the compare condition by matching it against the
4995/// condition of loop latch. If there is a match we assume a true value
4996/// for the condition while building SCEV nodes.
4997class SCEVBackedgeConditionFolder
4998 : public SCEVRewriteVisitor<SCEVBackedgeConditionFolder> {
4999public:
5000 static const SCEV *rewrite(const SCEV *S, const Loop *L,
5001 ScalarEvolution &SE) {
5002 bool IsPosBECond = false;
5003 Value *BECond = nullptr;
5004 if (BasicBlock *Latch = L->getLoopLatch()) {
5005 BranchInst *BI = dyn_cast<BranchInst>(Latch->getTerminator());
5006 if (BI && BI->isConditional()) {
5007 assert(BI->getSuccessor(0) != BI->getSuccessor(1) &&
5008 "Both outgoing branches should not target same header!");
5009 BECond = BI->getCondition();
5010 IsPosBECond = BI->getSuccessor(0) == L->getHeader();
5011 } else {
5012 return S;
5013 }
5014 }
5015 SCEVBackedgeConditionFolder Rewriter(L, BECond, IsPosBECond, SE);
5016 return Rewriter.visit(S);
5017 }
5018
5019 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5020 const SCEV *Result = Expr;
5021 bool InvariantF = SE.isLoopInvariant(Expr, L);
5022
5023 if (!InvariantF) {
5025 switch (I->getOpcode()) {
5026 case Instruction::Select: {
5027 SelectInst *SI = cast<SelectInst>(I);
5028 std::optional<const SCEV *> Res =
5029 compareWithBackedgeCondition(SI->getCondition());
5030 if (Res) {
5031 bool IsOne = cast<SCEVConstant>(*Res)->getValue()->isOne();
5032 Result = SE.getSCEV(IsOne ? SI->getTrueValue() : SI->getFalseValue());
5033 }
5034 break;
5035 }
5036 default: {
5037 std::optional<const SCEV *> Res = compareWithBackedgeCondition(I);
5038 if (Res)
5039 Result = *Res;
5040 break;
5041 }
5042 }
5043 }
5044 return Result;
5045 }
5046
5047private:
5048 explicit SCEVBackedgeConditionFolder(const Loop *L, Value *BECond,
5049 bool IsPosBECond, ScalarEvolution &SE)
5050 : SCEVRewriteVisitor(SE), L(L), BackedgeCond(BECond),
5051 IsPositiveBECond(IsPosBECond) {}
5052
5053 std::optional<const SCEV *> compareWithBackedgeCondition(Value *IC);
5054
5055 const Loop *L;
5056 /// Loop back condition.
5057 Value *BackedgeCond = nullptr;
5058 /// Set to true if loop back is on positive branch condition.
5059 bool IsPositiveBECond;
5060};
5061
5062std::optional<const SCEV *>
5063SCEVBackedgeConditionFolder::compareWithBackedgeCondition(Value *IC) {
5064
5065 // If value matches the backedge condition for loop latch,
5066 // then return a constant evolution node based on loopback
5067 // branch taken.
5068 if (BackedgeCond == IC)
5069 return IsPositiveBECond ? SE.getOne(Type::getInt1Ty(SE.getContext()))
5071 return std::nullopt;
5072}
5073
5074class SCEVShiftRewriter : public SCEVRewriteVisitor<SCEVShiftRewriter> {
5075public:
5076 static const SCEV *rewrite(const SCEV *S, const Loop *L,
5077 ScalarEvolution &SE) {
5078 SCEVShiftRewriter Rewriter(L, SE);
5079 const SCEV *Result = Rewriter.visit(S);
5080 return Rewriter.isValid() ? Result : SE.getCouldNotCompute();
5081 }
5082
5083 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5084 // Only allow AddRecExprs for this loop.
5085 if (!SE.isLoopInvariant(Expr, L))
5086 Valid = false;
5087 return Expr;
5088 }
5089
5090 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
5091 if (Expr->getLoop() == L && Expr->isAffine())
5092 return SE.getMinusSCEV(Expr, Expr->getStepRecurrence(SE));
5093 Valid = false;
5094 return Expr;
5095 }
5096
5097 bool isValid() { return Valid; }
5098
5099private:
5100 explicit SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE)
5101 : SCEVRewriteVisitor(SE), L(L) {}
5102
5103 const Loop *L;
5104 bool Valid = true;
5105};
5106
5107} // end anonymous namespace
5108
5110ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) {
5111 if (!AR->isAffine())
5112 return SCEV::FlagAnyWrap;
5113
5114 using OBO = OverflowingBinaryOperator;
5115
5117
5118 if (!AR->hasNoSelfWrap()) {
5119 const SCEV *BECount = getConstantMaxBackedgeTakenCount(AR->getLoop());
5120 if (const SCEVConstant *BECountMax = dyn_cast<SCEVConstant>(BECount)) {
5121 ConstantRange StepCR = getSignedRange(AR->getStepRecurrence(*this));
5122 const APInt &BECountAP = BECountMax->getAPInt();
5123 unsigned NoOverflowBitWidth =
5124 BECountAP.getActiveBits() + StepCR.getMinSignedBits();
5125 if (NoOverflowBitWidth <= getTypeSizeInBits(AR->getType()))
5127 }
5128 }
5129
5130 if (!AR->hasNoSignedWrap()) {
5131 ConstantRange AddRecRange = getSignedRange(AR);
5132 ConstantRange IncRange = getSignedRange(AR->getStepRecurrence(*this));
5133
5135 Instruction::Add, IncRange, OBO::NoSignedWrap);
5136 if (NSWRegion.contains(AddRecRange))
5138 }
5139
5140 if (!AR->hasNoUnsignedWrap()) {
5141 ConstantRange AddRecRange = getUnsignedRange(AR);
5142 ConstantRange IncRange = getUnsignedRange(AR->getStepRecurrence(*this));
5143
5145 Instruction::Add, IncRange, OBO::NoUnsignedWrap);
5146 if (NUWRegion.contains(AddRecRange))
5148 }
5149
5150 return Result;
5151}
5152
5154ScalarEvolution::proveNoSignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5156
5157 if (AR->hasNoSignedWrap())
5158 return Result;
5159
5160 if (!AR->isAffine())
5161 return Result;
5162
5163 // This function can be expensive, only try to prove NSW once per AddRec.
5164 if (!SignedWrapViaInductionTried.insert(AR).second)
5165 return Result;
5166
5167 const SCEV *Step = AR->getStepRecurrence(*this);
5168 const Loop *L = AR->getLoop();
5169
5170 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5171 // Note that this serves two purposes: It filters out loops that are
5172 // simply not analyzable, and it covers the case where this code is
5173 // being called from within backedge-taken count analysis, such that
5174 // attempting to ask for the backedge-taken count would likely result
5175 // in infinite recursion. In the later case, the analysis code will
5176 // cope with a conservative value, and it will take care to purge
5177 // that value once it has finished.
5178 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5179
5180 // Normally, in the cases we can prove no-overflow via a
5181 // backedge guarding condition, we can also compute a backedge
5182 // taken count for the loop. The exceptions are assumptions and
5183 // guards present in the loop -- SCEV is not great at exploiting
5184 // these to compute max backedge taken counts, but can still use
5185 // these to prove lack of overflow. Use this fact to avoid
5186 // doing extra work that may not pay off.
5187
5188 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5189 AC.assumptions().empty())
5190 return Result;
5191
5192 // If the backedge is guarded by a comparison with the pre-inc value the
5193 // addrec is safe. Also, if the entry is guarded by a comparison with the
5194 // start value and the backedge is guarded by a comparison with the post-inc
5195 // value, the addrec is safe.
5197 const SCEV *OverflowLimit =
5198 getSignedOverflowLimitForStep(Step, &Pred, this);
5199 if (OverflowLimit &&
5200 (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) ||
5201 isKnownOnEveryIteration(Pred, AR, OverflowLimit))) {
5202 Result = setFlags(Result, SCEV::FlagNSW);
5203 }
5204 return Result;
5205}
5207ScalarEvolution::proveNoUnsignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5209
5210 if (AR->hasNoUnsignedWrap())
5211 return Result;
5212
5213 if (!AR->isAffine())
5214 return Result;
5215
5216 // This function can be expensive, only try to prove NUW once per AddRec.
5217 if (!UnsignedWrapViaInductionTried.insert(AR).second)
5218 return Result;
5219
5220 const SCEV *Step = AR->getStepRecurrence(*this);
5221 unsigned BitWidth = getTypeSizeInBits(AR->getType());
5222 const Loop *L = AR->getLoop();
5223
5224 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5225 // Note that this serves two purposes: It filters out loops that are
5226 // simply not analyzable, and it covers the case where this code is
5227 // being called from within backedge-taken count analysis, such that
5228 // attempting to ask for the backedge-taken count would likely result
5229 // in infinite recursion. In the later case, the analysis code will
5230 // cope with a conservative value, and it will take care to purge
5231 // that value once it has finished.
5232 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5233
5234 // Normally, in the cases we can prove no-overflow via a
5235 // backedge guarding condition, we can also compute a backedge
5236 // taken count for the loop. The exceptions are assumptions and
5237 // guards present in the loop -- SCEV is not great at exploiting
5238 // these to compute max backedge taken counts, but can still use
5239 // these to prove lack of overflow. Use this fact to avoid
5240 // doing extra work that may not pay off.
5241
5242 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5243 AC.assumptions().empty())
5244 return Result;
5245
5246 // If the backedge is guarded by a comparison with the pre-inc value the
5247 // addrec is safe. Also, if the entry is guarded by a comparison with the
5248 // start value and the backedge is guarded by a comparison with the post-inc
5249 // value, the addrec is safe.
5250 if (isKnownPositive(Step)) {
5251 const SCEV *N = getConstant(APInt::getMinValue(BitWidth) -
5252 getUnsignedRangeMax(Step));
5255 Result = setFlags(Result, SCEV::FlagNUW);
5256 }
5257 }
5258
5259 return Result;
5260}
5261
5262namespace {
5263
5264/// Represents an abstract binary operation. This may exist as a
5265/// normal instruction or constant expression, or may have been
5266/// derived from an expression tree.
5267struct BinaryOp {
5268 unsigned Opcode;
5269 Value *LHS;
5270 Value *RHS;
5271 bool IsNSW = false;
5272 bool IsNUW = false;
5273
5274 /// Op is set if this BinaryOp corresponds to a concrete LLVM instruction or
5275 /// constant expression.
5276 Operator *Op = nullptr;
5277
5278 explicit BinaryOp(Operator *Op)
5279 : Opcode(Op->getOpcode()), LHS(Op->getOperand(0)), RHS(Op->getOperand(1)),
5280 Op(Op) {
5281 if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Op)) {
5282 IsNSW = OBO->hasNoSignedWrap();
5283 IsNUW = OBO->hasNoUnsignedWrap();
5284 }
5285 }
5286
5287 explicit BinaryOp(unsigned Opcode, Value *LHS, Value *RHS, bool IsNSW = false,
5288 bool IsNUW = false)
5289 : Opcode(Opcode), LHS(LHS), RHS(RHS), IsNSW(IsNSW), IsNUW(IsNUW) {}
5290};
5291
5292} // end anonymous namespace
5293
5294/// Try to map \p V into a BinaryOp, and return \c std::nullopt on failure.
5295static std::optional<BinaryOp> MatchBinaryOp(Value *V, const DataLayout &DL,
5296 AssumptionCache &AC,
5297 const DominatorTree &DT,
5298 const Instruction *CxtI) {
5299 auto *Op = dyn_cast<Operator>(V);
5300 if (!Op)
5301 return std::nullopt;
5302
5303 // Implementation detail: all the cleverness here should happen without
5304 // creating new SCEV expressions -- our caller knowns tricks to avoid creating
5305 // SCEV expressions when possible, and we should not break that.
5306
5307 switch (Op->getOpcode()) {
5308 case Instruction::Add:
5309 case Instruction::Sub:
5310 case Instruction::Mul:
5311 case Instruction::UDiv:
5312 case Instruction::URem:
5313 case Instruction::And:
5314 case Instruction::AShr:
5315 case Instruction::Shl:
5316 return BinaryOp(Op);
5317
5318 case Instruction::Or: {
5319 // Convert or disjoint into add nuw nsw.
5320 if (cast<PossiblyDisjointInst>(Op)->isDisjoint())
5321 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1),
5322 /*IsNSW=*/true, /*IsNUW=*/true);
5323 return BinaryOp(Op);
5324 }
5325
5326 case Instruction::Xor:
5327 if (auto *RHSC = dyn_cast<ConstantInt>(Op->getOperand(1)))
5328 // If the RHS of the xor is a signmask, then this is just an add.
5329 // Instcombine turns add of signmask into xor as a strength reduction step.
5330 if (RHSC->getValue().isSignMask())
5331 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5332 // Binary `xor` is a bit-wise `add`.
5333 if (V->getType()->isIntegerTy(1))
5334 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5335 return BinaryOp(Op);
5336
5337 case Instruction::LShr:
5338 // Turn logical shift right of a constant into a unsigned divide.
5339 if (ConstantInt *SA = dyn_cast<ConstantInt>(Op->getOperand(1))) {
5340 uint32_t BitWidth = cast<IntegerType>(Op->getType())->getBitWidth();
5341
5342 // If the shift count is not less than the bitwidth, the result of
5343 // the shift is undefined. Don't try to analyze it, because the
5344 // resolution chosen here may differ from the resolution chosen in
5345 // other parts of the compiler.
5346 if (SA->getValue().ult(BitWidth)) {
5347 Constant *X =
5348 ConstantInt::get(SA->getContext(),
5349 APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
5350 return BinaryOp(Instruction::UDiv, Op->getOperand(0), X);
5351 }
5352 }
5353 return BinaryOp(Op);
5354
5355 case Instruction::ExtractValue: {
5356 auto *EVI = cast<ExtractValueInst>(Op);
5357 if (EVI->getNumIndices() != 1 || EVI->getIndices()[0] != 0)
5358 break;
5359
5360 auto *WO = dyn_cast<WithOverflowInst>(EVI->getAggregateOperand());
5361 if (!WO)
5362 break;
5363
5364 Instruction::BinaryOps BinOp = WO->getBinaryOp();
5365 bool Signed = WO->isSigned();
5366 // TODO: Should add nuw/nsw flags for mul as well.
5367 if (BinOp == Instruction::Mul || !isOverflowIntrinsicNoWrap(WO, DT))
5368 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS());
5369
5370 // Now that we know that all uses of the arithmetic-result component of
5371 // CI are guarded by the overflow check, we can go ahead and pretend
5372 // that the arithmetic is non-overflowing.
5373 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS(),
5374 /* IsNSW = */ Signed, /* IsNUW = */ !Signed);
5375 }
5376
5377 default:
5378 break;
5379 }
5380
5381 // Recognise intrinsic loop.decrement.reg, and as this has exactly the same
5382 // semantics as a Sub, return a binary sub expression.
5383 if (auto *II = dyn_cast<IntrinsicInst>(V))
5384 if (II->getIntrinsicID() == Intrinsic::loop_decrement_reg)
5385 return BinaryOp(Instruction::Sub, II->getOperand(0), II->getOperand(1));
5386
5387 return std::nullopt;
5388}
5389
5390/// Helper function to createAddRecFromPHIWithCasts. We have a phi
5391/// node whose symbolic (unknown) SCEV is \p SymbolicPHI, which is updated via
5392/// the loop backedge by a SCEVAddExpr, possibly also with a few casts on the
5393/// way. This function checks if \p Op, an operand of this SCEVAddExpr,
5394/// follows one of the following patterns:
5395/// Op == (SExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5396/// Op == (ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5397/// If the SCEV expression of \p Op conforms with one of the expected patterns
5398/// we return the type of the truncation operation, and indicate whether the
5399/// truncated type should be treated as signed/unsigned by setting
5400/// \p Signed to true/false, respectively.
5401static Type *isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI,
5402 bool &Signed, ScalarEvolution &SE) {
5403 // The case where Op == SymbolicPHI (that is, with no type conversions on
5404 // the way) is handled by the regular add recurrence creating logic and
5405 // would have already been triggered in createAddRecForPHI. Reaching it here
5406 // means that createAddRecFromPHI had failed for this PHI before (e.g.,
5407 // because one of the other operands of the SCEVAddExpr updating this PHI is
5408 // not invariant).
5409 //
5410 // Here we look for the case where Op = (ext(trunc(SymbolicPHI))), and in
5411 // this case predicates that allow us to prove that Op == SymbolicPHI will
5412 // be added.
5413 if (Op == SymbolicPHI)
5414 return nullptr;
5415
5416 unsigned SourceBits = SE.getTypeSizeInBits(SymbolicPHI->getType());
5417 unsigned NewBits = SE.getTypeSizeInBits(Op->getType());
5418 if (SourceBits != NewBits)
5419 return nullptr;
5420
5421 if (match(Op, m_scev_SExt(m_scev_Trunc(m_scev_Specific(SymbolicPHI))))) {
5422 Signed = true;
5423 return cast<SCEVCastExpr>(Op)->getOperand()->getType();
5424 }
5425 if (match(Op, m_scev_ZExt(m_scev_Trunc(m_scev_Specific(SymbolicPHI))))) {
5426 Signed = false;
5427 return cast<SCEVCastExpr>(Op)->getOperand()->getType();
5428 }
5429 return nullptr;
5430}
5431
5432static const Loop *isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI) {
5433 if (!PN->getType()->isIntegerTy())
5434 return nullptr;
5435 const Loop *L = LI.getLoopFor(PN->getParent());
5436 if (!L || L->getHeader() != PN->getParent())
5437 return nullptr;
5438 return L;
5439}
5440
5441// Analyze \p SymbolicPHI, a SCEV expression of a phi node, and check if the
5442// computation that updates the phi follows the following pattern:
5443// (SExt/ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy) + InvariantAccum
5444// which correspond to a phi->trunc->sext/zext->add->phi update chain.
5445// If so, try to see if it can be rewritten as an AddRecExpr under some
5446// Predicates. If successful, return them as a pair. Also cache the results
5447// of the analysis.
5448//
5449// Example usage scenario:
5450// Say the Rewriter is called for the following SCEV:
5451// 8 * ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5452// where:
5453// %X = phi i64 (%Start, %BEValue)
5454// It will visitMul->visitAdd->visitSExt->visitTrunc->visitUnknown(%X),
5455// and call this function with %SymbolicPHI = %X.
5456//
5457// The analysis will find that the value coming around the backedge has
5458// the following SCEV:
5459// BEValue = ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5460// Upon concluding that this matches the desired pattern, the function
5461// will return the pair {NewAddRec, SmallPredsVec} where:
5462// NewAddRec = {%Start,+,%Step}
5463// SmallPredsVec = {P1, P2, P3} as follows:
5464// P1(WrapPred): AR: {trunc(%Start),+,(trunc %Step)}<nsw> Flags: <nssw>
5465// P2(EqualPred): %Start == (sext i32 (trunc i64 %Start to i32) to i64)
5466// P3(EqualPred): %Step == (sext i32 (trunc i64 %Step to i32) to i64)
5467// The returned pair means that SymbolicPHI can be rewritten into NewAddRec
5468// under the predicates {P1,P2,P3}.
5469// This predicated rewrite will be cached in PredicatedSCEVRewrites:
5470// PredicatedSCEVRewrites[{%X,L}] = {NewAddRec, {P1,P2,P3)}
5471//
5472// TODO's:
5473//
5474// 1) Extend the Induction descriptor to also support inductions that involve
5475// casts: When needed (namely, when we are called in the context of the
5476// vectorizer induction analysis), a Set of cast instructions will be
5477// populated by this method, and provided back to isInductionPHI. This is
5478// needed to allow the vectorizer to properly record them to be ignored by
5479// the cost model and to avoid vectorizing them (otherwise these casts,
5480// which are redundant under the runtime overflow checks, will be
5481// vectorized, which can be costly).
5482//
5483// 2) Support additional induction/PHISCEV patterns: We also want to support
5484// inductions where the sext-trunc / zext-trunc operations (partly) occur
5485// after the induction update operation (the induction increment):
5486//
5487// (Trunc iy (SExt/ZExt ix (%SymbolicPHI + InvariantAccum) to iy) to ix)
5488// which correspond to a phi->add->trunc->sext/zext->phi update chain.
5489//
5490// (Trunc iy ((SExt/ZExt ix (%SymbolicPhi) to iy) + InvariantAccum) to ix)
5491// which correspond to a phi->trunc->add->sext/zext->phi update chain.
5492//
5493// 3) Outline common code with createAddRecFromPHI to avoid duplication.
5494std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5495ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI) {
5497
5498 // *** Part1: Analyze if we have a phi-with-cast pattern for which we can
5499 // return an AddRec expression under some predicate.
5500
5501 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5502 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5503 assert(L && "Expecting an integer loop header phi");
5504
5505 // The loop may have multiple entrances or multiple exits; we can analyze
5506 // this phi as an addrec if it has a unique entry value and a unique
5507 // backedge value.
5508 Value *BEValueV = nullptr, *StartValueV = nullptr;
5509 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5510 Value *V = PN->getIncomingValue(i);
5511 if (L->contains(PN->getIncomingBlock(i))) {
5512 if (!BEValueV) {
5513 BEValueV = V;
5514 } else if (BEValueV != V) {
5515 BEValueV = nullptr;
5516 break;
5517 }
5518 } else if (!StartValueV) {
5519 StartValueV = V;
5520 } else if (StartValueV != V) {
5521 StartValueV = nullptr;
5522 break;
5523 }
5524 }
5525 if (!BEValueV || !StartValueV)
5526 return std::nullopt;
5527
5528 const SCEV *BEValue = getSCEV(BEValueV);
5529
5530 // If the value coming around the backedge is an add with the symbolic
5531 // value we just inserted, possibly with casts that we can ignore under
5532 // an appropriate runtime guard, then we found a simple induction variable!
5533 const auto *Add = dyn_cast<SCEVAddExpr>(BEValue);
5534 if (!Add)
5535 return std::nullopt;
5536
5537 // If there is a single occurrence of the symbolic value, possibly
5538 // casted, replace it with a recurrence.
5539 unsigned FoundIndex = Add->getNumOperands();
5540 Type *TruncTy = nullptr;
5541 bool Signed;
5542 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5543 if ((TruncTy =
5544 isSimpleCastedPHI(Add->getOperand(i), SymbolicPHI, Signed, *this)))
5545 if (FoundIndex == e) {
5546 FoundIndex = i;
5547 break;
5548 }
5549
5550 if (FoundIndex == Add->getNumOperands())
5551 return std::nullopt;
5552
5553 // Create an add with everything but the specified operand.
5555 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5556 if (i != FoundIndex)
5557 Ops.push_back(Add->getOperand(i));
5558 const SCEV *Accum = getAddExpr(Ops);
5559
5560 // The runtime checks will not be valid if the step amount is
5561 // varying inside the loop.
5562 if (!isLoopInvariant(Accum, L))
5563 return std::nullopt;
5564
5565 // *** Part2: Create the predicates
5566
5567 // Analysis was successful: we have a phi-with-cast pattern for which we
5568 // can return an AddRec expression under the following predicates:
5569 //
5570 // P1: A Wrap predicate that guarantees that Trunc(Start) + i*Trunc(Accum)
5571 // fits within the truncated type (does not overflow) for i = 0 to n-1.
5572 // P2: An Equal predicate that guarantees that
5573 // Start = (Ext ix (Trunc iy (Start) to ix) to iy)
5574 // P3: An Equal predicate that guarantees that
5575 // Accum = (Ext ix (Trunc iy (Accum) to ix) to iy)
5576 //
5577 // As we next prove, the above predicates guarantee that:
5578 // Start + i*Accum = (Ext ix (Trunc iy ( Start + i*Accum ) to ix) to iy)
5579 //
5580 //
5581 // More formally, we want to prove that:
5582 // Expr(i+1) = Start + (i+1) * Accum
5583 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5584 //
5585 // Given that:
5586 // 1) Expr(0) = Start
5587 // 2) Expr(1) = Start + Accum
5588 // = (Ext ix (Trunc iy (Start) to ix) to iy) + Accum :: from P2
5589 // 3) Induction hypothesis (step i):
5590 // Expr(i) = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum
5591 //
5592 // Proof:
5593 // Expr(i+1) =
5594 // = Start + (i+1)*Accum
5595 // = (Start + i*Accum) + Accum
5596 // = Expr(i) + Accum
5597 // = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum + Accum
5598 // :: from step i
5599 //
5600 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy) + Accum + Accum
5601 //
5602 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy)
5603 // + (Ext ix (Trunc iy (Accum) to ix) to iy)
5604 // + Accum :: from P3
5605 //
5606 // = (Ext ix (Trunc iy ((Start + (i-1)*Accum) + Accum) to ix) to iy)
5607 // + Accum :: from P1: Ext(x)+Ext(y)=>Ext(x+y)
5608 //
5609 // = (Ext ix (Trunc iy (Start + i*Accum) to ix) to iy) + Accum
5610 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5611 //
5612 // By induction, the same applies to all iterations 1<=i<n:
5613 //
5614
5615 // Create a truncated addrec for which we will add a no overflow check (P1).
5616 const SCEV *StartVal = getSCEV(StartValueV);
5617 const SCEV *PHISCEV =
5618 getAddRecExpr(getTruncateExpr(StartVal, TruncTy),
5619 getTruncateExpr(Accum, TruncTy), L, SCEV::FlagAnyWrap);
5620
5621 // PHISCEV can be either a SCEVConstant or a SCEVAddRecExpr.
5622 // ex: If truncated Accum is 0 and StartVal is a constant, then PHISCEV
5623 // will be constant.
5624 //
5625 // If PHISCEV is a constant, then P1 degenerates into P2 or P3, so we don't
5626 // add P1.
5627 if (const auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5631 const SCEVPredicate *AddRecPred = getWrapPredicate(AR, AddedFlags);
5632 Predicates.push_back(AddRecPred);
5633 }
5634
5635 // Create the Equal Predicates P2,P3:
5636
5637 // It is possible that the predicates P2 and/or P3 are computable at
5638 // compile time due to StartVal and/or Accum being constants.
5639 // If either one is, then we can check that now and escape if either P2
5640 // or P3 is false.
5641
5642 // Construct the extended SCEV: (Ext ix (Trunc iy (Expr) to ix) to iy)
5643 // for each of StartVal and Accum
5644 auto getExtendedExpr = [&](const SCEV *Expr,
5645 bool CreateSignExtend) -> const SCEV * {
5646 assert(isLoopInvariant(Expr, L) && "Expr is expected to be invariant");
5647 const SCEV *TruncatedExpr = getTruncateExpr(Expr, TruncTy);
5648 const SCEV *ExtendedExpr =
5649 CreateSignExtend ? getSignExtendExpr(TruncatedExpr, Expr->getType())
5650 : getZeroExtendExpr(TruncatedExpr, Expr->getType());
5651 return ExtendedExpr;
5652 };
5653
5654 // Given:
5655 // ExtendedExpr = (Ext ix (Trunc iy (Expr) to ix) to iy
5656 // = getExtendedExpr(Expr)
5657 // Determine whether the predicate P: Expr == ExtendedExpr
5658 // is known to be false at compile time
5659 auto PredIsKnownFalse = [&](const SCEV *Expr,
5660 const SCEV *ExtendedExpr) -> bool {
5661 return Expr != ExtendedExpr &&
5662 isKnownPredicate(ICmpInst::ICMP_NE, Expr, ExtendedExpr);
5663 };
5664
5665 const SCEV *StartExtended = getExtendedExpr(StartVal, Signed);
5666 if (PredIsKnownFalse(StartVal, StartExtended)) {
5667 LLVM_DEBUG(dbgs() << "P2 is compile-time false\n";);
5668 return std::nullopt;
5669 }
5670
5671 // The Step is always Signed (because the overflow checks are either
5672 // NSSW or NUSW)
5673 const SCEV *AccumExtended = getExtendedExpr(Accum, /*CreateSignExtend=*/true);
5674 if (PredIsKnownFalse(Accum, AccumExtended)) {
5675 LLVM_DEBUG(dbgs() << "P3 is compile-time false\n";);
5676 return std::nullopt;
5677 }
5678
5679 auto AppendPredicate = [&](const SCEV *Expr,
5680 const SCEV *ExtendedExpr) -> void {
5681 if (Expr != ExtendedExpr &&
5682 !isKnownPredicate(ICmpInst::ICMP_EQ, Expr, ExtendedExpr)) {
5683 const SCEVPredicate *Pred = getEqualPredicate(Expr, ExtendedExpr);
5684 LLVM_DEBUG(dbgs() << "Added Predicate: " << *Pred);
5685 Predicates.push_back(Pred);
5686 }
5687 };
5688
5689 AppendPredicate(StartVal, StartExtended);
5690 AppendPredicate(Accum, AccumExtended);
5691
5692 // *** Part3: Predicates are ready. Now go ahead and create the new addrec in
5693 // which the casts had been folded away. The caller can rewrite SymbolicPHI
5694 // into NewAR if it will also add the runtime overflow checks specified in
5695 // Predicates.
5696 auto *NewAR = getAddRecExpr(StartVal, Accum, L, SCEV::FlagAnyWrap);
5697
5698 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> PredRewrite =
5699 std::make_pair(NewAR, Predicates);
5700 // Remember the result of the analysis for this SCEV at this locayyytion.
5701 PredicatedSCEVRewrites[{SymbolicPHI, L}] = PredRewrite;
5702 return PredRewrite;
5703}
5704
5705std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5707 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5708 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5709 if (!L)
5710 return std::nullopt;
5711
5712 // Check to see if we already analyzed this PHI.
5713 auto I = PredicatedSCEVRewrites.find({SymbolicPHI, L});
5714 if (I != PredicatedSCEVRewrites.end()) {
5715 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> Rewrite =
5716 I->second;
5717 // Analysis was done before and failed to create an AddRec:
5718 if (Rewrite.first == SymbolicPHI)
5719 return std::nullopt;
5720 // Analysis was done before and succeeded to create an AddRec under
5721 // a predicate:
5722 assert(isa<SCEVAddRecExpr>(Rewrite.first) && "Expected an AddRec");
5723 assert(!(Rewrite.second).empty() && "Expected to find Predicates");
5724 return Rewrite;
5725 }
5726
5727 std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5728 Rewrite = createAddRecFromPHIWithCastsImpl(SymbolicPHI);
5729
5730 // Record in the cache that the analysis failed
5731 if (!Rewrite) {
5733 PredicatedSCEVRewrites[{SymbolicPHI, L}] = {SymbolicPHI, Predicates};
5734 return std::nullopt;
5735 }
5736
5737 return Rewrite;
5738}
5739
5740// FIXME: This utility is currently required because the Rewriter currently
5741// does not rewrite this expression:
5742// {0, +, (sext ix (trunc iy to ix) to iy)}
5743// into {0, +, %step},
5744// even when the following Equal predicate exists:
5745// "%step == (sext ix (trunc iy to ix) to iy)".
5747 const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const {
5748 if (AR1 == AR2)
5749 return true;
5750
5751 auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool {
5752 if (Expr1 != Expr2 &&
5753 !Preds->implies(SE.getEqualPredicate(Expr1, Expr2), SE) &&
5754 !Preds->implies(SE.getEqualPredicate(Expr2, Expr1), SE))
5755 return false;
5756 return true;
5757 };
5758
5759 if (!areExprsEqual(AR1->getStart(), AR2->getStart()) ||
5760 !areExprsEqual(AR1->getStepRecurrence(SE), AR2->getStepRecurrence(SE)))
5761 return false;
5762 return true;
5763}
5764
5765/// A helper function for createAddRecFromPHI to handle simple cases.
5766///
5767/// This function tries to find an AddRec expression for the simplest (yet most
5768/// common) cases: PN = PHI(Start, OP(Self, LoopInvariant)).
5769/// If it fails, createAddRecFromPHI will use a more general, but slow,
5770/// technique for finding the AddRec expression.
5771const SCEV *ScalarEvolution::createSimpleAffineAddRec(PHINode *PN,
5772 Value *BEValueV,
5773 Value *StartValueV) {
5774 const Loop *L = LI.getLoopFor(PN->getParent());
5775 assert(L && L->getHeader() == PN->getParent());
5776 assert(BEValueV && StartValueV);
5777
5778 auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN);
5779 if (!BO)
5780 return nullptr;
5781
5782 if (BO->Opcode != Instruction::Add)
5783 return nullptr;
5784
5785 const SCEV *Accum = nullptr;
5786 if (BO->LHS == PN && L->isLoopInvariant(BO->RHS))
5787 Accum = getSCEV(BO->RHS);
5788 else if (BO->RHS == PN && L->isLoopInvariant(BO->LHS))
5789 Accum = getSCEV(BO->LHS);
5790
5791 if (!Accum)
5792 return nullptr;
5793
5795 if (BO->IsNUW)
5796 Flags = setFlags(Flags, SCEV::FlagNUW);
5797 if (BO->IsNSW)
5798 Flags = setFlags(Flags, SCEV::FlagNSW);
5799
5800 const SCEV *StartVal = getSCEV(StartValueV);
5801 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5802 insertValueToMap(PN, PHISCEV);
5803
5804 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5805 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR),
5807 proveNoWrapViaConstantRanges(AR)));
5808 }
5809
5810 // We can add Flags to the post-inc expression only if we
5811 // know that it is *undefined behavior* for BEValueV to
5812 // overflow.
5813 if (auto *BEInst = dyn_cast<Instruction>(BEValueV)) {
5814 assert(isLoopInvariant(Accum, L) &&
5815 "Accum is defined outside L, but is not invariant?");
5816 if (isAddRecNeverPoison(BEInst, L))
5817 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5818 }
5819
5820 return PHISCEV;
5821}
5822
5823const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
5824 const Loop *L = LI.getLoopFor(PN->getParent());
5825 if (!L || L->getHeader() != PN->getParent())
5826 return nullptr;
5827
5828 // The loop may have multiple entrances or multiple exits; we can analyze
5829 // this phi as an addrec if it has a unique entry value and a unique
5830 // backedge value.
5831 Value *BEValueV = nullptr, *StartValueV = nullptr;
5832 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5833 Value *V = PN->getIncomingValue(i);
5834 if (L->contains(PN->getIncomingBlock(i))) {
5835 if (!BEValueV) {
5836 BEValueV = V;
5837 } else if (BEValueV != V) {
5838 BEValueV = nullptr;
5839 break;
5840 }
5841 } else if (!StartValueV) {
5842 StartValueV = V;
5843 } else if (StartValueV != V) {
5844 StartValueV = nullptr;
5845 break;
5846 }
5847 }
5848 if (!BEValueV || !StartValueV)
5849 return nullptr;
5850
5851 assert(ValueExprMap.find_as(PN) == ValueExprMap.end() &&
5852 "PHI node already processed?");
5853
5854 // First, try to find AddRec expression without creating a fictituos symbolic
5855 // value for PN.
5856 if (auto *S = createSimpleAffineAddRec(PN, BEValueV, StartValueV))
5857 return S;
5858
5859 // Handle PHI node value symbolically.
5860 const SCEV *SymbolicName = getUnknown(PN);
5861 insertValueToMap(PN, SymbolicName);
5862
5863 // Using this symbolic name for the PHI, analyze the value coming around
5864 // the back-edge.
5865 const SCEV *BEValue = getSCEV(BEValueV);
5866
5867 // NOTE: If BEValue is loop invariant, we know that the PHI node just
5868 // has a special value for the first iteration of the loop.
5869
5870 // If the value coming around the backedge is an add with the symbolic
5871 // value we just inserted, then we found a simple induction variable!
5872 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) {
5873 // If there is a single occurrence of the symbolic value, replace it
5874 // with a recurrence.
5875 unsigned FoundIndex = Add->getNumOperands();
5876 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5877 if (Add->getOperand(i) == SymbolicName)
5878 if (FoundIndex == e) {
5879 FoundIndex = i;
5880 break;
5881 }
5882
5883 if (FoundIndex != Add->getNumOperands()) {
5884 // Create an add with everything but the specified operand.
5886 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5887 if (i != FoundIndex)
5888 Ops.push_back(SCEVBackedgeConditionFolder::rewrite(Add->getOperand(i),
5889 L, *this));
5890 const SCEV *Accum = getAddExpr(Ops);
5891
5892 // This is not a valid addrec if the step amount is varying each
5893 // loop iteration, but is not itself an addrec in this loop.
5894 if (isLoopInvariant(Accum, L) ||
5895 (isa<SCEVAddRecExpr>(Accum) &&
5896 cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) {
5898
5899 if (auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN)) {
5900 if (BO->Opcode == Instruction::Add && BO->LHS == PN) {
5901 if (BO->IsNUW)
5902 Flags = setFlags(Flags, SCEV::FlagNUW);
5903 if (BO->IsNSW)
5904 Flags = setFlags(Flags, SCEV::FlagNSW);
5905 }
5906 } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(BEValueV)) {
5907 if (GEP->getOperand(0) == PN) {
5908 GEPNoWrapFlags NW = GEP->getNoWrapFlags();
5909 // If the increment has any nowrap flags, then we know the address
5910 // space cannot be wrapped around.
5911 if (NW != GEPNoWrapFlags::none())
5912 Flags = setFlags(Flags, SCEV::FlagNW);
5913 // If the GEP is nuw or nusw with non-negative offset, we know that
5914 // no unsigned wrap occurs. We cannot set the nsw flag as only the
5915 // offset is treated as signed, while the base is unsigned.
5916 if (NW.hasNoUnsignedWrap() ||
5918 Flags = setFlags(Flags, SCEV::FlagNUW);
5919 }
5920
5921 // We cannot transfer nuw and nsw flags from subtraction
5922 // operations -- sub nuw X, Y is not the same as add nuw X, -Y
5923 // for instance.
5924 }
5925
5926 const SCEV *StartVal = getSCEV(StartValueV);
5927 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5928
5929 // Okay, for the entire analysis of this edge we assumed the PHI
5930 // to be symbolic. We now need to go back and purge all of the
5931 // entries for the scalars that use the symbolic expression.
5932 forgetMemoizedResults(SymbolicName);
5933 insertValueToMap(PN, PHISCEV);
5934
5935 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5936 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR),
5938 proveNoWrapViaConstantRanges(AR)));
5939 }
5940
5941 // We can add Flags to the post-inc expression only if we
5942 // know that it is *undefined behavior* for BEValueV to
5943 // overflow.
5944 if (auto *BEInst = dyn_cast<Instruction>(BEValueV))
5945 if (isLoopInvariant(Accum, L) && isAddRecNeverPoison(BEInst, L))
5946 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5947
5948 return PHISCEV;
5949 }
5950 }
5951 } else {
5952 // Otherwise, this could be a loop like this:
5953 // i = 0; for (j = 1; ..; ++j) { .... i = j; }
5954 // In this case, j = {1,+,1} and BEValue is j.
5955 // Because the other in-value of i (0) fits the evolution of BEValue
5956 // i really is an addrec evolution.
5957 //
5958 // We can generalize this saying that i is the shifted value of BEValue
5959 // by one iteration:
5960 // PHI(f(0), f({1,+,1})) --> f({0,+,1})
5961
5962 // Do not allow refinement in rewriting of BEValue.
5963 const SCEV *Shifted = SCEVShiftRewriter::rewrite(BEValue, L, *this);
5964 const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this, false);
5965 if (Shifted != getCouldNotCompute() && Start != getCouldNotCompute() &&
5966 isGuaranteedNotToCauseUB(Shifted) && ::impliesPoison(Shifted, Start)) {
5967 const SCEV *StartVal = getSCEV(StartValueV);
5968 if (Start == StartVal) {
5969 // Okay, for the entire analysis of this edge we assumed the PHI
5970 // to be symbolic. We now need to go back and purge all of the
5971 // entries for the scalars that use the symbolic expression.
5972 forgetMemoizedResults(SymbolicName);
5973 insertValueToMap(PN, Shifted);
5974 return Shifted;
5975 }
5976 }
5977 }
5978
5979 // Remove the temporary PHI node SCEV that has been inserted while intending
5980 // to create an AddRecExpr for this PHI node. We can not keep this temporary
5981 // as it will prevent later (possibly simpler) SCEV expressions to be added
5982 // to the ValueExprMap.
5983 eraseValueFromMap(PN);
5984
5985 return nullptr;
5986}
5987
5988// Try to match a control flow sequence that branches out at BI and merges back
5989// at Merge into a "C ? LHS : RHS" select pattern. Return true on a successful
5990// match.
5992 Value *&C, Value *&LHS, Value *&RHS) {
5993 C = BI->getCondition();
5994
5995 BasicBlockEdge LeftEdge(BI->getParent(), BI->getSuccessor(0));
5996 BasicBlockEdge RightEdge(BI->getParent(), BI->getSuccessor(1));
5997
5998 if (!LeftEdge.isSingleEdge())
5999 return false;
6000
6001 assert(RightEdge.isSingleEdge() && "Follows from LeftEdge.isSingleEdge()");
6002
6003 Use &LeftUse = Merge->getOperandUse(0);
6004 Use &RightUse = Merge->getOperandUse(1);
6005
6006 if (DT.dominates(LeftEdge, LeftUse) && DT.dominates(RightEdge, RightUse)) {
6007 LHS = LeftUse;
6008 RHS = RightUse;
6009 return true;
6010 }
6011
6012 if (DT.dominates(LeftEdge, RightUse) && DT.dominates(RightEdge, LeftUse)) {
6013 LHS = RightUse;
6014 RHS = LeftUse;
6015 return true;
6016 }
6017
6018 return false;
6019}
6020
6021const SCEV *ScalarEvolution::createNodeFromSelectLikePHI(PHINode *PN) {
6022 auto IsReachable =
6023 [&](BasicBlock *BB) { return DT.isReachableFromEntry(BB); };
6024 if (PN->getNumIncomingValues() == 2 && all_of(PN->blocks(), IsReachable)) {
6025 // Try to match
6026 //
6027 // br %cond, label %left, label %right
6028 // left:
6029 // br label %merge
6030 // right:
6031 // br label %merge
6032 // merge:
6033 // V = phi [ %x, %left ], [ %y, %right ]
6034 //
6035 // as "select %cond, %x, %y"
6036
6037 BasicBlock *IDom = DT[PN->getParent()]->getIDom()->getBlock();
6038 assert(IDom && "At least the entry block should dominate PN");
6039
6040 auto *BI = dyn_cast<BranchInst>(IDom->getTerminator());
6041 Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr;
6042
6043 if (BI && BI->isConditional() &&
6044 BrPHIToSelect(DT, BI, PN, Cond, LHS, RHS) &&
6047 return createNodeForSelectOrPHI(PN, Cond, LHS, RHS);
6048 }
6049
6050 return nullptr;
6051}
6052
6053/// Returns SCEV for the first operand of a phi if all phi operands have
6054/// identical opcodes and operands
6055/// eg.
6056/// a: %add = %a + %b
6057/// br %c
6058/// b: %add1 = %a + %b
6059/// br %c
6060/// c: %phi = phi [%add, a], [%add1, b]
6061/// scev(%phi) => scev(%add)
6062const SCEV *
6063ScalarEvolution::createNodeForPHIWithIdenticalOperands(PHINode *PN) {
6064 BinaryOperator *CommonInst = nullptr;
6065 // Check if instructions are identical.
6066 for (Value *Incoming : PN->incoming_values()) {
6067 auto *IncomingInst = dyn_cast<BinaryOperator>(Incoming);
6068 if (!IncomingInst)
6069 return nullptr;
6070 if (CommonInst) {
6071 if (!CommonInst->isIdenticalToWhenDefined(IncomingInst))
6072 return nullptr; // Not identical, give up
6073 } else {
6074 // Remember binary operator
6075 CommonInst = IncomingInst;
6076 }
6077 }
6078 if (!CommonInst)
6079 return nullptr;
6080
6081 // Check if SCEV exprs for instructions are identical.
6082 const SCEV *CommonSCEV = getSCEV(CommonInst);
6083 bool SCEVExprsIdentical =
6085 [this, CommonSCEV](Value *V) { return CommonSCEV == getSCEV(V); });
6086 return SCEVExprsIdentical ? CommonSCEV : nullptr;
6087}
6088
6089const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
6090 if (const SCEV *S = createAddRecFromPHI(PN))
6091 return S;
6092
6093 // We do not allow simplifying phi (undef, X) to X here, to avoid reusing the
6094 // phi node for X.
6095 if (Value *V = simplifyInstruction(
6096 PN, {getDataLayout(), &TLI, &DT, &AC, /*CtxI=*/nullptr,
6097 /*UseInstrInfo=*/true, /*CanUseUndef=*/false}))
6098 return getSCEV(V);
6099
6100 if (const SCEV *S = createNodeForPHIWithIdenticalOperands(PN))
6101 return S;
6102
6103 if (const SCEV *S = createNodeFromSelectLikePHI(PN))
6104 return S;
6105
6106 // If it's not a loop phi, we can't handle it yet.
6107 return getUnknown(PN);
6108}
6109
6110bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind,
6111 SCEVTypes RootKind) {
6112 struct FindClosure {
6113 const SCEV *OperandToFind;
6114 const SCEVTypes RootKind; // Must be a sequential min/max expression.
6115 const SCEVTypes NonSequentialRootKind; // Non-seq variant of RootKind.
6116
6117 bool Found = false;
6118
6119 bool canRecurseInto(SCEVTypes Kind) const {
6120 // We can only recurse into the SCEV expression of the same effective type
6121 // as the type of our root SCEV expression, and into zero-extensions.
6122 return RootKind == Kind || NonSequentialRootKind == Kind ||
6123 scZeroExtend == Kind;
6124 };
6125
6126 FindClosure(const SCEV *OperandToFind, SCEVTypes RootKind)
6127 : OperandToFind(OperandToFind), RootKind(RootKind),
6128 NonSequentialRootKind(
6130 RootKind)) {}
6131
6132 bool follow(const SCEV *S) {
6133 Found = S == OperandToFind;
6134
6135 return !isDone() && canRecurseInto(S->getSCEVType());
6136 }
6137
6138 bool isDone() const { return Found; }
6139 };
6140
6141 FindClosure FC(OperandToFind, RootKind);
6142 visitAll(Root, FC);
6143 return FC.Found;
6144}
6145
6146std::optional<const SCEV *>
6147ScalarEvolution::createNodeForSelectOrPHIInstWithICmpInstCond(Type *Ty,
6148 ICmpInst *Cond,
6149 Value *TrueVal,
6150 Value *FalseVal) {
6151 // Try to match some simple smax or umax patterns.
6152 auto *ICI = Cond;
6153
6154 Value *LHS = ICI->getOperand(0);
6155 Value *RHS = ICI->getOperand(1);
6156
6157 switch (ICI->getPredicate()) {
6158 case ICmpInst::ICMP_SLT:
6159 case ICmpInst::ICMP_SLE:
6160 case ICmpInst::ICMP_ULT:
6161 case ICmpInst::ICMP_ULE:
6162 std::swap(LHS, RHS);
6163 [[fallthrough]];
6164 case ICmpInst::ICMP_SGT:
6165 case ICmpInst::ICMP_SGE:
6166 case ICmpInst::ICMP_UGT:
6167 case ICmpInst::ICMP_UGE:
6168 // a > b ? a+x : b+x -> max(a, b)+x
6169 // a > b ? b+x : a+x -> min(a, b)+x
6171 bool Signed = ICI->isSigned();
6172 const SCEV *LA = getSCEV(TrueVal);
6173 const SCEV *RA = getSCEV(FalseVal);
6174 const SCEV *LS = getSCEV(LHS);
6175 const SCEV *RS = getSCEV(RHS);
6176 if (LA->getType()->isPointerTy()) {
6177 // FIXME: Handle cases where LS/RS are pointers not equal to LA/RA.
6178 // Need to make sure we can't produce weird expressions involving
6179 // negated pointers.
6180 if (LA == LS && RA == RS)
6181 return Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS);
6182 if (LA == RS && RA == LS)
6183 return Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS);
6184 }
6185 auto CoerceOperand = [&](const SCEV *Op) -> const SCEV * {
6186 if (Op->getType()->isPointerTy()) {
6189 return Op;
6190 }
6191 if (Signed)
6192 Op = getNoopOrSignExtend(Op, Ty);
6193 else
6194 Op = getNoopOrZeroExtend(Op, Ty);
6195 return Op;
6196 };
6197 LS = CoerceOperand(LS);
6198 RS = CoerceOperand(RS);
6200 break;
6201 const SCEV *LDiff = getMinusSCEV(LA, LS);
6202 const SCEV *RDiff = getMinusSCEV(RA, RS);
6203 if (LDiff == RDiff)
6204 return getAddExpr(Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS),
6205 LDiff);
6206 LDiff = getMinusSCEV(LA, RS);
6207 RDiff = getMinusSCEV(RA, LS);
6208 if (LDiff == RDiff)
6209 return getAddExpr(Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS),
6210 LDiff);
6211 }
6212 break;
6213 case ICmpInst::ICMP_NE:
6214 // x != 0 ? x+y : C+y -> x == 0 ? C+y : x+y
6215 std::swap(TrueVal, FalseVal);
6216 [[fallthrough]];
6217 case ICmpInst::ICMP_EQ:
6218 // x == 0 ? C+y : x+y -> umax(x, C)+y iff C u<= 1
6221 const SCEV *X = getNoopOrZeroExtend(getSCEV(LHS), Ty);
6222 const SCEV *TrueValExpr = getSCEV(TrueVal); // C+y
6223 const SCEV *FalseValExpr = getSCEV(FalseVal); // x+y
6224 const SCEV *Y = getMinusSCEV(FalseValExpr, X); // y = (x+y)-x
6225 const SCEV *C = getMinusSCEV(TrueValExpr, Y); // C = (C+y)-y
6226 if (isa<SCEVConstant>(C) && cast<SCEVConstant>(C)->getAPInt().ule(1))
6227 return getAddExpr(getUMaxExpr(X, C), Y);
6228 }
6229 // x == 0 ? 0 : umin (..., x, ...) -> umin_seq(x, umin (...))
6230 // x == 0 ? 0 : umin_seq(..., x, ...) -> umin_seq(x, umin_seq(...))
6231 // x == 0 ? 0 : umin (..., umin_seq(..., x, ...), ...)
6232 // -> umin_seq(x, umin (..., umin_seq(...), ...))
6234 isa<ConstantInt>(TrueVal) && cast<ConstantInt>(TrueVal)->isZero()) {
6235 const SCEV *X = getSCEV(LHS);
6236 while (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(X))
6237 X = ZExt->getOperand();
6238 if (getTypeSizeInBits(X->getType()) <= getTypeSizeInBits(Ty)) {
6239 const SCEV *FalseValExpr = getSCEV(FalseVal);
6240 if (SCEVMinMaxExprContains(FalseValExpr, X, scSequentialUMinExpr))
6241 return getUMinExpr(getNoopOrZeroExtend(X, Ty), FalseValExpr,
6242 /*Sequential=*/true);
6243 }
6244 }
6245 break;
6246 default:
6247 break;
6248 }
6249
6250 return std::nullopt;
6251}
6252
6253static std::optional<const SCEV *>
6255 const SCEV *TrueExpr, const SCEV *FalseExpr) {
6256 assert(CondExpr->getType()->isIntegerTy(1) &&
6257 TrueExpr->getType() == FalseExpr->getType() &&
6258 TrueExpr->getType()->isIntegerTy(1) &&
6259 "Unexpected operands of a select.");
6260
6261 // i1 cond ? i1 x : i1 C --> C + (i1 cond ? (i1 x - i1 C) : i1 0)
6262 // --> C + (umin_seq cond, x - C)
6263 //
6264 // i1 cond ? i1 C : i1 x --> C + (i1 cond ? i1 0 : (i1 x - i1 C))
6265 // --> C + (i1 ~cond ? (i1 x - i1 C) : i1 0)
6266 // --> C + (umin_seq ~cond, x - C)
6267
6268 // FIXME: while we can't legally model the case where both of the hands
6269 // are fully variable, we only require that the *difference* is constant.
6270 if (!isa<SCEVConstant>(TrueExpr) && !isa<SCEVConstant>(FalseExpr))
6271 return std::nullopt;
6272
6273 const SCEV *X, *C;
6274 if (isa<SCEVConstant>(TrueExpr)) {
6275 CondExpr = SE->getNotSCEV(CondExpr);
6276 X = FalseExpr;
6277 C = TrueExpr;
6278 } else {
6279 X = TrueExpr;
6280 C = FalseExpr;
6281 }
6282 return SE->getAddExpr(C, SE->getUMinExpr(CondExpr, SE->getMinusSCEV(X, C),
6283 /*Sequential=*/true));
6284}
6285
6286static std::optional<const SCEV *>
6288 Value *FalseVal) {
6289 if (!isa<ConstantInt>(TrueVal) && !isa<ConstantInt>(FalseVal))
6290 return std::nullopt;
6291
6292 const auto *SECond = SE->getSCEV(Cond);
6293 const auto *SETrue = SE->getSCEV(TrueVal);
6294 const auto *SEFalse = SE->getSCEV(FalseVal);
6295 return createNodeForSelectViaUMinSeq(SE, SECond, SETrue, SEFalse);
6296}
6297
6298const SCEV *ScalarEvolution::createNodeForSelectOrPHIViaUMinSeq(
6299 Value *V, Value *Cond, Value *TrueVal, Value *FalseVal) {
6300 assert(Cond->getType()->isIntegerTy(1) && "Select condition is not an i1?");
6301 assert(TrueVal->getType() == FalseVal->getType() &&
6302 V->getType() == TrueVal->getType() &&
6303 "Types of select hands and of the result must match.");
6304
6305 // For now, only deal with i1-typed `select`s.
6306 if (!V->getType()->isIntegerTy(1))
6307 return getUnknown(V);
6308
6309 if (std::optional<const SCEV *> S =
6310 createNodeForSelectViaUMinSeq(this, Cond, TrueVal, FalseVal))
6311 return *S;
6312
6313 return getUnknown(V);
6314}
6315
6316const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Value *V, Value *Cond,
6317 Value *TrueVal,
6318 Value *FalseVal) {
6319 // Handle "constant" branch or select. This can occur for instance when a
6320 // loop pass transforms an inner loop and moves on to process the outer loop.
6321 if (auto *CI = dyn_cast<ConstantInt>(Cond))
6322 return getSCEV(CI->isOne() ? TrueVal : FalseVal);
6323
6324 if (auto *I = dyn_cast<Instruction>(V)) {
6325 if (auto *ICI = dyn_cast<ICmpInst>(Cond)) {
6326 if (std::optional<const SCEV *> S =
6327 createNodeForSelectOrPHIInstWithICmpInstCond(I->getType(), ICI,
6328 TrueVal, FalseVal))
6329 return *S;
6330 }
6331 }
6332
6333 return createNodeForSelectOrPHIViaUMinSeq(V, Cond, TrueVal, FalseVal);
6334}
6335
6336/// Expand GEP instructions into add and multiply operations. This allows them
6337/// to be analyzed by regular SCEV code.
6338const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
6339 assert(GEP->getSourceElementType()->isSized() &&
6340 "GEP source element type must be sized");
6341
6343 for (Value *Index : GEP->indices())
6344 IndexExprs.push_back(getSCEV(Index));
6345 return getGEPExpr(GEP, IndexExprs);
6346}
6347
6348APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S,
6349 const Instruction *CtxI) {
6350 uint64_t BitWidth = getTypeSizeInBits(S->getType());
6351 auto GetShiftedByZeros = [BitWidth](uint32_t TrailingZeros) {
6352 return TrailingZeros >= BitWidth
6354 : APInt::getOneBitSet(BitWidth, TrailingZeros);
6355 };
6356 auto GetGCDMultiple = [this, CtxI](const SCEVNAryExpr *N) {
6357 // The result is GCD of all operands results.
6358 APInt Res = getConstantMultiple(N->getOperand(0), CtxI);
6359 for (unsigned I = 1, E = N->getNumOperands(); I < E && Res != 1; ++I)
6361 Res, getConstantMultiple(N->getOperand(I), CtxI));
6362 return Res;
6363 };
6364
6365 switch (S->getSCEVType()) {
6366 case scConstant:
6367 return cast<SCEVConstant>(S)->getAPInt();
6368 case scPtrToInt:
6369 return getConstantMultiple(cast<SCEVPtrToIntExpr>(S)->getOperand(), CtxI);
6370 case scUDivExpr:
6371 case scVScale:
6372 return APInt(BitWidth, 1);
6373 case scTruncate: {
6374 // Only multiples that are a power of 2 will hold after truncation.
6375 const SCEVTruncateExpr *T = cast<SCEVTruncateExpr>(S);
6376 uint32_t TZ = getMinTrailingZeros(T->getOperand(), CtxI);
6377 return GetShiftedByZeros(TZ);
6378 }
6379 case scZeroExtend: {
6380 const SCEVZeroExtendExpr *Z = cast<SCEVZeroExtendExpr>(S);
6381 return getConstantMultiple(Z->getOperand(), CtxI).zext(BitWidth);
6382 }
6383 case scSignExtend: {
6384 // Only multiples that are a power of 2 will hold after sext.
6385 const SCEVSignExtendExpr *E = cast<SCEVSignExtendExpr>(S);
6386 uint32_t TZ = getMinTrailingZeros(E->getOperand(), CtxI);
6387 return GetShiftedByZeros(TZ);
6388 }
6389 case scMulExpr: {
6390 const SCEVMulExpr *M = cast<SCEVMulExpr>(S);
6391 if (M->hasNoUnsignedWrap()) {
6392 // The result is the product of all operand results.
6393 APInt Res = getConstantMultiple(M->getOperand(0), CtxI);
6394 for (const SCEV *Operand : M->operands().drop_front())
6395 Res = Res * getConstantMultiple(Operand, CtxI);
6396 return Res;
6397 }
6398
6399 // If there are no wrap guarentees, find the trailing zeros, which is the
6400 // sum of trailing zeros for all its operands.
6401 uint32_t TZ = 0;
6402 for (const SCEV *Operand : M->operands())
6403 TZ += getMinTrailingZeros(Operand, CtxI);
6404 return GetShiftedByZeros(TZ);
6405 }
6406 case scAddExpr:
6407 case scAddRecExpr: {
6408 const SCEVNAryExpr *N = cast<SCEVNAryExpr>(S);
6409 if (N->hasNoUnsignedWrap())
6410 return GetGCDMultiple(N);
6411 // Find the trailing bits, which is the minimum of its operands.
6412 uint32_t TZ = getMinTrailingZeros(N->getOperand(0), CtxI);
6413 for (const SCEV *Operand : N->operands().drop_front())
6414 TZ = std::min(TZ, getMinTrailingZeros(Operand, CtxI));
6415 return GetShiftedByZeros(TZ);
6416 }
6417 case scUMaxExpr:
6418 case scSMaxExpr:
6419 case scUMinExpr:
6420 case scSMinExpr:
6422 return GetGCDMultiple(cast<SCEVNAryExpr>(S));
6423 case scUnknown: {
6424 // Ask ValueTracking for known bits. SCEVUnknown only become available at
6425 // the point their underlying IR instruction has been defined. If CtxI was
6426 // not provided, use:
6427 // * the first instruction in the entry block if it is an argument
6428 // * the instruction itself otherwise.
6429 const SCEVUnknown *U = cast<SCEVUnknown>(S);
6430 if (!CtxI) {
6431 if (isa<Argument>(U->getValue()))
6432 CtxI = &*F.getEntryBlock().begin();
6433 else if (auto *I = dyn_cast<Instruction>(U->getValue()))
6434 CtxI = I;
6435 }
6436 unsigned Known =
6437 computeKnownBits(U->getValue(), getDataLayout(), &AC, CtxI, &DT)
6438 .countMinTrailingZeros();
6439 return GetShiftedByZeros(Known);
6440 }
6441 case scCouldNotCompute:
6442 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6443 }
6444 llvm_unreachable("Unknown SCEV kind!");
6445}
6446
6448 const Instruction *CtxI) {
6449 // Skip looking up and updating the cache if there is a context instruction,
6450 // as the result will only be valid in the specified context.
6451 if (CtxI)
6452 return getConstantMultipleImpl(S, CtxI);
6453
6454 auto I = ConstantMultipleCache.find(S);
6455 if (I != ConstantMultipleCache.end())
6456 return I->second;
6457
6458 APInt Result = getConstantMultipleImpl(S, CtxI);
6459 auto InsertPair = ConstantMultipleCache.insert({S, Result});
6460 assert(InsertPair.second && "Should insert a new key");
6461 return InsertPair.first->second;
6462}
6463
6465 APInt Multiple = getConstantMultiple(S);
6466 return Multiple == 0 ? APInt(Multiple.getBitWidth(), 1) : Multiple;
6467}
6468
6470 const Instruction *CtxI) {
6471 return std::min(getConstantMultiple(S, CtxI).countTrailingZeros(),
6472 (unsigned)getTypeSizeInBits(S->getType()));
6473}
6474
6475/// Helper method to assign a range to V from metadata present in the IR.
6476static std::optional<ConstantRange> GetRangeFromMetadata(Value *V) {
6478 if (MDNode *MD = I->getMetadata(LLVMContext::MD_range))
6479 return getConstantRangeFromMetadata(*MD);
6480 if (const auto *CB = dyn_cast<CallBase>(V))
6481 if (std::optional<ConstantRange> Range = CB->getRange())
6482 return Range;
6483 }
6484 if (auto *A = dyn_cast<Argument>(V))
6485 if (std::optional<ConstantRange> Range = A->getRange())
6486 return Range;
6487
6488 return std::nullopt;
6489}
6490
6492 SCEV::NoWrapFlags Flags) {
6493 if (AddRec->getNoWrapFlags(Flags) != Flags) {
6494 AddRec->setNoWrapFlags(Flags);
6495 UnsignedRanges.erase(AddRec);
6496 SignedRanges.erase(AddRec);
6497 ConstantMultipleCache.erase(AddRec);
6498 }
6499}
6500
6501ConstantRange ScalarEvolution::
6502getRangeForUnknownRecurrence(const SCEVUnknown *U) {
6503 const DataLayout &DL = getDataLayout();
6504
6505 unsigned BitWidth = getTypeSizeInBits(U->getType());
6506 const ConstantRange FullSet(BitWidth, /*isFullSet=*/true);
6507
6508 // Match a simple recurrence of the form: <start, ShiftOp, Step>, and then
6509 // use information about the trip count to improve our available range. Note
6510 // that the trip count independent cases are already handled by known bits.
6511 // WARNING: The definition of recurrence used here is subtly different than
6512 // the one used by AddRec (and thus most of this file). Step is allowed to
6513 // be arbitrarily loop varying here, where AddRec allows only loop invariant
6514 // and other addrecs in the same loop (for non-affine addrecs). The code
6515 // below intentionally handles the case where step is not loop invariant.
6516 auto *P = dyn_cast<PHINode>(U->getValue());
6517 if (!P)
6518 return FullSet;
6519
6520 // Make sure that no Phi input comes from an unreachable block. Otherwise,
6521 // even the values that are not available in these blocks may come from them,
6522 // and this leads to false-positive recurrence test.
6523 for (auto *Pred : predecessors(P->getParent()))
6524 if (!DT.isReachableFromEntry(Pred))
6525 return FullSet;
6526
6527 BinaryOperator *BO;
6528 Value *Start, *Step;
6529 if (!matchSimpleRecurrence(P, BO, Start, Step))
6530 return FullSet;
6531
6532 // If we found a recurrence in reachable code, we must be in a loop. Note
6533 // that BO might be in some subloop of L, and that's completely okay.
6534 auto *L = LI.getLoopFor(P->getParent());
6535 assert(L && L->getHeader() == P->getParent());
6536 if (!L->contains(BO->getParent()))
6537 // NOTE: This bailout should be an assert instead. However, asserting
6538 // the condition here exposes a case where LoopFusion is querying SCEV
6539 // with malformed loop information during the midst of the transform.
6540 // There doesn't appear to be an obvious fix, so for the moment bailout
6541 // until the caller issue can be fixed. PR49566 tracks the bug.
6542 return FullSet;
6543
6544 // TODO: Extend to other opcodes such as mul, and div
6545 switch (BO->getOpcode()) {
6546 default:
6547 return FullSet;
6548 case Instruction::AShr:
6549 case Instruction::LShr:
6550 case Instruction::Shl:
6551 break;
6552 };
6553
6554 if (BO->getOperand(0) != P)
6555 // TODO: Handle the power function forms some day.
6556 return FullSet;
6557
6558 unsigned TC = getSmallConstantMaxTripCount(L);
6559 if (!TC || TC >= BitWidth)
6560 return FullSet;
6561
6562 auto KnownStart = computeKnownBits(Start, DL, &AC, nullptr, &DT);
6563 auto KnownStep = computeKnownBits(Step, DL, &AC, nullptr, &DT);
6564 assert(KnownStart.getBitWidth() == BitWidth &&
6565 KnownStep.getBitWidth() == BitWidth);
6566
6567 // Compute total shift amount, being careful of overflow and bitwidths.
6568 auto MaxShiftAmt = KnownStep.getMaxValue();
6569 APInt TCAP(BitWidth, TC-1);
6570 bool Overflow = false;
6571 auto TotalShift = MaxShiftAmt.umul_ov(TCAP, Overflow);
6572 if (Overflow)
6573 return FullSet;
6574
6575 switch (BO->getOpcode()) {
6576 default:
6577 llvm_unreachable("filtered out above");
6578 case Instruction::AShr: {
6579 // For each ashr, three cases:
6580 // shift = 0 => unchanged value
6581 // saturation => 0 or -1
6582 // other => a value closer to zero (of the same sign)
6583 // Thus, the end value is closer to zero than the start.
6584 auto KnownEnd = KnownBits::ashr(KnownStart,
6585 KnownBits::makeConstant(TotalShift));
6586 if (KnownStart.isNonNegative())
6587 // Analogous to lshr (simply not yet canonicalized)
6588 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6589 KnownStart.getMaxValue() + 1);
6590 if (KnownStart.isNegative())
6591 // End >=u Start && End <=s Start
6592 return ConstantRange::getNonEmpty(KnownStart.getMinValue(),
6593 KnownEnd.getMaxValue() + 1);
6594 break;
6595 }
6596 case Instruction::LShr: {
6597 // For each lshr, three cases:
6598 // shift = 0 => unchanged value
6599 // saturation => 0
6600 // other => a smaller positive number
6601 // Thus, the low end of the unsigned range is the last value produced.
6602 auto KnownEnd = KnownBits::lshr(KnownStart,
6603 KnownBits::makeConstant(TotalShift));
6604 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6605 KnownStart.getMaxValue() + 1);
6606 }
6607 case Instruction::Shl: {
6608 // Iff no bits are shifted out, value increases on every shift.
6609 auto KnownEnd = KnownBits::shl(KnownStart,
6610 KnownBits::makeConstant(TotalShift));
6611 if (TotalShift.ult(KnownStart.countMinLeadingZeros()))
6612 return ConstantRange(KnownStart.getMinValue(),
6613 KnownEnd.getMaxValue() + 1);
6614 break;
6615 }
6616 };
6617 return FullSet;
6618}
6619
6620const ConstantRange &
6621ScalarEvolution::getRangeRefIter(const SCEV *S,
6622 ScalarEvolution::RangeSignHint SignHint) {
6623 DenseMap<const SCEV *, ConstantRange> &Cache =
6624 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6625 : SignedRanges;
6627 SmallPtrSet<const SCEV *, 8> Seen;
6628
6629 // Add Expr to the worklist, if Expr is either an N-ary expression or a
6630 // SCEVUnknown PHI node.
6631 auto AddToWorklist = [&WorkList, &Seen, &Cache](const SCEV *Expr) {
6632 if (!Seen.insert(Expr).second)
6633 return;
6634 if (Cache.contains(Expr))
6635 return;
6636 switch (Expr->getSCEVType()) {
6637 case scUnknown:
6638 if (!isa<PHINode>(cast<SCEVUnknown>(Expr)->getValue()))
6639 break;
6640 [[fallthrough]];
6641 case scConstant:
6642 case scVScale:
6643 case scTruncate:
6644 case scZeroExtend:
6645 case scSignExtend:
6646 case scPtrToInt:
6647 case scAddExpr:
6648 case scMulExpr:
6649 case scUDivExpr:
6650 case scAddRecExpr:
6651 case scUMaxExpr:
6652 case scSMaxExpr:
6653 case scUMinExpr:
6654 case scSMinExpr:
6656 WorkList.push_back(Expr);
6657 break;
6658 case scCouldNotCompute:
6659 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6660 }
6661 };
6662 AddToWorklist(S);
6663
6664 // Build worklist by queuing operands of N-ary expressions and phi nodes.
6665 for (unsigned I = 0; I != WorkList.size(); ++I) {
6666 const SCEV *P = WorkList[I];
6667 auto *UnknownS = dyn_cast<SCEVUnknown>(P);
6668 // If it is not a `SCEVUnknown`, just recurse into operands.
6669 if (!UnknownS) {
6670 for (const SCEV *Op : P->operands())
6671 AddToWorklist(Op);
6672 continue;
6673 }
6674 // `SCEVUnknown`'s require special treatment.
6675 if (const PHINode *P = dyn_cast<PHINode>(UnknownS->getValue())) {
6676 if (!PendingPhiRangesIter.insert(P).second)
6677 continue;
6678 for (auto &Op : reverse(P->operands()))
6679 AddToWorklist(getSCEV(Op));
6680 }
6681 }
6682
6683 if (!WorkList.empty()) {
6684 // Use getRangeRef to compute ranges for items in the worklist in reverse
6685 // order. This will force ranges for earlier operands to be computed before
6686 // their users in most cases.
6687 for (const SCEV *P : reverse(drop_begin(WorkList))) {
6688 getRangeRef(P, SignHint);
6689
6690 if (auto *UnknownS = dyn_cast<SCEVUnknown>(P))
6691 if (const PHINode *P = dyn_cast<PHINode>(UnknownS->getValue()))
6692 PendingPhiRangesIter.erase(P);
6693 }
6694 }
6695
6696 return getRangeRef(S, SignHint, 0);
6697}
6698
6699/// Determine the range for a particular SCEV. If SignHint is
6700/// HINT_RANGE_UNSIGNED (resp. HINT_RANGE_SIGNED) then getRange prefers ranges
6701/// with a "cleaner" unsigned (resp. signed) representation.
6702const ConstantRange &ScalarEvolution::getRangeRef(
6703 const SCEV *S, ScalarEvolution::RangeSignHint SignHint, unsigned Depth) {
6704 DenseMap<const SCEV *, ConstantRange> &Cache =
6705 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6706 : SignedRanges;
6708 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? ConstantRange::Unsigned
6710
6711 // See if we've computed this range already.
6713 if (I != Cache.end())
6714 return I->second;
6715
6716 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
6717 return setRange(C, SignHint, ConstantRange(C->getAPInt()));
6718
6719 // Switch to iteratively computing the range for S, if it is part of a deeply
6720 // nested expression.
6722 return getRangeRefIter(S, SignHint);
6723
6724 unsigned BitWidth = getTypeSizeInBits(S->getType());
6725 ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true);
6726 using OBO = OverflowingBinaryOperator;
6727
6728 // If the value has known zeros, the maximum value will have those known zeros
6729 // as well.
6730 if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) {
6731 APInt Multiple = getNonZeroConstantMultiple(S);
6732 APInt Remainder = APInt::getMaxValue(BitWidth).urem(Multiple);
6733 if (!Remainder.isZero())
6734 ConservativeResult =
6735 ConstantRange(APInt::getMinValue(BitWidth),
6736 APInt::getMaxValue(BitWidth) - Remainder + 1);
6737 }
6738 else {
6739 uint32_t TZ = getMinTrailingZeros(S);
6740 if (TZ != 0) {
6741 ConservativeResult = ConstantRange(
6743 APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1);
6744 }
6745 }
6746
6747 switch (S->getSCEVType()) {
6748 case scConstant:
6749 llvm_unreachable("Already handled above.");
6750 case scVScale:
6751 return setRange(S, SignHint, getVScaleRange(&F, BitWidth));
6752 case scTruncate: {
6753 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(S);
6754 ConstantRange X = getRangeRef(Trunc->getOperand(), SignHint, Depth + 1);
6755 return setRange(
6756 Trunc, SignHint,
6757 ConservativeResult.intersectWith(X.truncate(BitWidth), RangeType));
6758 }
6759 case scZeroExtend: {
6760 const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(S);
6761 ConstantRange X = getRangeRef(ZExt->getOperand(), SignHint, Depth + 1);
6762 return setRange(
6763 ZExt, SignHint,
6764 ConservativeResult.intersectWith(X.zeroExtend(BitWidth), RangeType));
6765 }
6766 case scSignExtend: {
6767 const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(S);
6768 ConstantRange X = getRangeRef(SExt->getOperand(), SignHint, Depth + 1);
6769 return setRange(
6770 SExt, SignHint,
6771 ConservativeResult.intersectWith(X.signExtend(BitWidth), RangeType));
6772 }
6773 case scPtrToInt: {
6774 const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(S);
6775 ConstantRange X = getRangeRef(PtrToInt->getOperand(), SignHint, Depth + 1);
6776 return setRange(PtrToInt, SignHint, X);
6777 }
6778 case scAddExpr: {
6779 const SCEVAddExpr *Add = cast<SCEVAddExpr>(S);
6780 // Check if this is a URem pattern: A - (A / B) * B, which is always < B.
6781 const SCEV *URemLHS = nullptr, *URemRHS = nullptr;
6782 if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED &&
6783 match(S, m_scev_URem(m_SCEV(URemLHS), m_SCEV(URemRHS), *this))) {
6784 ConstantRange LHSRange = getRangeRef(URemLHS, SignHint, Depth + 1);
6785 ConstantRange RHSRange = getRangeRef(URemRHS, SignHint, Depth + 1);
6786 ConservativeResult =
6787 ConservativeResult.intersectWith(LHSRange.urem(RHSRange), RangeType);
6788 }
6789 ConstantRange X = getRangeRef(Add->getOperand(0), SignHint, Depth + 1);
6790 unsigned WrapType = OBO::AnyWrap;
6791 if (Add->hasNoSignedWrap())
6792 WrapType |= OBO::NoSignedWrap;
6793 if (Add->hasNoUnsignedWrap())
6794 WrapType |= OBO::NoUnsignedWrap;
6795 for (const SCEV *Op : drop_begin(Add->operands()))
6796 X = X.addWithNoWrap(getRangeRef(Op, SignHint, Depth + 1), WrapType,
6797 RangeType);
6798 return setRange(Add, SignHint,
6799 ConservativeResult.intersectWith(X, RangeType));
6800 }
6801 case scMulExpr: {
6802 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(S);
6803 ConstantRange X = getRangeRef(Mul->getOperand(0), SignHint, Depth + 1);
6804 for (const SCEV *Op : drop_begin(Mul->operands()))
6805 X = X.multiply(getRangeRef(Op, SignHint, Depth + 1));
6806 return setRange(Mul, SignHint,
6807 ConservativeResult.intersectWith(X, RangeType));
6808 }
6809 case scUDivExpr: {
6810 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
6811 ConstantRange X = getRangeRef(UDiv->getLHS(), SignHint, Depth + 1);
6812 ConstantRange Y = getRangeRef(UDiv->getRHS(), SignHint, Depth + 1);
6813 return setRange(UDiv, SignHint,
6814 ConservativeResult.intersectWith(X.udiv(Y), RangeType));
6815 }
6816 case scAddRecExpr: {
6817 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(S);
6818 // If there's no unsigned wrap, the value will never be less than its
6819 // initial value.
6820 if (AddRec->hasNoUnsignedWrap()) {
6821 APInt UnsignedMinValue = getUnsignedRangeMin(AddRec->getStart());
6822 if (!UnsignedMinValue.isZero())
6823 ConservativeResult = ConservativeResult.intersectWith(
6824 ConstantRange(UnsignedMinValue, APInt(BitWidth, 0)), RangeType);
6825 }
6826
6827 // If there's no signed wrap, and all the operands except initial value have
6828 // the same sign or zero, the value won't ever be:
6829 // 1: smaller than initial value if operands are non negative,
6830 // 2: bigger than initial value if operands are non positive.
6831 // For both cases, value can not cross signed min/max boundary.
6832 if (AddRec->hasNoSignedWrap()) {
6833 bool AllNonNeg = true;
6834 bool AllNonPos = true;
6835 for (unsigned i = 1, e = AddRec->getNumOperands(); i != e; ++i) {
6836 if (!isKnownNonNegative(AddRec->getOperand(i)))
6837 AllNonNeg = false;
6838 if (!isKnownNonPositive(AddRec->getOperand(i)))
6839 AllNonPos = false;
6840 }
6841 if (AllNonNeg)
6842 ConservativeResult = ConservativeResult.intersectWith(
6845 RangeType);
6846 else if (AllNonPos)
6847 ConservativeResult = ConservativeResult.intersectWith(
6849 getSignedRangeMax(AddRec->getStart()) +
6850 1),
6851 RangeType);
6852 }
6853
6854 // TODO: non-affine addrec
6855 if (AddRec->isAffine()) {
6856 const SCEV *MaxBEScev =
6858 if (!isa<SCEVCouldNotCompute>(MaxBEScev)) {
6859 APInt MaxBECount = cast<SCEVConstant>(MaxBEScev)->getAPInt();
6860
6861 // Adjust MaxBECount to the same bitwidth as AddRec. We can truncate if
6862 // MaxBECount's active bits are all <= AddRec's bit width.
6863 if (MaxBECount.getBitWidth() > BitWidth &&
6864 MaxBECount.getActiveBits() <= BitWidth)
6865 MaxBECount = MaxBECount.trunc(BitWidth);
6866 else if (MaxBECount.getBitWidth() < BitWidth)
6867 MaxBECount = MaxBECount.zext(BitWidth);
6868
6869 if (MaxBECount.getBitWidth() == BitWidth) {
6870 auto RangeFromAffine = getRangeForAffineAR(
6871 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
6872 ConservativeResult =
6873 ConservativeResult.intersectWith(RangeFromAffine, RangeType);
6874
6875 auto RangeFromFactoring = getRangeViaFactoring(
6876 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
6877 ConservativeResult =
6878 ConservativeResult.intersectWith(RangeFromFactoring, RangeType);
6879 }
6880 }
6881
6882 // Now try symbolic BE count and more powerful methods.
6884 const SCEV *SymbolicMaxBECount =
6886 if (!isa<SCEVCouldNotCompute>(SymbolicMaxBECount) &&
6887 getTypeSizeInBits(MaxBEScev->getType()) <= BitWidth &&
6888 AddRec->hasNoSelfWrap()) {
6889 auto RangeFromAffineNew = getRangeForAffineNoSelfWrappingAR(
6890 AddRec, SymbolicMaxBECount, BitWidth, SignHint);
6891 ConservativeResult =
6892 ConservativeResult.intersectWith(RangeFromAffineNew, RangeType);
6893 }
6894 }
6895 }
6896
6897 return setRange(AddRec, SignHint, std::move(ConservativeResult));
6898 }
6899 case scUMaxExpr:
6900 case scSMaxExpr:
6901 case scUMinExpr:
6902 case scSMinExpr:
6903 case scSequentialUMinExpr: {
6905 switch (S->getSCEVType()) {
6906 case scUMaxExpr:
6907 ID = Intrinsic::umax;
6908 break;
6909 case scSMaxExpr:
6910 ID = Intrinsic::smax;
6911 break;
6912 case scUMinExpr:
6914 ID = Intrinsic::umin;
6915 break;
6916 case scSMinExpr:
6917 ID = Intrinsic::smin;
6918 break;
6919 default:
6920 llvm_unreachable("Unknown SCEVMinMaxExpr/SCEVSequentialMinMaxExpr.");
6921 }
6922
6923 const auto *NAry = cast<SCEVNAryExpr>(S);
6924 ConstantRange X = getRangeRef(NAry->getOperand(0), SignHint, Depth + 1);
6925 for (unsigned i = 1, e = NAry->getNumOperands(); i != e; ++i)
6926 X = X.intrinsic(
6927 ID, {X, getRangeRef(NAry->getOperand(i), SignHint, Depth + 1)});
6928 return setRange(S, SignHint,
6929 ConservativeResult.intersectWith(X, RangeType));
6930 }
6931 case scUnknown: {
6932 const SCEVUnknown *U = cast<SCEVUnknown>(S);
6933 Value *V = U->getValue();
6934
6935 // Check if the IR explicitly contains !range metadata.
6936 std::optional<ConstantRange> MDRange = GetRangeFromMetadata(V);
6937 if (MDRange)
6938 ConservativeResult =
6939 ConservativeResult.intersectWith(*MDRange, RangeType);
6940
6941 // Use facts about recurrences in the underlying IR. Note that add
6942 // recurrences are AddRecExprs and thus don't hit this path. This
6943 // primarily handles shift recurrences.
6944 auto CR = getRangeForUnknownRecurrence(U);
6945 ConservativeResult = ConservativeResult.intersectWith(CR);
6946
6947 // See if ValueTracking can give us a useful range.
6948 const DataLayout &DL = getDataLayout();
6949 KnownBits Known = computeKnownBits(V, DL, &AC, nullptr, &DT);
6950 if (Known.getBitWidth() != BitWidth)
6951 Known = Known.zextOrTrunc(BitWidth);
6952
6953 // ValueTracking may be able to compute a tighter result for the number of
6954 // sign bits than for the value of those sign bits.
6955 unsigned NS = ComputeNumSignBits(V, DL, &AC, nullptr, &DT);
6956 if (U->getType()->isPointerTy()) {
6957 // If the pointer size is larger than the index size type, this can cause
6958 // NS to be larger than BitWidth. So compensate for this.
6959 unsigned ptrSize = DL.getPointerTypeSizeInBits(U->getType());
6960 int ptrIdxDiff = ptrSize - BitWidth;
6961 if (ptrIdxDiff > 0 && ptrSize > BitWidth && NS > (unsigned)ptrIdxDiff)
6962 NS -= ptrIdxDiff;
6963 }
6964
6965 if (NS > 1) {
6966 // If we know any of the sign bits, we know all of the sign bits.
6967 if (!Known.Zero.getHiBits(NS).isZero())
6968 Known.Zero.setHighBits(NS);
6969 if (!Known.One.getHiBits(NS).isZero())
6970 Known.One.setHighBits(NS);
6971 }
6972
6973 if (Known.getMinValue() != Known.getMaxValue() + 1)
6974 ConservativeResult = ConservativeResult.intersectWith(
6975 ConstantRange(Known.getMinValue(), Known.getMaxValue() + 1),
6976 RangeType);
6977 if (NS > 1)
6978 ConservativeResult = ConservativeResult.intersectWith(
6979 ConstantRange(APInt::getSignedMinValue(BitWidth).ashr(NS - 1),
6980 APInt::getSignedMaxValue(BitWidth).ashr(NS - 1) + 1),
6981 RangeType);
6982
6983 if (U->getType()->isPointerTy() && SignHint == HINT_RANGE_UNSIGNED) {
6984 // Strengthen the range if the underlying IR value is a
6985 // global/alloca/heap allocation using the size of the object.
6986 bool CanBeNull, CanBeFreed;
6987 uint64_t DerefBytes =
6988 V->getPointerDereferenceableBytes(DL, CanBeNull, CanBeFreed);
6989 if (DerefBytes > 1 && isUIntN(BitWidth, DerefBytes)) {
6990 // The highest address the object can start is DerefBytes bytes before
6991 // the end (unsigned max value). If this value is not a multiple of the
6992 // alignment, the last possible start value is the next lowest multiple
6993 // of the alignment. Note: The computations below cannot overflow,
6994 // because if they would there's no possible start address for the
6995 // object.
6996 APInt MaxVal =
6997 APInt::getMaxValue(BitWidth) - APInt(BitWidth, DerefBytes);
6998 uint64_t Align = U->getValue()->getPointerAlignment(DL).value();
6999 uint64_t Rem = MaxVal.urem(Align);
7000 MaxVal -= APInt(BitWidth, Rem);
7001 APInt MinVal = APInt::getZero(BitWidth);
7002 if (llvm::isKnownNonZero(V, DL))
7003 MinVal = Align;
7004 ConservativeResult = ConservativeResult.intersectWith(
7005 ConstantRange::getNonEmpty(MinVal, MaxVal + 1), RangeType);
7006 }
7007 }
7008
7009 // A range of Phi is a subset of union of all ranges of its input.
7010 if (PHINode *Phi = dyn_cast<PHINode>(V)) {
7011 // Make sure that we do not run over cycled Phis.
7012 if (PendingPhiRanges.insert(Phi).second) {
7013 ConstantRange RangeFromOps(BitWidth, /*isFullSet=*/false);
7014
7015 for (const auto &Op : Phi->operands()) {
7016 auto OpRange = getRangeRef(getSCEV(Op), SignHint, Depth + 1);
7017 RangeFromOps = RangeFromOps.unionWith(OpRange);
7018 // No point to continue if we already have a full set.
7019 if (RangeFromOps.isFullSet())
7020 break;
7021 }
7022 ConservativeResult =
7023 ConservativeResult.intersectWith(RangeFromOps, RangeType);
7024 bool Erased = PendingPhiRanges.erase(Phi);
7025 assert(Erased && "Failed to erase Phi properly?");
7026 (void)Erased;
7027 }
7028 }
7029
7030 // vscale can't be equal to zero
7031 if (const auto *II = dyn_cast<IntrinsicInst>(V))
7032 if (II->getIntrinsicID() == Intrinsic::vscale) {
7033 ConstantRange Disallowed = APInt::getZero(BitWidth);
7034 ConservativeResult = ConservativeResult.difference(Disallowed);
7035 }
7036
7037 return setRange(U, SignHint, std::move(ConservativeResult));
7038 }
7039 case scCouldNotCompute:
7040 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
7041 }
7042
7043 return setRange(S, SignHint, std::move(ConservativeResult));
7044}
7045
7046// Given a StartRange, Step and MaxBECount for an expression compute a range of
7047// values that the expression can take. Initially, the expression has a value
7048// from StartRange and then is changed by Step up to MaxBECount times. Signed
7049// argument defines if we treat Step as signed or unsigned.
7051 const ConstantRange &StartRange,
7052 const APInt &MaxBECount,
7053 bool Signed) {
7054 unsigned BitWidth = Step.getBitWidth();
7055 assert(BitWidth == StartRange.getBitWidth() &&
7056 BitWidth == MaxBECount.getBitWidth() && "mismatched bit widths");
7057 // If either Step or MaxBECount is 0, then the expression won't change, and we
7058 // just need to return the initial range.
7059 if (Step == 0 || MaxBECount == 0)
7060 return StartRange;
7061
7062 // If we don't know anything about the initial value (i.e. StartRange is
7063 // FullRange), then we don't know anything about the final range either.
7064 // Return FullRange.
7065 if (StartRange.isFullSet())
7066 return ConstantRange::getFull(BitWidth);
7067
7068 // If Step is signed and negative, then we use its absolute value, but we also
7069 // note that we're moving in the opposite direction.
7070 bool Descending = Signed && Step.isNegative();
7071
7072 if (Signed)
7073 // This is correct even for INT_SMIN. Let's look at i8 to illustrate this:
7074 // abs(INT_SMIN) = abs(-128) = abs(0x80) = -0x80 = 0x80 = 128.
7075 // This equations hold true due to the well-defined wrap-around behavior of
7076 // APInt.
7077 Step = Step.abs();
7078
7079 // Check if Offset is more than full span of BitWidth. If it is, the
7080 // expression is guaranteed to overflow.
7081 if (APInt::getMaxValue(StartRange.getBitWidth()).udiv(Step).ult(MaxBECount))
7082 return ConstantRange::getFull(BitWidth);
7083
7084 // Offset is by how much the expression can change. Checks above guarantee no
7085 // overflow here.
7086 APInt Offset = Step * MaxBECount;
7087
7088 // Minimum value of the final range will match the minimal value of StartRange
7089 // if the expression is increasing and will be decreased by Offset otherwise.
7090 // Maximum value of the final range will match the maximal value of StartRange
7091 // if the expression is decreasing and will be increased by Offset otherwise.
7092 APInt StartLower = StartRange.getLower();
7093 APInt StartUpper = StartRange.getUpper() - 1;
7094 APInt MovedBoundary = Descending ? (StartLower - std::move(Offset))
7095 : (StartUpper + std::move(Offset));
7096
7097 // It's possible that the new minimum/maximum value will fall into the initial
7098 // range (due to wrap around). This means that the expression can take any
7099 // value in this bitwidth, and we have to return full range.
7100 if (StartRange.contains(MovedBoundary))
7101 return ConstantRange::getFull(BitWidth);
7102
7103 APInt NewLower =
7104 Descending ? std::move(MovedBoundary) : std::move(StartLower);
7105 APInt NewUpper =
7106 Descending ? std::move(StartUpper) : std::move(MovedBoundary);
7107 NewUpper += 1;
7108
7109 // No overflow detected, return [StartLower, StartUpper + Offset + 1) range.
7110 return ConstantRange::getNonEmpty(std::move(NewLower), std::move(NewUpper));
7111}
7112
7113ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start,
7114 const SCEV *Step,
7115 const APInt &MaxBECount) {
7116 assert(getTypeSizeInBits(Start->getType()) ==
7117 getTypeSizeInBits(Step->getType()) &&
7118 getTypeSizeInBits(Start->getType()) == MaxBECount.getBitWidth() &&
7119 "mismatched bit widths");
7120
7121 // First, consider step signed.
7122 ConstantRange StartSRange = getSignedRange(Start);
7123 ConstantRange StepSRange = getSignedRange(Step);
7124
7125 // If Step can be both positive and negative, we need to find ranges for the
7126 // maximum absolute step values in both directions and union them.
7127 ConstantRange SR = getRangeForAffineARHelper(
7128 StepSRange.getSignedMin(), StartSRange, MaxBECount, /* Signed = */ true);
7130 StartSRange, MaxBECount,
7131 /* Signed = */ true));
7132
7133 // Next, consider step unsigned.
7134 ConstantRange UR = getRangeForAffineARHelper(
7135 getUnsignedRangeMax(Step), getUnsignedRange(Start), MaxBECount,
7136 /* Signed = */ false);
7137
7138 // Finally, intersect signed and unsigned ranges.
7140}
7141
7142ConstantRange ScalarEvolution::getRangeForAffineNoSelfWrappingAR(
7143 const SCEVAddRecExpr *AddRec, const SCEV *MaxBECount, unsigned BitWidth,
7144 ScalarEvolution::RangeSignHint SignHint) {
7145 assert(AddRec->isAffine() && "Non-affine AddRecs are not suppored!\n");
7146 assert(AddRec->hasNoSelfWrap() &&
7147 "This only works for non-self-wrapping AddRecs!");
7148 const bool IsSigned = SignHint == HINT_RANGE_SIGNED;
7149 const SCEV *Step = AddRec->getStepRecurrence(*this);
7150 // Only deal with constant step to save compile time.
7151 if (!isa<SCEVConstant>(Step))
7152 return ConstantRange::getFull(BitWidth);
7153 // Let's make sure that we can prove that we do not self-wrap during
7154 // MaxBECount iterations. We need this because MaxBECount is a maximum
7155 // iteration count estimate, and we might infer nw from some exit for which we
7156 // do not know max exit count (or any other side reasoning).
7157 // TODO: Turn into assert at some point.
7158 if (getTypeSizeInBits(MaxBECount->getType()) >
7159 getTypeSizeInBits(AddRec->getType()))
7160 return ConstantRange::getFull(BitWidth);
7161 MaxBECount = getNoopOrZeroExtend(MaxBECount, AddRec->getType());
7162 const SCEV *RangeWidth = getMinusOne(AddRec->getType());
7163 const SCEV *StepAbs = getUMinExpr(Step, getNegativeSCEV(Step));
7164 const SCEV *MaxItersWithoutWrap = getUDivExpr(RangeWidth, StepAbs);
7165 if (!isKnownPredicateViaConstantRanges(ICmpInst::ICMP_ULE, MaxBECount,
7166 MaxItersWithoutWrap))
7167 return ConstantRange::getFull(BitWidth);
7168
7169 ICmpInst::Predicate LEPred =
7171 ICmpInst::Predicate GEPred =
7173 const SCEV *End = AddRec->evaluateAtIteration(MaxBECount, *this);
7174
7175 // We know that there is no self-wrap. Let's take Start and End values and
7176 // look at all intermediate values V1, V2, ..., Vn that IndVar takes during
7177 // the iteration. They either lie inside the range [Min(Start, End),
7178 // Max(Start, End)] or outside it:
7179 //
7180 // Case 1: RangeMin ... Start V1 ... VN End ... RangeMax;
7181 // Case 2: RangeMin Vk ... V1 Start ... End Vn ... Vk + 1 RangeMax;
7182 //
7183 // No self wrap flag guarantees that the intermediate values cannot be BOTH
7184 // outside and inside the range [Min(Start, End), Max(Start, End)]. Using that
7185 // knowledge, let's try to prove that we are dealing with Case 1. It is so if
7186 // Start <= End and step is positive, or Start >= End and step is negative.
7187 const SCEV *Start = applyLoopGuards(AddRec->getStart(), AddRec->getLoop());
7188 ConstantRange StartRange = getRangeRef(Start, SignHint);
7189 ConstantRange EndRange = getRangeRef(End, SignHint);
7190 ConstantRange RangeBetween = StartRange.unionWith(EndRange);
7191 // If they already cover full iteration space, we will know nothing useful
7192 // even if we prove what we want to prove.
7193 if (RangeBetween.isFullSet())
7194 return RangeBetween;
7195 // Only deal with ranges that do not wrap (i.e. RangeMin < RangeMax).
7196 bool IsWrappedSet = IsSigned ? RangeBetween.isSignWrappedSet()
7197 : RangeBetween.isWrappedSet();
7198 if (IsWrappedSet)
7199 return ConstantRange::getFull(BitWidth);
7200
7201 if (isKnownPositive(Step) &&
7202 isKnownPredicateViaConstantRanges(LEPred, Start, End))
7203 return RangeBetween;
7204 if (isKnownNegative(Step) &&
7205 isKnownPredicateViaConstantRanges(GEPred, Start, End))
7206 return RangeBetween;
7207 return ConstantRange::getFull(BitWidth);
7208}
7209
7210ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start,
7211 const SCEV *Step,
7212 const APInt &MaxBECount) {
7213 // RangeOf({C?A:B,+,C?P:Q}) == RangeOf(C?{A,+,P}:{B,+,Q})
7214 // == RangeOf({A,+,P}) union RangeOf({B,+,Q})
7215
7216 unsigned BitWidth = MaxBECount.getBitWidth();
7217 assert(getTypeSizeInBits(Start->getType()) == BitWidth &&
7218 getTypeSizeInBits(Step->getType()) == BitWidth &&
7219 "mismatched bit widths");
7220
7221 struct SelectPattern {
7222 Value *Condition = nullptr;
7223 APInt TrueValue;
7224 APInt FalseValue;
7225
7226 explicit SelectPattern(ScalarEvolution &SE, unsigned BitWidth,
7227 const SCEV *S) {
7228 std::optional<unsigned> CastOp;
7229 APInt Offset(BitWidth, 0);
7230
7232 "Should be!");
7233
7234 // Peel off a constant offset. In the future we could consider being
7235 // smarter here and handle {Start+Step,+,Step} too.
7236 const APInt *Off;
7237 if (match(S, m_scev_Add(m_scev_APInt(Off), m_SCEV(S))))
7238 Offset = *Off;
7239
7240 // Peel off a cast operation
7241 if (auto *SCast = dyn_cast<SCEVIntegralCastExpr>(S)) {
7242 CastOp = SCast->getSCEVType();
7243 S = SCast->getOperand();
7244 }
7245
7246 using namespace llvm::PatternMatch;
7247
7248 auto *SU = dyn_cast<SCEVUnknown>(S);
7249 const APInt *TrueVal, *FalseVal;
7250 if (!SU ||
7251 !match(SU->getValue(), m_Select(m_Value(Condition), m_APInt(TrueVal),
7252 m_APInt(FalseVal)))) {
7253 Condition = nullptr;
7254 return;
7255 }
7256
7257 TrueValue = *TrueVal;
7258 FalseValue = *FalseVal;
7259
7260 // Re-apply the cast we peeled off earlier
7261 if (CastOp)
7262 switch (*CastOp) {
7263 default:
7264 llvm_unreachable("Unknown SCEV cast type!");
7265
7266 case scTruncate:
7267 TrueValue = TrueValue.trunc(BitWidth);
7268 FalseValue = FalseValue.trunc(BitWidth);
7269 break;
7270 case scZeroExtend:
7271 TrueValue = TrueValue.zext(BitWidth);
7272 FalseValue = FalseValue.zext(BitWidth);
7273 break;
7274 case scSignExtend:
7275 TrueValue = TrueValue.sext(BitWidth);
7276 FalseValue = FalseValue.sext(BitWidth);
7277 break;
7278 }
7279
7280 // Re-apply the constant offset we peeled off earlier
7281 TrueValue += Offset;
7282 FalseValue += Offset;
7283 }
7284
7285 bool isRecognized() { return Condition != nullptr; }
7286 };
7287
7288 SelectPattern StartPattern(*this, BitWidth, Start);
7289 if (!StartPattern.isRecognized())
7290 return ConstantRange::getFull(BitWidth);
7291
7292 SelectPattern StepPattern(*this, BitWidth, Step);
7293 if (!StepPattern.isRecognized())
7294 return ConstantRange::getFull(BitWidth);
7295
7296 if (StartPattern.Condition != StepPattern.Condition) {
7297 // We don't handle this case today; but we could, by considering four
7298 // possibilities below instead of two. I'm not sure if there are cases where
7299 // that will help over what getRange already does, though.
7300 return ConstantRange::getFull(BitWidth);
7301 }
7302
7303 // NB! Calling ScalarEvolution::getConstant is fine, but we should not try to
7304 // construct arbitrary general SCEV expressions here. This function is called
7305 // from deep in the call stack, and calling getSCEV (on a sext instruction,
7306 // say) can end up caching a suboptimal value.
7307
7308 // FIXME: without the explicit `this` receiver below, MSVC errors out with
7309 // C2352 and C2512 (otherwise it isn't needed).
7310
7311 const SCEV *TrueStart = this->getConstant(StartPattern.TrueValue);
7312 const SCEV *TrueStep = this->getConstant(StepPattern.TrueValue);
7313 const SCEV *FalseStart = this->getConstant(StartPattern.FalseValue);
7314 const SCEV *FalseStep = this->getConstant(StepPattern.FalseValue);
7315
7316 ConstantRange TrueRange =
7317 this->getRangeForAffineAR(TrueStart, TrueStep, MaxBECount);
7318 ConstantRange FalseRange =
7319 this->getRangeForAffineAR(FalseStart, FalseStep, MaxBECount);
7320
7321 return TrueRange.unionWith(FalseRange);
7322}
7323
7324SCEV::NoWrapFlags ScalarEvolution::getNoWrapFlagsFromUB(const Value *V) {
7325 if (isa<ConstantExpr>(V)) return SCEV::FlagAnyWrap;
7326 const BinaryOperator *BinOp = cast<BinaryOperator>(V);
7327
7328 // Return early if there are no flags to propagate to the SCEV.
7330 if (BinOp->hasNoUnsignedWrap())
7332 if (BinOp->hasNoSignedWrap())
7334 if (Flags == SCEV::FlagAnyWrap)
7335 return SCEV::FlagAnyWrap;
7336
7337 return isSCEVExprNeverPoison(BinOp) ? Flags : SCEV::FlagAnyWrap;
7338}
7339
7340const Instruction *
7341ScalarEvolution::getNonTrivialDefiningScopeBound(const SCEV *S) {
7342 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S))
7343 return &*AddRec->getLoop()->getHeader()->begin();
7344 if (auto *U = dyn_cast<SCEVUnknown>(S))
7345 if (auto *I = dyn_cast<Instruction>(U->getValue()))
7346 return I;
7347 return nullptr;
7348}
7349
7350const Instruction *
7351ScalarEvolution::getDefiningScopeBound(ArrayRef<const SCEV *> Ops,
7352 bool &Precise) {
7353 Precise = true;
7354 // Do a bounded search of the def relation of the requested SCEVs.
7355 SmallPtrSet<const SCEV *, 16> Visited;
7357 auto pushOp = [&](const SCEV *S) {
7358 if (!Visited.insert(S).second)
7359 return;
7360 // Threshold of 30 here is arbitrary.
7361 if (Visited.size() > 30) {
7362 Precise = false;
7363 return;
7364 }
7365 Worklist.push_back(S);
7366 };
7367
7368 for (const auto *S : Ops)
7369 pushOp(S);
7370
7371 const Instruction *Bound = nullptr;
7372 while (!Worklist.empty()) {
7373 auto *S = Worklist.pop_back_val();
7374 if (auto *DefI = getNonTrivialDefiningScopeBound(S)) {
7375 if (!Bound || DT.dominates(Bound, DefI))
7376 Bound = DefI;
7377 } else {
7378 for (const auto *Op : S->operands())
7379 pushOp(Op);
7380 }
7381 }
7382 return Bound ? Bound : &*F.getEntryBlock().begin();
7383}
7384
7385const Instruction *
7386ScalarEvolution::getDefiningScopeBound(ArrayRef<const SCEV *> Ops) {
7387 bool Discard;
7388 return getDefiningScopeBound(Ops, Discard);
7389}
7390
7391bool ScalarEvolution::isGuaranteedToTransferExecutionTo(const Instruction *A,
7392 const Instruction *B) {
7393 if (A->getParent() == B->getParent() &&
7395 B->getIterator()))
7396 return true;
7397
7398 auto *BLoop = LI.getLoopFor(B->getParent());
7399 if (BLoop && BLoop->getHeader() == B->getParent() &&
7400 BLoop->getLoopPreheader() == A->getParent() &&
7402 A->getParent()->end()) &&
7403 isGuaranteedToTransferExecutionToSuccessor(B->getParent()->begin(),
7404 B->getIterator()))
7405 return true;
7406 return false;
7407}
7408
7409bool ScalarEvolution::isGuaranteedNotToBePoison(const SCEV *Op) {
7410 SCEVPoisonCollector PC(/* LookThroughMaybePoisonBlocking */ true);
7411 visitAll(Op, PC);
7412 return PC.MaybePoison.empty();
7413}
7414
7415bool ScalarEvolution::isGuaranteedNotToCauseUB(const SCEV *Op) {
7416 return !SCEVExprContains(Op, [this](const SCEV *S) {
7417 const SCEV *Op1;
7418 bool M = match(S, m_scev_UDiv(m_SCEV(), m_SCEV(Op1)));
7419 // The UDiv may be UB if the divisor is poison or zero. Unless the divisor
7420 // is a non-zero constant, we have to assume the UDiv may be UB.
7421 return M && (!isKnownNonZero(Op1) || !isGuaranteedNotToBePoison(Op1));
7422 });
7423}
7424
7425bool ScalarEvolution::isSCEVExprNeverPoison(const Instruction *I) {
7426 // Only proceed if we can prove that I does not yield poison.
7428 return false;
7429
7430 // At this point we know that if I is executed, then it does not wrap
7431 // according to at least one of NSW or NUW. If I is not executed, then we do
7432 // not know if the calculation that I represents would wrap. Multiple
7433 // instructions can map to the same SCEV. If we apply NSW or NUW from I to
7434 // the SCEV, we must guarantee no wrapping for that SCEV also when it is
7435 // derived from other instructions that map to the same SCEV. We cannot make
7436 // that guarantee for cases where I is not executed. So we need to find a
7437 // upper bound on the defining scope for the SCEV, and prove that I is
7438 // executed every time we enter that scope. When the bounding scope is a
7439 // loop (the common case), this is equivalent to proving I executes on every
7440 // iteration of that loop.
7442 for (const Use &Op : I->operands()) {
7443 // I could be an extractvalue from a call to an overflow intrinsic.
7444 // TODO: We can do better here in some cases.
7445 if (isSCEVable(Op->getType()))
7446 SCEVOps.push_back(getSCEV(Op));
7447 }
7448 auto *DefI = getDefiningScopeBound(SCEVOps);
7449 return isGuaranteedToTransferExecutionTo(DefI, I);
7450}
7451
7452bool ScalarEvolution::isAddRecNeverPoison(const Instruction *I, const Loop *L) {
7453 // If we know that \c I can never be poison period, then that's enough.
7454 if (isSCEVExprNeverPoison(I))
7455 return true;
7456
7457 // If the loop only has one exit, then we know that, if the loop is entered,
7458 // any instruction dominating that exit will be executed. If any such
7459 // instruction would result in UB, the addrec cannot be poison.
7460 //
7461 // This is basically the same reasoning as in isSCEVExprNeverPoison(), but
7462 // also handles uses outside the loop header (they just need to dominate the
7463 // single exit).
7464
7465 auto *ExitingBB = L->getExitingBlock();
7466 if (!ExitingBB || !loopHasNoAbnormalExits(L))
7467 return false;
7468
7469 SmallPtrSet<const Value *, 16> KnownPoison;
7471
7472 // We start by assuming \c I, the post-inc add recurrence, is poison. Only
7473 // things that are known to be poison under that assumption go on the
7474 // Worklist.
7475 KnownPoison.insert(I);
7476 Worklist.push_back(I);
7477
7478 while (!Worklist.empty()) {
7479 const Instruction *Poison = Worklist.pop_back_val();
7480
7481 for (const Use &U : Poison->uses()) {
7482 const Instruction *PoisonUser = cast<Instruction>(U.getUser());
7483 if (mustTriggerUB(PoisonUser, KnownPoison) &&
7484 DT.dominates(PoisonUser->getParent(), ExitingBB))
7485 return true;
7486
7487 if (propagatesPoison(U) && L->contains(PoisonUser))
7488 if (KnownPoison.insert(PoisonUser).second)
7489 Worklist.push_back(PoisonUser);
7490 }
7491 }
7492
7493 return false;
7494}
7495
7496ScalarEvolution::LoopProperties
7497ScalarEvolution::getLoopProperties(const Loop *L) {
7498 using LoopProperties = ScalarEvolution::LoopProperties;
7499
7500 auto Itr = LoopPropertiesCache.find(L);
7501 if (Itr == LoopPropertiesCache.end()) {
7502 auto HasSideEffects = [](Instruction *I) {
7503 if (auto *SI = dyn_cast<StoreInst>(I))
7504 return !SI->isSimple();
7505
7506 if (I->mayThrow())
7507 return true;
7508
7509 // Non-volatile memset / memcpy do not count as side-effect for forward
7510 // progress.
7511 if (isa<MemIntrinsic>(I) && !I->isVolatile())
7512 return false;
7513
7514 return I->mayWriteToMemory();
7515 };
7516
7517 LoopProperties LP = {/* HasNoAbnormalExits */ true,
7518 /*HasNoSideEffects*/ true};
7519
7520 for (auto *BB : L->getBlocks())
7521 for (auto &I : *BB) {
7523 LP.HasNoAbnormalExits = false;
7524 if (HasSideEffects(&I))
7525 LP.HasNoSideEffects = false;
7526 if (!LP.HasNoAbnormalExits && !LP.HasNoSideEffects)
7527 break; // We're already as pessimistic as we can get.
7528 }
7529
7530 auto InsertPair = LoopPropertiesCache.insert({L, LP});
7531 assert(InsertPair.second && "We just checked!");
7532 Itr = InsertPair.first;
7533 }
7534
7535 return Itr->second;
7536}
7537
7539 // A mustprogress loop without side effects must be finite.
7540 // TODO: The check used here is very conservative. It's only *specific*
7541 // side effects which are well defined in infinite loops.
7542 return isFinite(L) || (isMustProgress(L) && loopHasNoSideEffects(L));
7543}
7544
7545const SCEV *ScalarEvolution::createSCEVIter(Value *V) {
7546 // Worklist item with a Value and a bool indicating whether all operands have
7547 // been visited already.
7550
7551 Stack.emplace_back(V, true);
7552 Stack.emplace_back(V, false);
7553 while (!Stack.empty()) {
7554 auto E = Stack.pop_back_val();
7555 Value *CurV = E.getPointer();
7556
7557 if (getExistingSCEV(CurV))
7558 continue;
7559
7561 const SCEV *CreatedSCEV = nullptr;
7562 // If all operands have been visited already, create the SCEV.
7563 if (E.getInt()) {
7564 CreatedSCEV = createSCEV(CurV);
7565 } else {
7566 // Otherwise get the operands we need to create SCEV's for before creating
7567 // the SCEV for CurV. If the SCEV for CurV can be constructed trivially,
7568 // just use it.
7569 CreatedSCEV = getOperandsToCreate(CurV, Ops);
7570 }
7571
7572 if (CreatedSCEV) {
7573 insertValueToMap(CurV, CreatedSCEV);
7574 } else {
7575 // Queue CurV for SCEV creation, followed by its's operands which need to
7576 // be constructed first.
7577 Stack.emplace_back(CurV, true);
7578 for (Value *Op : Ops)
7579 Stack.emplace_back(Op, false);
7580 }
7581 }
7582
7583 return getExistingSCEV(V);
7584}
7585
7586const SCEV *
7587ScalarEvolution::getOperandsToCreate(Value *V, SmallVectorImpl<Value *> &Ops) {
7588 if (!isSCEVable(V->getType()))
7589 return getUnknown(V);
7590
7591 if (Instruction *I = dyn_cast<Instruction>(V)) {
7592 // Don't attempt to analyze instructions in blocks that aren't
7593 // reachable. Such instructions don't matter, and they aren't required
7594 // to obey basic rules for definitions dominating uses which this
7595 // analysis depends on.
7596 if (!DT.isReachableFromEntry(I->getParent()))
7597 return getUnknown(PoisonValue::get(V->getType()));
7598 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7599 return getConstant(CI);
7600 else if (isa<GlobalAlias>(V))
7601 return getUnknown(V);
7602 else if (!isa<ConstantExpr>(V))
7603 return getUnknown(V);
7604
7606 if (auto BO =
7608 bool IsConstArg = isa<ConstantInt>(BO->RHS);
7609 switch (BO->Opcode) {
7610 case Instruction::Add:
7611 case Instruction::Mul: {
7612 // For additions and multiplications, traverse add/mul chains for which we
7613 // can potentially create a single SCEV, to reduce the number of
7614 // get{Add,Mul}Expr calls.
7615 do {
7616 if (BO->Op) {
7617 if (BO->Op != V && getExistingSCEV(BO->Op)) {
7618 Ops.push_back(BO->Op);
7619 break;
7620 }
7621 }
7622 Ops.push_back(BO->RHS);
7623 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7625 if (!NewBO ||
7626 (BO->Opcode == Instruction::Add &&
7627 (NewBO->Opcode != Instruction::Add &&
7628 NewBO->Opcode != Instruction::Sub)) ||
7629 (BO->Opcode == Instruction::Mul &&
7630 NewBO->Opcode != Instruction::Mul)) {
7631 Ops.push_back(BO->LHS);
7632 break;
7633 }
7634 // CreateSCEV calls getNoWrapFlagsFromUB, which under certain conditions
7635 // requires a SCEV for the LHS.
7636 if (BO->Op && (BO->IsNSW || BO->IsNUW)) {
7637 auto *I = dyn_cast<Instruction>(BO->Op);
7638 if (I && programUndefinedIfPoison(I)) {
7639 Ops.push_back(BO->LHS);
7640 break;
7641 }
7642 }
7643 BO = NewBO;
7644 } while (true);
7645 return nullptr;
7646 }
7647 case Instruction::Sub:
7648 case Instruction::UDiv:
7649 case Instruction::URem:
7650 break;
7651 case Instruction::AShr:
7652 case Instruction::Shl:
7653 case Instruction::Xor:
7654 if (!IsConstArg)
7655 return nullptr;
7656 break;
7657 case Instruction::And:
7658 case Instruction::Or:
7659 if (!IsConstArg && !BO->LHS->getType()->isIntegerTy(1))
7660 return nullptr;
7661 break;
7662 case Instruction::LShr:
7663 return getUnknown(V);
7664 default:
7665 llvm_unreachable("Unhandled binop");
7666 break;
7667 }
7668
7669 Ops.push_back(BO->LHS);
7670 Ops.push_back(BO->RHS);
7671 return nullptr;
7672 }
7673
7674 switch (U->getOpcode()) {
7675 case Instruction::Trunc:
7676 case Instruction::ZExt:
7677 case Instruction::SExt:
7678 case Instruction::PtrToInt:
7679 Ops.push_back(U->getOperand(0));
7680 return nullptr;
7681
7682 case Instruction::BitCast:
7683 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType())) {
7684 Ops.push_back(U->getOperand(0));
7685 return nullptr;
7686 }
7687 return getUnknown(V);
7688
7689 case Instruction::SDiv:
7690 case Instruction::SRem:
7691 Ops.push_back(U->getOperand(0));
7692 Ops.push_back(U->getOperand(1));
7693 return nullptr;
7694
7695 case Instruction::GetElementPtr:
7696 assert(cast<GEPOperator>(U)->getSourceElementType()->isSized() &&
7697 "GEP source element type must be sized");
7698 llvm::append_range(Ops, U->operands());
7699 return nullptr;
7700
7701 case Instruction::IntToPtr:
7702 return getUnknown(V);
7703
7704 case Instruction::PHI:
7705 // Keep constructing SCEVs' for phis recursively for now.
7706 return nullptr;
7707
7708 case Instruction::Select: {
7709 // Check if U is a select that can be simplified to a SCEVUnknown.
7710 auto CanSimplifyToUnknown = [this, U]() {
7711 if (U->getType()->isIntegerTy(1) || isa<ConstantInt>(U->getOperand(0)))
7712 return false;
7713
7714 auto *ICI = dyn_cast<ICmpInst>(U->getOperand(0));
7715 if (!ICI)
7716 return false;
7717 Value *LHS = ICI->getOperand(0);
7718 Value *RHS = ICI->getOperand(1);
7719 if (ICI->getPredicate() == CmpInst::ICMP_EQ ||
7720 ICI->getPredicate() == CmpInst::ICMP_NE) {
7722 return true;
7723 } else if (getTypeSizeInBits(LHS->getType()) >
7724 getTypeSizeInBits(U->getType()))
7725 return true;
7726 return false;
7727 };
7728 if (CanSimplifyToUnknown())
7729 return getUnknown(U);
7730
7731 llvm::append_range(Ops, U->operands());
7732 return nullptr;
7733 break;
7734 }
7735 case Instruction::Call:
7736 case Instruction::Invoke:
7737 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand()) {
7738 Ops.push_back(RV);
7739 return nullptr;
7740 }
7741
7742 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
7743 switch (II->getIntrinsicID()) {
7744 case Intrinsic::abs:
7745 Ops.push_back(II->getArgOperand(0));
7746 return nullptr;
7747 case Intrinsic::umax:
7748 case Intrinsic::umin:
7749 case Intrinsic::smax:
7750 case Intrinsic::smin:
7751 case Intrinsic::usub_sat:
7752 case Intrinsic::uadd_sat:
7753 Ops.push_back(II->getArgOperand(0));
7754 Ops.push_back(II->getArgOperand(1));
7755 return nullptr;
7756 case Intrinsic::start_loop_iterations:
7757 case Intrinsic::annotation:
7758 case Intrinsic::ptr_annotation:
7759 Ops.push_back(II->getArgOperand(0));
7760 return nullptr;
7761 default:
7762 break;
7763 }
7764 }
7765 break;
7766 }
7767
7768 return nullptr;
7769}
7770
7771const SCEV *ScalarEvolution::createSCEV(Value *V) {
7772 if (!isSCEVable(V->getType()))
7773 return getUnknown(V);
7774
7775 if (Instruction *I = dyn_cast<Instruction>(V)) {
7776 // Don't attempt to analyze instructions in blocks that aren't
7777 // reachable. Such instructions don't matter, and they aren't required
7778 // to obey basic rules for definitions dominating uses which this
7779 // analysis depends on.
7780 if (!DT.isReachableFromEntry(I->getParent()))
7781 return getUnknown(PoisonValue::get(V->getType()));
7782 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7783 return getConstant(CI);
7784 else if (isa<GlobalAlias>(V))
7785 return getUnknown(V);
7786 else if (!isa<ConstantExpr>(V))
7787 return getUnknown(V);
7788
7789 const SCEV *LHS;
7790 const SCEV *RHS;
7791
7793 if (auto BO =
7795 switch (BO->Opcode) {
7796 case Instruction::Add: {
7797 // The simple thing to do would be to just call getSCEV on both operands
7798 // and call getAddExpr with the result. However if we're looking at a
7799 // bunch of things all added together, this can be quite inefficient,
7800 // because it leads to N-1 getAddExpr calls for N ultimate operands.
7801 // Instead, gather up all the operands and make a single getAddExpr call.
7802 // LLVM IR canonical form means we need only traverse the left operands.
7804 do {
7805 if (BO->Op) {
7806 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
7807 AddOps.push_back(OpSCEV);
7808 break;
7809 }
7810
7811 // If a NUW or NSW flag can be applied to the SCEV for this
7812 // addition, then compute the SCEV for this addition by itself
7813 // with a separate call to getAddExpr. We need to do that
7814 // instead of pushing the operands of the addition onto AddOps,
7815 // since the flags are only known to apply to this particular
7816 // addition - they may not apply to other additions that can be
7817 // formed with operands from AddOps.
7818 const SCEV *RHS = getSCEV(BO->RHS);
7819 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
7820 if (Flags != SCEV::FlagAnyWrap) {
7821 const SCEV *LHS = getSCEV(BO->LHS);
7822 if (BO->Opcode == Instruction::Sub)
7823 AddOps.push_back(getMinusSCEV(LHS, RHS, Flags));
7824 else
7825 AddOps.push_back(getAddExpr(LHS, RHS, Flags));
7826 break;
7827 }
7828 }
7829
7830 if (BO->Opcode == Instruction::Sub)
7831 AddOps.push_back(getNegativeSCEV(getSCEV(BO->RHS)));
7832 else
7833 AddOps.push_back(getSCEV(BO->RHS));
7834
7835 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7837 if (!NewBO || (NewBO->Opcode != Instruction::Add &&
7838 NewBO->Opcode != Instruction::Sub)) {
7839 AddOps.push_back(getSCEV(BO->LHS));
7840 break;
7841 }
7842 BO = NewBO;
7843 } while (true);
7844
7845 return getAddExpr(AddOps);
7846 }
7847
7848 case Instruction::Mul: {
7850 do {
7851 if (BO->Op) {
7852 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
7853 MulOps.push_back(OpSCEV);
7854 break;
7855 }
7856
7857 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
7858 if (Flags != SCEV::FlagAnyWrap) {
7859 LHS = getSCEV(BO->LHS);
7860 RHS = getSCEV(BO->RHS);
7861 MulOps.push_back(getMulExpr(LHS, RHS, Flags));
7862 break;
7863 }
7864 }
7865
7866 MulOps.push_back(getSCEV(BO->RHS));
7867 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7869 if (!NewBO || NewBO->Opcode != Instruction::Mul) {
7870 MulOps.push_back(getSCEV(BO->LHS));
7871 break;
7872 }
7873 BO = NewBO;
7874 } while (true);
7875
7876 return getMulExpr(MulOps);
7877 }
7878 case Instruction::UDiv:
7879 LHS = getSCEV(BO->LHS);
7880 RHS = getSCEV(BO->RHS);
7881 return getUDivExpr(LHS, RHS);
7882 case Instruction::URem:
7883 LHS = getSCEV(BO->LHS);
7884 RHS = getSCEV(BO->RHS);
7885 return getURemExpr(LHS, RHS);
7886 case Instruction::Sub: {
7888 if (BO->Op)
7889 Flags = getNoWrapFlagsFromUB(BO->Op);
7890 LHS = getSCEV(BO->LHS);
7891 RHS = getSCEV(BO->RHS);
7892 return getMinusSCEV(LHS, RHS, Flags);
7893 }
7894 case Instruction::And:
7895 // For an expression like x&255 that merely masks off the high bits,
7896 // use zext(trunc(x)) as the SCEV expression.
7897 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
7898 if (CI->isZero())
7899 return getSCEV(BO->RHS);
7900 if (CI->isMinusOne())
7901 return getSCEV(BO->LHS);
7902 const APInt &A = CI->getValue();
7903
7904 // Instcombine's ShrinkDemandedConstant may strip bits out of
7905 // constants, obscuring what would otherwise be a low-bits mask.
7906 // Use computeKnownBits to compute what ShrinkDemandedConstant
7907 // knew about to reconstruct a low-bits mask value.
7908 unsigned LZ = A.countl_zero();
7909 unsigned TZ = A.countr_zero();
7910 unsigned BitWidth = A.getBitWidth();
7911 KnownBits Known(BitWidth);
7912 computeKnownBits(BO->LHS, Known, getDataLayout(), &AC, nullptr, &DT);
7913
7914 APInt EffectiveMask =
7915 APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ);
7916 if ((LZ != 0 || TZ != 0) && !((~A & ~Known.Zero) & EffectiveMask)) {
7917 const SCEV *MulCount = getConstant(APInt::getOneBitSet(BitWidth, TZ));
7918 const SCEV *LHS = getSCEV(BO->LHS);
7919 const SCEV *ShiftedLHS = nullptr;
7920 if (auto *LHSMul = dyn_cast<SCEVMulExpr>(LHS)) {
7921 if (auto *OpC = dyn_cast<SCEVConstant>(LHSMul->getOperand(0))) {
7922 // For an expression like (x * 8) & 8, simplify the multiply.
7923 unsigned MulZeros = OpC->getAPInt().countr_zero();
7924 unsigned GCD = std::min(MulZeros, TZ);
7925 APInt DivAmt = APInt::getOneBitSet(BitWidth, TZ - GCD);
7927 MulOps.push_back(getConstant(OpC->getAPInt().ashr(GCD)));
7928 append_range(MulOps, LHSMul->operands().drop_front());
7929 auto *NewMul = getMulExpr(MulOps, LHSMul->getNoWrapFlags());
7930 ShiftedLHS = getUDivExpr(NewMul, getConstant(DivAmt));
7931 }
7932 }
7933 if (!ShiftedLHS)
7934 ShiftedLHS = getUDivExpr(LHS, MulCount);
7935 return getMulExpr(
7937 getTruncateExpr(ShiftedLHS,
7938 IntegerType::get(getContext(), BitWidth - LZ - TZ)),
7939 BO->LHS->getType()),
7940 MulCount);
7941 }
7942 }
7943 // Binary `and` is a bit-wise `umin`.
7944 if (BO->LHS->getType()->isIntegerTy(1)) {
7945 LHS = getSCEV(BO->LHS);
7946 RHS = getSCEV(BO->RHS);
7947 return getUMinExpr(LHS, RHS);
7948 }
7949 break;
7950
7951 case Instruction::Or:
7952 // Binary `or` is a bit-wise `umax`.
7953 if (BO->LHS->getType()->isIntegerTy(1)) {
7954 LHS = getSCEV(BO->LHS);
7955 RHS = getSCEV(BO->RHS);
7956 return getUMaxExpr(LHS, RHS);
7957 }
7958 break;
7959
7960 case Instruction::Xor:
7961 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
7962 // If the RHS of xor is -1, then this is a not operation.
7963 if (CI->isMinusOne())
7964 return getNotSCEV(getSCEV(BO->LHS));
7965
7966 // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask.
7967 // This is a variant of the check for xor with -1, and it handles
7968 // the case where instcombine has trimmed non-demanded bits out
7969 // of an xor with -1.
7970 if (auto *LBO = dyn_cast<BinaryOperator>(BO->LHS))
7971 if (ConstantInt *LCI = dyn_cast<ConstantInt>(LBO->getOperand(1)))
7972 if (LBO->getOpcode() == Instruction::And &&
7973 LCI->getValue() == CI->getValue())
7974 if (const SCEVZeroExtendExpr *Z =
7976 Type *UTy = BO->LHS->getType();
7977 const SCEV *Z0 = Z->getOperand();
7978 Type *Z0Ty = Z0->getType();
7979 unsigned Z0TySize = getTypeSizeInBits(Z0Ty);
7980
7981 // If C is a low-bits mask, the zero extend is serving to
7982 // mask off the high bits. Complement the operand and
7983 // re-apply the zext.
7984 if (CI->getValue().isMask(Z0TySize))
7985 return getZeroExtendExpr(getNotSCEV(Z0), UTy);
7986
7987 // If C is a single bit, it may be in the sign-bit position
7988 // before the zero-extend. In this case, represent the xor
7989 // using an add, which is equivalent, and re-apply the zext.
7990 APInt Trunc = CI->getValue().trunc(Z0TySize);
7991 if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() &&
7992 Trunc.isSignMask())
7993 return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)),
7994 UTy);
7995 }
7996 }
7997 break;
7998
7999 case Instruction::Shl:
8000 // Turn shift left of a constant amount into a multiply.
8001 if (ConstantInt *SA = dyn_cast<ConstantInt>(BO->RHS)) {
8002 uint32_t BitWidth = cast<IntegerType>(SA->getType())->getBitWidth();
8003
8004 // If the shift count is not less than the bitwidth, the result of
8005 // the shift is undefined. Don't try to analyze it, because the
8006 // resolution chosen here may differ from the resolution chosen in
8007 // other parts of the compiler.
8008 if (SA->getValue().uge(BitWidth))
8009 break;
8010
8011 // We can safely preserve the nuw flag in all cases. It's also safe to
8012 // turn a nuw nsw shl into a nuw nsw mul. However, nsw in isolation
8013 // requires special handling. It can be preserved as long as we're not
8014 // left shifting by bitwidth - 1.
8015 auto Flags = SCEV::FlagAnyWrap;
8016 if (BO->Op) {
8017 auto MulFlags = getNoWrapFlagsFromUB(BO->Op);
8018 if ((MulFlags & SCEV::FlagNSW) &&
8019 ((MulFlags & SCEV::FlagNUW) || SA->getValue().ult(BitWidth - 1)))
8021 if (MulFlags & SCEV::FlagNUW)
8023 }
8024
8025 ConstantInt *X = ConstantInt::get(
8026 getContext(), APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
8027 return getMulExpr(getSCEV(BO->LHS), getConstant(X), Flags);
8028 }
8029 break;
8030
8031 case Instruction::AShr:
8032 // AShr X, C, where C is a constant.
8033 ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS);
8034 if (!CI)
8035 break;
8036
8037 Type *OuterTy = BO->LHS->getType();
8038 uint64_t BitWidth = getTypeSizeInBits(OuterTy);
8039 // If the shift count is not less than the bitwidth, the result of
8040 // the shift is undefined. Don't try to analyze it, because the
8041 // resolution chosen here may differ from the resolution chosen in
8042 // other parts of the compiler.
8043 if (CI->getValue().uge(BitWidth))
8044 break;
8045
8046 if (CI->isZero())
8047 return getSCEV(BO->LHS); // shift by zero --> noop
8048
8049 uint64_t AShrAmt = CI->getZExtValue();
8050 Type *TruncTy = IntegerType::get(getContext(), BitWidth - AShrAmt);
8051
8052 Operator *L = dyn_cast<Operator>(BO->LHS);
8053 const SCEV *AddTruncateExpr = nullptr;
8054 ConstantInt *ShlAmtCI = nullptr;
8055 const SCEV *AddConstant = nullptr;
8056
8057 if (L && L->getOpcode() == Instruction::Add) {
8058 // X = Shl A, n
8059 // Y = Add X, c
8060 // Z = AShr Y, m
8061 // n, c and m are constants.
8062
8063 Operator *LShift = dyn_cast<Operator>(L->getOperand(0));
8064 ConstantInt *AddOperandCI = dyn_cast<ConstantInt>(L->getOperand(1));
8065 if (LShift && LShift->getOpcode() == Instruction::Shl) {
8066 if (AddOperandCI) {
8067 const SCEV *ShlOp0SCEV = getSCEV(LShift->getOperand(0));
8068 ShlAmtCI = dyn_cast<ConstantInt>(LShift->getOperand(1));
8069 // since we truncate to TruncTy, the AddConstant should be of the
8070 // same type, so create a new Constant with type same as TruncTy.
8071 // Also, the Add constant should be shifted right by AShr amount.
8072 APInt AddOperand = AddOperandCI->getValue().ashr(AShrAmt);
8073 AddConstant = getConstant(AddOperand.trunc(BitWidth - AShrAmt));
8074 // we model the expression as sext(add(trunc(A), c << n)), since the
8075 // sext(trunc) part is already handled below, we create a
8076 // AddExpr(TruncExp) which will be used later.
8077 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
8078 }
8079 }
8080 } else if (L && L->getOpcode() == Instruction::Shl) {
8081 // X = Shl A, n
8082 // Y = AShr X, m
8083 // Both n and m are constant.
8084
8085 const SCEV *ShlOp0SCEV = getSCEV(L->getOperand(0));
8086 ShlAmtCI = dyn_cast<ConstantInt>(L->getOperand(1));
8087 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
8088 }
8089
8090 if (AddTruncateExpr && ShlAmtCI) {
8091 // We can merge the two given cases into a single SCEV statement,
8092 // incase n = m, the mul expression will be 2^0, so it gets resolved to
8093 // a simpler case. The following code handles the two cases:
8094 //
8095 // 1) For a two-shift sext-inreg, i.e. n = m,
8096 // use sext(trunc(x)) as the SCEV expression.
8097 //
8098 // 2) When n > m, use sext(mul(trunc(x), 2^(n-m)))) as the SCEV
8099 // expression. We already checked that ShlAmt < BitWidth, so
8100 // the multiplier, 1 << (ShlAmt - AShrAmt), fits into TruncTy as
8101 // ShlAmt - AShrAmt < Amt.
8102 const APInt &ShlAmt = ShlAmtCI->getValue();
8103 if (ShlAmt.ult(BitWidth) && ShlAmt.uge(AShrAmt)) {
8104 APInt Mul = APInt::getOneBitSet(BitWidth - AShrAmt,
8105 ShlAmtCI->getZExtValue() - AShrAmt);
8106 const SCEV *CompositeExpr =
8107 getMulExpr(AddTruncateExpr, getConstant(Mul));
8108 if (L->getOpcode() != Instruction::Shl)
8109 CompositeExpr = getAddExpr(CompositeExpr, AddConstant);
8110
8111 return getSignExtendExpr(CompositeExpr, OuterTy);
8112 }
8113 }
8114 break;
8115 }
8116 }
8117
8118 switch (U->getOpcode()) {
8119 case Instruction::Trunc:
8120 return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType());
8121
8122 case Instruction::ZExt:
8123 return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType());
8124
8125 case Instruction::SExt:
8126 if (auto BO = MatchBinaryOp(U->getOperand(0), getDataLayout(), AC, DT,
8128 // The NSW flag of a subtract does not always survive the conversion to
8129 // A + (-1)*B. By pushing sign extension onto its operands we are much
8130 // more likely to preserve NSW and allow later AddRec optimisations.
8131 //
8132 // NOTE: This is effectively duplicating this logic from getSignExtend:
8133 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
8134 // but by that point the NSW information has potentially been lost.
8135 if (BO->Opcode == Instruction::Sub && BO->IsNSW) {
8136 Type *Ty = U->getType();
8137 auto *V1 = getSignExtendExpr(getSCEV(BO->LHS), Ty);
8138 auto *V2 = getSignExtendExpr(getSCEV(BO->RHS), Ty);
8139 return getMinusSCEV(V1, V2, SCEV::FlagNSW);
8140 }
8141 }
8142 return getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType());
8143
8144 case Instruction::BitCast:
8145 // BitCasts are no-op casts so we just eliminate the cast.
8146 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType()))
8147 return getSCEV(U->getOperand(0));
8148 break;
8149
8150 case Instruction::PtrToInt: {
8151 // Pointer to integer cast is straight-forward, so do model it.
8152 const SCEV *Op = getSCEV(U->getOperand(0));
8153 Type *DstIntTy = U->getType();
8154 // But only if effective SCEV (integer) type is wide enough to represent
8155 // all possible pointer values.
8156 const SCEV *IntOp = getPtrToIntExpr(Op, DstIntTy);
8157 if (isa<SCEVCouldNotCompute>(IntOp))
8158 return getUnknown(V);
8159 return IntOp;
8160 }
8161 case Instruction::IntToPtr:
8162 // Just don't deal with inttoptr casts.
8163 return getUnknown(V);
8164
8165 case Instruction::SDiv:
8166 // If both operands are non-negative, this is just an udiv.
8167 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8168 isKnownNonNegative(getSCEV(U->getOperand(1))))
8169 return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8170 break;
8171
8172 case Instruction::SRem:
8173 // If both operands are non-negative, this is just an urem.
8174 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8175 isKnownNonNegative(getSCEV(U->getOperand(1))))
8176 return getURemExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8177 break;
8178
8179 case Instruction::GetElementPtr:
8180 return createNodeForGEP(cast<GEPOperator>(U));
8181
8182 case Instruction::PHI:
8183 return createNodeForPHI(cast<PHINode>(U));
8184
8185 case Instruction::Select:
8186 return createNodeForSelectOrPHI(U, U->getOperand(0), U->getOperand(1),
8187 U->getOperand(2));
8188
8189 case Instruction::Call:
8190 case Instruction::Invoke:
8191 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand())
8192 return getSCEV(RV);
8193
8194 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
8195 switch (II->getIntrinsicID()) {
8196 case Intrinsic::abs:
8197 return getAbsExpr(
8198 getSCEV(II->getArgOperand(0)),
8199 /*IsNSW=*/cast<ConstantInt>(II->getArgOperand(1))->isOne());
8200 case Intrinsic::umax:
8201 LHS = getSCEV(II->getArgOperand(0));
8202 RHS = getSCEV(II->getArgOperand(1));
8203 return getUMaxExpr(LHS, RHS);
8204 case Intrinsic::umin:
8205 LHS = getSCEV(II->getArgOperand(0));
8206 RHS = getSCEV(II->getArgOperand(1));
8207 return getUMinExpr(LHS, RHS);
8208 case Intrinsic::smax:
8209 LHS = getSCEV(II->getArgOperand(0));
8210 RHS = getSCEV(II->getArgOperand(1));
8211 return getSMaxExpr(LHS, RHS);
8212 case Intrinsic::smin:
8213 LHS = getSCEV(II->getArgOperand(0));
8214 RHS = getSCEV(II->getArgOperand(1));
8215 return getSMinExpr(LHS, RHS);
8216 case Intrinsic::usub_sat: {
8217 const SCEV *X = getSCEV(II->getArgOperand(0));
8218 const SCEV *Y = getSCEV(II->getArgOperand(1));
8219 const SCEV *ClampedY = getUMinExpr(X, Y);
8220 return getMinusSCEV(X, ClampedY, SCEV::FlagNUW);
8221 }
8222 case Intrinsic::uadd_sat: {
8223 const SCEV *X = getSCEV(II->getArgOperand(0));
8224 const SCEV *Y = getSCEV(II->getArgOperand(1));
8225 const SCEV *ClampedX = getUMinExpr(X, getNotSCEV(Y));
8226 return getAddExpr(ClampedX, Y, SCEV::FlagNUW);
8227 }
8228 case Intrinsic::start_loop_iterations:
8229 case Intrinsic::annotation:
8230 case Intrinsic::ptr_annotation:
8231 // A start_loop_iterations or llvm.annotation or llvm.prt.annotation is
8232 // just eqivalent to the first operand for SCEV purposes.
8233 return getSCEV(II->getArgOperand(0));
8234 case Intrinsic::vscale:
8235 return getVScale(II->getType());
8236 default:
8237 break;
8238 }
8239 }
8240 break;
8241 }
8242
8243 return getUnknown(V);
8244}
8245
8246//===----------------------------------------------------------------------===//
8247// Iteration Count Computation Code
8248//
8249
8251 if (isa<SCEVCouldNotCompute>(ExitCount))
8252 return getCouldNotCompute();
8253
8254 auto *ExitCountType = ExitCount->getType();
8255 assert(ExitCountType->isIntegerTy());
8256 auto *EvalTy = Type::getIntNTy(ExitCountType->getContext(),
8257 1 + ExitCountType->getScalarSizeInBits());
8258 return getTripCountFromExitCount(ExitCount, EvalTy, nullptr);
8259}
8260
8262 Type *EvalTy,
8263 const Loop *L) {
8264 if (isa<SCEVCouldNotCompute>(ExitCount))
8265 return getCouldNotCompute();
8266
8267 unsigned ExitCountSize = getTypeSizeInBits(ExitCount->getType());
8268 unsigned EvalSize = EvalTy->getPrimitiveSizeInBits();
8269
8270 auto CanAddOneWithoutOverflow = [&]() {
8271 ConstantRange ExitCountRange =
8272 getRangeRef(ExitCount, RangeSignHint::HINT_RANGE_UNSIGNED);
8273 if (!ExitCountRange.contains(APInt::getMaxValue(ExitCountSize)))
8274 return true;
8275
8276 return L && isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, ExitCount,
8277 getMinusOne(ExitCount->getType()));
8278 };
8279
8280 // If we need to zero extend the backedge count, check if we can add one to
8281 // it prior to zero extending without overflow. Provided this is safe, it
8282 // allows better simplification of the +1.
8283 if (EvalSize > ExitCountSize && CanAddOneWithoutOverflow())
8284 return getZeroExtendExpr(
8285 getAddExpr(ExitCount, getOne(ExitCount->getType())), EvalTy);
8286
8287 // Get the total trip count from the count by adding 1. This may wrap.
8288 return getAddExpr(getTruncateOrZeroExtend(ExitCount, EvalTy), getOne(EvalTy));
8289}
8290
8291static unsigned getConstantTripCount(const SCEVConstant *ExitCount) {
8292 if (!ExitCount)
8293 return 0;
8294
8295 ConstantInt *ExitConst = ExitCount->getValue();
8296
8297 // Guard against huge trip counts.
8298 if (ExitConst->getValue().getActiveBits() > 32)
8299 return 0;
8300
8301 // In case of integer overflow, this returns 0, which is correct.
8302 return ((unsigned)ExitConst->getZExtValue()) + 1;
8303}
8304
8306 auto *ExitCount = dyn_cast<SCEVConstant>(getBackedgeTakenCount(L, Exact));
8307 return getConstantTripCount(ExitCount);
8308}
8309
8310unsigned
8312 const BasicBlock *ExitingBlock) {
8313 assert(ExitingBlock && "Must pass a non-null exiting block!");
8314 assert(L->isLoopExiting(ExitingBlock) &&
8315 "Exiting block must actually branch out of the loop!");
8316 const SCEVConstant *ExitCount =
8317 dyn_cast<SCEVConstant>(getExitCount(L, ExitingBlock));
8318 return getConstantTripCount(ExitCount);
8319}
8320
8322 const Loop *L, SmallVectorImpl<const SCEVPredicate *> *Predicates) {
8323
8324 const auto *MaxExitCount =
8325 Predicates ? getPredicatedConstantMaxBackedgeTakenCount(L, *Predicates)
8327 return getConstantTripCount(dyn_cast<SCEVConstant>(MaxExitCount));
8328}
8329
8331 SmallVector<BasicBlock *, 8> ExitingBlocks;
8332 L->getExitingBlocks(ExitingBlocks);
8333
8334 std::optional<unsigned> Res;
8335 for (auto *ExitingBB : ExitingBlocks) {
8336 unsigned Multiple = getSmallConstantTripMultiple(L, ExitingBB);
8337 if (!Res)
8338 Res = Multiple;
8339 Res = std::gcd(*Res, Multiple);
8340 }
8341 return Res.value_or(1);
8342}
8343
8345 const SCEV *ExitCount) {
8346 if (isa<SCEVCouldNotCompute>(ExitCount))
8347 return 1;
8348
8349 // Get the trip count
8350 const SCEV *TCExpr = getTripCountFromExitCount(applyLoopGuards(ExitCount, L));
8351
8352 APInt Multiple = getNonZeroConstantMultiple(TCExpr);
8353 // If a trip multiple is huge (>=2^32), the trip count is still divisible by
8354 // the greatest power of 2 divisor less than 2^32.
8355 return Multiple.getActiveBits() > 32
8356 ? 1U << std::min(31U, Multiple.countTrailingZeros())
8357 : (unsigned)Multiple.getZExtValue();
8358}
8359
8360/// Returns the largest constant divisor of the trip count of this loop as a
8361/// normal unsigned value, if possible. This means that the actual trip count is
8362/// always a multiple of the returned value (don't forget the trip count could
8363/// very well be zero as well!).
8364///
8365/// Returns 1 if the trip count is unknown or not guaranteed to be the
8366/// multiple of a constant (which is also the case if the trip count is simply
8367/// constant, use getSmallConstantTripCount for that case), Will also return 1
8368/// if the trip count is very large (>= 2^32).
8369///
8370/// As explained in the comments for getSmallConstantTripCount, this assumes
8371/// that control exits the loop via ExitingBlock.
8372unsigned
8374 const BasicBlock *ExitingBlock) {
8375 assert(ExitingBlock && "Must pass a non-null exiting block!");
8376 assert(L->isLoopExiting(ExitingBlock) &&
8377 "Exiting block must actually branch out of the loop!");
8378 const SCEV *ExitCount = getExitCount(L, ExitingBlock);
8379 return getSmallConstantTripMultiple(L, ExitCount);
8380}
8381
8383 const BasicBlock *ExitingBlock,
8384 ExitCountKind Kind) {
8385 switch (Kind) {
8386 case Exact:
8387 return getBackedgeTakenInfo(L).getExact(ExitingBlock, this);
8388 case SymbolicMaximum:
8389 return getBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this);
8390 case ConstantMaximum:
8391 return getBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this);
8392 };
8393 llvm_unreachable("Invalid ExitCountKind!");
8394}
8395
8397 const Loop *L, const BasicBlock *ExitingBlock,
8399 switch (Kind) {
8400 case Exact:
8401 return getPredicatedBackedgeTakenInfo(L).getExact(ExitingBlock, this,
8402 Predicates);
8403 case SymbolicMaximum:
8404 return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this,
8405 Predicates);
8406 case ConstantMaximum:
8407 return getPredicatedBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this,
8408 Predicates);
8409 };
8410 llvm_unreachable("Invalid ExitCountKind!");
8411}
8412
8415 return getPredicatedBackedgeTakenInfo(L).getExact(L, this, &Preds);
8416}
8417
8419 ExitCountKind Kind) {
8420 switch (Kind) {
8421 case Exact:
8422 return getBackedgeTakenInfo(L).getExact(L, this);
8423 case ConstantMaximum:
8424 return getBackedgeTakenInfo(L).getConstantMax(this);
8425 case SymbolicMaximum:
8426 return getBackedgeTakenInfo(L).getSymbolicMax(L, this);
8427 };
8428 llvm_unreachable("Invalid ExitCountKind!");
8429}
8430
8433 return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(L, this, &Preds);
8434}
8435
8438 return getPredicatedBackedgeTakenInfo(L).getConstantMax(this, &Preds);
8439}
8440
8442 return getBackedgeTakenInfo(L).isConstantMaxOrZero(this);
8443}
8444
8445/// Push PHI nodes in the header of the given loop onto the given Worklist.
8446static void PushLoopPHIs(const Loop *L,
8449 BasicBlock *Header = L->getHeader();
8450
8451 // Push all Loop-header PHIs onto the Worklist stack.
8452 for (PHINode &PN : Header->phis())
8453 if (Visited.insert(&PN).second)
8454 Worklist.push_back(&PN);
8455}
8456
8457ScalarEvolution::BackedgeTakenInfo &
8458ScalarEvolution::getPredicatedBackedgeTakenInfo(const Loop *L) {
8459 auto &BTI = getBackedgeTakenInfo(L);
8460 if (BTI.hasFullInfo())
8461 return BTI;
8462
8463 auto Pair = PredicatedBackedgeTakenCounts.try_emplace(L);
8464
8465 if (!Pair.second)
8466 return Pair.first->second;
8467
8468 BackedgeTakenInfo Result =
8469 computeBackedgeTakenCount(L, /*AllowPredicates=*/true);
8470
8471 return PredicatedBackedgeTakenCounts.find(L)->second = std::move(Result);
8472}
8473
8474ScalarEvolution::BackedgeTakenInfo &
8475ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
8476 // Initially insert an invalid entry for this loop. If the insertion
8477 // succeeds, proceed to actually compute a backedge-taken count and
8478 // update the value. The temporary CouldNotCompute value tells SCEV
8479 // code elsewhere that it shouldn't attempt to request a new
8480 // backedge-taken count, which could result in infinite recursion.
8481 std::pair<DenseMap<const Loop *, BackedgeTakenInfo>::iterator, bool> Pair =
8482 BackedgeTakenCounts.try_emplace(L);
8483 if (!Pair.second)
8484 return Pair.first->second;
8485
8486 // computeBackedgeTakenCount may allocate memory for its result. Inserting it
8487 // into the BackedgeTakenCounts map transfers ownership. Otherwise, the result
8488 // must be cleared in this scope.
8489 BackedgeTakenInfo Result = computeBackedgeTakenCount(L);
8490
8491 // Now that we know more about the trip count for this loop, forget any
8492 // existing SCEV values for PHI nodes in this loop since they are only
8493 // conservative estimates made without the benefit of trip count
8494 // information. This invalidation is not necessary for correctness, and is
8495 // only done to produce more precise results.
8496 if (Result.hasAnyInfo()) {
8497 // Invalidate any expression using an addrec in this loop.
8499 auto LoopUsersIt = LoopUsers.find(L);
8500 if (LoopUsersIt != LoopUsers.end())
8501 append_range(ToForget, LoopUsersIt->second);
8502 forgetMemoizedResults(ToForget);
8503
8504 // Invalidate constant-evolved loop header phis.
8505 for (PHINode &PN : L->getHeader()->phis())
8506 ConstantEvolutionLoopExitValue.erase(&PN);
8507 }
8508
8509 // Re-lookup the insert position, since the call to
8510 // computeBackedgeTakenCount above could result in a
8511 // recusive call to getBackedgeTakenInfo (on a different
8512 // loop), which would invalidate the iterator computed
8513 // earlier.
8514 return BackedgeTakenCounts.find(L)->second = std::move(Result);
8515}
8516
8518 // This method is intended to forget all info about loops. It should
8519 // invalidate caches as if the following happened:
8520 // - The trip counts of all loops have changed arbitrarily
8521 // - Every llvm::Value has been updated in place to produce a different
8522 // result.
8523 BackedgeTakenCounts.clear();
8524 PredicatedBackedgeTakenCounts.clear();
8525 BECountUsers.clear();
8526 LoopPropertiesCache.clear();
8527 ConstantEvolutionLoopExitValue.clear();
8528 ValueExprMap.clear();
8529 ValuesAtScopes.clear();
8530 ValuesAtScopesUsers.clear();
8531 LoopDispositions.clear();
8532 BlockDispositions.clear();
8533 UnsignedRanges.clear();
8534 SignedRanges.clear();
8535 ExprValueMap.clear();
8536 HasRecMap.clear();
8537 ConstantMultipleCache.clear();
8538 PredicatedSCEVRewrites.clear();
8539 FoldCache.clear();
8540 FoldCacheUser.clear();
8541}
8542void ScalarEvolution::visitAndClearUsers(
8546 while (!Worklist.empty()) {
8547 Instruction *I = Worklist.pop_back_val();
8548 if (!isSCEVable(I->getType()) && !isa<WithOverflowInst>(I))
8549 continue;
8550
8552 ValueExprMap.find_as(static_cast<Value *>(I));
8553 if (It != ValueExprMap.end()) {
8554 eraseValueFromMap(It->first);
8555 ToForget.push_back(It->second);
8556 if (PHINode *PN = dyn_cast<PHINode>(I))
8557 ConstantEvolutionLoopExitValue.erase(PN);
8558 }
8559
8560 PushDefUseChildren(I, Worklist, Visited);
8561 }
8562}
8563
8565 SmallVector<const Loop *, 16> LoopWorklist(1, L);
8569
8570 // Iterate over all the loops and sub-loops to drop SCEV information.
8571 while (!LoopWorklist.empty()) {
8572 auto *CurrL = LoopWorklist.pop_back_val();
8573
8574 // Drop any stored trip count value.
8575 forgetBackedgeTakenCounts(CurrL, /* Predicated */ false);
8576 forgetBackedgeTakenCounts(CurrL, /* Predicated */ true);
8577
8578 // Drop information about predicated SCEV rewrites for this loop.
8579 for (auto I = PredicatedSCEVRewrites.begin();
8580 I != PredicatedSCEVRewrites.end();) {
8581 std::pair<const SCEV *, const Loop *> Entry = I->first;
8582 if (Entry.second == CurrL)
8583 PredicatedSCEVRewrites.erase(I++);
8584 else
8585 ++I;
8586 }
8587
8588 auto LoopUsersItr = LoopUsers.find(CurrL);
8589 if (LoopUsersItr != LoopUsers.end())
8590 llvm::append_range(ToForget, LoopUsersItr->second);
8591
8592 // Drop information about expressions based on loop-header PHIs.
8593 PushLoopPHIs(CurrL, Worklist, Visited);
8594 visitAndClearUsers(Worklist, Visited, ToForget);
8595
8596 LoopPropertiesCache.erase(CurrL);
8597 // Forget all contained loops too, to avoid dangling entries in the
8598 // ValuesAtScopes map.
8599 LoopWorklist.append(CurrL->begin(), CurrL->end());
8600 }
8601 forgetMemoizedResults(ToForget);
8602}
8603
8605 forgetLoop(L->getOutermostLoop());
8606}
8607
8610 if (!I) return;
8611
8612 // Drop information about expressions based on loop-header PHIs.
8616 Worklist.push_back(I);
8617 Visited.insert(I);
8618 visitAndClearUsers(Worklist, Visited, ToForget);
8619
8620 forgetMemoizedResults(ToForget);
8621}
8622
8624 if (!isSCEVable(V->getType()))
8625 return;
8626
8627 // If SCEV looked through a trivial LCSSA phi node, we might have SCEV's
8628 // directly using a SCEVUnknown/SCEVAddRec defined in the loop. After an
8629 // extra predecessor is added, this is no longer valid. Find all Unknowns and
8630 // AddRecs defined in the loop and invalidate any SCEV's making use of them.
8631 if (const SCEV *S = getExistingSCEV(V)) {
8632 struct InvalidationRootCollector {
8633 Loop *L;
8635
8636 InvalidationRootCollector(Loop *L) : L(L) {}
8637
8638 bool follow(const SCEV *S) {
8639 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
8640 if (auto *I = dyn_cast<Instruction>(SU->getValue()))
8641 if (L->contains(I))
8642 Roots.push_back(S);
8643 } else if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
8644 if (L->contains(AddRec->getLoop()))
8645 Roots.push_back(S);
8646 }
8647 return true;
8648 }
8649 bool isDone() const { return false; }
8650 };
8651
8652 InvalidationRootCollector C(L);
8653 visitAll(S, C);
8654 forgetMemoizedResults(C.Roots);
8655 }
8656
8657 // Also perform the normal invalidation.
8658 forgetValue(V);
8659}
8660
8661void ScalarEvolution::forgetLoopDispositions() { LoopDispositions.clear(); }
8662
8664 // Unless a specific value is passed to invalidation, completely clear both
8665 // caches.
8666 if (!V) {
8667 BlockDispositions.clear();
8668 LoopDispositions.clear();
8669 return;
8670 }
8671
8672 if (!isSCEVable(V->getType()))
8673 return;
8674
8675 const SCEV *S = getExistingSCEV(V);
8676 if (!S)
8677 return;
8678
8679 // Invalidate the block and loop dispositions cached for S. Dispositions of
8680 // S's users may change if S's disposition changes (i.e. a user may change to
8681 // loop-invariant, if S changes to loop invariant), so also invalidate
8682 // dispositions of S's users recursively.
8683 SmallVector<const SCEV *, 8> Worklist = {S};
8685 while (!Worklist.empty()) {
8686 const SCEV *Curr = Worklist.pop_back_val();
8687 bool LoopDispoRemoved = LoopDispositions.erase(Curr);
8688 bool BlockDispoRemoved = BlockDispositions.erase(Curr);
8689 if (!LoopDispoRemoved && !BlockDispoRemoved)
8690 continue;
8691 auto Users = SCEVUsers.find(Curr);
8692 if (Users != SCEVUsers.end())
8693 for (const auto *User : Users->second)
8694 if (Seen.insert(User).second)
8695 Worklist.push_back(User);
8696 }
8697}
8698
8699/// Get the exact loop backedge taken count considering all loop exits. A
8700/// computable result can only be returned for loops with all exiting blocks
8701/// dominating the latch. howFarToZero assumes that the limit of each loop test
8702/// is never skipped. This is a valid assumption as long as the loop exits via
8703/// that test. For precise results, it is the caller's responsibility to specify
8704/// the relevant loop exiting block using getExact(ExitingBlock, SE).
8705const SCEV *ScalarEvolution::BackedgeTakenInfo::getExact(
8706 const Loop *L, ScalarEvolution *SE,
8708 // If any exits were not computable, the loop is not computable.
8709 if (!isComplete() || ExitNotTaken.empty())
8710 return SE->getCouldNotCompute();
8711
8712 const BasicBlock *Latch = L->getLoopLatch();
8713 // All exiting blocks we have collected must dominate the only backedge.
8714 if (!Latch)
8715 return SE->getCouldNotCompute();
8716
8717 // All exiting blocks we have gathered dominate loop's latch, so exact trip
8718 // count is simply a minimum out of all these calculated exit counts.
8720 for (const auto &ENT : ExitNotTaken) {
8721 const SCEV *BECount = ENT.ExactNotTaken;
8722 assert(BECount != SE->getCouldNotCompute() && "Bad exit SCEV!");
8723 assert(SE->DT.dominates(ENT.ExitingBlock, Latch) &&
8724 "We should only have known counts for exiting blocks that dominate "
8725 "latch!");
8726
8727 Ops.push_back(BECount);
8728
8729 if (Preds)
8730 append_range(*Preds, ENT.Predicates);
8731
8732 assert((Preds || ENT.hasAlwaysTruePredicate()) &&
8733 "Predicate should be always true!");
8734 }
8735
8736 // If an earlier exit exits on the first iteration (exit count zero), then
8737 // a later poison exit count should not propagate into the result. This are
8738 // exactly the semantics provided by umin_seq.
8739 return SE->getUMinFromMismatchedTypes(Ops, /* Sequential */ true);
8740}
8741
8742const ScalarEvolution::ExitNotTakenInfo *
8743ScalarEvolution::BackedgeTakenInfo::getExitNotTaken(
8744 const BasicBlock *ExitingBlock,
8745 SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
8746 for (const auto &ENT : ExitNotTaken)
8747 if (ENT.ExitingBlock == ExitingBlock) {
8748 if (ENT.hasAlwaysTruePredicate())
8749 return &ENT;
8750 else if (Predicates) {
8751 append_range(*Predicates, ENT.Predicates);
8752 return &ENT;
8753 }
8754 }
8755
8756 return nullptr;
8757}
8758
8759/// getConstantMax - Get the constant max backedge taken count for the loop.
8760const SCEV *ScalarEvolution::BackedgeTakenInfo::getConstantMax(
8761 ScalarEvolution *SE,
8762 SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
8763 if (!getConstantMax())
8764 return SE->getCouldNotCompute();
8765
8766 for (const auto &ENT : ExitNotTaken)
8767 if (!ENT.hasAlwaysTruePredicate()) {
8768 if (!Predicates)
8769 return SE->getCouldNotCompute();
8770 append_range(*Predicates, ENT.Predicates);
8771 }
8772
8773 assert((isa<SCEVCouldNotCompute>(getConstantMax()) ||
8774 isa<SCEVConstant>(getConstantMax())) &&
8775 "No point in having a non-constant max backedge taken count!");
8776 return getConstantMax();
8777}
8778
8779const SCEV *ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(
8780 const Loop *L, ScalarEvolution *SE,
8781 SmallVectorImpl<const SCEVPredicate *> *Predicates) {
8782 if (!SymbolicMax) {
8783 // Form an expression for the maximum exit count possible for this loop. We
8784 // merge the max and exact information to approximate a version of
8785 // getConstantMaxBackedgeTakenCount which isn't restricted to just
8786 // constants.
8788
8789 for (const auto &ENT : ExitNotTaken) {
8790 const SCEV *ExitCount = ENT.SymbolicMaxNotTaken;
8791 if (!isa<SCEVCouldNotCompute>(ExitCount)) {
8792 assert(SE->DT.dominates(ENT.ExitingBlock, L->getLoopLatch()) &&
8793 "We should only have known counts for exiting blocks that "
8794 "dominate latch!");
8795 ExitCounts.push_back(ExitCount);
8796 if (Predicates)
8797 append_range(*Predicates, ENT.Predicates);
8798
8799 assert((Predicates || ENT.hasAlwaysTruePredicate()) &&
8800 "Predicate should be always true!");
8801 }
8802 }
8803 if (ExitCounts.empty())
8804 SymbolicMax = SE->getCouldNotCompute();
8805 else
8806 SymbolicMax =
8807 SE->getUMinFromMismatchedTypes(ExitCounts, /*Sequential*/ true);
8808 }
8809 return SymbolicMax;
8810}
8811
8812bool ScalarEvolution::BackedgeTakenInfo::isConstantMaxOrZero(
8813 ScalarEvolution *SE) const {
8814 auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) {
8815 return !ENT.hasAlwaysTruePredicate();
8816 };
8817 return MaxOrZero && !any_of(ExitNotTaken, PredicateNotAlwaysTrue);
8818}
8819
8822
8824 const SCEV *E, const SCEV *ConstantMaxNotTaken,
8825 const SCEV *SymbolicMaxNotTaken, bool MaxOrZero,
8829 // If we prove the max count is zero, so is the symbolic bound. This happens
8830 // in practice due to differences in a) how context sensitive we've chosen
8831 // to be and b) how we reason about bounds implied by UB.
8832 if (ConstantMaxNotTaken->isZero()) {
8833 this->ExactNotTaken = E = ConstantMaxNotTaken;
8834 this->SymbolicMaxNotTaken = SymbolicMaxNotTaken = ConstantMaxNotTaken;
8835 }
8836
8839 "Exact is not allowed to be less precise than Constant Max");
8842 "Exact is not allowed to be less precise than Symbolic Max");
8845 "Symbolic Max is not allowed to be less precise than Constant Max");
8848 "No point in having a non-constant max backedge taken count!");
8850 for (const auto PredList : PredLists)
8851 for (const auto *P : PredList) {
8852 if (SeenPreds.contains(P))
8853 continue;
8854 assert(!isa<SCEVUnionPredicate>(P) && "Only add leaf predicates here!");
8855 SeenPreds.insert(P);
8856 Predicates.push_back(P);
8857 }
8858 assert((isa<SCEVCouldNotCompute>(E) || !E->getType()->isPointerTy()) &&
8859 "Backedge count should be int");
8861 !ConstantMaxNotTaken->getType()->isPointerTy()) &&
8862 "Max backedge count should be int");
8863}
8864
8872
8873/// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each
8874/// computable exit into a persistent ExitNotTakenInfo array.
8875ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(
8877 bool IsComplete, const SCEV *ConstantMax, bool MaxOrZero)
8878 : ConstantMax(ConstantMax), IsComplete(IsComplete), MaxOrZero(MaxOrZero) {
8879 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
8880
8881 ExitNotTaken.reserve(ExitCounts.size());
8882 std::transform(ExitCounts.begin(), ExitCounts.end(),
8883 std::back_inserter(ExitNotTaken),
8884 [&](const EdgeExitInfo &EEI) {
8885 BasicBlock *ExitBB = EEI.first;
8886 const ExitLimit &EL = EEI.second;
8887 return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken,
8888 EL.ConstantMaxNotTaken, EL.SymbolicMaxNotTaken,
8889 EL.Predicates);
8890 });
8891 assert((isa<SCEVCouldNotCompute>(ConstantMax) ||
8892 isa<SCEVConstant>(ConstantMax)) &&
8893 "No point in having a non-constant max backedge taken count!");
8894}
8895
8896/// Compute the number of times the backedge of the specified loop will execute.
8897ScalarEvolution::BackedgeTakenInfo
8898ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
8899 bool AllowPredicates) {
8900 SmallVector<BasicBlock *, 8> ExitingBlocks;
8901 L->getExitingBlocks(ExitingBlocks);
8902
8903 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
8904
8906 bool CouldComputeBECount = true;
8907 BasicBlock *Latch = L->getLoopLatch(); // may be NULL.
8908 const SCEV *MustExitMaxBECount = nullptr;
8909 const SCEV *MayExitMaxBECount = nullptr;
8910 bool MustExitMaxOrZero = false;
8911 bool IsOnlyExit = ExitingBlocks.size() == 1;
8912
8913 // Compute the ExitLimit for each loop exit. Use this to populate ExitCounts
8914 // and compute maxBECount.
8915 // Do a union of all the predicates here.
8916 for (BasicBlock *ExitBB : ExitingBlocks) {
8917 // We canonicalize untaken exits to br (constant), ignore them so that
8918 // proving an exit untaken doesn't negatively impact our ability to reason
8919 // about the loop as whole.
8920 if (auto *BI = dyn_cast<BranchInst>(ExitBB->getTerminator()))
8921 if (auto *CI = dyn_cast<ConstantInt>(BI->getCondition())) {
8922 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
8923 if (ExitIfTrue == CI->isZero())
8924 continue;
8925 }
8926
8927 ExitLimit EL = computeExitLimit(L, ExitBB, IsOnlyExit, AllowPredicates);
8928
8929 assert((AllowPredicates || EL.Predicates.empty()) &&
8930 "Predicated exit limit when predicates are not allowed!");
8931
8932 // 1. For each exit that can be computed, add an entry to ExitCounts.
8933 // CouldComputeBECount is true only if all exits can be computed.
8934 if (EL.ExactNotTaken != getCouldNotCompute())
8935 ++NumExitCountsComputed;
8936 else
8937 // We couldn't compute an exact value for this exit, so
8938 // we won't be able to compute an exact value for the loop.
8939 CouldComputeBECount = false;
8940 // Remember exit count if either exact or symbolic is known. Because
8941 // Exact always implies symbolic, only check symbolic.
8942 if (EL.SymbolicMaxNotTaken != getCouldNotCompute())
8943 ExitCounts.emplace_back(ExitBB, EL);
8944 else {
8945 assert(EL.ExactNotTaken == getCouldNotCompute() &&
8946 "Exact is known but symbolic isn't?");
8947 ++NumExitCountsNotComputed;
8948 }
8949
8950 // 2. Derive the loop's MaxBECount from each exit's max number of
8951 // non-exiting iterations. Partition the loop exits into two kinds:
8952 // LoopMustExits and LoopMayExits.
8953 //
8954 // If the exit dominates the loop latch, it is a LoopMustExit otherwise it
8955 // is a LoopMayExit. If any computable LoopMustExit is found, then
8956 // MaxBECount is the minimum EL.ConstantMaxNotTaken of computable
8957 // LoopMustExits. Otherwise, MaxBECount is conservatively the maximum
8958 // EL.ConstantMaxNotTaken, where CouldNotCompute is considered greater than
8959 // any
8960 // computable EL.ConstantMaxNotTaken.
8961 if (EL.ConstantMaxNotTaken != getCouldNotCompute() && Latch &&
8962 DT.dominates(ExitBB, Latch)) {
8963 if (!MustExitMaxBECount) {
8964 MustExitMaxBECount = EL.ConstantMaxNotTaken;
8965 MustExitMaxOrZero = EL.MaxOrZero;
8966 } else {
8967 MustExitMaxBECount = getUMinFromMismatchedTypes(MustExitMaxBECount,
8968 EL.ConstantMaxNotTaken);
8969 }
8970 } else if (MayExitMaxBECount != getCouldNotCompute()) {
8971 if (!MayExitMaxBECount || EL.ConstantMaxNotTaken == getCouldNotCompute())
8972 MayExitMaxBECount = EL.ConstantMaxNotTaken;
8973 else {
8974 MayExitMaxBECount = getUMaxFromMismatchedTypes(MayExitMaxBECount,
8975 EL.ConstantMaxNotTaken);
8976 }
8977 }
8978 }
8979 const SCEV *MaxBECount = MustExitMaxBECount ? MustExitMaxBECount :
8980 (MayExitMaxBECount ? MayExitMaxBECount : getCouldNotCompute());
8981 // The loop backedge will be taken the maximum or zero times if there's
8982 // a single exit that must be taken the maximum or zero times.
8983 bool MaxOrZero = (MustExitMaxOrZero && ExitingBlocks.size() == 1);
8984
8985 // Remember which SCEVs are used in exit limits for invalidation purposes.
8986 // We only care about non-constant SCEVs here, so we can ignore
8987 // EL.ConstantMaxNotTaken
8988 // and MaxBECount, which must be SCEVConstant.
8989 for (const auto &Pair : ExitCounts) {
8990 if (!isa<SCEVConstant>(Pair.second.ExactNotTaken))
8991 BECountUsers[Pair.second.ExactNotTaken].insert({L, AllowPredicates});
8992 if (!isa<SCEVConstant>(Pair.second.SymbolicMaxNotTaken))
8993 BECountUsers[Pair.second.SymbolicMaxNotTaken].insert(
8994 {L, AllowPredicates});
8995 }
8996 return BackedgeTakenInfo(std::move(ExitCounts), CouldComputeBECount,
8997 MaxBECount, MaxOrZero);
8998}
8999
9000ScalarEvolution::ExitLimit
9001ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
9002 bool IsOnlyExit, bool AllowPredicates) {
9003 assert(L->contains(ExitingBlock) && "Exit count for non-loop block?");
9004 // If our exiting block does not dominate the latch, then its connection with
9005 // loop's exit limit may be far from trivial.
9006 const BasicBlock *Latch = L->getLoopLatch();
9007 if (!Latch || !DT.dominates(ExitingBlock, Latch))
9008 return getCouldNotCompute();
9009
9010 Instruction *Term = ExitingBlock->getTerminator();
9011 if (BranchInst *BI = dyn_cast<BranchInst>(Term)) {
9012 assert(BI->isConditional() && "If unconditional, it can't be in loop!");
9013 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
9014 assert(ExitIfTrue == L->contains(BI->getSuccessor(1)) &&
9015 "It should have one successor in loop and one exit block!");
9016 // Proceed to the next level to examine the exit condition expression.
9017 return computeExitLimitFromCond(L, BI->getCondition(), ExitIfTrue,
9018 /*ControlsOnlyExit=*/IsOnlyExit,
9019 AllowPredicates);
9020 }
9021
9022 if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) {
9023 // For switch, make sure that there is a single exit from the loop.
9024 BasicBlock *Exit = nullptr;
9025 for (auto *SBB : successors(ExitingBlock))
9026 if (!L->contains(SBB)) {
9027 if (Exit) // Multiple exit successors.
9028 return getCouldNotCompute();
9029 Exit = SBB;
9030 }
9031 assert(Exit && "Exiting block must have at least one exit");
9032 return computeExitLimitFromSingleExitSwitch(
9033 L, SI, Exit, /*ControlsOnlyExit=*/IsOnlyExit);
9034 }
9035
9036 return getCouldNotCompute();
9037}
9038
9040 const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
9041 bool AllowPredicates) {
9042 ScalarEvolution::ExitLimitCacheTy Cache(L, ExitIfTrue, AllowPredicates);
9043 return computeExitLimitFromCondCached(Cache, L, ExitCond, ExitIfTrue,
9044 ControlsOnlyExit, AllowPredicates);
9045}
9046
9047std::optional<ScalarEvolution::ExitLimit>
9048ScalarEvolution::ExitLimitCache::find(const Loop *L, Value *ExitCond,
9049 bool ExitIfTrue, bool ControlsOnlyExit,
9050 bool AllowPredicates) {
9051 (void)this->L;
9052 (void)this->ExitIfTrue;
9053 (void)this->AllowPredicates;
9054
9055 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
9056 this->AllowPredicates == AllowPredicates &&
9057 "Variance in assumed invariant key components!");
9058 auto Itr = TripCountMap.find({ExitCond, ControlsOnlyExit});
9059 if (Itr == TripCountMap.end())
9060 return std::nullopt;
9061 return Itr->second;
9062}
9063
9064void ScalarEvolution::ExitLimitCache::insert(const Loop *L, Value *ExitCond,
9065 bool ExitIfTrue,
9066 bool ControlsOnlyExit,
9067 bool AllowPredicates,
9068 const ExitLimit &EL) {
9069 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
9070 this->AllowPredicates == AllowPredicates &&
9071 "Variance in assumed invariant key components!");
9072
9073 auto InsertResult = TripCountMap.insert({{ExitCond, ControlsOnlyExit}, EL});
9074 assert(InsertResult.second && "Expected successful insertion!");
9075 (void)InsertResult;
9076 (void)ExitIfTrue;
9077}
9078
9079ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached(
9080 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9081 bool ControlsOnlyExit, bool AllowPredicates) {
9082
9083 if (auto MaybeEL = Cache.find(L, ExitCond, ExitIfTrue, ControlsOnlyExit,
9084 AllowPredicates))
9085 return *MaybeEL;
9086
9087 ExitLimit EL = computeExitLimitFromCondImpl(
9088 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates);
9089 Cache.insert(L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates, EL);
9090 return EL;
9091}
9092
9093ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
9094 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9095 bool ControlsOnlyExit, bool AllowPredicates) {
9096 // Handle BinOp conditions (And, Or).
9097 if (auto LimitFromBinOp = computeExitLimitFromCondFromBinOp(
9098 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates))
9099 return *LimitFromBinOp;
9100
9101 // With an icmp, it may be feasible to compute an exact backedge-taken count.
9102 // Proceed to the next level to examine the icmp.
9103 if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond)) {
9104 ExitLimit EL =
9105 computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsOnlyExit);
9106 if (EL.hasFullInfo() || !AllowPredicates)
9107 return EL;
9108
9109 // Try again, but use SCEV predicates this time.
9110 return computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue,
9111 ControlsOnlyExit,
9112 /*AllowPredicates=*/true);
9113 }
9114
9115 // Check for a constant condition. These are normally stripped out by
9116 // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to
9117 // preserve the CFG and is temporarily leaving constant conditions
9118 // in place.
9119 if (ConstantInt *CI = dyn_cast<ConstantInt>(ExitCond)) {
9120 if (ExitIfTrue == !CI->getZExtValue())
9121 // The backedge is always taken.
9122 return getCouldNotCompute();
9123 // The backedge is never taken.
9124 return getZero(CI->getType());
9125 }
9126
9127 // If we're exiting based on the overflow flag of an x.with.overflow intrinsic
9128 // with a constant step, we can form an equivalent icmp predicate and figure
9129 // out how many iterations will be taken before we exit.
9130 const WithOverflowInst *WO;
9131 const APInt *C;
9132 if (match(ExitCond, m_ExtractValue<1>(m_WithOverflowInst(WO))) &&
9133 match(WO->getRHS(), m_APInt(C))) {
9134 ConstantRange NWR =
9136 WO->getNoWrapKind());
9137 CmpInst::Predicate Pred;
9138 APInt NewRHSC, Offset;
9139 NWR.getEquivalentICmp(Pred, NewRHSC, Offset);
9140 if (!ExitIfTrue)
9141 Pred = ICmpInst::getInversePredicate(Pred);
9142 auto *LHS = getSCEV(WO->getLHS());
9143 if (Offset != 0)
9145 auto EL = computeExitLimitFromICmp(L, Pred, LHS, getConstant(NewRHSC),
9146 ControlsOnlyExit, AllowPredicates);
9147 if (EL.hasAnyInfo())
9148 return EL;
9149 }
9150
9151 // If it's not an integer or pointer comparison then compute it the hard way.
9152 return computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
9153}
9154
9155std::optional<ScalarEvolution::ExitLimit>
9156ScalarEvolution::computeExitLimitFromCondFromBinOp(
9157 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9158 bool ControlsOnlyExit, bool AllowPredicates) {
9159 // Check if the controlling expression for this loop is an And or Or.
9160 Value *Op0, *Op1;
9161 bool IsAnd = false;
9162 if (match(ExitCond, m_LogicalAnd(m_Value(Op0), m_Value(Op1))))
9163 IsAnd = true;
9164 else if (match(ExitCond, m_LogicalOr(m_Value(Op0), m_Value(Op1))))
9165 IsAnd = false;
9166 else
9167 return std::nullopt;
9168
9169 // EitherMayExit is true in these two cases:
9170 // br (and Op0 Op1), loop, exit
9171 // br (or Op0 Op1), exit, loop
9172 bool EitherMayExit = IsAnd ^ ExitIfTrue;
9173 ExitLimit EL0 = computeExitLimitFromCondCached(
9174 Cache, L, Op0, ExitIfTrue, ControlsOnlyExit && !EitherMayExit,
9175 AllowPredicates);
9176 ExitLimit EL1 = computeExitLimitFromCondCached(
9177 Cache, L, Op1, ExitIfTrue, ControlsOnlyExit && !EitherMayExit,
9178 AllowPredicates);
9179
9180 // Be robust against unsimplified IR for the form "op i1 X, NeutralElement"
9181 const Constant *NeutralElement = ConstantInt::get(ExitCond->getType(), IsAnd);
9182 if (isa<ConstantInt>(Op1))
9183 return Op1 == NeutralElement ? EL0 : EL1;
9184 if (isa<ConstantInt>(Op0))
9185 return Op0 == NeutralElement ? EL1 : EL0;
9186
9187 const SCEV *BECount = getCouldNotCompute();
9188 const SCEV *ConstantMaxBECount = getCouldNotCompute();
9189 const SCEV *SymbolicMaxBECount = getCouldNotCompute();
9190 if (EitherMayExit) {
9191 bool UseSequentialUMin = !isa<BinaryOperator>(ExitCond);
9192 // Both conditions must be same for the loop to continue executing.
9193 // Choose the less conservative count.
9194 if (EL0.ExactNotTaken != getCouldNotCompute() &&
9195 EL1.ExactNotTaken != getCouldNotCompute()) {
9196 BECount = getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken,
9197 UseSequentialUMin);
9198 }
9199 if (EL0.ConstantMaxNotTaken == getCouldNotCompute())
9200 ConstantMaxBECount = EL1.ConstantMaxNotTaken;
9201 else if (EL1.ConstantMaxNotTaken == getCouldNotCompute())
9202 ConstantMaxBECount = EL0.ConstantMaxNotTaken;
9203 else
9204 ConstantMaxBECount = getUMinFromMismatchedTypes(EL0.ConstantMaxNotTaken,
9205 EL1.ConstantMaxNotTaken);
9206 if (EL0.SymbolicMaxNotTaken == getCouldNotCompute())
9207 SymbolicMaxBECount = EL1.SymbolicMaxNotTaken;
9208 else if (EL1.SymbolicMaxNotTaken == getCouldNotCompute())
9209 SymbolicMaxBECount = EL0.SymbolicMaxNotTaken;
9210 else
9211 SymbolicMaxBECount = getUMinFromMismatchedTypes(
9212 EL0.SymbolicMaxNotTaken, EL1.SymbolicMaxNotTaken, UseSequentialUMin);
9213 } else {
9214 // Both conditions must be same at the same time for the loop to exit.
9215 // For now, be conservative.
9216 if (EL0.ExactNotTaken == EL1.ExactNotTaken)
9217 BECount = EL0.ExactNotTaken;
9218 }
9219
9220 // There are cases (e.g. PR26207) where computeExitLimitFromCond is able
9221 // to be more aggressive when computing BECount than when computing
9222 // ConstantMaxBECount. In these cases it is possible for EL0.ExactNotTaken
9223 // and
9224 // EL1.ExactNotTaken to match, but for EL0.ConstantMaxNotTaken and
9225 // EL1.ConstantMaxNotTaken to not.
9226 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
9227 !isa<SCEVCouldNotCompute>(BECount))
9228 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
9229 if (isa<SCEVCouldNotCompute>(SymbolicMaxBECount))
9230 SymbolicMaxBECount =
9231 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
9232 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
9233 {ArrayRef(EL0.Predicates), ArrayRef(EL1.Predicates)});
9234}
9235
9236ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9237 const Loop *L, ICmpInst *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
9238 bool AllowPredicates) {
9239 // If the condition was exit on true, convert the condition to exit on false
9240 CmpPredicate Pred;
9241 if (!ExitIfTrue)
9242 Pred = ExitCond->getCmpPredicate();
9243 else
9244 Pred = ExitCond->getInverseCmpPredicate();
9245 const ICmpInst::Predicate OriginalPred = Pred;
9246
9247 const SCEV *LHS = getSCEV(ExitCond->getOperand(0));
9248 const SCEV *RHS = getSCEV(ExitCond->getOperand(1));
9249
9250 ExitLimit EL = computeExitLimitFromICmp(L, Pred, LHS, RHS, ControlsOnlyExit,
9251 AllowPredicates);
9252 if (EL.hasAnyInfo())
9253 return EL;
9254
9255 auto *ExhaustiveCount =
9256 computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
9257
9258 if (!isa<SCEVCouldNotCompute>(ExhaustiveCount))
9259 return ExhaustiveCount;
9260
9261 return computeShiftCompareExitLimit(ExitCond->getOperand(0),
9262 ExitCond->getOperand(1), L, OriginalPred);
9263}
9264ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9265 const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS,
9266 bool ControlsOnlyExit, bool AllowPredicates) {
9267
9268 // Try to evaluate any dependencies out of the loop.
9269 LHS = getSCEVAtScope(LHS, L);
9270 RHS = getSCEVAtScope(RHS, L);
9271
9272 // At this point, we would like to compute how many iterations of the
9273 // loop the predicate will return true for these inputs.
9274 if (isLoopInvariant(LHS, L) && !isLoopInvariant(RHS, L)) {
9275 // If there is a loop-invariant, force it into the RHS.
9276 std::swap(LHS, RHS);
9278 }
9279
9280 bool ControllingFiniteLoop = ControlsOnlyExit && loopHasNoAbnormalExits(L) &&
9282 // Simplify the operands before analyzing them.
9283 (void)SimplifyICmpOperands(Pred, LHS, RHS, /*Depth=*/0);
9284
9285 // If we have a comparison of a chrec against a constant, try to use value
9286 // ranges to answer this query.
9287 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS))
9288 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS))
9289 if (AddRec->getLoop() == L) {
9290 // Form the constant range.
9291 ConstantRange CompRange =
9292 ConstantRange::makeExactICmpRegion(Pred, RHSC->getAPInt());
9293
9294 const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this);
9295 if (!isa<SCEVCouldNotCompute>(Ret)) return Ret;
9296 }
9297
9298 // If this loop must exit based on this condition (or execute undefined
9299 // behaviour), see if we can improve wrap flags. This is essentially
9300 // a must execute style proof.
9301 if (ControllingFiniteLoop && isLoopInvariant(RHS, L)) {
9302 // If we can prove the test sequence produced must repeat the same values
9303 // on self-wrap of the IV, then we can infer that IV doesn't self wrap
9304 // because if it did, we'd have an infinite (undefined) loop.
9305 // TODO: We can peel off any functions which are invertible *in L*. Loop
9306 // invariant terms are effectively constants for our purposes here.
9307 auto *InnerLHS = LHS;
9308 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS))
9309 InnerLHS = ZExt->getOperand();
9310 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(InnerLHS);
9311 AR && !AR->hasNoSelfWrap() && AR->getLoop() == L && AR->isAffine() &&
9312 isKnownToBeAPowerOfTwo(AR->getStepRecurrence(*this), /*OrZero=*/true,
9313 /*OrNegative=*/true)) {
9314 auto Flags = AR->getNoWrapFlags();
9315 Flags = setFlags(Flags, SCEV::FlagNW);
9316 SmallVector<const SCEV *> Operands{AR->operands()};
9317 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
9318 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
9319 }
9320
9321 // For a slt/ult condition with a positive step, can we prove nsw/nuw?
9322 // From no-self-wrap, this follows trivially from the fact that every
9323 // (un)signed-wrapped, but not self-wrapped value must be LT than the
9324 // last value before (un)signed wrap. Since we know that last value
9325 // didn't exit, nor will any smaller one.
9326 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT) {
9327 auto WrapType = Pred == ICmpInst::ICMP_SLT ? SCEV::FlagNSW : SCEV::FlagNUW;
9328 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS);
9329 AR && AR->getLoop() == L && AR->isAffine() &&
9330 !AR->getNoWrapFlags(WrapType) && AR->hasNoSelfWrap() &&
9331 isKnownPositive(AR->getStepRecurrence(*this))) {
9332 auto Flags = AR->getNoWrapFlags();
9333 Flags = setFlags(Flags, WrapType);
9334 SmallVector<const SCEV*> Operands{AR->operands()};
9335 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
9336 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
9337 }
9338 }
9339 }
9340
9341 switch (Pred) {
9342 case ICmpInst::ICMP_NE: { // while (X != Y)
9343 // Convert to: while (X-Y != 0)
9344 if (LHS->getType()->isPointerTy()) {
9347 return LHS;
9348 }
9349 if (RHS->getType()->isPointerTy()) {
9352 return RHS;
9353 }
9354 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit,
9355 AllowPredicates);
9356 if (EL.hasAnyInfo())
9357 return EL;
9358 break;
9359 }
9360 case ICmpInst::ICMP_EQ: { // while (X == Y)
9361 // Convert to: while (X-Y == 0)
9362 if (LHS->getType()->isPointerTy()) {
9365 return LHS;
9366 }
9367 if (RHS->getType()->isPointerTy()) {
9370 return RHS;
9371 }
9372 ExitLimit EL = howFarToNonZero(getMinusSCEV(LHS, RHS), L);
9373 if (EL.hasAnyInfo()) return EL;
9374 break;
9375 }
9376 case ICmpInst::ICMP_SLE:
9377 case ICmpInst::ICMP_ULE:
9378 // Since the loop is finite, an invariant RHS cannot include the boundary
9379 // value, otherwise it would loop forever.
9380 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9381 !isLoopInvariant(RHS, L)) {
9382 // Otherwise, perform the addition in a wider type, to avoid overflow.
9383 // If the LHS is an addrec with the appropriate nowrap flag, the
9384 // extension will be sunk into it and the exit count can be analyzed.
9385 auto *OldType = dyn_cast<IntegerType>(LHS->getType());
9386 if (!OldType)
9387 break;
9388 // Prefer doubling the bitwidth over adding a single bit to make it more
9389 // likely that we use a legal type.
9390 auto *NewType =
9391 Type::getIntNTy(OldType->getContext(), OldType->getBitWidth() * 2);
9392 if (ICmpInst::isSigned(Pred)) {
9393 LHS = getSignExtendExpr(LHS, NewType);
9394 RHS = getSignExtendExpr(RHS, NewType);
9395 } else {
9396 LHS = getZeroExtendExpr(LHS, NewType);
9397 RHS = getZeroExtendExpr(RHS, NewType);
9398 }
9399 }
9401 [[fallthrough]];
9402 case ICmpInst::ICMP_SLT:
9403 case ICmpInst::ICMP_ULT: { // while (X < Y)
9404 bool IsSigned = ICmpInst::isSigned(Pred);
9405 ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9406 AllowPredicates);
9407 if (EL.hasAnyInfo())
9408 return EL;
9409 break;
9410 }
9411 case ICmpInst::ICMP_SGE:
9412 case ICmpInst::ICMP_UGE:
9413 // Since the loop is finite, an invariant RHS cannot include the boundary
9414 // value, otherwise it would loop forever.
9415 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9416 !isLoopInvariant(RHS, L))
9417 break;
9419 [[fallthrough]];
9420 case ICmpInst::ICMP_SGT:
9421 case ICmpInst::ICMP_UGT: { // while (X > Y)
9422 bool IsSigned = ICmpInst::isSigned(Pred);
9423 ExitLimit EL = howManyGreaterThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9424 AllowPredicates);
9425 if (EL.hasAnyInfo())
9426 return EL;
9427 break;
9428 }
9429 default:
9430 break;
9431 }
9432
9433 return getCouldNotCompute();
9434}
9435
9436ScalarEvolution::ExitLimit
9437ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L,
9438 SwitchInst *Switch,
9439 BasicBlock *ExitingBlock,
9440 bool ControlsOnlyExit) {
9441 assert(!L->contains(ExitingBlock) && "Not an exiting block!");
9442
9443 // Give up if the exit is the default dest of a switch.
9444 if (Switch->getDefaultDest() == ExitingBlock)
9445 return getCouldNotCompute();
9446
9447 assert(L->contains(Switch->getDefaultDest()) &&
9448 "Default case must not exit the loop!");
9449 const SCEV *LHS = getSCEVAtScope(Switch->getCondition(), L);
9450 const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock));
9451
9452 // while (X != Y) --> while (X-Y != 0)
9453 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit);
9454 if (EL.hasAnyInfo())
9455 return EL;
9456
9457 return getCouldNotCompute();
9458}
9459
9460static ConstantInt *
9462 ScalarEvolution &SE) {
9463 const SCEV *InVal = SE.getConstant(C);
9464 const SCEV *Val = AddRec->evaluateAtIteration(InVal, SE);
9466 "Evaluation of SCEV at constant didn't fold correctly?");
9467 return cast<SCEVConstant>(Val)->getValue();
9468}
9469
9470ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit(
9471 Value *LHS, Value *RHSV, const Loop *L, ICmpInst::Predicate Pred) {
9472 ConstantInt *RHS = dyn_cast<ConstantInt>(RHSV);
9473 if (!RHS)
9474 return getCouldNotCompute();
9475
9476 const BasicBlock *Latch = L->getLoopLatch();
9477 if (!Latch)
9478 return getCouldNotCompute();
9479
9480 const BasicBlock *Predecessor = L->getLoopPredecessor();
9481 if (!Predecessor)
9482 return getCouldNotCompute();
9483
9484 // Return true if V is of the form "LHS `shift_op` <positive constant>".
9485 // Return LHS in OutLHS and shift_opt in OutOpCode.
9486 auto MatchPositiveShift =
9487 [](Value *V, Value *&OutLHS, Instruction::BinaryOps &OutOpCode) {
9488
9489 using namespace PatternMatch;
9490
9491 ConstantInt *ShiftAmt;
9492 if (match(V, m_LShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9493 OutOpCode = Instruction::LShr;
9494 else if (match(V, m_AShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9495 OutOpCode = Instruction::AShr;
9496 else if (match(V, m_Shl(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9497 OutOpCode = Instruction::Shl;
9498 else
9499 return false;
9500
9501 return ShiftAmt->getValue().isStrictlyPositive();
9502 };
9503
9504 // Recognize a "shift recurrence" either of the form %iv or of %iv.shifted in
9505 //
9506 // loop:
9507 // %iv = phi i32 [ %iv.shifted, %loop ], [ %val, %preheader ]
9508 // %iv.shifted = lshr i32 %iv, <positive constant>
9509 //
9510 // Return true on a successful match. Return the corresponding PHI node (%iv
9511 // above) in PNOut and the opcode of the shift operation in OpCodeOut.
9512 auto MatchShiftRecurrence =
9513 [&](Value *V, PHINode *&PNOut, Instruction::BinaryOps &OpCodeOut) {
9514 std::optional<Instruction::BinaryOps> PostShiftOpCode;
9515
9516 {
9518 Value *V;
9519
9520 // If we encounter a shift instruction, "peel off" the shift operation,
9521 // and remember that we did so. Later when we inspect %iv's backedge
9522 // value, we will make sure that the backedge value uses the same
9523 // operation.
9524 //
9525 // Note: the peeled shift operation does not have to be the same
9526 // instruction as the one feeding into the PHI's backedge value. We only
9527 // really care about it being the same *kind* of shift instruction --
9528 // that's all that is required for our later inferences to hold.
9529 if (MatchPositiveShift(LHS, V, OpC)) {
9530 PostShiftOpCode = OpC;
9531 LHS = V;
9532 }
9533 }
9534
9535 PNOut = dyn_cast<PHINode>(LHS);
9536 if (!PNOut || PNOut->getParent() != L->getHeader())
9537 return false;
9538
9539 Value *BEValue = PNOut->getIncomingValueForBlock(Latch);
9540 Value *OpLHS;
9541
9542 return
9543 // The backedge value for the PHI node must be a shift by a positive
9544 // amount
9545 MatchPositiveShift(BEValue, OpLHS, OpCodeOut) &&
9546
9547 // of the PHI node itself
9548 OpLHS == PNOut &&
9549
9550 // and the kind of shift should be match the kind of shift we peeled
9551 // off, if any.
9552 (!PostShiftOpCode || *PostShiftOpCode == OpCodeOut);
9553 };
9554
9555 PHINode *PN;
9557 if (!MatchShiftRecurrence(LHS, PN, OpCode))
9558 return getCouldNotCompute();
9559
9560 const DataLayout &DL = getDataLayout();
9561
9562 // The key rationale for this optimization is that for some kinds of shift
9563 // recurrences, the value of the recurrence "stabilizes" to either 0 or -1
9564 // within a finite number of iterations. If the condition guarding the
9565 // backedge (in the sense that the backedge is taken if the condition is true)
9566 // is false for the value the shift recurrence stabilizes to, then we know
9567 // that the backedge is taken only a finite number of times.
9568
9569 ConstantInt *StableValue = nullptr;
9570 switch (OpCode) {
9571 default:
9572 llvm_unreachable("Impossible case!");
9573
9574 case Instruction::AShr: {
9575 // {K,ashr,<positive-constant>} stabilizes to signum(K) in at most
9576 // bitwidth(K) iterations.
9577 Value *FirstValue = PN->getIncomingValueForBlock(Predecessor);
9578 KnownBits Known = computeKnownBits(FirstValue, DL, &AC,
9579 Predecessor->getTerminator(), &DT);
9580 auto *Ty = cast<IntegerType>(RHS->getType());
9581 if (Known.isNonNegative())
9582 StableValue = ConstantInt::get(Ty, 0);
9583 else if (Known.isNegative())
9584 StableValue = ConstantInt::get(Ty, -1, true);
9585 else
9586 return getCouldNotCompute();
9587
9588 break;
9589 }
9590 case Instruction::LShr:
9591 case Instruction::Shl:
9592 // Both {K,lshr,<positive-constant>} and {K,shl,<positive-constant>}
9593 // stabilize to 0 in at most bitwidth(K) iterations.
9594 StableValue = ConstantInt::get(cast<IntegerType>(RHS->getType()), 0);
9595 break;
9596 }
9597
9598 auto *Result =
9599 ConstantFoldCompareInstOperands(Pred, StableValue, RHS, DL, &TLI);
9600 assert(Result->getType()->isIntegerTy(1) &&
9601 "Otherwise cannot be an operand to a branch instruction");
9602
9603 if (Result->isZeroValue()) {
9604 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
9605 const SCEV *UpperBound =
9607 return ExitLimit(getCouldNotCompute(), UpperBound, UpperBound, false);
9608 }
9609
9610 return getCouldNotCompute();
9611}
9612
9613/// Return true if we can constant fold an instruction of the specified type,
9614/// assuming that all operands were constants.
9615static bool CanConstantFold(const Instruction *I) {
9619 return true;
9620
9621 if (const CallInst *CI = dyn_cast<CallInst>(I))
9622 if (const Function *F = CI->getCalledFunction())
9623 return canConstantFoldCallTo(CI, F);
9624 return false;
9625}
9626
9627/// Determine whether this instruction can constant evolve within this loop
9628/// assuming its operands can all constant evolve.
9629static bool canConstantEvolve(Instruction *I, const Loop *L) {
9630 // An instruction outside of the loop can't be derived from a loop PHI.
9631 if (!L->contains(I)) return false;
9632
9633 if (isa<PHINode>(I)) {
9634 // We don't currently keep track of the control flow needed to evaluate
9635 // PHIs, so we cannot handle PHIs inside of loops.
9636 return L->getHeader() == I->getParent();
9637 }
9638
9639 // If we won't be able to constant fold this expression even if the operands
9640 // are constants, bail early.
9641 return CanConstantFold(I);
9642}
9643
9644/// getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by
9645/// recursing through each instruction operand until reaching a loop header phi.
9646static PHINode *
9649 unsigned Depth) {
9651 return nullptr;
9652
9653 // Otherwise, we can evaluate this instruction if all of its operands are
9654 // constant or derived from a PHI node themselves.
9655 PHINode *PHI = nullptr;
9656 for (Value *Op : UseInst->operands()) {
9657 if (isa<Constant>(Op)) continue;
9658
9660 if (!OpInst || !canConstantEvolve(OpInst, L)) return nullptr;
9661
9662 PHINode *P = dyn_cast<PHINode>(OpInst);
9663 if (!P)
9664 // If this operand is already visited, reuse the prior result.
9665 // We may have P != PHI if this is the deepest point at which the
9666 // inconsistent paths meet.
9667 P = PHIMap.lookup(OpInst);
9668 if (!P) {
9669 // Recurse and memoize the results, whether a phi is found or not.
9670 // This recursive call invalidates pointers into PHIMap.
9671 P = getConstantEvolvingPHIOperands(OpInst, L, PHIMap, Depth + 1);
9672 PHIMap[OpInst] = P;
9673 }
9674 if (!P)
9675 return nullptr; // Not evolving from PHI
9676 if (PHI && PHI != P)
9677 return nullptr; // Evolving from multiple different PHIs.
9678 PHI = P;
9679 }
9680 // This is a expression evolving from a constant PHI!
9681 return PHI;
9682}
9683
9684/// getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node
9685/// in the loop that V is derived from. We allow arbitrary operations along the
9686/// way, but the operands of an operation must either be constants or a value
9687/// derived from a constant PHI. If this expression does not fit with these
9688/// constraints, return null.
9691 if (!I || !canConstantEvolve(I, L)) return nullptr;
9692
9693 if (PHINode *PN = dyn_cast<PHINode>(I))
9694 return PN;
9695
9696 // Record non-constant instructions contained by the loop.
9698 return getConstantEvolvingPHIOperands(I, L, PHIMap, 0);
9699}
9700
9701/// EvaluateExpression - Given an expression that passes the
9702/// getConstantEvolvingPHI predicate, evaluate its value assuming the PHI node
9703/// in the loop has the value PHIVal. If we can't fold this expression for some
9704/// reason, return null.
9707 const DataLayout &DL,
9708 const TargetLibraryInfo *TLI) {
9709 // Convenient constant check, but redundant for recursive calls.
9710 if (Constant *C = dyn_cast<Constant>(V)) return C;
9712 if (!I) return nullptr;
9713
9714 if (Constant *C = Vals.lookup(I)) return C;
9715
9716 // An instruction inside the loop depends on a value outside the loop that we
9717 // weren't given a mapping for, or a value such as a call inside the loop.
9718 if (!canConstantEvolve(I, L)) return nullptr;
9719
9720 // An unmapped PHI can be due to a branch or another loop inside this loop,
9721 // or due to this not being the initial iteration through a loop where we
9722 // couldn't compute the evolution of this particular PHI last time.
9723 if (isa<PHINode>(I)) return nullptr;
9724
9725 std::vector<Constant*> Operands(I->getNumOperands());
9726
9727 for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
9728 Instruction *Operand = dyn_cast<Instruction>(I->getOperand(i));
9729 if (!Operand) {
9730 Operands[i] = dyn_cast<Constant>(I->getOperand(i));
9731 if (!Operands[i]) return nullptr;
9732 continue;
9733 }
9734 Constant *C = EvaluateExpression(Operand, L, Vals, DL, TLI);
9735 Vals[Operand] = C;
9736 if (!C) return nullptr;
9737 Operands[i] = C;
9738 }
9739
9740 return ConstantFoldInstOperands(I, Operands, DL, TLI,
9741 /*AllowNonDeterministic=*/false);
9742}
9743
9744
9745// If every incoming value to PN except the one for BB is a specific Constant,
9746// return that, else return nullptr.
9748 Constant *IncomingVal = nullptr;
9749
9750 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
9751 if (PN->getIncomingBlock(i) == BB)
9752 continue;
9753
9754 auto *CurrentVal = dyn_cast<Constant>(PN->getIncomingValue(i));
9755 if (!CurrentVal)
9756 return nullptr;
9757
9758 if (IncomingVal != CurrentVal) {
9759 if (IncomingVal)
9760 return nullptr;
9761 IncomingVal = CurrentVal;
9762 }
9763 }
9764
9765 return IncomingVal;
9766}
9767
9768/// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
9769/// in the header of its containing loop, we know the loop executes a
9770/// constant number of times, and the PHI node is just a recurrence
9771/// involving constants, fold it.
9772Constant *
9773ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN,
9774 const APInt &BEs,
9775 const Loop *L) {
9776 auto [I, Inserted] = ConstantEvolutionLoopExitValue.try_emplace(PN);
9777 if (!Inserted)
9778 return I->second;
9779
9781 return nullptr; // Not going to evaluate it.
9782
9783 Constant *&RetVal = I->second;
9784
9785 DenseMap<Instruction *, Constant *> CurrentIterVals;
9786 BasicBlock *Header = L->getHeader();
9787 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
9788
9789 BasicBlock *Latch = L->getLoopLatch();
9790 if (!Latch)
9791 return nullptr;
9792
9793 for (PHINode &PHI : Header->phis()) {
9794 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
9795 CurrentIterVals[&PHI] = StartCST;
9796 }
9797 if (!CurrentIterVals.count(PN))
9798 return RetVal = nullptr;
9799
9800 Value *BEValue = PN->getIncomingValueForBlock(Latch);
9801
9802 // Execute the loop symbolically to determine the exit value.
9803 assert(BEs.getActiveBits() < CHAR_BIT * sizeof(unsigned) &&
9804 "BEs is <= MaxBruteForceIterations which is an 'unsigned'!");
9805
9806 unsigned NumIterations = BEs.getZExtValue(); // must be in range
9807 unsigned IterationNum = 0;
9808 const DataLayout &DL = getDataLayout();
9809 for (; ; ++IterationNum) {
9810 if (IterationNum == NumIterations)
9811 return RetVal = CurrentIterVals[PN]; // Got exit value!
9812
9813 // Compute the value of the PHIs for the next iteration.
9814 // EvaluateExpression adds non-phi values to the CurrentIterVals map.
9815 DenseMap<Instruction *, Constant *> NextIterVals;
9816 Constant *NextPHI =
9817 EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9818 if (!NextPHI)
9819 return nullptr; // Couldn't evaluate!
9820 NextIterVals[PN] = NextPHI;
9821
9822 bool StoppedEvolving = NextPHI == CurrentIterVals[PN];
9823
9824 // Also evaluate the other PHI nodes. However, we don't get to stop if we
9825 // cease to be able to evaluate one of them or if they stop evolving,
9826 // because that doesn't necessarily prevent us from computing PN.
9828 for (const auto &I : CurrentIterVals) {
9829 PHINode *PHI = dyn_cast<PHINode>(I.first);
9830 if (!PHI || PHI == PN || PHI->getParent() != Header) continue;
9831 PHIsToCompute.emplace_back(PHI, I.second);
9832 }
9833 // We use two distinct loops because EvaluateExpression may invalidate any
9834 // iterators into CurrentIterVals.
9835 for (const auto &I : PHIsToCompute) {
9836 PHINode *PHI = I.first;
9837 Constant *&NextPHI = NextIterVals[PHI];
9838 if (!NextPHI) { // Not already computed.
9839 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
9840 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9841 }
9842 if (NextPHI != I.second)
9843 StoppedEvolving = false;
9844 }
9845
9846 // If all entries in CurrentIterVals == NextIterVals then we can stop
9847 // iterating, the loop can't continue to change.
9848 if (StoppedEvolving)
9849 return RetVal = CurrentIterVals[PN];
9850
9851 CurrentIterVals.swap(NextIterVals);
9852 }
9853}
9854
9855const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L,
9856 Value *Cond,
9857 bool ExitWhen) {
9858 PHINode *PN = getConstantEvolvingPHI(Cond, L);
9859 if (!PN) return getCouldNotCompute();
9860
9861 // If the loop is canonicalized, the PHI will have exactly two entries.
9862 // That's the only form we support here.
9863 if (PN->getNumIncomingValues() != 2) return getCouldNotCompute();
9864
9865 DenseMap<Instruction *, Constant *> CurrentIterVals;
9866 BasicBlock *Header = L->getHeader();
9867 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
9868
9869 BasicBlock *Latch = L->getLoopLatch();
9870 assert(Latch && "Should follow from NumIncomingValues == 2!");
9871
9872 for (PHINode &PHI : Header->phis()) {
9873 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
9874 CurrentIterVals[&PHI] = StartCST;
9875 }
9876 if (!CurrentIterVals.count(PN))
9877 return getCouldNotCompute();
9878
9879 // Okay, we find a PHI node that defines the trip count of this loop. Execute
9880 // the loop symbolically to determine when the condition gets a value of
9881 // "ExitWhen".
9882 unsigned MaxIterations = MaxBruteForceIterations; // Limit analysis.
9883 const DataLayout &DL = getDataLayout();
9884 for (unsigned IterationNum = 0; IterationNum != MaxIterations;++IterationNum){
9885 auto *CondVal = dyn_cast_or_null<ConstantInt>(
9886 EvaluateExpression(Cond, L, CurrentIterVals, DL, &TLI));
9887
9888 // Couldn't symbolically evaluate.
9889 if (!CondVal) return getCouldNotCompute();
9890
9891 if (CondVal->getValue() == uint64_t(ExitWhen)) {
9892 ++NumBruteForceTripCountsComputed;
9893 return getConstant(Type::getInt32Ty(getContext()), IterationNum);
9894 }
9895
9896 // Update all the PHI nodes for the next iteration.
9897 DenseMap<Instruction *, Constant *> NextIterVals;
9898
9899 // Create a list of which PHIs we need to compute. We want to do this before
9900 // calling EvaluateExpression on them because that may invalidate iterators
9901 // into CurrentIterVals.
9902 SmallVector<PHINode *, 8> PHIsToCompute;
9903 for (const auto &I : CurrentIterVals) {
9904 PHINode *PHI = dyn_cast<PHINode>(I.first);
9905 if (!PHI || PHI->getParent() != Header) continue;
9906 PHIsToCompute.push_back(PHI);
9907 }
9908 for (PHINode *PHI : PHIsToCompute) {
9909 Constant *&NextPHI = NextIterVals[PHI];
9910 if (NextPHI) continue; // Already computed!
9911
9912 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
9913 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
9914 }
9915 CurrentIterVals.swap(NextIterVals);
9916 }
9917
9918 // Too many iterations were needed to evaluate.
9919 return getCouldNotCompute();
9920}
9921
9922const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
9924 ValuesAtScopes[V];
9925 // Check to see if we've folded this expression at this loop before.
9926 for (auto &LS : Values)
9927 if (LS.first == L)
9928 return LS.second ? LS.second : V;
9929
9930 Values.emplace_back(L, nullptr);
9931
9932 // Otherwise compute it.
9933 const SCEV *C = computeSCEVAtScope(V, L);
9934 for (auto &LS : reverse(ValuesAtScopes[V]))
9935 if (LS.first == L) {
9936 LS.second = C;
9937 if (!isa<SCEVConstant>(C))
9938 ValuesAtScopesUsers[C].push_back({L, V});
9939 break;
9940 }
9941 return C;
9942}
9943
9944/// This builds up a Constant using the ConstantExpr interface. That way, we
9945/// will return Constants for objects which aren't represented by a
9946/// SCEVConstant, because SCEVConstant is restricted to ConstantInt.
9947/// Returns NULL if the SCEV isn't representable as a Constant.
9949 switch (V->getSCEVType()) {
9950 case scCouldNotCompute:
9951 case scAddRecExpr:
9952 case scVScale:
9953 return nullptr;
9954 case scConstant:
9955 return cast<SCEVConstant>(V)->getValue();
9956 case scUnknown:
9957 return dyn_cast<Constant>(cast<SCEVUnknown>(V)->getValue());
9958 case scPtrToInt: {
9960 if (Constant *CastOp = BuildConstantFromSCEV(P2I->getOperand()))
9961 return ConstantExpr::getPtrToInt(CastOp, P2I->getType());
9962
9963 return nullptr;
9964 }
9965 case scTruncate: {
9967 if (Constant *CastOp = BuildConstantFromSCEV(ST->getOperand()))
9968 return ConstantExpr::getTrunc(CastOp, ST->getType());
9969 return nullptr;
9970 }
9971 case scAddExpr: {
9972 const SCEVAddExpr *SA = cast<SCEVAddExpr>(V);
9973 Constant *C = nullptr;
9974 for (const SCEV *Op : SA->operands()) {
9976 if (!OpC)
9977 return nullptr;
9978 if (!C) {
9979 C = OpC;
9980 continue;
9981 }
9982 assert(!C->getType()->isPointerTy() &&
9983 "Can only have one pointer, and it must be last");
9984 if (OpC->getType()->isPointerTy()) {
9985 // The offsets have been converted to bytes. We can add bytes using
9986 // an i8 GEP.
9988 OpC, C);
9989 } else {
9990 C = ConstantExpr::getAdd(C, OpC);
9991 }
9992 }
9993 return C;
9994 }
9995 case scMulExpr:
9996 case scSignExtend:
9997 case scZeroExtend:
9998 case scUDivExpr:
9999 case scSMaxExpr:
10000 case scUMaxExpr:
10001 case scSMinExpr:
10002 case scUMinExpr:
10004 return nullptr;
10005 }
10006 llvm_unreachable("Unknown SCEV kind!");
10007}
10008
10009const SCEV *
10010ScalarEvolution::getWithOperands(const SCEV *S,
10011 SmallVectorImpl<const SCEV *> &NewOps) {
10012 switch (S->getSCEVType()) {
10013 case scTruncate:
10014 case scZeroExtend:
10015 case scSignExtend:
10016 case scPtrToInt:
10017 return getCastExpr(S->getSCEVType(), NewOps[0], S->getType());
10018 case scAddRecExpr: {
10019 auto *AddRec = cast<SCEVAddRecExpr>(S);
10020 return getAddRecExpr(NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags());
10021 }
10022 case scAddExpr:
10023 return getAddExpr(NewOps, cast<SCEVAddExpr>(S)->getNoWrapFlags());
10024 case scMulExpr:
10025 return getMulExpr(NewOps, cast<SCEVMulExpr>(S)->getNoWrapFlags());
10026 case scUDivExpr:
10027 return getUDivExpr(NewOps[0], NewOps[1]);
10028 case scUMaxExpr:
10029 case scSMaxExpr:
10030 case scUMinExpr:
10031 case scSMinExpr:
10032 return getMinMaxExpr(S->getSCEVType(), NewOps);
10034 return getSequentialMinMaxExpr(S->getSCEVType(), NewOps);
10035 case scConstant:
10036 case scVScale:
10037 case scUnknown:
10038 return S;
10039 case scCouldNotCompute:
10040 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
10041 }
10042 llvm_unreachable("Unknown SCEV kind!");
10043}
10044
10045const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
10046 switch (V->getSCEVType()) {
10047 case scConstant:
10048 case scVScale:
10049 return V;
10050 case scAddRecExpr: {
10051 // If this is a loop recurrence for a loop that does not contain L, then we
10052 // are dealing with the final value computed by the loop.
10053 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(V);
10054 // First, attempt to evaluate each operand.
10055 // Avoid performing the look-up in the common case where the specified
10056 // expression has no loop-variant portions.
10057 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
10058 const SCEV *OpAtScope = getSCEVAtScope(AddRec->getOperand(i), L);
10059 if (OpAtScope == AddRec->getOperand(i))
10060 continue;
10061
10062 // Okay, at least one of these operands is loop variant but might be
10063 // foldable. Build a new instance of the folded commutative expression.
10065 NewOps.reserve(AddRec->getNumOperands());
10066 append_range(NewOps, AddRec->operands().take_front(i));
10067 NewOps.push_back(OpAtScope);
10068 for (++i; i != e; ++i)
10069 NewOps.push_back(getSCEVAtScope(AddRec->getOperand(i), L));
10070
10071 const SCEV *FoldedRec = getAddRecExpr(
10072 NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags(SCEV::FlagNW));
10073 AddRec = dyn_cast<SCEVAddRecExpr>(FoldedRec);
10074 // The addrec may be folded to a nonrecurrence, for example, if the
10075 // induction variable is multiplied by zero after constant folding. Go
10076 // ahead and return the folded value.
10077 if (!AddRec)
10078 return FoldedRec;
10079 break;
10080 }
10081
10082 // If the scope is outside the addrec's loop, evaluate it by using the
10083 // loop exit value of the addrec.
10084 if (!AddRec->getLoop()->contains(L)) {
10085 // To evaluate this recurrence, we need to know how many times the AddRec
10086 // loop iterates. Compute this now.
10087 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop());
10088 if (BackedgeTakenCount == getCouldNotCompute())
10089 return AddRec;
10090
10091 // Then, evaluate the AddRec.
10092 return AddRec->evaluateAtIteration(BackedgeTakenCount, *this);
10093 }
10094
10095 return AddRec;
10096 }
10097 case scTruncate:
10098 case scZeroExtend:
10099 case scSignExtend:
10100 case scPtrToInt:
10101 case scAddExpr:
10102 case scMulExpr:
10103 case scUDivExpr:
10104 case scUMaxExpr:
10105 case scSMaxExpr:
10106 case scUMinExpr:
10107 case scSMinExpr:
10108 case scSequentialUMinExpr: {
10109 ArrayRef<const SCEV *> Ops = V->operands();
10110 // Avoid performing the look-up in the common case where the specified
10111 // expression has no loop-variant portions.
10112 for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
10113 const SCEV *OpAtScope = getSCEVAtScope(Ops[i], L);
10114 if (OpAtScope != Ops[i]) {
10115 // Okay, at least one of these operands is loop variant but might be
10116 // foldable. Build a new instance of the folded commutative expression.
10118 NewOps.reserve(Ops.size());
10119 append_range(NewOps, Ops.take_front(i));
10120 NewOps.push_back(OpAtScope);
10121
10122 for (++i; i != e; ++i) {
10123 OpAtScope = getSCEVAtScope(Ops[i], L);
10124 NewOps.push_back(OpAtScope);
10125 }
10126
10127 return getWithOperands(V, NewOps);
10128 }
10129 }
10130 // If we got here, all operands are loop invariant.
10131 return V;
10132 }
10133 case scUnknown: {
10134 // If this instruction is evolved from a constant-evolving PHI, compute the
10135 // exit value from the loop without using SCEVs.
10136 const SCEVUnknown *SU = cast<SCEVUnknown>(V);
10138 if (!I)
10139 return V; // This is some other type of SCEVUnknown, just return it.
10140
10141 if (PHINode *PN = dyn_cast<PHINode>(I)) {
10142 const Loop *CurrLoop = this->LI[I->getParent()];
10143 // Looking for loop exit value.
10144 if (CurrLoop && CurrLoop->getParentLoop() == L &&
10145 PN->getParent() == CurrLoop->getHeader()) {
10146 // Okay, there is no closed form solution for the PHI node. Check
10147 // to see if the loop that contains it has a known backedge-taken
10148 // count. If so, we may be able to force computation of the exit
10149 // value.
10150 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(CurrLoop);
10151 // This trivial case can show up in some degenerate cases where
10152 // the incoming IR has not yet been fully simplified.
10153 if (BackedgeTakenCount->isZero()) {
10154 Value *InitValue = nullptr;
10155 bool MultipleInitValues = false;
10156 for (unsigned i = 0; i < PN->getNumIncomingValues(); i++) {
10157 if (!CurrLoop->contains(PN->getIncomingBlock(i))) {
10158 if (!InitValue)
10159 InitValue = PN->getIncomingValue(i);
10160 else if (InitValue != PN->getIncomingValue(i)) {
10161 MultipleInitValues = true;
10162 break;
10163 }
10164 }
10165 }
10166 if (!MultipleInitValues && InitValue)
10167 return getSCEV(InitValue);
10168 }
10169 // Do we have a loop invariant value flowing around the backedge
10170 // for a loop which must execute the backedge?
10171 if (!isa<SCEVCouldNotCompute>(BackedgeTakenCount) &&
10172 isKnownNonZero(BackedgeTakenCount) &&
10173 PN->getNumIncomingValues() == 2) {
10174
10175 unsigned InLoopPred =
10176 CurrLoop->contains(PN->getIncomingBlock(0)) ? 0 : 1;
10177 Value *BackedgeVal = PN->getIncomingValue(InLoopPred);
10178 if (CurrLoop->isLoopInvariant(BackedgeVal))
10179 return getSCEV(BackedgeVal);
10180 }
10181 if (auto *BTCC = dyn_cast<SCEVConstant>(BackedgeTakenCount)) {
10182 // Okay, we know how many times the containing loop executes. If
10183 // this is a constant evolving PHI node, get the final value at
10184 // the specified iteration number.
10185 Constant *RV =
10186 getConstantEvolutionLoopExitValue(PN, BTCC->getAPInt(), CurrLoop);
10187 if (RV)
10188 return getSCEV(RV);
10189 }
10190 }
10191 }
10192
10193 // Okay, this is an expression that we cannot symbolically evaluate
10194 // into a SCEV. Check to see if it's possible to symbolically evaluate
10195 // the arguments into constants, and if so, try to constant propagate the
10196 // result. This is particularly useful for computing loop exit values.
10197 if (!CanConstantFold(I))
10198 return V; // This is some other type of SCEVUnknown, just return it.
10199
10200 SmallVector<Constant *, 4> Operands;
10201 Operands.reserve(I->getNumOperands());
10202 bool MadeImprovement = false;
10203 for (Value *Op : I->operands()) {
10204 if (Constant *C = dyn_cast<Constant>(Op)) {
10205 Operands.push_back(C);
10206 continue;
10207 }
10208
10209 // If any of the operands is non-constant and if they are
10210 // non-integer and non-pointer, don't even try to analyze them
10211 // with scev techniques.
10212 if (!isSCEVable(Op->getType()))
10213 return V;
10214
10215 const SCEV *OrigV = getSCEV(Op);
10216 const SCEV *OpV = getSCEVAtScope(OrigV, L);
10217 MadeImprovement |= OrigV != OpV;
10218
10220 if (!C)
10221 return V;
10222 assert(C->getType() == Op->getType() && "Type mismatch");
10223 Operands.push_back(C);
10224 }
10225
10226 // Check to see if getSCEVAtScope actually made an improvement.
10227 if (!MadeImprovement)
10228 return V; // This is some other type of SCEVUnknown, just return it.
10229
10230 Constant *C = nullptr;
10231 const DataLayout &DL = getDataLayout();
10232 C = ConstantFoldInstOperands(I, Operands, DL, &TLI,
10233 /*AllowNonDeterministic=*/false);
10234 if (!C)
10235 return V;
10236 return getSCEV(C);
10237 }
10238 case scCouldNotCompute:
10239 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
10240 }
10241 llvm_unreachable("Unknown SCEV type!");
10242}
10243
10245 return getSCEVAtScope(getSCEV(V), L);
10246}
10247
10248const SCEV *ScalarEvolution::stripInjectiveFunctions(const SCEV *S) const {
10250 return stripInjectiveFunctions(ZExt->getOperand());
10252 return stripInjectiveFunctions(SExt->getOperand());
10253 return S;
10254}
10255
10256/// Finds the minimum unsigned root of the following equation:
10257///
10258/// A * X = B (mod N)
10259///
10260/// where N = 2^BW and BW is the common bit width of A and B. The signedness of
10261/// A and B isn't important.
10262///
10263/// If the equation does not have a solution, SCEVCouldNotCompute is returned.
10264static const SCEV *
10267 ScalarEvolution &SE, const Loop *L) {
10268 uint32_t BW = A.getBitWidth();
10269 assert(BW == SE.getTypeSizeInBits(B->getType()));
10270 assert(A != 0 && "A must be non-zero.");
10271
10272 // 1. D = gcd(A, N)
10273 //
10274 // The gcd of A and N may have only one prime factor: 2. The number of
10275 // trailing zeros in A is its multiplicity
10276 uint32_t Mult2 = A.countr_zero();
10277 // D = 2^Mult2
10278
10279 // 2. Check if B is divisible by D.
10280 //
10281 // B is divisible by D if and only if the multiplicity of prime factor 2 for B
10282 // is not less than multiplicity of this prime factor for D.
10283 unsigned MinTZ = SE.getMinTrailingZeros(B);
10284 // Try again with the terminator of the loop predecessor for context-specific
10285 // result, if MinTZ s too small.
10286 if (MinTZ < Mult2 && L->getLoopPredecessor())
10287 MinTZ = SE.getMinTrailingZeros(B, L->getLoopPredecessor()->getTerminator());
10288 if (MinTZ < Mult2) {
10289 // Check if we can prove there's no remainder using URem.
10290 const SCEV *URem =
10291 SE.getURemExpr(B, SE.getConstant(APInt::getOneBitSet(BW, Mult2)));
10292 const SCEV *Zero = SE.getZero(B->getType());
10293 if (!SE.isKnownPredicate(CmpInst::ICMP_EQ, URem, Zero)) {
10294 // Try to add a predicate ensuring B is a multiple of 1 << Mult2.
10295 if (!Predicates)
10296 return SE.getCouldNotCompute();
10297
10298 // Avoid adding a predicate that is known to be false.
10299 if (SE.isKnownPredicate(CmpInst::ICMP_NE, URem, Zero))
10300 return SE.getCouldNotCompute();
10301 Predicates->push_back(SE.getEqualPredicate(URem, Zero));
10302 }
10303 }
10304
10305 // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic
10306 // modulo (N / D).
10307 //
10308 // If D == 1, (N / D) == N == 2^BW, so we need one extra bit to represent
10309 // (N / D) in general. The inverse itself always fits into BW bits, though,
10310 // so we immediately truncate it.
10311 APInt AD = A.lshr(Mult2).trunc(BW - Mult2); // AD = A / D
10312 APInt I = AD.multiplicativeInverse().zext(BW);
10313
10314 // 4. Compute the minimum unsigned root of the equation:
10315 // I * (B / D) mod (N / D)
10316 // To simplify the computation, we factor out the divide by D:
10317 // (I * B mod N) / D
10318 const SCEV *D = SE.getConstant(APInt::getOneBitSet(BW, Mult2));
10319 return SE.getUDivExactExpr(SE.getMulExpr(B, SE.getConstant(I)), D);
10320}
10321
10322/// For a given quadratic addrec, generate coefficients of the corresponding
10323/// quadratic equation, multiplied by a common value to ensure that they are
10324/// integers.
10325/// The returned value is a tuple { A, B, C, M, BitWidth }, where
10326/// Ax^2 + Bx + C is the quadratic function, M is the value that A, B and C
10327/// were multiplied by, and BitWidth is the bit width of the original addrec
10328/// coefficients.
10329/// This function returns std::nullopt if the addrec coefficients are not
10330/// compile- time constants.
10331static std::optional<std::tuple<APInt, APInt, APInt, APInt, unsigned>>
10333 assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!");
10334 const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0));
10335 const SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1));
10336 const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2));
10337 LLVM_DEBUG(dbgs() << __func__ << ": analyzing quadratic addrec: "
10338 << *AddRec << '\n');
10339
10340 // We currently can only solve this if the coefficients are constants.
10341 if (!LC || !MC || !NC) {
10342 LLVM_DEBUG(dbgs() << __func__ << ": coefficients are not constant\n");
10343 return std::nullopt;
10344 }
10345
10346 APInt L = LC->getAPInt();
10347 APInt M = MC->getAPInt();
10348 APInt N = NC->getAPInt();
10349 assert(!N.isZero() && "This is not a quadratic addrec");
10350
10351 unsigned BitWidth = LC->getAPInt().getBitWidth();
10352 unsigned NewWidth = BitWidth + 1;
10353 LLVM_DEBUG(dbgs() << __func__ << ": addrec coeff bw: "
10354 << BitWidth << '\n');
10355 // The sign-extension (as opposed to a zero-extension) here matches the
10356 // extension used in SolveQuadraticEquationWrap (with the same motivation).
10357 N = N.sext(NewWidth);
10358 M = M.sext(NewWidth);
10359 L = L.sext(NewWidth);
10360
10361 // The increments are M, M+N, M+2N, ..., so the accumulated values are
10362 // L+M, (L+M)+(M+N), (L+M)+(M+N)+(M+2N), ..., that is,
10363 // L+M, L+2M+N, L+3M+3N, ...
10364 // After n iterations the accumulated value Acc is L + nM + n(n-1)/2 N.
10365 //
10366 // The equation Acc = 0 is then
10367 // L + nM + n(n-1)/2 N = 0, or 2L + 2M n + n(n-1) N = 0.
10368 // In a quadratic form it becomes:
10369 // N n^2 + (2M-N) n + 2L = 0.
10370
10371 APInt A = N;
10372 APInt B = 2 * M - A;
10373 APInt C = 2 * L;
10374 APInt T = APInt(NewWidth, 2);
10375 LLVM_DEBUG(dbgs() << __func__ << ": equation " << A << "x^2 + " << B
10376 << "x + " << C << ", coeff bw: " << NewWidth
10377 << ", multiplied by " << T << '\n');
10378 return std::make_tuple(A, B, C, T, BitWidth);
10379}
10380
10381/// Helper function to compare optional APInts:
10382/// (a) if X and Y both exist, return min(X, Y),
10383/// (b) if neither X nor Y exist, return std::nullopt,
10384/// (c) if exactly one of X and Y exists, return that value.
10385static std::optional<APInt> MinOptional(std::optional<APInt> X,
10386 std::optional<APInt> Y) {
10387 if (X && Y) {
10388 unsigned W = std::max(X->getBitWidth(), Y->getBitWidth());
10389 APInt XW = X->sext(W);
10390 APInt YW = Y->sext(W);
10391 return XW.slt(YW) ? *X : *Y;
10392 }
10393 if (!X && !Y)
10394 return std::nullopt;
10395 return X ? *X : *Y;
10396}
10397
10398/// Helper function to truncate an optional APInt to a given BitWidth.
10399/// When solving addrec-related equations, it is preferable to return a value
10400/// that has the same bit width as the original addrec's coefficients. If the
10401/// solution fits in the original bit width, truncate it (except for i1).
10402/// Returning a value of a different bit width may inhibit some optimizations.
10403///
10404/// In general, a solution to a quadratic equation generated from an addrec
10405/// may require BW+1 bits, where BW is the bit width of the addrec's
10406/// coefficients. The reason is that the coefficients of the quadratic
10407/// equation are BW+1 bits wide (to avoid truncation when converting from
10408/// the addrec to the equation).
10409static std::optional<APInt> TruncIfPossible(std::optional<APInt> X,
10410 unsigned BitWidth) {
10411 if (!X)
10412 return std::nullopt;
10413 unsigned W = X->getBitWidth();
10415 return X->trunc(BitWidth);
10416 return X;
10417}
10418
10419/// Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n
10420/// iterations. The values L, M, N are assumed to be signed, and they
10421/// should all have the same bit widths.
10422/// Find the least n >= 0 such that c(n) = 0 in the arithmetic modulo 2^BW,
10423/// where BW is the bit width of the addrec's coefficients.
10424/// If the calculated value is a BW-bit integer (for BW > 1), it will be
10425/// returned as such, otherwise the bit width of the returned value may
10426/// be greater than BW.
10427///
10428/// This function returns std::nullopt if
10429/// (a) the addrec coefficients are not constant, or
10430/// (b) SolveQuadraticEquationWrap was unable to find a solution. For cases
10431/// like x^2 = 5, no integer solutions exist, in other cases an integer
10432/// solution may exist, but SolveQuadraticEquationWrap may fail to find it.
10433static std::optional<APInt>
10435 APInt A, B, C, M;
10436 unsigned BitWidth;
10437 auto T = GetQuadraticEquation(AddRec);
10438 if (!T)
10439 return std::nullopt;
10440
10441 std::tie(A, B, C, M, BitWidth) = *T;
10442 LLVM_DEBUG(dbgs() << __func__ << ": solving for unsigned overflow\n");
10443 std::optional<APInt> X =
10445 if (!X)
10446 return std::nullopt;
10447
10448 ConstantInt *CX = ConstantInt::get(SE.getContext(), *X);
10449 ConstantInt *V = EvaluateConstantChrecAtConstant(AddRec, CX, SE);
10450 if (!V->isZero())
10451 return std::nullopt;
10452
10453 return TruncIfPossible(X, BitWidth);
10454}
10455
10456/// Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n
10457/// iterations. The values M, N are assumed to be signed, and they
10458/// should all have the same bit widths.
10459/// Find the least n such that c(n) does not belong to the given range,
10460/// while c(n-1) does.
10461///
10462/// This function returns std::nullopt if
10463/// (a) the addrec coefficients are not constant, or
10464/// (b) SolveQuadraticEquationWrap was unable to find a solution for the
10465/// bounds of the range.
10466static std::optional<APInt>
10468 const ConstantRange &Range, ScalarEvolution &SE) {
10469 assert(AddRec->getOperand(0)->isZero() &&
10470 "Starting value of addrec should be 0");
10471 LLVM_DEBUG(dbgs() << __func__ << ": solving boundary crossing for range "
10472 << Range << ", addrec " << *AddRec << '\n');
10473 // This case is handled in getNumIterationsInRange. Here we can assume that
10474 // we start in the range.
10475 assert(Range.contains(APInt(SE.getTypeSizeInBits(AddRec->getType()), 0)) &&
10476 "Addrec's initial value should be in range");
10477
10478 APInt A, B, C, M;
10479 unsigned BitWidth;
10480 auto T = GetQuadraticEquation(AddRec);
10481 if (!T)
10482 return std::nullopt;
10483
10484 // Be careful about the return value: there can be two reasons for not
10485 // returning an actual number. First, if no solutions to the equations
10486 // were found, and second, if the solutions don't leave the given range.
10487 // The first case means that the actual solution is "unknown", the second
10488 // means that it's known, but not valid. If the solution is unknown, we
10489 // cannot make any conclusions.
10490 // Return a pair: the optional solution and a flag indicating if the
10491 // solution was found.
10492 auto SolveForBoundary =
10493 [&](APInt Bound) -> std::pair<std::optional<APInt>, bool> {
10494 // Solve for signed overflow and unsigned overflow, pick the lower
10495 // solution.
10496 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: checking boundary "
10497 << Bound << " (before multiplying by " << M << ")\n");
10498 Bound *= M; // The quadratic equation multiplier.
10499
10500 std::optional<APInt> SO;
10501 if (BitWidth > 1) {
10502 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10503 "signed overflow\n");
10505 }
10506 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10507 "unsigned overflow\n");
10508 std::optional<APInt> UO =
10510
10511 auto LeavesRange = [&] (const APInt &X) {
10512 ConstantInt *C0 = ConstantInt::get(SE.getContext(), X);
10513 ConstantInt *V0 = EvaluateConstantChrecAtConstant(AddRec, C0, SE);
10514 if (Range.contains(V0->getValue()))
10515 return false;
10516 // X should be at least 1, so X-1 is non-negative.
10517 ConstantInt *C1 = ConstantInt::get(SE.getContext(), X-1);
10518 ConstantInt *V1 = EvaluateConstantChrecAtConstant(AddRec, C1, SE);
10519 if (Range.contains(V1->getValue()))
10520 return true;
10521 return false;
10522 };
10523
10524 // If SolveQuadraticEquationWrap returns std::nullopt, it means that there
10525 // can be a solution, but the function failed to find it. We cannot treat it
10526 // as "no solution".
10527 if (!SO || !UO)
10528 return {std::nullopt, false};
10529
10530 // Check the smaller value first to see if it leaves the range.
10531 // At this point, both SO and UO must have values.
10532 std::optional<APInt> Min = MinOptional(SO, UO);
10533 if (LeavesRange(*Min))
10534 return { Min, true };
10535 std::optional<APInt> Max = Min == SO ? UO : SO;
10536 if (LeavesRange(*Max))
10537 return { Max, true };
10538
10539 // Solutions were found, but were eliminated, hence the "true".
10540 return {std::nullopt, true};
10541 };
10542
10543 std::tie(A, B, C, M, BitWidth) = *T;
10544 // Lower bound is inclusive, subtract 1 to represent the exiting value.
10545 APInt Lower = Range.getLower().sext(A.getBitWidth()) - 1;
10546 APInt Upper = Range.getUpper().sext(A.getBitWidth());
10547 auto SL = SolveForBoundary(Lower);
10548 auto SU = SolveForBoundary(Upper);
10549 // If any of the solutions was unknown, no meaninigful conclusions can
10550 // be made.
10551 if (!SL.second || !SU.second)
10552 return std::nullopt;
10553
10554 // Claim: The correct solution is not some value between Min and Max.
10555 //
10556 // Justification: Assuming that Min and Max are different values, one of
10557 // them is when the first signed overflow happens, the other is when the
10558 // first unsigned overflow happens. Crossing the range boundary is only
10559 // possible via an overflow (treating 0 as a special case of it, modeling
10560 // an overflow as crossing k*2^W for some k).
10561 //
10562 // The interesting case here is when Min was eliminated as an invalid
10563 // solution, but Max was not. The argument is that if there was another
10564 // overflow between Min and Max, it would also have been eliminated if
10565 // it was considered.
10566 //
10567 // For a given boundary, it is possible to have two overflows of the same
10568 // type (signed/unsigned) without having the other type in between: this
10569 // can happen when the vertex of the parabola is between the iterations
10570 // corresponding to the overflows. This is only possible when the two
10571 // overflows cross k*2^W for the same k. In such case, if the second one
10572 // left the range (and was the first one to do so), the first overflow
10573 // would have to enter the range, which would mean that either we had left
10574 // the range before or that we started outside of it. Both of these cases
10575 // are contradictions.
10576 //
10577 // Claim: In the case where SolveForBoundary returns std::nullopt, the correct
10578 // solution is not some value between the Max for this boundary and the
10579 // Min of the other boundary.
10580 //
10581 // Justification: Assume that we had such Max_A and Min_B corresponding
10582 // to range boundaries A and B and such that Max_A < Min_B. If there was
10583 // a solution between Max_A and Min_B, it would have to be caused by an
10584 // overflow corresponding to either A or B. It cannot correspond to B,
10585 // since Min_B is the first occurrence of such an overflow. If it
10586 // corresponded to A, it would have to be either a signed or an unsigned
10587 // overflow that is larger than both eliminated overflows for A. But
10588 // between the eliminated overflows and this overflow, the values would
10589 // cover the entire value space, thus crossing the other boundary, which
10590 // is a contradiction.
10591
10592 return TruncIfPossible(MinOptional(SL.first, SU.first), BitWidth);
10593}
10594
10595ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
10596 const Loop *L,
10597 bool ControlsOnlyExit,
10598 bool AllowPredicates) {
10599
10600 // This is only used for loops with a "x != y" exit test. The exit condition
10601 // is now expressed as a single expression, V = x-y. So the exit test is
10602 // effectively V != 0. We know and take advantage of the fact that this
10603 // expression only being used in a comparison by zero context.
10604
10606 // If the value is a constant
10607 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10608 // If the value is already zero, the branch will execute zero times.
10609 if (C->getValue()->isZero()) return C;
10610 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10611 }
10612
10613 const SCEVAddRecExpr *AddRec =
10614 dyn_cast<SCEVAddRecExpr>(stripInjectiveFunctions(V));
10615
10616 if (!AddRec && AllowPredicates)
10617 // Try to make this an AddRec using runtime tests, in the first X
10618 // iterations of this loop, where X is the SCEV expression found by the
10619 // algorithm below.
10620 AddRec = convertSCEVToAddRecWithPredicates(V, L, Predicates);
10621
10622 if (!AddRec || AddRec->getLoop() != L)
10623 return getCouldNotCompute();
10624
10625 // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of
10626 // the quadratic equation to solve it.
10627 if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) {
10628 // We can only use this value if the chrec ends up with an exact zero
10629 // value at this index. When solving for "X*X != 5", for example, we
10630 // should not accept a root of 2.
10631 if (auto S = SolveQuadraticAddRecExact(AddRec, *this)) {
10632 const auto *R = cast<SCEVConstant>(getConstant(*S));
10633 return ExitLimit(R, R, R, false, Predicates);
10634 }
10635 return getCouldNotCompute();
10636 }
10637
10638 // Otherwise we can only handle this if it is affine.
10639 if (!AddRec->isAffine())
10640 return getCouldNotCompute();
10641
10642 // If this is an affine expression, the execution count of this branch is
10643 // the minimum unsigned root of the following equation:
10644 //
10645 // Start + Step*N = 0 (mod 2^BW)
10646 //
10647 // equivalent to:
10648 //
10649 // Step*N = -Start (mod 2^BW)
10650 //
10651 // where BW is the common bit width of Start and Step.
10652
10653 // Get the initial value for the loop.
10654 const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop());
10655 const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop());
10656
10657 if (!isLoopInvariant(Step, L))
10658 return getCouldNotCompute();
10659
10660 LoopGuards Guards = LoopGuards::collect(L, *this);
10661 // Specialize step for this loop so we get context sensitive facts below.
10662 const SCEV *StepWLG = applyLoopGuards(Step, Guards);
10663
10664 // For positive steps (counting up until unsigned overflow):
10665 // N = -Start/Step (as unsigned)
10666 // For negative steps (counting down to zero):
10667 // N = Start/-Step
10668 // First compute the unsigned distance from zero in the direction of Step.
10669 bool CountDown = isKnownNegative(StepWLG);
10670 if (!CountDown && !isKnownNonNegative(StepWLG))
10671 return getCouldNotCompute();
10672
10673 const SCEV *Distance = CountDown ? Start : getNegativeSCEV(Start);
10674 // Handle unitary steps, which cannot wraparound.
10675 // 1*N = -Start; -1*N = Start (mod 2^BW), so:
10676 // N = Distance (as unsigned)
10677
10678 if (match(Step, m_CombineOr(m_scev_One(), m_scev_AllOnes()))) {
10679 APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, Guards));
10680 MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance));
10681
10682 // When a loop like "for (int i = 0; i != n; ++i) { /* body */ }" is rotated,
10683 // we end up with a loop whose backedge-taken count is n - 1. Detect this
10684 // case, and see if we can improve the bound.
10685 //
10686 // Explicitly handling this here is necessary because getUnsignedRange
10687 // isn't context-sensitive; it doesn't know that we only care about the
10688 // range inside the loop.
10689 const SCEV *Zero = getZero(Distance->getType());
10690 const SCEV *One = getOne(Distance->getType());
10691 const SCEV *DistancePlusOne = getAddExpr(Distance, One);
10692 if (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, DistancePlusOne, Zero)) {
10693 // If Distance + 1 doesn't overflow, we can compute the maximum distance
10694 // as "unsigned_max(Distance + 1) - 1".
10695 ConstantRange CR = getUnsignedRange(DistancePlusOne);
10696 MaxBECount = APIntOps::umin(MaxBECount, CR.getUnsignedMax() - 1);
10697 }
10698 return ExitLimit(Distance, getConstant(MaxBECount), Distance, false,
10699 Predicates);
10700 }
10701
10702 // If the condition controls loop exit (the loop exits only if the expression
10703 // is true) and the addition is no-wrap we can use unsigned divide to
10704 // compute the backedge count. In this case, the step may not divide the
10705 // distance, but we don't care because if the condition is "missed" the loop
10706 // will have undefined behavior due to wrapping.
10707 if (ControlsOnlyExit && AddRec->hasNoSelfWrap() &&
10708 loopHasNoAbnormalExits(AddRec->getLoop())) {
10709
10710 // If the stride is zero and the start is non-zero, the loop must be
10711 // infinite. In C++, most loops are finite by assumption, in which case the
10712 // step being zero implies UB must execute if the loop is entered.
10713 if (!(loopIsFiniteByAssumption(L) && isKnownNonZero(Start)) &&
10714 !isKnownNonZero(StepWLG))
10715 return getCouldNotCompute();
10716
10717 const SCEV *Exact =
10718 getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
10719 const SCEV *ConstantMax = getCouldNotCompute();
10720 if (Exact != getCouldNotCompute()) {
10721 APInt MaxInt = getUnsignedRangeMax(applyLoopGuards(Exact, Guards));
10722 ConstantMax =
10724 }
10725 const SCEV *SymbolicMax =
10726 isa<SCEVCouldNotCompute>(Exact) ? ConstantMax : Exact;
10727 return ExitLimit(Exact, ConstantMax, SymbolicMax, false, Predicates);
10728 }
10729
10730 // Solve the general equation.
10731 const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
10732 if (!StepC || StepC->getValue()->isZero())
10733 return getCouldNotCompute();
10734 const SCEV *E = SolveLinEquationWithOverflow(
10735 StepC->getAPInt(), getNegativeSCEV(Start),
10736 AllowPredicates ? &Predicates : nullptr, *this, L);
10737
10738 const SCEV *M = E;
10739 if (E != getCouldNotCompute()) {
10740 APInt MaxWithGuards = getUnsignedRangeMax(applyLoopGuards(E, Guards));
10741 M = getConstant(APIntOps::umin(MaxWithGuards, getUnsignedRangeMax(E)));
10742 }
10743 auto *S = isa<SCEVCouldNotCompute>(E) ? M : E;
10744 return ExitLimit(E, M, S, false, Predicates);
10745}
10746
10747ScalarEvolution::ExitLimit
10748ScalarEvolution::howFarToNonZero(const SCEV *V, const Loop *L) {
10749 // Loops that look like: while (X == 0) are very strange indeed. We don't
10750 // handle them yet except for the trivial case. This could be expanded in the
10751 // future as needed.
10752
10753 // If the value is a constant, check to see if it is known to be non-zero
10754 // already. If so, the backedge will execute zero times.
10755 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10756 if (!C->getValue()->isZero())
10757 return getZero(C->getType());
10758 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10759 }
10760
10761 // We could implement others, but I really doubt anyone writes loops like
10762 // this, and if they did, they would already be constant folded.
10763 return getCouldNotCompute();
10764}
10765
10766std::pair<const BasicBlock *, const BasicBlock *>
10767ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(const BasicBlock *BB)
10768 const {
10769 // If the block has a unique predecessor, then there is no path from the
10770 // predecessor to the block that does not go through the direct edge
10771 // from the predecessor to the block.
10772 if (const BasicBlock *Pred = BB->getSinglePredecessor())
10773 return {Pred, BB};
10774
10775 // A loop's header is defined to be a block that dominates the loop.
10776 // If the header has a unique predecessor outside the loop, it must be
10777 // a block that has exactly one successor that can reach the loop.
10778 if (const Loop *L = LI.getLoopFor(BB))
10779 return {L->getLoopPredecessor(), L->getHeader()};
10780
10781 return {nullptr, BB};
10782}
10783
10784/// SCEV structural equivalence is usually sufficient for testing whether two
10785/// expressions are equal, however for the purposes of looking for a condition
10786/// guarding a loop, it can be useful to be a little more general, since a
10787/// front-end may have replicated the controlling expression.
10788static bool HasSameValue(const SCEV *A, const SCEV *B) {
10789 // Quick check to see if they are the same SCEV.
10790 if (A == B) return true;
10791
10792 auto ComputesEqualValues = [](const Instruction *A, const Instruction *B) {
10793 // Not all instructions that are "identical" compute the same value. For
10794 // instance, two distinct alloca instructions allocating the same type are
10795 // identical and do not read memory; but compute distinct values.
10796 return A->isIdenticalTo(B) && (isa<BinaryOperator>(A) || isa<GetElementPtrInst>(A));
10797 };
10798
10799 // Otherwise, if they're both SCEVUnknown, it's possible that they hold
10800 // two different instructions with the same value. Check for this case.
10801 if (const SCEVUnknown *AU = dyn_cast<SCEVUnknown>(A))
10802 if (const SCEVUnknown *BU = dyn_cast<SCEVUnknown>(B))
10803 if (const Instruction *AI = dyn_cast<Instruction>(AU->getValue()))
10804 if (const Instruction *BI = dyn_cast<Instruction>(BU->getValue()))
10805 if (ComputesEqualValues(AI, BI))
10806 return true;
10807
10808 // Otherwise assume they may have a different value.
10809 return false;
10810}
10811
10812static bool MatchBinarySub(const SCEV *S, const SCEV *&LHS, const SCEV *&RHS) {
10813 const SCEV *Op0, *Op1;
10814 if (!match(S, m_scev_Add(m_SCEV(Op0), m_SCEV(Op1))))
10815 return false;
10816 if (match(Op0, m_scev_Mul(m_scev_AllOnes(), m_SCEV(RHS)))) {
10817 LHS = Op1;
10818 return true;
10819 }
10820 if (match(Op1, m_scev_Mul(m_scev_AllOnes(), m_SCEV(RHS)))) {
10821 LHS = Op0;
10822 return true;
10823 }
10824 return false;
10825}
10826
10828 const SCEV *&RHS, unsigned Depth) {
10829 bool Changed = false;
10830 // Simplifies ICMP to trivial true or false by turning it into '0 == 0' or
10831 // '0 != 0'.
10832 auto TrivialCase = [&](bool TriviallyTrue) {
10834 Pred = TriviallyTrue ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE;
10835 return true;
10836 };
10837 // If we hit the max recursion limit bail out.
10838 if (Depth >= 3)
10839 return false;
10840
10841 const SCEV *NewLHS, *NewRHS;
10842 if (match(LHS, m_scev_c_Mul(m_SCEV(NewLHS), m_SCEVVScale())) &&
10843 match(RHS, m_scev_c_Mul(m_SCEV(NewRHS), m_SCEVVScale()))) {
10844 const SCEVMulExpr *LMul = cast<SCEVMulExpr>(LHS);
10845 const SCEVMulExpr *RMul = cast<SCEVMulExpr>(RHS);
10846
10847 // (X * vscale) pred (Y * vscale) ==> X pred Y
10848 // when both multiples are NSW.
10849 // (X * vscale) uicmp/eq/ne (Y * vscale) ==> X uicmp/eq/ne Y
10850 // when both multiples are NUW.
10851 if ((LMul->hasNoSignedWrap() && RMul->hasNoSignedWrap()) ||
10852 (LMul->hasNoUnsignedWrap() && RMul->hasNoUnsignedWrap() &&
10853 !ICmpInst::isSigned(Pred))) {
10854 LHS = NewLHS;
10855 RHS = NewRHS;
10856 Changed = true;
10857 }
10858 }
10859
10860 // Canonicalize a constant to the right side.
10861 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
10862 // Check for both operands constant.
10863 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
10864 if (!ICmpInst::compare(LHSC->getAPInt(), RHSC->getAPInt(), Pred))
10865 return TrivialCase(false);
10866 return TrivialCase(true);
10867 }
10868 // Otherwise swap the operands to put the constant on the right.
10869 std::swap(LHS, RHS);
10871 Changed = true;
10872 }
10873
10874 // If we're comparing an addrec with a value which is loop-invariant in the
10875 // addrec's loop, put the addrec on the left. Also make a dominance check,
10876 // as both operands could be addrecs loop-invariant in each other's loop.
10877 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(RHS)) {
10878 const Loop *L = AR->getLoop();
10879 if (isLoopInvariant(LHS, L) && properlyDominates(LHS, L->getHeader())) {
10880 std::swap(LHS, RHS);
10882 Changed = true;
10883 }
10884 }
10885
10886 // If there's a constant operand, canonicalize comparisons with boundary
10887 // cases, and canonicalize *-or-equal comparisons to regular comparisons.
10888 if (const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS)) {
10889 const APInt &RA = RC->getAPInt();
10890
10891 bool SimplifiedByConstantRange = false;
10892
10893 if (!ICmpInst::isEquality(Pred)) {
10895 if (ExactCR.isFullSet())
10896 return TrivialCase(true);
10897 if (ExactCR.isEmptySet())
10898 return TrivialCase(false);
10899
10900 APInt NewRHS;
10901 CmpInst::Predicate NewPred;
10902 if (ExactCR.getEquivalentICmp(NewPred, NewRHS) &&
10903 ICmpInst::isEquality(NewPred)) {
10904 // We were able to convert an inequality to an equality.
10905 Pred = NewPred;
10906 RHS = getConstant(NewRHS);
10907 Changed = SimplifiedByConstantRange = true;
10908 }
10909 }
10910
10911 if (!SimplifiedByConstantRange) {
10912 switch (Pred) {
10913 default:
10914 break;
10915 case ICmpInst::ICMP_EQ:
10916 case ICmpInst::ICMP_NE:
10917 // Fold ((-1) * %a) + %b == 0 (equivalent to %b-%a == 0) into %a == %b.
10918 if (RA.isZero() && MatchBinarySub(LHS, LHS, RHS))
10919 Changed = true;
10920 break;
10921
10922 // The "Should have been caught earlier!" messages refer to the fact
10923 // that the ExactCR.isFullSet() or ExactCR.isEmptySet() check above
10924 // should have fired on the corresponding cases, and canonicalized the
10925 // check to trivial case.
10926
10927 case ICmpInst::ICMP_UGE:
10928 assert(!RA.isMinValue() && "Should have been caught earlier!");
10929 Pred = ICmpInst::ICMP_UGT;
10930 RHS = getConstant(RA - 1);
10931 Changed = true;
10932 break;
10933 case ICmpInst::ICMP_ULE:
10934 assert(!RA.isMaxValue() && "Should have been caught earlier!");
10935 Pred = ICmpInst::ICMP_ULT;
10936 RHS = getConstant(RA + 1);
10937 Changed = true;
10938 break;
10939 case ICmpInst::ICMP_SGE:
10940 assert(!RA.isMinSignedValue() && "Should have been caught earlier!");
10941 Pred = ICmpInst::ICMP_SGT;
10942 RHS = getConstant(RA - 1);
10943 Changed = true;
10944 break;
10945 case ICmpInst::ICMP_SLE:
10946 assert(!RA.isMaxSignedValue() && "Should have been caught earlier!");
10947 Pred = ICmpInst::ICMP_SLT;
10948 RHS = getConstant(RA + 1);
10949 Changed = true;
10950 break;
10951 }
10952 }
10953 }
10954
10955 // Check for obvious equality.
10956 if (HasSameValue(LHS, RHS)) {
10957 if (ICmpInst::isTrueWhenEqual(Pred))
10958 return TrivialCase(true);
10960 return TrivialCase(false);
10961 }
10962
10963 // If possible, canonicalize GE/LE comparisons to GT/LT comparisons, by
10964 // adding or subtracting 1 from one of the operands.
10965 switch (Pred) {
10966 case ICmpInst::ICMP_SLE:
10967 if (!getSignedRangeMax(RHS).isMaxSignedValue()) {
10968 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
10970 Pred = ICmpInst::ICMP_SLT;
10971 Changed = true;
10972 } else if (!getSignedRangeMin(LHS).isMinSignedValue()) {
10973 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS,
10975 Pred = ICmpInst::ICMP_SLT;
10976 Changed = true;
10977 }
10978 break;
10979 case ICmpInst::ICMP_SGE:
10980 if (!getSignedRangeMin(RHS).isMinSignedValue()) {
10981 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS,
10983 Pred = ICmpInst::ICMP_SGT;
10984 Changed = true;
10985 } else if (!getSignedRangeMax(LHS).isMaxSignedValue()) {
10986 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
10988 Pred = ICmpInst::ICMP_SGT;
10989 Changed = true;
10990 }
10991 break;
10992 case ICmpInst::ICMP_ULE:
10993 if (!getUnsignedRangeMax(RHS).isMaxValue()) {
10994 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
10996 Pred = ICmpInst::ICMP_ULT;
10997 Changed = true;
10998 } else if (!getUnsignedRangeMin(LHS).isMinValue()) {
10999 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS);
11000 Pred = ICmpInst::ICMP_ULT;
11001 Changed = true;
11002 }
11003 break;
11004 case ICmpInst::ICMP_UGE:
11005 // If RHS is an op we can fold the -1, try that first.
11006 // Otherwise prefer LHS to preserve the nuw flag.
11007 if ((isa<SCEVConstant>(RHS) ||
11009 isa<SCEVConstant>(cast<SCEVNAryExpr>(RHS)->getOperand(0)))) &&
11010 !getUnsignedRangeMin(RHS).isMinValue()) {
11011 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
11012 Pred = ICmpInst::ICMP_UGT;
11013 Changed = true;
11014 } else if (!getUnsignedRangeMax(LHS).isMaxValue()) {
11015 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
11017 Pred = ICmpInst::ICMP_UGT;
11018 Changed = true;
11019 } else if (!getUnsignedRangeMin(RHS).isMinValue()) {
11020 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
11021 Pred = ICmpInst::ICMP_UGT;
11022 Changed = true;
11023 }
11024 break;
11025 default:
11026 break;
11027 }
11028
11029 // TODO: More simplifications are possible here.
11030
11031 // Recursively simplify until we either hit a recursion limit or nothing
11032 // changes.
11033 if (Changed)
11034 (void)SimplifyICmpOperands(Pred, LHS, RHS, Depth + 1);
11035
11036 return Changed;
11037}
11038
11040 return getSignedRangeMax(S).isNegative();
11041}
11042
11046
11048 return !getSignedRangeMin(S).isNegative();
11049}
11050
11054
11056 // Query push down for cases where the unsigned range is
11057 // less than sufficient.
11058 if (const auto *SExt = dyn_cast<SCEVSignExtendExpr>(S))
11059 return isKnownNonZero(SExt->getOperand(0));
11060 return getUnsignedRangeMin(S) != 0;
11061}
11062
11064 bool OrNegative) {
11065 auto NonRecursive = [this, OrNegative](const SCEV *S) {
11066 if (auto *C = dyn_cast<SCEVConstant>(S))
11067 return C->getAPInt().isPowerOf2() ||
11068 (OrNegative && C->getAPInt().isNegatedPowerOf2());
11069
11070 // The vscale_range indicates vscale is a power-of-two.
11071 return isa<SCEVVScale>(S) && F.hasFnAttribute(Attribute::VScaleRange);
11072 };
11073
11074 if (NonRecursive(S))
11075 return true;
11076
11077 auto *Mul = dyn_cast<SCEVMulExpr>(S);
11078 if (!Mul)
11079 return false;
11080 return all_of(Mul->operands(), NonRecursive) && (OrZero || isKnownNonZero(S));
11081}
11082
11084 const SCEV *S, uint64_t M,
11086 if (M == 0)
11087 return false;
11088 if (M == 1)
11089 return true;
11090
11091 // Recursively check AddRec operands. An AddRecExpr S is a multiple of M if S
11092 // starts with a multiple of M and at every iteration step S only adds
11093 // multiples of M.
11094 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S))
11095 return isKnownMultipleOf(AddRec->getStart(), M, Assumptions) &&
11096 isKnownMultipleOf(AddRec->getStepRecurrence(*this), M, Assumptions);
11097
11098 // For a constant, check that "S % M == 0".
11099 if (auto *Cst = dyn_cast<SCEVConstant>(S)) {
11100 APInt C = Cst->getAPInt();
11101 return C.urem(M) == 0;
11102 }
11103
11104 // TODO: Also check other SCEV expressions, i.e., SCEVAddRecExpr, etc.
11105
11106 // Basic tests have failed.
11107 // Check "S % M == 0" at compile time and record runtime Assumptions.
11108 auto *STy = dyn_cast<IntegerType>(S->getType());
11109 const SCEV *SmodM =
11110 getURemExpr(S, getConstant(ConstantInt::get(STy, M, false)));
11111 const SCEV *Zero = getZero(STy);
11112
11113 // Check whether "S % M == 0" is known at compile time.
11114 if (isKnownPredicate(ICmpInst::ICMP_EQ, SmodM, Zero))
11115 return true;
11116
11117 // Check whether "S % M != 0" is known at compile time.
11118 if (isKnownPredicate(ICmpInst::ICMP_NE, SmodM, Zero))
11119 return false;
11120
11122
11123 // Detect redundant predicates.
11124 for (auto *A : Assumptions)
11125 if (A->implies(P, *this))
11126 return true;
11127
11128 // Only record non-redundant predicates.
11129 Assumptions.push_back(P);
11130 return true;
11131}
11132
11134 return ((isKnownNonNegative(S1) && isKnownNonNegative(S2)) ||
11136}
11137
11138std::pair<const SCEV *, const SCEV *>
11140 // Compute SCEV on entry of loop L.
11141 const SCEV *Start = SCEVInitRewriter::rewrite(S, L, *this);
11142 if (Start == getCouldNotCompute())
11143 return { Start, Start };
11144 // Compute post increment SCEV for loop L.
11145 const SCEV *PostInc = SCEVPostIncRewriter::rewrite(S, L, *this);
11146 assert(PostInc != getCouldNotCompute() && "Unexpected could not compute");
11147 return { Start, PostInc };
11148}
11149
11151 const SCEV *RHS) {
11152 // First collect all loops.
11154 getUsedLoops(LHS, LoopsUsed);
11155 getUsedLoops(RHS, LoopsUsed);
11156
11157 if (LoopsUsed.empty())
11158 return false;
11159
11160 // Domination relationship must be a linear order on collected loops.
11161#ifndef NDEBUG
11162 for (const auto *L1 : LoopsUsed)
11163 for (const auto *L2 : LoopsUsed)
11164 assert((DT.dominates(L1->getHeader(), L2->getHeader()) ||
11165 DT.dominates(L2->getHeader(), L1->getHeader())) &&
11166 "Domination relationship is not a linear order");
11167#endif
11168
11169 const Loop *MDL =
11170 *llvm::max_element(LoopsUsed, [&](const Loop *L1, const Loop *L2) {
11171 return DT.properlyDominates(L1->getHeader(), L2->getHeader());
11172 });
11173
11174 // Get init and post increment value for LHS.
11175 auto SplitLHS = SplitIntoInitAndPostInc(MDL, LHS);
11176 // if LHS contains unknown non-invariant SCEV then bail out.
11177 if (SplitLHS.first == getCouldNotCompute())
11178 return false;
11179 assert (SplitLHS.second != getCouldNotCompute() && "Unexpected CNC");
11180 // Get init and post increment value for RHS.
11181 auto SplitRHS = SplitIntoInitAndPostInc(MDL, RHS);
11182 // if RHS contains unknown non-invariant SCEV then bail out.
11183 if (SplitRHS.first == getCouldNotCompute())
11184 return false;
11185 assert (SplitRHS.second != getCouldNotCompute() && "Unexpected CNC");
11186 // It is possible that init SCEV contains an invariant load but it does
11187 // not dominate MDL and is not available at MDL loop entry, so we should
11188 // check it here.
11189 if (!isAvailableAtLoopEntry(SplitLHS.first, MDL) ||
11190 !isAvailableAtLoopEntry(SplitRHS.first, MDL))
11191 return false;
11192
11193 // It seems backedge guard check is faster than entry one so in some cases
11194 // it can speed up whole estimation by short circuit
11195 return isLoopBackedgeGuardedByCond(MDL, Pred, SplitLHS.second,
11196 SplitRHS.second) &&
11197 isLoopEntryGuardedByCond(MDL, Pred, SplitLHS.first, SplitRHS.first);
11198}
11199
11201 const SCEV *RHS) {
11202 // Canonicalize the inputs first.
11203 (void)SimplifyICmpOperands(Pred, LHS, RHS);
11204
11205 if (isKnownViaInduction(Pred, LHS, RHS))
11206 return true;
11207
11208 if (isKnownPredicateViaSplitting(Pred, LHS, RHS))
11209 return true;
11210
11211 // Otherwise see what can be done with some simple reasoning.
11212 return isKnownViaNonRecursiveReasoning(Pred, LHS, RHS);
11213}
11214
11216 const SCEV *LHS,
11217 const SCEV *RHS) {
11218 if (isKnownPredicate(Pred, LHS, RHS))
11219 return true;
11221 return false;
11222 return std::nullopt;
11223}
11224
11226 const SCEV *RHS,
11227 const Instruction *CtxI) {
11228 // TODO: Analyze guards and assumes from Context's block.
11229 return isKnownPredicate(Pred, LHS, RHS) ||
11230 isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS);
11231}
11232
11233std::optional<bool>
11235 const SCEV *RHS, const Instruction *CtxI) {
11236 std::optional<bool> KnownWithoutContext = evaluatePredicate(Pred, LHS, RHS);
11237 if (KnownWithoutContext)
11238 return KnownWithoutContext;
11239
11240 if (isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS))
11241 return true;
11243 CtxI->getParent(), ICmpInst::getInverseCmpPredicate(Pred), LHS, RHS))
11244 return false;
11245 return std::nullopt;
11246}
11247
11249 const SCEVAddRecExpr *LHS,
11250 const SCEV *RHS) {
11251 const Loop *L = LHS->getLoop();
11252 return isLoopEntryGuardedByCond(L, Pred, LHS->getStart(), RHS) &&
11253 isLoopBackedgeGuardedByCond(L, Pred, LHS->getPostIncExpr(*this), RHS);
11254}
11255
11256std::optional<ScalarEvolution::MonotonicPredicateType>
11258 ICmpInst::Predicate Pred) {
11259 auto Result = getMonotonicPredicateTypeImpl(LHS, Pred);
11260
11261#ifndef NDEBUG
11262 // Verify an invariant: inverting the predicate should turn a monotonically
11263 // increasing change to a monotonically decreasing one, and vice versa.
11264 if (Result) {
11265 auto ResultSwapped =
11266 getMonotonicPredicateTypeImpl(LHS, ICmpInst::getSwappedPredicate(Pred));
11267
11268 assert(*ResultSwapped != *Result &&
11269 "monotonicity should flip as we flip the predicate");
11270 }
11271#endif
11272
11273 return Result;
11274}
11275
11276std::optional<ScalarEvolution::MonotonicPredicateType>
11277ScalarEvolution::getMonotonicPredicateTypeImpl(const SCEVAddRecExpr *LHS,
11278 ICmpInst::Predicate Pred) {
11279 // A zero step value for LHS means the induction variable is essentially a
11280 // loop invariant value. We don't really depend on the predicate actually
11281 // flipping from false to true (for increasing predicates, and the other way
11282 // around for decreasing predicates), all we care about is that *if* the
11283 // predicate changes then it only changes from false to true.
11284 //
11285 // A zero step value in itself is not very useful, but there may be places
11286 // where SCEV can prove X >= 0 but not prove X > 0, so it is helpful to be
11287 // as general as possible.
11288
11289 // Only handle LE/LT/GE/GT predicates.
11290 if (!ICmpInst::isRelational(Pred))
11291 return std::nullopt;
11292
11293 bool IsGreater = ICmpInst::isGE(Pred) || ICmpInst::isGT(Pred);
11294 assert((IsGreater || ICmpInst::isLE(Pred) || ICmpInst::isLT(Pred)) &&
11295 "Should be greater or less!");
11296
11297 // Check that AR does not wrap.
11298 if (ICmpInst::isUnsigned(Pred)) {
11299 if (!LHS->hasNoUnsignedWrap())
11300 return std::nullopt;
11302 }
11303 assert(ICmpInst::isSigned(Pred) &&
11304 "Relational predicate is either signed or unsigned!");
11305 if (!LHS->hasNoSignedWrap())
11306 return std::nullopt;
11307
11308 const SCEV *Step = LHS->getStepRecurrence(*this);
11309
11310 if (isKnownNonNegative(Step))
11312
11313 if (isKnownNonPositive(Step))
11315
11316 return std::nullopt;
11317}
11318
11319std::optional<ScalarEvolution::LoopInvariantPredicate>
11321 const SCEV *RHS, const Loop *L,
11322 const Instruction *CtxI) {
11323 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
11324 if (!isLoopInvariant(RHS, L)) {
11325 if (!isLoopInvariant(LHS, L))
11326 return std::nullopt;
11327
11328 std::swap(LHS, RHS);
11330 }
11331
11332 const SCEVAddRecExpr *ArLHS = dyn_cast<SCEVAddRecExpr>(LHS);
11333 if (!ArLHS || ArLHS->getLoop() != L)
11334 return std::nullopt;
11335
11336 auto MonotonicType = getMonotonicPredicateType(ArLHS, Pred);
11337 if (!MonotonicType)
11338 return std::nullopt;
11339 // If the predicate "ArLHS `Pred` RHS" monotonically increases from false to
11340 // true as the loop iterates, and the backedge is control dependent on
11341 // "ArLHS `Pred` RHS" == true then we can reason as follows:
11342 //
11343 // * if the predicate was false in the first iteration then the predicate
11344 // is never evaluated again, since the loop exits without taking the
11345 // backedge.
11346 // * if the predicate was true in the first iteration then it will
11347 // continue to be true for all future iterations since it is
11348 // monotonically increasing.
11349 //
11350 // For both the above possibilities, we can replace the loop varying
11351 // predicate with its value on the first iteration of the loop (which is
11352 // loop invariant).
11353 //
11354 // A similar reasoning applies for a monotonically decreasing predicate, by
11355 // replacing true with false and false with true in the above two bullets.
11357 auto P = Increasing ? Pred : ICmpInst::getInverseCmpPredicate(Pred);
11358
11359 if (isLoopBackedgeGuardedByCond(L, P, LHS, RHS))
11361 RHS);
11362
11363 if (!CtxI)
11364 return std::nullopt;
11365 // Try to prove via context.
11366 // TODO: Support other cases.
11367 switch (Pred) {
11368 default:
11369 break;
11370 case ICmpInst::ICMP_ULE:
11371 case ICmpInst::ICMP_ULT: {
11372 assert(ArLHS->hasNoUnsignedWrap() && "Is a requirement of monotonicity!");
11373 // Given preconditions
11374 // (1) ArLHS does not cross the border of positive and negative parts of
11375 // range because of:
11376 // - Positive step; (TODO: lift this limitation)
11377 // - nuw - does not cross zero boundary;
11378 // - nsw - does not cross SINT_MAX boundary;
11379 // (2) ArLHS <s RHS
11380 // (3) RHS >=s 0
11381 // we can replace the loop variant ArLHS <u RHS condition with loop
11382 // invariant Start(ArLHS) <u RHS.
11383 //
11384 // Because of (1) there are two options:
11385 // - ArLHS is always negative. It means that ArLHS <u RHS is always false;
11386 // - ArLHS is always non-negative. Because of (3) RHS is also non-negative.
11387 // It means that ArLHS <s RHS <=> ArLHS <u RHS.
11388 // Because of (2) ArLHS <u RHS is trivially true.
11389 // All together it means that ArLHS <u RHS <=> Start(ArLHS) >=s 0.
11390 // We can strengthen this to Start(ArLHS) <u RHS.
11391 auto SignFlippedPred = ICmpInst::getFlippedSignednessPredicate(Pred);
11392 if (ArLHS->hasNoSignedWrap() && ArLHS->isAffine() &&
11393 isKnownPositive(ArLHS->getStepRecurrence(*this)) &&
11394 isKnownNonNegative(RHS) &&
11395 isKnownPredicateAt(SignFlippedPred, ArLHS, RHS, CtxI))
11397 RHS);
11398 }
11399 }
11400
11401 return std::nullopt;
11402}
11403
11404std::optional<ScalarEvolution::LoopInvariantPredicate>
11406 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11407 const Instruction *CtxI, const SCEV *MaxIter) {
11409 Pred, LHS, RHS, L, CtxI, MaxIter))
11410 return LIP;
11411 if (auto *UMin = dyn_cast<SCEVUMinExpr>(MaxIter))
11412 // Number of iterations expressed as UMIN isn't always great for expressing
11413 // the value on the last iteration. If the straightforward approach didn't
11414 // work, try the following trick: if the a predicate is invariant for X, it
11415 // is also invariant for umin(X, ...). So try to find something that works
11416 // among subexpressions of MaxIter expressed as umin.
11417 for (auto *Op : UMin->operands())
11419 Pred, LHS, RHS, L, CtxI, Op))
11420 return LIP;
11421 return std::nullopt;
11422}
11423
11424std::optional<ScalarEvolution::LoopInvariantPredicate>
11426 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11427 const Instruction *CtxI, const SCEV *MaxIter) {
11428 // Try to prove the following set of facts:
11429 // - The predicate is monotonic in the iteration space.
11430 // - If the check does not fail on the 1st iteration:
11431 // - No overflow will happen during first MaxIter iterations;
11432 // - It will not fail on the MaxIter'th iteration.
11433 // If the check does fail on the 1st iteration, we leave the loop and no
11434 // other checks matter.
11435
11436 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
11437 if (!isLoopInvariant(RHS, L)) {
11438 if (!isLoopInvariant(LHS, L))
11439 return std::nullopt;
11440
11441 std::swap(LHS, RHS);
11443 }
11444
11445 auto *AR = dyn_cast<SCEVAddRecExpr>(LHS);
11446 if (!AR || AR->getLoop() != L)
11447 return std::nullopt;
11448
11449 // The predicate must be relational (i.e. <, <=, >=, >).
11450 if (!ICmpInst::isRelational(Pred))
11451 return std::nullopt;
11452
11453 // TODO: Support steps other than +/- 1.
11454 const SCEV *Step = AR->getStepRecurrence(*this);
11455 auto *One = getOne(Step->getType());
11456 auto *MinusOne = getNegativeSCEV(One);
11457 if (Step != One && Step != MinusOne)
11458 return std::nullopt;
11459
11460 // Type mismatch here means that MaxIter is potentially larger than max
11461 // unsigned value in start type, which mean we cannot prove no wrap for the
11462 // indvar.
11463 if (AR->getType() != MaxIter->getType())
11464 return std::nullopt;
11465
11466 // Value of IV on suggested last iteration.
11467 const SCEV *Last = AR->evaluateAtIteration(MaxIter, *this);
11468 // Does it still meet the requirement?
11469 if (!isLoopBackedgeGuardedByCond(L, Pred, Last, RHS))
11470 return std::nullopt;
11471 // Because step is +/- 1 and MaxIter has same type as Start (i.e. it does
11472 // not exceed max unsigned value of this type), this effectively proves
11473 // that there is no wrap during the iteration. To prove that there is no
11474 // signed/unsigned wrap, we need to check that
11475 // Start <= Last for step = 1 or Start >= Last for step = -1.
11476 ICmpInst::Predicate NoOverflowPred =
11478 if (Step == MinusOne)
11479 NoOverflowPred = ICmpInst::getSwappedCmpPredicate(NoOverflowPred);
11480 const SCEV *Start = AR->getStart();
11481 if (!isKnownPredicateAt(NoOverflowPred, Start, Last, CtxI))
11482 return std::nullopt;
11483
11484 // Everything is fine.
11485 return ScalarEvolution::LoopInvariantPredicate(Pred, Start, RHS);
11486}
11487
11488bool ScalarEvolution::isKnownPredicateViaConstantRanges(CmpPredicate Pred,
11489 const SCEV *LHS,
11490 const SCEV *RHS) {
11491 if (HasSameValue(LHS, RHS))
11492 return ICmpInst::isTrueWhenEqual(Pred);
11493
11494 auto CheckRange = [&](bool IsSigned) {
11495 auto RangeLHS = IsSigned ? getSignedRange(LHS) : getUnsignedRange(LHS);
11496 auto RangeRHS = IsSigned ? getSignedRange(RHS) : getUnsignedRange(RHS);
11497 return RangeLHS.icmp(Pred, RangeRHS);
11498 };
11499
11500 // The check at the top of the function catches the case where the values are
11501 // known to be equal.
11502 if (Pred == CmpInst::ICMP_EQ)
11503 return false;
11504
11505 if (Pred == CmpInst::ICMP_NE) {
11506 if (CheckRange(true) || CheckRange(false))
11507 return true;
11508 auto *Diff = getMinusSCEV(LHS, RHS);
11509 return !isa<SCEVCouldNotCompute>(Diff) && isKnownNonZero(Diff);
11510 }
11511
11512 return CheckRange(CmpInst::isSigned(Pred));
11513}
11514
11515bool ScalarEvolution::isKnownPredicateViaNoOverflow(CmpPredicate Pred,
11516 const SCEV *LHS,
11517 const SCEV *RHS) {
11518 // Match X to (A + C1)<ExpectedFlags> and Y to (A + C2)<ExpectedFlags>, where
11519 // C1 and C2 are constant integers. If either X or Y are not add expressions,
11520 // consider them as X + 0 and Y + 0 respectively. C1 and C2 are returned via
11521 // OutC1 and OutC2.
11522 auto MatchBinaryAddToConst = [this](const SCEV *X, const SCEV *Y,
11523 APInt &OutC1, APInt &OutC2,
11524 SCEV::NoWrapFlags ExpectedFlags) {
11525 const SCEV *XNonConstOp, *XConstOp;
11526 const SCEV *YNonConstOp, *YConstOp;
11527 SCEV::NoWrapFlags XFlagsPresent;
11528 SCEV::NoWrapFlags YFlagsPresent;
11529
11530 if (!splitBinaryAdd(X, XConstOp, XNonConstOp, XFlagsPresent)) {
11531 XConstOp = getZero(X->getType());
11532 XNonConstOp = X;
11533 XFlagsPresent = ExpectedFlags;
11534 }
11535 if (!isa<SCEVConstant>(XConstOp))
11536 return false;
11537
11538 if (!splitBinaryAdd(Y, YConstOp, YNonConstOp, YFlagsPresent)) {
11539 YConstOp = getZero(Y->getType());
11540 YNonConstOp = Y;
11541 YFlagsPresent = ExpectedFlags;
11542 }
11543
11544 if (YNonConstOp != XNonConstOp)
11545 return false;
11546
11547 if (!isa<SCEVConstant>(YConstOp))
11548 return false;
11549
11550 // When matching ADDs with NUW flags (and unsigned predicates), only the
11551 // second ADD (with the larger constant) requires NUW.
11552 if ((YFlagsPresent & ExpectedFlags) != ExpectedFlags)
11553 return false;
11554 if (ExpectedFlags != SCEV::FlagNUW &&
11555 (XFlagsPresent & ExpectedFlags) != ExpectedFlags) {
11556 return false;
11557 }
11558
11559 OutC1 = cast<SCEVConstant>(XConstOp)->getAPInt();
11560 OutC2 = cast<SCEVConstant>(YConstOp)->getAPInt();
11561
11562 return true;
11563 };
11564
11565 APInt C1;
11566 APInt C2;
11567
11568 switch (Pred) {
11569 default:
11570 break;
11571
11572 case ICmpInst::ICMP_SGE:
11573 std::swap(LHS, RHS);
11574 [[fallthrough]];
11575 case ICmpInst::ICMP_SLE:
11576 // (X + C1)<nsw> s<= (X + C2)<nsw> if C1 s<= C2.
11577 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.sle(C2))
11578 return true;
11579
11580 break;
11581
11582 case ICmpInst::ICMP_SGT:
11583 std::swap(LHS, RHS);
11584 [[fallthrough]];
11585 case ICmpInst::ICMP_SLT:
11586 // (X + C1)<nsw> s< (X + C2)<nsw> if C1 s< C2.
11587 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.slt(C2))
11588 return true;
11589
11590 break;
11591
11592 case ICmpInst::ICMP_UGE:
11593 std::swap(LHS, RHS);
11594 [[fallthrough]];
11595 case ICmpInst::ICMP_ULE:
11596 // (X + C1) u<= (X + C2)<nuw> for C1 u<= C2.
11597 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNUW) && C1.ule(C2))
11598 return true;
11599
11600 break;
11601
11602 case ICmpInst::ICMP_UGT:
11603 std::swap(LHS, RHS);
11604 [[fallthrough]];
11605 case ICmpInst::ICMP_ULT:
11606 // (X + C1) u< (X + C2)<nuw> if C1 u< C2.
11607 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNUW) && C1.ult(C2))
11608 return true;
11609 break;
11610 }
11611
11612 return false;
11613}
11614
11615bool ScalarEvolution::isKnownPredicateViaSplitting(CmpPredicate Pred,
11616 const SCEV *LHS,
11617 const SCEV *RHS) {
11618 if (Pred != ICmpInst::ICMP_ULT || ProvingSplitPredicate)
11619 return false;
11620
11621 // Allowing arbitrary number of activations of isKnownPredicateViaSplitting on
11622 // the stack can result in exponential time complexity.
11623 SaveAndRestore Restore(ProvingSplitPredicate, true);
11624
11625 // If L >= 0 then I `ult` L <=> I >= 0 && I `slt` L
11626 //
11627 // To prove L >= 0 we use isKnownNonNegative whereas to prove I >= 0 we use
11628 // isKnownPredicate. isKnownPredicate is more powerful, but also more
11629 // expensive; and using isKnownNonNegative(RHS) is sufficient for most of the
11630 // interesting cases seen in practice. We can consider "upgrading" L >= 0 to
11631 // use isKnownPredicate later if needed.
11632 return isKnownNonNegative(RHS) &&
11635}
11636
11637bool ScalarEvolution::isImpliedViaGuard(const BasicBlock *BB, CmpPredicate Pred,
11638 const SCEV *LHS, const SCEV *RHS) {
11639 // No need to even try if we know the module has no guards.
11640 if (!HasGuards)
11641 return false;
11642
11643 return any_of(*BB, [&](const Instruction &I) {
11644 using namespace llvm::PatternMatch;
11645
11646 Value *Condition;
11648 m_Value(Condition))) &&
11649 isImpliedCond(Pred, LHS, RHS, Condition, false);
11650 });
11651}
11652
11653/// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is
11654/// protected by a conditional between LHS and RHS. This is used to
11655/// to eliminate casts.
11657 CmpPredicate Pred,
11658 const SCEV *LHS,
11659 const SCEV *RHS) {
11660 // Interpret a null as meaning no loop, where there is obviously no guard
11661 // (interprocedural conditions notwithstanding). Do not bother about
11662 // unreachable loops.
11663 if (!L || !DT.isReachableFromEntry(L->getHeader()))
11664 return true;
11665
11666 if (VerifyIR)
11667 assert(!verifyFunction(*L->getHeader()->getParent(), &dbgs()) &&
11668 "This cannot be done on broken IR!");
11669
11670
11671 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
11672 return true;
11673
11674 BasicBlock *Latch = L->getLoopLatch();
11675 if (!Latch)
11676 return false;
11677
11678 BranchInst *LoopContinuePredicate =
11680 if (LoopContinuePredicate && LoopContinuePredicate->isConditional() &&
11681 isImpliedCond(Pred, LHS, RHS,
11682 LoopContinuePredicate->getCondition(),
11683 LoopContinuePredicate->getSuccessor(0) != L->getHeader()))
11684 return true;
11685
11686 // We don't want more than one activation of the following loops on the stack
11687 // -- that can lead to O(n!) time complexity.
11688 if (WalkingBEDominatingConds)
11689 return false;
11690
11691 SaveAndRestore ClearOnExit(WalkingBEDominatingConds, true);
11692
11693 // See if we can exploit a trip count to prove the predicate.
11694 const auto &BETakenInfo = getBackedgeTakenInfo(L);
11695 const SCEV *LatchBECount = BETakenInfo.getExact(Latch, this);
11696 if (LatchBECount != getCouldNotCompute()) {
11697 // We know that Latch branches back to the loop header exactly
11698 // LatchBECount times. This means the backdege condition at Latch is
11699 // equivalent to "{0,+,1} u< LatchBECount".
11700 Type *Ty = LatchBECount->getType();
11701 auto NoWrapFlags = SCEV::NoWrapFlags(SCEV::FlagNUW | SCEV::FlagNW);
11702 const SCEV *LoopCounter =
11703 getAddRecExpr(getZero(Ty), getOne(Ty), L, NoWrapFlags);
11704 if (isImpliedCond(Pred, LHS, RHS, ICmpInst::ICMP_ULT, LoopCounter,
11705 LatchBECount))
11706 return true;
11707 }
11708
11709 // Check conditions due to any @llvm.assume intrinsics.
11710 for (auto &AssumeVH : AC.assumptions()) {
11711 if (!AssumeVH)
11712 continue;
11713 auto *CI = cast<CallInst>(AssumeVH);
11714 if (!DT.dominates(CI, Latch->getTerminator()))
11715 continue;
11716
11717 if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false))
11718 return true;
11719 }
11720
11721 if (isImpliedViaGuard(Latch, Pred, LHS, RHS))
11722 return true;
11723
11724 for (DomTreeNode *DTN = DT[Latch], *HeaderDTN = DT[L->getHeader()];
11725 DTN != HeaderDTN; DTN = DTN->getIDom()) {
11726 assert(DTN && "should reach the loop header before reaching the root!");
11727
11728 BasicBlock *BB = DTN->getBlock();
11729 if (isImpliedViaGuard(BB, Pred, LHS, RHS))
11730 return true;
11731
11732 BasicBlock *PBB = BB->getSinglePredecessor();
11733 if (!PBB)
11734 continue;
11735
11736 BranchInst *ContinuePredicate = dyn_cast<BranchInst>(PBB->getTerminator());
11737 if (!ContinuePredicate || !ContinuePredicate->isConditional())
11738 continue;
11739
11740 Value *Condition = ContinuePredicate->getCondition();
11741
11742 // If we have an edge `E` within the loop body that dominates the only
11743 // latch, the condition guarding `E` also guards the backedge. This
11744 // reasoning works only for loops with a single latch.
11745
11746 BasicBlockEdge DominatingEdge(PBB, BB);
11747 if (DominatingEdge.isSingleEdge()) {
11748 // We're constructively (and conservatively) enumerating edges within the
11749 // loop body that dominate the latch. The dominator tree better agree
11750 // with us on this:
11751 assert(DT.dominates(DominatingEdge, Latch) && "should be!");
11752
11753 if (isImpliedCond(Pred, LHS, RHS, Condition,
11754 BB != ContinuePredicate->getSuccessor(0)))
11755 return true;
11756 }
11757 }
11758
11759 return false;
11760}
11761
11763 CmpPredicate Pred,
11764 const SCEV *LHS,
11765 const SCEV *RHS) {
11766 // Do not bother proving facts for unreachable code.
11767 if (!DT.isReachableFromEntry(BB))
11768 return true;
11769 if (VerifyIR)
11770 assert(!verifyFunction(*BB->getParent(), &dbgs()) &&
11771 "This cannot be done on broken IR!");
11772
11773 // If we cannot prove strict comparison (e.g. a > b), maybe we can prove
11774 // the facts (a >= b && a != b) separately. A typical situation is when the
11775 // non-strict comparison is known from ranges and non-equality is known from
11776 // dominating predicates. If we are proving strict comparison, we always try
11777 // to prove non-equality and non-strict comparison separately.
11778 CmpPredicate NonStrictPredicate = ICmpInst::getNonStrictCmpPredicate(Pred);
11779 const bool ProvingStrictComparison =
11780 Pred != NonStrictPredicate.dropSameSign();
11781 bool ProvedNonStrictComparison = false;
11782 bool ProvedNonEquality = false;
11783
11784 auto SplitAndProve = [&](std::function<bool(CmpPredicate)> Fn) -> bool {
11785 if (!ProvedNonStrictComparison)
11786 ProvedNonStrictComparison = Fn(NonStrictPredicate);
11787 if (!ProvedNonEquality)
11788 ProvedNonEquality = Fn(ICmpInst::ICMP_NE);
11789 if (ProvedNonStrictComparison && ProvedNonEquality)
11790 return true;
11791 return false;
11792 };
11793
11794 if (ProvingStrictComparison) {
11795 auto ProofFn = [&](CmpPredicate P) {
11796 return isKnownViaNonRecursiveReasoning(P, LHS, RHS);
11797 };
11798 if (SplitAndProve(ProofFn))
11799 return true;
11800 }
11801
11802 // Try to prove (Pred, LHS, RHS) using isImpliedCond.
11803 auto ProveViaCond = [&](const Value *Condition, bool Inverse) {
11804 const Instruction *CtxI = &BB->front();
11805 if (isImpliedCond(Pred, LHS, RHS, Condition, Inverse, CtxI))
11806 return true;
11807 if (ProvingStrictComparison) {
11808 auto ProofFn = [&](CmpPredicate P) {
11809 return isImpliedCond(P, LHS, RHS, Condition, Inverse, CtxI);
11810 };
11811 if (SplitAndProve(ProofFn))
11812 return true;
11813 }
11814 return false;
11815 };
11816
11817 // Starting at the block's predecessor, climb up the predecessor chain, as long
11818 // as there are predecessors that can be found that have unique successors
11819 // leading to the original block.
11820 const Loop *ContainingLoop = LI.getLoopFor(BB);
11821 const BasicBlock *PredBB;
11822 if (ContainingLoop && ContainingLoop->getHeader() == BB)
11823 PredBB = ContainingLoop->getLoopPredecessor();
11824 else
11825 PredBB = BB->getSinglePredecessor();
11826 for (std::pair<const BasicBlock *, const BasicBlock *> Pair(PredBB, BB);
11827 Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
11828 const BranchInst *BlockEntryPredicate =
11829 dyn_cast<BranchInst>(Pair.first->getTerminator());
11830 if (!BlockEntryPredicate || BlockEntryPredicate->isUnconditional())
11831 continue;
11832
11833 if (ProveViaCond(BlockEntryPredicate->getCondition(),
11834 BlockEntryPredicate->getSuccessor(0) != Pair.second))
11835 return true;
11836 }
11837
11838 // Check conditions due to any @llvm.assume intrinsics.
11839 for (auto &AssumeVH : AC.assumptions()) {
11840 if (!AssumeVH)
11841 continue;
11842 auto *CI = cast<CallInst>(AssumeVH);
11843 if (!DT.dominates(CI, BB))
11844 continue;
11845
11846 if (ProveViaCond(CI->getArgOperand(0), false))
11847 return true;
11848 }
11849
11850 // Check conditions due to any @llvm.experimental.guard intrinsics.
11851 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
11852 F.getParent(), Intrinsic::experimental_guard);
11853 if (GuardDecl)
11854 for (const auto *GU : GuardDecl->users())
11855 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
11856 if (Guard->getFunction() == BB->getParent() && DT.dominates(Guard, BB))
11857 if (ProveViaCond(Guard->getArgOperand(0), false))
11858 return true;
11859 return false;
11860}
11861
11863 const SCEV *LHS,
11864 const SCEV *RHS) {
11865 // Interpret a null as meaning no loop, where there is obviously no guard
11866 // (interprocedural conditions notwithstanding).
11867 if (!L)
11868 return false;
11869
11870 // Both LHS and RHS must be available at loop entry.
11872 "LHS is not available at Loop Entry");
11874 "RHS is not available at Loop Entry");
11875
11876 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
11877 return true;
11878
11879 return isBasicBlockEntryGuardedByCond(L->getHeader(), Pred, LHS, RHS);
11880}
11881
11882bool ScalarEvolution::isImpliedCond(CmpPredicate Pred, const SCEV *LHS,
11883 const SCEV *RHS,
11884 const Value *FoundCondValue, bool Inverse,
11885 const Instruction *CtxI) {
11886 // False conditions implies anything. Do not bother analyzing it further.
11887 if (FoundCondValue ==
11888 ConstantInt::getBool(FoundCondValue->getContext(), Inverse))
11889 return true;
11890
11891 if (!PendingLoopPredicates.insert(FoundCondValue).second)
11892 return false;
11893
11894 llvm::scope_exit ClearOnExit(
11895 [&]() { PendingLoopPredicates.erase(FoundCondValue); });
11896
11897 // Recursively handle And and Or conditions.
11898 const Value *Op0, *Op1;
11899 if (match(FoundCondValue, m_LogicalAnd(m_Value(Op0), m_Value(Op1)))) {
11900 if (!Inverse)
11901 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
11902 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
11903 } else if (match(FoundCondValue, m_LogicalOr(m_Value(Op0), m_Value(Op1)))) {
11904 if (Inverse)
11905 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
11906 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
11907 }
11908
11909 const ICmpInst *ICI = dyn_cast<ICmpInst>(FoundCondValue);
11910 if (!ICI) return false;
11911
11912 // Now that we found a conditional branch that dominates the loop or controls
11913 // the loop latch. Check to see if it is the comparison we are looking for.
11914 CmpPredicate FoundPred;
11915 if (Inverse)
11916 FoundPred = ICI->getInverseCmpPredicate();
11917 else
11918 FoundPred = ICI->getCmpPredicate();
11919
11920 const SCEV *FoundLHS = getSCEV(ICI->getOperand(0));
11921 const SCEV *FoundRHS = getSCEV(ICI->getOperand(1));
11922
11923 return isImpliedCond(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS, CtxI);
11924}
11925
11926bool ScalarEvolution::isImpliedCond(CmpPredicate Pred, const SCEV *LHS,
11927 const SCEV *RHS, CmpPredicate FoundPred,
11928 const SCEV *FoundLHS, const SCEV *FoundRHS,
11929 const Instruction *CtxI) {
11930 // Balance the types.
11931 if (getTypeSizeInBits(LHS->getType()) <
11932 getTypeSizeInBits(FoundLHS->getType())) {
11933 // For unsigned and equality predicates, try to prove that both found
11934 // operands fit into narrow unsigned range. If so, try to prove facts in
11935 // narrow types.
11936 if (!CmpInst::isSigned(FoundPred) && !FoundLHS->getType()->isPointerTy() &&
11937 !FoundRHS->getType()->isPointerTy()) {
11938 auto *NarrowType = LHS->getType();
11939 auto *WideType = FoundLHS->getType();
11940 auto BitWidth = getTypeSizeInBits(NarrowType);
11941 const SCEV *MaxValue = getZeroExtendExpr(
11943 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundLHS,
11944 MaxValue) &&
11945 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundRHS,
11946 MaxValue)) {
11947 const SCEV *TruncFoundLHS = getTruncateExpr(FoundLHS, NarrowType);
11948 const SCEV *TruncFoundRHS = getTruncateExpr(FoundRHS, NarrowType);
11949 // We cannot preserve samesign after truncation.
11950 if (isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred.dropSameSign(),
11951 TruncFoundLHS, TruncFoundRHS, CtxI))
11952 return true;
11953 }
11954 }
11955
11956 if (LHS->getType()->isPointerTy() || RHS->getType()->isPointerTy())
11957 return false;
11958 if (CmpInst::isSigned(Pred)) {
11959 LHS = getSignExtendExpr(LHS, FoundLHS->getType());
11960 RHS = getSignExtendExpr(RHS, FoundLHS->getType());
11961 } else {
11962 LHS = getZeroExtendExpr(LHS, FoundLHS->getType());
11963 RHS = getZeroExtendExpr(RHS, FoundLHS->getType());
11964 }
11965 } else if (getTypeSizeInBits(LHS->getType()) >
11966 getTypeSizeInBits(FoundLHS->getType())) {
11967 if (FoundLHS->getType()->isPointerTy() || FoundRHS->getType()->isPointerTy())
11968 return false;
11969 if (CmpInst::isSigned(FoundPred)) {
11970 FoundLHS = getSignExtendExpr(FoundLHS, LHS->getType());
11971 FoundRHS = getSignExtendExpr(FoundRHS, LHS->getType());
11972 } else {
11973 FoundLHS = getZeroExtendExpr(FoundLHS, LHS->getType());
11974 FoundRHS = getZeroExtendExpr(FoundRHS, LHS->getType());
11975 }
11976 }
11977 return isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, FoundLHS,
11978 FoundRHS, CtxI);
11979}
11980
11981bool ScalarEvolution::isImpliedCondBalancedTypes(
11982 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, CmpPredicate FoundPred,
11983 const SCEV *FoundLHS, const SCEV *FoundRHS, const Instruction *CtxI) {
11985 getTypeSizeInBits(FoundLHS->getType()) &&
11986 "Types should be balanced!");
11987 // Canonicalize the query to match the way instcombine will have
11988 // canonicalized the comparison.
11989 if (SimplifyICmpOperands(Pred, LHS, RHS))
11990 if (LHS == RHS)
11991 return CmpInst::isTrueWhenEqual(Pred);
11992 if (SimplifyICmpOperands(FoundPred, FoundLHS, FoundRHS))
11993 if (FoundLHS == FoundRHS)
11994 return CmpInst::isFalseWhenEqual(FoundPred);
11995
11996 // Check to see if we can make the LHS or RHS match.
11997 if (LHS == FoundRHS || RHS == FoundLHS) {
11998 if (isa<SCEVConstant>(RHS)) {
11999 std::swap(FoundLHS, FoundRHS);
12000 FoundPred = ICmpInst::getSwappedCmpPredicate(FoundPred);
12001 } else {
12002 std::swap(LHS, RHS);
12004 }
12005 }
12006
12007 // Check whether the found predicate is the same as the desired predicate.
12008 if (auto P = CmpPredicate::getMatching(FoundPred, Pred))
12009 return isImpliedCondOperands(*P, LHS, RHS, FoundLHS, FoundRHS, CtxI);
12010
12011 // Check whether swapping the found predicate makes it the same as the
12012 // desired predicate.
12013 if (auto P = CmpPredicate::getMatching(
12014 ICmpInst::getSwappedCmpPredicate(FoundPred), Pred)) {
12015 // We can write the implication
12016 // 0. LHS Pred RHS <- FoundLHS SwapPred FoundRHS
12017 // using one of the following ways:
12018 // 1. LHS Pred RHS <- FoundRHS Pred FoundLHS
12019 // 2. RHS SwapPred LHS <- FoundLHS SwapPred FoundRHS
12020 // 3. LHS Pred RHS <- ~FoundLHS Pred ~FoundRHS
12021 // 4. ~LHS SwapPred ~RHS <- FoundLHS SwapPred FoundRHS
12022 // Forms 1. and 2. require swapping the operands of one condition. Don't
12023 // do this if it would break canonical constant/addrec ordering.
12025 return isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(*P), RHS,
12026 LHS, FoundLHS, FoundRHS, CtxI);
12027 if (!isa<SCEVConstant>(FoundRHS) && !isa<SCEVAddRecExpr>(FoundLHS))
12028 return isImpliedCondOperands(*P, LHS, RHS, FoundRHS, FoundLHS, CtxI);
12029
12030 // There's no clear preference between forms 3. and 4., try both. Avoid
12031 // forming getNotSCEV of pointer values as the resulting subtract is
12032 // not legal.
12033 if (!LHS->getType()->isPointerTy() && !RHS->getType()->isPointerTy() &&
12034 isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(*P),
12035 getNotSCEV(LHS), getNotSCEV(RHS), FoundLHS,
12036 FoundRHS, CtxI))
12037 return true;
12038
12039 if (!FoundLHS->getType()->isPointerTy() &&
12040 !FoundRHS->getType()->isPointerTy() &&
12041 isImpliedCondOperands(*P, LHS, RHS, getNotSCEV(FoundLHS),
12042 getNotSCEV(FoundRHS), CtxI))
12043 return true;
12044
12045 return false;
12046 }
12047
12048 auto IsSignFlippedPredicate = [](CmpInst::Predicate P1,
12049 CmpInst::Predicate P2) {
12050 assert(P1 != P2 && "Handled earlier!");
12051 return CmpInst::isRelational(P2) &&
12053 };
12054 if (IsSignFlippedPredicate(Pred, FoundPred)) {
12055 // Unsigned comparison is the same as signed comparison when both the
12056 // operands are non-negative or negative.
12057 if (haveSameSign(FoundLHS, FoundRHS))
12058 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI);
12059 // Create local copies that we can freely swap and canonicalize our
12060 // conditions to "le/lt".
12061 CmpPredicate CanonicalPred = Pred, CanonicalFoundPred = FoundPred;
12062 const SCEV *CanonicalLHS = LHS, *CanonicalRHS = RHS,
12063 *CanonicalFoundLHS = FoundLHS, *CanonicalFoundRHS = FoundRHS;
12064 if (ICmpInst::isGT(CanonicalPred) || ICmpInst::isGE(CanonicalPred)) {
12065 CanonicalPred = ICmpInst::getSwappedCmpPredicate(CanonicalPred);
12066 CanonicalFoundPred = ICmpInst::getSwappedCmpPredicate(CanonicalFoundPred);
12067 std::swap(CanonicalLHS, CanonicalRHS);
12068 std::swap(CanonicalFoundLHS, CanonicalFoundRHS);
12069 }
12070 assert((ICmpInst::isLT(CanonicalPred) || ICmpInst::isLE(CanonicalPred)) &&
12071 "Must be!");
12072 assert((ICmpInst::isLT(CanonicalFoundPred) ||
12073 ICmpInst::isLE(CanonicalFoundPred)) &&
12074 "Must be!");
12075 if (ICmpInst::isSigned(CanonicalPred) && isKnownNonNegative(CanonicalRHS))
12076 // Use implication:
12077 // x <u y && y >=s 0 --> x <s y.
12078 // If we can prove the left part, the right part is also proven.
12079 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
12080 CanonicalRHS, CanonicalFoundLHS,
12081 CanonicalFoundRHS);
12082 if (ICmpInst::isUnsigned(CanonicalPred) && isKnownNegative(CanonicalRHS))
12083 // Use implication:
12084 // x <s y && y <s 0 --> x <u y.
12085 // If we can prove the left part, the right part is also proven.
12086 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
12087 CanonicalRHS, CanonicalFoundLHS,
12088 CanonicalFoundRHS);
12089 }
12090
12091 // Check if we can make progress by sharpening ranges.
12092 if (FoundPred == ICmpInst::ICMP_NE &&
12093 (isa<SCEVConstant>(FoundLHS) || isa<SCEVConstant>(FoundRHS))) {
12094
12095 const SCEVConstant *C = nullptr;
12096 const SCEV *V = nullptr;
12097
12098 if (isa<SCEVConstant>(FoundLHS)) {
12099 C = cast<SCEVConstant>(FoundLHS);
12100 V = FoundRHS;
12101 } else {
12102 C = cast<SCEVConstant>(FoundRHS);
12103 V = FoundLHS;
12104 }
12105
12106 // The guarding predicate tells us that C != V. If the known range
12107 // of V is [C, t), we can sharpen the range to [C + 1, t). The
12108 // range we consider has to correspond to same signedness as the
12109 // predicate we're interested in folding.
12110
12111 APInt Min = ICmpInst::isSigned(Pred) ?
12113
12114 if (Min == C->getAPInt()) {
12115 // Given (V >= Min && V != Min) we conclude V >= (Min + 1).
12116 // This is true even if (Min + 1) wraps around -- in case of
12117 // wraparound, (Min + 1) < Min, so (V >= Min => V >= (Min + 1)).
12118
12119 APInt SharperMin = Min + 1;
12120
12121 switch (Pred) {
12122 case ICmpInst::ICMP_SGE:
12123 case ICmpInst::ICMP_UGE:
12124 // We know V `Pred` SharperMin. If this implies LHS `Pred`
12125 // RHS, we're done.
12126 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(SharperMin),
12127 CtxI))
12128 return true;
12129 [[fallthrough]];
12130
12131 case ICmpInst::ICMP_SGT:
12132 case ICmpInst::ICMP_UGT:
12133 // We know from the range information that (V `Pred` Min ||
12134 // V == Min). We know from the guarding condition that !(V
12135 // == Min). This gives us
12136 //
12137 // V `Pred` Min || V == Min && !(V == Min)
12138 // => V `Pred` Min
12139 //
12140 // If V `Pred` Min implies LHS `Pred` RHS, we're done.
12141
12142 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min), CtxI))
12143 return true;
12144 break;
12145
12146 // `LHS < RHS` and `LHS <= RHS` are handled in the same way as `RHS > LHS` and `RHS >= LHS` respectively.
12147 case ICmpInst::ICMP_SLE:
12148 case ICmpInst::ICMP_ULE:
12149 if (isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(Pred), RHS,
12150 LHS, V, getConstant(SharperMin), CtxI))
12151 return true;
12152 [[fallthrough]];
12153
12154 case ICmpInst::ICMP_SLT:
12155 case ICmpInst::ICMP_ULT:
12156 if (isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(Pred), RHS,
12157 LHS, V, getConstant(Min), CtxI))
12158 return true;
12159 break;
12160
12161 default:
12162 // No change
12163 break;
12164 }
12165 }
12166 }
12167
12168 // Check whether the actual condition is beyond sufficient.
12169 if (FoundPred == ICmpInst::ICMP_EQ)
12170 if (ICmpInst::isTrueWhenEqual(Pred))
12171 if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
12172 return true;
12173 if (Pred == ICmpInst::ICMP_NE)
12174 if (!ICmpInst::isTrueWhenEqual(FoundPred))
12175 if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
12176 return true;
12177
12178 if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS))
12179 return true;
12180
12181 // Otherwise assume the worst.
12182 return false;
12183}
12184
12185bool ScalarEvolution::splitBinaryAdd(const SCEV *Expr,
12186 const SCEV *&L, const SCEV *&R,
12187 SCEV::NoWrapFlags &Flags) {
12188 if (!match(Expr, m_scev_Add(m_SCEV(L), m_SCEV(R))))
12189 return false;
12190
12191 Flags = cast<SCEVAddExpr>(Expr)->getNoWrapFlags();
12192 return true;
12193}
12194
12195std::optional<APInt>
12197 // We avoid subtracting expressions here because this function is usually
12198 // fairly deep in the call stack (i.e. is called many times).
12199
12200 unsigned BW = getTypeSizeInBits(More->getType());
12201 APInt Diff(BW, 0);
12202 APInt DiffMul(BW, 1);
12203 // Try various simplifications to reduce the difference to a constant. Limit
12204 // the number of allowed simplifications to keep compile-time low.
12205 for (unsigned I = 0; I < 8; ++I) {
12206 if (More == Less)
12207 return Diff;
12208
12209 // Reduce addrecs with identical steps to their start value.
12211 const auto *LAR = cast<SCEVAddRecExpr>(Less);
12212 const auto *MAR = cast<SCEVAddRecExpr>(More);
12213
12214 if (LAR->getLoop() != MAR->getLoop())
12215 return std::nullopt;
12216
12217 // We look at affine expressions only; not for correctness but to keep
12218 // getStepRecurrence cheap.
12219 if (!LAR->isAffine() || !MAR->isAffine())
12220 return std::nullopt;
12221
12222 if (LAR->getStepRecurrence(*this) != MAR->getStepRecurrence(*this))
12223 return std::nullopt;
12224
12225 Less = LAR->getStart();
12226 More = MAR->getStart();
12227 continue;
12228 }
12229
12230 // Try to match a common constant multiply.
12231 auto MatchConstMul =
12232 [](const SCEV *S) -> std::optional<std::pair<const SCEV *, APInt>> {
12233 const APInt *C;
12234 const SCEV *Op;
12235 if (match(S, m_scev_Mul(m_scev_APInt(C), m_SCEV(Op))))
12236 return {{Op, *C}};
12237 return std::nullopt;
12238 };
12239 if (auto MatchedMore = MatchConstMul(More)) {
12240 if (auto MatchedLess = MatchConstMul(Less)) {
12241 if (MatchedMore->second == MatchedLess->second) {
12242 More = MatchedMore->first;
12243 Less = MatchedLess->first;
12244 DiffMul *= MatchedMore->second;
12245 continue;
12246 }
12247 }
12248 }
12249
12250 // Try to cancel out common factors in two add expressions.
12252 auto Add = [&](const SCEV *S, int Mul) {
12253 if (auto *C = dyn_cast<SCEVConstant>(S)) {
12254 if (Mul == 1) {
12255 Diff += C->getAPInt() * DiffMul;
12256 } else {
12257 assert(Mul == -1);
12258 Diff -= C->getAPInt() * DiffMul;
12259 }
12260 } else
12261 Multiplicity[S] += Mul;
12262 };
12263 auto Decompose = [&](const SCEV *S, int Mul) {
12264 if (isa<SCEVAddExpr>(S)) {
12265 for (const SCEV *Op : S->operands())
12266 Add(Op, Mul);
12267 } else
12268 Add(S, Mul);
12269 };
12270 Decompose(More, 1);
12271 Decompose(Less, -1);
12272
12273 // Check whether all the non-constants cancel out, or reduce to new
12274 // More/Less values.
12275 const SCEV *NewMore = nullptr, *NewLess = nullptr;
12276 for (const auto &[S, Mul] : Multiplicity) {
12277 if (Mul == 0)
12278 continue;
12279 if (Mul == 1) {
12280 if (NewMore)
12281 return std::nullopt;
12282 NewMore = S;
12283 } else if (Mul == -1) {
12284 if (NewLess)
12285 return std::nullopt;
12286 NewLess = S;
12287 } else
12288 return std::nullopt;
12289 }
12290
12291 // Values stayed the same, no point in trying further.
12292 if (NewMore == More || NewLess == Less)
12293 return std::nullopt;
12294
12295 More = NewMore;
12296 Less = NewLess;
12297
12298 // Reduced to constant.
12299 if (!More && !Less)
12300 return Diff;
12301
12302 // Left with variable on only one side, bail out.
12303 if (!More || !Less)
12304 return std::nullopt;
12305 }
12306
12307 // Did not reduce to constant.
12308 return std::nullopt;
12309}
12310
12311bool ScalarEvolution::isImpliedCondOperandsViaAddRecStart(
12312 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS,
12313 const SCEV *FoundRHS, const Instruction *CtxI) {
12314 // Try to recognize the following pattern:
12315 //
12316 // FoundRHS = ...
12317 // ...
12318 // loop:
12319 // FoundLHS = {Start,+,W}
12320 // context_bb: // Basic block from the same loop
12321 // known(Pred, FoundLHS, FoundRHS)
12322 //
12323 // If some predicate is known in the context of a loop, it is also known on
12324 // each iteration of this loop, including the first iteration. Therefore, in
12325 // this case, `FoundLHS Pred FoundRHS` implies `Start Pred FoundRHS`. Try to
12326 // prove the original pred using this fact.
12327 if (!CtxI)
12328 return false;
12329 const BasicBlock *ContextBB = CtxI->getParent();
12330 // Make sure AR varies in the context block.
12331 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundLHS)) {
12332 const Loop *L = AR->getLoop();
12333 // Make sure that context belongs to the loop and executes on 1st iteration
12334 // (if it ever executes at all).
12335 if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch()))
12336 return false;
12337 if (!isAvailableAtLoopEntry(FoundRHS, AR->getLoop()))
12338 return false;
12339 return isImpliedCondOperands(Pred, LHS, RHS, AR->getStart(), FoundRHS);
12340 }
12341
12342 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundRHS)) {
12343 const Loop *L = AR->getLoop();
12344 // Make sure that context belongs to the loop and executes on 1st iteration
12345 // (if it ever executes at all).
12346 if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch()))
12347 return false;
12348 if (!isAvailableAtLoopEntry(FoundLHS, AR->getLoop()))
12349 return false;
12350 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, AR->getStart());
12351 }
12352
12353 return false;
12354}
12355
12356bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow(CmpPredicate Pred,
12357 const SCEV *LHS,
12358 const SCEV *RHS,
12359 const SCEV *FoundLHS,
12360 const SCEV *FoundRHS) {
12361 if (Pred != CmpInst::ICMP_SLT && Pred != CmpInst::ICMP_ULT)
12362 return false;
12363
12364 const auto *AddRecLHS = dyn_cast<SCEVAddRecExpr>(LHS);
12365 if (!AddRecLHS)
12366 return false;
12367
12368 const auto *AddRecFoundLHS = dyn_cast<SCEVAddRecExpr>(FoundLHS);
12369 if (!AddRecFoundLHS)
12370 return false;
12371
12372 // We'd like to let SCEV reason about control dependencies, so we constrain
12373 // both the inequalities to be about add recurrences on the same loop. This
12374 // way we can use isLoopEntryGuardedByCond later.
12375
12376 const Loop *L = AddRecFoundLHS->getLoop();
12377 if (L != AddRecLHS->getLoop())
12378 return false;
12379
12380 // FoundLHS u< FoundRHS u< -C => (FoundLHS + C) u< (FoundRHS + C) ... (1)
12381 //
12382 // FoundLHS s< FoundRHS s< INT_MIN - C => (FoundLHS + C) s< (FoundRHS + C)
12383 // ... (2)
12384 //
12385 // Informal proof for (2), assuming (1) [*]:
12386 //
12387 // We'll also assume (A s< B) <=> ((A + INT_MIN) u< (B + INT_MIN)) ... (3)[**]
12388 //
12389 // Then
12390 //
12391 // FoundLHS s< FoundRHS s< INT_MIN - C
12392 // <=> (FoundLHS + INT_MIN) u< (FoundRHS + INT_MIN) u< -C [ using (3) ]
12393 // <=> (FoundLHS + INT_MIN + C) u< (FoundRHS + INT_MIN + C) [ using (1) ]
12394 // <=> (FoundLHS + INT_MIN + C + INT_MIN) s<
12395 // (FoundRHS + INT_MIN + C + INT_MIN) [ using (3) ]
12396 // <=> FoundLHS + C s< FoundRHS + C
12397 //
12398 // [*]: (1) can be proved by ruling out overflow.
12399 //
12400 // [**]: This can be proved by analyzing all the four possibilities:
12401 // (A s< 0, B s< 0), (A s< 0, B s>= 0), (A s>= 0, B s< 0) and
12402 // (A s>= 0, B s>= 0).
12403 //
12404 // Note:
12405 // Despite (2), "FoundRHS s< INT_MIN - C" does not mean that "FoundRHS + C"
12406 // will not sign underflow. For instance, say FoundLHS = (i8 -128), FoundRHS
12407 // = (i8 -127) and C = (i8 -100). Then INT_MIN - C = (i8 -28), and FoundRHS
12408 // s< (INT_MIN - C). Lack of sign overflow / underflow in "FoundRHS + C" is
12409 // neither necessary nor sufficient to prove "(FoundLHS + C) s< (FoundRHS +
12410 // C)".
12411
12412 std::optional<APInt> LDiff = computeConstantDifference(LHS, FoundLHS);
12413 if (!LDiff)
12414 return false;
12415 std::optional<APInt> RDiff = computeConstantDifference(RHS, FoundRHS);
12416 if (!RDiff || *LDiff != *RDiff)
12417 return false;
12418
12419 if (LDiff->isMinValue())
12420 return true;
12421
12422 APInt FoundRHSLimit;
12423
12424 if (Pred == CmpInst::ICMP_ULT) {
12425 FoundRHSLimit = -(*RDiff);
12426 } else {
12427 assert(Pred == CmpInst::ICMP_SLT && "Checked above!");
12428 FoundRHSLimit = APInt::getSignedMinValue(getTypeSizeInBits(RHS->getType())) - *RDiff;
12429 }
12430
12431 // Try to prove (1) or (2), as needed.
12432 return isAvailableAtLoopEntry(FoundRHS, L) &&
12433 isLoopEntryGuardedByCond(L, Pred, FoundRHS,
12434 getConstant(FoundRHSLimit));
12435}
12436
12437bool ScalarEvolution::isImpliedViaMerge(CmpPredicate Pred, const SCEV *LHS,
12438 const SCEV *RHS, const SCEV *FoundLHS,
12439 const SCEV *FoundRHS, unsigned Depth) {
12440 const PHINode *LPhi = nullptr, *RPhi = nullptr;
12441
12442 llvm::scope_exit ClearOnExit([&]() {
12443 if (LPhi) {
12444 bool Erased = PendingMerges.erase(LPhi);
12445 assert(Erased && "Failed to erase LPhi!");
12446 (void)Erased;
12447 }
12448 if (RPhi) {
12449 bool Erased = PendingMerges.erase(RPhi);
12450 assert(Erased && "Failed to erase RPhi!");
12451 (void)Erased;
12452 }
12453 });
12454
12455 // Find respective Phis and check that they are not being pending.
12456 if (const SCEVUnknown *LU = dyn_cast<SCEVUnknown>(LHS))
12457 if (auto *Phi = dyn_cast<PHINode>(LU->getValue())) {
12458 if (!PendingMerges.insert(Phi).second)
12459 return false;
12460 LPhi = Phi;
12461 }
12462 if (const SCEVUnknown *RU = dyn_cast<SCEVUnknown>(RHS))
12463 if (auto *Phi = dyn_cast<PHINode>(RU->getValue())) {
12464 // If we detect a loop of Phi nodes being processed by this method, for
12465 // example:
12466 //
12467 // %a = phi i32 [ %some1, %preheader ], [ %b, %latch ]
12468 // %b = phi i32 [ %some2, %preheader ], [ %a, %latch ]
12469 //
12470 // we don't want to deal with a case that complex, so return conservative
12471 // answer false.
12472 if (!PendingMerges.insert(Phi).second)
12473 return false;
12474 RPhi = Phi;
12475 }
12476
12477 // If none of LHS, RHS is a Phi, nothing to do here.
12478 if (!LPhi && !RPhi)
12479 return false;
12480
12481 // If there is a SCEVUnknown Phi we are interested in, make it left.
12482 if (!LPhi) {
12483 std::swap(LHS, RHS);
12484 std::swap(FoundLHS, FoundRHS);
12485 std::swap(LPhi, RPhi);
12487 }
12488
12489 assert(LPhi && "LPhi should definitely be a SCEVUnknown Phi!");
12490 const BasicBlock *LBB = LPhi->getParent();
12491 const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
12492
12493 auto ProvedEasily = [&](const SCEV *S1, const SCEV *S2) {
12494 return isKnownViaNonRecursiveReasoning(Pred, S1, S2) ||
12495 isImpliedCondOperandsViaRanges(Pred, S1, S2, Pred, FoundLHS, FoundRHS) ||
12496 isImpliedViaOperations(Pred, S1, S2, FoundLHS, FoundRHS, Depth);
12497 };
12498
12499 if (RPhi && RPhi->getParent() == LBB) {
12500 // Case one: RHS is also a SCEVUnknown Phi from the same basic block.
12501 // If we compare two Phis from the same block, and for each entry block
12502 // the predicate is true for incoming values from this block, then the
12503 // predicate is also true for the Phis.
12504 for (const BasicBlock *IncBB : predecessors(LBB)) {
12505 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12506 const SCEV *R = getSCEV(RPhi->getIncomingValueForBlock(IncBB));
12507 if (!ProvedEasily(L, R))
12508 return false;
12509 }
12510 } else if (RAR && RAR->getLoop()->getHeader() == LBB) {
12511 // Case two: RHS is also a Phi from the same basic block, and it is an
12512 // AddRec. It means that there is a loop which has both AddRec and Unknown
12513 // PHIs, for it we can compare incoming values of AddRec from above the loop
12514 // and latch with their respective incoming values of LPhi.
12515 // TODO: Generalize to handle loops with many inputs in a header.
12516 if (LPhi->getNumIncomingValues() != 2) return false;
12517
12518 auto *RLoop = RAR->getLoop();
12519 auto *Predecessor = RLoop->getLoopPredecessor();
12520 assert(Predecessor && "Loop with AddRec with no predecessor?");
12521 const SCEV *L1 = getSCEV(LPhi->getIncomingValueForBlock(Predecessor));
12522 if (!ProvedEasily(L1, RAR->getStart()))
12523 return false;
12524 auto *Latch = RLoop->getLoopLatch();
12525 assert(Latch && "Loop with AddRec with no latch?");
12526 const SCEV *L2 = getSCEV(LPhi->getIncomingValueForBlock(Latch));
12527 if (!ProvedEasily(L2, RAR->getPostIncExpr(*this)))
12528 return false;
12529 } else {
12530 // In all other cases go over inputs of LHS and compare each of them to RHS,
12531 // the predicate is true for (LHS, RHS) if it is true for all such pairs.
12532 // At this point RHS is either a non-Phi, or it is a Phi from some block
12533 // different from LBB.
12534 for (const BasicBlock *IncBB : predecessors(LBB)) {
12535 // Check that RHS is available in this block.
12536 if (!dominates(RHS, IncBB))
12537 return false;
12538 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12539 // Make sure L does not refer to a value from a potentially previous
12540 // iteration of a loop.
12541 if (!properlyDominates(L, LBB))
12542 return false;
12543 // Addrecs are considered to properly dominate their loop, so are missed
12544 // by the previous check. Discard any values that have computable
12545 // evolution in this loop.
12546 if (auto *Loop = LI.getLoopFor(LBB))
12547 if (hasComputableLoopEvolution(L, Loop))
12548 return false;
12549 if (!ProvedEasily(L, RHS))
12550 return false;
12551 }
12552 }
12553 return true;
12554}
12555
12556bool ScalarEvolution::isImpliedCondOperandsViaShift(CmpPredicate Pred,
12557 const SCEV *LHS,
12558 const SCEV *RHS,
12559 const SCEV *FoundLHS,
12560 const SCEV *FoundRHS) {
12561 // We want to imply LHS < RHS from LHS < (RHS >> shiftvalue). First, make
12562 // sure that we are dealing with same LHS.
12563 if (RHS == FoundRHS) {
12564 std::swap(LHS, RHS);
12565 std::swap(FoundLHS, FoundRHS);
12567 }
12568 if (LHS != FoundLHS)
12569 return false;
12570
12571 auto *SUFoundRHS = dyn_cast<SCEVUnknown>(FoundRHS);
12572 if (!SUFoundRHS)
12573 return false;
12574
12575 Value *Shiftee, *ShiftValue;
12576
12577 using namespace PatternMatch;
12578 if (match(SUFoundRHS->getValue(),
12579 m_LShr(m_Value(Shiftee), m_Value(ShiftValue)))) {
12580 auto *ShifteeS = getSCEV(Shiftee);
12581 // Prove one of the following:
12582 // LHS <u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <u RHS
12583 // LHS <=u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <=u RHS
12584 // LHS <s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12585 // ---> LHS <s RHS
12586 // LHS <=s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12587 // ---> LHS <=s RHS
12588 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE)
12589 return isKnownPredicate(ICmpInst::ICMP_ULE, ShifteeS, RHS);
12590 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE)
12591 if (isKnownNonNegative(ShifteeS))
12592 return isKnownPredicate(ICmpInst::ICMP_SLE, ShifteeS, RHS);
12593 }
12594
12595 return false;
12596}
12597
12598bool ScalarEvolution::isImpliedCondOperands(CmpPredicate Pred, const SCEV *LHS,
12599 const SCEV *RHS,
12600 const SCEV *FoundLHS,
12601 const SCEV *FoundRHS,
12602 const Instruction *CtxI) {
12603 return isImpliedCondOperandsViaRanges(Pred, LHS, RHS, Pred, FoundLHS,
12604 FoundRHS) ||
12605 isImpliedCondOperandsViaNoOverflow(Pred, LHS, RHS, FoundLHS,
12606 FoundRHS) ||
12607 isImpliedCondOperandsViaShift(Pred, LHS, RHS, FoundLHS, FoundRHS) ||
12608 isImpliedCondOperandsViaAddRecStart(Pred, LHS, RHS, FoundLHS, FoundRHS,
12609 CtxI) ||
12610 isImpliedCondOperandsHelper(Pred, LHS, RHS, FoundLHS, FoundRHS);
12611}
12612
12613/// Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
12614template <typename MinMaxExprType>
12615static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr,
12616 const SCEV *Candidate) {
12617 const MinMaxExprType *MinMaxExpr = dyn_cast<MinMaxExprType>(MaybeMinMaxExpr);
12618 if (!MinMaxExpr)
12619 return false;
12620
12621 return is_contained(MinMaxExpr->operands(), Candidate);
12622}
12623
12625 CmpPredicate Pred, const SCEV *LHS,
12626 const SCEV *RHS) {
12627 // If both sides are affine addrecs for the same loop, with equal
12628 // steps, and we know the recurrences don't wrap, then we only
12629 // need to check the predicate on the starting values.
12630
12631 if (!ICmpInst::isRelational(Pred))
12632 return false;
12633
12634 const SCEV *LStart, *RStart, *Step;
12635 const Loop *L;
12636 if (!match(LHS,
12637 m_scev_AffineAddRec(m_SCEV(LStart), m_SCEV(Step), m_Loop(L))) ||
12639 m_SpecificLoop(L))))
12640 return false;
12645 if (!LAR->getNoWrapFlags(NW) || !RAR->getNoWrapFlags(NW))
12646 return false;
12647
12648 return SE.isKnownPredicate(Pred, LStart, RStart);
12649}
12650
12651/// Is LHS `Pred` RHS true on the virtue of LHS or RHS being a Min or Max
12652/// expression?
12654 const SCEV *LHS, const SCEV *RHS) {
12655 switch (Pred) {
12656 default:
12657 return false;
12658
12659 case ICmpInst::ICMP_SGE:
12660 std::swap(LHS, RHS);
12661 [[fallthrough]];
12662 case ICmpInst::ICMP_SLE:
12663 return
12664 // min(A, ...) <= A
12666 // A <= max(A, ...)
12668
12669 case ICmpInst::ICMP_UGE:
12670 std::swap(LHS, RHS);
12671 [[fallthrough]];
12672 case ICmpInst::ICMP_ULE:
12673 return
12674 // min(A, ...) <= A
12675 // FIXME: what about umin_seq?
12677 // A <= max(A, ...)
12679 }
12680
12681 llvm_unreachable("covered switch fell through?!");
12682}
12683
12684bool ScalarEvolution::isImpliedViaOperations(CmpPredicate Pred, const SCEV *LHS,
12685 const SCEV *RHS,
12686 const SCEV *FoundLHS,
12687 const SCEV *FoundRHS,
12688 unsigned Depth) {
12691 "LHS and RHS have different sizes?");
12692 assert(getTypeSizeInBits(FoundLHS->getType()) ==
12693 getTypeSizeInBits(FoundRHS->getType()) &&
12694 "FoundLHS and FoundRHS have different sizes?");
12695 // We want to avoid hurting the compile time with analysis of too big trees.
12697 return false;
12698
12699 // We only want to work with GT comparison so far.
12700 if (ICmpInst::isLT(Pred)) {
12702 std::swap(LHS, RHS);
12703 std::swap(FoundLHS, FoundRHS);
12704 }
12705
12707
12708 // For unsigned, try to reduce it to corresponding signed comparison.
12709 if (P == ICmpInst::ICMP_UGT)
12710 // We can replace unsigned predicate with its signed counterpart if all
12711 // involved values are non-negative.
12712 // TODO: We could have better support for unsigned.
12713 if (isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) {
12714 // Knowing that both FoundLHS and FoundRHS are non-negative, and knowing
12715 // FoundLHS >u FoundRHS, we also know that FoundLHS >s FoundRHS. Let us
12716 // use this fact to prove that LHS and RHS are non-negative.
12717 const SCEV *MinusOne = getMinusOne(LHS->getType());
12718 if (isImpliedCondOperands(ICmpInst::ICMP_SGT, LHS, MinusOne, FoundLHS,
12719 FoundRHS) &&
12720 isImpliedCondOperands(ICmpInst::ICMP_SGT, RHS, MinusOne, FoundLHS,
12721 FoundRHS))
12723 }
12724
12725 if (P != ICmpInst::ICMP_SGT)
12726 return false;
12727
12728 auto GetOpFromSExt = [&](const SCEV *S) {
12729 if (auto *Ext = dyn_cast<SCEVSignExtendExpr>(S))
12730 return Ext->getOperand();
12731 // TODO: If S is a SCEVConstant then you can cheaply "strip" the sext off
12732 // the constant in some cases.
12733 return S;
12734 };
12735
12736 // Acquire values from extensions.
12737 auto *OrigLHS = LHS;
12738 auto *OrigFoundLHS = FoundLHS;
12739 LHS = GetOpFromSExt(LHS);
12740 FoundLHS = GetOpFromSExt(FoundLHS);
12741
12742 // Is the SGT predicate can be proved trivially or using the found context.
12743 auto IsSGTViaContext = [&](const SCEV *S1, const SCEV *S2) {
12744 return isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGT, S1, S2) ||
12745 isImpliedViaOperations(ICmpInst::ICMP_SGT, S1, S2, OrigFoundLHS,
12746 FoundRHS, Depth + 1);
12747 };
12748
12749 if (auto *LHSAddExpr = dyn_cast<SCEVAddExpr>(LHS)) {
12750 // We want to avoid creation of any new non-constant SCEV. Since we are
12751 // going to compare the operands to RHS, we should be certain that we don't
12752 // need any size extensions for this. So let's decline all cases when the
12753 // sizes of types of LHS and RHS do not match.
12754 // TODO: Maybe try to get RHS from sext to catch more cases?
12756 return false;
12757
12758 // Should not overflow.
12759 if (!LHSAddExpr->hasNoSignedWrap())
12760 return false;
12761
12762 auto *LL = LHSAddExpr->getOperand(0);
12763 auto *LR = LHSAddExpr->getOperand(1);
12764 auto *MinusOne = getMinusOne(RHS->getType());
12765
12766 // Checks that S1 >= 0 && S2 > RHS, trivially or using the found context.
12767 auto IsSumGreaterThanRHS = [&](const SCEV *S1, const SCEV *S2) {
12768 return IsSGTViaContext(S1, MinusOne) && IsSGTViaContext(S2, RHS);
12769 };
12770 // Try to prove the following rule:
12771 // (LHS = LL + LR) && (LL >= 0) && (LR > RHS) => (LHS > RHS).
12772 // (LHS = LL + LR) && (LR >= 0) && (LL > RHS) => (LHS > RHS).
12773 if (IsSumGreaterThanRHS(LL, LR) || IsSumGreaterThanRHS(LR, LL))
12774 return true;
12775 } else if (auto *LHSUnknownExpr = dyn_cast<SCEVUnknown>(LHS)) {
12776 Value *LL, *LR;
12777 // FIXME: Once we have SDiv implemented, we can get rid of this matching.
12778
12779 using namespace llvm::PatternMatch;
12780
12781 if (match(LHSUnknownExpr->getValue(), m_SDiv(m_Value(LL), m_Value(LR)))) {
12782 // Rules for division.
12783 // We are going to perform some comparisons with Denominator and its
12784 // derivative expressions. In general case, creating a SCEV for it may
12785 // lead to a complex analysis of the entire graph, and in particular it
12786 // can request trip count recalculation for the same loop. This would
12787 // cache as SCEVCouldNotCompute to avoid the infinite recursion. To avoid
12788 // this, we only want to create SCEVs that are constants in this section.
12789 // So we bail if Denominator is not a constant.
12790 if (!isa<ConstantInt>(LR))
12791 return false;
12792
12793 auto *Denominator = cast<SCEVConstant>(getSCEV(LR));
12794
12795 // We want to make sure that LHS = FoundLHS / Denominator. If it is so,
12796 // then a SCEV for the numerator already exists and matches with FoundLHS.
12797 auto *Numerator = getExistingSCEV(LL);
12798 if (!Numerator || Numerator->getType() != FoundLHS->getType())
12799 return false;
12800
12801 // Make sure that the numerator matches with FoundLHS and the denominator
12802 // is positive.
12803 if (!HasSameValue(Numerator, FoundLHS) || !isKnownPositive(Denominator))
12804 return false;
12805
12806 auto *DTy = Denominator->getType();
12807 auto *FRHSTy = FoundRHS->getType();
12808 if (DTy->isPointerTy() != FRHSTy->isPointerTy())
12809 // One of types is a pointer and another one is not. We cannot extend
12810 // them properly to a wider type, so let us just reject this case.
12811 // TODO: Usage of getEffectiveSCEVType for DTy, FRHSTy etc should help
12812 // to avoid this check.
12813 return false;
12814
12815 // Given that:
12816 // FoundLHS > FoundRHS, LHS = FoundLHS / Denominator, Denominator > 0.
12817 auto *WTy = getWiderType(DTy, FRHSTy);
12818 auto *DenominatorExt = getNoopOrSignExtend(Denominator, WTy);
12819 auto *FoundRHSExt = getNoopOrSignExtend(FoundRHS, WTy);
12820
12821 // Try to prove the following rule:
12822 // (FoundRHS > Denominator - 2) && (RHS <= 0) => (LHS > RHS).
12823 // For example, given that FoundLHS > 2. It means that FoundLHS is at
12824 // least 3. If we divide it by Denominator < 4, we will have at least 1.
12825 auto *DenomMinusTwo = getMinusSCEV(DenominatorExt, getConstant(WTy, 2));
12826 if (isKnownNonPositive(RHS) &&
12827 IsSGTViaContext(FoundRHSExt, DenomMinusTwo))
12828 return true;
12829
12830 // Try to prove the following rule:
12831 // (FoundRHS > -1 - Denominator) && (RHS < 0) => (LHS > RHS).
12832 // For example, given that FoundLHS > -3. Then FoundLHS is at least -2.
12833 // If we divide it by Denominator > 2, then:
12834 // 1. If FoundLHS is negative, then the result is 0.
12835 // 2. If FoundLHS is non-negative, then the result is non-negative.
12836 // Anyways, the result is non-negative.
12837 auto *MinusOne = getMinusOne(WTy);
12838 auto *NegDenomMinusOne = getMinusSCEV(MinusOne, DenominatorExt);
12839 if (isKnownNegative(RHS) &&
12840 IsSGTViaContext(FoundRHSExt, NegDenomMinusOne))
12841 return true;
12842 }
12843 }
12844
12845 // If our expression contained SCEVUnknown Phis, and we split it down and now
12846 // need to prove something for them, try to prove the predicate for every
12847 // possible incoming values of those Phis.
12848 if (isImpliedViaMerge(Pred, OrigLHS, RHS, OrigFoundLHS, FoundRHS, Depth + 1))
12849 return true;
12850
12851 return false;
12852}
12853
12855 const SCEV *RHS) {
12856 // zext x u<= sext x, sext x s<= zext x
12857 const SCEV *Op;
12858 switch (Pred) {
12859 case ICmpInst::ICMP_SGE:
12860 std::swap(LHS, RHS);
12861 [[fallthrough]];
12862 case ICmpInst::ICMP_SLE: {
12863 // If operand >=s 0 then ZExt == SExt. If operand <s 0 then SExt <s ZExt.
12864 return match(LHS, m_scev_SExt(m_SCEV(Op))) &&
12866 }
12867 case ICmpInst::ICMP_UGE:
12868 std::swap(LHS, RHS);
12869 [[fallthrough]];
12870 case ICmpInst::ICMP_ULE: {
12871 // If operand >=u 0 then ZExt == SExt. If operand <u 0 then ZExt <u SExt.
12872 return match(LHS, m_scev_ZExt(m_SCEV(Op))) &&
12874 }
12875 default:
12876 return false;
12877 };
12878 llvm_unreachable("unhandled case");
12879}
12880
12881bool ScalarEvolution::isKnownViaNonRecursiveReasoning(CmpPredicate Pred,
12882 const SCEV *LHS,
12883 const SCEV *RHS) {
12884 return isKnownPredicateExtendIdiom(Pred, LHS, RHS) ||
12885 isKnownPredicateViaConstantRanges(Pred, LHS, RHS) ||
12886 IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) ||
12887 IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) ||
12888 isKnownPredicateViaNoOverflow(Pred, LHS, RHS);
12889}
12890
12891bool ScalarEvolution::isImpliedCondOperandsHelper(CmpPredicate Pred,
12892 const SCEV *LHS,
12893 const SCEV *RHS,
12894 const SCEV *FoundLHS,
12895 const SCEV *FoundRHS) {
12896 switch (Pred) {
12897 default:
12898 llvm_unreachable("Unexpected CmpPredicate value!");
12899 case ICmpInst::ICMP_EQ:
12900 case ICmpInst::ICMP_NE:
12901 if (HasSameValue(LHS, FoundLHS) && HasSameValue(RHS, FoundRHS))
12902 return true;
12903 break;
12904 case ICmpInst::ICMP_SLT:
12905 case ICmpInst::ICMP_SLE:
12906 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, LHS, FoundLHS) &&
12907 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, RHS, FoundRHS))
12908 return true;
12909 break;
12910 case ICmpInst::ICMP_SGT:
12911 case ICmpInst::ICMP_SGE:
12912 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, LHS, FoundLHS) &&
12913 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, RHS, FoundRHS))
12914 return true;
12915 break;
12916 case ICmpInst::ICMP_ULT:
12917 case ICmpInst::ICMP_ULE:
12918 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, LHS, FoundLHS) &&
12919 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, RHS, FoundRHS))
12920 return true;
12921 break;
12922 case ICmpInst::ICMP_UGT:
12923 case ICmpInst::ICMP_UGE:
12924 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, LHS, FoundLHS) &&
12925 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, RHS, FoundRHS))
12926 return true;
12927 break;
12928 }
12929
12930 // Maybe it can be proved via operations?
12931 if (isImpliedViaOperations(Pred, LHS, RHS, FoundLHS, FoundRHS))
12932 return true;
12933
12934 return false;
12935}
12936
12937bool ScalarEvolution::isImpliedCondOperandsViaRanges(
12938 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, CmpPredicate FoundPred,
12939 const SCEV *FoundLHS, const SCEV *FoundRHS) {
12940 if (!isa<SCEVConstant>(RHS) || !isa<SCEVConstant>(FoundRHS))
12941 // The restriction on `FoundRHS` be lifted easily -- it exists only to
12942 // reduce the compile time impact of this optimization.
12943 return false;
12944
12945 std::optional<APInt> Addend = computeConstantDifference(LHS, FoundLHS);
12946 if (!Addend)
12947 return false;
12948
12949 const APInt &ConstFoundRHS = cast<SCEVConstant>(FoundRHS)->getAPInt();
12950
12951 // `FoundLHSRange` is the range we know `FoundLHS` to be in by virtue of the
12952 // antecedent "`FoundLHS` `FoundPred` `FoundRHS`".
12953 ConstantRange FoundLHSRange =
12954 ConstantRange::makeExactICmpRegion(FoundPred, ConstFoundRHS);
12955
12956 // Since `LHS` is `FoundLHS` + `Addend`, we can compute a range for `LHS`:
12957 ConstantRange LHSRange = FoundLHSRange.add(ConstantRange(*Addend));
12958
12959 // We can also compute the range of values for `LHS` that satisfy the
12960 // consequent, "`LHS` `Pred` `RHS`":
12961 const APInt &ConstRHS = cast<SCEVConstant>(RHS)->getAPInt();
12962 // The antecedent implies the consequent if every value of `LHS` that
12963 // satisfies the antecedent also satisfies the consequent.
12964 return LHSRange.icmp(Pred, ConstRHS);
12965}
12966
12967bool ScalarEvolution::canIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride,
12968 bool IsSigned) {
12969 assert(isKnownPositive(Stride) && "Positive stride expected!");
12970
12971 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
12972 const SCEV *One = getOne(Stride->getType());
12973
12974 if (IsSigned) {
12975 APInt MaxRHS = getSignedRangeMax(RHS);
12976 APInt MaxValue = APInt::getSignedMaxValue(BitWidth);
12977 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
12978
12979 // SMaxRHS + SMaxStrideMinusOne > SMaxValue => overflow!
12980 return (std::move(MaxValue) - MaxStrideMinusOne).slt(MaxRHS);
12981 }
12982
12983 APInt MaxRHS = getUnsignedRangeMax(RHS);
12984 APInt MaxValue = APInt::getMaxValue(BitWidth);
12985 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
12986
12987 // UMaxRHS + UMaxStrideMinusOne > UMaxValue => overflow!
12988 return (std::move(MaxValue) - MaxStrideMinusOne).ult(MaxRHS);
12989}
12990
12991bool ScalarEvolution::canIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride,
12992 bool IsSigned) {
12993
12994 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
12995 const SCEV *One = getOne(Stride->getType());
12996
12997 if (IsSigned) {
12998 APInt MinRHS = getSignedRangeMin(RHS);
12999 APInt MinValue = APInt::getSignedMinValue(BitWidth);
13000 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
13001
13002 // SMinRHS - SMaxStrideMinusOne < SMinValue => overflow!
13003 return (std::move(MinValue) + MaxStrideMinusOne).sgt(MinRHS);
13004 }
13005
13006 APInt MinRHS = getUnsignedRangeMin(RHS);
13007 APInt MinValue = APInt::getMinValue(BitWidth);
13008 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
13009
13010 // UMinRHS - UMaxStrideMinusOne < UMinValue => overflow!
13011 return (std::move(MinValue) + MaxStrideMinusOne).ugt(MinRHS);
13012}
13013
13015 // umin(N, 1) + floor((N - umin(N, 1)) / D)
13016 // This is equivalent to "1 + floor((N - 1) / D)" for N != 0. The umin
13017 // expression fixes the case of N=0.
13018 const SCEV *MinNOne = getUMinExpr(N, getOne(N->getType()));
13019 const SCEV *NMinusOne = getMinusSCEV(N, MinNOne);
13020 return getAddExpr(MinNOne, getUDivExpr(NMinusOne, D));
13021}
13022
13023const SCEV *ScalarEvolution::computeMaxBECountForLT(const SCEV *Start,
13024 const SCEV *Stride,
13025 const SCEV *End,
13026 unsigned BitWidth,
13027 bool IsSigned) {
13028 // The logic in this function assumes we can represent a positive stride.
13029 // If we can't, the backedge-taken count must be zero.
13030 if (IsSigned && BitWidth == 1)
13031 return getZero(Stride->getType());
13032
13033 // This code below only been closely audited for negative strides in the
13034 // unsigned comparison case, it may be correct for signed comparison, but
13035 // that needs to be established.
13036 if (IsSigned && isKnownNegative(Stride))
13037 return getCouldNotCompute();
13038
13039 // Calculate the maximum backedge count based on the range of values
13040 // permitted by Start, End, and Stride.
13041 APInt MinStart =
13042 IsSigned ? getSignedRangeMin(Start) : getUnsignedRangeMin(Start);
13043
13044 APInt MinStride =
13045 IsSigned ? getSignedRangeMin(Stride) : getUnsignedRangeMin(Stride);
13046
13047 // We assume either the stride is positive, or the backedge-taken count
13048 // is zero. So force StrideForMaxBECount to be at least one.
13049 APInt One(BitWidth, 1);
13050 APInt StrideForMaxBECount = IsSigned ? APIntOps::smax(One, MinStride)
13051 : APIntOps::umax(One, MinStride);
13052
13053 APInt MaxValue = IsSigned ? APInt::getSignedMaxValue(BitWidth)
13054 : APInt::getMaxValue(BitWidth);
13055 APInt Limit = MaxValue - (StrideForMaxBECount - 1);
13056
13057 // Although End can be a MAX expression we estimate MaxEnd considering only
13058 // the case End = RHS of the loop termination condition. This is safe because
13059 // in the other case (End - Start) is zero, leading to a zero maximum backedge
13060 // taken count.
13061 APInt MaxEnd = IsSigned ? APIntOps::smin(getSignedRangeMax(End), Limit)
13062 : APIntOps::umin(getUnsignedRangeMax(End), Limit);
13063
13064 // MaxBECount = ceil((max(MaxEnd, MinStart) - MinStart) / Stride)
13065 MaxEnd = IsSigned ? APIntOps::smax(MaxEnd, MinStart)
13066 : APIntOps::umax(MaxEnd, MinStart);
13067
13068 return getUDivCeilSCEV(getConstant(MaxEnd - MinStart) /* Delta */,
13069 getConstant(StrideForMaxBECount) /* Step */);
13070}
13071
13073ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
13074 const Loop *L, bool IsSigned,
13075 bool ControlsOnlyExit, bool AllowPredicates) {
13077
13079 bool PredicatedIV = false;
13080 if (!IV) {
13081 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS)) {
13082 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(ZExt->getOperand());
13083 if (AR && AR->getLoop() == L && AR->isAffine()) {
13084 auto canProveNUW = [&]() {
13085 // We can use the comparison to infer no-wrap flags only if it fully
13086 // controls the loop exit.
13087 if (!ControlsOnlyExit)
13088 return false;
13089
13090 if (!isLoopInvariant(RHS, L))
13091 return false;
13092
13093 if (!isKnownNonZero(AR->getStepRecurrence(*this)))
13094 // We need the sequence defined by AR to strictly increase in the
13095 // unsigned integer domain for the logic below to hold.
13096 return false;
13097
13098 const unsigned InnerBitWidth = getTypeSizeInBits(AR->getType());
13099 const unsigned OuterBitWidth = getTypeSizeInBits(RHS->getType());
13100 // If RHS <=u Limit, then there must exist a value V in the sequence
13101 // defined by AR (e.g. {Start,+,Step}) such that V >u RHS, and
13102 // V <=u UINT_MAX. Thus, we must exit the loop before unsigned
13103 // overflow occurs. This limit also implies that a signed comparison
13104 // (in the wide bitwidth) is equivalent to an unsigned comparison as
13105 // the high bits on both sides must be zero.
13106 APInt StrideMax = getUnsignedRangeMax(AR->getStepRecurrence(*this));
13107 APInt Limit = APInt::getMaxValue(InnerBitWidth) - (StrideMax - 1);
13108 Limit = Limit.zext(OuterBitWidth);
13109 return getUnsignedRangeMax(applyLoopGuards(RHS, L)).ule(Limit);
13110 };
13111 auto Flags = AR->getNoWrapFlags();
13112 if (!hasFlags(Flags, SCEV::FlagNUW) && canProveNUW())
13113 Flags = setFlags(Flags, SCEV::FlagNUW);
13114
13115 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
13116 if (AR->hasNoUnsignedWrap()) {
13117 // Emulate what getZeroExtendExpr would have done during construction
13118 // if we'd been able to infer the fact just above at that time.
13119 const SCEV *Step = AR->getStepRecurrence(*this);
13120 Type *Ty = ZExt->getType();
13121 auto *S = getAddRecExpr(
13123 getZeroExtendExpr(Step, Ty, 0), L, AR->getNoWrapFlags());
13125 }
13126 }
13127 }
13128 }
13129
13130
13131 if (!IV && AllowPredicates) {
13132 // Try to make this an AddRec using runtime tests, in the first X
13133 // iterations of this loop, where X is the SCEV expression found by the
13134 // algorithm below.
13135 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
13136 PredicatedIV = true;
13137 }
13138
13139 // Avoid weird loops
13140 if (!IV || IV->getLoop() != L || !IV->isAffine())
13141 return getCouldNotCompute();
13142
13143 // A precondition of this method is that the condition being analyzed
13144 // reaches an exiting branch which dominates the latch. Given that, we can
13145 // assume that an increment which violates the nowrap specification and
13146 // produces poison must cause undefined behavior when the resulting poison
13147 // value is branched upon and thus we can conclude that the backedge is
13148 // taken no more often than would be required to produce that poison value.
13149 // Note that a well defined loop can exit on the iteration which violates
13150 // the nowrap specification if there is another exit (either explicit or
13151 // implicit/exceptional) which causes the loop to execute before the
13152 // exiting instruction we're analyzing would trigger UB.
13153 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
13154 bool NoWrap = ControlsOnlyExit && IV->getNoWrapFlags(WrapType);
13156
13157 const SCEV *Stride = IV->getStepRecurrence(*this);
13158
13159 bool PositiveStride = isKnownPositive(Stride);
13160
13161 // Avoid negative or zero stride values.
13162 if (!PositiveStride) {
13163 // We can compute the correct backedge taken count for loops with unknown
13164 // strides if we can prove that the loop is not an infinite loop with side
13165 // effects. Here's the loop structure we are trying to handle -
13166 //
13167 // i = start
13168 // do {
13169 // A[i] = i;
13170 // i += s;
13171 // } while (i < end);
13172 //
13173 // The backedge taken count for such loops is evaluated as -
13174 // (max(end, start + stride) - start - 1) /u stride
13175 //
13176 // The additional preconditions that we need to check to prove correctness
13177 // of the above formula is as follows -
13178 //
13179 // a) IV is either nuw or nsw depending upon signedness (indicated by the
13180 // NoWrap flag).
13181 // b) the loop is guaranteed to be finite (e.g. is mustprogress and has
13182 // no side effects within the loop)
13183 // c) loop has a single static exit (with no abnormal exits)
13184 //
13185 // Precondition a) implies that if the stride is negative, this is a single
13186 // trip loop. The backedge taken count formula reduces to zero in this case.
13187 //
13188 // Precondition b) and c) combine to imply that if rhs is invariant in L,
13189 // then a zero stride means the backedge can't be taken without executing
13190 // undefined behavior.
13191 //
13192 // The positive stride case is the same as isKnownPositive(Stride) returning
13193 // true (original behavior of the function).
13194 //
13195 if (PredicatedIV || !NoWrap || !loopIsFiniteByAssumption(L) ||
13197 return getCouldNotCompute();
13198
13199 if (!isKnownNonZero(Stride)) {
13200 // If we have a step of zero, and RHS isn't invariant in L, we don't know
13201 // if it might eventually be greater than start and if so, on which
13202 // iteration. We can't even produce a useful upper bound.
13203 if (!isLoopInvariant(RHS, L))
13204 return getCouldNotCompute();
13205
13206 // We allow a potentially zero stride, but we need to divide by stride
13207 // below. Since the loop can't be infinite and this check must control
13208 // the sole exit, we can infer the exit must be taken on the first
13209 // iteration (e.g. backedge count = 0) if the stride is zero. Given that,
13210 // we know the numerator in the divides below must be zero, so we can
13211 // pick an arbitrary non-zero value for the denominator (e.g. stride)
13212 // and produce the right result.
13213 // FIXME: Handle the case where Stride is poison?
13214 auto wouldZeroStrideBeUB = [&]() {
13215 // Proof by contradiction. Suppose the stride were zero. If we can
13216 // prove that the backedge *is* taken on the first iteration, then since
13217 // we know this condition controls the sole exit, we must have an
13218 // infinite loop. We can't have a (well defined) infinite loop per
13219 // check just above.
13220 // Note: The (Start - Stride) term is used to get the start' term from
13221 // (start' + stride,+,stride). Remember that we only care about the
13222 // result of this expression when stride == 0 at runtime.
13223 auto *StartIfZero = getMinusSCEV(IV->getStart(), Stride);
13224 return isLoopEntryGuardedByCond(L, Cond, StartIfZero, RHS);
13225 };
13226 if (!wouldZeroStrideBeUB()) {
13227 Stride = getUMaxExpr(Stride, getOne(Stride->getType()));
13228 }
13229 }
13230 } else if (!NoWrap) {
13231 // Avoid proven overflow cases: this will ensure that the backedge taken
13232 // count will not generate any unsigned overflow.
13233 if (canIVOverflowOnLT(RHS, Stride, IsSigned))
13234 return getCouldNotCompute();
13235 }
13236
13237 // On all paths just preceeding, we established the following invariant:
13238 // IV can be assumed not to overflow up to and including the exiting
13239 // iteration. We proved this in one of two ways:
13240 // 1) We can show overflow doesn't occur before the exiting iteration
13241 // 1a) canIVOverflowOnLT, and b) step of one
13242 // 2) We can show that if overflow occurs, the loop must execute UB
13243 // before any possible exit.
13244 // Note that we have not yet proved RHS invariant (in general).
13245
13246 const SCEV *Start = IV->getStart();
13247
13248 // Preserve pointer-typed Start/RHS to pass to isLoopEntryGuardedByCond.
13249 // If we convert to integers, isLoopEntryGuardedByCond will miss some cases.
13250 // Use integer-typed versions for actual computation; we can't subtract
13251 // pointers in general.
13252 const SCEV *OrigStart = Start;
13253 const SCEV *OrigRHS = RHS;
13254 if (Start->getType()->isPointerTy()) {
13256 if (isa<SCEVCouldNotCompute>(Start))
13257 return Start;
13258 }
13259 if (RHS->getType()->isPointerTy()) {
13262 return RHS;
13263 }
13264
13265 const SCEV *End = nullptr, *BECount = nullptr,
13266 *BECountIfBackedgeTaken = nullptr;
13267 if (!isLoopInvariant(RHS, L)) {
13268 const auto *RHSAddRec = dyn_cast<SCEVAddRecExpr>(RHS);
13269 if (PositiveStride && RHSAddRec != nullptr && RHSAddRec->getLoop() == L &&
13270 RHSAddRec->getNoWrapFlags()) {
13271 // The structure of loop we are trying to calculate backedge count of:
13272 //
13273 // left = left_start
13274 // right = right_start
13275 //
13276 // while(left < right){
13277 // ... do something here ...
13278 // left += s1; // stride of left is s1 (s1 > 0)
13279 // right += s2; // stride of right is s2 (s2 < 0)
13280 // }
13281 //
13282
13283 const SCEV *RHSStart = RHSAddRec->getStart();
13284 const SCEV *RHSStride = RHSAddRec->getStepRecurrence(*this);
13285
13286 // If Stride - RHSStride is positive and does not overflow, we can write
13287 // backedge count as ->
13288 // ceil((End - Start) /u (Stride - RHSStride))
13289 // Where, End = max(RHSStart, Start)
13290
13291 // Check if RHSStride < 0 and Stride - RHSStride will not overflow.
13292 if (isKnownNegative(RHSStride) &&
13293 willNotOverflow(Instruction::Sub, /*Signed=*/true, Stride,
13294 RHSStride)) {
13295
13296 const SCEV *Denominator = getMinusSCEV(Stride, RHSStride);
13297 if (isKnownPositive(Denominator)) {
13298 End = IsSigned ? getSMaxExpr(RHSStart, Start)
13299 : getUMaxExpr(RHSStart, Start);
13300
13301 // We can do this because End >= Start, as End = max(RHSStart, Start)
13302 const SCEV *Delta = getMinusSCEV(End, Start);
13303
13304 BECount = getUDivCeilSCEV(Delta, Denominator);
13305 BECountIfBackedgeTaken =
13306 getUDivCeilSCEV(getMinusSCEV(RHSStart, Start), Denominator);
13307 }
13308 }
13309 }
13310 if (BECount == nullptr) {
13311 // If we cannot calculate ExactBECount, we can calculate the MaxBECount,
13312 // given the start, stride and max value for the end bound of the
13313 // loop (RHS), and the fact that IV does not overflow (which is
13314 // checked above).
13315 const SCEV *MaxBECount = computeMaxBECountForLT(
13316 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13317 return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,
13318 MaxBECount, false /*MaxOrZero*/, Predicates);
13319 }
13320 } else {
13321 // We use the expression (max(End,Start)-Start)/Stride to describe the
13322 // backedge count, as if the backedge is taken at least once
13323 // max(End,Start) is End and so the result is as above, and if not
13324 // max(End,Start) is Start so we get a backedge count of zero.
13325 auto *OrigStartMinusStride = getMinusSCEV(OrigStart, Stride);
13326 assert(isAvailableAtLoopEntry(OrigStartMinusStride, L) && "Must be!");
13327 assert(isAvailableAtLoopEntry(OrigStart, L) && "Must be!");
13328 assert(isAvailableAtLoopEntry(OrigRHS, L) && "Must be!");
13329 // Can we prove (max(RHS,Start) > Start - Stride?
13330 if (isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigStart) &&
13331 isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigRHS)) {
13332 // In this case, we can use a refined formula for computing backedge
13333 // taken count. The general formula remains:
13334 // "End-Start /uceiling Stride" where "End = max(RHS,Start)"
13335 // We want to use the alternate formula:
13336 // "((End - 1) - (Start - Stride)) /u Stride"
13337 // Let's do a quick case analysis to show these are equivalent under
13338 // our precondition that max(RHS,Start) > Start - Stride.
13339 // * For RHS <= Start, the backedge-taken count must be zero.
13340 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
13341 // "((Start - 1) - (Start - Stride)) /u Stride" which simplies to
13342 // "Stride - 1 /u Stride" which is indeed zero for all non-zero values
13343 // of Stride. For 0 stride, we've use umin(1,Stride) above,
13344 // reducing this to the stride of 1 case.
13345 // * For RHS >= Start, the backedge count must be "RHS-Start /uceil
13346 // Stride".
13347 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
13348 // "((RHS - 1) - (Start - Stride)) /u Stride" reassociates to
13349 // "((RHS - (Start - Stride) - 1) /u Stride".
13350 // Our preconditions trivially imply no overflow in that form.
13351 const SCEV *MinusOne = getMinusOne(Stride->getType());
13352 const SCEV *Numerator =
13353 getMinusSCEV(getAddExpr(RHS, MinusOne), getMinusSCEV(Start, Stride));
13354 BECount = getUDivExpr(Numerator, Stride);
13355 }
13356
13357 if (!BECount) {
13358 auto canProveRHSGreaterThanEqualStart = [&]() {
13359 auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
13360 const SCEV *GuardedRHS = applyLoopGuards(OrigRHS, L);
13361 const SCEV *GuardedStart = applyLoopGuards(OrigStart, L);
13362
13363 if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart) ||
13364 isKnownPredicate(CondGE, GuardedRHS, GuardedStart))
13365 return true;
13366
13367 // (RHS > Start - 1) implies RHS >= Start.
13368 // * "RHS >= Start" is trivially equivalent to "RHS > Start - 1" if
13369 // "Start - 1" doesn't overflow.
13370 // * For signed comparison, if Start - 1 does overflow, it's equal
13371 // to INT_MAX, and "RHS >s INT_MAX" is trivially false.
13372 // * For unsigned comparison, if Start - 1 does overflow, it's equal
13373 // to UINT_MAX, and "RHS >u UINT_MAX" is trivially false.
13374 //
13375 // FIXME: Should isLoopEntryGuardedByCond do this for us?
13376 auto CondGT = IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT;
13377 auto *StartMinusOne =
13378 getAddExpr(OrigStart, getMinusOne(OrigStart->getType()));
13379 return isLoopEntryGuardedByCond(L, CondGT, OrigRHS, StartMinusOne);
13380 };
13381
13382 // If we know that RHS >= Start in the context of loop, then we know
13383 // that max(RHS, Start) = RHS at this point.
13384 if (canProveRHSGreaterThanEqualStart()) {
13385 End = RHS;
13386 } else {
13387 // If RHS < Start, the backedge will be taken zero times. So in
13388 // general, we can write the backedge-taken count as:
13389 //
13390 // RHS >= Start ? ceil(RHS - Start) / Stride : 0
13391 //
13392 // We convert it to the following to make it more convenient for SCEV:
13393 //
13394 // ceil(max(RHS, Start) - Start) / Stride
13395 End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start);
13396
13397 // See what would happen if we assume the backedge is taken. This is
13398 // used to compute MaxBECount.
13399 BECountIfBackedgeTaken =
13400 getUDivCeilSCEV(getMinusSCEV(RHS, Start), Stride);
13401 }
13402
13403 // At this point, we know:
13404 //
13405 // 1. If IsSigned, Start <=s End; otherwise, Start <=u End
13406 // 2. The index variable doesn't overflow.
13407 //
13408 // Therefore, we know N exists such that
13409 // (Start + Stride * N) >= End, and computing "(Start + Stride * N)"
13410 // doesn't overflow.
13411 //
13412 // Using this information, try to prove whether the addition in
13413 // "(Start - End) + (Stride - 1)" has unsigned overflow.
13414 const SCEV *One = getOne(Stride->getType());
13415 bool MayAddOverflow = [&] {
13416 if (isKnownToBeAPowerOfTwo(Stride)) {
13417 // Suppose Stride is a power of two, and Start/End are unsigned
13418 // integers. Let UMAX be the largest representable unsigned
13419 // integer.
13420 //
13421 // By the preconditions of this function, we know
13422 // "(Start + Stride * N) >= End", and this doesn't overflow.
13423 // As a formula:
13424 //
13425 // End <= (Start + Stride * N) <= UMAX
13426 //
13427 // Subtracting Start from all the terms:
13428 //
13429 // End - Start <= Stride * N <= UMAX - Start
13430 //
13431 // Since Start is unsigned, UMAX - Start <= UMAX. Therefore:
13432 //
13433 // End - Start <= Stride * N <= UMAX
13434 //
13435 // Stride * N is a multiple of Stride. Therefore,
13436 //
13437 // End - Start <= Stride * N <= UMAX - (UMAX mod Stride)
13438 //
13439 // Since Stride is a power of two, UMAX + 1 is divisible by
13440 // Stride. Therefore, UMAX mod Stride == Stride - 1. So we can
13441 // write:
13442 //
13443 // End - Start <= Stride * N <= UMAX - Stride - 1
13444 //
13445 // Dropping the middle term:
13446 //
13447 // End - Start <= UMAX - Stride - 1
13448 //
13449 // Adding Stride - 1 to both sides:
13450 //
13451 // (End - Start) + (Stride - 1) <= UMAX
13452 //
13453 // In other words, the addition doesn't have unsigned overflow.
13454 //
13455 // A similar proof works if we treat Start/End as signed values.
13456 // Just rewrite steps before "End - Start <= Stride * N <= UMAX"
13457 // to use signed max instead of unsigned max. Note that we're
13458 // trying to prove a lack of unsigned overflow in either case.
13459 return false;
13460 }
13461 if (Start == Stride || Start == getMinusSCEV(Stride, One)) {
13462 // If Start is equal to Stride, (End - Start) + (Stride - 1) == End
13463 // - 1. If !IsSigned, 0 <u Stride == Start <=u End; so 0 <u End - 1
13464 // <u End. If IsSigned, 0 <s Stride == Start <=s End; so 0 <s End -
13465 // 1 <s End.
13466 //
13467 // If Start is equal to Stride - 1, (End - Start) + Stride - 1 ==
13468 // End.
13469 return false;
13470 }
13471 return true;
13472 }();
13473
13474 const SCEV *Delta = getMinusSCEV(End, Start);
13475 if (!MayAddOverflow) {
13476 // floor((D + (S - 1)) / S)
13477 // We prefer this formulation if it's legal because it's fewer
13478 // operations.
13479 BECount =
13480 getUDivExpr(getAddExpr(Delta, getMinusSCEV(Stride, One)), Stride);
13481 } else {
13482 BECount = getUDivCeilSCEV(Delta, Stride);
13483 }
13484 }
13485 }
13486
13487 const SCEV *ConstantMaxBECount;
13488 bool MaxOrZero = false;
13489 if (isa<SCEVConstant>(BECount)) {
13490 ConstantMaxBECount = BECount;
13491 } else if (BECountIfBackedgeTaken &&
13492 isa<SCEVConstant>(BECountIfBackedgeTaken)) {
13493 // If we know exactly how many times the backedge will be taken if it's
13494 // taken at least once, then the backedge count will either be that or
13495 // zero.
13496 ConstantMaxBECount = BECountIfBackedgeTaken;
13497 MaxOrZero = true;
13498 } else {
13499 ConstantMaxBECount = computeMaxBECountForLT(
13500 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13501 }
13502
13503 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
13504 !isa<SCEVCouldNotCompute>(BECount))
13505 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
13506
13507 const SCEV *SymbolicMaxBECount =
13508 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13509 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, MaxOrZero,
13510 Predicates);
13511}
13512
13513ScalarEvolution::ExitLimit ScalarEvolution::howManyGreaterThans(
13514 const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned,
13515 bool ControlsOnlyExit, bool AllowPredicates) {
13517 // We handle only IV > Invariant
13518 if (!isLoopInvariant(RHS, L))
13519 return getCouldNotCompute();
13520
13521 const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
13522 if (!IV && AllowPredicates)
13523 // Try to make this an AddRec using runtime tests, in the first X
13524 // iterations of this loop, where X is the SCEV expression found by the
13525 // algorithm below.
13526 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
13527
13528 // Avoid weird loops
13529 if (!IV || IV->getLoop() != L || !IV->isAffine())
13530 return getCouldNotCompute();
13531
13532 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
13533 bool NoWrap = ControlsOnlyExit && IV->getNoWrapFlags(WrapType);
13535
13536 const SCEV *Stride = getNegativeSCEV(IV->getStepRecurrence(*this));
13537
13538 // Avoid negative or zero stride values
13539 if (!isKnownPositive(Stride))
13540 return getCouldNotCompute();
13541
13542 // Avoid proven overflow cases: this will ensure that the backedge taken count
13543 // will not generate any unsigned overflow. Relaxed no-overflow conditions
13544 // exploit NoWrapFlags, allowing to optimize in presence of undefined
13545 // behaviors like the case of C language.
13546 if (!Stride->isOne() && !NoWrap)
13547 if (canIVOverflowOnGT(RHS, Stride, IsSigned))
13548 return getCouldNotCompute();
13549
13550 const SCEV *Start = IV->getStart();
13551 const SCEV *End = RHS;
13552 if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) {
13553 // If we know that Start >= RHS in the context of loop, then we know that
13554 // min(RHS, Start) = RHS at this point.
13556 L, IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE, Start, RHS))
13557 End = RHS;
13558 else
13559 End = IsSigned ? getSMinExpr(RHS, Start) : getUMinExpr(RHS, Start);
13560 }
13561
13562 if (Start->getType()->isPointerTy()) {
13564 if (isa<SCEVCouldNotCompute>(Start))
13565 return Start;
13566 }
13567 if (End->getType()->isPointerTy()) {
13568 End = getLosslessPtrToIntExpr(End);
13569 if (isa<SCEVCouldNotCompute>(End))
13570 return End;
13571 }
13572
13573 // Compute ((Start - End) + (Stride - 1)) / Stride.
13574 // FIXME: This can overflow. Holding off on fixing this for now;
13575 // howManyGreaterThans will hopefully be gone soon.
13576 const SCEV *One = getOne(Stride->getType());
13577 const SCEV *BECount = getUDivExpr(
13578 getAddExpr(getMinusSCEV(Start, End), getMinusSCEV(Stride, One)), Stride);
13579
13580 APInt MaxStart = IsSigned ? getSignedRangeMax(Start)
13582
13583 APInt MinStride = IsSigned ? getSignedRangeMin(Stride)
13584 : getUnsignedRangeMin(Stride);
13585
13586 unsigned BitWidth = getTypeSizeInBits(LHS->getType());
13587 APInt Limit = IsSigned ? APInt::getSignedMinValue(BitWidth) + (MinStride - 1)
13588 : APInt::getMinValue(BitWidth) + (MinStride - 1);
13589
13590 // Although End can be a MIN expression we estimate MinEnd considering only
13591 // the case End = RHS. This is safe because in the other case (Start - End)
13592 // is zero, leading to a zero maximum backedge taken count.
13593 APInt MinEnd =
13594 IsSigned ? APIntOps::smax(getSignedRangeMin(RHS), Limit)
13595 : APIntOps::umax(getUnsignedRangeMin(RHS), Limit);
13596
13597 const SCEV *ConstantMaxBECount =
13598 isa<SCEVConstant>(BECount)
13599 ? BECount
13600 : getUDivCeilSCEV(getConstant(MaxStart - MinEnd),
13601 getConstant(MinStride));
13602
13603 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount))
13604 ConstantMaxBECount = BECount;
13605 const SCEV *SymbolicMaxBECount =
13606 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13607
13608 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
13609 Predicates);
13610}
13611
13613 ScalarEvolution &SE) const {
13614 if (Range.isFullSet()) // Infinite loop.
13615 return SE.getCouldNotCompute();
13616
13617 // If the start is a non-zero constant, shift the range to simplify things.
13618 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart()))
13619 if (!SC->getValue()->isZero()) {
13621 Operands[0] = SE.getZero(SC->getType());
13622 const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop(),
13624 if (const auto *ShiftedAddRec = dyn_cast<SCEVAddRecExpr>(Shifted))
13625 return ShiftedAddRec->getNumIterationsInRange(
13626 Range.subtract(SC->getAPInt()), SE);
13627 // This is strange and shouldn't happen.
13628 return SE.getCouldNotCompute();
13629 }
13630
13631 // The only time we can solve this is when we have all constant indices.
13632 // Otherwise, we cannot determine the overflow conditions.
13633 if (any_of(operands(), [](const SCEV *Op) { return !isa<SCEVConstant>(Op); }))
13634 return SE.getCouldNotCompute();
13635
13636 // Okay at this point we know that all elements of the chrec are constants and
13637 // that the start element is zero.
13638
13639 // First check to see if the range contains zero. If not, the first
13640 // iteration exits.
13641 unsigned BitWidth = SE.getTypeSizeInBits(getType());
13642 if (!Range.contains(APInt(BitWidth, 0)))
13643 return SE.getZero(getType());
13644
13645 if (isAffine()) {
13646 // If this is an affine expression then we have this situation:
13647 // Solve {0,+,A} in Range === Ax in Range
13648
13649 // We know that zero is in the range. If A is positive then we know that
13650 // the upper value of the range must be the first possible exit value.
13651 // If A is negative then the lower of the range is the last possible loop
13652 // value. Also note that we already checked for a full range.
13653 APInt A = cast<SCEVConstant>(getOperand(1))->getAPInt();
13654 APInt End = A.sge(1) ? (Range.getUpper() - 1) : Range.getLower();
13655
13656 // The exit value should be (End+A)/A.
13657 APInt ExitVal = (End + A).udiv(A);
13658 ConstantInt *ExitValue = ConstantInt::get(SE.getContext(), ExitVal);
13659
13660 // Evaluate at the exit value. If we really did fall out of the valid
13661 // range, then we computed our trip count, otherwise wrap around or other
13662 // things must have happened.
13663 ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue, SE);
13664 if (Range.contains(Val->getValue()))
13665 return SE.getCouldNotCompute(); // Something strange happened
13666
13667 // Ensure that the previous value is in the range.
13668 assert(Range.contains(
13670 ConstantInt::get(SE.getContext(), ExitVal - 1), SE)->getValue()) &&
13671 "Linear scev computation is off in a bad way!");
13672 return SE.getConstant(ExitValue);
13673 }
13674
13675 if (isQuadratic()) {
13676 if (auto S = SolveQuadraticAddRecRange(this, Range, SE))
13677 return SE.getConstant(*S);
13678 }
13679
13680 return SE.getCouldNotCompute();
13681}
13682
13683const SCEVAddRecExpr *
13685 assert(getNumOperands() > 1 && "AddRec with zero step?");
13686 // There is a temptation to just call getAddExpr(this, getStepRecurrence(SE)),
13687 // but in this case we cannot guarantee that the value returned will be an
13688 // AddRec because SCEV does not have a fixed point where it stops
13689 // simplification: it is legal to return ({rec1} + {rec2}). For example, it
13690 // may happen if we reach arithmetic depth limit while simplifying. So we
13691 // construct the returned value explicitly.
13693 // If this is {A,+,B,+,C,...,+,N}, then its step is {B,+,C,+,...,+,N}, and
13694 // (this + Step) is {A+B,+,B+C,+...,+,N}.
13695 for (unsigned i = 0, e = getNumOperands() - 1; i < e; ++i)
13696 Ops.push_back(SE.getAddExpr(getOperand(i), getOperand(i + 1)));
13697 // We know that the last operand is not a constant zero (otherwise it would
13698 // have been popped out earlier). This guarantees us that if the result has
13699 // the same last operand, then it will also not be popped out, meaning that
13700 // the returned value will be an AddRec.
13701 const SCEV *Last = getOperand(getNumOperands() - 1);
13702 assert(!Last->isZero() && "Recurrency with zero step?");
13703 Ops.push_back(Last);
13706}
13707
13708// Return true when S contains at least an undef value.
13710 return SCEVExprContains(
13711 S, [](const SCEV *S) { return match(S, m_scev_UndefOrPoison()); });
13712}
13713
13714// Return true when S contains a value that is a nullptr.
13716 return SCEVExprContains(S, [](const SCEV *S) {
13717 if (const auto *SU = dyn_cast<SCEVUnknown>(S))
13718 return SU->getValue() == nullptr;
13719 return false;
13720 });
13721}
13722
13723/// Return the size of an element read or written by Inst.
13725 Type *Ty;
13726 if (StoreInst *Store = dyn_cast<StoreInst>(Inst))
13727 Ty = Store->getValueOperand()->getType();
13728 else if (LoadInst *Load = dyn_cast<LoadInst>(Inst))
13729 Ty = Load->getType();
13730 else
13731 return nullptr;
13732
13734 return getSizeOfExpr(ETy, Ty);
13735}
13736
13737//===----------------------------------------------------------------------===//
13738// SCEVCallbackVH Class Implementation
13739//===----------------------------------------------------------------------===//
13740
13742 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13743 if (PHINode *PN = dyn_cast<PHINode>(getValPtr()))
13744 SE->ConstantEvolutionLoopExitValue.erase(PN);
13745 SE->eraseValueFromMap(getValPtr());
13746 // this now dangles!
13747}
13748
13749void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) {
13750 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13751
13752 // Forget all the expressions associated with users of the old value,
13753 // so that future queries will recompute the expressions using the new
13754 // value.
13755 SE->forgetValue(getValPtr());
13756 // this now dangles!
13757}
13758
13759ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se)
13760 : CallbackVH(V), SE(se) {}
13761
13762//===----------------------------------------------------------------------===//
13763// ScalarEvolution Class Implementation
13764//===----------------------------------------------------------------------===//
13765
13768 LoopInfo &LI)
13769 : F(F), DL(F.getDataLayout()), TLI(TLI), AC(AC), DT(DT), LI(LI),
13770 CouldNotCompute(new SCEVCouldNotCompute()), ValuesAtScopes(64),
13771 LoopDispositions(64), BlockDispositions(64) {
13772 // To use guards for proving predicates, we need to scan every instruction in
13773 // relevant basic blocks, and not just terminators. Doing this is a waste of
13774 // time if the IR does not actually contain any calls to
13775 // @llvm.experimental.guard, so do a quick check and remember this beforehand.
13776 //
13777 // This pessimizes the case where a pass that preserves ScalarEvolution wants
13778 // to _add_ guards to the module when there weren't any before, and wants
13779 // ScalarEvolution to optimize based on those guards. For now we prefer to be
13780 // efficient in lieu of being smart in that rather obscure case.
13781
13782 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
13783 F.getParent(), Intrinsic::experimental_guard);
13784 HasGuards = GuardDecl && !GuardDecl->use_empty();
13785}
13786
13788 : F(Arg.F), DL(Arg.DL), HasGuards(Arg.HasGuards), TLI(Arg.TLI), AC(Arg.AC),
13789 DT(Arg.DT), LI(Arg.LI), CouldNotCompute(std::move(Arg.CouldNotCompute)),
13790 ValueExprMap(std::move(Arg.ValueExprMap)),
13791 PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)),
13792 PendingPhiRanges(std::move(Arg.PendingPhiRanges)),
13793 PendingMerges(std::move(Arg.PendingMerges)),
13794 ConstantMultipleCache(std::move(Arg.ConstantMultipleCache)),
13795 BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)),
13796 PredicatedBackedgeTakenCounts(
13797 std::move(Arg.PredicatedBackedgeTakenCounts)),
13798 BECountUsers(std::move(Arg.BECountUsers)),
13799 ConstantEvolutionLoopExitValue(
13800 std::move(Arg.ConstantEvolutionLoopExitValue)),
13801 ValuesAtScopes(std::move(Arg.ValuesAtScopes)),
13802 ValuesAtScopesUsers(std::move(Arg.ValuesAtScopesUsers)),
13803 LoopDispositions(std::move(Arg.LoopDispositions)),
13804 LoopPropertiesCache(std::move(Arg.LoopPropertiesCache)),
13805 BlockDispositions(std::move(Arg.BlockDispositions)),
13806 SCEVUsers(std::move(Arg.SCEVUsers)),
13807 UnsignedRanges(std::move(Arg.UnsignedRanges)),
13808 SignedRanges(std::move(Arg.SignedRanges)),
13809 UniqueSCEVs(std::move(Arg.UniqueSCEVs)),
13810 UniquePreds(std::move(Arg.UniquePreds)),
13811 SCEVAllocator(std::move(Arg.SCEVAllocator)),
13812 LoopUsers(std::move(Arg.LoopUsers)),
13813 PredicatedSCEVRewrites(std::move(Arg.PredicatedSCEVRewrites)),
13814 FirstUnknown(Arg.FirstUnknown) {
13815 Arg.FirstUnknown = nullptr;
13816}
13817
13819 // Iterate through all the SCEVUnknown instances and call their
13820 // destructors, so that they release their references to their values.
13821 for (SCEVUnknown *U = FirstUnknown; U;) {
13822 SCEVUnknown *Tmp = U;
13823 U = U->Next;
13824 Tmp->~SCEVUnknown();
13825 }
13826 FirstUnknown = nullptr;
13827
13828 ExprValueMap.clear();
13829 ValueExprMap.clear();
13830 HasRecMap.clear();
13831 BackedgeTakenCounts.clear();
13832 PredicatedBackedgeTakenCounts.clear();
13833
13834 assert(PendingLoopPredicates.empty() && "isImpliedCond garbage");
13835 assert(PendingPhiRanges.empty() && "getRangeRef garbage");
13836 assert(PendingMerges.empty() && "isImpliedViaMerge garbage");
13837 assert(!WalkingBEDominatingConds && "isLoopBackedgeGuardedByCond garbage!");
13838 assert(!ProvingSplitPredicate && "ProvingSplitPredicate garbage!");
13839}
13840
13844
13845/// When printing a top-level SCEV for trip counts, it's helpful to include
13846/// a type for constants which are otherwise hard to disambiguate.
13847static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV* S) {
13848 if (isa<SCEVConstant>(S))
13849 OS << *S->getType() << " ";
13850 OS << *S;
13851}
13852
13854 const Loop *L) {
13855 // Print all inner loops first
13856 for (Loop *I : *L)
13857 PrintLoopInfo(OS, SE, I);
13858
13859 OS << "Loop ";
13860 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13861 OS << ": ";
13862
13863 SmallVector<BasicBlock *, 8> ExitingBlocks;
13864 L->getExitingBlocks(ExitingBlocks);
13865 if (ExitingBlocks.size() != 1)
13866 OS << "<multiple exits> ";
13867
13868 auto *BTC = SE->getBackedgeTakenCount(L);
13869 if (!isa<SCEVCouldNotCompute>(BTC)) {
13870 OS << "backedge-taken count is ";
13871 PrintSCEVWithTypeHint(OS, BTC);
13872 } else
13873 OS << "Unpredictable backedge-taken count.";
13874 OS << "\n";
13875
13876 if (ExitingBlocks.size() > 1)
13877 for (BasicBlock *ExitingBlock : ExitingBlocks) {
13878 OS << " exit count for " << ExitingBlock->getName() << ": ";
13879 const SCEV *EC = SE->getExitCount(L, ExitingBlock);
13880 PrintSCEVWithTypeHint(OS, EC);
13881 if (isa<SCEVCouldNotCompute>(EC)) {
13882 // Retry with predicates.
13884 EC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates);
13885 if (!isa<SCEVCouldNotCompute>(EC)) {
13886 OS << "\n predicated exit count for " << ExitingBlock->getName()
13887 << ": ";
13888 PrintSCEVWithTypeHint(OS, EC);
13889 OS << "\n Predicates:\n";
13890 for (const auto *P : Predicates)
13891 P->print(OS, 4);
13892 }
13893 }
13894 OS << "\n";
13895 }
13896
13897 OS << "Loop ";
13898 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13899 OS << ": ";
13900
13901 auto *ConstantBTC = SE->getConstantMaxBackedgeTakenCount(L);
13902 if (!isa<SCEVCouldNotCompute>(ConstantBTC)) {
13903 OS << "constant max backedge-taken count is ";
13904 PrintSCEVWithTypeHint(OS, ConstantBTC);
13906 OS << ", actual taken count either this or zero.";
13907 } else {
13908 OS << "Unpredictable constant max backedge-taken count. ";
13909 }
13910
13911 OS << "\n"
13912 "Loop ";
13913 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13914 OS << ": ";
13915
13916 auto *SymbolicBTC = SE->getSymbolicMaxBackedgeTakenCount(L);
13917 if (!isa<SCEVCouldNotCompute>(SymbolicBTC)) {
13918 OS << "symbolic max backedge-taken count is ";
13919 PrintSCEVWithTypeHint(OS, SymbolicBTC);
13921 OS << ", actual taken count either this or zero.";
13922 } else {
13923 OS << "Unpredictable symbolic max backedge-taken count. ";
13924 }
13925 OS << "\n";
13926
13927 if (ExitingBlocks.size() > 1)
13928 for (BasicBlock *ExitingBlock : ExitingBlocks) {
13929 OS << " symbolic max exit count for " << ExitingBlock->getName() << ": ";
13930 auto *ExitBTC = SE->getExitCount(L, ExitingBlock,
13932 PrintSCEVWithTypeHint(OS, ExitBTC);
13933 if (isa<SCEVCouldNotCompute>(ExitBTC)) {
13934 // Retry with predicates.
13936 ExitBTC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates,
13938 if (!isa<SCEVCouldNotCompute>(ExitBTC)) {
13939 OS << "\n predicated symbolic max exit count for "
13940 << ExitingBlock->getName() << ": ";
13941 PrintSCEVWithTypeHint(OS, ExitBTC);
13942 OS << "\n Predicates:\n";
13943 for (const auto *P : Predicates)
13944 P->print(OS, 4);
13945 }
13946 }
13947 OS << "\n";
13948 }
13949
13951 auto *PBT = SE->getPredicatedBackedgeTakenCount(L, Preds);
13952 if (PBT != BTC) {
13953 assert(!Preds.empty() && "Different predicated BTC, but no predicates");
13954 OS << "Loop ";
13955 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13956 OS << ": ";
13957 if (!isa<SCEVCouldNotCompute>(PBT)) {
13958 OS << "Predicated backedge-taken count is ";
13959 PrintSCEVWithTypeHint(OS, PBT);
13960 } else
13961 OS << "Unpredictable predicated backedge-taken count.";
13962 OS << "\n";
13963 OS << " Predicates:\n";
13964 for (const auto *P : Preds)
13965 P->print(OS, 4);
13966 }
13967 Preds.clear();
13968
13969 auto *PredConstantMax =
13971 if (PredConstantMax != ConstantBTC) {
13972 assert(!Preds.empty() &&
13973 "different predicated constant max BTC but no predicates");
13974 OS << "Loop ";
13975 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13976 OS << ": ";
13977 if (!isa<SCEVCouldNotCompute>(PredConstantMax)) {
13978 OS << "Predicated constant max backedge-taken count is ";
13979 PrintSCEVWithTypeHint(OS, PredConstantMax);
13980 } else
13981 OS << "Unpredictable predicated constant max backedge-taken count.";
13982 OS << "\n";
13983 OS << " Predicates:\n";
13984 for (const auto *P : Preds)
13985 P->print(OS, 4);
13986 }
13987 Preds.clear();
13988
13989 auto *PredSymbolicMax =
13991 if (SymbolicBTC != PredSymbolicMax) {
13992 assert(!Preds.empty() &&
13993 "Different predicated symbolic max BTC, but no predicates");
13994 OS << "Loop ";
13995 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
13996 OS << ": ";
13997 if (!isa<SCEVCouldNotCompute>(PredSymbolicMax)) {
13998 OS << "Predicated symbolic max backedge-taken count is ";
13999 PrintSCEVWithTypeHint(OS, PredSymbolicMax);
14000 } else
14001 OS << "Unpredictable predicated symbolic max backedge-taken count.";
14002 OS << "\n";
14003 OS << " Predicates:\n";
14004 for (const auto *P : Preds)
14005 P->print(OS, 4);
14006 }
14007
14009 OS << "Loop ";
14010 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14011 OS << ": ";
14012 OS << "Trip multiple is " << SE->getSmallConstantTripMultiple(L) << "\n";
14013 }
14014}
14015
14016namespace llvm {
14017// Note: these overloaded operators need to be in the llvm namespace for them
14018// to be resolved correctly. If we put them outside the llvm namespace, the
14019//
14020// OS << ": " << SE.getLoopDisposition(SV, InnerL);
14021//
14022// code below "breaks" and start printing raw enum values as opposed to the
14023// string values.
14026 switch (LD) {
14028 OS << "Variant";
14029 break;
14031 OS << "Invariant";
14032 break;
14034 OS << "Computable";
14035 break;
14036 }
14037 return OS;
14038}
14039
14042 switch (BD) {
14044 OS << "DoesNotDominate";
14045 break;
14047 OS << "Dominates";
14048 break;
14050 OS << "ProperlyDominates";
14051 break;
14052 }
14053 return OS;
14054}
14055} // namespace llvm
14056
14058 // ScalarEvolution's implementation of the print method is to print
14059 // out SCEV values of all instructions that are interesting. Doing
14060 // this potentially causes it to create new SCEV objects though,
14061 // which technically conflicts with the const qualifier. This isn't
14062 // observable from outside the class though, so casting away the
14063 // const isn't dangerous.
14064 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
14065
14066 if (ClassifyExpressions) {
14067 OS << "Classifying expressions for: ";
14068 F.printAsOperand(OS, /*PrintType=*/false);
14069 OS << "\n";
14070 for (Instruction &I : instructions(F))
14071 if (isSCEVable(I.getType()) && !isa<CmpInst>(I)) {
14072 OS << I << '\n';
14073 OS << " --> ";
14074 const SCEV *SV = SE.getSCEV(&I);
14075 SV->print(OS);
14076 if (!isa<SCEVCouldNotCompute>(SV)) {
14077 OS << " U: ";
14078 SE.getUnsignedRange(SV).print(OS);
14079 OS << " S: ";
14080 SE.getSignedRange(SV).print(OS);
14081 }
14082
14083 const Loop *L = LI.getLoopFor(I.getParent());
14084
14085 const SCEV *AtUse = SE.getSCEVAtScope(SV, L);
14086 if (AtUse != SV) {
14087 OS << " --> ";
14088 AtUse->print(OS);
14089 if (!isa<SCEVCouldNotCompute>(AtUse)) {
14090 OS << " U: ";
14091 SE.getUnsignedRange(AtUse).print(OS);
14092 OS << " S: ";
14093 SE.getSignedRange(AtUse).print(OS);
14094 }
14095 }
14096
14097 if (L) {
14098 OS << "\t\t" "Exits: ";
14099 const SCEV *ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop());
14100 if (!SE.isLoopInvariant(ExitValue, L)) {
14101 OS << "<<Unknown>>";
14102 } else {
14103 OS << *ExitValue;
14104 }
14105
14106 bool First = true;
14107 for (const auto *Iter = L; Iter; Iter = Iter->getParentLoop()) {
14108 if (First) {
14109 OS << "\t\t" "LoopDispositions: { ";
14110 First = false;
14111 } else {
14112 OS << ", ";
14113 }
14114
14115 Iter->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14116 OS << ": " << SE.getLoopDisposition(SV, Iter);
14117 }
14118
14119 for (const auto *InnerL : depth_first(L)) {
14120 if (InnerL == L)
14121 continue;
14122 if (First) {
14123 OS << "\t\t" "LoopDispositions: { ";
14124 First = false;
14125 } else {
14126 OS << ", ";
14127 }
14128
14129 InnerL->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14130 OS << ": " << SE.getLoopDisposition(SV, InnerL);
14131 }
14132
14133 OS << " }";
14134 }
14135
14136 OS << "\n";
14137 }
14138 }
14139
14140 OS << "Determining loop execution counts for: ";
14141 F.printAsOperand(OS, /*PrintType=*/false);
14142 OS << "\n";
14143 for (Loop *I : LI)
14144 PrintLoopInfo(OS, &SE, I);
14145}
14146
14149 auto &Values = LoopDispositions[S];
14150 for (auto &V : Values) {
14151 if (V.getPointer() == L)
14152 return V.getInt();
14153 }
14154 Values.emplace_back(L, LoopVariant);
14155 LoopDisposition D = computeLoopDisposition(S, L);
14156 auto &Values2 = LoopDispositions[S];
14157 for (auto &V : llvm::reverse(Values2)) {
14158 if (V.getPointer() == L) {
14159 V.setInt(D);
14160 break;
14161 }
14162 }
14163 return D;
14164}
14165
14167ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) {
14168 switch (S->getSCEVType()) {
14169 case scConstant:
14170 case scVScale:
14171 return LoopInvariant;
14172 case scAddRecExpr: {
14173 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
14174
14175 // If L is the addrec's loop, it's computable.
14176 if (AR->getLoop() == L)
14177 return LoopComputable;
14178
14179 // Add recurrences are never invariant in the function-body (null loop).
14180 if (!L)
14181 return LoopVariant;
14182
14183 // Everything that is not defined at loop entry is variant.
14184 if (DT.dominates(L->getHeader(), AR->getLoop()->getHeader()))
14185 return LoopVariant;
14186 assert(!L->contains(AR->getLoop()) && "Containing loop's header does not"
14187 " dominate the contained loop's header?");
14188
14189 // This recurrence is invariant w.r.t. L if AR's loop contains L.
14190 if (AR->getLoop()->contains(L))
14191 return LoopInvariant;
14192
14193 // This recurrence is variant w.r.t. L if any of its operands
14194 // are variant.
14195 for (const auto *Op : AR->operands())
14196 if (!isLoopInvariant(Op, L))
14197 return LoopVariant;
14198
14199 // Otherwise it's loop-invariant.
14200 return LoopInvariant;
14201 }
14202 case scTruncate:
14203 case scZeroExtend:
14204 case scSignExtend:
14205 case scPtrToInt:
14206 case scAddExpr:
14207 case scMulExpr:
14208 case scUDivExpr:
14209 case scUMaxExpr:
14210 case scSMaxExpr:
14211 case scUMinExpr:
14212 case scSMinExpr:
14213 case scSequentialUMinExpr: {
14214 bool HasVarying = false;
14215 for (const auto *Op : S->operands()) {
14217 if (D == LoopVariant)
14218 return LoopVariant;
14219 if (D == LoopComputable)
14220 HasVarying = true;
14221 }
14222 return HasVarying ? LoopComputable : LoopInvariant;
14223 }
14224 case scUnknown:
14225 // All non-instruction values are loop invariant. All instructions are loop
14226 // invariant if they are not contained in the specified loop.
14227 // Instructions are never considered invariant in the function body
14228 // (null loop) because they are defined within the "loop".
14229 if (auto *I = dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue()))
14230 return (L && !L->contains(I)) ? LoopInvariant : LoopVariant;
14231 return LoopInvariant;
14232 case scCouldNotCompute:
14233 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
14234 }
14235 llvm_unreachable("Unknown SCEV kind!");
14236}
14237
14239 return getLoopDisposition(S, L) == LoopInvariant;
14240}
14241
14243 return getLoopDisposition(S, L) == LoopComputable;
14244}
14245
14248 auto &Values = BlockDispositions[S];
14249 for (auto &V : Values) {
14250 if (V.getPointer() == BB)
14251 return V.getInt();
14252 }
14253 Values.emplace_back(BB, DoesNotDominateBlock);
14254 BlockDisposition D = computeBlockDisposition(S, BB);
14255 auto &Values2 = BlockDispositions[S];
14256 for (auto &V : llvm::reverse(Values2)) {
14257 if (V.getPointer() == BB) {
14258 V.setInt(D);
14259 break;
14260 }
14261 }
14262 return D;
14263}
14264
14266ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) {
14267 switch (S->getSCEVType()) {
14268 case scConstant:
14269 case scVScale:
14271 case scAddRecExpr: {
14272 // This uses a "dominates" query instead of "properly dominates" query
14273 // to test for proper dominance too, because the instruction which
14274 // produces the addrec's value is a PHI, and a PHI effectively properly
14275 // dominates its entire containing block.
14276 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
14277 if (!DT.dominates(AR->getLoop()->getHeader(), BB))
14278 return DoesNotDominateBlock;
14279
14280 // Fall through into SCEVNAryExpr handling.
14281 [[fallthrough]];
14282 }
14283 case scTruncate:
14284 case scZeroExtend:
14285 case scSignExtend:
14286 case scPtrToInt:
14287 case scAddExpr:
14288 case scMulExpr:
14289 case scUDivExpr:
14290 case scUMaxExpr:
14291 case scSMaxExpr:
14292 case scUMinExpr:
14293 case scSMinExpr:
14294 case scSequentialUMinExpr: {
14295 bool Proper = true;
14296 for (const SCEV *NAryOp : S->operands()) {
14298 if (D == DoesNotDominateBlock)
14299 return DoesNotDominateBlock;
14300 if (D == DominatesBlock)
14301 Proper = false;
14302 }
14303 return Proper ? ProperlyDominatesBlock : DominatesBlock;
14304 }
14305 case scUnknown:
14306 if (Instruction *I =
14307 dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue())) {
14308 if (I->getParent() == BB)
14309 return DominatesBlock;
14310 if (DT.properlyDominates(I->getParent(), BB))
14312 return DoesNotDominateBlock;
14313 }
14315 case scCouldNotCompute:
14316 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
14317 }
14318 llvm_unreachable("Unknown SCEV kind!");
14319}
14320
14321bool ScalarEvolution::dominates(const SCEV *S, const BasicBlock *BB) {
14322 return getBlockDisposition(S, BB) >= DominatesBlock;
14323}
14324
14327}
14328
14329bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const {
14330 return SCEVExprContains(S, [&](const SCEV *Expr) { return Expr == Op; });
14331}
14332
14333void ScalarEvolution::forgetBackedgeTakenCounts(const Loop *L,
14334 bool Predicated) {
14335 auto &BECounts =
14336 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
14337 auto It = BECounts.find(L);
14338 if (It != BECounts.end()) {
14339 for (const ExitNotTakenInfo &ENT : It->second.ExitNotTaken) {
14340 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
14341 if (!isa<SCEVConstant>(S)) {
14342 auto UserIt = BECountUsers.find(S);
14343 assert(UserIt != BECountUsers.end());
14344 UserIt->second.erase({L, Predicated});
14345 }
14346 }
14347 }
14348 BECounts.erase(It);
14349 }
14350}
14351
14352void ScalarEvolution::forgetMemoizedResults(ArrayRef<const SCEV *> SCEVs) {
14353 SmallPtrSet<const SCEV *, 8> ToForget(llvm::from_range, SCEVs);
14354 SmallVector<const SCEV *, 8> Worklist(ToForget.begin(), ToForget.end());
14355
14356 while (!Worklist.empty()) {
14357 const SCEV *Curr = Worklist.pop_back_val();
14358 auto Users = SCEVUsers.find(Curr);
14359 if (Users != SCEVUsers.end())
14360 for (const auto *User : Users->second)
14361 if (ToForget.insert(User).second)
14362 Worklist.push_back(User);
14363 }
14364
14365 for (const auto *S : ToForget)
14366 forgetMemoizedResultsImpl(S);
14367
14368 for (auto I = PredicatedSCEVRewrites.begin();
14369 I != PredicatedSCEVRewrites.end();) {
14370 std::pair<const SCEV *, const Loop *> Entry = I->first;
14371 if (ToForget.count(Entry.first))
14372 PredicatedSCEVRewrites.erase(I++);
14373 else
14374 ++I;
14375 }
14376}
14377
14378void ScalarEvolution::forgetMemoizedResultsImpl(const SCEV *S) {
14379 LoopDispositions.erase(S);
14380 BlockDispositions.erase(S);
14381 UnsignedRanges.erase(S);
14382 SignedRanges.erase(S);
14383 HasRecMap.erase(S);
14384 ConstantMultipleCache.erase(S);
14385
14386 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S)) {
14387 UnsignedWrapViaInductionTried.erase(AR);
14388 SignedWrapViaInductionTried.erase(AR);
14389 }
14390
14391 auto ExprIt = ExprValueMap.find(S);
14392 if (ExprIt != ExprValueMap.end()) {
14393 for (Value *V : ExprIt->second) {
14394 auto ValueIt = ValueExprMap.find_as(V);
14395 if (ValueIt != ValueExprMap.end())
14396 ValueExprMap.erase(ValueIt);
14397 }
14398 ExprValueMap.erase(ExprIt);
14399 }
14400
14401 auto ScopeIt = ValuesAtScopes.find(S);
14402 if (ScopeIt != ValuesAtScopes.end()) {
14403 for (const auto &Pair : ScopeIt->second)
14404 if (!isa_and_nonnull<SCEVConstant>(Pair.second))
14405 llvm::erase(ValuesAtScopesUsers[Pair.second],
14406 std::make_pair(Pair.first, S));
14407 ValuesAtScopes.erase(ScopeIt);
14408 }
14409
14410 auto ScopeUserIt = ValuesAtScopesUsers.find(S);
14411 if (ScopeUserIt != ValuesAtScopesUsers.end()) {
14412 for (const auto &Pair : ScopeUserIt->second)
14413 llvm::erase(ValuesAtScopes[Pair.second], std::make_pair(Pair.first, S));
14414 ValuesAtScopesUsers.erase(ScopeUserIt);
14415 }
14416
14417 auto BEUsersIt = BECountUsers.find(S);
14418 if (BEUsersIt != BECountUsers.end()) {
14419 // Work on a copy, as forgetBackedgeTakenCounts() will modify the original.
14420 auto Copy = BEUsersIt->second;
14421 for (const auto &Pair : Copy)
14422 forgetBackedgeTakenCounts(Pair.getPointer(), Pair.getInt());
14423 BECountUsers.erase(BEUsersIt);
14424 }
14425
14426 auto FoldUser = FoldCacheUser.find(S);
14427 if (FoldUser != FoldCacheUser.end())
14428 for (auto &KV : FoldUser->second)
14429 FoldCache.erase(KV);
14430 FoldCacheUser.erase(S);
14431}
14432
14433void
14434ScalarEvolution::getUsedLoops(const SCEV *S,
14435 SmallPtrSetImpl<const Loop *> &LoopsUsed) {
14436 struct FindUsedLoops {
14437 FindUsedLoops(SmallPtrSetImpl<const Loop *> &LoopsUsed)
14438 : LoopsUsed(LoopsUsed) {}
14439 SmallPtrSetImpl<const Loop *> &LoopsUsed;
14440 bool follow(const SCEV *S) {
14441 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S))
14442 LoopsUsed.insert(AR->getLoop());
14443 return true;
14444 }
14445
14446 bool isDone() const { return false; }
14447 };
14448
14449 FindUsedLoops F(LoopsUsed);
14450 SCEVTraversal<FindUsedLoops>(F).visitAll(S);
14451}
14452
14453void ScalarEvolution::getReachableBlocks(
14456 Worklist.push_back(&F.getEntryBlock());
14457 while (!Worklist.empty()) {
14458 BasicBlock *BB = Worklist.pop_back_val();
14459 if (!Reachable.insert(BB).second)
14460 continue;
14461
14462 Value *Cond;
14463 BasicBlock *TrueBB, *FalseBB;
14464 if (match(BB->getTerminator(), m_Br(m_Value(Cond), m_BasicBlock(TrueBB),
14465 m_BasicBlock(FalseBB)))) {
14466 if (auto *C = dyn_cast<ConstantInt>(Cond)) {
14467 Worklist.push_back(C->isOne() ? TrueBB : FalseBB);
14468 continue;
14469 }
14470
14471 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
14472 const SCEV *L = getSCEV(Cmp->getOperand(0));
14473 const SCEV *R = getSCEV(Cmp->getOperand(1));
14474 if (isKnownPredicateViaConstantRanges(Cmp->getCmpPredicate(), L, R)) {
14475 Worklist.push_back(TrueBB);
14476 continue;
14477 }
14478 if (isKnownPredicateViaConstantRanges(Cmp->getInverseCmpPredicate(), L,
14479 R)) {
14480 Worklist.push_back(FalseBB);
14481 continue;
14482 }
14483 }
14484 }
14485
14486 append_range(Worklist, successors(BB));
14487 }
14488}
14489
14491 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
14492 ScalarEvolution SE2(F, TLI, AC, DT, LI);
14493
14494 SmallVector<Loop *, 8> LoopStack(LI.begin(), LI.end());
14495
14496 // Map's SCEV expressions from one ScalarEvolution "universe" to another.
14497 struct SCEVMapper : public SCEVRewriteVisitor<SCEVMapper> {
14498 SCEVMapper(ScalarEvolution &SE) : SCEVRewriteVisitor<SCEVMapper>(SE) {}
14499
14500 const SCEV *visitConstant(const SCEVConstant *Constant) {
14501 return SE.getConstant(Constant->getAPInt());
14502 }
14503
14504 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14505 return SE.getUnknown(Expr->getValue());
14506 }
14507
14508 const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
14509 return SE.getCouldNotCompute();
14510 }
14511 };
14512
14513 SCEVMapper SCM(SE2);
14514 SmallPtrSet<BasicBlock *, 16> ReachableBlocks;
14515 SE2.getReachableBlocks(ReachableBlocks, F);
14516
14517 auto GetDelta = [&](const SCEV *Old, const SCEV *New) -> const SCEV * {
14518 if (containsUndefs(Old) || containsUndefs(New)) {
14519 // SCEV treats "undef" as an unknown but consistent value (i.e. it does
14520 // not propagate undef aggressively). This means we can (and do) fail
14521 // verification in cases where a transform makes a value go from "undef"
14522 // to "undef+1" (say). The transform is fine, since in both cases the
14523 // result is "undef", but SCEV thinks the value increased by 1.
14524 return nullptr;
14525 }
14526
14527 // Unless VerifySCEVStrict is set, we only compare constant deltas.
14528 const SCEV *Delta = SE2.getMinusSCEV(Old, New);
14529 if (!VerifySCEVStrict && !isa<SCEVConstant>(Delta))
14530 return nullptr;
14531
14532 return Delta;
14533 };
14534
14535 while (!LoopStack.empty()) {
14536 auto *L = LoopStack.pop_back_val();
14537 llvm::append_range(LoopStack, *L);
14538
14539 // Only verify BECounts in reachable loops. For an unreachable loop,
14540 // any BECount is legal.
14541 if (!ReachableBlocks.contains(L->getHeader()))
14542 continue;
14543
14544 // Only verify cached BECounts. Computing new BECounts may change the
14545 // results of subsequent SCEV uses.
14546 auto It = BackedgeTakenCounts.find(L);
14547 if (It == BackedgeTakenCounts.end())
14548 continue;
14549
14550 auto *CurBECount =
14551 SCM.visit(It->second.getExact(L, const_cast<ScalarEvolution *>(this)));
14552 auto *NewBECount = SE2.getBackedgeTakenCount(L);
14553
14554 if (CurBECount == SE2.getCouldNotCompute() ||
14555 NewBECount == SE2.getCouldNotCompute()) {
14556 // NB! This situation is legal, but is very suspicious -- whatever pass
14557 // change the loop to make a trip count go from could not compute to
14558 // computable or vice-versa *should have* invalidated SCEV. However, we
14559 // choose not to assert here (for now) since we don't want false
14560 // positives.
14561 continue;
14562 }
14563
14564 if (SE.getTypeSizeInBits(CurBECount->getType()) >
14565 SE.getTypeSizeInBits(NewBECount->getType()))
14566 NewBECount = SE2.getZeroExtendExpr(NewBECount, CurBECount->getType());
14567 else if (SE.getTypeSizeInBits(CurBECount->getType()) <
14568 SE.getTypeSizeInBits(NewBECount->getType()))
14569 CurBECount = SE2.getZeroExtendExpr(CurBECount, NewBECount->getType());
14570
14571 const SCEV *Delta = GetDelta(CurBECount, NewBECount);
14572 if (Delta && !Delta->isZero()) {
14573 dbgs() << "Trip Count for " << *L << " Changed!\n";
14574 dbgs() << "Old: " << *CurBECount << "\n";
14575 dbgs() << "New: " << *NewBECount << "\n";
14576 dbgs() << "Delta: " << *Delta << "\n";
14577 std::abort();
14578 }
14579 }
14580
14581 // Collect all valid loops currently in LoopInfo.
14582 SmallPtrSet<Loop *, 32> ValidLoops;
14583 SmallVector<Loop *, 32> Worklist(LI.begin(), LI.end());
14584 while (!Worklist.empty()) {
14585 Loop *L = Worklist.pop_back_val();
14586 if (ValidLoops.insert(L).second)
14587 Worklist.append(L->begin(), L->end());
14588 }
14589 for (const auto &KV : ValueExprMap) {
14590#ifndef NDEBUG
14591 // Check for SCEV expressions referencing invalid/deleted loops.
14592 if (auto *AR = dyn_cast<SCEVAddRecExpr>(KV.second)) {
14593 assert(ValidLoops.contains(AR->getLoop()) &&
14594 "AddRec references invalid loop");
14595 }
14596#endif
14597
14598 // Check that the value is also part of the reverse map.
14599 auto It = ExprValueMap.find(KV.second);
14600 if (It == ExprValueMap.end() || !It->second.contains(KV.first)) {
14601 dbgs() << "Value " << *KV.first
14602 << " is in ValueExprMap but not in ExprValueMap\n";
14603 std::abort();
14604 }
14605
14606 if (auto *I = dyn_cast<Instruction>(&*KV.first)) {
14607 if (!ReachableBlocks.contains(I->getParent()))
14608 continue;
14609 const SCEV *OldSCEV = SCM.visit(KV.second);
14610 const SCEV *NewSCEV = SE2.getSCEV(I);
14611 const SCEV *Delta = GetDelta(OldSCEV, NewSCEV);
14612 if (Delta && !Delta->isZero()) {
14613 dbgs() << "SCEV for value " << *I << " changed!\n"
14614 << "Old: " << *OldSCEV << "\n"
14615 << "New: " << *NewSCEV << "\n"
14616 << "Delta: " << *Delta << "\n";
14617 std::abort();
14618 }
14619 }
14620 }
14621
14622 for (const auto &KV : ExprValueMap) {
14623 for (Value *V : KV.second) {
14624 const SCEV *S = ValueExprMap.lookup(V);
14625 if (!S) {
14626 dbgs() << "Value " << *V
14627 << " is in ExprValueMap but not in ValueExprMap\n";
14628 std::abort();
14629 }
14630 if (S != KV.first) {
14631 dbgs() << "Value " << *V << " mapped to " << *S << " rather than "
14632 << *KV.first << "\n";
14633 std::abort();
14634 }
14635 }
14636 }
14637
14638 // Verify integrity of SCEV users.
14639 for (const auto &S : UniqueSCEVs) {
14640 for (const auto *Op : S.operands()) {
14641 // We do not store dependencies of constants.
14642 if (isa<SCEVConstant>(Op))
14643 continue;
14644 auto It = SCEVUsers.find(Op);
14645 if (It != SCEVUsers.end() && It->second.count(&S))
14646 continue;
14647 dbgs() << "Use of operand " << *Op << " by user " << S
14648 << " is not being tracked!\n";
14649 std::abort();
14650 }
14651 }
14652
14653 // Verify integrity of ValuesAtScopes users.
14654 for (const auto &ValueAndVec : ValuesAtScopes) {
14655 const SCEV *Value = ValueAndVec.first;
14656 for (const auto &LoopAndValueAtScope : ValueAndVec.second) {
14657 const Loop *L = LoopAndValueAtScope.first;
14658 const SCEV *ValueAtScope = LoopAndValueAtScope.second;
14659 if (!isa<SCEVConstant>(ValueAtScope)) {
14660 auto It = ValuesAtScopesUsers.find(ValueAtScope);
14661 if (It != ValuesAtScopesUsers.end() &&
14662 is_contained(It->second, std::make_pair(L, Value)))
14663 continue;
14664 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14665 << *ValueAtScope << " missing in ValuesAtScopesUsers\n";
14666 std::abort();
14667 }
14668 }
14669 }
14670
14671 for (const auto &ValueAtScopeAndVec : ValuesAtScopesUsers) {
14672 const SCEV *ValueAtScope = ValueAtScopeAndVec.first;
14673 for (const auto &LoopAndValue : ValueAtScopeAndVec.second) {
14674 const Loop *L = LoopAndValue.first;
14675 const SCEV *Value = LoopAndValue.second;
14677 auto It = ValuesAtScopes.find(Value);
14678 if (It != ValuesAtScopes.end() &&
14679 is_contained(It->second, std::make_pair(L, ValueAtScope)))
14680 continue;
14681 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14682 << *ValueAtScope << " missing in ValuesAtScopes\n";
14683 std::abort();
14684 }
14685 }
14686
14687 // Verify integrity of BECountUsers.
14688 auto VerifyBECountUsers = [&](bool Predicated) {
14689 auto &BECounts =
14690 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
14691 for (const auto &LoopAndBEInfo : BECounts) {
14692 for (const ExitNotTakenInfo &ENT : LoopAndBEInfo.second.ExitNotTaken) {
14693 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
14694 if (!isa<SCEVConstant>(S)) {
14695 auto UserIt = BECountUsers.find(S);
14696 if (UserIt != BECountUsers.end() &&
14697 UserIt->second.contains({ LoopAndBEInfo.first, Predicated }))
14698 continue;
14699 dbgs() << "Value " << *S << " for loop " << *LoopAndBEInfo.first
14700 << " missing from BECountUsers\n";
14701 std::abort();
14702 }
14703 }
14704 }
14705 }
14706 };
14707 VerifyBECountUsers(/* Predicated */ false);
14708 VerifyBECountUsers(/* Predicated */ true);
14709
14710 // Verify intergity of loop disposition cache.
14711 for (auto &[S, Values] : LoopDispositions) {
14712 for (auto [Loop, CachedDisposition] : Values) {
14713 const auto RecomputedDisposition = SE2.getLoopDisposition(S, Loop);
14714 if (CachedDisposition != RecomputedDisposition) {
14715 dbgs() << "Cached disposition of " << *S << " for loop " << *Loop
14716 << " is incorrect: cached " << CachedDisposition << ", actual "
14717 << RecomputedDisposition << "\n";
14718 std::abort();
14719 }
14720 }
14721 }
14722
14723 // Verify integrity of the block disposition cache.
14724 for (auto &[S, Values] : BlockDispositions) {
14725 for (auto [BB, CachedDisposition] : Values) {
14726 const auto RecomputedDisposition = SE2.getBlockDisposition(S, BB);
14727 if (CachedDisposition != RecomputedDisposition) {
14728 dbgs() << "Cached disposition of " << *S << " for block %"
14729 << BB->getName() << " is incorrect: cached " << CachedDisposition
14730 << ", actual " << RecomputedDisposition << "\n";
14731 std::abort();
14732 }
14733 }
14734 }
14735
14736 // Verify FoldCache/FoldCacheUser caches.
14737 for (auto [FoldID, Expr] : FoldCache) {
14738 auto I = FoldCacheUser.find(Expr);
14739 if (I == FoldCacheUser.end()) {
14740 dbgs() << "Missing entry in FoldCacheUser for cached expression " << *Expr
14741 << "!\n";
14742 std::abort();
14743 }
14744 if (!is_contained(I->second, FoldID)) {
14745 dbgs() << "Missing FoldID in cached users of " << *Expr << "!\n";
14746 std::abort();
14747 }
14748 }
14749 for (auto [Expr, IDs] : FoldCacheUser) {
14750 for (auto &FoldID : IDs) {
14751 const SCEV *S = FoldCache.lookup(FoldID);
14752 if (!S) {
14753 dbgs() << "Missing entry in FoldCache for expression " << *Expr
14754 << "!\n";
14755 std::abort();
14756 }
14757 if (S != Expr) {
14758 dbgs() << "Entry in FoldCache doesn't match FoldCacheUser: " << *S
14759 << " != " << *Expr << "!\n";
14760 std::abort();
14761 }
14762 }
14763 }
14764
14765 // Verify that ConstantMultipleCache computations are correct. We check that
14766 // cached multiples and recomputed multiples are multiples of each other to
14767 // verify correctness. It is possible that a recomputed multiple is different
14768 // from the cached multiple due to strengthened no wrap flags or changes in
14769 // KnownBits computations.
14770 for (auto [S, Multiple] : ConstantMultipleCache) {
14771 APInt RecomputedMultiple = SE2.getConstantMultiple(S);
14772 if ((Multiple != 0 && RecomputedMultiple != 0 &&
14773 Multiple.urem(RecomputedMultiple) != 0 &&
14774 RecomputedMultiple.urem(Multiple) != 0)) {
14775 dbgs() << "Incorrect cached computation in ConstantMultipleCache for "
14776 << *S << " : Computed " << RecomputedMultiple
14777 << " but cache contains " << Multiple << "!\n";
14778 std::abort();
14779 }
14780 }
14781}
14782
14784 Function &F, const PreservedAnalyses &PA,
14785 FunctionAnalysisManager::Invalidator &Inv) {
14786 // Invalidate the ScalarEvolution object whenever it isn't preserved or one
14787 // of its dependencies is invalidated.
14788 auto PAC = PA.getChecker<ScalarEvolutionAnalysis>();
14789 return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>()) ||
14790 Inv.invalidate<AssumptionAnalysis>(F, PA) ||
14791 Inv.invalidate<DominatorTreeAnalysis>(F, PA) ||
14792 Inv.invalidate<LoopAnalysis>(F, PA);
14793}
14794
14795AnalysisKey ScalarEvolutionAnalysis::Key;
14796
14799 auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
14800 auto &AC = AM.getResult<AssumptionAnalysis>(F);
14801 auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
14802 auto &LI = AM.getResult<LoopAnalysis>(F);
14803 return ScalarEvolution(F, TLI, AC, DT, LI);
14804}
14805
14811
14814 // For compatibility with opt's -analyze feature under legacy pass manager
14815 // which was not ported to NPM. This keeps tests using
14816 // update_analyze_test_checks.py working.
14817 OS << "Printing analysis 'Scalar Evolution Analysis' for function '"
14818 << F.getName() << "':\n";
14820 return PreservedAnalyses::all();
14821}
14822
14824 "Scalar Evolution Analysis", false, true)
14830 "Scalar Evolution Analysis", false, true)
14831
14833
14835
14837 SE.reset(new ScalarEvolution(
14839 getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F),
14841 getAnalysis<LoopInfoWrapperPass>().getLoopInfo()));
14842 return false;
14843}
14844
14846
14848 SE->print(OS);
14849}
14850
14852 if (!VerifySCEV)
14853 return;
14854
14855 SE->verify();
14856}
14857
14865
14867 const SCEV *RHS) {
14868 return getComparePredicate(ICmpInst::ICMP_EQ, LHS, RHS);
14869}
14870
14871const SCEVPredicate *
14873 const SCEV *LHS, const SCEV *RHS) {
14875 assert(LHS->getType() == RHS->getType() &&
14876 "Type mismatch between LHS and RHS");
14877 // Unique this node based on the arguments
14878 ID.AddInteger(SCEVPredicate::P_Compare);
14879 ID.AddInteger(Pred);
14880 ID.AddPointer(LHS);
14881 ID.AddPointer(RHS);
14882 void *IP = nullptr;
14883 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
14884 return S;
14885 SCEVComparePredicate *Eq = new (SCEVAllocator)
14886 SCEVComparePredicate(ID.Intern(SCEVAllocator), Pred, LHS, RHS);
14887 UniquePreds.InsertNode(Eq, IP);
14888 return Eq;
14889}
14890
14892 const SCEVAddRecExpr *AR,
14895 // Unique this node based on the arguments
14896 ID.AddInteger(SCEVPredicate::P_Wrap);
14897 ID.AddPointer(AR);
14898 ID.AddInteger(AddedFlags);
14899 void *IP = nullptr;
14900 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
14901 return S;
14902 auto *OF = new (SCEVAllocator)
14903 SCEVWrapPredicate(ID.Intern(SCEVAllocator), AR, AddedFlags);
14904 UniquePreds.InsertNode(OF, IP);
14905 return OF;
14906}
14907
14908namespace {
14909
14910class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
14911public:
14912
14913 /// Rewrites \p S in the context of a loop L and the SCEV predication
14914 /// infrastructure.
14915 ///
14916 /// If \p Pred is non-null, the SCEV expression is rewritten to respect the
14917 /// equivalences present in \p Pred.
14918 ///
14919 /// If \p NewPreds is non-null, rewrite is free to add further predicates to
14920 /// \p NewPreds such that the result will be an AddRecExpr.
14921 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
14923 const SCEVPredicate *Pred) {
14924 SCEVPredicateRewriter Rewriter(L, SE, NewPreds, Pred);
14925 return Rewriter.visit(S);
14926 }
14927
14928 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14929 if (Pred) {
14930 if (auto *U = dyn_cast<SCEVUnionPredicate>(Pred)) {
14931 for (const auto *Pred : U->getPredicates())
14932 if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred))
14933 if (IPred->getLHS() == Expr &&
14934 IPred->getPredicate() == ICmpInst::ICMP_EQ)
14935 return IPred->getRHS();
14936 } else if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred)) {
14937 if (IPred->getLHS() == Expr &&
14938 IPred->getPredicate() == ICmpInst::ICMP_EQ)
14939 return IPred->getRHS();
14940 }
14941 }
14942 return convertToAddRecWithPreds(Expr);
14943 }
14944
14945 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
14946 const SCEV *Operand = visit(Expr->getOperand());
14947 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
14948 if (AR && AR->getLoop() == L && AR->isAffine()) {
14949 // This couldn't be folded because the operand didn't have the nuw
14950 // flag. Add the nusw flag as an assumption that we could make.
14951 const SCEV *Step = AR->getStepRecurrence(SE);
14952 Type *Ty = Expr->getType();
14953 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNUSW))
14954 return SE.getAddRecExpr(SE.getZeroExtendExpr(AR->getStart(), Ty),
14955 SE.getSignExtendExpr(Step, Ty), L,
14956 AR->getNoWrapFlags());
14957 }
14958 return SE.getZeroExtendExpr(Operand, Expr->getType());
14959 }
14960
14961 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
14962 const SCEV *Operand = visit(Expr->getOperand());
14963 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
14964 if (AR && AR->getLoop() == L && AR->isAffine()) {
14965 // This couldn't be folded because the operand didn't have the nsw
14966 // flag. Add the nssw flag as an assumption that we could make.
14967 const SCEV *Step = AR->getStepRecurrence(SE);
14968 Type *Ty = Expr->getType();
14969 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNSSW))
14970 return SE.getAddRecExpr(SE.getSignExtendExpr(AR->getStart(), Ty),
14971 SE.getSignExtendExpr(Step, Ty), L,
14972 AR->getNoWrapFlags());
14973 }
14974 return SE.getSignExtendExpr(Operand, Expr->getType());
14975 }
14976
14977private:
14978 explicit SCEVPredicateRewriter(
14979 const Loop *L, ScalarEvolution &SE,
14980 SmallVectorImpl<const SCEVPredicate *> *NewPreds,
14981 const SCEVPredicate *Pred)
14982 : SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {}
14983
14984 bool addOverflowAssumption(const SCEVPredicate *P) {
14985 if (!NewPreds) {
14986 // Check if we've already made this assumption.
14987 return Pred && Pred->implies(P, SE);
14988 }
14989 NewPreds->push_back(P);
14990 return true;
14991 }
14992
14993 bool addOverflowAssumption(const SCEVAddRecExpr *AR,
14995 auto *A = SE.getWrapPredicate(AR, AddedFlags);
14996 return addOverflowAssumption(A);
14997 }
14998
14999 // If \p Expr represents a PHINode, we try to see if it can be represented
15000 // as an AddRec, possibly under a predicate (PHISCEVPred). If it is possible
15001 // to add this predicate as a runtime overflow check, we return the AddRec.
15002 // If \p Expr does not meet these conditions (is not a PHI node, or we
15003 // couldn't create an AddRec for it, or couldn't add the predicate), we just
15004 // return \p Expr.
15005 const SCEV *convertToAddRecWithPreds(const SCEVUnknown *Expr) {
15006 if (!isa<PHINode>(Expr->getValue()))
15007 return Expr;
15008 std::optional<
15009 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
15010 PredicatedRewrite = SE.createAddRecFromPHIWithCasts(Expr);
15011 if (!PredicatedRewrite)
15012 return Expr;
15013 for (const auto *P : PredicatedRewrite->second){
15014 // Wrap predicates from outer loops are not supported.
15015 if (auto *WP = dyn_cast<const SCEVWrapPredicate>(P)) {
15016 if (L != WP->getExpr()->getLoop())
15017 return Expr;
15018 }
15019 if (!addOverflowAssumption(P))
15020 return Expr;
15021 }
15022 return PredicatedRewrite->first;
15023 }
15024
15025 SmallVectorImpl<const SCEVPredicate *> *NewPreds;
15026 const SCEVPredicate *Pred;
15027 const Loop *L;
15028};
15029
15030} // end anonymous namespace
15031
15032const SCEV *
15034 const SCEVPredicate &Preds) {
15035 return SCEVPredicateRewriter::rewrite(S, L, *this, nullptr, &Preds);
15036}
15037
15039 const SCEV *S, const Loop *L,
15042 S = SCEVPredicateRewriter::rewrite(S, L, *this, &TransformPreds, nullptr);
15043 auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
15044
15045 if (!AddRec)
15046 return nullptr;
15047
15048 // Check if any of the transformed predicates is known to be false. In that
15049 // case, it doesn't make sense to convert to a predicated AddRec, as the
15050 // versioned loop will never execute.
15051 for (const SCEVPredicate *Pred : TransformPreds) {
15052 auto *WrapPred = dyn_cast<SCEVWrapPredicate>(Pred);
15053 if (!WrapPred || WrapPred->getFlags() != SCEVWrapPredicate::IncrementNSSW)
15054 continue;
15055
15056 const SCEVAddRecExpr *AddRecToCheck = WrapPred->getExpr();
15057 const SCEV *ExitCount = getBackedgeTakenCount(AddRecToCheck->getLoop());
15058 if (isa<SCEVCouldNotCompute>(ExitCount))
15059 continue;
15060
15061 const SCEV *Step = AddRecToCheck->getStepRecurrence(*this);
15062 if (!Step->isOne())
15063 continue;
15064
15065 ExitCount = getTruncateOrSignExtend(ExitCount, Step->getType());
15066 const SCEV *Add = getAddExpr(AddRecToCheck->getStart(), ExitCount);
15067 if (isKnownPredicate(CmpInst::ICMP_SLT, Add, AddRecToCheck->getStart()))
15068 return nullptr;
15069 }
15070
15071 // Since the transformation was successful, we can now transfer the SCEV
15072 // predicates.
15073 Preds.append(TransformPreds.begin(), TransformPreds.end());
15074
15075 return AddRec;
15076}
15077
15078/// SCEV predicates
15082
15084 const ICmpInst::Predicate Pred,
15085 const SCEV *LHS, const SCEV *RHS)
15086 : SCEVPredicate(ID, P_Compare), Pred(Pred), LHS(LHS), RHS(RHS) {
15087 assert(LHS->getType() == RHS->getType() && "LHS and RHS types don't match");
15088 assert(LHS != RHS && "LHS and RHS are the same SCEV");
15089}
15090
15092 ScalarEvolution &SE) const {
15093 const auto *Op = dyn_cast<SCEVComparePredicate>(N);
15094
15095 if (!Op)
15096 return false;
15097
15098 if (Pred != ICmpInst::ICMP_EQ)
15099 return false;
15100
15101 return Op->LHS == LHS && Op->RHS == RHS;
15102}
15103
15104bool SCEVComparePredicate::isAlwaysTrue() const { return false; }
15105
15107 if (Pred == ICmpInst::ICMP_EQ)
15108 OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n";
15109 else
15110 OS.indent(Depth) << "Compare predicate: " << *LHS << " " << Pred << ") "
15111 << *RHS << "\n";
15112
15113}
15114
15116 const SCEVAddRecExpr *AR,
15117 IncrementWrapFlags Flags)
15118 : SCEVPredicate(ID, P_Wrap), AR(AR), Flags(Flags) {}
15119
15120const SCEVAddRecExpr *SCEVWrapPredicate::getExpr() const { return AR; }
15121
15123 ScalarEvolution &SE) const {
15124 const auto *Op = dyn_cast<SCEVWrapPredicate>(N);
15125 if (!Op || setFlags(Flags, Op->Flags) != Flags)
15126 return false;
15127
15128 if (Op->AR == AR)
15129 return true;
15130
15131 if (Flags != SCEVWrapPredicate::IncrementNSSW &&
15133 return false;
15134
15135 const SCEV *Start = AR->getStart();
15136 const SCEV *OpStart = Op->AR->getStart();
15137 if (Start->getType()->isPointerTy() != OpStart->getType()->isPointerTy())
15138 return false;
15139
15140 // Reject pointers to different address spaces.
15141 if (Start->getType()->isPointerTy() && Start->getType() != OpStart->getType())
15142 return false;
15143
15144 const SCEV *Step = AR->getStepRecurrence(SE);
15145 const SCEV *OpStep = Op->AR->getStepRecurrence(SE);
15146 if (!SE.isKnownPositive(Step) || !SE.isKnownPositive(OpStep))
15147 return false;
15148
15149 // If both steps are positive, this implies N, if N's start and step are
15150 // ULE/SLE (for NSUW/NSSW) than this'.
15151 Type *WiderTy = SE.getWiderType(Step->getType(), OpStep->getType());
15152 Step = SE.getNoopOrZeroExtend(Step, WiderTy);
15153 OpStep = SE.getNoopOrZeroExtend(OpStep, WiderTy);
15154
15155 bool IsNUW = Flags == SCEVWrapPredicate::IncrementNUSW;
15156 OpStart = IsNUW ? SE.getNoopOrZeroExtend(OpStart, WiderTy)
15157 : SE.getNoopOrSignExtend(OpStart, WiderTy);
15158 Start = IsNUW ? SE.getNoopOrZeroExtend(Start, WiderTy)
15159 : SE.getNoopOrSignExtend(Start, WiderTy);
15161 return SE.isKnownPredicate(Pred, OpStep, Step) &&
15162 SE.isKnownPredicate(Pred, OpStart, Start);
15163}
15164
15166 SCEV::NoWrapFlags ScevFlags = AR->getNoWrapFlags();
15167 IncrementWrapFlags IFlags = Flags;
15168
15169 if (ScalarEvolution::setFlags(ScevFlags, SCEV::FlagNSW) == ScevFlags)
15170 IFlags = clearFlags(IFlags, IncrementNSSW);
15171
15172 return IFlags == IncrementAnyWrap;
15173}
15174
15175void SCEVWrapPredicate::print(raw_ostream &OS, unsigned Depth) const {
15176 OS.indent(Depth) << *getExpr() << " Added Flags: ";
15178 OS << "<nusw>";
15180 OS << "<nssw>";
15181 OS << "\n";
15182}
15183
15186 ScalarEvolution &SE) {
15187 IncrementWrapFlags ImpliedFlags = IncrementAnyWrap;
15188 SCEV::NoWrapFlags StaticFlags = AR->getNoWrapFlags();
15189
15190 // We can safely transfer the NSW flag as NSSW.
15191 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNSW) == StaticFlags)
15192 ImpliedFlags = IncrementNSSW;
15193
15194 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNUW) == StaticFlags) {
15195 // If the increment is positive, the SCEV NUW flag will also imply the
15196 // WrapPredicate NUSW flag.
15197 if (const auto *Step = dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE)))
15198 if (Step->getValue()->getValue().isNonNegative())
15199 ImpliedFlags = setFlags(ImpliedFlags, IncrementNUSW);
15200 }
15201
15202 return ImpliedFlags;
15203}
15204
15205/// Union predicates don't get cached so create a dummy set ID for it.
15207 ScalarEvolution &SE)
15208 : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
15209 for (const auto *P : Preds)
15210 add(P, SE);
15211}
15212
15214 return all_of(Preds,
15215 [](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
15216}
15217
15219 ScalarEvolution &SE) const {
15220 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N))
15221 return all_of(Set->Preds, [this, &SE](const SCEVPredicate *I) {
15222 return this->implies(I, SE);
15223 });
15224
15225 return any_of(Preds,
15226 [N, &SE](const SCEVPredicate *I) { return I->implies(N, SE); });
15227}
15228
15230 for (const auto *Pred : Preds)
15231 Pred->print(OS, Depth);
15232}
15233
15234void SCEVUnionPredicate::add(const SCEVPredicate *N, ScalarEvolution &SE) {
15235 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) {
15236 for (const auto *Pred : Set->Preds)
15237 add(Pred, SE);
15238 return;
15239 }
15240
15241 // Implication checks are quadratic in the number of predicates. Stop doing
15242 // them if there are many predicates, as they should be too expensive to use
15243 // anyway at that point.
15244 bool CheckImplies = Preds.size() < 16;
15245
15246 // Only add predicate if it is not already implied by this union predicate.
15247 if (CheckImplies && implies(N, SE))
15248 return;
15249
15250 // Build a new vector containing the current predicates, except the ones that
15251 // are implied by the new predicate N.
15253 for (auto *P : Preds) {
15254 if (CheckImplies && N->implies(P, SE))
15255 continue;
15256 PrunedPreds.push_back(P);
15257 }
15258 Preds = std::move(PrunedPreds);
15259 Preds.push_back(N);
15260}
15261
15263 Loop &L)
15264 : SE(SE), L(L) {
15266 Preds = std::make_unique<SCEVUnionPredicate>(Empty, SE);
15267}
15268
15271 for (const auto *Op : Ops)
15272 // We do not expect that forgetting cached data for SCEVConstants will ever
15273 // open any prospects for sharpening or introduce any correctness issues,
15274 // so we don't bother storing their dependencies.
15275 if (!isa<SCEVConstant>(Op))
15276 SCEVUsers[Op].insert(User);
15277}
15278
15280 const SCEV *Expr = SE.getSCEV(V);
15281 return getPredicatedSCEV(Expr);
15282}
15283
15285 RewriteEntry &Entry = RewriteMap[Expr];
15286
15287 // If we already have an entry and the version matches, return it.
15288 if (Entry.second && Generation == Entry.first)
15289 return Entry.second;
15290
15291 // We found an entry but it's stale. Rewrite the stale entry
15292 // according to the current predicate.
15293 if (Entry.second)
15294 Expr = Entry.second;
15295
15296 const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, &L, *Preds);
15297 Entry = {Generation, NewSCEV};
15298
15299 return NewSCEV;
15300}
15301
15303 if (!BackedgeCount) {
15305 BackedgeCount = SE.getPredicatedBackedgeTakenCount(&L, Preds);
15306 for (const auto *P : Preds)
15307 addPredicate(*P);
15308 }
15309 return BackedgeCount;
15310}
15311
15313 if (!SymbolicMaxBackedgeCount) {
15315 SymbolicMaxBackedgeCount =
15316 SE.getPredicatedSymbolicMaxBackedgeTakenCount(&L, Preds);
15317 for (const auto *P : Preds)
15318 addPredicate(*P);
15319 }
15320 return SymbolicMaxBackedgeCount;
15321}
15322
15324 if (!SmallConstantMaxTripCount) {
15326 SmallConstantMaxTripCount = SE.getSmallConstantMaxTripCount(&L, &Preds);
15327 for (const auto *P : Preds)
15328 addPredicate(*P);
15329 }
15330 return *SmallConstantMaxTripCount;
15331}
15332
15334 if (Preds->implies(&Pred, SE))
15335 return;
15336
15337 SmallVector<const SCEVPredicate *, 4> NewPreds(Preds->getPredicates());
15338 NewPreds.push_back(&Pred);
15339 Preds = std::make_unique<SCEVUnionPredicate>(NewPreds, SE);
15340 updateGeneration();
15341}
15342
15344 return *Preds;
15345}
15346
15347void PredicatedScalarEvolution::updateGeneration() {
15348 // If the generation number wrapped recompute everything.
15349 if (++Generation == 0) {
15350 for (auto &II : RewriteMap) {
15351 const SCEV *Rewritten = II.second.second;
15352 II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, &L, *Preds)};
15353 }
15354 }
15355}
15356
15359 const SCEV *Expr = getSCEV(V);
15360 const auto *AR = cast<SCEVAddRecExpr>(Expr);
15361
15362 auto ImpliedFlags = SCEVWrapPredicate::getImpliedFlags(AR, SE);
15363
15364 // Clear the statically implied flags.
15365 Flags = SCEVWrapPredicate::clearFlags(Flags, ImpliedFlags);
15366 addPredicate(*SE.getWrapPredicate(AR, Flags));
15367
15368 auto II = FlagsMap.insert({V, Flags});
15369 if (!II.second)
15370 II.first->second = SCEVWrapPredicate::setFlags(Flags, II.first->second);
15371}
15372
15375 const SCEV *Expr = getSCEV(V);
15376 const auto *AR = cast<SCEVAddRecExpr>(Expr);
15377
15379 Flags, SCEVWrapPredicate::getImpliedFlags(AR, SE));
15380
15381 auto II = FlagsMap.find(V);
15382
15383 if (II != FlagsMap.end())
15384 Flags = SCEVWrapPredicate::clearFlags(Flags, II->second);
15385
15387}
15388
15390 const SCEV *Expr = this->getSCEV(V);
15392 auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, NewPreds);
15393
15394 if (!New)
15395 return nullptr;
15396
15397 for (const auto *P : NewPreds)
15398 addPredicate(*P);
15399
15400 RewriteMap[SE.getSCEV(V)] = {Generation, New};
15401 return New;
15402}
15403
15406 : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
15407 Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates(),
15408 SE)),
15409 Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
15410 for (auto I : Init.FlagsMap)
15411 FlagsMap.insert(I);
15412}
15413
15415 // For each block.
15416 for (auto *BB : L.getBlocks())
15417 for (auto &I : *BB) {
15418 if (!SE.isSCEVable(I.getType()))
15419 continue;
15420
15421 auto *Expr = SE.getSCEV(&I);
15422 auto II = RewriteMap.find(Expr);
15423
15424 if (II == RewriteMap.end())
15425 continue;
15426
15427 // Don't print things that are not interesting.
15428 if (II->second.second == Expr)
15429 continue;
15430
15431 OS.indent(Depth) << "[PSE]" << I << ":\n";
15432 OS.indent(Depth + 2) << *Expr << "\n";
15433 OS.indent(Depth + 2) << "--> " << *II->second.second << "\n";
15434 }
15435}
15436
15439 BasicBlock *Header = L->getHeader();
15440 BasicBlock *Pred = L->getLoopPredecessor();
15441 LoopGuards Guards(SE);
15442 if (!Pred)
15443 return Guards;
15445 collectFromBlock(SE, Guards, Header, Pred, VisitedBlocks);
15446 return Guards;
15447}
15448
15449void ScalarEvolution::LoopGuards::collectFromPHI(
15453 unsigned Depth) {
15454 if (!SE.isSCEVable(Phi.getType()))
15455 return;
15456
15457 using MinMaxPattern = std::pair<const SCEVConstant *, SCEVTypes>;
15458 auto GetMinMaxConst = [&](unsigned IncomingIdx) -> MinMaxPattern {
15459 const BasicBlock *InBlock = Phi.getIncomingBlock(IncomingIdx);
15460 if (!VisitedBlocks.insert(InBlock).second)
15461 return {nullptr, scCouldNotCompute};
15462
15463 // Avoid analyzing unreachable blocks so that we don't get trapped
15464 // traversing cycles with ill-formed dominance or infinite cycles
15465 if (!SE.DT.isReachableFromEntry(InBlock))
15466 return {nullptr, scCouldNotCompute};
15467
15468 auto [G, Inserted] = IncomingGuards.try_emplace(InBlock, LoopGuards(SE));
15469 if (Inserted)
15470 collectFromBlock(SE, G->second, Phi.getParent(), InBlock, VisitedBlocks,
15471 Depth + 1);
15472 auto &RewriteMap = G->second.RewriteMap;
15473 if (RewriteMap.empty())
15474 return {nullptr, scCouldNotCompute};
15475 auto S = RewriteMap.find(SE.getSCEV(Phi.getIncomingValue(IncomingIdx)));
15476 if (S == RewriteMap.end())
15477 return {nullptr, scCouldNotCompute};
15478 auto *SM = dyn_cast_if_present<SCEVMinMaxExpr>(S->second);
15479 if (!SM)
15480 return {nullptr, scCouldNotCompute};
15481 if (const SCEVConstant *C0 = dyn_cast<SCEVConstant>(SM->getOperand(0)))
15482 return {C0, SM->getSCEVType()};
15483 return {nullptr, scCouldNotCompute};
15484 };
15485 auto MergeMinMaxConst = [](MinMaxPattern P1,
15486 MinMaxPattern P2) -> MinMaxPattern {
15487 auto [C1, T1] = P1;
15488 auto [C2, T2] = P2;
15489 if (!C1 || !C2 || T1 != T2)
15490 return {nullptr, scCouldNotCompute};
15491 switch (T1) {
15492 case scUMaxExpr:
15493 return {C1->getAPInt().ult(C2->getAPInt()) ? C1 : C2, T1};
15494 case scSMaxExpr:
15495 return {C1->getAPInt().slt(C2->getAPInt()) ? C1 : C2, T1};
15496 case scUMinExpr:
15497 return {C1->getAPInt().ugt(C2->getAPInt()) ? C1 : C2, T1};
15498 case scSMinExpr:
15499 return {C1->getAPInt().sgt(C2->getAPInt()) ? C1 : C2, T1};
15500 default:
15501 llvm_unreachable("Trying to merge non-MinMaxExpr SCEVs.");
15502 }
15503 };
15504 auto P = GetMinMaxConst(0);
15505 for (unsigned int In = 1; In < Phi.getNumIncomingValues(); In++) {
15506 if (!P.first)
15507 break;
15508 P = MergeMinMaxConst(P, GetMinMaxConst(In));
15509 }
15510 if (P.first) {
15511 const SCEV *LHS = SE.getSCEV(const_cast<PHINode *>(&Phi));
15513 const SCEV *RHS = SE.getMinMaxExpr(P.second, Ops);
15514 Guards.RewriteMap.insert({LHS, RHS});
15515 }
15516}
15517
15518// Return a new SCEV that modifies \p Expr to the closest number divides by
15519// \p Divisor and less or equal than Expr. For now, only handle constant
15520// Expr.
15522 const APInt &DivisorVal,
15523 ScalarEvolution &SE) {
15524 const APInt *ExprVal;
15525 if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() ||
15526 DivisorVal.isNonPositive())
15527 return Expr;
15528 APInt Rem = ExprVal->urem(DivisorVal);
15529 // return the SCEV: Expr - Expr % Divisor
15530 return SE.getConstant(*ExprVal - Rem);
15531}
15532
15533// Return a new SCEV that modifies \p Expr to the closest number divides by
15534// \p Divisor and greater or equal than Expr. For now, only handle constant
15535// Expr.
15536static const SCEV *getNextSCEVDivisibleByDivisor(const SCEV *Expr,
15537 const APInt &DivisorVal,
15538 ScalarEvolution &SE) {
15539 const APInt *ExprVal;
15540 if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() ||
15541 DivisorVal.isNonPositive())
15542 return Expr;
15543 APInt Rem = ExprVal->urem(DivisorVal);
15544 if (Rem.isZero())
15545 return Expr;
15546 // return the SCEV: Expr + Divisor - Expr % Divisor
15547 return SE.getConstant(*ExprVal + DivisorVal - Rem);
15548}
15549
15551 ICmpInst::Predicate Predicate, const SCEV *LHS, const SCEV *RHS,
15554 // If we have LHS == 0, check if LHS is computing a property of some unknown
15555 // SCEV %v which we can rewrite %v to express explicitly.
15557 return false;
15558 // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
15559 // explicitly express that.
15560 const SCEVUnknown *URemLHS = nullptr;
15561 const SCEV *URemRHS = nullptr;
15562 if (!match(LHS, m_scev_URem(m_SCEVUnknown(URemLHS), m_SCEV(URemRHS), SE)))
15563 return false;
15564
15565 const SCEV *Multiple =
15566 SE.getMulExpr(SE.getUDivExpr(URemLHS, URemRHS), URemRHS);
15567 DivInfo[URemLHS] = Multiple;
15568 if (auto *C = dyn_cast<SCEVConstant>(URemRHS))
15569 Multiples[URemLHS] = C->getAPInt();
15570 return true;
15571}
15572
15573// Check if the condition is a divisibility guard (A % B == 0).
15574static bool isDivisibilityGuard(const SCEV *LHS, const SCEV *RHS,
15575 ScalarEvolution &SE) {
15576 const SCEV *X, *Y;
15577 return match(LHS, m_scev_URem(m_SCEV(X), m_SCEV(Y), SE)) && RHS->isZero();
15578}
15579
15580// Apply divisibility by \p Divisor on MinMaxExpr with constant values,
15581// recursively. This is done by aligning up/down the constant value to the
15582// Divisor.
15583static const SCEV *applyDivisibilityOnMinMaxExpr(const SCEV *MinMaxExpr,
15584 APInt Divisor,
15585 ScalarEvolution &SE) {
15586 // Return true if \p Expr is a MinMax SCEV expression with a non-negative
15587 // constant operand. If so, return in \p SCTy the SCEV type and in \p RHS
15588 // the non-constant operand and in \p LHS the constant operand.
15589 auto IsMinMaxSCEVWithNonNegativeConstant =
15590 [&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS,
15591 const SCEV *&RHS) {
15592 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
15593 if (MinMax->getNumOperands() != 2)
15594 return false;
15595 if (auto *C = dyn_cast<SCEVConstant>(MinMax->getOperand(0))) {
15596 if (C->getAPInt().isNegative())
15597 return false;
15598 SCTy = MinMax->getSCEVType();
15599 LHS = MinMax->getOperand(0);
15600 RHS = MinMax->getOperand(1);
15601 return true;
15602 }
15603 }
15604 return false;
15605 };
15606
15607 const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr;
15608 SCEVTypes SCTy;
15609 if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS,
15610 MinMaxRHS))
15611 return MinMaxExpr;
15612 auto IsMin = isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr);
15613 assert(SE.isKnownNonNegative(MinMaxLHS) && "Expected non-negative operand!");
15614 auto *DivisibleExpr =
15615 IsMin ? getPreviousSCEVDivisibleByDivisor(MinMaxLHS, Divisor, SE)
15616 : getNextSCEVDivisibleByDivisor(MinMaxLHS, Divisor, SE);
15618 applyDivisibilityOnMinMaxExpr(MinMaxRHS, Divisor, SE), DivisibleExpr};
15619 return SE.getMinMaxExpr(SCTy, Ops);
15620}
15621
15622void ScalarEvolution::LoopGuards::collectFromBlock(
15623 ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
15624 const BasicBlock *Block, const BasicBlock *Pred,
15625 SmallPtrSetImpl<const BasicBlock *> &VisitedBlocks, unsigned Depth) {
15626
15628
15629 SmallVector<const SCEV *> ExprsToRewrite;
15630 auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
15631 const SCEV *RHS,
15632 DenseMap<const SCEV *, const SCEV *> &RewriteMap,
15633 const LoopGuards &DivGuards) {
15634 // WARNING: It is generally unsound to apply any wrap flags to the proposed
15635 // replacement SCEV which isn't directly implied by the structure of that
15636 // SCEV. In particular, using contextual facts to imply flags is *NOT*
15637 // legal. See the scoping rules for flags in the header to understand why.
15638
15639 // Check for a condition of the form (-C1 + X < C2). InstCombine will
15640 // create this form when combining two checks of the form (X u< C2 + C1) and
15641 // (X >=u C1).
15642 auto MatchRangeCheckIdiom = [&SE, Predicate, LHS, RHS, &RewriteMap,
15643 &ExprsToRewrite]() {
15644 const SCEVConstant *C1;
15645 const SCEVUnknown *LHSUnknown;
15646 auto *C2 = dyn_cast<SCEVConstant>(RHS);
15647 if (!match(LHS,
15648 m_scev_Add(m_SCEVConstant(C1), m_SCEVUnknown(LHSUnknown))) ||
15649 !C2)
15650 return false;
15651
15652 auto ExactRegion =
15653 ConstantRange::makeExactICmpRegion(Predicate, C2->getAPInt())
15654 .sub(C1->getAPInt());
15655
15656 // Bail out, unless we have a non-wrapping, monotonic range.
15657 if (ExactRegion.isWrappedSet() || ExactRegion.isFullSet())
15658 return false;
15659 auto [I, Inserted] = RewriteMap.try_emplace(LHSUnknown);
15660 const SCEV *RewrittenLHS = Inserted ? LHSUnknown : I->second;
15661 I->second = SE.getUMaxExpr(
15662 SE.getConstant(ExactRegion.getUnsignedMin()),
15663 SE.getUMinExpr(RewrittenLHS,
15664 SE.getConstant(ExactRegion.getUnsignedMax())));
15665 ExprsToRewrite.push_back(LHSUnknown);
15666 return true;
15667 };
15668 if (MatchRangeCheckIdiom())
15669 return;
15670
15671 // Do not apply information for constants or if RHS contains an AddRec.
15673 return;
15674
15675 // If RHS is SCEVUnknown, make sure the information is applied to it.
15677 std::swap(LHS, RHS);
15679 }
15680
15681 // Puts rewrite rule \p From -> \p To into the rewrite map. Also if \p From
15682 // and \p FromRewritten are the same (i.e. there has been no rewrite
15683 // registered for \p From), then puts this value in the list of rewritten
15684 // expressions.
15685 auto AddRewrite = [&](const SCEV *From, const SCEV *FromRewritten,
15686 const SCEV *To) {
15687 if (From == FromRewritten)
15688 ExprsToRewrite.push_back(From);
15689 RewriteMap[From] = To;
15690 };
15691
15692 // Checks whether \p S has already been rewritten. In that case returns the
15693 // existing rewrite because we want to chain further rewrites onto the
15694 // already rewritten value. Otherwise returns \p S.
15695 auto GetMaybeRewritten = [&](const SCEV *S) {
15696 return RewriteMap.lookup_or(S, S);
15697 };
15698
15699 const SCEV *RewrittenLHS = GetMaybeRewritten(LHS);
15700 // Apply divisibility information when computing the constant multiple.
15701 const APInt &DividesBy =
15702 SE.getConstantMultiple(DivGuards.rewrite(RewrittenLHS));
15703
15704 // Collect rewrites for LHS and its transitive operands based on the
15705 // condition.
15706 // For min/max expressions, also apply the guard to its operands:
15707 // 'min(a, b) >= c' -> '(a >= c) and (b >= c)',
15708 // 'min(a, b) > c' -> '(a > c) and (b > c)',
15709 // 'max(a, b) <= c' -> '(a <= c) and (b <= c)',
15710 // 'max(a, b) < c' -> '(a < c) and (b < c)'.
15711
15712 // We cannot express strict predicates in SCEV, so instead we replace them
15713 // with non-strict ones against plus or minus one of RHS depending on the
15714 // predicate.
15715 const SCEV *One = SE.getOne(RHS->getType());
15716 switch (Predicate) {
15717 case CmpInst::ICMP_ULT:
15718 if (RHS->getType()->isPointerTy())
15719 return;
15720 RHS = SE.getUMaxExpr(RHS, One);
15721 [[fallthrough]];
15722 case CmpInst::ICMP_SLT: {
15723 RHS = SE.getMinusSCEV(RHS, One);
15724 RHS = getPreviousSCEVDivisibleByDivisor(RHS, DividesBy, SE);
15725 break;
15726 }
15727 case CmpInst::ICMP_UGT:
15728 case CmpInst::ICMP_SGT:
15729 RHS = SE.getAddExpr(RHS, One);
15730 RHS = getNextSCEVDivisibleByDivisor(RHS, DividesBy, SE);
15731 break;
15732 case CmpInst::ICMP_ULE:
15733 case CmpInst::ICMP_SLE:
15734 RHS = getPreviousSCEVDivisibleByDivisor(RHS, DividesBy, SE);
15735 break;
15736 case CmpInst::ICMP_UGE:
15737 case CmpInst::ICMP_SGE:
15738 RHS = getNextSCEVDivisibleByDivisor(RHS, DividesBy, SE);
15739 break;
15740 default:
15741 break;
15742 }
15743
15745 SmallPtrSet<const SCEV *, 16> Visited;
15746
15747 auto EnqueueOperands = [&Worklist](const SCEVNAryExpr *S) {
15748 append_range(Worklist, S->operands());
15749 };
15750
15751 while (!Worklist.empty()) {
15752 const SCEV *From = Worklist.pop_back_val();
15753 if (isa<SCEVConstant>(From))
15754 continue;
15755 if (!Visited.insert(From).second)
15756 continue;
15757 const SCEV *FromRewritten = GetMaybeRewritten(From);
15758 const SCEV *To = nullptr;
15759
15760 switch (Predicate) {
15761 case CmpInst::ICMP_ULT:
15762 case CmpInst::ICMP_ULE:
15763 To = SE.getUMinExpr(FromRewritten, RHS);
15764 if (auto *UMax = dyn_cast<SCEVUMaxExpr>(FromRewritten))
15765 EnqueueOperands(UMax);
15766 break;
15767 case CmpInst::ICMP_SLT:
15768 case CmpInst::ICMP_SLE:
15769 To = SE.getSMinExpr(FromRewritten, RHS);
15770 if (auto *SMax = dyn_cast<SCEVSMaxExpr>(FromRewritten))
15771 EnqueueOperands(SMax);
15772 break;
15773 case CmpInst::ICMP_UGT:
15774 case CmpInst::ICMP_UGE:
15775 To = SE.getUMaxExpr(FromRewritten, RHS);
15776 if (auto *UMin = dyn_cast<SCEVUMinExpr>(FromRewritten))
15777 EnqueueOperands(UMin);
15778 break;
15779 case CmpInst::ICMP_SGT:
15780 case CmpInst::ICMP_SGE:
15781 To = SE.getSMaxExpr(FromRewritten, RHS);
15782 if (auto *SMin = dyn_cast<SCEVSMinExpr>(FromRewritten))
15783 EnqueueOperands(SMin);
15784 break;
15785 case CmpInst::ICMP_EQ:
15787 To = RHS;
15788 break;
15789 case CmpInst::ICMP_NE:
15790 if (match(RHS, m_scev_Zero())) {
15791 const SCEV *OneAlignedUp =
15792 getNextSCEVDivisibleByDivisor(One, DividesBy, SE);
15793 To = SE.getUMaxExpr(FromRewritten, OneAlignedUp);
15794 } else {
15795 // LHS != RHS can be rewritten as (LHS - RHS) = UMax(1, LHS - RHS),
15796 // but creating the subtraction eagerly is expensive. Track the
15797 // inequalities in a separate map, and materialize the rewrite lazily
15798 // when encountering a suitable subtraction while re-writing.
15799 if (LHS->getType()->isPointerTy()) {
15803 break;
15804 }
15805 const SCEVConstant *C;
15806 const SCEV *A, *B;
15809 RHS = A;
15810 LHS = B;
15811 }
15812 if (LHS > RHS)
15813 std::swap(LHS, RHS);
15814 Guards.NotEqual.insert({LHS, RHS});
15815 continue;
15816 }
15817 break;
15818 default:
15819 break;
15820 }
15821
15822 if (To)
15823 AddRewrite(From, FromRewritten, To);
15824 }
15825 };
15826
15828 // First, collect information from assumptions dominating the loop.
15829 for (auto &AssumeVH : SE.AC.assumptions()) {
15830 if (!AssumeVH)
15831 continue;
15832 auto *AssumeI = cast<CallInst>(AssumeVH);
15833 if (!SE.DT.dominates(AssumeI, Block))
15834 continue;
15835 Terms.emplace_back(AssumeI->getOperand(0), true);
15836 }
15837
15838 // Second, collect information from llvm.experimental.guards dominating the loop.
15839 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
15840 SE.F.getParent(), Intrinsic::experimental_guard);
15841 if (GuardDecl)
15842 for (const auto *GU : GuardDecl->users())
15843 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
15844 if (Guard->getFunction() == Block->getParent() &&
15845 SE.DT.dominates(Guard, Block))
15846 Terms.emplace_back(Guard->getArgOperand(0), true);
15847
15848 // Third, collect conditions from dominating branches. Starting at the loop
15849 // predecessor, climb up the predecessor chain, as long as there are
15850 // predecessors that can be found that have unique successors leading to the
15851 // original header.
15852 // TODO: share this logic with isLoopEntryGuardedByCond.
15853 unsigned NumCollectedConditions = 0;
15855 std::pair<const BasicBlock *, const BasicBlock *> Pair(Pred, Block);
15856 for (; Pair.first;
15857 Pair = SE.getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
15858 VisitedBlocks.insert(Pair.second);
15859 const BranchInst *LoopEntryPredicate =
15860 dyn_cast<BranchInst>(Pair.first->getTerminator());
15861 if (!LoopEntryPredicate || LoopEntryPredicate->isUnconditional())
15862 continue;
15863
15864 Terms.emplace_back(LoopEntryPredicate->getCondition(),
15865 LoopEntryPredicate->getSuccessor(0) == Pair.second);
15866 NumCollectedConditions++;
15867
15868 // If we are recursively collecting guards stop after 2
15869 // conditions to limit compile-time impact for now.
15870 if (Depth > 0 && NumCollectedConditions == 2)
15871 break;
15872 }
15873 // Finally, if we stopped climbing the predecessor chain because
15874 // there wasn't a unique one to continue, try to collect conditions
15875 // for PHINodes by recursively following all of their incoming
15876 // blocks and try to merge the found conditions to build a new one
15877 // for the Phi.
15878 if (Pair.second->hasNPredecessorsOrMore(2) &&
15880 SmallDenseMap<const BasicBlock *, LoopGuards> IncomingGuards;
15881 for (auto &Phi : Pair.second->phis())
15882 collectFromPHI(SE, Guards, Phi, VisitedBlocks, IncomingGuards, Depth);
15883 }
15884
15885 // Now apply the information from the collected conditions to
15886 // Guards.RewriteMap. Conditions are processed in reverse order, so the
15887 // earliest conditions is processed first, except guards with divisibility
15888 // information, which are moved to the back. This ensures the SCEVs with the
15889 // shortest dependency chains are constructed first.
15891 GuardsToProcess;
15892 for (auto [Term, EnterIfTrue] : reverse(Terms)) {
15893 SmallVector<Value *, 8> Worklist;
15894 SmallPtrSet<Value *, 8> Visited;
15895 Worklist.push_back(Term);
15896 while (!Worklist.empty()) {
15897 Value *Cond = Worklist.pop_back_val();
15898 if (!Visited.insert(Cond).second)
15899 continue;
15900
15901 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
15902 auto Predicate =
15903 EnterIfTrue ? Cmp->getPredicate() : Cmp->getInversePredicate();
15904 const auto *LHS = SE.getSCEV(Cmp->getOperand(0));
15905 const auto *RHS = SE.getSCEV(Cmp->getOperand(1));
15906 // If LHS is a constant, apply information to the other expression.
15907 // TODO: If LHS is not a constant, check if using CompareSCEVComplexity
15908 // can improve results.
15909 if (isa<SCEVConstant>(LHS)) {
15910 std::swap(LHS, RHS);
15912 }
15913 GuardsToProcess.emplace_back(Predicate, LHS, RHS);
15914 continue;
15915 }
15916
15917 Value *L, *R;
15918 if (EnterIfTrue ? match(Cond, m_LogicalAnd(m_Value(L), m_Value(R)))
15919 : match(Cond, m_LogicalOr(m_Value(L), m_Value(R)))) {
15920 Worklist.push_back(L);
15921 Worklist.push_back(R);
15922 }
15923 }
15924 }
15925
15926 // Process divisibility guards in reverse order to populate DivGuards early.
15927 DenseMap<const SCEV *, APInt> Multiples;
15928 LoopGuards DivGuards(SE);
15929 for (const auto &[Predicate, LHS, RHS] : GuardsToProcess) {
15930 if (!isDivisibilityGuard(LHS, RHS, SE))
15931 continue;
15932 collectDivisibilityInformation(Predicate, LHS, RHS, DivGuards.RewriteMap,
15933 Multiples, SE);
15934 }
15935
15936 for (const auto &[Predicate, LHS, RHS] : GuardsToProcess)
15937 CollectCondition(Predicate, LHS, RHS, Guards.RewriteMap, DivGuards);
15938
15939 // Apply divisibility information last. This ensures it is applied to the
15940 // outermost expression after other rewrites for the given value.
15941 for (const auto &[K, Divisor] : Multiples) {
15942 const SCEV *DivisorSCEV = SE.getConstant(Divisor);
15943 Guards.RewriteMap[K] =
15945 Guards.rewrite(K), Divisor, SE),
15946 DivisorSCEV),
15947 DivisorSCEV);
15948 ExprsToRewrite.push_back(K);
15949 }
15950
15951 // Let the rewriter preserve NUW/NSW flags if the unsigned/signed ranges of
15952 // the replacement expressions are contained in the ranges of the replaced
15953 // expressions.
15954 Guards.PreserveNUW = true;
15955 Guards.PreserveNSW = true;
15956 for (const SCEV *Expr : ExprsToRewrite) {
15957 const SCEV *RewriteTo = Guards.RewriteMap[Expr];
15958 Guards.PreserveNUW &=
15959 SE.getUnsignedRange(Expr).contains(SE.getUnsignedRange(RewriteTo));
15960 Guards.PreserveNSW &=
15961 SE.getSignedRange(Expr).contains(SE.getSignedRange(RewriteTo));
15962 }
15963
15964 // Now that all rewrite information is collect, rewrite the collected
15965 // expressions with the information in the map. This applies information to
15966 // sub-expressions.
15967 if (ExprsToRewrite.size() > 1) {
15968 for (const SCEV *Expr : ExprsToRewrite) {
15969 const SCEV *RewriteTo = Guards.RewriteMap[Expr];
15970 Guards.RewriteMap.erase(Expr);
15971 Guards.RewriteMap.insert({Expr, Guards.rewrite(RewriteTo)});
15972 }
15973 }
15974}
15975
15977 /// A rewriter to replace SCEV expressions in Map with the corresponding entry
15978 /// in the map. It skips AddRecExpr because we cannot guarantee that the
15979 /// replacement is loop invariant in the loop of the AddRec.
15980 class SCEVLoopGuardRewriter
15981 : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
15984
15986
15987 public:
15988 SCEVLoopGuardRewriter(ScalarEvolution &SE,
15989 const ScalarEvolution::LoopGuards &Guards)
15990 : SCEVRewriteVisitor(SE), Map(Guards.RewriteMap),
15991 NotEqual(Guards.NotEqual) {
15992 if (Guards.PreserveNUW)
15993 FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNUW);
15994 if (Guards.PreserveNSW)
15995 FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNSW);
15996 }
15997
15998 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
15999
16000 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
16001 return Map.lookup_or(Expr, Expr);
16002 }
16003
16004 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
16005 if (const SCEV *S = Map.lookup(Expr))
16006 return S;
16007
16008 // If we didn't find the extact ZExt expr in the map, check if there's
16009 // an entry for a smaller ZExt we can use instead.
16010 Type *Ty = Expr->getType();
16011 const SCEV *Op = Expr->getOperand(0);
16012 unsigned Bitwidth = Ty->getScalarSizeInBits() / 2;
16013 while (Bitwidth % 8 == 0 && Bitwidth >= 8 &&
16014 Bitwidth > Op->getType()->getScalarSizeInBits()) {
16015 Type *NarrowTy = IntegerType::get(SE.getContext(), Bitwidth);
16016 auto *NarrowExt = SE.getZeroExtendExpr(Op, NarrowTy);
16017 if (const SCEV *S = Map.lookup(NarrowExt))
16018 return SE.getZeroExtendExpr(S, Ty);
16019 Bitwidth = Bitwidth / 2;
16020 }
16021
16023 Expr);
16024 }
16025
16026 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
16027 if (const SCEV *S = Map.lookup(Expr))
16028 return S;
16030 Expr);
16031 }
16032
16033 const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) {
16034 if (const SCEV *S = Map.lookup(Expr))
16035 return S;
16037 }
16038
16039 const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) {
16040 if (const SCEV *S = Map.lookup(Expr))
16041 return S;
16043 }
16044
16045 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
16046 // Helper to check if S is a subtraction (A - B) where A != B, and if so,
16047 // return UMax(S, 1).
16048 auto RewriteSubtraction = [&](const SCEV *S) -> const SCEV * {
16049 const SCEV *LHS, *RHS;
16050 if (MatchBinarySub(S, LHS, RHS)) {
16051 if (LHS > RHS)
16052 std::swap(LHS, RHS);
16053 if (NotEqual.contains({LHS, RHS})) {
16054 const SCEV *OneAlignedUp = getNextSCEVDivisibleByDivisor(
16055 SE.getOne(S->getType()), SE.getConstantMultiple(S), SE);
16056 return SE.getUMaxExpr(OneAlignedUp, S);
16057 }
16058 }
16059 return nullptr;
16060 };
16061
16062 // Check if Expr itself is a subtraction pattern with guard info.
16063 if (const SCEV *Rewritten = RewriteSubtraction(Expr))
16064 return Rewritten;
16065
16066 // Trip count expressions sometimes consist of adding 3 operands, i.e.
16067 // (Const + A + B). There may be guard info for A + B, and if so, apply
16068 // it.
16069 // TODO: Could more generally apply guards to Add sub-expressions.
16070 if (isa<SCEVConstant>(Expr->getOperand(0)) &&
16071 Expr->getNumOperands() == 3) {
16072 const SCEV *Add =
16073 SE.getAddExpr(Expr->getOperand(1), Expr->getOperand(2));
16074 if (const SCEV *Rewritten = RewriteSubtraction(Add))
16075 return SE.getAddExpr(
16076 Expr->getOperand(0), Rewritten,
16077 ScalarEvolution::maskFlags(Expr->getNoWrapFlags(), FlagMask));
16078 if (const SCEV *S = Map.lookup(Add))
16079 return SE.getAddExpr(Expr->getOperand(0), S);
16080 }
16082 bool Changed = false;
16083 for (const auto *Op : Expr->operands()) {
16084 Operands.push_back(
16086 Changed |= Op != Operands.back();
16087 }
16088 // We are only replacing operands with equivalent values, so transfer the
16089 // flags from the original expression.
16090 return !Changed ? Expr
16091 : SE.getAddExpr(Operands,
16093 Expr->getNoWrapFlags(), FlagMask));
16094 }
16095
16096 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
16098 bool Changed = false;
16099 for (const auto *Op : Expr->operands()) {
16100 Operands.push_back(
16102 Changed |= Op != Operands.back();
16103 }
16104 // We are only replacing operands with equivalent values, so transfer the
16105 // flags from the original expression.
16106 return !Changed ? Expr
16107 : SE.getMulExpr(Operands,
16109 Expr->getNoWrapFlags(), FlagMask));
16110 }
16111 };
16112
16113 if (RewriteMap.empty() && NotEqual.empty())
16114 return Expr;
16115
16116 SCEVLoopGuardRewriter Rewriter(SE, *this);
16117 return Rewriter.visit(Expr);
16118}
16119
16120const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
16121 return applyLoopGuards(Expr, LoopGuards::collect(L, *this));
16122}
16123
16125 const LoopGuards &Guards) {
16126 return Guards.rewrite(Expr);
16127}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
constexpr LLT S1
Rewrite undef for PHI
This file implements a class to represent arbitrary precision integral constant values and operations...
@ PostInc
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Expand Atomic instructions
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
#define LLVM_DUMP_METHOD
Mark debug helper function definitions like dump() that should not be stripped from debug builds.
Definition Compiler.h:638
This file contains the declarations for the subclasses of Constant, which represent the different fla...
SmallPtrSet< const BasicBlock *, 8 > VisitedBlocks
This file defines the DenseMap class.
This file builds on the ADT/GraphTraits.h file to build generic depth first graph iterator.
This file defines a hash set that can be used to remove duplication of nodes in a graph.
#define op(i)
Hexagon Common GEP
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
This defines the Use class.
iv Induction Variable Users
Definition IVUsers.cpp:48
const AbstractManglingParser< Derived, Alloc >::OperatorInfo AbstractManglingParser< Derived, Alloc >::Ops[]
static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, AssumptionCache *AC)
Definition Lint.cpp:539
#define F(x, y, z)
Definition MD5.cpp:54
#define I(x, y, z)
Definition MD5.cpp:57
#define G(x, y, z)
Definition MD5.cpp:55
#define T
#define T1
ConstantRange Range(APInt(BitWidth, Low), APInt(BitWidth, High))
uint64_t IntrinsicInst * II
#define P(N)
ppc ctr loops verify
PowerPC Reduce CR logical Operation
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition PassSupport.h:42
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition PassSupport.h:44
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition PassSupport.h:39
R600 Clause Merge
const SmallVectorImpl< MachineOperand > & Cond
static bool isValid(const char C)
Returns true if C is a valid mangled character: <0-9a-zA-Z_>.
SI optimize exec mask operations pre RA
void visit(MachineFunction &MF, MachineBasicBlock &Start, std::function< void(MachineBasicBlock *)> op)
This file contains some templates that are useful if you are working with the STL at all.
This file provides utility classes that use RAII to save and restore values.
bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind, SCEVTypes RootKind)
static cl::opt< unsigned > MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden, cl::desc("Max coefficients in AddRec during evolving"), cl::init(8))
static cl::opt< unsigned > RangeIterThreshold("scev-range-iter-threshold", cl::Hidden, cl::desc("Threshold for switching to iteratively computing SCEV ranges"), cl::init(32))
static const Loop * isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI)
static unsigned getConstantTripCount(const SCEVConstant *ExitCount)
static int CompareValueComplexity(const LoopInfo *const LI, Value *LV, Value *RV, unsigned Depth)
Compare the two values LV and RV in terms of their "complexity" where "complexity" is a partial (and ...
static const SCEV * getNextSCEVDivisibleByDivisor(const SCEV *Expr, const APInt &DivisorVal, ScalarEvolution &SE)
static void PushLoopPHIs(const Loop *L, SmallVectorImpl< Instruction * > &Worklist, SmallPtrSetImpl< Instruction * > &Visited)
Push PHI nodes in the header of the given loop onto the given Worklist.
static void insertFoldCacheEntry(const ScalarEvolution::FoldID &ID, const SCEV *S, DenseMap< ScalarEvolution::FoldID, const SCEV * > &FoldCache, DenseMap< const SCEV *, SmallVector< ScalarEvolution::FoldID, 2 > > &FoldCacheUser)
static cl::opt< bool > ClassifyExpressions("scalar-evolution-classify-expressions", cl::Hidden, cl::init(true), cl::desc("When printing analysis, include information on every instruction"))
static bool CanConstantFold(const Instruction *I)
Return true if we can constant fold an instruction of the specified type, assuming that all operands ...
static cl::opt< unsigned > AddOpsInlineThreshold("scev-addops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining addition operands into a SCEV"), cl::init(500))
static cl::opt< unsigned > MaxLoopGuardCollectionDepth("scalar-evolution-max-loop-guard-collection-depth", cl::Hidden, cl::desc("Maximum depth for recursive loop guard collection"), cl::init(1))
static cl::opt< bool > VerifyIR("scev-verify-ir", cl::Hidden, cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"), cl::init(false))
static bool BrPHIToSelect(DominatorTree &DT, BranchInst *BI, PHINode *Merge, Value *&C, Value *&LHS, Value *&RHS)
static const SCEV * getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static std::optional< APInt > MinOptional(std::optional< APInt > X, std::optional< APInt > Y)
Helper function to compare optional APInts: (a) if X and Y both exist, return min(X,...
static cl::opt< unsigned > MulOpsInlineThreshold("scev-mulops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining multiplication operands into a SCEV"), cl::init(32))
static void GroupByComplexity(SmallVectorImpl< const SCEV * > &Ops, LoopInfo *LI, DominatorTree &DT)
Given a list of SCEV objects, order them by their complexity, and group objects of the same complexit...
static const SCEV * constantFoldAndGroupOps(ScalarEvolution &SE, LoopInfo &LI, DominatorTree &DT, SmallVectorImpl< const SCEV * > &Ops, FoldT Fold, IsIdentityT IsIdentity, IsAbsorberT IsAbsorber)
Performs a number of common optimizations on the passed Ops.
static bool isDivisibilityGuard(const SCEV *LHS, const SCEV *RHS, ScalarEvolution &SE)
static std::optional< const SCEV * > createNodeForSelectViaUMinSeq(ScalarEvolution *SE, const SCEV *CondExpr, const SCEV *TrueExpr, const SCEV *FalseExpr)
static Constant * BuildConstantFromSCEV(const SCEV *V)
This builds up a Constant using the ConstantExpr interface.
static ConstantInt * EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C, ScalarEvolution &SE)
static const SCEV * BinomialCoefficient(const SCEV *It, unsigned K, ScalarEvolution &SE, Type *ResultTy)
Compute BC(It, K). The result has width W. Assume, K > 0.
static cl::opt< unsigned > MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden, cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"), cl::init(8))
static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr, const SCEV *Candidate)
Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
static PHINode * getConstantEvolvingPHI(Value *V, const Loop *L)
getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node in the loop that V is deri...
static const SCEV * SolveLinEquationWithOverflow(const APInt &A, const SCEV *B, SmallVectorImpl< const SCEVPredicate * > *Predicates, ScalarEvolution &SE, const Loop *L)
Finds the minimum unsigned root of the following equation:
static cl::opt< unsigned > MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden, cl::desc("Maximum number of iterations SCEV will " "symbolically execute a constant " "derived loop"), cl::init(100))
static bool MatchBinarySub(const SCEV *S, const SCEV *&LHS, const SCEV *&RHS)
static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow)
static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV *S)
When printing a top-level SCEV for trip counts, it's helpful to include a type for constants which ar...
static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, const Loop *L)
static bool containsConstantInAddMulChain(const SCEV *StartExpr)
Determine if any of the operands in this SCEV are a constant or if any of the add or multiply express...
static const SCEV * getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static bool hasHugeExpression(ArrayRef< const SCEV * > Ops)
Returns true if Ops contains a huge SCEV (the subtree of S contains at least HugeExprThreshold nodes)...
static cl::opt< unsigned > MaxPhiSCCAnalysisSize("scalar-evolution-max-scc-analysis-depth", cl::Hidden, cl::desc("Maximum amount of nodes to process while searching SCEVUnknown " "Phi strongly connected components"), cl::init(8))
static bool IsKnownPredicateViaAddRecStart(ScalarEvolution &SE, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
static bool collectDivisibilityInformation(ICmpInst::Predicate Predicate, const SCEV *LHS, const SCEV *RHS, DenseMap< const SCEV *, const SCEV * > &DivInfo, DenseMap< const SCEV *, APInt > &Multiples, ScalarEvolution &SE)
static cl::opt< unsigned > MaxSCEVOperationsImplicationDepth("scalar-evolution-max-scev-operations-implication-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV operations implication analysis"), cl::init(2))
static void PushDefUseChildren(Instruction *I, SmallVectorImpl< Instruction * > &Worklist, SmallPtrSetImpl< Instruction * > &Visited)
Push users of the given Instruction onto the given Worklist.
static std::optional< APInt > SolveQuadraticAddRecRange(const SCEVAddRecExpr *AddRec, const ConstantRange &Range, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n iterations.
static cl::opt< bool > UseContextForNoWrapFlagInference("scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden, cl::desc("Infer nuw/nsw flags using context where suitable"), cl::init(true))
static cl::opt< bool > EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden, cl::desc("Handle <= and >= in finite loops"), cl::init(true))
static std::optional< std::tuple< APInt, APInt, APInt, APInt, unsigned > > GetQuadraticEquation(const SCEVAddRecExpr *AddRec)
For a given quadratic addrec, generate coefficients of the corresponding quadratic equation,...
static bool isKnownPredicateExtendIdiom(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
static std::optional< BinaryOp > MatchBinaryOp(Value *V, const DataLayout &DL, AssumptionCache &AC, const DominatorTree &DT, const Instruction *CxtI)
Try to map V into a BinaryOp, and return std::nullopt on failure.
static std::optional< APInt > SolveQuadraticAddRecExact(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n iterations.
static std::optional< APInt > TruncIfPossible(std::optional< APInt > X, unsigned BitWidth)
Helper function to truncate an optional APInt to a given BitWidth.
static cl::opt< unsigned > MaxSCEVCompareDepth("scalar-evolution-max-scev-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV complexity comparisons"), cl::init(32))
static APInt extractConstantWithoutWrapping(ScalarEvolution &SE, const SCEVConstant *ConstantTerm, const SCEVAddExpr *WholeAddExpr)
static cl::opt< unsigned > MaxConstantEvolvingDepth("scalar-evolution-max-constant-evolving-depth", cl::Hidden, cl::desc("Maximum depth of recursive constant evolving"), cl::init(32))
static ConstantRange getRangeForAffineARHelper(APInt Step, const ConstantRange &StartRange, const APInt &MaxBECount, bool Signed)
static std::optional< ConstantRange > GetRangeFromMetadata(Value *V)
Helper method to assign a range to V from metadata present in the IR.
static cl::opt< unsigned > HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden, cl::desc("Size of the expression which is considered huge"), cl::init(4096))
static SCEV::NoWrapFlags StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, ArrayRef< const SCEV * > Ops, SCEV::NoWrapFlags Flags)
static Type * isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI, bool &Signed, ScalarEvolution &SE)
Helper function to createAddRecFromPHIWithCasts.
static Constant * EvaluateExpression(Value *V, const Loop *L, DenseMap< Instruction *, Constant * > &Vals, const DataLayout &DL, const TargetLibraryInfo *TLI)
EvaluateExpression - Given an expression that passes the getConstantEvolvingPHI predicate,...
static const SCEV * getPreviousSCEVDivisibleByDivisor(const SCEV *Expr, const APInt &DivisorVal, ScalarEvolution &SE)
static const SCEV * MatchNotExpr(const SCEV *Expr)
If Expr computes ~A, return A else return nullptr.
static cl::opt< unsigned > MaxValueCompareDepth("scalar-evolution-max-value-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive value complexity comparisons"), cl::init(2))
static const SCEV * applyDivisibilityOnMinMaxExpr(const SCEV *MinMaxExpr, APInt Divisor, ScalarEvolution &SE)
static cl::opt< bool, true > VerifySCEVOpt("verify-scev", cl::Hidden, cl::location(VerifySCEV), cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"))
static const SCEV * getSignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
static cl::opt< unsigned > MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden, cl::desc("Maximum depth of recursive arithmetics"), cl::init(32))
static bool HasSameValue(const SCEV *A, const SCEV *B)
SCEV structural equivalence is usually sufficient for testing whether two expressions are equal,...
static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow)
Compute the result of "n choose k", the binomial coefficient.
static std::optional< int > CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS, const SCEV *RHS, DominatorTree &DT, unsigned Depth=0)
static bool CollectAddOperandsWithScales(SmallDenseMap< const SCEV *, APInt, 16 > &M, SmallVectorImpl< const SCEV * > &NewOps, APInt &AccumulatedConstant, ArrayRef< const SCEV * > Ops, const APInt &Scale, ScalarEvolution &SE)
Process the given Ops list, which is a list of operands to be added under the given scale,...
static bool canConstantEvolve(Instruction *I, const Loop *L)
Determine whether this instruction can constant evolve within this loop assuming its operands can all...
static PHINode * getConstantEvolvingPHIOperands(Instruction *UseInst, const Loop *L, DenseMap< Instruction *, PHINode * > &PHIMap, unsigned Depth)
getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by recursing through each instructi...
static bool scevUnconditionallyPropagatesPoisonFromOperands(SCEVTypes Kind)
static cl::opt< bool > VerifySCEVStrict("verify-scev-strict", cl::Hidden, cl::desc("Enable stricter verification with -verify-scev is passed"))
static Constant * getOtherIncomingValue(PHINode *PN, BasicBlock *BB)
static cl::opt< bool > UseExpensiveRangeSharpening("scalar-evolution-use-expensive-range-sharpening", cl::Hidden, cl::init(false), cl::desc("Use more powerful methods of sharpening expression ranges. May " "be costly in terms of compile time"))
static const SCEV * getUnsignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Is LHS Pred RHS true on the virtue of LHS or RHS being a Min or Max expression?
This file defines the make_scope_exit function, which executes user-defined cleanup logic at scope ex...
static bool InBlock(const Value *V, const BasicBlock *BB)
Provides some synthesis utilities to produce sequences of values.
This file defines the SmallPtrSet class.
This file defines the SmallVector class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition Statistic.h:171
This file contains some functions that are useful when dealing with strings.
#define LLVM_DEBUG(...)
Definition Debug.h:114
static TableGen::Emitter::Opt Y("gen-skeleton-entry", EmitSkeleton, "Generate example skeleton entry")
static TableGen::Emitter::OptClass< SkeletonEmitter > X("gen-skeleton-class", "Generate example skeleton class")
static SymbolRef::Type getType(const Symbol *Sym)
Definition TapiFile.cpp:39
LocallyHashedType DenseMapInfo< LocallyHashedType >::Empty
static std::optional< unsigned > getOpcode(ArrayRef< VPValue * > Values)
Returns the opcode of Values or ~0 if they do not all agree.
Definition VPlanSLP.cpp:247
static std::optional< bool > isImpliedCondOperands(CmpInst::Predicate Pred, const Value *ALHS, const Value *ARHS, const Value *BLHS, const Value *BRHS)
Return true if "icmp Pred BLHS BRHS" is true whenever "icmp PredALHS ARHS" is true.
Virtual Register Rewriter
Value * RHS
Value * LHS
BinaryOperator * Mul
static const uint32_t IV[8]
Definition blake3_impl.h:83
Class for arbitrary precision integers.
Definition APInt.h:78
LLVM_ABI APInt umul_ov(const APInt &RHS, bool &Overflow) const
Definition APInt.cpp:1971
LLVM_ABI APInt zext(unsigned width) const
Zero extend to a new width.
Definition APInt.cpp:1012
bool isMinSignedValue() const
Determine if this is the smallest signed value.
Definition APInt.h:424
uint64_t getZExtValue() const
Get zero extended value.
Definition APInt.h:1541
void setHighBits(unsigned hiBits)
Set the top hiBits bits.
Definition APInt.h:1392
LLVM_ABI APInt getHiBits(unsigned numBits) const
Compute an APInt containing numBits highbits from this APInt.
Definition APInt.cpp:639
unsigned getActiveBits() const
Compute the number of active bits in the value.
Definition APInt.h:1513
LLVM_ABI APInt trunc(unsigned width) const
Truncate to new width.
Definition APInt.cpp:936
static APInt getMaxValue(unsigned numBits)
Gets maximum unsigned value of APInt for specific bit width.
Definition APInt.h:207
APInt abs() const
Get the absolute value.
Definition APInt.h:1796
bool sgt(const APInt &RHS) const
Signed greater than comparison.
Definition APInt.h:1202
bool isAllOnes() const
Determine if all bits are set. This is true for zero-width values.
Definition APInt.h:372
bool ugt(const APInt &RHS) const
Unsigned greater than comparison.
Definition APInt.h:1183
bool isZero() const
Determine if this value is zero, i.e. all bits are clear.
Definition APInt.h:381
bool isSignMask() const
Check if the APInt's value is returned by getSignMask.
Definition APInt.h:467
LLVM_ABI APInt urem(const APInt &RHS) const
Unsigned remainder operation.
Definition APInt.cpp:1666
unsigned getBitWidth() const
Return the number of bits in the APInt.
Definition APInt.h:1489
bool ult(const APInt &RHS) const
Unsigned less than comparison.
Definition APInt.h:1112
static APInt getSignedMaxValue(unsigned numBits)
Gets maximum signed value of APInt for a specific bit width.
Definition APInt.h:210
static APInt getMinValue(unsigned numBits)
Gets minimum unsigned value of APInt for a specific bit width.
Definition APInt.h:217
bool isNegative() const
Determine sign of this APInt.
Definition APInt.h:330
bool sle(const APInt &RHS) const
Signed less or equal comparison.
Definition APInt.h:1167
static APInt getSignedMinValue(unsigned numBits)
Gets minimum signed value of APInt for a specific bit width.
Definition APInt.h:220
bool isNonPositive() const
Determine if this APInt Value is non-positive (<= 0).
Definition APInt.h:362
unsigned countTrailingZeros() const
Definition APInt.h:1648
bool isStrictlyPositive() const
Determine if this APInt Value is positive.
Definition APInt.h:357
unsigned logBase2() const
Definition APInt.h:1762
APInt ashr(unsigned ShiftAmt) const
Arithmetic right-shift function.
Definition APInt.h:828
LLVM_ABI APInt multiplicativeInverse() const
Definition APInt.cpp:1274
bool ule(const APInt &RHS) const
Unsigned less or equal comparison.
Definition APInt.h:1151
LLVM_ABI APInt sext(unsigned width) const
Sign extend to a new width.
Definition APInt.cpp:985
APInt shl(unsigned shiftAmt) const
Left-shift function.
Definition APInt.h:874
bool isPowerOf2() const
Check if this APInt's value is a power of two greater than zero.
Definition APInt.h:441
static APInt getLowBitsSet(unsigned numBits, unsigned loBitsSet)
Constructs an APInt value that has the bottom loBitsSet bits set.
Definition APInt.h:307
bool isSignBitSet() const
Determine if sign bit of this APInt is set.
Definition APInt.h:342
bool slt(const APInt &RHS) const
Signed less than comparison.
Definition APInt.h:1131
static APInt getZero(unsigned numBits)
Get the '0' value for the specified bit-width.
Definition APInt.h:201
bool isIntN(unsigned N) const
Check if this APInt has an N-bits unsigned integer value.
Definition APInt.h:433
static APInt getOneBitSet(unsigned numBits, unsigned BitNo)
Return an APInt with exactly one bit set in the result.
Definition APInt.h:240
bool uge(const APInt &RHS) const
Unsigned greater or equal comparison.
Definition APInt.h:1222
This templated class represents "all analyses that operate over <aparticular IR unit>" (e....
Definition Analysis.h:50
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Represent the analysis usage information of a pass.
void setPreservesAll()
Set by analyses that do not transform their input at all.
AnalysisUsage & addRequiredTransitive()
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition ArrayRef.h:40
iterator end() const
Definition ArrayRef.h:131
size_t size() const
size - Get the array size.
Definition ArrayRef.h:142
iterator begin() const
Definition ArrayRef.h:130
A function analysis which provides an AssumptionCache.
An immutable pass that tracks lazily created AssumptionCache objects.
A cache of @llvm.assume calls within a function.
MutableArrayRef< WeakVH > assumptions()
Access the list of assumption handles currently tracked for this function.
LLVM_ABI bool isSingleEdge() const
Check if this is the only edge between Start and End.
LLVM Basic Block Representation.
Definition BasicBlock.h:62
iterator begin()
Instruction iterator methods.
Definition BasicBlock.h:459
const Function * getParent() const
Return the enclosing method, or null if none.
Definition BasicBlock.h:213
LLVM_ABI const BasicBlock * getSinglePredecessor() const
Return the predecessor of this block if it has a single predecessor block.
const Instruction & front() const
Definition BasicBlock.h:482
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition BasicBlock.h:233
LLVM_ABI unsigned getNoWrapKind() const
Returns one of OBO::NoSignedWrap or OBO::NoUnsignedWrap.
LLVM_ABI Instruction::BinaryOps getBinaryOp() const
Returns the binary operation underlying the intrinsic.
BinaryOps getOpcode() const
Definition InstrTypes.h:374
Conditional or Unconditional Branch instruction.
bool isConditional() const
BasicBlock * getSuccessor(unsigned i) const
bool isUnconditional() const
Value * getCondition() const
This class represents a function call, abstracting a target machine's calling convention.
virtual void deleted()
Callback for Value destruction.
void setValPtr(Value *P)
bool isFalseWhenEqual() const
This is just a convenience.
Definition InstrTypes.h:948
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition InstrTypes.h:676
@ ICMP_SLT
signed less than
Definition InstrTypes.h:705
@ ICMP_SLE
signed less or equal
Definition InstrTypes.h:706
@ ICMP_UGE
unsigned greater or equal
Definition InstrTypes.h:700
@ ICMP_UGT
unsigned greater than
Definition InstrTypes.h:699
@ ICMP_SGT
signed greater than
Definition InstrTypes.h:703
@ ICMP_ULT
unsigned less than
Definition InstrTypes.h:701
@ ICMP_NE
not equal
Definition InstrTypes.h:698
@ ICMP_SGE
signed greater or equal
Definition InstrTypes.h:704
@ ICMP_ULE
unsigned less or equal
Definition InstrTypes.h:702
bool isSigned() const
Definition InstrTypes.h:930
Predicate getSwappedPredicate() const
For example, EQ->EQ, SLE->SGE, ULT->UGT, OEQ->OEQ, ULE->UGE, OLT->OGT, etc.
Definition InstrTypes.h:827
bool isTrueWhenEqual() const
This is just a convenience.
Definition InstrTypes.h:942
Predicate getInversePredicate() const
For example, EQ -> NE, UGT -> ULE, SLT -> SGE, OEQ -> UNE, UGT -> OLE, OLT -> UGE,...
Definition InstrTypes.h:789
bool isUnsigned() const
Definition InstrTypes.h:936
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
Definition InstrTypes.h:926
An abstraction over a floating-point predicate, and a pack of an integer predicate with samesign info...
static LLVM_ABI std::optional< CmpPredicate > getMatching(CmpPredicate A, CmpPredicate B)
Compares two CmpPredicates taking samesign into account and returns the canonicalized CmpPredicate if...
LLVM_ABI CmpInst::Predicate getPreferredSignedPredicate() const
Attempts to return a signed CmpInst::Predicate from the CmpPredicate.
CmpInst::Predicate dropSameSign() const
Drops samesign information.
static LLVM_ABI Constant * getNot(Constant *C)
static LLVM_ABI Constant * getPtrToInt(Constant *C, Type *Ty, bool OnlyIfReduced=false)
static Constant * getGetElementPtr(Type *Ty, Constant *C, ArrayRef< Constant * > IdxList, GEPNoWrapFlags NW=GEPNoWrapFlags::none(), std::optional< ConstantRange > InRange=std::nullopt, Type *OnlyIfReducedTy=nullptr)
Getelementptr form.
Definition Constants.h:1284
static LLVM_ABI Constant * getAdd(Constant *C1, Constant *C2, bool HasNUW=false, bool HasNSW=false)
static LLVM_ABI Constant * getNeg(Constant *C, bool HasNSW=false)
static LLVM_ABI Constant * getTrunc(Constant *C, Type *Ty, bool OnlyIfReduced=false)
This is the shared class of boolean and integer constants.
Definition Constants.h:87
bool isZero() const
This is just a convenience method to make client code smaller for a common code.
Definition Constants.h:219
static LLVM_ABI ConstantInt * getFalse(LLVMContext &Context)
uint64_t getZExtValue() const
Return the constant as a 64-bit unsigned integer value after it has been zero extended as appropriate...
Definition Constants.h:168
const APInt & getValue() const
Return the constant as an APInt value reference.
Definition Constants.h:159
static LLVM_ABI ConstantInt * getBool(LLVMContext &Context, bool V)
This class represents a range of values.
LLVM_ABI ConstantRange add(const ConstantRange &Other) const
Return a new range representing the possible values resulting from an addition of a value in this ran...
LLVM_ABI ConstantRange zextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
PreferredRangeType
If represented precisely, the result of some range operations may consist of multiple disjoint ranges...
LLVM_ABI bool getEquivalentICmp(CmpInst::Predicate &Pred, APInt &RHS) const
Set up Pred and RHS such that ConstantRange::makeExactICmpRegion(Pred, RHS) == *this.
const APInt & getLower() const
Return the lower value for this range.
LLVM_ABI ConstantRange urem(const ConstantRange &Other) const
Return a new range representing the possible values resulting from an unsigned remainder operation of...
LLVM_ABI bool isFullSet() const
Return true if this set contains all of the elements possible for this data-type.
LLVM_ABI bool icmp(CmpInst::Predicate Pred, const ConstantRange &Other) const
Does the predicate Pred hold between ranges this and Other?
LLVM_ABI bool isEmptySet() const
Return true if this set contains no members.
LLVM_ABI ConstantRange zeroExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
LLVM_ABI bool isSignWrappedSet() const
Return true if this set wraps around the signed domain.
LLVM_ABI APInt getSignedMin() const
Return the smallest signed value contained in the ConstantRange.
LLVM_ABI bool isWrappedSet() const
Return true if this set wraps around the unsigned domain.
LLVM_ABI void print(raw_ostream &OS) const
Print out the bounds to a stream.
LLVM_ABI ConstantRange truncate(uint32_t BitWidth, unsigned NoWrapKind=0) const
Return a new range in the specified integer type, which must be strictly smaller than the current typ...
LLVM_ABI ConstantRange signExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
const APInt & getUpper() const
Return the upper value for this range.
LLVM_ABI ConstantRange unionWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the union of this range with another range.
static LLVM_ABI ConstantRange makeExactICmpRegion(CmpInst::Predicate Pred, const APInt &Other)
Produce the exact range such that all values in the returned range satisfy the given predicate with a...
LLVM_ABI bool contains(const APInt &Val) const
Return true if the specified value is in the set.
LLVM_ABI APInt getUnsignedMax() const
Return the largest unsigned value contained in the ConstantRange.
LLVM_ABI ConstantRange intersectWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the intersection of this range with another range.
LLVM_ABI APInt getSignedMax() const
Return the largest signed value contained in the ConstantRange.
static ConstantRange getNonEmpty(APInt Lower, APInt Upper)
Create non-empty constant range with the given bounds.
static LLVM_ABI ConstantRange makeGuaranteedNoWrapRegion(Instruction::BinaryOps BinOp, const ConstantRange &Other, unsigned NoWrapKind)
Produce the largest range containing all X such that "X BinOp Y" is guaranteed not to wrap (overflow)...
LLVM_ABI unsigned getMinSignedBits() const
Compute the maximal number of bits needed to represent every value in this signed range.
uint32_t getBitWidth() const
Get the bit width of this ConstantRange.
LLVM_ABI ConstantRange sub(const ConstantRange &Other) const
Return a new range representing the possible values resulting from a subtraction of a value in this r...
LLVM_ABI ConstantRange sextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
static LLVM_ABI ConstantRange makeExactNoWrapRegion(Instruction::BinaryOps BinOp, const APInt &Other, unsigned NoWrapKind)
Produce the range that contains X if and only if "X BinOp Other" does not wrap.
This is an important base class in LLVM.
Definition Constant.h:43
A parsed version of the target data layout string in and methods for querying it.
Definition DataLayout.h:64
LLVM_ABI const StructLayout * getStructLayout(StructType *Ty) const
Returns a StructLayout object, indicating the alignment of the struct, its size, and the offsets of i...
LLVM_ABI IntegerType * getIntPtrType(LLVMContext &C, unsigned AddressSpace=0) const
Returns an integer type with size at least as big as that of a pointer in the given address space.
LLVM_ABI unsigned getIndexTypeSizeInBits(Type *Ty) const
The size in bits of the index used in GEP calculation for this type.
LLVM_ABI IntegerType * getIndexType(LLVMContext &C, unsigned AddressSpace) const
Returns the type of a GEP index in AddressSpace.
TypeSize getTypeSizeInBits(Type *Ty) const
Size examples:
Definition DataLayout.h:771
ValueT lookup(const_arg_type_t< KeyT > Val) const
lookup - Return the entry for the specified key, or a default constructed value if no such entry exis...
Definition DenseMap.h:205
iterator find(const_arg_type_t< KeyT > Val)
Definition DenseMap.h:178
std::pair< iterator, bool > try_emplace(KeyT &&Key, Ts &&...Args)
Definition DenseMap.h:256
DenseMapIterator< KeyT, ValueT, KeyInfoT, BucketT > iterator
Definition DenseMap.h:74
iterator find_as(const LookupKeyT &Val)
Alternate version of find() which allows a different, and possibly less expensive,...
Definition DenseMap.h:191
size_type count(const_arg_type_t< KeyT > Val) const
Return 1 if the specified key is in the map, 0 otherwise.
Definition DenseMap.h:174
iterator end()
Definition DenseMap.h:81
bool contains(const_arg_type_t< KeyT > Val) const
Return true if the specified key is in the map, false otherwise.
Definition DenseMap.h:169
void swap(DerivedT &RHS)
Definition DenseMap.h:371
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition DenseMap.h:241
Analysis pass which computes a DominatorTree.
Definition Dominators.h:283
Legacy analysis pass which computes a DominatorTree.
Definition Dominators.h:321
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition Dominators.h:164
LLVM_ABI bool isReachableFromEntry(const Use &U) const
Provide an overload for a Use.
LLVM_ABI bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
FoldingSetNodeIDRef - This class describes a reference to an interned FoldingSetNodeID,...
Definition FoldingSet.h:172
FoldingSetNodeID - This class is used to gather all the unique data bits of a node.
Definition FoldingSet.h:209
FunctionPass(char &pid)
Definition Pass.h:316
Represents flags for the getelementptr instruction/expression.
bool hasNoUnsignedSignedWrap() const
bool hasNoUnsignedWrap() const
static GEPNoWrapFlags none()
static LLVM_ABI Type * getTypeAtIndex(Type *Ty, Value *Idx)
Return the type of the element at the given index of an indexable type.
Module * getParent()
Get the module that this global value is contained inside of...
static bool isPrivateLinkage(LinkageTypes Linkage)
static bool isInternalLinkage(LinkageTypes Linkage)
This instruction compares its operands according to the predicate given to the constructor.
CmpPredicate getCmpPredicate() const
static bool isGE(Predicate P)
Return true if the predicate is SGE or UGE.
CmpPredicate getSwappedCmpPredicate() const
static LLVM_ABI bool compare(const APInt &LHS, const APInt &RHS, ICmpInst::Predicate Pred)
Return result of LHS Pred RHS comparison.
static bool isLT(Predicate P)
Return true if the predicate is SLT or ULT.
CmpPredicate getInverseCmpPredicate() const
Predicate getNonStrictCmpPredicate() const
For example, SGT -> SGE, SLT -> SLE, ULT -> ULE, UGT -> UGE.
static bool isGT(Predicate P)
Return true if the predicate is SGT or UGT.
Predicate getFlippedSignednessPredicate() const
For example, SLT->ULT, ULT->SLT, SLE->ULE, ULE->SLE, EQ->EQ.
static CmpPredicate getInverseCmpPredicate(CmpPredicate Pred)
static bool isEquality(Predicate P)
Return true if this predicate is either EQ or NE.
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
static bool isLE(Predicate P)
Return true if the predicate is SLE or ULE.
LLVM_ABI bool hasNoUnsignedWrap() const LLVM_READONLY
Determine whether the no unsigned wrap flag is set.
LLVM_ABI bool hasNoSignedWrap() const LLVM_READONLY
Determine whether the no signed wrap flag is set.
LLVM_ABI bool isIdenticalToWhenDefined(const Instruction *I, bool IntersectAttrs=false) const LLVM_READONLY
This is like isIdenticalTo, except that it ignores the SubclassOptionalData flags,...
Class to represent integer types.
static LLVM_ABI IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
Definition Type.cpp:318
An instruction for reading from memory.
Analysis pass that exposes the LoopInfo for a function.
Definition LoopInfo.h:569
bool contains(const LoopT *L) const
Return true if the specified loop is contained within in this loop.
BlockT * getHeader() const
unsigned getLoopDepth() const
Return the nesting level of this loop.
BlockT * getLoopPredecessor() const
If the given loop's header has exactly one unique predecessor outside the loop, return it.
LoopT * getParentLoop() const
Return the parent loop if it exists or nullptr for top level loops.
unsigned getLoopDepth(const BlockT *BB) const
Return the loop nesting level of the specified block.
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
The legacy pass manager's analysis pass to compute loop information.
Definition LoopInfo.h:596
Represents a single loop in the control flow graph.
Definition LoopInfo.h:40
bool isLoopInvariant(const Value *V) const
Return true if the specified value is loop invariant.
Definition LoopInfo.cpp:61
Metadata node.
Definition Metadata.h:1078
A Module instance is used to store all the information related to an LLVM module.
Definition Module.h:67
unsigned getOpcode() const
Return the opcode for this Instruction or ConstantExpr.
Definition Operator.h:43
Utility class for integer operators which may exhibit overflow - Add, Sub, Mul, and Shl.
Definition Operator.h:78
bool hasNoSignedWrap() const
Test whether this operation is known to never undergo signed overflow, aka the nsw property.
Definition Operator.h:111
bool hasNoUnsignedWrap() const
Test whether this operation is known to never undergo unsigned overflow, aka the nuw property.
Definition Operator.h:105
iterator_range< const_block_iterator > blocks() const
op_range incoming_values()
Value * getIncomingValueForBlock(const BasicBlock *BB) const
BasicBlock * getIncomingBlock(unsigned i) const
Return incoming basic block number i.
Value * getIncomingValue(unsigned i) const
Return incoming value number x.
unsigned getNumIncomingValues() const
Return the number of incoming edges.
AnalysisType & getAnalysis() const
getAnalysis<AnalysisType>() - This function is used by subclasses to get to the analysis information ...
PointerIntPair - This class implements a pair of a pointer and small integer.
static PointerType * getUnqual(Type *ElementType)
This constructs a pointer to an object of the specified type in the default address space (address sp...
static LLVM_ABI PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
LLVM_ABI void addPredicate(const SCEVPredicate &Pred)
Adds a new predicate.
LLVM_ABI const SCEVPredicate & getPredicate() const
LLVM_ABI const SCEV * getPredicatedSCEV(const SCEV *Expr)
Returns the rewritten SCEV for Expr in the context of the current SCEV predicate.
LLVM_ABI bool hasNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags)
Returns true if we've proved that V doesn't wrap by means of a SCEV predicate.
LLVM_ABI void setNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags)
Proves that V doesn't overflow by adding SCEV predicate.
LLVM_ABI void print(raw_ostream &OS, unsigned Depth) const
Print the SCEV mappings done by the Predicated Scalar Evolution.
LLVM_ABI bool areAddRecsEqualWithPreds(const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const
Check if AR1 and AR2 are equal, while taking into account Equal predicates in Preds.
LLVM_ABI PredicatedScalarEvolution(ScalarEvolution &SE, Loop &L)
LLVM_ABI const SCEVAddRecExpr * getAsAddRec(Value *V)
Attempts to produce an AddRecExpr for V by adding additional SCEV predicates.
LLVM_ABI unsigned getSmallConstantMaxTripCount()
Returns the upper bound of the loop trip count as a normal unsigned value, or 0 if the trip count is ...
LLVM_ABI const SCEV * getBackedgeTakenCount()
Get the (predicated) backedge count for the analyzed loop.
LLVM_ABI const SCEV * getSymbolicMaxBackedgeTakenCount()
Get the (predicated) symbolic max backedge count for the analyzed loop.
LLVM_ABI const SCEV * getSCEV(Value *V)
Returns the SCEV expression of V, in the context of the current SCEV predicate.
A set of analyses that are preserved following a run of a transformation pass.
Definition Analysis.h:112
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition Analysis.h:118
PreservedAnalysisChecker getChecker() const
Build a checker for this PreservedAnalyses and the specified analysis type.
Definition Analysis.h:275
constexpr bool isValid() const
Definition Register.h:112
This node represents an addition of some number of SCEVs.
This node represents a polynomial recurrence on the trip count of the specified loop.
LLVM_ABI const SCEV * evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const
Return the value of this chain of recurrences at the specified iteration number.
const SCEV * getStepRecurrence(ScalarEvolution &SE) const
Constructs and returns the recurrence indicating how much this expression steps by.
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a recurrence without clearing any previously set flags.
bool isAffine() const
Return true if this represents an expression A + B*x where A and B are loop invariant values.
bool isQuadratic() const
Return true if this represents an expression A + B*x + C*x^2 where A, B and C are loop invariant valu...
LLVM_ABI const SCEV * getNumIterationsInRange(const ConstantRange &Range, ScalarEvolution &SE) const
Return the number of iterations of this loop that produce values in the specified constant range.
LLVM_ABI const SCEVAddRecExpr * getPostIncExpr(ScalarEvolution &SE) const
Return an expression representing the value of this expression one iteration of the loop ahead.
This is the base class for unary cast operator classes.
const SCEV * getOperand() const
LLVM_ABI SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op, Type *ty)
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a non-recurrence without clearing previously set flags.
This class represents an assumption that the expression LHS Pred RHS evaluates to true,...
SCEVComparePredicate(const FoldingSetNodeIDRef ID, const ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Implementation of the SCEVPredicate interface.
This class represents a constant integer value.
ConstantInt * getValue() const
const APInt & getAPInt() const
This is the base class for unary integral cast operator classes.
LLVM_ABI SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op, Type *ty)
This node is the base class min/max selections.
static enum SCEVTypes negate(enum SCEVTypes T)
This node represents multiplication of some number of SCEVs.
This node is a base class providing common functionality for n'ary operators.
NoWrapFlags getNoWrapFlags(NoWrapFlags Mask=NoWrapMask) const
const SCEV * getOperand(unsigned i) const
ArrayRef< const SCEV * > operands() const
This class represents an assumption made using SCEV expressions which can be checked at run-time.
SCEVPredicate(const SCEVPredicate &)=default
virtual bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const =0
Returns true if this predicate implies N.
SCEVPredicateKind Kind
This class represents a cast from a pointer to a pointer-sized integer value.
This visitor recursively visits a SCEV expression and re-writes it.
const SCEV * visitSignExtendExpr(const SCEVSignExtendExpr *Expr)
const SCEV * visit(const SCEV *S)
const SCEV * visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr)
const SCEV * visitSMinExpr(const SCEVSMinExpr *Expr)
const SCEV * visitUMinExpr(const SCEVUMinExpr *Expr)
This class represents a signed minimum selection.
This node is the base class for sequential/in-order min/max selections.
static SCEVTypes getEquivalentNonSequentialSCEVType(SCEVTypes Ty)
This class represents a sign extension of a small integer value to a larger integer value.
Visit all nodes in the expression tree using worklist traversal.
This class represents a truncation of an integer value to a smaller integer value.
This class represents a binary unsigned division operation.
This class represents an unsigned minimum selection.
This class represents a composition of other SCEV predicates, and is the class that most clients will...
void print(raw_ostream &OS, unsigned Depth) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Returns true if this predicate implies N.
SCEVUnionPredicate(ArrayRef< const SCEVPredicate * > Preds, ScalarEvolution &SE)
Union predicates don't get cached so create a dummy set ID for it.
bool isAlwaysTrue() const override
Implementation of the SCEVPredicate interface.
This means that we are dealing with an entirely unknown SCEV value, and only represent it as its LLVM...
This class represents the value of vscale, as used when defining the length of a scalable vector or r...
This class represents an assumption made on an AddRec expression.
IncrementWrapFlags
Similar to SCEV::NoWrapFlags, but with slightly different semantics for FlagNUSW.
SCEVWrapPredicate(const FoldingSetNodeIDRef ID, const SCEVAddRecExpr *AR, IncrementWrapFlags Flags)
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Returns true if this predicate implies N.
static SCEVWrapPredicate::IncrementWrapFlags setFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OnFlags)
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
const SCEVAddRecExpr * getExpr() const
Implementation of the SCEVPredicate interface.
static SCEVWrapPredicate::IncrementWrapFlags clearFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OffFlags)
Convenient IncrementWrapFlags manipulation methods.
static SCEVWrapPredicate::IncrementWrapFlags getImpliedFlags(const SCEVAddRecExpr *AR, ScalarEvolution &SE)
Returns the set of SCEVWrapPredicate no wrap flags implied by a SCEVAddRecExpr.
IncrementWrapFlags getFlags() const
Returns the set assumed no overflow flags.
This class represents a zero extension of a small integer value to a larger integer value.
This class represents an analyzed expression in the program.
LLVM_ABI ArrayRef< const SCEV * > operands() const
Return operands of this SCEV expression.
unsigned short getExpressionSize() const
LLVM_ABI bool isOne() const
Return true if the expression is a constant one.
LLVM_ABI bool isZero() const
Return true if the expression is a constant zero.
LLVM_ABI void dump() const
This method is used for debugging.
LLVM_ABI bool isAllOnesValue() const
Return true if the expression is a constant all-ones value.
LLVM_ABI bool isNonConstantNegative() const
Return true if the specified scev is negated, but not a constant.
LLVM_ABI void print(raw_ostream &OS) const
Print out the internal representation of this scalar to the specified stream.
SCEV(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, unsigned short ExpressionSize)
SCEVTypes getSCEVType() const
LLVM_ABI Type * getType() const
Return the LLVM type of this SCEV expression.
NoWrapFlags
NoWrapFlags are bitfield indices into SubclassData.
Analysis pass that exposes the ScalarEvolution for a function.
LLVM_ABI ScalarEvolution run(Function &F, FunctionAnalysisManager &AM)
LLVM_ABI PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
LLVM_ABI PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
void print(raw_ostream &OS, const Module *=nullptr) const override
print - Print out the internal state of the pass.
bool runOnFunction(Function &F) override
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
void releaseMemory() override
releaseMemory() - This member can be implemented by a pass if it wants to be able to release its memo...
void verifyAnalysis() const override
verifyAnalysis() - This member can be implemented by a analysis pass to check state of analysis infor...
static LLVM_ABI LoopGuards collect(const Loop *L, ScalarEvolution &SE)
Collect rewrite map for loop guards for loop L, together with flags indicating if NUW and NSW can be ...
LLVM_ABI const SCEV * rewrite(const SCEV *Expr) const
Try to apply the collected loop guards to Expr.
The main scalar evolution driver.
const SCEV * getConstantMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEVConstant that is greater than or equal to (i.e.
static bool hasFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags TestFlags)
const DataLayout & getDataLayout() const
Return the DataLayout associated with the module this SCEV instance is operating on.
LLVM_ABI bool isKnownNonNegative(const SCEV *S)
Test if the given expression is known to be non-negative.
LLVM_ABI bool isKnownOnEveryIteration(CmpPredicate Pred, const SCEVAddRecExpr *LHS, const SCEV *RHS)
Test if the condition described by Pred, LHS, RHS is known to be true on every iteration of the loop ...
LLVM_ABI const SCEV * getNegativeSCEV(const SCEV *V, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap)
Return the SCEV object corresponding to -V.
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterationsImpl(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
LLVM_ABI const SCEV * getSMaxExpr(const SCEV *LHS, const SCEV *RHS)
LLVM_ABI const SCEV * getUDivCeilSCEV(const SCEV *N, const SCEV *D)
Compute ceil(N / D).
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterations(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L at given Context duri...
LLVM_ABI Type * getWiderType(Type *Ty1, Type *Ty2) const
LLVM_ABI const SCEV * getAbsExpr(const SCEV *Op, bool IsNSW)
LLVM_ABI bool isKnownNonPositive(const SCEV *S)
Test if the given expression is known to be non-positive.
LLVM_ABI const SCEV * getURemExpr(const SCEV *LHS, const SCEV *RHS)
Represents an unsigned remainder expression based on unsigned division.
LLVM_ABI bool isKnownNegative(const SCEV *S)
Test if the given expression is known to be negative.
LLVM_ABI const SCEV * getPredicatedConstantMaxBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getConstantMaxBackedgeTakenCount, except it will add a set of SCEV predicates to Predicate...
LLVM_ABI const SCEV * removePointerBase(const SCEV *S)
Compute an expression equivalent to S - getPointerBase(S).
LLVM_ABI bool isLoopEntryGuardedByCond(const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the loop is protected by a conditional between LHS and RHS.
LLVM_ABI bool isKnownNonZero(const SCEV *S)
Test if the given expression is known to be non-zero.
LLVM_ABI const SCEV * getSCEVAtScope(const SCEV *S, const Loop *L)
Return a SCEV expression for the specified value at the specified scope in the program.
LLVM_ABI const SCEV * getSMinExpr(const SCEV *LHS, const SCEV *RHS)
LLVM_ABI const SCEV * getBackedgeTakenCount(const Loop *L, ExitCountKind Kind=Exact)
If the specified loop has a predictable backedge-taken count, return it, otherwise return a SCEVCould...
LLVM_ABI const SCEV * getUMaxExpr(const SCEV *LHS, const SCEV *RHS)
LLVM_ABI void setNoWrapFlags(SCEVAddRecExpr *AddRec, SCEV::NoWrapFlags Flags)
Update no-wrap flags of an AddRec.
LLVM_ABI const SCEV * getUMaxFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS)
Promote the operands to the wider of the types using zero-extension, and then perform a umax operatio...
const SCEV * getZero(Type *Ty)
Return a SCEV for the constant 0 of a specific type.
LLVM_ABI bool willNotOverflow(Instruction::BinaryOps BinOp, bool Signed, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI=nullptr)
Is operation BinOp between LHS and RHS provably does not have a signed/unsigned overflow (Signed)?
LLVM_ABI ExitLimit computeExitLimitFromCond(const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit, bool AllowPredicates=false)
Compute the number of times the backedge of the specified loop will execute if its exit condition wer...
LLVM_ABI const SCEV * getZeroExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI const SCEVPredicate * getEqualPredicate(const SCEV *LHS, const SCEV *RHS)
LLVM_ABI unsigned getSmallConstantTripMultiple(const Loop *L, const SCEV *ExitCount)
Returns the largest constant divisor of the trip count as a normal unsigned value,...
LLVM_ABI uint64_t getTypeSizeInBits(Type *Ty) const
Return the size in bits of the specified type, for which isSCEVable must return true.
LLVM_ABI const SCEV * getConstant(ConstantInt *V)
LLVM_ABI const SCEV * getPredicatedBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getBackedgeTakenCount, except it will add a set of SCEV predicates to Predicates that are ...
LLVM_ABI const SCEV * getSCEV(Value *V)
Return a SCEV expression for the full generality of the specified expression.
ConstantRange getSignedRange(const SCEV *S)
Determine the signed range for a particular SCEV.
LLVM_ABI const SCEV * getNoopOrSignExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
bool loopHasNoAbnormalExits(const Loop *L)
Return true if the loop has no abnormal exits.
LLVM_ABI const SCEV * getTripCountFromExitCount(const SCEV *ExitCount)
A version of getTripCountFromExitCount below which always picks an evaluation type which can not resu...
LLVM_ABI ScalarEvolution(Function &F, TargetLibraryInfo &TLI, AssumptionCache &AC, DominatorTree &DT, LoopInfo &LI)
const SCEV * getOne(Type *Ty)
Return a SCEV for the constant 1 of a specific type.
LLVM_ABI const SCEV * getTruncateOrNoop(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI const SCEV * getCastExpr(SCEVTypes Kind, const SCEV *Op, Type *Ty)
LLVM_ABI const SCEV * getSequentialMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< const SCEV * > &Operands)
LLVM_ABI const SCEV * getLosslessPtrToIntExpr(const SCEV *Op, unsigned Depth=0)
LLVM_ABI std::optional< bool > evaluatePredicateAt(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Check whether the condition described by Pred, LHS, and RHS is true or false in the given Context.
LLVM_ABI unsigned getSmallConstantMaxTripCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > *Predicates=nullptr)
Returns the upper bound of the loop trip count as a normal unsigned value.
LLVM_ABI const SCEV * getPtrToIntExpr(const SCEV *Op, Type *Ty)
LLVM_ABI bool isBackedgeTakenCountMaxOrZero(const Loop *L)
Return true if the backedge taken count is either the value returned by getConstantMaxBackedgeTakenCo...
LLVM_ABI void forgetLoop(const Loop *L)
This method should be called by the client when it has changed a loop in a way that may effect Scalar...
LLVM_ABI bool isLoopInvariant(const SCEV *S, const Loop *L)
Return true if the value of the given SCEV is unchanging in the specified loop.
LLVM_ABI bool isKnownPositive(const SCEV *S)
Test if the given expression is known to be positive.
APInt getUnsignedRangeMin(const SCEV *S)
Determine the min of the unsigned range for a particular SCEV.
LLVM_ABI bool SimplifyICmpOperands(CmpPredicate &Pred, const SCEV *&LHS, const SCEV *&RHS, unsigned Depth=0)
Simplify LHS and RHS in a comparison with predicate Pred.
LLVM_ABI const SCEV * getOffsetOfExpr(Type *IntTy, StructType *STy, unsigned FieldNo)
Return an expression for offsetof on the given field with type IntTy.
LLVM_ABI LoopDisposition getLoopDisposition(const SCEV *S, const Loop *L)
Return the "disposition" of the given SCEV with respect to the given loop.
LLVM_ABI bool containsAddRecurrence(const SCEV *S)
Return true if the SCEV is a scAddRecExpr or it contains scAddRecExpr.
LLVM_ABI const SCEV * getSignExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI const SCEV * getAddRecExpr(const SCEV *Start, const SCEV *Step, const Loop *L, SCEV::NoWrapFlags Flags)
Get an add recurrence expression for the specified loop.
LLVM_ABI bool hasOperand(const SCEV *S, const SCEV *Op) const
Test whether the given SCEV has Op as a direct or indirect operand.
LLVM_ABI const SCEV * getUDivExpr(const SCEV *LHS, const SCEV *RHS)
Get a canonical unsigned division expression, or something simpler if possible.
LLVM_ABI const SCEV * getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI bool isSCEVable(Type *Ty) const
Test if values of the given type are analyzable within the SCEV framework.
LLVM_ABI Type * getEffectiveSCEVType(Type *Ty) const
Return a type with the same bitwidth as the given type and which represents how SCEV will treat the g...
LLVM_ABI const SCEVPredicate * getComparePredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
LLVM_ABI bool haveSameSign(const SCEV *S1, const SCEV *S2)
Return true if we know that S1 and S2 must have the same sign.
LLVM_ABI const SCEV * getNotSCEV(const SCEV *V)
Return the SCEV object corresponding to ~V.
LLVM_ABI const SCEV * getElementCount(Type *Ty, ElementCount EC, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap)
LLVM_ABI bool instructionCouldExistWithOperands(const SCEV *A, const SCEV *B)
Return true if there exists a point in the program at which both A and B could be operands to the sam...
ConstantRange getUnsignedRange(const SCEV *S)
Determine the unsigned range for a particular SCEV.
LLVM_ABI void print(raw_ostream &OS) const
LLVM_ABI const SCEV * getUMinExpr(const SCEV *LHS, const SCEV *RHS, bool Sequential=false)
LLVM_ABI const SCEV * getPredicatedExitCount(const Loop *L, const BasicBlock *ExitingBlock, SmallVectorImpl< const SCEVPredicate * > *Predicates, ExitCountKind Kind=Exact)
Same as above except this uses the predicated backedge taken info and may require predicates.
static SCEV::NoWrapFlags clearFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OffFlags)
LLVM_ABI void forgetTopmostLoop(const Loop *L)
LLVM_ABI void forgetValue(Value *V)
This method should be called by the client when it has changed a value in a way that may effect its v...
APInt getSignedRangeMin(const SCEV *S)
Determine the min of the signed range for a particular SCEV.
LLVM_ABI const SCEV * getNoopOrAnyExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI void forgetBlockAndLoopDispositions(Value *V=nullptr)
Called when the client has changed the disposition of values in a loop or block.
LLVM_ABI const SCEV * getTruncateExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantPredicate(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI=nullptr)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L, return a LoopInvaria...
LLVM_ABI const SCEV * getStoreSizeOfExpr(Type *IntTy, Type *StoreTy)
Return an expression for the store size of StoreTy that is type IntTy.
LLVM_ABI const SCEVPredicate * getWrapPredicate(const SCEVAddRecExpr *AR, SCEVWrapPredicate::IncrementWrapFlags AddedFlags)
LLVM_ABI bool isLoopBackedgeGuardedByCond(const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether the backedge of the loop is protected by a conditional between LHS and RHS.
LLVM_ABI const SCEV * getMinusSCEV(const SCEV *LHS, const SCEV *RHS, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Return LHS-RHS.
LLVM_ABI APInt getNonZeroConstantMultiple(const SCEV *S)
const SCEV * getMinusOne(Type *Ty)
Return a SCEV for the constant -1 of a specific type.
static SCEV::NoWrapFlags setFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OnFlags)
LLVM_ABI bool hasLoopInvariantBackedgeTakenCount(const Loop *L)
Return true if the specified loop has an analyzable loop-invariant backedge-taken count.
LLVM_ABI BlockDisposition getBlockDisposition(const SCEV *S, const BasicBlock *BB)
Return the "disposition" of the given SCEV with respect to the given block.
LLVM_ABI const SCEV * getNoopOrZeroExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI bool invalidate(Function &F, const PreservedAnalyses &PA, FunctionAnalysisManager::Invalidator &Inv)
LLVM_ABI const SCEV * getUMinFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS, bool Sequential=false)
Promote the operands to the wider of the types using zero-extension, and then perform a umin operatio...
LLVM_ABI bool loopIsFiniteByAssumption(const Loop *L)
Return true if this loop is finite by assumption.
LLVM_ABI const SCEV * getExistingSCEV(Value *V)
Return an existing SCEV for V if there is one, otherwise return nullptr.
LLVM_ABI APInt getConstantMultiple(const SCEV *S, const Instruction *CtxI=nullptr)
Returns the max constant multiple of S.
LoopDisposition
An enum describing the relationship between a SCEV and a loop.
@ LoopComputable
The SCEV varies predictably with the loop.
@ LoopVariant
The SCEV is loop-variant (unknown).
@ LoopInvariant
The SCEV is loop-invariant.
LLVM_ABI bool isKnownMultipleOf(const SCEV *S, uint64_t M, SmallVectorImpl< const SCEVPredicate * > &Assumptions)
Check that S is a multiple of M.
LLVM_ABI const SCEV * getAnyExtendExpr(const SCEV *Op, Type *Ty)
getAnyExtendExpr - Return a SCEV for the given operand extended with unspecified bits out to the give...
LLVM_ABI bool isKnownToBeAPowerOfTwo(const SCEV *S, bool OrZero=false, bool OrNegative=false)
Test if the given expression is known to be a power of 2.
LLVM_ABI std::optional< SCEV::NoWrapFlags > getStrengthenedNoWrapFlagsFromBinOp(const OverflowingBinaryOperator *OBO)
Parse NSW/NUW flags from add/sub/mul IR binary operation Op into SCEV no-wrap flags,...
LLVM_ABI void forgetLcssaPhiWithNewPredecessor(Loop *L, PHINode *V)
Forget LCSSA phi node V of loop L to which a new predecessor was added, such that it may no longer be...
LLVM_ABI bool containsUndefs(const SCEV *S) const
Return true if the SCEV expression contains an undef value.
LLVM_ABI std::optional< MonotonicPredicateType > getMonotonicPredicateType(const SCEVAddRecExpr *LHS, ICmpInst::Predicate Pred)
If, for all loop invariant X, the predicate "LHS `Pred` X" is monotonically increasing or decreasing,...
LLVM_ABI const SCEV * getCouldNotCompute()
LLVM_ABI bool isAvailableAtLoopEntry(const SCEV *S, const Loop *L)
Determine if the SCEV can be evaluated at loop's entry.
LLVM_ABI uint32_t getMinTrailingZeros(const SCEV *S, const Instruction *CtxI=nullptr)
Determine the minimum number of zero bits that S is guaranteed to end in (at every loop iteration).
BlockDisposition
An enum describing the relationship between a SCEV and a basic block.
@ DominatesBlock
The SCEV dominates the block.
@ ProperlyDominatesBlock
The SCEV properly dominates the block.
@ DoesNotDominateBlock
The SCEV does not dominate the block.
LLVM_ABI const SCEV * getGEPExpr(GEPOperator *GEP, ArrayRef< const SCEV * > IndexExprs)
Returns an expression for a GEP.
LLVM_ABI const SCEV * getExitCount(const Loop *L, const BasicBlock *ExitingBlock, ExitCountKind Kind=Exact)
Return the number of times the backedge executes before the given exit would be taken; if not exactly...
LLVM_ABI const SCEV * getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI void getPoisonGeneratingValues(SmallPtrSetImpl< const Value * > &Result, const SCEV *S)
Return the set of Values that, if poison, will definitively result in S being poison as well.
LLVM_ABI void forgetLoopDispositions()
Called when the client has changed the disposition of values in this loop.
LLVM_ABI const SCEV * getVScale(Type *Ty)
LLVM_ABI unsigned getSmallConstantTripCount(const Loop *L)
Returns the exact trip count of the loop if we can compute it, and the result is a small constant.
LLVM_ABI bool hasComputableLoopEvolution(const SCEV *S, const Loop *L)
Return true if the given SCEV changes value in a known way in the specified loop.
LLVM_ABI const SCEV * getPointerBase(const SCEV *V)
Transitively follow the chain of pointer-type operands until reaching a SCEV that does not have a sin...
LLVM_ABI const SCEV * getMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< const SCEV * > &Operands)
LLVM_ABI void forgetAllLoops()
LLVM_ABI bool dominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV dominate the specified basic block.
APInt getUnsignedRangeMax(const SCEV *S)
Determine the max of the unsigned range for a particular SCEV.
ExitCountKind
The terms "backedge taken count" and "exit count" are used interchangeably to refer to the number of ...
@ SymbolicMaximum
An expression which provides an upper bound on the exact trip count.
@ ConstantMaximum
A constant which provides an upper bound on the exact trip count.
@ Exact
An expression exactly describing the number of times the backedge has executed when a loop is exited.
LLVM_ABI const SCEV * applyLoopGuards(const SCEV *Expr, const Loop *L)
Try to apply information from loop guards for L to Expr.
LLVM_ABI const SCEV * getMulExpr(SmallVectorImpl< const SCEV * > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical multiply expression, or something simpler if possible.
LLVM_ABI const SCEVAddRecExpr * convertSCEVToAddRecWithPredicates(const SCEV *S, const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Preds)
Tries to convert the S expression to an AddRec expression, adding additional predicates to Preds as r...
LLVM_ABI const SCEV * getElementSize(Instruction *Inst)
Return the size of an element read or written by Inst.
LLVM_ABI const SCEV * getSizeOfExpr(Type *IntTy, TypeSize Size)
Return an expression for a TypeSize.
LLVM_ABI std::optional< bool > evaluatePredicate(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Check whether the condition described by Pred, LHS, and RHS is true or false.
LLVM_ABI const SCEV * getUnknown(Value *V)
LLVM_ABI std::optional< std::pair< const SCEV *, SmallVector< const SCEVPredicate *, 3 > > > createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI)
Checks if SymbolicPHI can be rewritten as an AddRecExpr under some Predicates.
LLVM_ABI const SCEV * getTruncateOrZeroExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
static SCEV::NoWrapFlags maskFlags(SCEV::NoWrapFlags Flags, int Mask)
Convenient NoWrapFlags manipulation that hides enum casts and is visible in the ScalarEvolution name ...
LLVM_ABI std::optional< APInt > computeConstantDifference(const SCEV *LHS, const SCEV *RHS)
Compute LHS - RHS and returns the result as an APInt if it is a constant, and std::nullopt if it isn'...
LLVM_ABI bool properlyDominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV properly dominate the specified basic block.
LLVM_ABI const SCEV * rewriteUsingPredicate(const SCEV *S, const Loop *L, const SCEVPredicate &A)
Re-writes the SCEV according to the Predicates in A.
LLVM_ABI std::pair< const SCEV *, const SCEV * > SplitIntoInitAndPostInc(const Loop *L, const SCEV *S)
Splits SCEV expression S into two SCEVs.
LLVM_ABI bool canReuseInstruction(const SCEV *S, Instruction *I, SmallVectorImpl< Instruction * > &DropPoisonGeneratingInsts)
Check whether it is poison-safe to represent the expression S using the instruction I.
LLVM_ABI bool isKnownPredicateAt(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
LLVM_ABI const SCEV * getPredicatedSymbolicMaxBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getSymbolicMaxBackedgeTakenCount, except it will add a set of SCEV predicates to Predicate...
LLVM_ABI const SCEV * getUDivExactExpr(const SCEV *LHS, const SCEV *RHS)
Get a canonical unsigned division expression, or something simpler if possible.
LLVM_ABI void registerUser(const SCEV *User, ArrayRef< const SCEV * > Ops)
Notify this ScalarEvolution that User directly uses SCEVs in Ops.
LLVM_ABI const SCEV * getAddExpr(SmallVectorImpl< const SCEV * > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical add expression, or something simpler if possible.
LLVM_ABI bool isBasicBlockEntryGuardedByCond(const BasicBlock *BB, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the basic block is protected by a conditional between LHS and RHS.
LLVM_ABI const SCEV * getTruncateOrSignExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI bool containsErasedValue(const SCEV *S) const
Return true if the SCEV expression contains a Value that has been optimised out and is now a nullptr.
LLVM_ABI bool isKnownPredicate(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
LLVM_ABI bool isKnownViaInduction(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
We'd like to check the predicate on every iteration of the most dominated loop between loops used in ...
const SCEV * getSymbolicMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEV that is greater than or equal to (i.e.
APInt getSignedRangeMax(const SCEV *S)
Determine the max of the signed range for a particular SCEV.
LLVM_ABI void verify() const
LLVMContext & getContext() const
Implements a dense probed hash-table based set with some number of buckets stored inline.
Definition DenseSet.h:291
size_type size() const
Definition SmallPtrSet.h:99
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
bool contains(ConstPtrType Ptr) const
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
reference emplace_back(ArgTypes &&... Args)
void reserve(size_type N)
iterator erase(const_iterator CI)
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
iterator insert(iterator I, T &&Elt)
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
An instruction for storing to memory.
Used to lazily calculate structure layout information for a target machine, based on the DataLayout s...
Definition DataLayout.h:723
TypeSize getElementOffset(unsigned Idx) const
Definition DataLayout.h:754
TypeSize getSizeInBits() const
Definition DataLayout.h:734
Class to represent struct types.
Analysis pass providing the TargetLibraryInfo.
Provides information about what library functions are available for the current target.
The instances of the Type class are immutable: once they are created, they are never changed.
Definition Type.h:45
static LLVM_ABI IntegerType * getInt32Ty(LLVMContext &C)
Definition Type.cpp:296
bool isPointerTy() const
True if this is an instance of PointerType.
Definition Type.h:267
static LLVM_ABI IntegerType * getInt8Ty(LLVMContext &C)
Definition Type.cpp:294
LLVM_ABI TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
Definition Type.cpp:197
static LLVM_ABI IntegerType * getInt1Ty(LLVMContext &C)
Definition Type.cpp:293
bool isIntOrPtrTy() const
Return true if this is an integer type or a pointer type.
Definition Type.h:255
bool isIntegerTy() const
True if this is an instance of IntegerType.
Definition Type.h:240
static LLVM_ABI IntegerType * getIntNTy(LLVMContext &C, unsigned N)
Definition Type.cpp:300
A Use represents the edge between a Value definition and its users.
Definition Use.h:35
op_range operands()
Definition User.h:293
Use & Op()
Definition User.h:197
Value * getOperand(unsigned i) const
Definition User.h:233
LLVM Value Representation.
Definition Value.h:75
Type * getType() const
All values are typed, get the type of this value.
Definition Value.h:256
unsigned getValueID() const
Return an ID for the concrete type of this object.
Definition Value.h:543
LLVM_ABI void printAsOperand(raw_ostream &O, bool PrintType=true, const Module *M=nullptr) const
Print the name of this Value out to the specified raw_ostream.
LLVM_ABI LLVMContext & getContext() const
All values hold a context through their type.
Definition Value.cpp:1106
LLVM_ABI StringRef getName() const
Return a constant reference to the value's name.
Definition Value.cpp:322
constexpr bool isScalable() const
Returns whether the quantity is scaled by a runtime quantity (vscale).
Definition TypeSize.h:168
const ParentTy * getParent() const
Definition ilist_node.h:34
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition raw_ostream.h:53
raw_ostream & indent(unsigned NumSpaces)
indent - Insert 'NumSpaces' spaces.
Changed
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr char Align[]
Key for Kernel::Arg::Metadata::mAlign.
const APInt & smin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be signed.
Definition APInt.h:2249
const APInt & smax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be signed.
Definition APInt.h:2254
const APInt & umin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be unsigned.
Definition APInt.h:2259
LLVM_ABI std::optional< APInt > SolveQuadraticEquationWrap(APInt A, APInt B, APInt C, unsigned RangeWidth)
Let q(n) = An^2 + Bn + C, and BW = bit width of the value range (e.g.
Definition APInt.cpp:2812
const APInt & umax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be unsigned.
Definition APInt.h:2264
LLVM_ABI APInt GreatestCommonDivisor(APInt A, APInt B)
Compute GCD of two unsigned APInt values.
Definition APInt.cpp:798
@ Entry
Definition COFF.h:862
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition CallingConv.h:24
@ C
The default llvm calling convention, compatible with C.
Definition CallingConv.h:34
int getMinValue(MCInstrInfo const &MCII, MCInst const &MCI)
Return the minimum value of an extendable operand.
@ BasicBlock
Various leaf nodes.
Definition ISDOpcodes.h:81
LLVM_ABI Function * getDeclarationIfExists(const Module *M, ID id)
Look up the Function declaration of the intrinsic id in the Module M and return it if it exists.
Predicate
Predicate - These are "(BI << 5) | BO" for various predicates.
BinaryOp_match< LHS, RHS, Instruction::AShr > m_AShr(const LHS &L, const RHS &R)
ap_match< APInt > m_APInt(const APInt *&Res)
Match a ConstantInt or splatted ConstantVector, binding the specified pointer to the contained APInt.
bool match(Val *V, const Pattern &P)
class_match< ConstantInt > m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
IntrinsicID_match m_Intrinsic()
Match intrinsic calls like this: m_Intrinsic<Intrinsic::fabs>(m_Value(X))
ThreeOps_match< Cond, LHS, RHS, Instruction::Select > m_Select(const Cond &C, const LHS &L, const RHS &R)
Matches SelectInst.
ExtractValue_match< Ind, Val_t > m_ExtractValue(const Val_t &V)
Match a single index ExtractValue instruction.
bind_ty< WithOverflowInst > m_WithOverflowInst(WithOverflowInst *&I)
Match a with overflow intrinsic, capturing it if we match.
auto m_LogicalOr()
Matches L || R where L and R are arbitrary values.
brc_match< Cond_t, bind_ty< BasicBlock >, bind_ty< BasicBlock > > m_Br(const Cond_t &C, BasicBlock *&T, BasicBlock *&F)
BinaryOp_match< LHS, RHS, Instruction::SDiv > m_SDiv(const LHS &L, const RHS &R)
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
BinaryOp_match< LHS, RHS, Instruction::LShr > m_LShr(const LHS &L, const RHS &R)
BinaryOp_match< LHS, RHS, Instruction::Shl > m_Shl(const LHS &L, const RHS &R)
auto m_LogicalAnd()
Matches L && R where L and R are arbitrary values.
class_match< BasicBlock > m_BasicBlock()
Match an arbitrary basic block value and ignore it.
match_combine_or< LTy, RTy > m_CombineOr(const LTy &L, const RTy &R)
Combine two pattern matchers matching L || R.
class_match< const SCEVVScale > m_SCEVVScale()
bind_cst_ty m_scev_APInt(const APInt *&C)
Match an SCEV constant and bind it to an APInt.
cst_pred_ty< is_all_ones > m_scev_AllOnes()
Match an integer with all bits set.
SCEVUnaryExpr_match< SCEVZeroExtendExpr, Op0_t > m_scev_ZExt(const Op0_t &Op0)
is_undef_or_poison m_scev_UndefOrPoison()
Match an SCEVUnknown wrapping undef or poison.
class_match< const SCEVConstant > m_SCEVConstant()
cst_pred_ty< is_one > m_scev_One()
Match an integer 1.
specificloop_ty m_SpecificLoop(const Loop *L)
SCEVAffineAddRec_match< Op0_t, Op1_t, class_match< const Loop > > m_scev_AffineAddRec(const Op0_t &Op0, const Op1_t &Op1)
bind_ty< const SCEVMulExpr > m_scev_Mul(const SCEVMulExpr *&V)
SCEVUnaryExpr_match< SCEVSignExtendExpr, Op0_t > m_scev_SExt(const Op0_t &Op0)
cst_pred_ty< is_zero > m_scev_Zero()
Match an integer 0.
SCEVUnaryExpr_match< SCEVTruncateExpr, Op0_t > m_scev_Trunc(const Op0_t &Op0)
bool match(const SCEV *S, const Pattern &P)
SCEVBinaryExpr_match< SCEVUDivExpr, Op0_t, Op1_t > m_scev_UDiv(const Op0_t &Op0, const Op1_t &Op1)
specificscev_ty m_scev_Specific(const SCEV *S)
Match if we have a specific specified SCEV.
SCEVBinaryExpr_match< SCEVMulExpr, Op0_t, Op1_t, SCEV::FlagNUW, true > m_scev_c_NUWMul(const Op0_t &Op0, const Op1_t &Op1)
class_match< const Loop > m_Loop()
bind_ty< const SCEVAddExpr > m_scev_Add(const SCEVAddExpr *&V)
bind_ty< const SCEVUnknown > m_SCEVUnknown(const SCEVUnknown *&V)
SCEVBinaryExpr_match< SCEVMulExpr, Op0_t, Op1_t, SCEV::FlagAnyWrap, true > m_scev_c_Mul(const Op0_t &Op0, const Op1_t &Op1)
SCEVBinaryExpr_match< SCEVSMaxExpr, Op0_t, Op1_t > m_scev_SMax(const Op0_t &Op0, const Op1_t &Op1)
SCEVURem_match< Op0_t, Op1_t > m_scev_URem(Op0_t LHS, Op1_t RHS, ScalarEvolution &SE)
Match the mathematical pattern A - (A / B) * B, where A and B can be arbitrary expressions.
class_match< const SCEV > m_SCEV()
@ Valid
The data is already valid.
initializer< Ty > init(const Ty &Val)
LocationClass< Ty > location(Ty &L)
@ Switch
The "resume-switch" lowering, where there are separate resume and destroy functions that are shared b...
Definition CoroShape.h:31
constexpr double e
NodeAddr< PhiNode * > Phi
Definition RDFGraph.h:390
friend class Instruction
Iterator for Instructions in a `BasicBlock.
Definition BasicBlock.h:73
This is an optimization pass for GlobalISel generic memory operations.
Definition Types.h:26
void visitAll(const SCEV *Root, SV &Visitor)
Use SCEVTraversal to visit all nodes in the given expression tree.
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
Definition STLExtras.h:316
@ Offset
Definition DWP.cpp:532
FunctionAddr VTableAddr Value
Definition InstrProf.h:137
LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt gcd(const DynamicAPInt &A, const DynamicAPInt &B)
void stable_sort(R &&Range)
Definition STLExtras.h:2106
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1737
SaveAndRestore(T &) -> SaveAndRestore< T >
Printable print(const GCNRegPressure &RP, const GCNSubtarget *ST=nullptr, unsigned DynamicVGPRBlockSize=0)
LLVM_ABI bool canCreatePoison(const Operator *Op, bool ConsiderFlagsAndMetadata=true)
LLVM_ABI bool mustTriggerUB(const Instruction *I, const SmallPtrSetImpl< const Value * > &KnownPoison)
Return true if the given instruction must trigger undefined behavior when I is executed with any oper...
LLVM_ABI bool canConstantFoldCallTo(const CallBase *Call, const Function *F)
canConstantFoldCallTo - Return true if its even possible to fold a call to the specified function.
InterleavedRange< Range > interleaved(const Range &R, StringRef Separator=", ", StringRef Prefix="", StringRef Suffix="")
Output range R as a sequence of interleaved elements.
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:643
LLVM_ABI bool verifyFunction(const Function &F, raw_ostream *OS=nullptr)
Check a function for errors, useful for use when debugging a pass.
auto successors(const MachineBasicBlock *BB)
scope_exit(Callable) -> scope_exit< Callable >
constexpr from_range_t from_range
auto dyn_cast_if_present(const Y &Val)
dyn_cast_if_present<X> - Functionally identical to dyn_cast, except that a null (or none in the case ...
Definition Casting.h:732
bool set_is_subset(const S1Ty &S1, const S2Ty &S2)
set_is_subset(A, B) - Return true iff A in B
void append_range(Container &C, Range &&R)
Wrapper function to append range R to container C.
Definition STLExtras.h:2184
constexpr bool isUIntN(unsigned N, uint64_t x)
Checks if an unsigned integer fits into the given (dynamic) bit width.
Definition MathExtras.h:243
LLVM_ABI Constant * ConstantFoldCompareInstOperands(unsigned Predicate, Constant *LHS, Constant *RHS, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr, const Instruction *I=nullptr)
Attempt to constant fold a compare instruction (icmp/fcmp) with the specified operands.
unsigned short computeExpressionSize(ArrayRef< const SCEV * > Args)
void * PointerTy
LLVM_ABI bool VerifySCEV
auto uninitialized_copy(R &&Src, IterTy Dst)
Definition STLExtras.h:2101
bool isa_and_nonnull(const Y &Val)
Definition Casting.h:676
LLVM_ABI ConstantRange getConstantRangeFromMetadata(const MDNode &RangeMD)
Parse out a conservative ConstantRange from !range metadata.
int countr_zero(T Val)
Count number of 0's from the least significant bit to the most stopping at the first 1.
Definition bit.h:202
LLVM_ABI Value * simplifyInstruction(Instruction *I, const SimplifyQuery &Q)
See if we can compute a simplified version of this instruction.
LLVM_ABI bool isOverflowIntrinsicNoWrap(const WithOverflowInst *WO, const DominatorTree &DT)
Returns true if the arithmetic part of the WO 's result is used only along the paths control dependen...
DomTreeNodeBase< BasicBlock > DomTreeNode
Definition Dominators.h:94
LLVM_ABI bool matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO, Value *&Start, Value *&Step)
Attempt to match a simple first order recurrence cycle of the form: iv = phi Ty [Start,...
auto dyn_cast_or_null(const Y &Val)
Definition Casting.h:753
void erase(Container &C, ValueType V)
Wrapper function to remove a value from a container:
Definition STLExtras.h:2176
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1744
iterator_range< pointee_iterator< WrappedIteratorT > > make_pointee_range(RangeT &&Range)
Definition iterator.h:341
auto reverse(ContainerTy &&C)
Definition STLExtras.h:406
LLVM_ABI bool isMustProgress(const Loop *L)
Return true if this loop can be assumed to make progress.
LLVM_ABI bool impliesPoison(const Value *ValAssumedPoison, const Value *V)
Return true if V is poison given that ValAssumedPoison is already poison.
LLVM_ABI bool isFinite(const Loop *L)
Return true if this loop can be assumed to run for a finite number of iterations.
LLVM_ABI void computeKnownBits(const Value *V, KnownBits &Known, const DataLayout &DL, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true, unsigned Depth=0)
Determine which bits of V are known to be either zero or one and return them in the KnownZero/KnownOn...
LLVM_ABI bool programUndefinedIfPoison(const Instruction *Inst)
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:207
bool isPointerTy(const Type *T)
Definition SPIRVUtils.h:361
FunctionAddr VTableAddr Count
Definition InstrProf.h:139
LLVM_ABI ConstantRange getVScaleRange(const Function *F, unsigned BitWidth)
Determine the possible constant range of vscale with the given bit width, based on the vscale_range f...
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
Definition Casting.h:547
LLVM_ATTRIBUTE_VISIBILITY_DEFAULT AnalysisKey InnerAnalysisManagerProxy< AnalysisManagerT, IRUnitT, ExtraArgTs... >::Key
LLVM_ABI bool isKnownNonZero(const Value *V, const SimplifyQuery &Q, unsigned Depth=0)
Return true if the given value is known to be non-zero when defined.
@ First
Helpers to iterate all locations in the MemoryEffectsBase class.
Definition ModRef.h:74
LLVM_ABI bool propagatesPoison(const Use &PoisonOp)
Return true if PoisonOp's user yields poison or raises UB if its operand PoisonOp is poison.
@ UMin
Unsigned integer min implemented in terms of select(cmp()).
@ Mul
Product of integers.
@ SMax
Signed integer max implemented in terms of select(cmp()).
@ SMin
Signed integer min implemented in terms of select(cmp()).
@ Add
Sum of integers.
@ UMax
Unsigned integer max implemented in terms of select(cmp()).
auto count(R &&Range, const E &Element)
Wrapper function around std::count to count the number of times an element Element occurs in the give...
Definition STLExtras.h:2002
DWARFExpression::Operation Op
auto max_element(R &&Range)
Provide wrappers to std::max_element which take ranges instead of having to pass begin/end explicitly...
Definition STLExtras.h:2078
raw_ostream & operator<<(raw_ostream &OS, const APFixedPoint &FX)
ArrayRef(const T &OneElt) -> ArrayRef< T >
LLVM_ABI unsigned ComputeNumSignBits(const Value *Op, const DataLayout &DL, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true, unsigned Depth=0)
Return the number of times the sign bit of the register is replicated into the other bits.
constexpr unsigned BitWidth
OutputIt move(R &&Range, OutputIt Out)
Provide wrappers to std::move which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1915
LLVM_ABI bool isGuaranteedToTransferExecutionToSuccessor(const Instruction *I)
Return true if this function can prove that the instruction I will always transfer execution to one o...
auto count_if(R &&Range, UnaryPredicate P)
Wrapper function around std::count_if to count the number of times an element satisfying a given pred...
Definition STLExtras.h:2009
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:559
constexpr bool isIntN(unsigned N, int64_t x)
Checks if an signed integer fits into the given (dynamic) bit width.
Definition MathExtras.h:248
auto predecessors(const MachineBasicBlock *BB)
bool is_contained(R &&Range, const E &Element)
Returns true if Element is found in Range.
Definition STLExtras.h:1945
iterator_range< df_iterator< T > > depth_first(const T &G)
auto seq(T Begin, T End)
Iterate over an integral type from Begin up to - but not including - End.
Definition Sequence.h:305
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.
LLVM_ABI bool isGuaranteedNotToBePoison(const Value *V, AssumptionCache *AC=nullptr, const Instruction *CtxI=nullptr, const DominatorTree *DT=nullptr, unsigned Depth=0)
Returns true if V cannot be poison, but may be undef.
LLVM_ABI Constant * ConstantFoldInstOperands(const Instruction *I, ArrayRef< Constant * > Ops, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr, bool AllowNonDeterministic=true)
ConstantFoldInstOperands - Attempt to constant fold an instruction with the specified operands.
bool SCEVExprContains(const SCEV *Root, PredTy Pred)
Return true if any node in Root satisfies the predicate Pred.
Implement std::hash so that hash_code can be used in STL containers.
Definition BitVector.h:870
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition BitVector.h:872
#define N
#define NC
Definition regutils.h:42
A special type used by analysis passes to provide an address that identifies that particular analysis...
Definition Analysis.h:29
static KnownBits makeConstant(const APInt &C)
Create known bits from a known constant.
Definition KnownBits.h:301
bool isNonNegative() const
Returns true if this value is known to be non-negative.
Definition KnownBits.h:108
static LLVM_ABI KnownBits ashr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for ashr(LHS, RHS).
unsigned getBitWidth() const
Get the bit width of this value.
Definition KnownBits.h:44
static LLVM_ABI KnownBits lshr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for lshr(LHS, RHS).
KnownBits zextOrTrunc(unsigned BitWidth) const
Return known bits for a zero extension or truncation of the value we're tracking.
Definition KnownBits.h:196
APInt getMaxValue() const
Return the maximal unsigned value possible given these KnownBits.
Definition KnownBits.h:145
APInt getMinValue() const
Return the minimal unsigned value possible given these KnownBits.
Definition KnownBits.h:129
bool isNegative() const
Returns true if this value is known to be negative.
Definition KnownBits.h:105
static LLVM_ABI KnownBits shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW=false, bool NSW=false, bool ShAmtNonZero=false)
Compute known bits for shl(LHS, RHS).
An object of this class is returned by queries that could not be answered.
static LLVM_ABI bool classof(const SCEV *S)
Methods for support type inquiry through isa, cast, and dyn_cast:
This class defines a simple visitor class that may be used for various SCEV analysis purposes.
A utility class that uses RAII to save and restore the value of a variable.
Information about the number of loop iterations for which a loop exit's branch condition evaluates to...
LLVM_ABI ExitLimit(const SCEV *E)
Construct either an exact exit limit from a constant, or an unknown one from a SCEVCouldNotCompute.
SmallVector< const SCEVPredicate *, 4 > Predicates
A vector of predicate guards for this ExitLimit.