LLVM 23.0.0git
ScalarEvolution.cpp
Go to the documentation of this file.
1//===- ScalarEvolution.cpp - Scalar Evolution Analysis --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains the implementation of the scalar evolution analysis
10// engine, which is used primarily to analyze expressions involving induction
11// variables in loops.
12//
13// There are several aspects to this library. First is the representation of
14// scalar expressions, which are represented as subclasses of the SCEV class.
15// These classes are used to represent certain types of subexpressions that we
16// can handle. We only create one SCEV of a particular shape, so
17// pointer-comparisons for equality are legal.
18//
19// One important aspect of the SCEV objects is that they are never cyclic, even
20// if there is a cycle in the dataflow for an expression (ie, a PHI node). If
21// the PHI node is one of the idioms that we can represent (e.g., a polynomial
22// recurrence) then we represent it directly as a recurrence node, otherwise we
23// represent it as a SCEVUnknown node.
24//
25// In addition to being able to represent expressions of various types, we also
26// have folders that are used to build the *canonical* representation for a
27// particular expression. These folders are capable of using a variety of
28// rewrite rules to simplify the expressions.
29//
30// Once the folders are defined, we can implement the more interesting
31// higher-level code, such as the code that recognizes PHI nodes of various
32// types, computes the execution count of a loop, etc.
33//
34// TODO: We should use these routines and value representations to implement
35// dependence analysis!
36//
37//===----------------------------------------------------------------------===//
38//
39// There are several good references for the techniques used in this analysis.
40//
41// Chains of recurrences -- a method to expedite the evaluation
42// of closed-form functions
43// Olaf Bachmann, Paul S. Wang, Eugene V. Zima
44//
45// On computational properties of chains of recurrences
46// Eugene V. Zima
47//
48// Symbolic Evaluation of Chains of Recurrences for Loop Optimization
49// Robert A. van Engelen
50//
51// Efficient Symbolic Analysis for Optimizing Compilers
52// Robert A. van Engelen
53//
54// Using the chains of recurrences algebra for data dependence testing and
55// induction variable substitution
56// MS Thesis, Johnie Birch
57//
58//===----------------------------------------------------------------------===//
59
61#include "llvm/ADT/APInt.h"
62#include "llvm/ADT/ArrayRef.h"
63#include "llvm/ADT/DenseMap.h"
65#include "llvm/ADT/FoldingSet.h"
66#include "llvm/ADT/STLExtras.h"
67#include "llvm/ADT/ScopeExit.h"
68#include "llvm/ADT/Sequence.h"
71#include "llvm/ADT/Statistic.h"
73#include "llvm/ADT/StringRef.h"
83#include "llvm/Config/llvm-config.h"
84#include "llvm/IR/Argument.h"
85#include "llvm/IR/BasicBlock.h"
86#include "llvm/IR/CFG.h"
87#include "llvm/IR/Constant.h"
89#include "llvm/IR/Constants.h"
90#include "llvm/IR/DataLayout.h"
92#include "llvm/IR/Dominators.h"
93#include "llvm/IR/Function.h"
94#include "llvm/IR/GlobalAlias.h"
95#include "llvm/IR/GlobalValue.h"
97#include "llvm/IR/InstrTypes.h"
98#include "llvm/IR/Instruction.h"
101#include "llvm/IR/Intrinsics.h"
102#include "llvm/IR/LLVMContext.h"
103#include "llvm/IR/Operator.h"
104#include "llvm/IR/PatternMatch.h"
105#include "llvm/IR/Type.h"
106#include "llvm/IR/Use.h"
107#include "llvm/IR/User.h"
108#include "llvm/IR/Value.h"
109#include "llvm/IR/Verifier.h"
111#include "llvm/Pass.h"
112#include "llvm/Support/Casting.h"
115#include "llvm/Support/Debug.h"
121#include <algorithm>
122#include <cassert>
123#include <climits>
124#include <cstdint>
125#include <cstdlib>
126#include <map>
127#include <memory>
128#include <numeric>
129#include <optional>
130#include <tuple>
131#include <utility>
132#include <vector>
133
134using namespace llvm;
135using namespace PatternMatch;
136using namespace SCEVPatternMatch;
137
138#define DEBUG_TYPE "scalar-evolution"
139
140STATISTIC(NumExitCountsComputed,
141 "Number of loop exits with predictable exit counts");
142STATISTIC(NumExitCountsNotComputed,
143 "Number of loop exits without predictable exit counts");
144STATISTIC(NumBruteForceTripCountsComputed,
145 "Number of loops with trip counts computed by force");
146
147#ifdef EXPENSIVE_CHECKS
148bool llvm::VerifySCEV = true;
149#else
150bool llvm::VerifySCEV = false;
151#endif
152
154 MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
155 cl::desc("Maximum number of iterations SCEV will "
156 "symbolically execute a constant "
157 "derived loop"),
158 cl::init(100));
159
161 "verify-scev", cl::Hidden, cl::location(VerifySCEV),
162 cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"));
164 "verify-scev-strict", cl::Hidden,
165 cl::desc("Enable stricter verification with -verify-scev is passed"));
166
168 "scev-verify-ir", cl::Hidden,
169 cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"),
170 cl::init(false));
171
173 "scev-mulops-inline-threshold", cl::Hidden,
174 cl::desc("Threshold for inlining multiplication operands into a SCEV"),
175 cl::init(32));
176
178 "scev-addops-inline-threshold", cl::Hidden,
179 cl::desc("Threshold for inlining addition operands into a SCEV"),
180 cl::init(500));
181
183 "scalar-evolution-max-scev-compare-depth", cl::Hidden,
184 cl::desc("Maximum depth of recursive SCEV complexity comparisons"),
185 cl::init(32));
186
188 "scalar-evolution-max-scev-operations-implication-depth", cl::Hidden,
189 cl::desc("Maximum depth of recursive SCEV operations implication analysis"),
190 cl::init(2));
191
193 "scalar-evolution-max-value-compare-depth", cl::Hidden,
194 cl::desc("Maximum depth of recursive value complexity comparisons"),
195 cl::init(2));
196
198 MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden,
199 cl::desc("Maximum depth of recursive arithmetics"),
200 cl::init(32));
201
203 "scalar-evolution-max-constant-evolving-depth", cl::Hidden,
204 cl::desc("Maximum depth of recursive constant evolving"), cl::init(32));
205
207 MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden,
208 cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"),
209 cl::init(8));
210
212 MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden,
213 cl::desc("Max coefficients in AddRec during evolving"),
214 cl::init(8));
215
217 HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden,
218 cl::desc("Size of the expression which is considered huge"),
219 cl::init(4096));
220
222 "scev-range-iter-threshold", cl::Hidden,
223 cl::desc("Threshold for switching to iteratively computing SCEV ranges"),
224 cl::init(32));
225
227 "scalar-evolution-max-loop-guard-collection-depth", cl::Hidden,
228 cl::desc("Maximum depth for recursive loop guard collection"), cl::init(1));
229
230static cl::opt<bool>
231ClassifyExpressions("scalar-evolution-classify-expressions",
232 cl::Hidden, cl::init(true),
233 cl::desc("When printing analysis, include information on every instruction"));
234
236 "scalar-evolution-use-expensive-range-sharpening", cl::Hidden,
237 cl::init(false),
238 cl::desc("Use more powerful methods of sharpening expression ranges. May "
239 "be costly in terms of compile time"));
240
242 "scalar-evolution-max-scc-analysis-depth", cl::Hidden,
243 cl::desc("Maximum amount of nodes to process while searching SCEVUnknown "
244 "Phi strongly connected components"),
245 cl::init(8));
246
247static cl::opt<bool>
248 EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden,
249 cl::desc("Handle <= and >= in finite loops"),
250 cl::init(true));
251
253 "scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden,
254 cl::desc("Infer nuw/nsw flags using context where suitable"),
255 cl::init(true));
256
257//===----------------------------------------------------------------------===//
258// SCEV class definitions
259//===----------------------------------------------------------------------===//
260
262 // Leaf nodes are always their own canonical.
263 switch (getSCEVType()) {
264 case scConstant:
265 case scVScale:
266 case scUnknown:
267 CanonicalSCEV = this;
268 return;
269 default:
270 break;
271 }
272
273 // For all other expressions, check whether any immediate operand has a
274 // different canonical. Since operands are always created before their parent,
275 // their canonical pointers are already set — no recursion needed.
276 bool Changed = false;
278 for (SCEVUse Op : operands()) {
279 CanonOps.push_back(Op->getCanonical());
280 Changed |= CanonOps.back() != Op.getPointer();
281 }
282
283 if (!Changed) {
284 CanonicalSCEV = this;
285 return;
286 }
287
288 auto *NAry = dyn_cast<SCEVNAryExpr>(this);
289 SCEV::NoWrapFlags Flags = NAry ? NAry->getNoWrapFlags() : SCEV::FlagAnyWrap;
290 switch (getSCEVType()) {
291 case scPtrToAddr:
292 CanonicalSCEV = SE.getPtrToAddrExpr(CanonOps[0]);
293 return;
294 case scPtrToInt:
295 CanonicalSCEV = SE.getPtrToIntExpr(CanonOps[0], getType());
296 return;
297 case scTruncate:
298 CanonicalSCEV = SE.getTruncateExpr(CanonOps[0], getType());
299 return;
300 case scZeroExtend:
301 CanonicalSCEV = SE.getZeroExtendExpr(CanonOps[0], getType());
302 return;
303 case scSignExtend:
304 CanonicalSCEV = SE.getSignExtendExpr(CanonOps[0], getType());
305 return;
306 case scUDivExpr:
307 CanonicalSCEV = SE.getUDivExpr(CanonOps[0], CanonOps[1]);
308 return;
309 case scAddExpr:
310 CanonicalSCEV = SE.getAddExpr(CanonOps, Flags);
311 return;
312 case scMulExpr:
313 CanonicalSCEV = SE.getMulExpr(CanonOps, Flags);
314 return;
315 case scAddRecExpr:
317 CanonOps, cast<SCEVAddRecExpr>(this)->getLoop(), Flags);
318 return;
319 case scSMaxExpr:
320 CanonicalSCEV = SE.getSMaxExpr(CanonOps);
321 return;
322 case scUMaxExpr:
323 CanonicalSCEV = SE.getUMaxExpr(CanonOps);
324 return;
325 case scSMinExpr:
326 CanonicalSCEV = SE.getSMinExpr(CanonOps);
327 return;
328 case scUMinExpr:
329 CanonicalSCEV = SE.getUMinExpr(CanonOps);
330 return;
332 CanonicalSCEV = SE.getUMinExpr(CanonOps, /*Sequential=*/true);
333 return;
334 default:
335 llvm_unreachable("Unknown SCEV type");
336 }
337}
338
339//===----------------------------------------------------------------------===//
340// Implementation of the SCEV class.
341//
342
343#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
345 print(dbgs());
346 dbgs() << '\n';
347}
348#endif
349
350void SCEV::print(raw_ostream &OS) const {
351 switch (getSCEVType()) {
352 case scConstant:
353 cast<SCEVConstant>(this)->getValue()->printAsOperand(OS, false);
354 return;
355 case scVScale:
356 OS << "vscale";
357 return;
358 case scPtrToAddr:
359 case scPtrToInt: {
360 const SCEVCastExpr *PtrCast = cast<SCEVCastExpr>(this);
361 const SCEV *Op = PtrCast->getOperand();
362 StringRef OpS = getSCEVType() == scPtrToAddr ? "addr" : "int";
363 OS << "(ptrto" << OpS << " " << *Op->getType() << " " << *Op << " to "
364 << *PtrCast->getType() << ")";
365 return;
366 }
367 case scTruncate: {
368 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(this);
369 const SCEV *Op = Trunc->getOperand();
370 OS << "(trunc " << *Op->getType() << " " << *Op << " to "
371 << *Trunc->getType() << ")";
372 return;
373 }
374 case scZeroExtend: {
376 const SCEV *Op = ZExt->getOperand();
377 OS << "(zext " << *Op->getType() << " " << *Op << " to "
378 << *ZExt->getType() << ")";
379 return;
380 }
381 case scSignExtend: {
383 const SCEV *Op = SExt->getOperand();
384 OS << "(sext " << *Op->getType() << " " << *Op << " to "
385 << *SExt->getType() << ")";
386 return;
387 }
388 case scAddRecExpr: {
389 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(this);
390 OS << "{" << *AR->getOperand(0);
391 for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i)
392 OS << ",+," << *AR->getOperand(i);
393 OS << "}<";
394 if (AR->hasNoUnsignedWrap())
395 OS << "nuw><";
396 if (AR->hasNoSignedWrap())
397 OS << "nsw><";
398 if (AR->hasNoSelfWrap() && !AR->hasNoUnsignedWrap() &&
399 !AR->hasNoSignedWrap())
400 OS << "nw><";
401 AR->getLoop()->getHeader()->printAsOperand(OS, /*PrintType=*/false);
402 OS << ">";
403 return;
404 }
405 case scAddExpr:
406 case scMulExpr:
407 case scUMaxExpr:
408 case scSMaxExpr:
409 case scUMinExpr:
410 case scSMinExpr:
412 const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(this);
413 const char *OpStr = nullptr;
414 switch (NAry->getSCEVType()) {
415 case scAddExpr: OpStr = " + "; break;
416 case scMulExpr: OpStr = " * "; break;
417 case scUMaxExpr: OpStr = " umax "; break;
418 case scSMaxExpr: OpStr = " smax "; break;
419 case scUMinExpr:
420 OpStr = " umin ";
421 break;
422 case scSMinExpr:
423 OpStr = " smin ";
424 break;
426 OpStr = " umin_seq ";
427 break;
428 default:
429 llvm_unreachable("There are no other nary expression types.");
430 }
431 OS << "("
433 << ")";
434 switch (NAry->getSCEVType()) {
435 case scAddExpr:
436 case scMulExpr:
437 if (NAry->hasNoUnsignedWrap())
438 OS << "<nuw>";
439 if (NAry->hasNoSignedWrap())
440 OS << "<nsw>";
441 break;
442 default:
443 // Nothing to print for other nary expressions.
444 break;
445 }
446 return;
447 }
448 case scUDivExpr: {
449 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(this);
450 OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")";
451 return;
452 }
453 case scUnknown:
454 cast<SCEVUnknown>(this)->getValue()->printAsOperand(OS, false);
455 return;
457 OS << "***COULDNOTCOMPUTE***";
458 return;
459 }
460 llvm_unreachable("Unknown SCEV kind!");
461}
462
464 switch (getSCEVType()) {
465 case scConstant:
466 return cast<SCEVConstant>(this)->getType();
467 case scVScale:
468 return cast<SCEVVScale>(this)->getType();
469 case scPtrToAddr:
470 case scPtrToInt:
471 case scTruncate:
472 case scZeroExtend:
473 case scSignExtend:
474 return cast<SCEVCastExpr>(this)->getType();
475 case scAddRecExpr:
476 return cast<SCEVAddRecExpr>(this)->getType();
477 case scMulExpr:
478 return cast<SCEVMulExpr>(this)->getType();
479 case scUMaxExpr:
480 case scSMaxExpr:
481 case scUMinExpr:
482 case scSMinExpr:
483 return cast<SCEVMinMaxExpr>(this)->getType();
485 return cast<SCEVSequentialMinMaxExpr>(this)->getType();
486 case scAddExpr:
487 return cast<SCEVAddExpr>(this)->getType();
488 case scUDivExpr:
489 return cast<SCEVUDivExpr>(this)->getType();
490 case scUnknown:
491 return cast<SCEVUnknown>(this)->getType();
493 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
494 }
495 llvm_unreachable("Unknown SCEV kind!");
496}
497
499 switch (getSCEVType()) {
500 case scConstant:
501 case scVScale:
502 case scUnknown:
503 return {};
504 case scPtrToAddr:
505 case scPtrToInt:
506 case scTruncate:
507 case scZeroExtend:
508 case scSignExtend:
509 return cast<SCEVCastExpr>(this)->operands();
510 case scAddRecExpr:
511 case scAddExpr:
512 case scMulExpr:
513 case scUMaxExpr:
514 case scSMaxExpr:
515 case scUMinExpr:
516 case scSMinExpr:
518 return cast<SCEVNAryExpr>(this)->operands();
519 case scUDivExpr:
520 return cast<SCEVUDivExpr>(this)->operands();
522 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
523 }
524 llvm_unreachable("Unknown SCEV kind!");
525}
526
527bool SCEV::isZero() const { return match(this, m_scev_Zero()); }
528
529bool SCEV::isOne() const { return match(this, m_scev_One()); }
530
531bool SCEV::isAllOnesValue() const { return match(this, m_scev_AllOnes()); }
532
535 if (!Mul) return false;
536
537 // If there is a constant factor, it will be first.
538 const SCEVConstant *SC = dyn_cast<SCEVConstant>(Mul->getOperand(0));
539 if (!SC) return false;
540
541 // Return true if the value is negative, this matches things like (-42 * V).
542 return SC->getAPInt().isNegative();
543}
544
547
549 return S->getSCEVType() == scCouldNotCompute;
550}
551
554 ID.AddInteger(scConstant);
555 ID.AddPointer(V);
556 void *IP = nullptr;
557 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
558 SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V);
559 UniqueSCEVs.InsertNode(S, IP);
560 S->computeAndSetCanonical(*this);
561 return S;
562}
563
565 return getConstant(ConstantInt::get(getContext(), Val));
566}
567
568const SCEV *
571 // TODO: Avoid implicit trunc?
572 // See https://github.com/llvm/llvm-project/issues/112510.
573 return getConstant(
574 ConstantInt::get(ITy, V, isSigned, /*ImplicitTrunc=*/true));
575}
576
579 ID.AddInteger(scVScale);
580 ID.AddPointer(Ty);
581 void *IP = nullptr;
582 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
583 return S;
584 SCEV *S = new (SCEVAllocator) SCEVVScale(ID.Intern(SCEVAllocator), Ty);
585 UniqueSCEVs.InsertNode(S, IP);
586 S->computeAndSetCanonical(*this);
587 return S;
588}
589
591 SCEV::NoWrapFlags Flags) {
592 const SCEV *Res = getConstant(Ty, EC.getKnownMinValue());
593 if (EC.isScalable())
594 Res = getMulExpr(Res, getVScale(Ty), Flags);
595 return Res;
596}
597
601
602SCEVPtrToAddrExpr::SCEVPtrToAddrExpr(const FoldingSetNodeIDRef ID,
603 const SCEV *Op, Type *ITy)
604 : SCEVCastExpr(ID, scPtrToAddr, Op, ITy) {
605 assert(getOperand()->getType()->isPointerTy() && Ty->isIntegerTy() &&
606 "Must be a non-bit-width-changing pointer-to-integer cast!");
607}
608
609SCEVPtrToIntExpr::SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, SCEVUse Op,
610 Type *ITy)
611 : SCEVCastExpr(ID, scPtrToInt, Op, ITy) {
612 assert(getOperand()->getType()->isPointerTy() && Ty->isIntegerTy() &&
613 "Must be a non-bit-width-changing pointer-to-integer cast!");
614}
615
620
621SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID, SCEVUse op,
622 Type *ty)
624 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
625 "Cannot truncate non-integer value!");
626}
627
628SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID, SCEVUse op,
629 Type *ty)
631 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
632 "Cannot zero extend non-integer value!");
633}
634
635SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID, SCEVUse op,
636 Type *ty)
638 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
639 "Cannot sign extend non-integer value!");
640}
641
643 // Clear this SCEVUnknown from various maps.
644 SE->forgetMemoizedResults({this});
645
646 // Remove this SCEVUnknown from the uniquing map.
647 SE->UniqueSCEVs.RemoveNode(this);
648
649 // Release the value.
650 setValPtr(nullptr);
651}
652
653void SCEVUnknown::allUsesReplacedWith(Value *New) {
654 // Clear this SCEVUnknown from various maps.
655 SE->forgetMemoizedResults({this});
656
657 // Remove this SCEVUnknown from the uniquing map.
658 SE->UniqueSCEVs.RemoveNode(this);
659
660 // Replace the value pointer in case someone is still using this SCEVUnknown.
661 setValPtr(New);
662}
663
664//===----------------------------------------------------------------------===//
665// SCEV Utilities
666//===----------------------------------------------------------------------===//
667
668/// Compare the two values \p LV and \p RV in terms of their "complexity" where
669/// "complexity" is a partial (and somewhat ad-hoc) relation used to order
670/// operands in SCEV expressions.
671static int CompareValueComplexity(const LoopInfo *const LI, Value *LV,
672 Value *RV, unsigned Depth) {
674 return 0;
675
676 // Order pointer values after integer values. This helps SCEVExpander form
677 // GEPs.
678 bool LIsPointer = LV->getType()->isPointerTy(),
679 RIsPointer = RV->getType()->isPointerTy();
680 if (LIsPointer != RIsPointer)
681 return (int)LIsPointer - (int)RIsPointer;
682
683 // Compare getValueID values.
684 unsigned LID = LV->getValueID(), RID = RV->getValueID();
685 if (LID != RID)
686 return (int)LID - (int)RID;
687
688 // Sort arguments by their position.
689 if (const auto *LA = dyn_cast<Argument>(LV)) {
690 const auto *RA = cast<Argument>(RV);
691 unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo();
692 return (int)LArgNo - (int)RArgNo;
693 }
694
695 if (const auto *LGV = dyn_cast<GlobalValue>(LV)) {
696 const auto *RGV = cast<GlobalValue>(RV);
697
698 if (auto L = LGV->getLinkage() - RGV->getLinkage())
699 return L;
700
701 const auto IsGVNameSemantic = [&](const GlobalValue *GV) {
702 auto LT = GV->getLinkage();
703 return !(GlobalValue::isPrivateLinkage(LT) ||
705 };
706
707 // Use the names to distinguish the two values, but only if the
708 // names are semantically important.
709 if (IsGVNameSemantic(LGV) && IsGVNameSemantic(RGV))
710 return LGV->getName().compare(RGV->getName());
711 }
712
713 // For instructions, compare their loop depth, and their operand count. This
714 // is pretty loose.
715 if (const auto *LInst = dyn_cast<Instruction>(LV)) {
716 const auto *RInst = cast<Instruction>(RV);
717
718 // Compare loop depths.
719 const BasicBlock *LParent = LInst->getParent(),
720 *RParent = RInst->getParent();
721 if (LParent != RParent) {
722 unsigned LDepth = LI->getLoopDepth(LParent),
723 RDepth = LI->getLoopDepth(RParent);
724 if (LDepth != RDepth)
725 return (int)LDepth - (int)RDepth;
726 }
727
728 // Compare the number of operands.
729 unsigned LNumOps = LInst->getNumOperands(),
730 RNumOps = RInst->getNumOperands();
731 if (LNumOps != RNumOps)
732 return (int)LNumOps - (int)RNumOps;
733
734 for (unsigned Idx : seq(LNumOps)) {
735 int Result = CompareValueComplexity(LI, LInst->getOperand(Idx),
736 RInst->getOperand(Idx), Depth + 1);
737 if (Result != 0)
738 return Result;
739 }
740 }
741
742 return 0;
743}
744
745// Return negative, zero, or positive, if LHS is less than, equal to, or greater
746// than RHS, respectively. A three-way result allows recursive comparisons to be
747// more efficient.
748// If the max analysis depth was reached, return std::nullopt, assuming we do
749// not know if they are equivalent for sure.
750static std::optional<int>
751CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS,
752 const SCEV *RHS, DominatorTree &DT, unsigned Depth = 0) {
753 // Fast-path: SCEVs are uniqued so we can do a quick equality check.
754 if (LHS == RHS)
755 return 0;
756
757 // Primarily, sort the SCEVs by their getSCEVType().
758 SCEVTypes LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
759 if (LType != RType)
760 return (int)LType - (int)RType;
761
763 return std::nullopt;
764
765 // Aside from the getSCEVType() ordering, the particular ordering
766 // isn't very important except that it's beneficial to be consistent,
767 // so that (a + b) and (b + a) don't end up as different expressions.
768 switch (LType) {
769 case scUnknown: {
770 const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
771 const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
772
773 int X =
774 CompareValueComplexity(LI, LU->getValue(), RU->getValue(), Depth + 1);
775 return X;
776 }
777
778 case scConstant: {
781
782 // Compare constant values.
783 const APInt &LA = LC->getAPInt();
784 const APInt &RA = RC->getAPInt();
785 unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth();
786 if (LBitWidth != RBitWidth)
787 return (int)LBitWidth - (int)RBitWidth;
788 return LA.ult(RA) ? -1 : 1;
789 }
790
791 case scVScale: {
792 const auto *LTy = cast<IntegerType>(cast<SCEVVScale>(LHS)->getType());
793 const auto *RTy = cast<IntegerType>(cast<SCEVVScale>(RHS)->getType());
794 return LTy->getBitWidth() - RTy->getBitWidth();
795 }
796
797 case scAddRecExpr: {
800
801 // There is always a dominance between two recs that are used by one SCEV,
802 // so we can safely sort recs by loop header dominance. We require such
803 // order in getAddExpr.
804 const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop();
805 if (LLoop != RLoop) {
806 const BasicBlock *LHead = LLoop->getHeader(), *RHead = RLoop->getHeader();
807 assert(LHead != RHead && "Two loops share the same header?");
808 if (DT.dominates(LHead, RHead))
809 return 1;
810 assert(DT.dominates(RHead, LHead) &&
811 "No dominance between recurrences used by one SCEV?");
812 return -1;
813 }
814
815 [[fallthrough]];
816 }
817
818 case scTruncate:
819 case scZeroExtend:
820 case scSignExtend:
821 case scPtrToAddr:
822 case scPtrToInt:
823 case scAddExpr:
824 case scMulExpr:
825 case scUDivExpr:
826 case scSMaxExpr:
827 case scUMaxExpr:
828 case scSMinExpr:
829 case scUMinExpr:
831 ArrayRef<SCEVUse> LOps = LHS->operands();
832 ArrayRef<SCEVUse> ROps = RHS->operands();
833
834 // Lexicographically compare n-ary-like expressions.
835 unsigned LNumOps = LOps.size(), RNumOps = ROps.size();
836 if (LNumOps != RNumOps)
837 return (int)LNumOps - (int)RNumOps;
838
839 for (unsigned i = 0; i != LNumOps; ++i) {
840 auto X = CompareSCEVComplexity(LI, LOps[i].getPointer(),
841 ROps[i].getPointer(), DT, Depth + 1);
842 if (X != 0)
843 return X;
844 }
845 return 0;
846 }
847
849 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
850 }
851 llvm_unreachable("Unknown SCEV kind!");
852}
853
854/// Given a list of SCEV objects, order them by their complexity, and group
855/// objects of the same complexity together by value. When this routine is
856/// finished, we know that any duplicates in the vector are consecutive and that
857/// complexity is monotonically increasing.
858///
859/// Note that we go take special precautions to ensure that we get deterministic
860/// results from this routine. In other words, we don't want the results of
861/// this to depend on where the addresses of various SCEV objects happened to
862/// land in memory.
864 DominatorTree &DT) {
865 if (Ops.size() < 2) return; // Noop
866
867 // Whether LHS has provably less complexity than RHS.
868 auto IsLessComplex = [&](SCEVUse LHS, SCEVUse RHS) {
869 auto Complexity = CompareSCEVComplexity(LI, LHS, RHS, DT);
870 return Complexity && *Complexity < 0;
871 };
872 if (Ops.size() == 2) {
873 // This is the common case, which also happens to be trivially simple.
874 // Special case it.
875 SCEVUse &LHS = Ops[0], &RHS = Ops[1];
876 if (IsLessComplex(RHS, LHS))
877 std::swap(LHS, RHS);
878 return;
879 }
880
881 // Do the rough sort by complexity.
883 Ops, [&](SCEVUse LHS, SCEVUse RHS) { return IsLessComplex(LHS, RHS); });
884
885 // Now that we are sorted by complexity, group elements of the same
886 // complexity. Note that this is, at worst, N^2, but the vector is likely to
887 // be extremely short in practice. Note that we take this approach because we
888 // do not want to depend on the addresses of the objects we are grouping.
889 for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) {
890 const SCEV *S = Ops[i];
891 unsigned Complexity = S->getSCEVType();
892
893 // If there are any objects of the same complexity and same value as this
894 // one, group them.
895 for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) {
896 if (Ops[j] == S) { // Found a duplicate.
897 // Move it to immediately after i'th element.
898 std::swap(Ops[i+1], Ops[j]);
899 ++i; // no need to rescan it.
900 if (i == e-2) return; // Done!
901 }
902 }
903 }
904}
905
906/// Returns true if \p Ops contains a huge SCEV (the subtree of S contains at
907/// least HugeExprThreshold nodes).
909 return any_of(Ops, [](const SCEV *S) {
911 });
912}
913
914/// Performs a number of common optimizations on the passed \p Ops. If the
915/// whole expression reduces down to a single operand, it will be returned.
916///
917/// The following optimizations are performed:
918/// * Fold constants using the \p Fold function.
919/// * Remove identity constants satisfying \p IsIdentity.
920/// * If a constant satisfies \p IsAbsorber, return it.
921/// * Sort operands by complexity.
922template <typename FoldT, typename IsIdentityT, typename IsAbsorberT>
923static const SCEV *
925 SmallVectorImpl<SCEVUse> &Ops, FoldT Fold,
926 IsIdentityT IsIdentity, IsAbsorberT IsAbsorber) {
927 const SCEVConstant *Folded = nullptr;
928 for (unsigned Idx = 0; Idx < Ops.size();) {
929 const SCEV *Op = Ops[Idx];
930 if (const auto *C = dyn_cast<SCEVConstant>(Op)) {
931 if (!Folded)
932 Folded = C;
933 else
934 Folded = cast<SCEVConstant>(
935 SE.getConstant(Fold(Folded->getAPInt(), C->getAPInt())));
936 Ops.erase(Ops.begin() + Idx);
937 continue;
938 }
939 ++Idx;
940 }
941
942 if (Ops.empty()) {
943 assert(Folded && "Must have folded value");
944 return Folded;
945 }
946
947 if (Folded && IsAbsorber(Folded->getAPInt()))
948 return Folded;
949
950 GroupByComplexity(Ops, &LI, DT);
951 if (Folded && !IsIdentity(Folded->getAPInt()))
952 Ops.insert(Ops.begin(), Folded);
953
954 return Ops.size() == 1 ? Ops[0] : nullptr;
955}
956
957//===----------------------------------------------------------------------===//
958// Simple SCEV method implementations
959//===----------------------------------------------------------------------===//
960
961/// Compute BC(It, K). The result has width W. Assume, K > 0.
962static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
963 ScalarEvolution &SE,
964 Type *ResultTy) {
965 // Handle the simplest case efficiently.
966 if (K == 1)
967 return SE.getTruncateOrZeroExtend(It, ResultTy);
968
969 // We are using the following formula for BC(It, K):
970 //
971 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
972 //
973 // Suppose, W is the bitwidth of the return value. We must be prepared for
974 // overflow. Hence, we must assure that the result of our computation is
975 // equal to the accurate one modulo 2^W. Unfortunately, division isn't
976 // safe in modular arithmetic.
977 //
978 // However, this code doesn't use exactly that formula; the formula it uses
979 // is something like the following, where T is the number of factors of 2 in
980 // K! (i.e. trailing zeros in the binary representation of K!), and ^ is
981 // exponentiation:
982 //
983 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T)
984 //
985 // This formula is trivially equivalent to the previous formula. However,
986 // this formula can be implemented much more efficiently. The trick is that
987 // K! / 2^T is odd, and exact division by an odd number *is* safe in modular
988 // arithmetic. To do exact division in modular arithmetic, all we have
989 // to do is multiply by the inverse. Therefore, this step can be done at
990 // width W.
991 //
992 // The next issue is how to safely do the division by 2^T. The way this
993 // is done is by doing the multiplication step at a width of at least W + T
994 // bits. This way, the bottom W+T bits of the product are accurate. Then,
995 // when we perform the division by 2^T (which is equivalent to a right shift
996 // by T), the bottom W bits are accurate. Extra bits are okay; they'll get
997 // truncated out after the division by 2^T.
998 //
999 // In comparison to just directly using the first formula, this technique
1000 // is much more efficient; using the first formula requires W * K bits,
1001 // but this formula less than W + K bits. Also, the first formula requires
1002 // a division step, whereas this formula only requires multiplies and shifts.
1003 //
1004 // It doesn't matter whether the subtraction step is done in the calculation
1005 // width or the input iteration count's width; if the subtraction overflows,
1006 // the result must be zero anyway. We prefer here to do it in the width of
1007 // the induction variable because it helps a lot for certain cases; CodeGen
1008 // isn't smart enough to ignore the overflow, which leads to much less
1009 // efficient code if the width of the subtraction is wider than the native
1010 // register width.
1011 //
1012 // (It's possible to not widen at all by pulling out factors of 2 before
1013 // the multiplication; for example, K=2 can be calculated as
1014 // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires
1015 // extra arithmetic, so it's not an obvious win, and it gets
1016 // much more complicated for K > 3.)
1017
1018 // Protection from insane SCEVs; this bound is conservative,
1019 // but it probably doesn't matter.
1020 if (K > 1000)
1021 return SE.getCouldNotCompute();
1022
1023 unsigned W = SE.getTypeSizeInBits(ResultTy);
1024
1025 // Calculate K! / 2^T and T; we divide out the factors of two before
1026 // multiplying for calculating K! / 2^T to avoid overflow.
1027 // Other overflow doesn't matter because we only care about the bottom
1028 // W bits of the result.
1029 APInt OddFactorial(W, 1);
1030 unsigned T = 1;
1031 for (unsigned i = 3; i <= K; ++i) {
1032 unsigned TwoFactors = countr_zero(i);
1033 T += TwoFactors;
1034 OddFactorial *= (i >> TwoFactors);
1035 }
1036
1037 // We need at least W + T bits for the multiplication step
1038 unsigned CalculationBits = W + T;
1039
1040 // Calculate 2^T, at width T+W.
1041 APInt DivFactor = APInt::getOneBitSet(CalculationBits, T);
1042
1043 // Calculate the multiplicative inverse of K! / 2^T;
1044 // this multiplication factor will perform the exact division by
1045 // K! / 2^T.
1046 APInt MultiplyFactor = OddFactorial.multiplicativeInverse();
1047
1048 // Calculate the product, at width T+W
1049 IntegerType *CalculationTy = IntegerType::get(SE.getContext(),
1050 CalculationBits);
1051 const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
1052 for (unsigned i = 1; i != K; ++i) {
1053 const SCEV *S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i));
1054 Dividend = SE.getMulExpr(Dividend,
1055 SE.getTruncateOrZeroExtend(S, CalculationTy));
1056 }
1057
1058 // Divide by 2^T
1059 const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
1060
1061 // Truncate the result, and divide by K! / 2^T.
1062
1063 return SE.getMulExpr(SE.getConstant(MultiplyFactor),
1064 SE.getTruncateOrZeroExtend(DivResult, ResultTy));
1065}
1066
1067/// Return the value of this chain of recurrences at the specified iteration
1068/// number. We can evaluate this recurrence by multiplying each element in the
1069/// chain by the binomial coefficient corresponding to it. In other words, we
1070/// can evaluate {A,+,B,+,C,+,D} as:
1071///
1072/// A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
1073///
1074/// where BC(It, k) stands for binomial coefficient.
1076 ScalarEvolution &SE) const {
1077 return evaluateAtIteration(operands(), It, SE);
1078}
1079
1081 const SCEV *It,
1082 ScalarEvolution &SE) {
1083 assert(Operands.size() > 0);
1084 const SCEV *Result = Operands[0].getPointer();
1085 for (unsigned i = 1, e = Operands.size(); i != e; ++i) {
1086 // The computation is correct in the face of overflow provided that the
1087 // multiplication is performed _after_ the evaluation of the binomial
1088 // coefficient.
1089 const SCEV *Coeff = BinomialCoefficient(It, i, SE, Result->getType());
1090 if (isa<SCEVCouldNotCompute>(Coeff))
1091 return Coeff;
1092
1093 Result =
1094 SE.getAddExpr(Result, SE.getMulExpr(Operands[i].getPointer(), Coeff));
1095 }
1096 return Result;
1097}
1098
1099//===----------------------------------------------------------------------===//
1100// SCEV Expression folder implementations
1101//===----------------------------------------------------------------------===//
1102
1103/// The SCEVCastSinkingRewriter takes a scalar evolution expression,
1104/// which computes a pointer-typed value, and rewrites the whole expression
1105/// tree so that *all* the computations are done on integers, and the only
1106/// pointer-typed operands in the expression are SCEVUnknown.
1107/// The CreatePtrCast callback is invoked to create the actual conversion
1108/// (ptrtoint or ptrtoaddr) at the SCEVUnknown leaves.
1110 : public SCEVRewriteVisitor<SCEVCastSinkingRewriter> {
1112 using ConversionFn = function_ref<const SCEV *(const SCEVUnknown *)>;
1113 Type *TargetTy;
1114 ConversionFn CreatePtrCast;
1115
1116public:
1118 ConversionFn CreatePtrCast)
1119 : Base(SE), TargetTy(TargetTy), CreatePtrCast(std::move(CreatePtrCast)) {}
1120
1121 static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE,
1122 Type *TargetTy, ConversionFn CreatePtrCast) {
1123 SCEVCastSinkingRewriter Rewriter(SE, TargetTy, std::move(CreatePtrCast));
1124 return Rewriter.visit(Scev);
1125 }
1126
1127 const SCEV *visit(const SCEV *S) {
1128 Type *STy = S->getType();
1129 // If the expression is not pointer-typed, just keep it as-is.
1130 if (!STy->isPointerTy())
1131 return S;
1132 // Else, recursively sink the cast down into it.
1133 return Base::visit(S);
1134 }
1135
1136 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
1137 // Preserve wrap flags on rewritten SCEVAddExpr, which the default
1138 // implementation drops.
1139 SmallVector<SCEVUse, 2> Operands;
1140 bool Changed = false;
1141 for (SCEVUse Op : Expr->operands()) {
1142 Operands.push_back(visit(Op.getPointer()));
1143 Changed |= Op.getPointer() != Operands.back();
1144 }
1145 return !Changed ? Expr : SE.getAddExpr(Operands, Expr->getNoWrapFlags());
1146 }
1147
1148 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
1149 SmallVector<SCEVUse, 2> Operands;
1150 bool Changed = false;
1151 for (SCEVUse Op : Expr->operands()) {
1152 Operands.push_back(visit(Op.getPointer()));
1153 Changed |= Op.getPointer() != Operands.back();
1154 }
1155 return !Changed ? Expr : SE.getMulExpr(Operands, Expr->getNoWrapFlags());
1156 }
1157
1158 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
1159 assert(Expr->getType()->isPointerTy() &&
1160 "Should only reach pointer-typed SCEVUnknown's.");
1161 // Perform some basic constant folding. If the operand of the cast is a
1162 // null pointer, don't create a cast SCEV expression (that will be left
1163 // as-is), but produce a zero constant.
1165 return SE.getZero(TargetTy);
1166 return CreatePtrCast(Expr);
1167 }
1168};
1169
1171 assert(Op->getType()->isPointerTy() && "Op must be a pointer");
1172
1173 // It isn't legal for optimizations to construct new ptrtoint expressions
1174 // for non-integral pointers.
1175 if (getDataLayout().isNonIntegralPointerType(Op->getType()))
1176 return getCouldNotCompute();
1177
1178 Type *IntPtrTy = getDataLayout().getIntPtrType(Op->getType());
1179
1180 // We can only trivially model ptrtoint if SCEV's effective (integer) type
1181 // is sufficiently wide to represent all possible pointer values.
1182 // We could theoretically teach SCEV to truncate wider pointers, but
1183 // that isn't implemented for now.
1185 getDataLayout().getTypeSizeInBits(IntPtrTy))
1186 return getCouldNotCompute();
1187
1188 // Use the rewriter to sink the cast down to SCEVUnknown leaves.
1190 Op, *this, IntPtrTy, [this, IntPtrTy](const SCEVUnknown *U) {
1192 ID.AddInteger(scPtrToInt);
1193 ID.AddPointer(U);
1194 void *IP = nullptr;
1195 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1196 return S;
1197 SCEV *S = new (SCEVAllocator)
1198 SCEVPtrToIntExpr(ID.Intern(SCEVAllocator), U, IntPtrTy);
1199 UniqueSCEVs.InsertNode(S, IP);
1200 S->computeAndSetCanonical(*this);
1201 registerUser(S, U);
1202 return static_cast<const SCEV *>(S);
1203 });
1204 assert(IntOp->getType()->isIntegerTy() &&
1205 "We must have succeeded in sinking the cast, "
1206 "and ending up with an integer-typed expression!");
1207 return IntOp;
1208}
1209
1211 assert(Op->getType()->isPointerTy() && "Op must be a pointer");
1212
1213 // Treat pointers with unstable representation conservatively, since the
1214 // address bits may change.
1215 if (DL.hasUnstableRepresentation(Op->getType()))
1216 return getCouldNotCompute();
1217
1218 Type *Ty = DL.getAddressType(Op->getType());
1219
1220 // Use the rewriter to sink the cast down to SCEVUnknown leaves.
1221 // The rewriter handles null pointer constant folding.
1223 Op, *this, Ty, [this, Ty](const SCEVUnknown *U) {
1225 ID.AddInteger(scPtrToAddr);
1226 ID.AddPointer(U);
1227 ID.AddPointer(Ty);
1228 void *IP = nullptr;
1229 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1230 return S;
1231 SCEV *S = new (SCEVAllocator)
1232 SCEVPtrToAddrExpr(ID.Intern(SCEVAllocator), U, Ty);
1233 UniqueSCEVs.InsertNode(S, IP);
1234 S->computeAndSetCanonical(*this);
1235 registerUser(S, U);
1236 return static_cast<const SCEV *>(S);
1237 });
1238 assert(IntOp->getType()->isIntegerTy() &&
1239 "We must have succeeded in sinking the cast, "
1240 "and ending up with an integer-typed expression!");
1241 return IntOp;
1242}
1243
1245 assert(Ty->isIntegerTy() && "Target type must be an integer type!");
1246
1247 const SCEV *IntOp = getLosslessPtrToIntExpr(Op);
1248 if (isa<SCEVCouldNotCompute>(IntOp))
1249 return IntOp;
1250
1251 return getTruncateOrZeroExtend(IntOp, Ty);
1252}
1253
1255 unsigned Depth) {
1256 assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) &&
1257 "This is not a truncating conversion!");
1258 assert(isSCEVable(Ty) &&
1259 "This is not a conversion to a SCEVable type!");
1260 assert(!Op->getType()->isPointerTy() && "Can't truncate pointer!");
1261 Ty = getEffectiveSCEVType(Ty);
1262
1264 ID.AddInteger(scTruncate);
1265 ID.AddPointer(Op);
1266 ID.AddPointer(Ty);
1267 void *IP = nullptr;
1268 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1269
1270 // Fold if the operand is constant.
1271 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1272 return getConstant(
1273 cast<ConstantInt>(ConstantExpr::getTrunc(SC->getValue(), Ty)));
1274
1275 // trunc(trunc(x)) --> trunc(x)
1277 return getTruncateExpr(ST->getOperand(), Ty, Depth + 1);
1278
1279 // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing
1281 return getTruncateOrSignExtend(SS->getOperand(), Ty, Depth + 1);
1282
1283 // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing
1285 return getTruncateOrZeroExtend(SZ->getOperand(), Ty, Depth + 1);
1286
1287 if (Depth > MaxCastDepth) {
1288 SCEV *S =
1289 new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), Op, Ty);
1290 UniqueSCEVs.InsertNode(S, IP);
1291 S->computeAndSetCanonical(*this);
1292 registerUser(S, Op);
1293 return S;
1294 }
1295
1296 // trunc(x1 + ... + xN) --> trunc(x1) + ... + trunc(xN) and
1297 // trunc(x1 * ... * xN) --> trunc(x1) * ... * trunc(xN),
1298 // if after transforming we have at most one truncate, not counting truncates
1299 // that replace other casts.
1301 auto *CommOp = cast<SCEVCommutativeExpr>(Op);
1302 SmallVector<SCEVUse, 4> Operands;
1303 unsigned numTruncs = 0;
1304 for (unsigned i = 0, e = CommOp->getNumOperands(); i != e && numTruncs < 2;
1305 ++i) {
1306 const SCEV *S = getTruncateExpr(CommOp->getOperand(i), Ty, Depth + 1);
1307 if (!isa<SCEVIntegralCastExpr>(CommOp->getOperand(i)) &&
1309 numTruncs++;
1310 Operands.push_back(S);
1311 }
1312 if (numTruncs < 2) {
1313 if (isa<SCEVAddExpr>(Op))
1314 return getAddExpr(Operands);
1315 if (isa<SCEVMulExpr>(Op))
1316 return getMulExpr(Operands);
1317 llvm_unreachable("Unexpected SCEV type for Op.");
1318 }
1319 // Although we checked in the beginning that ID is not in the cache, it is
1320 // possible that during recursion and different modification ID was inserted
1321 // into the cache. So if we find it, just return it.
1322 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1323 return S;
1324 }
1325
1326 // If the input value is a chrec scev, truncate the chrec's operands.
1327 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
1328 SmallVector<SCEVUse, 4> Operands;
1329 for (const SCEV *Op : AddRec->operands())
1330 Operands.push_back(getTruncateExpr(Op, Ty, Depth + 1));
1331 return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap);
1332 }
1333
1334 // Return zero if truncating to known zeros.
1335 uint32_t MinTrailingZeros = getMinTrailingZeros(Op);
1336 if (MinTrailingZeros >= getTypeSizeInBits(Ty))
1337 return getZero(Ty);
1338
1339 // The cast wasn't folded; create an explicit cast node. We can reuse
1340 // the existing insert position since if we get here, we won't have
1341 // made any changes which would invalidate it.
1342 SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator),
1343 Op, Ty);
1344 UniqueSCEVs.InsertNode(S, IP);
1345 S->computeAndSetCanonical(*this);
1346 registerUser(S, Op);
1347 return S;
1348}
1349
1350// Get the limit of a recurrence such that incrementing by Step cannot cause
1351// signed overflow as long as the value of the recurrence within the
1352// loop does not exceed this limit before incrementing.
1353static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step,
1354 ICmpInst::Predicate *Pred,
1355 ScalarEvolution *SE) {
1356 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1357 if (SE->isKnownPositive(Step)) {
1358 *Pred = ICmpInst::ICMP_SLT;
1360 SE->getSignedRangeMax(Step));
1361 }
1362 if (SE->isKnownNegative(Step)) {
1363 *Pred = ICmpInst::ICMP_SGT;
1365 SE->getSignedRangeMin(Step));
1366 }
1367 return nullptr;
1368}
1369
1370// Get the limit of a recurrence such that incrementing by Step cannot cause
1371// unsigned overflow as long as the value of the recurrence within the loop does
1372// not exceed this limit before incrementing.
1374 ICmpInst::Predicate *Pred,
1375 ScalarEvolution *SE) {
1376 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1377 *Pred = ICmpInst::ICMP_ULT;
1378
1380 SE->getUnsignedRangeMax(Step));
1381}
1382
1383namespace {
1384
1385struct ExtendOpTraitsBase {
1386 typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *,
1387 unsigned);
1388};
1389
1390// Used to make code generic over signed and unsigned overflow.
1391template <typename ExtendOp> struct ExtendOpTraits {
1392 // Members present:
1393 //
1394 // static const SCEV::NoWrapFlags WrapType;
1395 //
1396 // static const ExtendOpTraitsBase::GetExtendExprTy GetExtendExpr;
1397 //
1398 // static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1399 // ICmpInst::Predicate *Pred,
1400 // ScalarEvolution *SE);
1401};
1402
1403template <>
1404struct ExtendOpTraits<SCEVSignExtendExpr> : public ExtendOpTraitsBase {
1405 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNSW;
1406
1407 static const GetExtendExprTy GetExtendExpr;
1408
1409 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1410 ICmpInst::Predicate *Pred,
1411 ScalarEvolution *SE) {
1412 return getSignedOverflowLimitForStep(Step, Pred, SE);
1413 }
1414};
1415
1416const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1418
1419template <>
1420struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase {
1421 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNUW;
1422
1423 static const GetExtendExprTy GetExtendExpr;
1424
1425 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1426 ICmpInst::Predicate *Pred,
1427 ScalarEvolution *SE) {
1428 return getUnsignedOverflowLimitForStep(Step, Pred, SE);
1429 }
1430};
1431
1432const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1434
1435} // end anonymous namespace
1436
1437// The recurrence AR has been shown to have no signed/unsigned wrap or something
1438// close to it. Typically, if we can prove NSW/NUW for AR, then we can just as
1439// easily prove NSW/NUW for its preincrement or postincrement sibling. This
1440// allows normalizing a sign/zero extended AddRec as such: {sext/zext(Step +
1441// Start),+,Step} => {(Step + sext/zext(Start),+,Step} As a result, the
1442// expression "Step + sext/zext(PreIncAR)" is congruent with
1443// "sext/zext(PostIncAR)"
1444template <typename ExtendOpTy>
1445static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty,
1446 ScalarEvolution *SE, unsigned Depth) {
1447 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1448 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1449
1450 const Loop *L = AR->getLoop();
1451 const SCEV *Start = AR->getStart();
1452 const SCEV *Step = AR->getStepRecurrence(*SE);
1453
1454 // Check for a simple looking step prior to loop entry.
1455 const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Start);
1456 if (!SA)
1457 return nullptr;
1458
1459 // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV
1460 // subtraction is expensive. For this purpose, perform a quick and dirty
1461 // difference, by checking for Step in the operand list. Note, that
1462 // SA might have repeated ops, like %a + %a + ..., so only remove one.
1463 SmallVector<SCEVUse, 4> DiffOps(SA->operands());
1464 for (auto It = DiffOps.begin(); It != DiffOps.end(); ++It)
1465 if (*It == Step) {
1466 DiffOps.erase(It);
1467 break;
1468 }
1469
1470 if (DiffOps.size() == SA->getNumOperands())
1471 return nullptr;
1472
1473 // Try to prove `WrapType` (SCEV::FlagNSW or SCEV::FlagNUW) on `PreStart` +
1474 // `Step`:
1475
1476 // 1. NSW/NUW flags on the step increment.
1477 auto PreStartFlags =
1479 const SCEV *PreStart = SE->getAddExpr(DiffOps, PreStartFlags);
1481 SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap));
1482
1483 // "{S,+,X} is <nsw>/<nuw>" and "the backedge is taken at least once" implies
1484 // "S+X does not sign/unsign-overflow".
1485 //
1486
1487 const SCEV *BECount = SE->getBackedgeTakenCount(L);
1488 if (PreAR && any(PreAR->getNoWrapFlags(WrapType)) &&
1489 !isa<SCEVCouldNotCompute>(BECount) && SE->isKnownPositive(BECount))
1490 return PreStart;
1491
1492 // 2. Direct overflow check on the step operation's expression.
1493 unsigned BitWidth = SE->getTypeSizeInBits(AR->getType());
1494 Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2);
1495 const SCEV *OperandExtendedStart =
1496 SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy, Depth),
1497 (SE->*GetExtendExpr)(Step, WideTy, Depth));
1498 if ((SE->*GetExtendExpr)(Start, WideTy, Depth) == OperandExtendedStart) {
1499 if (PreAR && any(AR->getNoWrapFlags(WrapType))) {
1500 // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW
1501 // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then
1502 // `PreAR` == {`PreStart`,+,`Step`} is also `WrapType`. Cache this fact.
1503 SE->setNoWrapFlags(const_cast<SCEVAddRecExpr *>(PreAR), WrapType);
1504 }
1505 return PreStart;
1506 }
1507
1508 // 3. Loop precondition.
1510 const SCEV *OverflowLimit =
1511 ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(Step, &Pred, SE);
1512
1513 if (OverflowLimit &&
1514 SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit))
1515 return PreStart;
1516
1517 return nullptr;
1518}
1519
1520// Get the normalized zero or sign extended expression for this AddRec's Start.
1521template <typename ExtendOpTy>
1522static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty,
1523 ScalarEvolution *SE,
1524 unsigned Depth) {
1525 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1526
1527 const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE, Depth);
1528 if (!PreStart)
1529 return (SE->*GetExtendExpr)(AR->getStart(), Ty, Depth);
1530
1531 return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty,
1532 Depth),
1533 (SE->*GetExtendExpr)(PreStart, Ty, Depth));
1534}
1535
1536// Try to prove away overflow by looking at "nearby" add recurrences. A
1537// motivating example for this rule: if we know `{0,+,4}` is `ult` `-1` and it
1538// does not itself wrap then we can conclude that `{1,+,4}` is `nuw`.
1539//
1540// Formally:
1541//
1542// {S,+,X} == {S-T,+,X} + T
1543// => Ext({S,+,X}) == Ext({S-T,+,X} + T)
1544//
1545// If ({S-T,+,X} + T) does not overflow ... (1)
1546//
1547// RHS == Ext({S-T,+,X} + T) == Ext({S-T,+,X}) + Ext(T)
1548//
1549// If {S-T,+,X} does not overflow ... (2)
1550//
1551// RHS == Ext({S-T,+,X}) + Ext(T) == {Ext(S-T),+,Ext(X)} + Ext(T)
1552// == {Ext(S-T)+Ext(T),+,Ext(X)}
1553//
1554// If (S-T)+T does not overflow ... (3)
1555//
1556// RHS == {Ext(S-T)+Ext(T),+,Ext(X)} == {Ext(S-T+T),+,Ext(X)}
1557// == {Ext(S),+,Ext(X)} == LHS
1558//
1559// Thus, if (1), (2) and (3) are true for some T, then
1560// Ext({S,+,X}) == {Ext(S),+,Ext(X)}
1561//
1562// (3) is implied by (1) -- "(S-T)+T does not overflow" is simply "({S-T,+,X}+T)
1563// does not overflow" restricted to the 0th iteration. Therefore we only need
1564// to check for (1) and (2).
1565//
1566// In the current context, S is `Start`, X is `Step`, Ext is `ExtendOpTy` and T
1567// is `Delta` (defined below).
1568template <typename ExtendOpTy>
1569bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start,
1570 const SCEV *Step,
1571 const Loop *L) {
1572 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1573
1574 // We restrict `Start` to a constant to prevent SCEV from spending too much
1575 // time here. It is correct (but more expensive) to continue with a
1576 // non-constant `Start` and do a general SCEV subtraction to compute
1577 // `PreStart` below.
1578 const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start);
1579 if (!StartC)
1580 return false;
1581
1582 APInt StartAI = StartC->getAPInt();
1583
1584 for (unsigned Delta : {-2, -1, 1, 2}) {
1585 const SCEV *PreStart = getConstant(StartAI - Delta);
1586
1587 FoldingSetNodeID ID;
1588 ID.AddInteger(scAddRecExpr);
1589 ID.AddPointer(PreStart);
1590 ID.AddPointer(Step);
1591 ID.AddPointer(L);
1592 void *IP = nullptr;
1593 const auto *PreAR =
1594 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
1595
1596 // Give up if we don't already have the add recurrence we need because
1597 // actually constructing an add recurrence is relatively expensive.
1598 if (PreAR && any(PreAR->getNoWrapFlags(WrapType))) { // proves (2)
1599 const SCEV *DeltaS = getConstant(StartC->getType(), Delta);
1601 const SCEV *Limit = ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(
1602 DeltaS, &Pred, this);
1603 if (Limit && isKnownPredicate(Pred, PreAR, Limit)) // proves (1)
1604 return true;
1605 }
1606 }
1607
1608 return false;
1609}
1610
1611// Finds an integer D for an expression (C + x + y + ...) such that the top
1612// level addition in (D + (C - D + x + y + ...)) would not wrap (signed or
1613// unsigned) and the number of trailing zeros of (C - D + x + y + ...) is
1614// maximized, where C is the \p ConstantTerm, x, y, ... are arbitrary SCEVs, and
1615// the (C + x + y + ...) expression is \p WholeAddExpr.
1617 const SCEVConstant *ConstantTerm,
1618 const SCEVAddExpr *WholeAddExpr) {
1619 const APInt &C = ConstantTerm->getAPInt();
1620 const unsigned BitWidth = C.getBitWidth();
1621 // Find number of trailing zeros of (x + y + ...) w/o the C first:
1622 uint32_t TZ = BitWidth;
1623 for (unsigned I = 1, E = WholeAddExpr->getNumOperands(); I < E && TZ; ++I)
1624 TZ = std::min(TZ, SE.getMinTrailingZeros(WholeAddExpr->getOperand(I)));
1625 if (TZ) {
1626 // Set D to be as many least significant bits of C as possible while still
1627 // guaranteeing that adding D to (C - D + x + y + ...) won't cause a wrap:
1628 return TZ < BitWidth ? C.trunc(TZ).zext(BitWidth) : C;
1629 }
1630 return APInt(BitWidth, 0);
1631}
1632
1633// Finds an integer D for an affine AddRec expression {C,+,x} such that the top
1634// level addition in (D + {C-D,+,x}) would not wrap (signed or unsigned) and the
1635// number of trailing zeros of (C - D + x * n) is maximized, where C is the \p
1636// ConstantStart, x is an arbitrary \p Step, and n is the loop trip count.
1638 const APInt &ConstantStart,
1639 const SCEV *Step) {
1640 const unsigned BitWidth = ConstantStart.getBitWidth();
1641 const uint32_t TZ = SE.getMinTrailingZeros(Step);
1642 if (TZ)
1643 return TZ < BitWidth ? ConstantStart.trunc(TZ).zext(BitWidth)
1644 : ConstantStart;
1645 return APInt(BitWidth, 0);
1646}
1647
1649 const ScalarEvolution::FoldID &ID, const SCEV *S,
1652 &FoldCacheUser) {
1653 auto I = FoldCache.insert({ID, S});
1654 if (!I.second) {
1655 // Remove FoldCacheUser entry for ID when replacing an existing FoldCache
1656 // entry.
1657 auto &UserIDs = FoldCacheUser[I.first->second];
1658 assert(count(UserIDs, ID) == 1 && "unexpected duplicates in UserIDs");
1659 for (unsigned I = 0; I != UserIDs.size(); ++I)
1660 if (UserIDs[I] == ID) {
1661 std::swap(UserIDs[I], UserIDs.back());
1662 break;
1663 }
1664 UserIDs.pop_back();
1665 I.first->second = S;
1666 }
1667 FoldCacheUser[S].push_back(ID);
1668}
1669
1670const SCEV *
1672 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1673 "This is not an extending conversion!");
1674 assert(isSCEVable(Ty) &&
1675 "This is not a conversion to a SCEVable type!");
1676 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1677 Ty = getEffectiveSCEVType(Ty);
1678
1679 FoldID ID(scZeroExtend, Op, Ty);
1680 if (const SCEV *S = FoldCache.lookup(ID))
1681 return S;
1682
1683 const SCEV *S = getZeroExtendExprImpl(Op, Ty, Depth);
1685 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
1686 return S;
1687}
1688
1690 unsigned Depth) {
1691 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1692 "This is not an extending conversion!");
1693 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
1694 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1695
1696 // Fold if the operand is constant.
1697 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1698 return getConstant(SC->getAPInt().zext(getTypeSizeInBits(Ty)));
1699
1700 // zext(zext(x)) --> zext(x)
1702 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1703
1704 // Before doing any expensive analysis, check to see if we've already
1705 // computed a SCEV for this Op and Ty.
1707 ID.AddInteger(scZeroExtend);
1708 ID.AddPointer(Op);
1709 ID.AddPointer(Ty);
1710 void *IP = nullptr;
1711 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1712 if (Depth > MaxCastDepth) {
1713 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1714 Op, Ty);
1715 UniqueSCEVs.InsertNode(S, IP);
1716 S->computeAndSetCanonical(*this);
1717 registerUser(S, Op);
1718 return S;
1719 }
1720
1721 // zext(trunc(x)) --> zext(x) or x or trunc(x)
1723 // It's possible the bits taken off by the truncate were all zero bits. If
1724 // so, we should be able to simplify this further.
1725 const SCEV *X = ST->getOperand();
1727 unsigned TruncBits = getTypeSizeInBits(ST->getType());
1728 unsigned NewBits = getTypeSizeInBits(Ty);
1729 if (CR.truncate(TruncBits).zeroExtend(NewBits).contains(
1730 CR.zextOrTrunc(NewBits)))
1731 return getTruncateOrZeroExtend(X, Ty, Depth);
1732 }
1733
1734 // If the input value is a chrec scev, and we can prove that the value
1735 // did not overflow the old, smaller, value, we can zero extend all of the
1736 // operands (often constants). This allows analysis of something like
1737 // this: for (unsigned char X = 0; X < 100; ++X) { int Y = X; }
1739 if (AR->isAffine()) {
1740 const SCEV *Start = AR->getStart();
1741 const SCEV *Step = AR->getStepRecurrence(*this);
1742 unsigned BitWidth = getTypeSizeInBits(AR->getType());
1743 const Loop *L = AR->getLoop();
1744
1745 // If we have special knowledge that this addrec won't overflow,
1746 // we don't need to do any further analysis.
1747 if (AR->hasNoUnsignedWrap()) {
1748 Start =
1750 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1751 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1752 }
1753
1754 // Check whether the backedge-taken count is SCEVCouldNotCompute.
1755 // Note that this serves two purposes: It filters out loops that are
1756 // simply not analyzable, and it covers the case where this code is
1757 // being called from within backedge-taken count analysis, such that
1758 // attempting to ask for the backedge-taken count would likely result
1759 // in infinite recursion. In the later case, the analysis code will
1760 // cope with a conservative value, and it will take care to purge
1761 // that value once it has finished.
1762 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
1763 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
1764 // Manually compute the final value for AR, checking for overflow.
1765
1766 // Check whether the backedge-taken count can be losslessly casted to
1767 // the addrec's type. The count is always unsigned.
1768 const SCEV *CastedMaxBECount =
1769 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
1770 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
1771 CastedMaxBECount, MaxBECount->getType(), Depth);
1772 if (MaxBECount == RecastedMaxBECount) {
1773 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
1774 // Check whether Start+Step*MaxBECount has no unsigned overflow.
1775 const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step,
1777 const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul,
1779 Depth + 1),
1780 WideTy, Depth + 1);
1781 const SCEV *WideStart = getZeroExtendExpr(Start, WideTy, Depth + 1);
1782 const SCEV *WideMaxBECount =
1783 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
1784 const SCEV *OperandExtendedAdd =
1785 getAddExpr(WideStart,
1786 getMulExpr(WideMaxBECount,
1787 getZeroExtendExpr(Step, WideTy, Depth + 1),
1790 if (ZAdd == OperandExtendedAdd) {
1791 // Cache knowledge of AR NUW, which is propagated to this AddRec.
1792 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1793 // Return the expression with the addrec on the outside.
1794 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1795 Depth + 1);
1796 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1797 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1798 }
1799 // Similar to above, only this time treat the step value as signed.
1800 // This covers loops that count down.
1801 OperandExtendedAdd =
1802 getAddExpr(WideStart,
1803 getMulExpr(WideMaxBECount,
1804 getSignExtendExpr(Step, WideTy, Depth + 1),
1807 if (ZAdd == OperandExtendedAdd) {
1808 // Cache knowledge of AR NW, which is propagated to this AddRec.
1809 // Negative step causes unsigned wrap, but it still can't self-wrap.
1810 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1811 // Return the expression with the addrec on the outside.
1812 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1813 Depth + 1);
1814 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1815 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1816 }
1817 }
1818 }
1819
1820 // Normally, in the cases we can prove no-overflow via a
1821 // backedge guarding condition, we can also compute a backedge
1822 // taken count for the loop. The exceptions are assumptions and
1823 // guards present in the loop -- SCEV is not great at exploiting
1824 // these to compute max backedge taken counts, but can still use
1825 // these to prove lack of overflow. Use this fact to avoid
1826 // doing extra work that may not pay off.
1827 if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards ||
1828 !AC.assumptions().empty()) {
1829
1830 auto NewFlags = proveNoUnsignedWrapViaInduction(AR);
1831 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
1832 if (AR->hasNoUnsignedWrap()) {
1833 // Same as nuw case above - duplicated here to avoid a compile time
1834 // issue. It's not clear that the order of checks does matter, but
1835 // it's one of two issue possible causes for a change which was
1836 // reverted. Be conservative for the moment.
1837 Start =
1839 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1840 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1841 }
1842
1843 // For a negative step, we can extend the operands iff doing so only
1844 // traverses values in the range zext([0,UINT_MAX]).
1845 if (isKnownNegative(Step)) {
1847 getSignedRangeMin(Step));
1850 // Cache knowledge of AR NW, which is propagated to this
1851 // AddRec. Negative step causes unsigned wrap, but it
1852 // still can't self-wrap.
1853 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1854 // Return the expression with the addrec on the outside.
1855 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1856 Depth + 1);
1857 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1858 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1859 }
1860 }
1861 }
1862
1863 // zext({C,+,Step}) --> (zext(D) + zext({C-D,+,Step}))<nuw><nsw>
1864 // if D + (C - D + Step * n) could be proven to not unsigned wrap
1865 // where D maximizes the number of trailing zeros of (C - D + Step * n)
1866 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
1867 const APInt &C = SC->getAPInt();
1868 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
1869 if (D != 0) {
1870 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1871 const SCEV *SResidual =
1872 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
1873 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1874 return getAddExpr(SZExtD, SZExtR, SCEV::FlagNSW | SCEV::FlagNUW,
1875 Depth + 1);
1876 }
1877 }
1878
1879 if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) {
1880 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1881 Start =
1883 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1884 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1885 }
1886 }
1887
1888 // zext(A % B) --> zext(A) % zext(B)
1889 {
1890 const SCEV *LHS;
1891 const SCEV *RHS;
1892 if (match(Op, m_scev_URem(m_SCEV(LHS), m_SCEV(RHS), *this)))
1893 return getURemExpr(getZeroExtendExpr(LHS, Ty, Depth + 1),
1894 getZeroExtendExpr(RHS, Ty, Depth + 1));
1895 }
1896
1897 // zext(A / B) --> zext(A) / zext(B).
1898 if (auto *Div = dyn_cast<SCEVUDivExpr>(Op))
1899 return getUDivExpr(getZeroExtendExpr(Div->getLHS(), Ty, Depth + 1),
1900 getZeroExtendExpr(Div->getRHS(), Ty, Depth + 1));
1901
1902 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1903 // zext((A + B + ...)<nuw>) --> (zext(A) + zext(B) + ...)<nuw>
1904 if (SA->hasNoUnsignedWrap()) {
1905 // If the addition does not unsign overflow then we can, by definition,
1906 // commute the zero extension with the addition operation.
1908 for (SCEVUse Op : SA->operands())
1909 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1910 return getAddExpr(Ops, SCEV::FlagNUW, Depth + 1);
1911 }
1912
1913 const APInt *C, *C2;
1914 // zext (C + A)<nsw> -> (sext(C) + sext(A))<nsw> if zext (C + A)<nsw> >=s 0.
1915 // Currently the non-negative check is done manually, as isKnownNonNegative
1916 // is too expensive.
1917 if (SA->hasNoSignedWrap() &&
1919 m_scev_SMax(m_scev_APInt(C2), m_SCEV()))) &&
1920 C->isNegative() && !C->isMinSignedValue() && C2->sge(C->abs())) {
1921 assert(isKnownNonNegative(SA) && "incorrectly determined non-negative");
1922 return getAddExpr(getSignExtendExpr(SA->getOperand(0), Ty, Depth + 1),
1923 getSignExtendExpr(SA->getOperand(1), Ty, Depth + 1),
1924 SCEV::FlagNSW, Depth + 1);
1925 }
1926
1927 // zext(C + x + y + ...) --> (zext(D) + zext((C - D) + x + y + ...))
1928 // if D + (C - D + x + y + ...) could be proven to not unsigned wrap
1929 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1930 //
1931 // Often address arithmetics contain expressions like
1932 // (zext (add (shl X, C1), C2)), for instance, (zext (5 + (4 * X))).
1933 // This transformation is useful while proving that such expressions are
1934 // equal or differ by a small constant amount, see LoadStoreVectorizer pass.
1935 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1936 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1937 if (D != 0) {
1938 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1939 const SCEV *SResidual =
1941 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1942 return getAddExpr(SZExtD, SZExtR, (SCEV::FlagNSW | SCEV::FlagNUW),
1943 Depth + 1);
1944 }
1945 }
1946 }
1947
1948 if (auto *SM = dyn_cast<SCEVMulExpr>(Op)) {
1949 // zext((A * B * ...)<nuw>) --> (zext(A) * zext(B) * ...)<nuw>
1950 if (SM->hasNoUnsignedWrap()) {
1951 // If the multiply does not unsign overflow then we can, by definition,
1952 // commute the zero extension with the multiply operation.
1954 for (SCEVUse Op : SM->operands())
1955 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1956 return getMulExpr(Ops, SCEV::FlagNUW, Depth + 1);
1957 }
1958
1959 // zext(2^K * (trunc X to iN)) to iM ->
1960 // 2^K * (zext(trunc X to i{N-K}) to iM)<nuw>
1961 //
1962 // Proof:
1963 //
1964 // zext(2^K * (trunc X to iN)) to iM
1965 // = zext((trunc X to iN) << K) to iM
1966 // = zext((trunc X to i{N-K}) << K)<nuw> to iM
1967 // (because shl removes the top K bits)
1968 // = zext((2^K * (trunc X to i{N-K}))<nuw>) to iM
1969 // = (2^K * (zext(trunc X to i{N-K}) to iM))<nuw>.
1970 //
1971 const APInt *C;
1972 const SCEV *TruncRHS;
1973 if (match(SM,
1974 m_scev_Mul(m_scev_APInt(C), m_scev_Trunc(m_SCEV(TruncRHS)))) &&
1975 C->isPowerOf2()) {
1976 int NewTruncBits =
1977 getTypeSizeInBits(SM->getOperand(1)->getType()) - C->logBase2();
1978 Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits);
1979 return getMulExpr(
1980 getZeroExtendExpr(SM->getOperand(0), Ty),
1981 getZeroExtendExpr(getTruncateExpr(TruncRHS, NewTruncTy), Ty),
1982 SCEV::FlagNUW, Depth + 1);
1983 }
1984 }
1985
1986 // zext(umin(x, y)) -> umin(zext(x), zext(y))
1987 // zext(umax(x, y)) -> umax(zext(x), zext(y))
1990 SmallVector<SCEVUse, 4> Operands;
1991 for (SCEVUse Operand : MinMax->operands())
1992 Operands.push_back(getZeroExtendExpr(Operand, Ty));
1994 return getUMinExpr(Operands);
1995 return getUMaxExpr(Operands);
1996 }
1997
1998 // zext(umin_seq(x, y)) -> umin_seq(zext(x), zext(y))
2000 assert(isa<SCEVSequentialUMinExpr>(MinMax) && "Not supported!");
2001 SmallVector<SCEVUse, 4> Operands;
2002 for (SCEVUse Operand : MinMax->operands())
2003 Operands.push_back(getZeroExtendExpr(Operand, Ty));
2004 return getUMinExpr(Operands, /*Sequential*/ true);
2005 }
2006
2007 // The cast wasn't folded; create an explicit cast node.
2008 // Recompute the insert position, as it may have been invalidated.
2009 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2010 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
2011 Op, Ty);
2012 UniqueSCEVs.InsertNode(S, IP);
2013 S->computeAndSetCanonical(*this);
2014 registerUser(S, Op);
2015 return S;
2016}
2017
2018const SCEV *
2020 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
2021 "This is not an extending conversion!");
2022 assert(isSCEVable(Ty) &&
2023 "This is not a conversion to a SCEVable type!");
2024 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
2025 Ty = getEffectiveSCEVType(Ty);
2026
2027 FoldID ID(scSignExtend, Op, Ty);
2028 if (const SCEV *S = FoldCache.lookup(ID))
2029 return S;
2030
2031 const SCEV *S = getSignExtendExprImpl(Op, Ty, Depth);
2033 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
2034 return S;
2035}
2036
2038 unsigned Depth) {
2039 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
2040 "This is not an extending conversion!");
2041 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
2042 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
2043 Ty = getEffectiveSCEVType(Ty);
2044
2045 // Fold if the operand is constant.
2046 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
2047 return getConstant(SC->getAPInt().sext(getTypeSizeInBits(Ty)));
2048
2049 // sext(sext(x)) --> sext(x)
2051 return getSignExtendExpr(SS->getOperand(), Ty, Depth + 1);
2052
2053 // sext(zext(x)) --> zext(x)
2055 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
2056
2057 // Before doing any expensive analysis, check to see if we've already
2058 // computed a SCEV for this Op and Ty.
2060 ID.AddInteger(scSignExtend);
2061 ID.AddPointer(Op);
2062 ID.AddPointer(Ty);
2063 void *IP = nullptr;
2064 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2065 // Limit recursion depth.
2066 if (Depth > MaxCastDepth) {
2067 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
2068 Op, Ty);
2069 UniqueSCEVs.InsertNode(S, IP);
2070 S->computeAndSetCanonical(*this);
2071 registerUser(S, Op);
2072 return S;
2073 }
2074
2075 // sext(trunc(x)) --> sext(x) or x or trunc(x)
2077 // It's possible the bits taken off by the truncate were all sign bits. If
2078 // so, we should be able to simplify this further.
2079 const SCEV *X = ST->getOperand();
2081 unsigned TruncBits = getTypeSizeInBits(ST->getType());
2082 unsigned NewBits = getTypeSizeInBits(Ty);
2083 if (CR.truncate(TruncBits).signExtend(NewBits).contains(
2084 CR.sextOrTrunc(NewBits)))
2085 return getTruncateOrSignExtend(X, Ty, Depth);
2086 }
2087
2088 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
2089 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
2090 if (SA->hasNoSignedWrap()) {
2091 // If the addition does not sign overflow then we can, by definition,
2092 // commute the sign extension with the addition operation.
2094 for (SCEVUse Op : SA->operands())
2095 Ops.push_back(getSignExtendExpr(Op, Ty, Depth + 1));
2096 return getAddExpr(Ops, SCEV::FlagNSW, Depth + 1);
2097 }
2098
2099 // sext(C + x + y + ...) --> (sext(D) + sext((C - D) + x + y + ...))
2100 // if D + (C - D + x + y + ...) could be proven to not signed wrap
2101 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
2102 //
2103 // For instance, this will bring two seemingly different expressions:
2104 // 1 + sext(5 + 20 * %x + 24 * %y) and
2105 // sext(6 + 20 * %x + 24 * %y)
2106 // to the same form:
2107 // 2 + sext(4 + 20 * %x + 24 * %y)
2108 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
2109 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
2110 if (D != 0) {
2111 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
2112 const SCEV *SResidual =
2114 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
2115 return getAddExpr(SSExtD, SSExtR, (SCEV::FlagNSW | SCEV::FlagNUW),
2116 Depth + 1);
2117 }
2118 }
2119 }
2120 // If the input value is a chrec scev, and we can prove that the value
2121 // did not overflow the old, smaller, value, we can sign extend all of the
2122 // operands (often constants). This allows analysis of something like
2123 // this: for (signed char X = 0; X < 100; ++X) { int Y = X; }
2125 if (AR->isAffine()) {
2126 const SCEV *Start = AR->getStart();
2127 const SCEV *Step = AR->getStepRecurrence(*this);
2128 unsigned BitWidth = getTypeSizeInBits(AR->getType());
2129 const Loop *L = AR->getLoop();
2130
2131 // If we have special knowledge that this addrec won't overflow,
2132 // we don't need to do any further analysis.
2133 if (AR->hasNoSignedWrap()) {
2134 Start =
2136 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2137 return getAddRecExpr(Start, Step, L, SCEV::FlagNSW);
2138 }
2139
2140 // Check whether the backedge-taken count is SCEVCouldNotCompute.
2141 // Note that this serves two purposes: It filters out loops that are
2142 // simply not analyzable, and it covers the case where this code is
2143 // being called from within backedge-taken count analysis, such that
2144 // attempting to ask for the backedge-taken count would likely result
2145 // in infinite recursion. In the later case, the analysis code will
2146 // cope with a conservative value, and it will take care to purge
2147 // that value once it has finished.
2148 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
2149 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
2150 // Manually compute the final value for AR, checking for
2151 // overflow.
2152
2153 // Check whether the backedge-taken count can be losslessly casted to
2154 // the addrec's type. The count is always unsigned.
2155 const SCEV *CastedMaxBECount =
2156 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
2157 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
2158 CastedMaxBECount, MaxBECount->getType(), Depth);
2159 if (MaxBECount == RecastedMaxBECount) {
2160 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
2161 // Check whether Start+Step*MaxBECount has no signed overflow.
2162 const SCEV *SMul = getMulExpr(CastedMaxBECount, Step,
2164 const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul,
2166 Depth + 1),
2167 WideTy, Depth + 1);
2168 const SCEV *WideStart = getSignExtendExpr(Start, WideTy, Depth + 1);
2169 const SCEV *WideMaxBECount =
2170 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
2171 const SCEV *OperandExtendedAdd =
2172 getAddExpr(WideStart,
2173 getMulExpr(WideMaxBECount,
2174 getSignExtendExpr(Step, WideTy, Depth + 1),
2177 if (SAdd == OperandExtendedAdd) {
2178 // Cache knowledge of AR NSW, which is propagated to this AddRec.
2179 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2180 // Return the expression with the addrec on the outside.
2181 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2182 Depth + 1);
2183 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2184 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2185 }
2186 // Similar to above, only this time treat the step value as unsigned.
2187 // This covers loops that count up with an unsigned step.
2188 OperandExtendedAdd =
2189 getAddExpr(WideStart,
2190 getMulExpr(WideMaxBECount,
2191 getZeroExtendExpr(Step, WideTy, Depth + 1),
2194 if (SAdd == OperandExtendedAdd) {
2195 // If AR wraps around then
2196 //
2197 // abs(Step) * MaxBECount > unsigned-max(AR->getType())
2198 // => SAdd != OperandExtendedAdd
2199 //
2200 // Thus (AR is not NW => SAdd != OperandExtendedAdd) <=>
2201 // (SAdd == OperandExtendedAdd => AR is NW)
2202
2203 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
2204
2205 // Return the expression with the addrec on the outside.
2206 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2207 Depth + 1);
2208 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
2209 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2210 }
2211 }
2212 }
2213
2214 auto NewFlags = proveNoSignedWrapViaInduction(AR);
2215 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
2216 if (AR->hasNoSignedWrap()) {
2217 // Same as nsw case above - duplicated here to avoid a compile time
2218 // issue. It's not clear that the order of checks does matter, but
2219 // it's one of two issue possible causes for a change which was
2220 // reverted. Be conservative for the moment.
2221 Start =
2223 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2224 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2225 }
2226
2227 // sext({C,+,Step}) --> (sext(D) + sext({C-D,+,Step}))<nuw><nsw>
2228 // if D + (C - D + Step * n) could be proven to not signed wrap
2229 // where D maximizes the number of trailing zeros of (C - D + Step * n)
2230 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
2231 const APInt &C = SC->getAPInt();
2232 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
2233 if (D != 0) {
2234 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
2235 const SCEV *SResidual =
2236 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
2237 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
2238 return getAddExpr(SSExtD, SSExtR, (SCEV::FlagNSW | SCEV::FlagNUW),
2239 Depth + 1);
2240 }
2241 }
2242
2243 if (proveNoWrapByVaryingStart<SCEVSignExtendExpr>(Start, Step, L)) {
2244 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2245 Start =
2247 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2248 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2249 }
2250 }
2251
2252 // If the input value is provably positive and we could not simplify
2253 // away the sext build a zext instead.
2255 return getZeroExtendExpr(Op, Ty, Depth + 1);
2256
2257 // sext(smin(x, y)) -> smin(sext(x), sext(y))
2258 // sext(smax(x, y)) -> smax(sext(x), sext(y))
2261 SmallVector<SCEVUse, 4> Operands;
2262 for (SCEVUse Operand : MinMax->operands())
2263 Operands.push_back(getSignExtendExpr(Operand, Ty));
2265 return getSMinExpr(Operands);
2266 return getSMaxExpr(Operands);
2267 }
2268
2269 // The cast wasn't folded; create an explicit cast node.
2270 // Recompute the insert position, as it may have been invalidated.
2271 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2272 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
2273 Op, Ty);
2274 UniqueSCEVs.InsertNode(S, IP);
2275 S->computeAndSetCanonical(*this);
2276 registerUser(S, Op);
2277 return S;
2278}
2279
2281 Type *Ty) {
2282 switch (Kind) {
2283 case scTruncate:
2284 return getTruncateExpr(Op, Ty);
2285 case scZeroExtend:
2286 return getZeroExtendExpr(Op, Ty);
2287 case scSignExtend:
2288 return getSignExtendExpr(Op, Ty);
2289 case scPtrToInt:
2290 return getPtrToIntExpr(Op, Ty);
2291 default:
2292 llvm_unreachable("Not a SCEV cast expression!");
2293 }
2294}
2295
2296/// getAnyExtendExpr - Return a SCEV for the given operand extended with
2297/// unspecified bits out to the given type.
2299 Type *Ty) {
2300 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
2301 "This is not an extending conversion!");
2302 assert(isSCEVable(Ty) &&
2303 "This is not a conversion to a SCEVable type!");
2304 Ty = getEffectiveSCEVType(Ty);
2305
2306 // Sign-extend negative constants.
2307 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
2308 if (SC->getAPInt().isNegative())
2309 return getSignExtendExpr(Op, Ty);
2310
2311 // Peel off a truncate cast.
2313 const SCEV *NewOp = T->getOperand();
2314 if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty))
2315 return getAnyExtendExpr(NewOp, Ty);
2316 return getTruncateOrNoop(NewOp, Ty);
2317 }
2318
2319 // Next try a zext cast. If the cast is folded, use it.
2320 const SCEV *ZExt = getZeroExtendExpr(Op, Ty);
2321 if (!isa<SCEVZeroExtendExpr>(ZExt))
2322 return ZExt;
2323
2324 // Next try a sext cast. If the cast is folded, use it.
2325 const SCEV *SExt = getSignExtendExpr(Op, Ty);
2326 if (!isa<SCEVSignExtendExpr>(SExt))
2327 return SExt;
2328
2329 // Force the cast to be folded into the operands of an addrec.
2330 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) {
2332 for (const SCEV *Op : AR->operands())
2333 Ops.push_back(getAnyExtendExpr(Op, Ty));
2334 return getAddRecExpr(Ops, AR->getLoop(), SCEV::FlagNW);
2335 }
2336
2337 // If the expression is obviously signed, use the sext cast value.
2338 if (isa<SCEVSMaxExpr>(Op))
2339 return SExt;
2340
2341 // Absent any other information, use the zext cast value.
2342 return ZExt;
2343}
2344
2345/// Process the given Ops list, which is a list of operands to be added under
2346/// the given scale, update the given map. This is a helper function for
2347/// getAddRecExpr. As an example of what it does, given a sequence of operands
2348/// that would form an add expression like this:
2349///
2350/// m + n + 13 + (A * (o + p + (B * (q + m + 29)))) + r + (-1 * r)
2351///
2352/// where A and B are constants, update the map with these values:
2353///
2354/// (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0)
2355///
2356/// and add 13 + A*B*29 to AccumulatedConstant.
2357/// This will allow getAddRecExpr to produce this:
2358///
2359/// 13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B)
2360///
2361/// This form often exposes folding opportunities that are hidden in
2362/// the original operand list.
2363///
2364/// Return true iff it appears that any interesting folding opportunities
2365/// may be exposed. This helps getAddRecExpr short-circuit extra work in
2366/// the common case where no interesting opportunities are present, and
2367/// is also used as a check to avoid infinite recursion.
2370 APInt &AccumulatedConstant,
2372 const APInt &Scale,
2373 ScalarEvolution &SE) {
2374 bool Interesting = false;
2375
2376 // Iterate over the add operands. They are sorted, with constants first.
2377 unsigned i = 0;
2378 while (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
2379 ++i;
2380 // Pull a buried constant out to the outside.
2381 if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero())
2382 Interesting = true;
2383 AccumulatedConstant += Scale * C->getAPInt();
2384 }
2385
2386 // Next comes everything else. We're especially interested in multiplies
2387 // here, but they're in the middle, so just visit the rest with one loop.
2388 for (; i != Ops.size(); ++i) {
2390 if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) {
2391 APInt NewScale =
2392 Scale * cast<SCEVConstant>(Mul->getOperand(0))->getAPInt();
2393 if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) {
2394 // A multiplication of a constant with another add; recurse.
2395 const SCEVAddExpr *Add = cast<SCEVAddExpr>(Mul->getOperand(1));
2396 Interesting |= CollectAddOperandsWithScales(
2397 M, NewOps, AccumulatedConstant, Add->operands(), NewScale, SE);
2398 } else {
2399 // A multiplication of a constant with some other value. Update
2400 // the map.
2401 SmallVector<SCEVUse, 4> MulOps(drop_begin(Mul->operands()));
2402 const SCEV *Key = SE.getMulExpr(MulOps);
2403 auto Pair = M.insert({Key, NewScale});
2404 if (Pair.second) {
2405 NewOps.push_back(Pair.first->first);
2406 } else {
2407 Pair.first->second += NewScale;
2408 // The map already had an entry for this value, which may indicate
2409 // a folding opportunity.
2410 Interesting = true;
2411 }
2412 }
2413 } else {
2414 // An ordinary operand. Update the map.
2415 auto Pair = M.insert({Ops[i], Scale});
2416 if (Pair.second) {
2417 NewOps.push_back(Pair.first->first);
2418 } else {
2419 Pair.first->second += Scale;
2420 // The map already had an entry for this value, which may indicate
2421 // a folding opportunity.
2422 Interesting = true;
2423 }
2424 }
2425 }
2426
2427 return Interesting;
2428}
2429
2431 const SCEV *LHS, const SCEV *RHS,
2432 const Instruction *CtxI) {
2434 unsigned);
2435 switch (BinOp) {
2436 default:
2437 llvm_unreachable("Unsupported binary op");
2438 case Instruction::Add:
2440 break;
2441 case Instruction::Sub:
2443 break;
2444 case Instruction::Mul:
2446 break;
2447 }
2448
2449 const SCEV *(ScalarEvolution::*Extension)(const SCEV *, Type *, unsigned) =
2452
2453 // Check ext(LHS op RHS) == ext(LHS) op ext(RHS)
2454 auto *NarrowTy = cast<IntegerType>(LHS->getType());
2455 auto *WideTy =
2456 IntegerType::get(NarrowTy->getContext(), NarrowTy->getBitWidth() * 2);
2457
2458 const SCEV *A = (this->*Extension)(
2459 (this->*Operation)(LHS, RHS, SCEV::FlagAnyWrap, 0), WideTy, 0);
2460 const SCEV *LHSB = (this->*Extension)(LHS, WideTy, 0);
2461 const SCEV *RHSB = (this->*Extension)(RHS, WideTy, 0);
2462 const SCEV *B = (this->*Operation)(LHSB, RHSB, SCEV::FlagAnyWrap, 0);
2463 if (A == B)
2464 return true;
2465 // Can we use context to prove the fact we need?
2466 if (!CtxI)
2467 return false;
2468 // TODO: Support mul.
2469 if (BinOp == Instruction::Mul)
2470 return false;
2471 auto *RHSC = dyn_cast<SCEVConstant>(RHS);
2472 // TODO: Lift this limitation.
2473 if (!RHSC)
2474 return false;
2475 APInt C = RHSC->getAPInt();
2476 unsigned NumBits = C.getBitWidth();
2477 bool IsSub = (BinOp == Instruction::Sub);
2478 bool IsNegativeConst = (Signed && C.isNegative());
2479 // Compute the direction and magnitude by which we need to check overflow.
2480 bool OverflowDown = IsSub ^ IsNegativeConst;
2481 APInt Magnitude = C;
2482 if (IsNegativeConst) {
2483 if (C == APInt::getSignedMinValue(NumBits))
2484 // TODO: SINT_MIN on inversion gives the same negative value, we don't
2485 // want to deal with that.
2486 return false;
2487 Magnitude = -C;
2488 }
2489
2491 if (OverflowDown) {
2492 // To avoid overflow down, we need to make sure that MIN + Magnitude <= LHS.
2493 APInt Min = Signed ? APInt::getSignedMinValue(NumBits)
2494 : APInt::getMinValue(NumBits);
2495 APInt Limit = Min + Magnitude;
2496 return isKnownPredicateAt(Pred, getConstant(Limit), LHS, CtxI);
2497 } else {
2498 // To avoid overflow up, we need to make sure that LHS <= MAX - Magnitude.
2499 APInt Max = Signed ? APInt::getSignedMaxValue(NumBits)
2500 : APInt::getMaxValue(NumBits);
2501 APInt Limit = Max - Magnitude;
2502 return isKnownPredicateAt(Pred, LHS, getConstant(Limit), CtxI);
2503 }
2504}
2505
2506std::optional<SCEV::NoWrapFlags>
2508 const OverflowingBinaryOperator *OBO) {
2509 // It cannot be done any better.
2510 if (OBO->hasNoUnsignedWrap() && OBO->hasNoSignedWrap())
2511 return std::nullopt;
2512
2513 SCEV::NoWrapFlags Flags = SCEV::NoWrapFlags::FlagAnyWrap;
2514
2515 if (OBO->hasNoUnsignedWrap())
2517 if (OBO->hasNoSignedWrap())
2519
2520 bool Deduced = false;
2521
2522 if (OBO->getOpcode() != Instruction::Add &&
2523 OBO->getOpcode() != Instruction::Sub &&
2524 OBO->getOpcode() != Instruction::Mul)
2525 return std::nullopt;
2526
2527 const SCEV *LHS = getSCEV(OBO->getOperand(0));
2528 const SCEV *RHS = getSCEV(OBO->getOperand(1));
2529
2530 const Instruction *CtxI =
2532 if (!OBO->hasNoUnsignedWrap() &&
2534 /* Signed */ false, LHS, RHS, CtxI)) {
2536 Deduced = true;
2537 }
2538
2539 if (!OBO->hasNoSignedWrap() &&
2541 /* Signed */ true, LHS, RHS, CtxI)) {
2543 Deduced = true;
2544 }
2545
2546 if (Deduced)
2547 return Flags;
2548 return std::nullopt;
2549}
2550
2551// We're trying to construct a SCEV of type `Type' with `Ops' as operands and
2552// `OldFlags' as can't-wrap behavior. Infer a more aggressive set of
2553// can't-overflow flags for the operation if possible.
2557 SCEV::NoWrapFlags Flags) {
2558 using namespace std::placeholders;
2559
2560 using OBO = OverflowingBinaryOperator;
2561
2562 bool CanAnalyze =
2564 (void)CanAnalyze;
2565 assert(CanAnalyze && "don't call from other places!");
2566
2567 SCEV::NoWrapFlags SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
2568 SCEV::NoWrapFlags SignOrUnsignWrap =
2569 ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2570
2571 // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
2572 auto IsKnownNonNegative = [&](SCEVUse U) {
2573 return SE->isKnownNonNegative(U);
2574 };
2575
2576 if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative))
2577 Flags = ScalarEvolution::setFlags(Flags, SignOrUnsignMask);
2578
2579 SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2580
2581 if (SignOrUnsignWrap != SignOrUnsignMask &&
2582 (Type == scAddExpr || Type == scMulExpr) && Ops.size() == 2 &&
2583 isa<SCEVConstant>(Ops[0])) {
2584
2585 auto Opcode = [&] {
2586 switch (Type) {
2587 case scAddExpr:
2588 return Instruction::Add;
2589 case scMulExpr:
2590 return Instruction::Mul;
2591 default:
2592 llvm_unreachable("Unexpected SCEV op.");
2593 }
2594 }();
2595
2596 const APInt &C = cast<SCEVConstant>(Ops[0])->getAPInt();
2597
2598 // (A <opcode> C) --> (A <opcode> C)<nsw> if the op doesn't sign overflow.
2599 if (!(SignOrUnsignWrap & SCEV::FlagNSW)) {
2601 Opcode, C, OBO::NoSignedWrap);
2602 if (NSWRegion.contains(SE->getSignedRange(Ops[1])))
2604 }
2605
2606 // (A <opcode> C) --> (A <opcode> C)<nuw> if the op doesn't unsign overflow.
2607 if (!(SignOrUnsignWrap & SCEV::FlagNUW)) {
2609 Opcode, C, OBO::NoUnsignedWrap);
2610 if (NUWRegion.contains(SE->getUnsignedRange(Ops[1])))
2612 }
2613 }
2614
2615 // <0,+,nonnegative><nw> is also nuw
2616 // TODO: Add corresponding nsw case
2618 !ScalarEvolution::hasFlags(Flags, SCEV::FlagNUW) && Ops.size() == 2 &&
2619 Ops[0]->isZero() && IsKnownNonNegative(Ops[1]))
2621
2622 // both (udiv X, Y) * Y and Y * (udiv X, Y) are always NUW
2624 Ops.size() == 2) {
2625 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[0]))
2626 if (UDiv->getOperand(1) == Ops[1])
2628 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[1]))
2629 if (UDiv->getOperand(1) == Ops[0])
2631 }
2632
2633 return Flags;
2634}
2635
2637 return isLoopInvariant(S, L) && properlyDominates(S, L->getHeader());
2638}
2639
2640/// Get a canonical add expression, or something simpler if possible.
2642 SCEV::NoWrapFlags OrigFlags,
2643 unsigned Depth) {
2644 assert(!(OrigFlags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) &&
2645 "only nuw or nsw allowed");
2646 assert(!Ops.empty() && "Cannot get empty add!");
2647 if (Ops.size() == 1) return Ops[0];
2648#ifndef NDEBUG
2649 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
2650 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2651 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
2652 "SCEVAddExpr operand types don't match!");
2653 unsigned NumPtrs = count_if(
2654 Ops, [](const SCEV *Op) { return Op->getType()->isPointerTy(); });
2655 assert(NumPtrs <= 1 && "add has at most one pointer operand");
2656#endif
2657
2658 const SCEV *Folded = constantFoldAndGroupOps(
2659 *this, LI, DT, Ops,
2660 [](const APInt &C1, const APInt &C2) { return C1 + C2; },
2661 [](const APInt &C) { return C.isZero(); }, // identity
2662 [](const APInt &C) { return false; }); // absorber
2663 if (Folded)
2664 return Folded;
2665
2666 unsigned Idx = isa<SCEVConstant>(Ops[0]) ? 1 : 0;
2667
2668 // Delay expensive flag strengthening until necessary.
2669 auto ComputeFlags = [this, OrigFlags](ArrayRef<SCEVUse> Ops) {
2670 return StrengthenNoWrapFlags(this, scAddExpr, Ops, OrigFlags);
2671 };
2672
2673 // Limit recursion calls depth.
2675 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2676
2677 if (SCEV *S = findExistingSCEVInCache(scAddExpr, Ops)) {
2678 // Don't strengthen flags if we have no new information.
2679 SCEVAddExpr *Add = static_cast<SCEVAddExpr *>(S);
2680 if (Add->getNoWrapFlags(OrigFlags) != OrigFlags)
2681 Add->setNoWrapFlags(ComputeFlags(Ops));
2682 return S;
2683 }
2684
2685 // Okay, check to see if the same value occurs in the operand list more than
2686 // once. If so, merge them together into an multiply expression. Since we
2687 // sorted the list, these values are required to be adjacent.
2688 Type *Ty = Ops[0]->getType();
2689 bool FoundMatch = false;
2690 for (unsigned i = 0, e = Ops.size(); i != e-1; ++i)
2691 if (Ops[i] == Ops[i+1]) { // X + Y + Y --> X + Y*2
2692 // Scan ahead to count how many equal operands there are.
2693 unsigned Count = 2;
2694 while (i+Count != e && Ops[i+Count] == Ops[i])
2695 ++Count;
2696 // Merge the values into a multiply.
2697 SCEVUse Scale = getConstant(Ty, Count);
2698 const SCEV *Mul = getMulExpr(Scale, Ops[i], SCEV::FlagAnyWrap, Depth + 1);
2699 if (Ops.size() == Count)
2700 return Mul;
2701 Ops[i] = Mul;
2702 Ops.erase(Ops.begin()+i+1, Ops.begin()+i+Count);
2703 --i; e -= Count - 1;
2704 FoundMatch = true;
2705 }
2706 if (FoundMatch)
2707 return getAddExpr(Ops, OrigFlags, Depth + 1);
2708
2709 // Check for truncates. If all the operands are truncated from the same
2710 // type, see if factoring out the truncate would permit the result to be
2711 // folded. eg., n*trunc(x) + m*trunc(y) --> trunc(trunc(m)*x + trunc(n)*y)
2712 // if the contents of the resulting outer trunc fold to something simple.
2713 auto FindTruncSrcType = [&]() -> Type * {
2714 // We're ultimately looking to fold an addrec of truncs and muls of only
2715 // constants and truncs, so if we find any other types of SCEV
2716 // as operands of the addrec then we bail and return nullptr here.
2717 // Otherwise, we return the type of the operand of a trunc that we find.
2718 if (auto *T = dyn_cast<SCEVTruncateExpr>(Ops[Idx]))
2719 return T->getOperand()->getType();
2720 if (const auto *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
2721 SCEVUse LastOp = Mul->getOperand(Mul->getNumOperands() - 1);
2722 if (const auto *T = dyn_cast<SCEVTruncateExpr>(LastOp))
2723 return T->getOperand()->getType();
2724 }
2725 return nullptr;
2726 };
2727 if (auto *SrcType = FindTruncSrcType()) {
2728 SmallVector<SCEVUse, 8> LargeOps;
2729 bool Ok = true;
2730 // Check all the operands to see if they can be represented in the
2731 // source type of the truncate.
2732 for (const SCEV *Op : Ops) {
2734 if (T->getOperand()->getType() != SrcType) {
2735 Ok = false;
2736 break;
2737 }
2738 LargeOps.push_back(T->getOperand());
2739 } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Op)) {
2740 LargeOps.push_back(getAnyExtendExpr(C, SrcType));
2741 } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Op)) {
2742 SmallVector<SCEVUse, 8> LargeMulOps;
2743 for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) {
2744 if (const SCEVTruncateExpr *T =
2745 dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) {
2746 if (T->getOperand()->getType() != SrcType) {
2747 Ok = false;
2748 break;
2749 }
2750 LargeMulOps.push_back(T->getOperand());
2751 } else if (const auto *C = dyn_cast<SCEVConstant>(M->getOperand(j))) {
2752 LargeMulOps.push_back(getAnyExtendExpr(C, SrcType));
2753 } else {
2754 Ok = false;
2755 break;
2756 }
2757 }
2758 if (Ok)
2759 LargeOps.push_back(getMulExpr(LargeMulOps, SCEV::FlagAnyWrap, Depth + 1));
2760 } else {
2761 Ok = false;
2762 break;
2763 }
2764 }
2765 if (Ok) {
2766 // Evaluate the expression in the larger type.
2767 const SCEV *Fold = getAddExpr(LargeOps, SCEV::FlagAnyWrap, Depth + 1);
2768 // If it folds to something simple, use it. Otherwise, don't.
2769 if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
2770 return getTruncateExpr(Fold, Ty);
2771 }
2772 }
2773
2774 if (Ops.size() == 2) {
2775 // Check if we have an expression of the form ((X + C1) - C2), where C1 and
2776 // C2 can be folded in a way that allows retaining wrapping flags of (X +
2777 // C1).
2778 const SCEV *A = Ops[0];
2779 const SCEV *B = Ops[1];
2780 auto *AddExpr = dyn_cast<SCEVAddExpr>(B);
2781 auto *C = dyn_cast<SCEVConstant>(A);
2782 if (AddExpr && C && isa<SCEVConstant>(AddExpr->getOperand(0))) {
2783 auto C1 = cast<SCEVConstant>(AddExpr->getOperand(0))->getAPInt();
2784 auto C2 = C->getAPInt();
2785 SCEV::NoWrapFlags PreservedFlags = SCEV::FlagAnyWrap;
2786
2787 APInt ConstAdd = C1 + C2;
2788 auto AddFlags = AddExpr->getNoWrapFlags();
2789 // Adding a smaller constant is NUW if the original AddExpr was NUW.
2791 ConstAdd.ule(C1)) {
2792 PreservedFlags =
2794 }
2795
2796 // Adding a constant with the same sign and small magnitude is NSW, if the
2797 // original AddExpr was NSW.
2799 C1.isSignBitSet() == ConstAdd.isSignBitSet() &&
2800 ConstAdd.abs().ule(C1.abs())) {
2801 PreservedFlags =
2803 }
2804
2805 if (PreservedFlags != SCEV::FlagAnyWrap) {
2806 SmallVector<SCEVUse, 4> NewOps(AddExpr->operands());
2807 NewOps[0] = getConstant(ConstAdd);
2808 return getAddExpr(NewOps, PreservedFlags);
2809 }
2810 }
2811
2812 // Try to push the constant operand into a ZExt: A + zext (-A + B) -> zext
2813 // (B), if trunc (A) + -A + B does not unsigned-wrap.
2814 const SCEVAddExpr *InnerAdd;
2815 if (match(B, m_scev_ZExt(m_scev_Add(InnerAdd)))) {
2816 const SCEV *NarrowA = getTruncateExpr(A, InnerAdd->getType());
2817 if (NarrowA == getNegativeSCEV(InnerAdd->getOperand(0)) &&
2818 getZeroExtendExpr(NarrowA, B->getType()) == A &&
2819 hasFlags(StrengthenNoWrapFlags(this, scAddExpr, {NarrowA, InnerAdd},
2821 SCEV::FlagNUW)) {
2822 return getZeroExtendExpr(getAddExpr(NarrowA, InnerAdd), B->getType());
2823 }
2824 }
2825 }
2826
2827 // Canonicalize (-1 * urem X, Y) + X --> (Y * X/Y)
2828 const SCEV *Y;
2829 if (Ops.size() == 2 &&
2830 match(Ops[0],
2832 m_scev_URem(m_scev_Specific(Ops[1]), m_SCEV(Y), *this))))
2833 return getMulExpr(Y, getUDivExpr(Ops[1], Y));
2834
2835 // Skip past any other cast SCEVs.
2836 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
2837 ++Idx;
2838
2839 // If there are add operands they would be next.
2840 if (Idx < Ops.size()) {
2841 bool DeletedAdd = false;
2842 // If the original flags and all inlined SCEVAddExprs are NUW, use the
2843 // common NUW flag for expression after inlining. Other flags cannot be
2844 // preserved, because they may depend on the original order of operations.
2845 SCEV::NoWrapFlags CommonFlags = maskFlags(OrigFlags, SCEV::FlagNUW);
2846 while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
2847 if (Ops.size() > AddOpsInlineThreshold ||
2848 Add->getNumOperands() > AddOpsInlineThreshold)
2849 break;
2850 // If we have an add, expand the add operands onto the end of the operands
2851 // list.
2852 Ops.erase(Ops.begin()+Idx);
2853 append_range(Ops, Add->operands());
2854 DeletedAdd = true;
2855 CommonFlags = maskFlags(CommonFlags, Add->getNoWrapFlags());
2856 }
2857
2858 // If we deleted at least one add, we added operands to the end of the list,
2859 // and they are not necessarily sorted. Recurse to resort and resimplify
2860 // any operands we just acquired.
2861 if (DeletedAdd)
2862 return getAddExpr(Ops, CommonFlags, Depth + 1);
2863 }
2864
2865 // Skip over the add expression until we get to a multiply.
2866 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
2867 ++Idx;
2868
2869 // Check to see if there are any folding opportunities present with
2870 // operands multiplied by constant values.
2871 if (Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx])) {
2875 APInt AccumulatedConstant(BitWidth, 0);
2876 if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2877 Ops, APInt(BitWidth, 1), *this)) {
2878 struct APIntCompare {
2879 bool operator()(const APInt &LHS, const APInt &RHS) const {
2880 return LHS.ult(RHS);
2881 }
2882 };
2883
2884 // Some interesting folding opportunity is present, so its worthwhile to
2885 // re-generate the operands list. Group the operands by constant scale,
2886 // to avoid multiplying by the same constant scale multiple times.
2887 std::map<APInt, SmallVector<SCEVUse, 4>, APIntCompare> MulOpLists;
2888 for (const SCEV *NewOp : NewOps)
2889 MulOpLists[M.find(NewOp)->second].push_back(NewOp);
2890 // Re-generate the operands list.
2891 Ops.clear();
2892 if (AccumulatedConstant != 0)
2893 Ops.push_back(getConstant(AccumulatedConstant));
2894 for (auto &MulOp : MulOpLists) {
2895 if (MulOp.first == 1) {
2896 Ops.push_back(getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1));
2897 } else if (MulOp.first != 0) {
2898 Ops.push_back(getMulExpr(
2899 getConstant(MulOp.first),
2900 getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1),
2901 SCEV::FlagAnyWrap, Depth + 1));
2902 }
2903 }
2904 if (Ops.empty())
2905 return getZero(Ty);
2906 if (Ops.size() == 1)
2907 return Ops[0];
2908 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2909 }
2910 }
2911
2912 // If we are adding something to a multiply expression, make sure the
2913 // something is not already an operand of the multiply. If so, merge it into
2914 // the multiply.
2915 for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) {
2916 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]);
2917 for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) {
2918 const SCEV *MulOpSCEV = Mul->getOperand(MulOp);
2919 if (isa<SCEVConstant>(MulOpSCEV))
2920 continue;
2921 for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
2922 if (MulOpSCEV == Ops[AddOp]) {
2923 // Fold W + X + (X * Y * Z) --> W + (X * ((Y*Z)+1))
2924 const SCEV *InnerMul = Mul->getOperand(MulOp == 0);
2925 if (Mul->getNumOperands() != 2) {
2926 // If the multiply has more than two operands, we must get the
2927 // Y*Z term.
2928 SmallVector<SCEVUse, 4> MulOps(Mul->operands().take_front(MulOp));
2929 append_range(MulOps, Mul->operands().drop_front(MulOp + 1));
2930 InnerMul = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2931 }
2932 const SCEV *AddOne =
2933 getAddExpr(getOne(Ty), InnerMul, SCEV::FlagAnyWrap, Depth + 1);
2934 const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV,
2936 if (Ops.size() == 2) return OuterMul;
2937 if (AddOp < Idx) {
2938 Ops.erase(Ops.begin()+AddOp);
2939 Ops.erase(Ops.begin()+Idx-1);
2940 } else {
2941 Ops.erase(Ops.begin()+Idx);
2942 Ops.erase(Ops.begin()+AddOp-1);
2943 }
2944 Ops.push_back(OuterMul);
2945 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2946 }
2947
2948 // Check this multiply against other multiplies being added together.
2949 for (unsigned OtherMulIdx = Idx+1;
2950 OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]);
2951 ++OtherMulIdx) {
2952 const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]);
2953 // If MulOp occurs in OtherMul, we can fold the two multiplies
2954 // together.
2955 for (unsigned OMulOp = 0, e = OtherMul->getNumOperands();
2956 OMulOp != e; ++OMulOp)
2957 if (OtherMul->getOperand(OMulOp) == MulOpSCEV) {
2958 // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E))
2959 const SCEV *InnerMul1 = Mul->getOperand(MulOp == 0);
2960 if (Mul->getNumOperands() != 2) {
2961 SmallVector<SCEVUse, 4> MulOps(Mul->operands().take_front(MulOp));
2962 append_range(MulOps, Mul->operands().drop_front(MulOp+1));
2963 InnerMul1 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2964 }
2965 const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0);
2966 if (OtherMul->getNumOperands() != 2) {
2968 OtherMul->operands().take_front(OMulOp));
2969 append_range(MulOps, OtherMul->operands().drop_front(OMulOp+1));
2970 InnerMul2 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2971 }
2972 const SCEV *InnerMulSum =
2973 getAddExpr(InnerMul1, InnerMul2, SCEV::FlagAnyWrap, Depth + 1);
2974 const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum,
2976 if (Ops.size() == 2) return OuterMul;
2977 Ops.erase(Ops.begin()+Idx);
2978 Ops.erase(Ops.begin()+OtherMulIdx-1);
2979 Ops.push_back(OuterMul);
2980 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2981 }
2982 }
2983 }
2984 }
2985
2986 // If there are any add recurrences in the operands list, see if any other
2987 // added values are loop invariant. If so, we can fold them into the
2988 // recurrence.
2989 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
2990 ++Idx;
2991
2992 // Scan over all recurrences, trying to fold loop invariants into them.
2993 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
2994 // Scan all of the other operands to this add and add them to the vector if
2995 // they are loop invariant w.r.t. the recurrence.
2997 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
2998 const Loop *AddRecLoop = AddRec->getLoop();
2999 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3000 if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) {
3001 LIOps.push_back(Ops[i]);
3002 Ops.erase(Ops.begin()+i);
3003 --i; --e;
3004 }
3005
3006 // If we found some loop invariants, fold them into the recurrence.
3007 if (!LIOps.empty()) {
3008 // Compute nowrap flags for the addition of the loop-invariant ops and
3009 // the addrec. Temporarily push it as an operand for that purpose. These
3010 // flags are valid in the scope of the addrec only.
3011 LIOps.push_back(AddRec);
3012 SCEV::NoWrapFlags Flags = ComputeFlags(LIOps);
3013 LIOps.pop_back();
3014
3015 // NLI + LI + {Start,+,Step} --> NLI + {LI+Start,+,Step}
3016 LIOps.push_back(AddRec->getStart());
3017
3018 SmallVector<SCEVUse, 4> AddRecOps(AddRec->operands());
3019
3020 // It is not in general safe to propagate flags valid on an add within
3021 // the addrec scope to one outside it. We must prove that the inner
3022 // scope is guaranteed to execute if the outer one does to be able to
3023 // safely propagate. We know the program is undefined if poison is
3024 // produced on the inner scoped addrec. We also know that *for this use*
3025 // the outer scoped add can't overflow (because of the flags we just
3026 // computed for the inner scoped add) without the program being undefined.
3027 // Proving that entry to the outer scope neccesitates entry to the inner
3028 // scope, thus proves the program undefined if the flags would be violated
3029 // in the outer scope.
3030 SCEV::NoWrapFlags AddFlags = Flags;
3031 if (AddFlags != SCEV::FlagAnyWrap) {
3032 auto *DefI = getDefiningScopeBound(LIOps);
3033 auto *ReachI = &*AddRecLoop->getHeader()->begin();
3034 if (!isGuaranteedToTransferExecutionTo(DefI, ReachI))
3035 AddFlags = SCEV::FlagAnyWrap;
3036 }
3037 AddRecOps[0] = getAddExpr(LIOps, AddFlags, Depth + 1);
3038
3039 // Build the new addrec. Propagate the NUW and NSW flags if both the
3040 // outer add and the inner addrec are guaranteed to have no overflow.
3041 // Always propagate NW.
3042 Flags = AddRec->getNoWrapFlags(setFlags(Flags, SCEV::FlagNW));
3043 const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags);
3044
3045 // If all of the other operands were loop invariant, we are done.
3046 if (Ops.size() == 1) return NewRec;
3047
3048 // Otherwise, add the folded AddRec by the non-invariant parts.
3049 for (unsigned i = 0;; ++i)
3050 if (Ops[i] == AddRec) {
3051 Ops[i] = NewRec;
3052 break;
3053 }
3054 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3055 }
3056
3057 // Okay, if there weren't any loop invariants to be folded, check to see if
3058 // there are multiple AddRec's with the same loop induction variable being
3059 // added together. If so, we can fold them.
3060 for (unsigned OtherIdx = Idx+1;
3061 OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
3062 ++OtherIdx) {
3063 // We expect the AddRecExpr's to be sorted in reverse dominance order,
3064 // so that the 1st found AddRecExpr is dominated by all others.
3065 assert(DT.dominates(
3066 cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()->getHeader(),
3067 AddRec->getLoop()->getHeader()) &&
3068 "AddRecExprs are not sorted in reverse dominance order?");
3069 if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) {
3070 // Other + {A,+,B}<L> + {C,+,D}<L> --> Other + {A+C,+,B+D}<L>
3071 SmallVector<SCEVUse, 4> AddRecOps(AddRec->operands());
3072 for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
3073 ++OtherIdx) {
3074 const auto *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
3075 if (OtherAddRec->getLoop() == AddRecLoop) {
3076 for (unsigned i = 0, e = OtherAddRec->getNumOperands();
3077 i != e; ++i) {
3078 if (i >= AddRecOps.size()) {
3079 append_range(AddRecOps, OtherAddRec->operands().drop_front(i));
3080 break;
3081 }
3082 AddRecOps[i] =
3083 getAddExpr(AddRecOps[i], OtherAddRec->getOperand(i),
3085 }
3086 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
3087 }
3088 }
3089 // Step size has changed, so we cannot guarantee no self-wraparound.
3090 Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap);
3091 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3092 }
3093 }
3094
3095 // Otherwise couldn't fold anything into this recurrence. Move onto the
3096 // next one.
3097 }
3098
3099 // Okay, it looks like we really DO need an add expr. Check to see if we
3100 // already have one, otherwise create a new one.
3101 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
3102}
3103
3104const SCEV *ScalarEvolution::getOrCreateAddExpr(ArrayRef<SCEVUse> Ops,
3105 SCEV::NoWrapFlags Flags) {
3107 ID.AddInteger(scAddExpr);
3108 for (const SCEV *Op : Ops)
3109 ID.AddPointer(Op);
3110 void *IP = nullptr;
3111 SCEVAddExpr *S =
3112 static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3113 if (!S) {
3114 SCEVUse *O = SCEVAllocator.Allocate<SCEVUse>(Ops.size());
3116 S = new (SCEVAllocator)
3117 SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size());
3118 UniqueSCEVs.InsertNode(S, IP);
3119 S->computeAndSetCanonical(*this);
3120 registerUser(S, Ops);
3121 }
3122 S->setNoWrapFlags(Flags);
3123 return S;
3124}
3125
3126const SCEV *ScalarEvolution::getOrCreateAddRecExpr(ArrayRef<SCEVUse> Ops,
3127 const Loop *L,
3128 SCEV::NoWrapFlags Flags) {
3129 FoldingSetNodeID ID;
3130 ID.AddInteger(scAddRecExpr);
3131 for (const SCEV *Op : Ops)
3132 ID.AddPointer(Op);
3133 ID.AddPointer(L);
3134 void *IP = nullptr;
3135 SCEVAddRecExpr *S =
3136 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3137 if (!S) {
3138 SCEVUse *O = SCEVAllocator.Allocate<SCEVUse>(Ops.size());
3140 S = new (SCEVAllocator)
3141 SCEVAddRecExpr(ID.Intern(SCEVAllocator), O, Ops.size(), L);
3142 UniqueSCEVs.InsertNode(S, IP);
3143 S->computeAndSetCanonical(*this);
3144 LoopUsers[L].push_back(S);
3145 registerUser(S, Ops);
3146 }
3147 setNoWrapFlags(S, Flags);
3148 return S;
3149}
3150
3151const SCEV *ScalarEvolution::getOrCreateMulExpr(ArrayRef<SCEVUse> Ops,
3152 SCEV::NoWrapFlags Flags) {
3153 FoldingSetNodeID ID;
3154 ID.AddInteger(scMulExpr);
3155 for (const SCEV *Op : Ops)
3156 ID.AddPointer(Op);
3157 void *IP = nullptr;
3158 SCEVMulExpr *S =
3159 static_cast<SCEVMulExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3160 if (!S) {
3161 SCEVUse *O = SCEVAllocator.Allocate<SCEVUse>(Ops.size());
3163 S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator),
3164 O, Ops.size());
3165 UniqueSCEVs.InsertNode(S, IP);
3166 S->computeAndSetCanonical(*this);
3167 registerUser(S, Ops);
3168 }
3169 S->setNoWrapFlags(Flags);
3170 return S;
3171}
3172
3173static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) {
3174 uint64_t k = i*j;
3175 if (j > 1 && k / j != i) Overflow = true;
3176 return k;
3177}
3178
3179/// Compute the result of "n choose k", the binomial coefficient. If an
3180/// intermediate computation overflows, Overflow will be set and the return will
3181/// be garbage. Overflow is not cleared on absence of overflow.
3182static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) {
3183 // We use the multiplicative formula:
3184 // n(n-1)(n-2)...(n-(k-1)) / k(k-1)(k-2)...1 .
3185 // At each iteration, we take the n-th term of the numeral and divide by the
3186 // (k-n)th term of the denominator. This division will always produce an
3187 // integral result, and helps reduce the chance of overflow in the
3188 // intermediate computations. However, we can still overflow even when the
3189 // final result would fit.
3190
3191 if (n == 0 || n == k) return 1;
3192 if (k > n) return 0;
3193
3194 if (k > n/2)
3195 k = n-k;
3196
3197 uint64_t r = 1;
3198 for (uint64_t i = 1; i <= k; ++i) {
3199 r = umul_ov(r, n-(i-1), Overflow);
3200 r /= i;
3201 }
3202 return r;
3203}
3204
3205/// Determine if any of the operands in this SCEV are a constant or if
3206/// any of the add or multiply expressions in this SCEV contain a constant.
3207static bool containsConstantInAddMulChain(const SCEV *StartExpr) {
3208 struct FindConstantInAddMulChain {
3209 bool FoundConstant = false;
3210
3211 bool follow(const SCEV *S) {
3212 FoundConstant |= isa<SCEVConstant>(S);
3213 return isa<SCEVAddExpr>(S) || isa<SCEVMulExpr>(S);
3214 }
3215
3216 bool isDone() const {
3217 return FoundConstant;
3218 }
3219 };
3220
3221 FindConstantInAddMulChain F;
3223 ST.visitAll(StartExpr);
3224 return F.FoundConstant;
3225}
3226
3227/// Get a canonical multiply expression, or something simpler if possible.
3229 SCEV::NoWrapFlags OrigFlags,
3230 unsigned Depth) {
3231 assert(OrigFlags == maskFlags(OrigFlags, SCEV::FlagNUW | SCEV::FlagNSW) &&
3232 "only nuw or nsw allowed");
3233 assert(!Ops.empty() && "Cannot get empty mul!");
3234 if (Ops.size() == 1) return Ops[0];
3235#ifndef NDEBUG
3236 Type *ETy = Ops[0]->getType();
3237 assert(!ETy->isPointerTy());
3238 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
3239 assert(Ops[i]->getType() == ETy &&
3240 "SCEVMulExpr operand types don't match!");
3241#endif
3242
3243 const SCEV *Folded = constantFoldAndGroupOps(
3244 *this, LI, DT, Ops,
3245 [](const APInt &C1, const APInt &C2) { return C1 * C2; },
3246 [](const APInt &C) { return C.isOne(); }, // identity
3247 [](const APInt &C) { return C.isZero(); }); // absorber
3248 if (Folded)
3249 return Folded;
3250
3251 // Delay expensive flag strengthening until necessary.
3252 auto ComputeFlags = [this, OrigFlags](const ArrayRef<SCEVUse> Ops) {
3253 return StrengthenNoWrapFlags(this, scMulExpr, Ops, OrigFlags);
3254 };
3255
3256 // Limit recursion calls depth.
3258 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3259
3260 if (SCEV *S = findExistingSCEVInCache(scMulExpr, Ops)) {
3261 // Don't strengthen flags if we have no new information.
3262 SCEVMulExpr *Mul = static_cast<SCEVMulExpr *>(S);
3263 if (Mul->getNoWrapFlags(OrigFlags) != OrigFlags)
3264 Mul->setNoWrapFlags(ComputeFlags(Ops));
3265 return S;
3266 }
3267
3268 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3269 if (Ops.size() == 2) {
3270 // C1*(C2+V) -> C1*C2 + C1*V
3271 // If any of Add's ops are Adds or Muls with a constant, apply this
3272 // transformation as well.
3273 //
3274 // TODO: There are some cases where this transformation is not
3275 // profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of
3276 // this transformation should be narrowed down.
3277 const SCEV *Op0, *Op1;
3278 if (match(Ops[1], m_scev_Add(m_SCEV(Op0), m_SCEV(Op1))) &&
3280 const SCEV *LHS = getMulExpr(LHSC, Op0, SCEV::FlagAnyWrap, Depth + 1);
3281 const SCEV *RHS = getMulExpr(LHSC, Op1, SCEV::FlagAnyWrap, Depth + 1);
3282 return getAddExpr(LHS, RHS, SCEV::FlagAnyWrap, Depth + 1);
3283 }
3284
3285 if (Ops[0]->isAllOnesValue()) {
3286 // If we have a mul by -1 of an add, try distributing the -1 among the
3287 // add operands.
3288 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) {
3290 bool AnyFolded = false;
3291 for (const SCEV *AddOp : Add->operands()) {
3292 const SCEV *Mul = getMulExpr(Ops[0], SCEVUse(AddOp),
3294 if (!isa<SCEVMulExpr>(Mul)) AnyFolded = true;
3295 NewOps.push_back(Mul);
3296 }
3297 if (AnyFolded)
3298 return getAddExpr(NewOps, SCEV::FlagAnyWrap, Depth + 1);
3299 } else if (const auto *AddRec = dyn_cast<SCEVAddRecExpr>(Ops[1])) {
3300 // Negation preserves a recurrence's no self-wrap property.
3301 SmallVector<SCEVUse, 4> Operands;
3302 for (const SCEV *AddRecOp : AddRec->operands())
3303 Operands.push_back(getMulExpr(Ops[0], SCEVUse(AddRecOp),
3304 SCEV::FlagAnyWrap, Depth + 1));
3305 // Let M be the minimum representable signed value. AddRec with nsw
3306 // multiplied by -1 can have signed overflow if and only if it takes a
3307 // value of M: M * (-1) would stay M and (M + 1) * (-1) would be the
3308 // maximum signed value. In all other cases signed overflow is
3309 // impossible.
3310 auto FlagsMask = SCEV::FlagNW;
3311 if (AddRec->hasNoSignedWrap()) {
3312 auto MinInt =
3313 APInt::getSignedMinValue(getTypeSizeInBits(AddRec->getType()));
3314 if (getSignedRangeMin(AddRec) != MinInt)
3315 FlagsMask = setFlags(FlagsMask, SCEV::FlagNSW);
3316 }
3317 return getAddRecExpr(Operands, AddRec->getLoop(),
3318 AddRec->getNoWrapFlags(FlagsMask));
3319 }
3320 }
3321
3322 // Try to push the constant operand into a ZExt: C * zext (A + B) ->
3323 // zext (C*A + C*B) if trunc (C) * (A + B) does not unsigned-wrap.
3324 const SCEVAddExpr *InnerAdd;
3325 if (match(Ops[1], m_scev_ZExt(m_scev_Add(InnerAdd)))) {
3326 const SCEV *NarrowC = getTruncateExpr(LHSC, InnerAdd->getType());
3327 if (isa<SCEVConstant>(InnerAdd->getOperand(0)) &&
3328 getZeroExtendExpr(NarrowC, Ops[1]->getType()) == LHSC &&
3329 hasFlags(StrengthenNoWrapFlags(this, scMulExpr, {NarrowC, InnerAdd},
3331 SCEV::FlagNUW)) {
3332 auto *Res = getMulExpr(NarrowC, InnerAdd, SCEV::FlagNUW, Depth + 1);
3333 return getZeroExtendExpr(Res, Ops[1]->getType(), Depth + 1);
3334 };
3335 }
3336
3337 // Try to fold (C1 * D /u C2) -> C1/C2 * D, if C1 and C2 are powers-of-2,
3338 // D is a multiple of C2, and C1 is a multiple of C2. If C2 is a multiple
3339 // of C1, fold to (D /u (C2 /u C1)).
3340 const SCEV *D;
3341 APInt C1V = LHSC->getAPInt();
3342 // (C1 * D /u C2) == -1 * -C1 * D /u C2 when C1 != INT_MIN. Don't treat -1
3343 // as -1 * 1, as it won't enable additional folds.
3344 if (C1V.isNegative() && !C1V.isMinSignedValue() && !C1V.isAllOnes())
3345 C1V = C1V.abs();
3346 const SCEVConstant *C2;
3347 if (C1V.isPowerOf2() &&
3349 C2->getAPInt().isPowerOf2() &&
3350 C1V.logBase2() <= getMinTrailingZeros(D)) {
3351 const SCEV *NewMul = nullptr;
3352 if (C1V.uge(C2->getAPInt())) {
3353 NewMul = getMulExpr(getUDivExpr(getConstant(C1V), C2), D);
3354 } else if (C2->getAPInt().logBase2() <= getMinTrailingZeros(D)) {
3355 assert(C1V.ugt(1) && "C1 <= 1 should have been folded earlier");
3356 NewMul = getUDivExpr(D, getUDivExpr(C2, getConstant(C1V)));
3357 }
3358 if (NewMul)
3359 return C1V == LHSC->getAPInt() ? NewMul : getNegativeSCEV(NewMul);
3360 }
3361 }
3362 }
3363
3364 // Skip over the add expression until we get to a multiply.
3365 unsigned Idx = 0;
3366 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
3367 ++Idx;
3368
3369 // If there are mul operands inline them all into this expression.
3370 if (Idx < Ops.size()) {
3371 bool DeletedMul = false;
3372 while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
3373 if (Ops.size() > MulOpsInlineThreshold)
3374 break;
3375 // If we have an mul, expand the mul operands onto the end of the
3376 // operands list.
3377 Ops.erase(Ops.begin()+Idx);
3378 append_range(Ops, Mul->operands());
3379 DeletedMul = true;
3380 }
3381
3382 // If we deleted at least one mul, we added operands to the end of the
3383 // list, and they are not necessarily sorted. Recurse to resort and
3384 // resimplify any operands we just acquired.
3385 if (DeletedMul)
3386 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3387 }
3388
3389 // If there are any add recurrences in the operands list, see if any other
3390 // added values are loop invariant. If so, we can fold them into the
3391 // recurrence.
3392 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
3393 ++Idx;
3394
3395 // Scan over all recurrences, trying to fold loop invariants into them.
3396 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
3397 // Scan all of the other operands to this mul and add them to the vector
3398 // if they are loop invariant w.r.t. the recurrence.
3400 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
3401 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3402 if (isAvailableAtLoopEntry(Ops[i], AddRec->getLoop())) {
3403 LIOps.push_back(Ops[i]);
3404 Ops.erase(Ops.begin()+i);
3405 --i; --e;
3406 }
3407
3408 // If we found some loop invariants, fold them into the recurrence.
3409 if (!LIOps.empty()) {
3410 // NLI * LI * {Start,+,Step} --> NLI * {LI*Start,+,LI*Step}
3412 NewOps.reserve(AddRec->getNumOperands());
3413 const SCEV *Scale = getMulExpr(LIOps, SCEV::FlagAnyWrap, Depth + 1);
3414
3415 // If both the mul and addrec are nuw, we can preserve nuw.
3416 // If both the mul and addrec are nsw, we can only preserve nsw if either
3417 // a) they are also nuw, or
3418 // b) all multiplications of addrec operands with scale are nsw.
3419 SCEV::NoWrapFlags Flags =
3420 AddRec->getNoWrapFlags(ComputeFlags({Scale, AddRec}));
3421
3422 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
3423 NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i),
3424 SCEV::FlagAnyWrap, Depth + 1));
3425
3426 if (hasFlags(Flags, SCEV::FlagNSW) && !hasFlags(Flags, SCEV::FlagNUW)) {
3428 Instruction::Mul, getSignedRange(Scale),
3430 if (!NSWRegion.contains(getSignedRange(AddRec->getOperand(i))))
3431 Flags = clearFlags(Flags, SCEV::FlagNSW);
3432 }
3433 }
3434
3435 const SCEV *NewRec = getAddRecExpr(NewOps, AddRec->getLoop(), Flags);
3436
3437 // If all of the other operands were loop invariant, we are done.
3438 if (Ops.size() == 1) return NewRec;
3439
3440 // Otherwise, multiply the folded AddRec by the non-invariant parts.
3441 for (unsigned i = 0;; ++i)
3442 if (Ops[i] == AddRec) {
3443 Ops[i] = NewRec;
3444 break;
3445 }
3446 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3447 }
3448
3449 // Okay, if there weren't any loop invariants to be folded, check to see
3450 // if there are multiple AddRec's with the same loop induction variable
3451 // being multiplied together. If so, we can fold them.
3452
3453 // {A1,+,A2,+,...,+,An}<L> * {B1,+,B2,+,...,+,Bn}<L>
3454 // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [
3455 // choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z
3456 // ]]],+,...up to x=2n}.
3457 // Note that the arguments to choose() are always integers with values
3458 // known at compile time, never SCEV objects.
3459 //
3460 // The implementation avoids pointless extra computations when the two
3461 // addrec's are of different length (mathematically, it's equivalent to
3462 // an infinite stream of zeros on the right).
3463 bool OpsModified = false;
3464 for (unsigned OtherIdx = Idx+1;
3465 OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
3466 ++OtherIdx) {
3467 const SCEVAddRecExpr *OtherAddRec =
3468 dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]);
3469 if (!OtherAddRec || OtherAddRec->getLoop() != AddRec->getLoop())
3470 continue;
3471
3472 // Limit max number of arguments to avoid creation of unreasonably big
3473 // SCEVAddRecs with very complex operands.
3474 if (AddRec->getNumOperands() + OtherAddRec->getNumOperands() - 1 >
3475 MaxAddRecSize || hasHugeExpression({AddRec, OtherAddRec}))
3476 continue;
3477
3478 bool Overflow = false;
3479 Type *Ty = AddRec->getType();
3480 bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64;
3481 SmallVector<SCEVUse, 7> AddRecOps;
3482 for (int x = 0, xe = AddRec->getNumOperands() +
3483 OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) {
3485 for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) {
3486 uint64_t Coeff1 = Choose(x, 2*x - y, Overflow);
3487 for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1),
3488 ze = std::min(x+1, (int)OtherAddRec->getNumOperands());
3489 z < ze && !Overflow; ++z) {
3490 uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow);
3491 uint64_t Coeff;
3492 if (LargerThan64Bits)
3493 Coeff = umul_ov(Coeff1, Coeff2, Overflow);
3494 else
3495 Coeff = Coeff1*Coeff2;
3496 const SCEV *CoeffTerm = getConstant(Ty, Coeff);
3497 const SCEV *Term1 = AddRec->getOperand(y-z);
3498 const SCEV *Term2 = OtherAddRec->getOperand(z);
3499 SumOps.push_back(getMulExpr(CoeffTerm, Term1, Term2,
3500 SCEV::FlagAnyWrap, Depth + 1));
3501 }
3502 }
3503 if (SumOps.empty())
3504 SumOps.push_back(getZero(Ty));
3505 AddRecOps.push_back(getAddExpr(SumOps, SCEV::FlagAnyWrap, Depth + 1));
3506 }
3507 if (!Overflow) {
3508 const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRec->getLoop(),
3510 if (Ops.size() == 2) return NewAddRec;
3511 Ops[Idx] = NewAddRec;
3512 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
3513 OpsModified = true;
3514 AddRec = dyn_cast<SCEVAddRecExpr>(NewAddRec);
3515 if (!AddRec)
3516 break;
3517 }
3518 }
3519 if (OpsModified)
3520 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3521
3522 // Otherwise couldn't fold anything into this recurrence. Move onto the
3523 // next one.
3524 }
3525
3526 // Okay, it looks like we really DO need an mul expr. Check to see if we
3527 // already have one, otherwise create a new one.
3528 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3529}
3530
3531/// Represents an unsigned remainder expression based on unsigned division.
3533 assert(getEffectiveSCEVType(LHS->getType()) ==
3534 getEffectiveSCEVType(RHS->getType()) &&
3535 "SCEVURemExpr operand types don't match!");
3536
3537 // Short-circuit easy cases
3538 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3539 // If constant is one, the result is trivial
3540 if (RHSC->getValue()->isOne())
3541 return getZero(LHS->getType()); // X urem 1 --> 0
3542
3543 // If constant is a power of two, fold into a zext(trunc(LHS)).
3544 if (RHSC->getAPInt().isPowerOf2()) {
3545 Type *FullTy = LHS->getType();
3546 Type *TruncTy =
3547 IntegerType::get(getContext(), RHSC->getAPInt().logBase2());
3548 return getZeroExtendExpr(getTruncateExpr(LHS, TruncTy), FullTy);
3549 }
3550 }
3551
3552 // Fallback to %a == %x urem %y == %x -<nuw> ((%x udiv %y) *<nuw> %y)
3553 const SCEV *UDiv = getUDivExpr(LHS, RHS);
3554 const SCEV *Mult = getMulExpr(UDiv, RHS, SCEV::FlagNUW);
3555 return getMinusSCEV(LHS, Mult, SCEV::FlagNUW);
3556}
3557
3558/// Get a canonical unsigned division expression, or something simpler if
3559/// possible.
3561 assert(!LHS->getType()->isPointerTy() &&
3562 "SCEVUDivExpr operand can't be pointer!");
3563 assert(LHS->getType() == RHS->getType() &&
3564 "SCEVUDivExpr operand types don't match!");
3565
3567 ID.AddInteger(scUDivExpr);
3568 ID.AddPointer(LHS);
3569 ID.AddPointer(RHS);
3570 void *IP = nullptr;
3571 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3572 return S;
3573
3574 // 0 udiv Y == 0
3575 if (match(LHS, m_scev_Zero()))
3576 return LHS;
3577
3578 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3579 if (RHSC->getValue()->isOne())
3580 return LHS; // X udiv 1 --> x
3581 // If the denominator is zero, the result of the udiv is undefined. Don't
3582 // try to analyze it, because the resolution chosen here may differ from
3583 // the resolution chosen in other parts of the compiler.
3584 if (!RHSC->getValue()->isZero()) {
3585 // Determine if the division can be folded into the operands of
3586 // its operands.
3587 // TODO: Generalize this to non-constants by using known-bits information.
3588 Type *Ty = LHS->getType();
3589 unsigned LZ = RHSC->getAPInt().countl_zero();
3590 unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1;
3591 // For non-power-of-two values, effectively round the value up to the
3592 // nearest power of two.
3593 if (!RHSC->getAPInt().isPowerOf2())
3594 ++MaxShiftAmt;
3595 IntegerType *ExtTy =
3596 IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt);
3597 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
3598 if (const SCEVConstant *Step =
3599 dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this))) {
3600 // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded.
3601 const APInt &StepInt = Step->getAPInt();
3602 const APInt &DivInt = RHSC->getAPInt();
3603 if (!StepInt.urem(DivInt) &&
3604 getZeroExtendExpr(AR, ExtTy) ==
3605 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3606 getZeroExtendExpr(Step, ExtTy),
3607 AR->getLoop(), SCEV::FlagAnyWrap)) {
3608 SmallVector<SCEVUse, 4> Operands;
3609 for (const SCEV *Op : AR->operands())
3610 Operands.push_back(getUDivExpr(Op, RHS));
3611 return getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagNW);
3612 }
3613 /// Get a canonical UDivExpr for a recurrence.
3614 /// {X,+,N}/C => {Y,+,N}/C where Y=X-(X%N). Safe when C%N=0.
3615 const APInt *StartRem;
3616 if (!DivInt.urem(StepInt) && match(getURemExpr(AR->getStart(), Step),
3617 m_scev_APInt(StartRem))) {
3618 bool NoWrap =
3619 getZeroExtendExpr(AR, ExtTy) ==
3620 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3621 getZeroExtendExpr(Step, ExtTy), AR->getLoop(),
3623
3624 // With N <= C and both N, C as powers-of-2, the transformation
3625 // {X,+,N}/C => {(X - X%N),+,N}/C preserves division results even
3626 // if wrapping occurs, as the division results remain equivalent for
3627 // all offsets in [[(X - X%N), X).
3628 bool CanFoldWithWrap = StepInt.ule(DivInt) && // N <= C
3629 StepInt.isPowerOf2() && DivInt.isPowerOf2();
3630 // Only fold if the subtraction can be folded in the start
3631 // expression.
3632 const SCEV *NewStart =
3633 getMinusSCEV(AR->getStart(), getConstant(*StartRem));
3634 if (*StartRem != 0 && (NoWrap || CanFoldWithWrap) &&
3635 !isa<SCEVAddExpr>(NewStart)) {
3636 const SCEV *NewLHS =
3637 getAddRecExpr(NewStart, Step, AR->getLoop(),
3638 NoWrap ? SCEV::FlagNW : SCEV::FlagAnyWrap);
3639 if (LHS != NewLHS) {
3640 LHS = NewLHS;
3641
3642 // Reset the ID to include the new LHS, and check if it is
3643 // already cached.
3644 ID.clear();
3645 ID.AddInteger(scUDivExpr);
3646 ID.AddPointer(LHS);
3647 ID.AddPointer(RHS);
3648 IP = nullptr;
3649 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3650 return S;
3651 }
3652 }
3653 }
3654 }
3655 // (A*B)/C --> A*(B/C) if safe and B/C can be folded.
3656 if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) {
3657 SmallVector<SCEVUse, 4> Operands;
3658 for (const SCEV *Op : M->operands())
3659 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3660 if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands)) {
3661 // Find an operand that's safely divisible.
3662 for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) {
3663 const SCEV *Op = M->getOperand(i);
3664 const SCEV *Div = getUDivExpr(Op, RHSC);
3665 if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) {
3666 Operands = SmallVector<SCEVUse, 4>(M->operands());
3667 Operands[i] = Div;
3668 return getMulExpr(Operands);
3669 }
3670 }
3671
3672 // Even if it's not divisible, try to remove a common factor.
3673 if (const auto *LHSC = dyn_cast<SCEVConstant>(M->getOperand(0))) {
3674 APInt Factor = APIntOps::GreatestCommonDivisor(LHSC->getAPInt(),
3675 RHSC->getAPInt());
3676 if (!Factor.isIntN(1)) {
3677 SmallVector<SCEVUse, 2> NewOperands;
3678 NewOperands.push_back(getConstant(LHSC->getAPInt().udiv(Factor)));
3679 append_range(NewOperands, M->operands().drop_front());
3680 const SCEV *NewMul = getMulExpr(NewOperands);
3681 return getUDivExpr(NewMul,
3682 getConstant(RHSC->getAPInt().udiv(Factor)));
3683 }
3684 }
3685 }
3686 }
3687
3688 // (A/B)/C --> A/(B*C) if safe and B*C can be folded.
3689 if (const SCEVUDivExpr *OtherDiv = dyn_cast<SCEVUDivExpr>(LHS)) {
3690 if (auto *DivisorConstant =
3691 dyn_cast<SCEVConstant>(OtherDiv->getRHS())) {
3692 bool Overflow = false;
3693 APInt NewRHS =
3694 DivisorConstant->getAPInt().umul_ov(RHSC->getAPInt(), Overflow);
3695 if (Overflow) {
3696 return getConstant(RHSC->getType(), 0, false);
3697 }
3698 return getUDivExpr(OtherDiv->getLHS(), getConstant(NewRHS));
3699 }
3700 }
3701
3702 // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded.
3703 if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(LHS)) {
3704 SmallVector<SCEVUse, 4> Operands;
3705 for (const SCEV *Op : A->operands())
3706 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3707 if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) {
3708 Operands.clear();
3709 for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) {
3710 const SCEV *Op = getUDivExpr(A->getOperand(i), RHS);
3711 if (isa<SCEVUDivExpr>(Op) ||
3712 getMulExpr(Op, RHS) != A->getOperand(i))
3713 break;
3714 Operands.push_back(Op);
3715 }
3716 if (Operands.size() == A->getNumOperands())
3717 return getAddExpr(Operands);
3718 }
3719 }
3720
3721 // ((N - M) + (M * A)) / N --> ((N - 1) + (M * A)) / N
3722 // This is an idiom for rounding A up to the next multiple of N, where A
3723 // is aready known to be a multiple of M. In this case, instcombine can
3724 // see that some low bits of the added constant are unused, so can clear
3725 // them, but we want to canonicalise to set the low bits. This makes the
3726 // pattern easier to match, without needing to check for known bits in
3727 // A*M.
3728 const APInt &N = RHSC->getAPInt();
3729 const APInt *NMinusM, *M;
3730 const SCEV *A;
3731 if (match(LHS, m_scev_Add(m_scev_APInt(NMinusM),
3732 m_scev_Mul(m_scev_APInt(M), m_SCEV(A))))) {
3733 if (N.isPowerOf2() && M->isPowerOf2() && M->ult(N) &&
3734 *NMinusM == N - *M) {
3735 return getUDivExpr(
3737 RHS);
3738 }
3739 }
3740
3741 // Fold if both operands are constant.
3742 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS))
3743 return getConstant(LHSC->getAPInt().udiv(RHSC->getAPInt()));
3744 }
3745 }
3746
3747 // ((-C + (C smax %x)) /u %x) evaluates to zero, for any positive constant C.
3748 const APInt *NegC, *C;
3749 if (match(LHS,
3752 NegC->isNegative() && !NegC->isMinSignedValue() && *C == -*NegC)
3753 return getZero(LHS->getType());
3754
3755 // (%a * %b)<nuw> / %b -> %a
3756 const auto *Mul = dyn_cast<SCEVMulExpr>(LHS);
3757 if (Mul && Mul->hasNoUnsignedWrap()) {
3758 for (int i = 0, e = Mul->getNumOperands(); i != e; ++i) {
3759 if (Mul->getOperand(i) == RHS) {
3760 SmallVector<SCEVUse, 2> Operands;
3761 append_range(Operands, Mul->operands().take_front(i));
3762 append_range(Operands, Mul->operands().drop_front(i + 1));
3763 return getMulExpr(Operands);
3764 }
3765 }
3766 }
3767
3768 // TODO: Generalize to handle any common factors.
3769 // udiv (mul nuw a, vscale), (mul nuw b, vscale) --> udiv a, b
3770 const SCEV *NewLHS, *NewRHS;
3771 if (match(LHS, m_scev_c_NUWMul(m_SCEV(NewLHS), m_SCEVVScale())) &&
3772 match(RHS, m_scev_c_NUWMul(m_SCEV(NewRHS), m_SCEVVScale())))
3773 return getUDivExpr(NewLHS, NewRHS);
3774
3775 // The Insertion Point (IP) might be invalid by now (due to UniqueSCEVs
3776 // changes). Make sure we get a new one.
3777 IP = nullptr;
3778 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
3779 SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator),
3780 LHS, RHS);
3781 UniqueSCEVs.InsertNode(S, IP);
3782 S->computeAndSetCanonical(*this);
3783 registerUser(S, ArrayRef<SCEVUse>({LHS, RHS}));
3784 return S;
3785}
3786
3787APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) {
3788 APInt A = C1->getAPInt().abs();
3789 APInt B = C2->getAPInt().abs();
3790 uint32_t ABW = A.getBitWidth();
3791 uint32_t BBW = B.getBitWidth();
3792
3793 if (ABW > BBW)
3794 B = B.zext(ABW);
3795 else if (ABW < BBW)
3796 A = A.zext(BBW);
3797
3798 return APIntOps::GreatestCommonDivisor(std::move(A), std::move(B));
3799}
3800
3801/// Get a canonical unsigned division expression, or something simpler if
3802/// possible. There is no representation for an exact udiv in SCEV IR, but we
3803/// can attempt to optimize it prior to construction.
3805 // Currently there is no exact specific logic.
3806
3807 return getUDivExpr(LHS, RHS);
3808}
3809
3810/// Get an add recurrence expression for the specified loop. Simplify the
3811/// expression as much as possible.
3813 const Loop *L,
3814 SCEV::NoWrapFlags Flags) {
3815 SmallVector<SCEVUse, 4> Operands;
3816 Operands.push_back(Start);
3817 if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step))
3818 if (StepChrec->getLoop() == L) {
3819 append_range(Operands, StepChrec->operands());
3820 return getAddRecExpr(Operands, L, maskFlags(Flags, SCEV::FlagNW));
3821 }
3822
3823 Operands.push_back(Step);
3824 return getAddRecExpr(Operands, L, Flags);
3825}
3826
3827/// Get an add recurrence expression for the specified loop. Simplify the
3828/// expression as much as possible.
3830 const Loop *L,
3831 SCEV::NoWrapFlags Flags) {
3832 if (Operands.size() == 1) return Operands[0];
3833#ifndef NDEBUG
3834 Type *ETy = getEffectiveSCEVType(Operands[0]->getType());
3835 for (const SCEV *Op : llvm::drop_begin(Operands)) {
3836 assert(getEffectiveSCEVType(Op->getType()) == ETy &&
3837 "SCEVAddRecExpr operand types don't match!");
3838 assert(!Op->getType()->isPointerTy() && "Step must be integer");
3839 }
3840 for (const SCEV *Op : Operands)
3842 "SCEVAddRecExpr operand is not available at loop entry!");
3843#endif
3844
3845 if (Operands.back()->isZero()) {
3846 Operands.pop_back();
3847 return getAddRecExpr(Operands, L, SCEV::FlagAnyWrap); // {X,+,0} --> X
3848 }
3849
3850 // It's tempting to want to call getConstantMaxBackedgeTakenCount count here and
3851 // use that information to infer NUW and NSW flags. However, computing a
3852 // BE count requires calling getAddRecExpr, so we may not yet have a
3853 // meaningful BE count at this point (and if we don't, we'd be stuck
3854 // with a SCEVCouldNotCompute as the cached BE count).
3855
3856 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
3857
3858 // Canonicalize nested AddRecs in by nesting them in order of loop depth.
3859 if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
3860 const Loop *NestedLoop = NestedAR->getLoop();
3861 if (L->contains(NestedLoop)
3862 ? (L->getLoopDepth() < NestedLoop->getLoopDepth())
3863 : (!NestedLoop->contains(L) &&
3864 DT.dominates(L->getHeader(), NestedLoop->getHeader()))) {
3865 SmallVector<SCEVUse, 4> NestedOperands(NestedAR->operands());
3866 Operands[0] = NestedAR->getStart();
3867 // AddRecs require their operands be loop-invariant with respect to their
3868 // loops. Don't perform this transformation if it would break this
3869 // requirement.
3870 bool AllInvariant = all_of(
3871 Operands, [&](const SCEV *Op) { return isLoopInvariant(Op, L); });
3872
3873 if (AllInvariant) {
3874 // Create a recurrence for the outer loop with the same step size.
3875 //
3876 // The outer recurrence keeps its NW flag but only keeps NUW/NSW if the
3877 // inner recurrence has the same property.
3878 SCEV::NoWrapFlags OuterFlags =
3879 maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags());
3880
3881 NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags);
3882 AllInvariant = all_of(NestedOperands, [&](const SCEV *Op) {
3883 return isLoopInvariant(Op, NestedLoop);
3884 });
3885
3886 if (AllInvariant) {
3887 // Ok, both add recurrences are valid after the transformation.
3888 //
3889 // The inner recurrence keeps its NW flag but only keeps NUW/NSW if
3890 // the outer recurrence has the same property.
3891 SCEV::NoWrapFlags InnerFlags =
3892 maskFlags(NestedAR->getNoWrapFlags(), SCEV::FlagNW | Flags);
3893 return getAddRecExpr(NestedOperands, NestedLoop, InnerFlags);
3894 }
3895 }
3896 // Reset Operands to its original state.
3897 Operands[0] = NestedAR;
3898 }
3899 }
3900
3901 // Okay, it looks like we really DO need an addrec expr. Check to see if we
3902 // already have one, otherwise create a new one.
3903 return getOrCreateAddRecExpr(Operands, L, Flags);
3904}
3905
3907 ArrayRef<SCEVUse> IndexExprs) {
3908 const SCEV *BaseExpr = getSCEV(GEP->getPointerOperand());
3909 // getSCEV(Base)->getType() has the same address space as Base->getType()
3910 // because SCEV::getType() preserves the address space.
3911 GEPNoWrapFlags NW = GEP->getNoWrapFlags();
3912 if (NW != GEPNoWrapFlags::none()) {
3913 // We'd like to propagate flags from the IR to the corresponding SCEV nodes,
3914 // but to do that, we have to ensure that said flag is valid in the entire
3915 // defined scope of the SCEV.
3916 // TODO: non-instructions have global scope. We might be able to prove
3917 // some global scope cases
3918 auto *GEPI = dyn_cast<Instruction>(GEP);
3919 if (!GEPI || !isSCEVExprNeverPoison(GEPI))
3920 NW = GEPNoWrapFlags::none();
3921 }
3922
3923 return getGEPExpr(BaseExpr, IndexExprs, GEP->getSourceElementType(), NW);
3924}
3925
3927 ArrayRef<SCEVUse> IndexExprs,
3928 Type *SrcElementTy, GEPNoWrapFlags NW) {
3930 if (NW.hasNoUnsignedSignedWrap())
3931 OffsetWrap = setFlags(OffsetWrap, SCEV::FlagNSW);
3932 if (NW.hasNoUnsignedWrap())
3933 OffsetWrap = setFlags(OffsetWrap, SCEV::FlagNUW);
3934
3935 Type *CurTy = BaseExpr->getType();
3936 Type *IntIdxTy = getEffectiveSCEVType(BaseExpr->getType());
3937 bool FirstIter = true;
3939 for (SCEVUse IndexExpr : IndexExprs) {
3940 // Compute the (potentially symbolic) offset in bytes for this index.
3941 if (StructType *STy = dyn_cast<StructType>(CurTy)) {
3942 // For a struct, add the member offset.
3943 ConstantInt *Index = cast<SCEVConstant>(IndexExpr)->getValue();
3944 unsigned FieldNo = Index->getZExtValue();
3945 const SCEV *FieldOffset = getOffsetOfExpr(IntIdxTy, STy, FieldNo);
3946 Offsets.push_back(FieldOffset);
3947
3948 // Update CurTy to the type of the field at Index.
3949 CurTy = STy->getTypeAtIndex(Index);
3950 } else {
3951 // Update CurTy to its element type.
3952 if (FirstIter) {
3953 assert(isa<PointerType>(CurTy) &&
3954 "The first index of a GEP indexes a pointer");
3955 CurTy = SrcElementTy;
3956 FirstIter = false;
3957 } else {
3959 }
3960 // For an array, add the element offset, explicitly scaled.
3961 const SCEV *ElementSize = getSizeOfExpr(IntIdxTy, CurTy);
3962 // Getelementptr indices are signed.
3963 IndexExpr = getTruncateOrSignExtend(IndexExpr, IntIdxTy);
3964
3965 // Multiply the index by the element size to compute the element offset.
3966 const SCEV *LocalOffset = getMulExpr(IndexExpr, ElementSize, OffsetWrap);
3967 Offsets.push_back(LocalOffset);
3968 }
3969 }
3970
3971 // Handle degenerate case of GEP without offsets.
3972 if (Offsets.empty())
3973 return BaseExpr;
3974
3975 // Add the offsets together, assuming nsw if inbounds.
3976 const SCEV *Offset = getAddExpr(Offsets, OffsetWrap);
3977 // Add the base address and the offset. We cannot use the nsw flag, as the
3978 // base address is unsigned. However, if we know that the offset is
3979 // non-negative, we can use nuw.
3980 bool NUW = NW.hasNoUnsignedWrap() ||
3983 auto *GEPExpr = getAddExpr(BaseExpr, Offset, BaseWrap);
3984 assert(BaseExpr->getType() == GEPExpr->getType() &&
3985 "GEP should not change type mid-flight.");
3986 return GEPExpr;
3987}
3988
3989SCEV *ScalarEvolution::findExistingSCEVInCache(SCEVTypes SCEVType,
3992 ID.AddInteger(SCEVType);
3993 for (const SCEV *Op : Ops)
3994 ID.AddPointer(Op);
3995 void *IP = nullptr;
3996 return UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3997}
3998
3999SCEV *ScalarEvolution::findExistingSCEVInCache(SCEVTypes SCEVType,
4002 ID.AddInteger(SCEVType);
4003 for (const SCEV *Op : Ops)
4004 ID.AddPointer(Op);
4005 void *IP = nullptr;
4006 return UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
4007}
4008
4009const SCEV *ScalarEvolution::getAbsExpr(const SCEV *Op, bool IsNSW) {
4011 return getSMaxExpr(Op, getNegativeSCEV(Op, Flags));
4012}
4013
4016 assert(SCEVMinMaxExpr::isMinMaxType(Kind) && "Not a SCEVMinMaxExpr!");
4017 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
4018 if (Ops.size() == 1) return Ops[0];
4019#ifndef NDEBUG
4020 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
4021 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4022 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
4023 "Operand types don't match!");
4024 assert(Ops[0]->getType()->isPointerTy() ==
4025 Ops[i]->getType()->isPointerTy() &&
4026 "min/max should be consistently pointerish");
4027 }
4028#endif
4029
4030 bool IsSigned = Kind == scSMaxExpr || Kind == scSMinExpr;
4031 bool IsMax = Kind == scSMaxExpr || Kind == scUMaxExpr;
4032
4033 const SCEV *Folded = constantFoldAndGroupOps(
4034 *this, LI, DT, Ops,
4035 [&](const APInt &C1, const APInt &C2) {
4036 switch (Kind) {
4037 case scSMaxExpr:
4038 return APIntOps::smax(C1, C2);
4039 case scSMinExpr:
4040 return APIntOps::smin(C1, C2);
4041 case scUMaxExpr:
4042 return APIntOps::umax(C1, C2);
4043 case scUMinExpr:
4044 return APIntOps::umin(C1, C2);
4045 default:
4046 llvm_unreachable("Unknown SCEV min/max opcode");
4047 }
4048 },
4049 [&](const APInt &C) {
4050 // identity
4051 if (IsMax)
4052 return IsSigned ? C.isMinSignedValue() : C.isMinValue();
4053 else
4054 return IsSigned ? C.isMaxSignedValue() : C.isMaxValue();
4055 },
4056 [&](const APInt &C) {
4057 // absorber
4058 if (IsMax)
4059 return IsSigned ? C.isMaxSignedValue() : C.isMaxValue();
4060 else
4061 return IsSigned ? C.isMinSignedValue() : C.isMinValue();
4062 });
4063 if (Folded)
4064 return Folded;
4065
4066 // Check if we have created the same expression before.
4067 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops)) {
4068 return S;
4069 }
4070
4071 // Find the first operation of the same kind
4072 unsigned Idx = 0;
4073 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < Kind)
4074 ++Idx;
4075
4076 // Check to see if one of the operands is of the same kind. If so, expand its
4077 // operands onto our operand list, and recurse to simplify.
4078 if (Idx < Ops.size()) {
4079 bool DeletedAny = false;
4080 while (Ops[Idx]->getSCEVType() == Kind) {
4081 const SCEVMinMaxExpr *SMME = cast<SCEVMinMaxExpr>(Ops[Idx]);
4082 Ops.erase(Ops.begin()+Idx);
4083 append_range(Ops, SMME->operands());
4084 DeletedAny = true;
4085 }
4086
4087 if (DeletedAny)
4088 return getMinMaxExpr(Kind, Ops);
4089 }
4090
4091 // Okay, check to see if the same value occurs in the operand list twice. If
4092 // so, delete one. Since we sorted the list, these values are required to
4093 // be adjacent.
4098 llvm::CmpInst::Predicate FirstPred = IsMax ? GEPred : LEPred;
4099 llvm::CmpInst::Predicate SecondPred = IsMax ? LEPred : GEPred;
4100 for (unsigned i = 0, e = Ops.size() - 1; i != e; ++i) {
4101 if (Ops[i] == Ops[i + 1] ||
4102 isKnownViaNonRecursiveReasoning(FirstPred, Ops[i], Ops[i + 1])) {
4103 // X op Y op Y --> X op Y
4104 // X op Y --> X, if we know X, Y are ordered appropriately
4105 Ops.erase(Ops.begin() + i + 1, Ops.begin() + i + 2);
4106 --i;
4107 --e;
4108 } else if (isKnownViaNonRecursiveReasoning(SecondPred, Ops[i],
4109 Ops[i + 1])) {
4110 // X op Y --> Y, if we know X, Y are ordered appropriately
4111 Ops.erase(Ops.begin() + i, Ops.begin() + i + 1);
4112 --i;
4113 --e;
4114 }
4115 }
4116
4117 if (Ops.size() == 1) return Ops[0];
4118
4119 assert(!Ops.empty() && "Reduced smax down to nothing!");
4120
4121 // Okay, it looks like we really DO need an expr. Check to see if we
4122 // already have one, otherwise create a new one.
4124 ID.AddInteger(Kind);
4125 for (const SCEV *Op : Ops)
4126 ID.AddPointer(Op);
4127 void *IP = nullptr;
4128 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
4129 if (ExistingSCEV)
4130 return ExistingSCEV;
4131 SCEVUse *O = SCEVAllocator.Allocate<SCEVUse>(Ops.size());
4133 SCEV *S = new (SCEVAllocator)
4134 SCEVMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
4135
4136 UniqueSCEVs.InsertNode(S, IP);
4137 S->computeAndSetCanonical(*this);
4138 registerUser(S, Ops);
4139 return S;
4140}
4141
4142namespace {
4143
4144class SCEVSequentialMinMaxDeduplicatingVisitor final
4145 : public SCEVVisitor<SCEVSequentialMinMaxDeduplicatingVisitor,
4146 std::optional<const SCEV *>> {
4147 using RetVal = std::optional<const SCEV *>;
4149
4150 ScalarEvolution &SE;
4151 const SCEVTypes RootKind; // Must be a sequential min/max expression.
4152 const SCEVTypes NonSequentialRootKind; // Non-sequential variant of RootKind.
4154
4155 bool canRecurseInto(SCEVTypes Kind) const {
4156 // We can only recurse into the SCEV expression of the same effective type
4157 // as the type of our root SCEV expression.
4158 return RootKind == Kind || NonSequentialRootKind == Kind;
4159 };
4160
4161 RetVal visitAnyMinMaxExpr(const SCEV *S) {
4163 "Only for min/max expressions.");
4164 SCEVTypes Kind = S->getSCEVType();
4165
4166 if (!canRecurseInto(Kind))
4167 return S;
4168
4169 auto *NAry = cast<SCEVNAryExpr>(S);
4170 SmallVector<SCEVUse> NewOps;
4171 bool Changed = visit(Kind, NAry->operands(), NewOps);
4172
4173 if (!Changed)
4174 return S;
4175 if (NewOps.empty())
4176 return std::nullopt;
4177
4179 ? SE.getSequentialMinMaxExpr(Kind, NewOps)
4180 : SE.getMinMaxExpr(Kind, NewOps);
4181 }
4182
4183 RetVal visit(const SCEV *S) {
4184 // Has the whole operand been seen already?
4185 if (!SeenOps.insert(S).second)
4186 return std::nullopt;
4187 return Base::visit(S);
4188 }
4189
4190public:
4191 SCEVSequentialMinMaxDeduplicatingVisitor(ScalarEvolution &SE,
4192 SCEVTypes RootKind)
4193 : SE(SE), RootKind(RootKind),
4194 NonSequentialRootKind(
4195 SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType(
4196 RootKind)) {}
4197
4198 bool /*Changed*/ visit(SCEVTypes Kind, ArrayRef<SCEVUse> OrigOps,
4199 SmallVectorImpl<SCEVUse> &NewOps) {
4200 bool Changed = false;
4202 Ops.reserve(OrigOps.size());
4203
4204 for (const SCEV *Op : OrigOps) {
4205 RetVal NewOp = visit(Op);
4206 if (NewOp != Op)
4207 Changed = true;
4208 if (NewOp)
4209 Ops.emplace_back(*NewOp);
4210 }
4211
4212 if (Changed)
4213 NewOps = std::move(Ops);
4214 return Changed;
4215 }
4216
4217 RetVal visitConstant(const SCEVConstant *Constant) { return Constant; }
4218
4219 RetVal visitVScale(const SCEVVScale *VScale) { return VScale; }
4220
4221 RetVal visitPtrToAddrExpr(const SCEVPtrToAddrExpr *Expr) { return Expr; }
4222
4223 RetVal visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { return Expr; }
4224
4225 RetVal visitTruncateExpr(const SCEVTruncateExpr *Expr) { return Expr; }
4226
4227 RetVal visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { return Expr; }
4228
4229 RetVal visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { return Expr; }
4230
4231 RetVal visitAddExpr(const SCEVAddExpr *Expr) { return Expr; }
4232
4233 RetVal visitMulExpr(const SCEVMulExpr *Expr) { return Expr; }
4234
4235 RetVal visitUDivExpr(const SCEVUDivExpr *Expr) { return Expr; }
4236
4237 RetVal visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
4238
4239 RetVal visitSMaxExpr(const SCEVSMaxExpr *Expr) {
4240 return visitAnyMinMaxExpr(Expr);
4241 }
4242
4243 RetVal visitUMaxExpr(const SCEVUMaxExpr *Expr) {
4244 return visitAnyMinMaxExpr(Expr);
4245 }
4246
4247 RetVal visitSMinExpr(const SCEVSMinExpr *Expr) {
4248 return visitAnyMinMaxExpr(Expr);
4249 }
4250
4251 RetVal visitUMinExpr(const SCEVUMinExpr *Expr) {
4252 return visitAnyMinMaxExpr(Expr);
4253 }
4254
4255 RetVal visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
4256 return visitAnyMinMaxExpr(Expr);
4257 }
4258
4259 RetVal visitUnknown(const SCEVUnknown *Expr) { return Expr; }
4260
4261 RetVal visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { return Expr; }
4262};
4263
4264} // namespace
4265
4267 switch (Kind) {
4268 case scConstant:
4269 case scVScale:
4270 case scTruncate:
4271 case scZeroExtend:
4272 case scSignExtend:
4273 case scPtrToAddr:
4274 case scPtrToInt:
4275 case scAddExpr:
4276 case scMulExpr:
4277 case scUDivExpr:
4278 case scAddRecExpr:
4279 case scUMaxExpr:
4280 case scSMaxExpr:
4281 case scUMinExpr:
4282 case scSMinExpr:
4283 case scUnknown:
4284 // If any operand is poison, the whole expression is poison.
4285 return true;
4287 // FIXME: if the *first* operand is poison, the whole expression is poison.
4288 return false; // Pessimistically, say that it does not propagate poison.
4289 case scCouldNotCompute:
4290 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
4291 }
4292 llvm_unreachable("Unknown SCEV kind!");
4293}
4294
4295namespace {
4296// The only way poison may be introduced in a SCEV expression is from a
4297// poison SCEVUnknown (ConstantExprs are also represented as SCEVUnknown,
4298// not SCEVConstant). Notably, nowrap flags in SCEV nodes can *not*
4299// introduce poison -- they encode guaranteed, non-speculated knowledge.
4300//
4301// Additionally, all SCEV nodes propagate poison from inputs to outputs,
4302// with the notable exception of umin_seq, where only poison from the first
4303// operand is (unconditionally) propagated.
4304struct SCEVPoisonCollector {
4305 bool LookThroughMaybePoisonBlocking;
4306 SmallPtrSet<const SCEVUnknown *, 4> MaybePoison;
4307 SCEVPoisonCollector(bool LookThroughMaybePoisonBlocking)
4308 : LookThroughMaybePoisonBlocking(LookThroughMaybePoisonBlocking) {}
4309
4310 bool follow(const SCEV *S) {
4311 if (!LookThroughMaybePoisonBlocking &&
4313 return false;
4314
4315 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
4316 if (!isGuaranteedNotToBePoison(SU->getValue()))
4317 MaybePoison.insert(SU);
4318 }
4319 return true;
4320 }
4321 bool isDone() const { return false; }
4322};
4323} // namespace
4324
4325/// Return true if V is poison given that AssumedPoison is already poison.
4326static bool impliesPoison(const SCEV *AssumedPoison, const SCEV *S) {
4327 // First collect all SCEVs that might result in AssumedPoison to be poison.
4328 // We need to look through potentially poison-blocking operations here,
4329 // because we want to find all SCEVs that *might* result in poison, not only
4330 // those that are *required* to.
4331 SCEVPoisonCollector PC1(/* LookThroughMaybePoisonBlocking */ true);
4332 visitAll(AssumedPoison, PC1);
4333
4334 // AssumedPoison is never poison. As the assumption is false, the implication
4335 // is true. Don't bother walking the other SCEV in this case.
4336 if (PC1.MaybePoison.empty())
4337 return true;
4338
4339 // Collect all SCEVs in S that, if poison, *will* result in S being poison
4340 // as well. We cannot look through potentially poison-blocking operations
4341 // here, as their arguments only *may* make the result poison.
4342 SCEVPoisonCollector PC2(/* LookThroughMaybePoisonBlocking */ false);
4343 visitAll(S, PC2);
4344
4345 // Make sure that no matter which SCEV in PC1.MaybePoison is actually poison,
4346 // it will also make S poison by being part of PC2.MaybePoison.
4347 return llvm::set_is_subset(PC1.MaybePoison, PC2.MaybePoison);
4348}
4349
4351 SmallPtrSetImpl<const Value *> &Result, const SCEV *S) {
4352 SCEVPoisonCollector PC(/* LookThroughMaybePoisonBlocking */ false);
4353 visitAll(S, PC);
4354 for (const SCEVUnknown *SU : PC.MaybePoison)
4355 Result.insert(SU->getValue());
4356}
4357
4359 const SCEV *S, Instruction *I,
4360 SmallVectorImpl<Instruction *> &DropPoisonGeneratingInsts) {
4361 // If the instruction cannot be poison, it's always safe to reuse.
4363 return true;
4364
4365 // Otherwise, it is possible that I is more poisonous that S. Collect the
4366 // poison-contributors of S, and then check whether I has any additional
4367 // poison-contributors. Poison that is contributed through poison-generating
4368 // flags is handled by dropping those flags instead.
4370 getPoisonGeneratingValues(PoisonVals, S);
4371
4372 SmallVector<Value *> Worklist;
4374 Worklist.push_back(I);
4375 while (!Worklist.empty()) {
4376 Value *V = Worklist.pop_back_val();
4377 if (!Visited.insert(V).second)
4378 continue;
4379
4380 // Avoid walking large instruction graphs.
4381 if (Visited.size() > 16)
4382 return false;
4383
4384 // Either the value can't be poison, or the S would also be poison if it
4385 // is.
4386 if (PoisonVals.contains(V) || ::isGuaranteedNotToBePoison(V))
4387 continue;
4388
4389 auto *I = dyn_cast<Instruction>(V);
4390 if (!I)
4391 return false;
4392
4393 // Disjoint or instructions are interpreted as adds by SCEV. However, we
4394 // can't replace an arbitrary add with disjoint or, even if we drop the
4395 // flag. We would need to convert the or into an add.
4396 if (auto *PDI = dyn_cast<PossiblyDisjointInst>(I))
4397 if (PDI->isDisjoint())
4398 return false;
4399
4400 // FIXME: Ignore vscale, even though it technically could be poison. Do this
4401 // because SCEV currently assumes it can't be poison. Remove this special
4402 // case once we proper model when vscale can be poison.
4403 if (auto *II = dyn_cast<IntrinsicInst>(I);
4404 II && II->getIntrinsicID() == Intrinsic::vscale)
4405 continue;
4406
4407 if (canCreatePoison(cast<Operator>(I), /*ConsiderFlagsAndMetadata*/ false))
4408 return false;
4409
4410 // If the instruction can't create poison, we can recurse to its operands.
4411 if (I->hasPoisonGeneratingAnnotations())
4412 DropPoisonGeneratingInsts.push_back(I);
4413
4414 llvm::append_range(Worklist, I->operands());
4415 }
4416 return true;
4417}
4418
4419const SCEV *
4422 assert(SCEVSequentialMinMaxExpr::isSequentialMinMaxType(Kind) &&
4423 "Not a SCEVSequentialMinMaxExpr!");
4424 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
4425 if (Ops.size() == 1)
4426 return Ops[0];
4427#ifndef NDEBUG
4428 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
4429 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4430 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
4431 "Operand types don't match!");
4432 assert(Ops[0]->getType()->isPointerTy() ==
4433 Ops[i]->getType()->isPointerTy() &&
4434 "min/max should be consistently pointerish");
4435 }
4436#endif
4437
4438 // Note that SCEVSequentialMinMaxExpr is *NOT* commutative,
4439 // so we can *NOT* do any kind of sorting of the expressions!
4440
4441 // Check if we have created the same expression before.
4442 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops))
4443 return S;
4444
4445 // FIXME: there are *some* simplifications that we can do here.
4446
4447 // Keep only the first instance of an operand.
4448 {
4449 SCEVSequentialMinMaxDeduplicatingVisitor Deduplicator(*this, Kind);
4450 bool Changed = Deduplicator.visit(Kind, Ops, Ops);
4451 if (Changed)
4452 return getSequentialMinMaxExpr(Kind, Ops);
4453 }
4454
4455 // Check to see if one of the operands is of the same kind. If so, expand its
4456 // operands onto our operand list, and recurse to simplify.
4457 {
4458 unsigned Idx = 0;
4459 bool DeletedAny = false;
4460 while (Idx < Ops.size()) {
4461 if (Ops[Idx]->getSCEVType() != Kind) {
4462 ++Idx;
4463 continue;
4464 }
4465 const auto *SMME = cast<SCEVSequentialMinMaxExpr>(Ops[Idx]);
4466 Ops.erase(Ops.begin() + Idx);
4467 Ops.insert(Ops.begin() + Idx, SMME->operands().begin(),
4468 SMME->operands().end());
4469 DeletedAny = true;
4470 }
4471
4472 if (DeletedAny)
4473 return getSequentialMinMaxExpr(Kind, Ops);
4474 }
4475
4476 const SCEV *SaturationPoint;
4478 switch (Kind) {
4480 SaturationPoint = getZero(Ops[0]->getType());
4481 Pred = ICmpInst::ICMP_ULE;
4482 break;
4483 default:
4484 llvm_unreachable("Not a sequential min/max type.");
4485 }
4486
4487 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4488 if (!isGuaranteedNotToCauseUB(Ops[i]))
4489 continue;
4490 // We can replace %x umin_seq %y with %x umin %y if either:
4491 // * %y being poison implies %x is also poison.
4492 // * %x cannot be the saturating value (e.g. zero for umin).
4493 if (::impliesPoison(Ops[i], Ops[i - 1]) ||
4494 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_NE, Ops[i - 1],
4495 SaturationPoint)) {
4496 SmallVector<SCEVUse, 2> SeqOps = {Ops[i - 1], Ops[i]};
4497 Ops[i - 1] = getMinMaxExpr(
4499 SeqOps);
4500 Ops.erase(Ops.begin() + i);
4501 return getSequentialMinMaxExpr(Kind, Ops);
4502 }
4503 // Fold %x umin_seq %y to %x if %x ule %y.
4504 // TODO: We might be able to prove the predicate for a later operand.
4505 if (isKnownViaNonRecursiveReasoning(Pred, Ops[i - 1], Ops[i])) {
4506 Ops.erase(Ops.begin() + i);
4507 return getSequentialMinMaxExpr(Kind, Ops);
4508 }
4509 }
4510
4511 // Okay, it looks like we really DO need an expr. Check to see if we
4512 // already have one, otherwise create a new one.
4514 ID.AddInteger(Kind);
4515 for (const SCEV *Op : Ops)
4516 ID.AddPointer(Op);
4517 void *IP = nullptr;
4518 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
4519 if (ExistingSCEV)
4520 return ExistingSCEV;
4521
4522 SCEVUse *O = SCEVAllocator.Allocate<SCEVUse>(Ops.size());
4524 SCEV *S = new (SCEVAllocator)
4525 SCEVSequentialMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
4526
4527 UniqueSCEVs.InsertNode(S, IP);
4528 S->computeAndSetCanonical(*this);
4529 registerUser(S, Ops);
4530 return S;
4531}
4532
4537
4541
4546
4550
4555
4559
4561 bool Sequential) {
4562 SmallVector<SCEVUse, 2> Ops = {LHS, RHS};
4563 return getUMinExpr(Ops, Sequential);
4564}
4565
4571
4572const SCEV *
4574 const SCEV *Res = getConstant(IntTy, Size.getKnownMinValue());
4575 if (Size.isScalable())
4576 Res = getMulExpr(Res, getVScale(IntTy));
4577 return Res;
4578}
4579
4581 return getSizeOfExpr(IntTy, getDataLayout().getTypeAllocSize(AllocTy));
4582}
4583
4585 return getSizeOfExpr(IntTy, getDataLayout().getTypeStoreSize(StoreTy));
4586}
4587
4589 StructType *STy,
4590 unsigned FieldNo) {
4591 // We can bypass creating a target-independent constant expression and then
4592 // folding it back into a ConstantInt. This is just a compile-time
4593 // optimization.
4594 const StructLayout *SL = getDataLayout().getStructLayout(STy);
4595 assert(!SL->getSizeInBits().isScalable() &&
4596 "Cannot get offset for structure containing scalable vector types");
4597 return getConstant(IntTy, SL->getElementOffset(FieldNo));
4598}
4599
4601 // Don't attempt to do anything other than create a SCEVUnknown object
4602 // here. createSCEV only calls getUnknown after checking for all other
4603 // interesting possibilities, and any other code that calls getUnknown
4604 // is doing so in order to hide a value from SCEV canonicalization.
4605
4607 ID.AddInteger(scUnknown);
4608 ID.AddPointer(V);
4609 void *IP = nullptr;
4610 if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) {
4611 assert(cast<SCEVUnknown>(S)->getValue() == V &&
4612 "Stale SCEVUnknown in uniquing map!");
4613 return S;
4614 }
4615 SCEV *S = new (SCEVAllocator) SCEVUnknown(ID.Intern(SCEVAllocator), V, this,
4616 FirstUnknown);
4617 FirstUnknown = cast<SCEVUnknown>(S);
4618 UniqueSCEVs.InsertNode(S, IP);
4619 S->computeAndSetCanonical(*this);
4620 return S;
4621}
4622
4623//===----------------------------------------------------------------------===//
4624// Basic SCEV Analysis and PHI Idiom Recognition Code
4625//
4626
4627/// Test if values of the given type are analyzable within the SCEV
4628/// framework. This primarily includes integer types, and it can optionally
4629/// include pointer types if the ScalarEvolution class has access to
4630/// target-specific information.
4632 // Integers and pointers are always SCEVable.
4633 return Ty->isIntOrPtrTy();
4634}
4635
4636/// Return the size in bits of the specified type, for which isSCEVable must
4637/// return true.
4639 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4640 if (Ty->isPointerTy())
4642 return getDataLayout().getTypeSizeInBits(Ty);
4643}
4644
4645/// Return a type with the same bitwidth as the given type and which represents
4646/// how SCEV will treat the given type, for which isSCEVable must return
4647/// true. For pointer types, this is the pointer index sized integer type.
4649 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4650
4651 if (Ty->isIntegerTy())
4652 return Ty;
4653
4654 // The only other support type is pointer.
4655 assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!");
4656 return getDataLayout().getIndexType(Ty);
4657}
4658
4660 return getTypeSizeInBits(T1) >= getTypeSizeInBits(T2) ? T1 : T2;
4661}
4662
4664 const SCEV *B) {
4665 /// For a valid use point to exist, the defining scope of one operand
4666 /// must dominate the other.
4667 bool PreciseA, PreciseB;
4668 auto *ScopeA = getDefiningScopeBound({A}, PreciseA);
4669 auto *ScopeB = getDefiningScopeBound({B}, PreciseB);
4670 if (!PreciseA || !PreciseB)
4671 // Can't tell.
4672 return false;
4673 return (ScopeA == ScopeB) || DT.dominates(ScopeA, ScopeB) ||
4674 DT.dominates(ScopeB, ScopeA);
4675}
4676
4678 return CouldNotCompute.get();
4679}
4680
4681bool ScalarEvolution::checkValidity(const SCEV *S) const {
4682 bool ContainsNulls = SCEVExprContains(S, [](const SCEV *S) {
4683 auto *SU = dyn_cast<SCEVUnknown>(S);
4684 return SU && SU->getValue() == nullptr;
4685 });
4686
4687 return !ContainsNulls;
4688}
4689
4691 HasRecMapType::iterator I = HasRecMap.find(S);
4692 if (I != HasRecMap.end())
4693 return I->second;
4694
4695 bool FoundAddRec =
4696 SCEVExprContains(S, [](const SCEV *S) { return isa<SCEVAddRecExpr>(S); });
4697 HasRecMap.insert({S, FoundAddRec});
4698 return FoundAddRec;
4699}
4700
4701/// Return the ValueOffsetPair set for \p S. \p S can be represented
4702/// by the value and offset from any ValueOffsetPair in the set.
4703ArrayRef<Value *> ScalarEvolution::getSCEVValues(const SCEV *S) {
4704 ExprValueMapType::iterator SI = ExprValueMap.find_as(S);
4705 if (SI == ExprValueMap.end())
4706 return {};
4707 return SI->second.getArrayRef();
4708}
4709
4710/// Erase Value from ValueExprMap and ExprValueMap. ValueExprMap.erase(V)
4711/// cannot be used separately. eraseValueFromMap should be used to remove
4712/// V from ValueExprMap and ExprValueMap at the same time.
4713void ScalarEvolution::eraseValueFromMap(Value *V) {
4714 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4715 if (I != ValueExprMap.end()) {
4716 auto EVIt = ExprValueMap.find(I->second);
4717 bool Removed = EVIt->second.remove(V);
4718 (void) Removed;
4719 assert(Removed && "Value not in ExprValueMap?");
4720 ValueExprMap.erase(I);
4721 }
4722}
4723
4724void ScalarEvolution::insertValueToMap(Value *V, const SCEV *S) {
4725 // A recursive query may have already computed the SCEV. It should be
4726 // equivalent, but may not necessarily be exactly the same, e.g. due to lazily
4727 // inferred nowrap flags.
4728 auto It = ValueExprMap.find_as(V);
4729 if (It == ValueExprMap.end()) {
4730 ValueExprMap.insert({SCEVCallbackVH(V, this), S});
4731 ExprValueMap[S].insert(V);
4732 }
4733}
4734
4735/// Return an existing SCEV if it exists, otherwise analyze the expression and
4736/// create a new one.
4738 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4739
4740 if (const SCEV *S = getExistingSCEV(V))
4741 return S;
4742 return createSCEVIter(V);
4743}
4744
4746 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4747
4748 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4749 if (I != ValueExprMap.end()) {
4750 const SCEV *S = I->second;
4751 assert(checkValidity(S) &&
4752 "existing SCEV has not been properly invalidated");
4753 return S;
4754 }
4755 return nullptr;
4756}
4757
4758/// Return a SCEV corresponding to -V = -1*V
4760 SCEV::NoWrapFlags Flags) {
4761 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4762 return getConstant(
4763 cast<ConstantInt>(ConstantExpr::getNeg(VC->getValue())));
4764
4765 Type *Ty = V->getType();
4766 Ty = getEffectiveSCEVType(Ty);
4767 return getMulExpr(V, getMinusOne(Ty), Flags);
4768}
4769
4770/// If Expr computes ~A, return A else return nullptr
4771static const SCEV *MatchNotExpr(const SCEV *Expr) {
4772 const SCEV *MulOp;
4773 if (match(Expr, m_scev_Add(m_scev_AllOnes(),
4774 m_scev_Mul(m_scev_AllOnes(), m_SCEV(MulOp)))))
4775 return MulOp;
4776 return nullptr;
4777}
4778
4779/// Return a SCEV corresponding to ~V = -1-V
4781 assert(!V->getType()->isPointerTy() && "Can't negate pointer");
4782
4783 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4784 return getConstant(
4785 cast<ConstantInt>(ConstantExpr::getNot(VC->getValue())));
4786
4787 // Fold ~(u|s)(min|max)(~x, ~y) to (u|s)(max|min)(x, y)
4788 if (const SCEVMinMaxExpr *MME = dyn_cast<SCEVMinMaxExpr>(V)) {
4789 auto MatchMinMaxNegation = [&](const SCEVMinMaxExpr *MME) {
4790 SmallVector<SCEVUse, 2> MatchedOperands;
4791 for (const SCEV *Operand : MME->operands()) {
4792 const SCEV *Matched = MatchNotExpr(Operand);
4793 if (!Matched)
4794 return (const SCEV *)nullptr;
4795 MatchedOperands.push_back(Matched);
4796 }
4797 return getMinMaxExpr(SCEVMinMaxExpr::negate(MME->getSCEVType()),
4798 MatchedOperands);
4799 };
4800 if (const SCEV *Replaced = MatchMinMaxNegation(MME))
4801 return Replaced;
4802 }
4803
4804 Type *Ty = V->getType();
4805 Ty = getEffectiveSCEVType(Ty);
4806 return getMinusSCEV(getMinusOne(Ty), V);
4807}
4808
4810 assert(P->getType()->isPointerTy());
4811
4812 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(P)) {
4813 // The base of an AddRec is the first operand.
4814 SmallVector<SCEVUse> Ops{AddRec->operands()};
4815 Ops[0] = removePointerBase(Ops[0]);
4816 // Don't try to transfer nowrap flags for now. We could in some cases
4817 // (for example, if pointer operand of the AddRec is a SCEVUnknown).
4818 return getAddRecExpr(Ops, AddRec->getLoop(), SCEV::FlagAnyWrap);
4819 }
4820 if (auto *Add = dyn_cast<SCEVAddExpr>(P)) {
4821 // The base of an Add is the pointer operand.
4822 SmallVector<SCEVUse> Ops{Add->operands()};
4823 SCEVUse *PtrOp = nullptr;
4824 for (SCEVUse &AddOp : Ops) {
4825 if (AddOp->getType()->isPointerTy()) {
4826 assert(!PtrOp && "Cannot have multiple pointer ops");
4827 PtrOp = &AddOp;
4828 }
4829 }
4830 *PtrOp = removePointerBase(*PtrOp);
4831 // Don't try to transfer nowrap flags for now. We could in some cases
4832 // (for example, if the pointer operand of the Add is a SCEVUnknown).
4833 return getAddExpr(Ops);
4834 }
4835 // Any other expression must be a pointer base.
4836 return getZero(P->getType());
4837}
4838
4840 SCEV::NoWrapFlags Flags,
4841 unsigned Depth) {
4842 // Fast path: X - X --> 0.
4843 if (LHS == RHS)
4844 return getZero(LHS->getType());
4845
4846 // If we subtract two pointers with different pointer bases, bail.
4847 // Eventually, we're going to add an assertion to getMulExpr that we
4848 // can't multiply by a pointer.
4849 if (RHS->getType()->isPointerTy()) {
4850 if (!LHS->getType()->isPointerTy() ||
4851 getPointerBase(LHS) != getPointerBase(RHS))
4852 return getCouldNotCompute();
4853 LHS = removePointerBase(LHS);
4854 RHS = removePointerBase(RHS);
4855 }
4856
4857 // We represent LHS - RHS as LHS + (-1)*RHS. This transformation
4858 // makes it so that we cannot make much use of NUW.
4859 auto AddFlags = SCEV::FlagAnyWrap;
4860 const bool RHSIsNotMinSigned =
4862 if (hasFlags(Flags, SCEV::FlagNSW)) {
4863 // Let M be the minimum representable signed value. Then (-1)*RHS
4864 // signed-wraps if and only if RHS is M. That can happen even for
4865 // a NSW subtraction because e.g. (-1)*M signed-wraps even though
4866 // -1 - M does not. So to transfer NSW from LHS - RHS to LHS +
4867 // (-1)*RHS, we need to prove that RHS != M.
4868 //
4869 // If LHS is non-negative and we know that LHS - RHS does not
4870 // signed-wrap, then RHS cannot be M. So we can rule out signed-wrap
4871 // either by proving that RHS > M or that LHS >= 0.
4872 if (RHSIsNotMinSigned || isKnownNonNegative(LHS)) {
4873 AddFlags = SCEV::FlagNSW;
4874 }
4875 }
4876
4877 // FIXME: Find a correct way to transfer NSW to (-1)*M when LHS -
4878 // RHS is NSW and LHS >= 0.
4879 //
4880 // The difficulty here is that the NSW flag may have been proven
4881 // relative to a loop that is to be found in a recurrence in LHS and
4882 // not in RHS. Applying NSW to (-1)*M may then let the NSW have a
4883 // larger scope than intended.
4884 auto NegFlags = RHSIsNotMinSigned ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
4885
4886 return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags, Depth);
4887}
4888
4890 unsigned Depth) {
4891 Type *SrcTy = V->getType();
4892 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4893 "Cannot truncate or zero extend with non-integer arguments!");
4894 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4895 return V; // No conversion
4896 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4897 return getTruncateExpr(V, Ty, Depth);
4898 return getZeroExtendExpr(V, Ty, Depth);
4899}
4900
4902 unsigned Depth) {
4903 Type *SrcTy = V->getType();
4904 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4905 "Cannot truncate or zero extend with non-integer arguments!");
4906 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4907 return V; // No conversion
4908 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4909 return getTruncateExpr(V, Ty, Depth);
4910 return getSignExtendExpr(V, Ty, Depth);
4911}
4912
4913const SCEV *
4915 Type *SrcTy = V->getType();
4916 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4917 "Cannot noop or zero extend with non-integer arguments!");
4919 "getNoopOrZeroExtend cannot truncate!");
4920 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4921 return V; // No conversion
4922 return getZeroExtendExpr(V, Ty);
4923}
4924
4925const SCEV *
4927 Type *SrcTy = V->getType();
4928 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4929 "Cannot noop or sign extend with non-integer arguments!");
4931 "getNoopOrSignExtend cannot truncate!");
4932 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4933 return V; // No conversion
4934 return getSignExtendExpr(V, Ty);
4935}
4936
4937const SCEV *
4939 Type *SrcTy = V->getType();
4940 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4941 "Cannot noop or any extend with non-integer arguments!");
4943 "getNoopOrAnyExtend cannot truncate!");
4944 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4945 return V; // No conversion
4946 return getAnyExtendExpr(V, Ty);
4947}
4948
4949const SCEV *
4951 Type *SrcTy = V->getType();
4952 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4953 "Cannot truncate or noop with non-integer arguments!");
4955 "getTruncateOrNoop cannot extend!");
4956 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4957 return V; // No conversion
4958 return getTruncateExpr(V, Ty);
4959}
4960
4962 const SCEV *RHS) {
4963 const SCEV *PromotedLHS = LHS;
4964 const SCEV *PromotedRHS = RHS;
4965
4966 if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType()))
4967 PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
4968 else
4969 PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
4970
4971 return getUMaxExpr(PromotedLHS, PromotedRHS);
4972}
4973
4975 const SCEV *RHS,
4976 bool Sequential) {
4977 SmallVector<SCEVUse, 2> Ops = {LHS, RHS};
4978 return getUMinFromMismatchedTypes(Ops, Sequential);
4979}
4980
4981const SCEV *
4983 bool Sequential) {
4984 assert(!Ops.empty() && "At least one operand must be!");
4985 // Trivial case.
4986 if (Ops.size() == 1)
4987 return Ops[0];
4988
4989 // Find the max type first.
4990 Type *MaxType = nullptr;
4991 for (SCEVUse S : Ops)
4992 if (MaxType)
4993 MaxType = getWiderType(MaxType, S->getType());
4994 else
4995 MaxType = S->getType();
4996 assert(MaxType && "Failed to find maximum type!");
4997
4998 // Extend all ops to max type.
4999 SmallVector<SCEVUse, 2> PromotedOps;
5000 for (SCEVUse S : Ops)
5001 PromotedOps.push_back(getNoopOrZeroExtend(S, MaxType));
5002
5003 // Generate umin.
5004 return getUMinExpr(PromotedOps, Sequential);
5005}
5006
5008 // A pointer operand may evaluate to a nonpointer expression, such as null.
5009 if (!V->getType()->isPointerTy())
5010 return V;
5011
5012 while (true) {
5013 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(V)) {
5014 V = AddRec->getStart();
5015 } else if (auto *Add = dyn_cast<SCEVAddExpr>(V)) {
5016 const SCEV *PtrOp = nullptr;
5017 for (const SCEV *AddOp : Add->operands()) {
5018 if (AddOp->getType()->isPointerTy()) {
5019 assert(!PtrOp && "Cannot have multiple pointer ops");
5020 PtrOp = AddOp;
5021 }
5022 }
5023 assert(PtrOp && "Must have pointer op");
5024 V = PtrOp;
5025 } else // Not something we can look further into.
5026 return V;
5027 }
5028}
5029
5030/// Push users of the given Instruction onto the given Worklist.
5034 // Push the def-use children onto the Worklist stack.
5035 for (User *U : I->users()) {
5036 auto *UserInsn = cast<Instruction>(U);
5037 if (Visited.insert(UserInsn).second)
5038 Worklist.push_back(UserInsn);
5039 }
5040}
5041
5042namespace {
5043
5044/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its start
5045/// expression in case its Loop is L. If it is not L then
5046/// if IgnoreOtherLoops is true then use AddRec itself
5047/// otherwise rewrite cannot be done.
5048/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
5049class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> {
5050public:
5051 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
5052 bool IgnoreOtherLoops = true) {
5053 SCEVInitRewriter Rewriter(L, SE);
5054 const SCEV *Result = Rewriter.visit(S);
5055 if (Rewriter.hasSeenLoopVariantSCEVUnknown())
5056 return SE.getCouldNotCompute();
5057 return Rewriter.hasSeenOtherLoops() && !IgnoreOtherLoops
5058 ? SE.getCouldNotCompute()
5059 : Result;
5060 }
5061
5062 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5063 if (!SE.isLoopInvariant(Expr, L))
5064 SeenLoopVariantSCEVUnknown = true;
5065 return Expr;
5066 }
5067
5068 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
5069 // Only re-write AddRecExprs for this loop.
5070 if (Expr->getLoop() == L)
5071 return Expr->getStart();
5072 SeenOtherLoops = true;
5073 return Expr;
5074 }
5075
5076 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
5077
5078 bool hasSeenOtherLoops() { return SeenOtherLoops; }
5079
5080private:
5081 explicit SCEVInitRewriter(const Loop *L, ScalarEvolution &SE)
5082 : SCEVRewriteVisitor(SE), L(L) {}
5083
5084 const Loop *L;
5085 bool SeenLoopVariantSCEVUnknown = false;
5086 bool SeenOtherLoops = false;
5087};
5088
5089/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its post
5090/// increment expression in case its Loop is L. If it is not L then
5091/// use AddRec itself.
5092/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
5093class SCEVPostIncRewriter : public SCEVRewriteVisitor<SCEVPostIncRewriter> {
5094public:
5095 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE) {
5096 SCEVPostIncRewriter Rewriter(L, SE);
5097 const SCEV *Result = Rewriter.visit(S);
5098 return Rewriter.hasSeenLoopVariantSCEVUnknown()
5099 ? SE.getCouldNotCompute()
5100 : Result;
5101 }
5102
5103 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5104 if (!SE.isLoopInvariant(Expr, L))
5105 SeenLoopVariantSCEVUnknown = true;
5106 return Expr;
5107 }
5108
5109 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
5110 // Only re-write AddRecExprs for this loop.
5111 if (Expr->getLoop() == L)
5112 return Expr->getPostIncExpr(SE);
5113 SeenOtherLoops = true;
5114 return Expr;
5115 }
5116
5117 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
5118
5119 bool hasSeenOtherLoops() { return SeenOtherLoops; }
5120
5121private:
5122 explicit SCEVPostIncRewriter(const Loop *L, ScalarEvolution &SE)
5123 : SCEVRewriteVisitor(SE), L(L) {}
5124
5125 const Loop *L;
5126 bool SeenLoopVariantSCEVUnknown = false;
5127 bool SeenOtherLoops = false;
5128};
5129
5130/// This class evaluates the compare condition by matching it against the
5131/// condition of loop latch. If there is a match we assume a true value
5132/// for the condition while building SCEV nodes.
5133class SCEVBackedgeConditionFolder
5134 : public SCEVRewriteVisitor<SCEVBackedgeConditionFolder> {
5135public:
5136 static const SCEV *rewrite(const SCEV *S, const Loop *L,
5137 ScalarEvolution &SE) {
5138 bool IsPosBECond = false;
5139 Value *BECond = nullptr;
5140 if (BasicBlock *Latch = L->getLoopLatch()) {
5141 if (CondBrInst *BI = dyn_cast<CondBrInst>(Latch->getTerminator())) {
5142 assert(BI->getSuccessor(0) != BI->getSuccessor(1) &&
5143 "Both outgoing branches should not target same header!");
5144 BECond = BI->getCondition();
5145 IsPosBECond = BI->getSuccessor(0) == L->getHeader();
5146 } else {
5147 return S;
5148 }
5149 }
5150 SCEVBackedgeConditionFolder Rewriter(L, BECond, IsPosBECond, SE);
5151 return Rewriter.visit(S);
5152 }
5153
5154 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5155 const SCEV *Result = Expr;
5156 bool InvariantF = SE.isLoopInvariant(Expr, L);
5157
5158 if (!InvariantF) {
5160 switch (I->getOpcode()) {
5161 case Instruction::Select: {
5162 SelectInst *SI = cast<SelectInst>(I);
5163 std::optional<const SCEV *> Res =
5164 compareWithBackedgeCondition(SI->getCondition());
5165 if (Res) {
5166 bool IsOne = cast<SCEVConstant>(*Res)->getValue()->isOne();
5167 Result = SE.getSCEV(IsOne ? SI->getTrueValue() : SI->getFalseValue());
5168 }
5169 break;
5170 }
5171 default: {
5172 std::optional<const SCEV *> Res = compareWithBackedgeCondition(I);
5173 if (Res)
5174 Result = *Res;
5175 break;
5176 }
5177 }
5178 }
5179 return Result;
5180 }
5181
5182private:
5183 explicit SCEVBackedgeConditionFolder(const Loop *L, Value *BECond,
5184 bool IsPosBECond, ScalarEvolution &SE)
5185 : SCEVRewriteVisitor(SE), L(L), BackedgeCond(BECond),
5186 IsPositiveBECond(IsPosBECond) {}
5187
5188 std::optional<const SCEV *> compareWithBackedgeCondition(Value *IC);
5189
5190 const Loop *L;
5191 /// Loop back condition.
5192 Value *BackedgeCond = nullptr;
5193 /// Set to true if loop back is on positive branch condition.
5194 bool IsPositiveBECond;
5195};
5196
5197std::optional<const SCEV *>
5198SCEVBackedgeConditionFolder::compareWithBackedgeCondition(Value *IC) {
5199
5200 // If value matches the backedge condition for loop latch,
5201 // then return a constant evolution node based on loopback
5202 // branch taken.
5203 if (BackedgeCond == IC)
5204 return IsPositiveBECond ? SE.getOne(Type::getInt1Ty(SE.getContext()))
5206 return std::nullopt;
5207}
5208
5209class SCEVShiftRewriter : public SCEVRewriteVisitor<SCEVShiftRewriter> {
5210public:
5211 static const SCEV *rewrite(const SCEV *S, const Loop *L,
5212 ScalarEvolution &SE) {
5213 SCEVShiftRewriter Rewriter(L, SE);
5214 const SCEV *Result = Rewriter.visit(S);
5215 return Rewriter.isValid() ? Result : SE.getCouldNotCompute();
5216 }
5217
5218 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5219 // Only allow AddRecExprs for this loop.
5220 if (!SE.isLoopInvariant(Expr, L))
5221 Valid = false;
5222 return Expr;
5223 }
5224
5225 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
5226 if (Expr->getLoop() == L && Expr->isAffine())
5227 return SE.getMinusSCEV(Expr, Expr->getStepRecurrence(SE));
5228 Valid = false;
5229 return Expr;
5230 }
5231
5232 bool isValid() { return Valid; }
5233
5234private:
5235 explicit SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE)
5236 : SCEVRewriteVisitor(SE), L(L) {}
5237
5238 const Loop *L;
5239 bool Valid = true;
5240};
5241
5242} // end anonymous namespace
5243
5245ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) {
5246 if (!AR->isAffine())
5247 return SCEV::FlagAnyWrap;
5248
5249 using OBO = OverflowingBinaryOperator;
5250
5252
5253 if (!AR->hasNoSelfWrap()) {
5254 const SCEV *BECount = getConstantMaxBackedgeTakenCount(AR->getLoop());
5255 if (const SCEVConstant *BECountMax = dyn_cast<SCEVConstant>(BECount)) {
5256 ConstantRange StepCR = getSignedRange(AR->getStepRecurrence(*this));
5257 const APInt &BECountAP = BECountMax->getAPInt();
5258 unsigned NoOverflowBitWidth =
5259 BECountAP.getActiveBits() + StepCR.getMinSignedBits();
5260 if (NoOverflowBitWidth <= getTypeSizeInBits(AR->getType()))
5262 }
5263 }
5264
5265 if (!AR->hasNoSignedWrap()) {
5266 ConstantRange AddRecRange = getSignedRange(AR);
5267 ConstantRange IncRange = getSignedRange(AR->getStepRecurrence(*this));
5268
5270 Instruction::Add, IncRange, OBO::NoSignedWrap);
5271 if (NSWRegion.contains(AddRecRange))
5273 }
5274
5275 if (!AR->hasNoUnsignedWrap()) {
5276 ConstantRange AddRecRange = getUnsignedRange(AR);
5277 ConstantRange IncRange = getUnsignedRange(AR->getStepRecurrence(*this));
5278
5280 Instruction::Add, IncRange, OBO::NoUnsignedWrap);
5281 if (NUWRegion.contains(AddRecRange))
5283 }
5284
5285 return Result;
5286}
5287
5289ScalarEvolution::proveNoSignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5291
5292 if (AR->hasNoSignedWrap())
5293 return Result;
5294
5295 if (!AR->isAffine())
5296 return Result;
5297
5298 // This function can be expensive, only try to prove NSW once per AddRec.
5299 if (!SignedWrapViaInductionTried.insert(AR).second)
5300 return Result;
5301
5302 const SCEV *Step = AR->getStepRecurrence(*this);
5303 const Loop *L = AR->getLoop();
5304
5305 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5306 // Note that this serves two purposes: It filters out loops that are
5307 // simply not analyzable, and it covers the case where this code is
5308 // being called from within backedge-taken count analysis, such that
5309 // attempting to ask for the backedge-taken count would likely result
5310 // in infinite recursion. In the later case, the analysis code will
5311 // cope with a conservative value, and it will take care to purge
5312 // that value once it has finished.
5313 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5314
5315 // Normally, in the cases we can prove no-overflow via a
5316 // backedge guarding condition, we can also compute a backedge
5317 // taken count for the loop. The exceptions are assumptions and
5318 // guards present in the loop -- SCEV is not great at exploiting
5319 // these to compute max backedge taken counts, but can still use
5320 // these to prove lack of overflow. Use this fact to avoid
5321 // doing extra work that may not pay off.
5322
5323 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5324 AC.assumptions().empty())
5325 return Result;
5326
5327 // If the backedge is guarded by a comparison with the pre-inc value the
5328 // addrec is safe. Also, if the entry is guarded by a comparison with the
5329 // start value and the backedge is guarded by a comparison with the post-inc
5330 // value, the addrec is safe.
5332 const SCEV *OverflowLimit =
5333 getSignedOverflowLimitForStep(Step, &Pred, this);
5334 if (OverflowLimit &&
5335 (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) ||
5336 isKnownOnEveryIteration(Pred, AR, OverflowLimit))) {
5337 Result = setFlags(Result, SCEV::FlagNSW);
5338 }
5339 return Result;
5340}
5342ScalarEvolution::proveNoUnsignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5344
5345 if (AR->hasNoUnsignedWrap())
5346 return Result;
5347
5348 if (!AR->isAffine())
5349 return Result;
5350
5351 // This function can be expensive, only try to prove NUW once per AddRec.
5352 if (!UnsignedWrapViaInductionTried.insert(AR).second)
5353 return Result;
5354
5355 const SCEV *Step = AR->getStepRecurrence(*this);
5356 unsigned BitWidth = getTypeSizeInBits(AR->getType());
5357 const Loop *L = AR->getLoop();
5358
5359 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5360 // Note that this serves two purposes: It filters out loops that are
5361 // simply not analyzable, and it covers the case where this code is
5362 // being called from within backedge-taken count analysis, such that
5363 // attempting to ask for the backedge-taken count would likely result
5364 // in infinite recursion. In the later case, the analysis code will
5365 // cope with a conservative value, and it will take care to purge
5366 // that value once it has finished.
5367 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5368
5369 // Normally, in the cases we can prove no-overflow via a
5370 // backedge guarding condition, we can also compute a backedge
5371 // taken count for the loop. The exceptions are assumptions and
5372 // guards present in the loop -- SCEV is not great at exploiting
5373 // these to compute max backedge taken counts, but can still use
5374 // these to prove lack of overflow. Use this fact to avoid
5375 // doing extra work that may not pay off.
5376
5377 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5378 AC.assumptions().empty())
5379 return Result;
5380
5381 // If the backedge is guarded by a comparison with the pre-inc value the
5382 // addrec is safe. Also, if the entry is guarded by a comparison with the
5383 // start value and the backedge is guarded by a comparison with the post-inc
5384 // value, the addrec is safe.
5385 if (isKnownPositive(Step)) {
5386 const SCEV *N = getConstant(APInt::getMinValue(BitWidth) -
5387 getUnsignedRangeMax(Step));
5390 Result = setFlags(Result, SCEV::FlagNUW);
5391 }
5392 }
5393
5394 return Result;
5395}
5396
5397namespace {
5398
5399/// Represents an abstract binary operation. This may exist as a
5400/// normal instruction or constant expression, or may have been
5401/// derived from an expression tree.
5402struct BinaryOp {
5403 unsigned Opcode;
5404 Value *LHS;
5405 Value *RHS;
5406 bool IsNSW = false;
5407 bool IsNUW = false;
5408
5409 /// Op is set if this BinaryOp corresponds to a concrete LLVM instruction or
5410 /// constant expression.
5411 Operator *Op = nullptr;
5412
5413 explicit BinaryOp(Operator *Op)
5414 : Opcode(Op->getOpcode()), LHS(Op->getOperand(0)), RHS(Op->getOperand(1)),
5415 Op(Op) {
5416 if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Op)) {
5417 IsNSW = OBO->hasNoSignedWrap();
5418 IsNUW = OBO->hasNoUnsignedWrap();
5419 }
5420 }
5421
5422 explicit BinaryOp(unsigned Opcode, Value *LHS, Value *RHS, bool IsNSW = false,
5423 bool IsNUW = false)
5424 : Opcode(Opcode), LHS(LHS), RHS(RHS), IsNSW(IsNSW), IsNUW(IsNUW) {}
5425};
5426
5427} // end anonymous namespace
5428
5429/// Try to map \p V into a BinaryOp, and return \c std::nullopt on failure.
5430static std::optional<BinaryOp> MatchBinaryOp(Value *V, const DataLayout &DL,
5431 AssumptionCache &AC,
5432 const DominatorTree &DT,
5433 const Instruction *CxtI) {
5434 auto *Op = dyn_cast<Operator>(V);
5435 if (!Op)
5436 return std::nullopt;
5437
5438 // Implementation detail: all the cleverness here should happen without
5439 // creating new SCEV expressions -- our caller knowns tricks to avoid creating
5440 // SCEV expressions when possible, and we should not break that.
5441
5442 switch (Op->getOpcode()) {
5443 case Instruction::Add:
5444 case Instruction::Sub:
5445 case Instruction::Mul:
5446 case Instruction::UDiv:
5447 case Instruction::URem:
5448 case Instruction::And:
5449 case Instruction::AShr:
5450 case Instruction::Shl:
5451 return BinaryOp(Op);
5452
5453 case Instruction::Or: {
5454 // Convert or disjoint into add nuw nsw.
5455 if (cast<PossiblyDisjointInst>(Op)->isDisjoint()) {
5456 BinaryOp BinOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1),
5457 /*IsNSW=*/true, /*IsNUW=*/true);
5458 // Keep the reference to the original instruction so that we can later
5459 // check whether it can produce poison value or not.
5460 BinOp.Op = Op;
5461 return BinOp;
5462 }
5463 return BinaryOp(Op);
5464 }
5465
5466 case Instruction::Xor:
5467 if (auto *RHSC = dyn_cast<ConstantInt>(Op->getOperand(1)))
5468 // If the RHS of the xor is a signmask, then this is just an add.
5469 // Instcombine turns add of signmask into xor as a strength reduction step.
5470 if (RHSC->getValue().isSignMask())
5471 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5472 // Binary `xor` is a bit-wise `add`.
5473 if (V->getType()->isIntegerTy(1))
5474 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5475 return BinaryOp(Op);
5476
5477 case Instruction::LShr:
5478 // Turn logical shift right of a constant into a unsigned divide.
5479 if (ConstantInt *SA = dyn_cast<ConstantInt>(Op->getOperand(1))) {
5480 uint32_t BitWidth = cast<IntegerType>(Op->getType())->getBitWidth();
5481
5482 // If the shift count is not less than the bitwidth, the result of
5483 // the shift is undefined. Don't try to analyze it, because the
5484 // resolution chosen here may differ from the resolution chosen in
5485 // other parts of the compiler.
5486 if (SA->getValue().ult(BitWidth)) {
5487 Constant *X =
5488 ConstantInt::get(SA->getContext(),
5489 APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
5490 return BinaryOp(Instruction::UDiv, Op->getOperand(0), X);
5491 }
5492 }
5493 return BinaryOp(Op);
5494
5495 case Instruction::ExtractValue: {
5496 auto *EVI = cast<ExtractValueInst>(Op);
5497 if (EVI->getNumIndices() != 1 || EVI->getIndices()[0] != 0)
5498 break;
5499
5500 auto *WO = dyn_cast<WithOverflowInst>(EVI->getAggregateOperand());
5501 if (!WO)
5502 break;
5503
5504 Instruction::BinaryOps BinOp = WO->getBinaryOp();
5505 bool Signed = WO->isSigned();
5506 // TODO: Should add nuw/nsw flags for mul as well.
5507 if (BinOp == Instruction::Mul || !isOverflowIntrinsicNoWrap(WO, DT))
5508 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS());
5509
5510 // Now that we know that all uses of the arithmetic-result component of
5511 // CI are guarded by the overflow check, we can go ahead and pretend
5512 // that the arithmetic is non-overflowing.
5513 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS(),
5514 /* IsNSW = */ Signed, /* IsNUW = */ !Signed);
5515 }
5516
5517 default:
5518 break;
5519 }
5520
5521 // Recognise intrinsic loop.decrement.reg, and as this has exactly the same
5522 // semantics as a Sub, return a binary sub expression.
5523 if (auto *II = dyn_cast<IntrinsicInst>(V))
5524 if (II->getIntrinsicID() == Intrinsic::loop_decrement_reg)
5525 return BinaryOp(Instruction::Sub, II->getOperand(0), II->getOperand(1));
5526
5527 return std::nullopt;
5528}
5529
5530/// Helper function to createAddRecFromPHIWithCasts. We have a phi
5531/// node whose symbolic (unknown) SCEV is \p SymbolicPHI, which is updated via
5532/// the loop backedge by a SCEVAddExpr, possibly also with a few casts on the
5533/// way. This function checks if \p Op, an operand of this SCEVAddExpr,
5534/// follows one of the following patterns:
5535/// Op == (SExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5536/// Op == (ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5537/// If the SCEV expression of \p Op conforms with one of the expected patterns
5538/// we return the type of the truncation operation, and indicate whether the
5539/// truncated type should be treated as signed/unsigned by setting
5540/// \p Signed to true/false, respectively.
5541static Type *isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI,
5542 bool &Signed, ScalarEvolution &SE) {
5543 // The case where Op == SymbolicPHI (that is, with no type conversions on
5544 // the way) is handled by the regular add recurrence creating logic and
5545 // would have already been triggered in createAddRecForPHI. Reaching it here
5546 // means that createAddRecFromPHI had failed for this PHI before (e.g.,
5547 // because one of the other operands of the SCEVAddExpr updating this PHI is
5548 // not invariant).
5549 //
5550 // Here we look for the case where Op = (ext(trunc(SymbolicPHI))), and in
5551 // this case predicates that allow us to prove that Op == SymbolicPHI will
5552 // be added.
5553 if (Op == SymbolicPHI)
5554 return nullptr;
5555
5556 unsigned SourceBits = SE.getTypeSizeInBits(SymbolicPHI->getType());
5557 unsigned NewBits = SE.getTypeSizeInBits(Op->getType());
5558 if (SourceBits != NewBits)
5559 return nullptr;
5560
5561 if (match(Op, m_scev_SExt(m_scev_Trunc(m_scev_Specific(SymbolicPHI))))) {
5562 Signed = true;
5563 return cast<SCEVCastExpr>(Op)->getOperand()->getType();
5564 }
5565 if (match(Op, m_scev_ZExt(m_scev_Trunc(m_scev_Specific(SymbolicPHI))))) {
5566 Signed = false;
5567 return cast<SCEVCastExpr>(Op)->getOperand()->getType();
5568 }
5569 return nullptr;
5570}
5571
5572static const Loop *isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI) {
5573 if (!PN->getType()->isIntegerTy())
5574 return nullptr;
5575 const Loop *L = LI.getLoopFor(PN->getParent());
5576 if (!L || L->getHeader() != PN->getParent())
5577 return nullptr;
5578 return L;
5579}
5580
5581// Analyze \p SymbolicPHI, a SCEV expression of a phi node, and check if the
5582// computation that updates the phi follows the following pattern:
5583// (SExt/ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy) + InvariantAccum
5584// which correspond to a phi->trunc->sext/zext->add->phi update chain.
5585// If so, try to see if it can be rewritten as an AddRecExpr under some
5586// Predicates. If successful, return them as a pair. Also cache the results
5587// of the analysis.
5588//
5589// Example usage scenario:
5590// Say the Rewriter is called for the following SCEV:
5591// 8 * ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5592// where:
5593// %X = phi i64 (%Start, %BEValue)
5594// It will visitMul->visitAdd->visitSExt->visitTrunc->visitUnknown(%X),
5595// and call this function with %SymbolicPHI = %X.
5596//
5597// The analysis will find that the value coming around the backedge has
5598// the following SCEV:
5599// BEValue = ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5600// Upon concluding that this matches the desired pattern, the function
5601// will return the pair {NewAddRec, SmallPredsVec} where:
5602// NewAddRec = {%Start,+,%Step}
5603// SmallPredsVec = {P1, P2, P3} as follows:
5604// P1(WrapPred): AR: {trunc(%Start),+,(trunc %Step)}<nsw> Flags: <nssw>
5605// P2(EqualPred): %Start == (sext i32 (trunc i64 %Start to i32) to i64)
5606// P3(EqualPred): %Step == (sext i32 (trunc i64 %Step to i32) to i64)
5607// The returned pair means that SymbolicPHI can be rewritten into NewAddRec
5608// under the predicates {P1,P2,P3}.
5609// This predicated rewrite will be cached in PredicatedSCEVRewrites:
5610// PredicatedSCEVRewrites[{%X,L}] = {NewAddRec, {P1,P2,P3)}
5611//
5612// TODO's:
5613//
5614// 1) Extend the Induction descriptor to also support inductions that involve
5615// casts: When needed (namely, when we are called in the context of the
5616// vectorizer induction analysis), a Set of cast instructions will be
5617// populated by this method, and provided back to isInductionPHI. This is
5618// needed to allow the vectorizer to properly record them to be ignored by
5619// the cost model and to avoid vectorizing them (otherwise these casts,
5620// which are redundant under the runtime overflow checks, will be
5621// vectorized, which can be costly).
5622//
5623// 2) Support additional induction/PHISCEV patterns: We also want to support
5624// inductions where the sext-trunc / zext-trunc operations (partly) occur
5625// after the induction update operation (the induction increment):
5626//
5627// (Trunc iy (SExt/ZExt ix (%SymbolicPHI + InvariantAccum) to iy) to ix)
5628// which correspond to a phi->add->trunc->sext/zext->phi update chain.
5629//
5630// (Trunc iy ((SExt/ZExt ix (%SymbolicPhi) to iy) + InvariantAccum) to ix)
5631// which correspond to a phi->trunc->add->sext/zext->phi update chain.
5632//
5633// 3) Outline common code with createAddRecFromPHI to avoid duplication.
5634std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5635ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI) {
5637
5638 // *** Part1: Analyze if we have a phi-with-cast pattern for which we can
5639 // return an AddRec expression under some predicate.
5640
5641 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5642 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5643 assert(L && "Expecting an integer loop header phi");
5644
5645 // The loop may have multiple entrances or multiple exits; we can analyze
5646 // this phi as an addrec if it has a unique entry value and a unique
5647 // backedge value.
5648 Value *BEValueV = nullptr, *StartValueV = nullptr;
5649 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5650 Value *V = PN->getIncomingValue(i);
5651 if (L->contains(PN->getIncomingBlock(i))) {
5652 if (!BEValueV) {
5653 BEValueV = V;
5654 } else if (BEValueV != V) {
5655 BEValueV = nullptr;
5656 break;
5657 }
5658 } else if (!StartValueV) {
5659 StartValueV = V;
5660 } else if (StartValueV != V) {
5661 StartValueV = nullptr;
5662 break;
5663 }
5664 }
5665 if (!BEValueV || !StartValueV)
5666 return std::nullopt;
5667
5668 const SCEV *BEValue = getSCEV(BEValueV);
5669
5670 // If the value coming around the backedge is an add with the symbolic
5671 // value we just inserted, possibly with casts that we can ignore under
5672 // an appropriate runtime guard, then we found a simple induction variable!
5673 const auto *Add = dyn_cast<SCEVAddExpr>(BEValue);
5674 if (!Add)
5675 return std::nullopt;
5676
5677 // If there is a single occurrence of the symbolic value, possibly
5678 // casted, replace it with a recurrence.
5679 unsigned FoundIndex = Add->getNumOperands();
5680 Type *TruncTy = nullptr;
5681 bool Signed;
5682 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5683 if ((TruncTy =
5684 isSimpleCastedPHI(Add->getOperand(i), SymbolicPHI, Signed, *this)))
5685 if (FoundIndex == e) {
5686 FoundIndex = i;
5687 break;
5688 }
5689
5690 if (FoundIndex == Add->getNumOperands())
5691 return std::nullopt;
5692
5693 // Create an add with everything but the specified operand.
5695 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5696 if (i != FoundIndex)
5697 Ops.push_back(Add->getOperand(i));
5698 const SCEV *Accum = getAddExpr(Ops);
5699
5700 // The runtime checks will not be valid if the step amount is
5701 // varying inside the loop.
5702 if (!isLoopInvariant(Accum, L))
5703 return std::nullopt;
5704
5705 // *** Part2: Create the predicates
5706
5707 // Analysis was successful: we have a phi-with-cast pattern for which we
5708 // can return an AddRec expression under the following predicates:
5709 //
5710 // P1: A Wrap predicate that guarantees that Trunc(Start) + i*Trunc(Accum)
5711 // fits within the truncated type (does not overflow) for i = 0 to n-1.
5712 // P2: An Equal predicate that guarantees that
5713 // Start = (Ext ix (Trunc iy (Start) to ix) to iy)
5714 // P3: An Equal predicate that guarantees that
5715 // Accum = (Ext ix (Trunc iy (Accum) to ix) to iy)
5716 //
5717 // As we next prove, the above predicates guarantee that:
5718 // Start + i*Accum = (Ext ix (Trunc iy ( Start + i*Accum ) to ix) to iy)
5719 //
5720 //
5721 // More formally, we want to prove that:
5722 // Expr(i+1) = Start + (i+1) * Accum
5723 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5724 //
5725 // Given that:
5726 // 1) Expr(0) = Start
5727 // 2) Expr(1) = Start + Accum
5728 // = (Ext ix (Trunc iy (Start) to ix) to iy) + Accum :: from P2
5729 // 3) Induction hypothesis (step i):
5730 // Expr(i) = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum
5731 //
5732 // Proof:
5733 // Expr(i+1) =
5734 // = Start + (i+1)*Accum
5735 // = (Start + i*Accum) + Accum
5736 // = Expr(i) + Accum
5737 // = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum + Accum
5738 // :: from step i
5739 //
5740 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy) + Accum + Accum
5741 //
5742 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy)
5743 // + (Ext ix (Trunc iy (Accum) to ix) to iy)
5744 // + Accum :: from P3
5745 //
5746 // = (Ext ix (Trunc iy ((Start + (i-1)*Accum) + Accum) to ix) to iy)
5747 // + Accum :: from P1: Ext(x)+Ext(y)=>Ext(x+y)
5748 //
5749 // = (Ext ix (Trunc iy (Start + i*Accum) to ix) to iy) + Accum
5750 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5751 //
5752 // By induction, the same applies to all iterations 1<=i<n:
5753 //
5754
5755 // Create a truncated addrec for which we will add a no overflow check (P1).
5756 const SCEV *StartVal = getSCEV(StartValueV);
5757 const SCEV *PHISCEV =
5758 getAddRecExpr(getTruncateExpr(StartVal, TruncTy),
5759 getTruncateExpr(Accum, TruncTy), L, SCEV::FlagAnyWrap);
5760
5761 // PHISCEV can be either a SCEVConstant or a SCEVAddRecExpr.
5762 // ex: If truncated Accum is 0 and StartVal is a constant, then PHISCEV
5763 // will be constant.
5764 //
5765 // If PHISCEV is a constant, then P1 degenerates into P2 or P3, so we don't
5766 // add P1.
5767 if (const auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5771 const SCEVPredicate *AddRecPred = getWrapPredicate(AR, AddedFlags);
5772 Predicates.push_back(AddRecPred);
5773 }
5774
5775 // Create the Equal Predicates P2,P3:
5776
5777 // It is possible that the predicates P2 and/or P3 are computable at
5778 // compile time due to StartVal and/or Accum being constants.
5779 // If either one is, then we can check that now and escape if either P2
5780 // or P3 is false.
5781
5782 // Construct the extended SCEV: (Ext ix (Trunc iy (Expr) to ix) to iy)
5783 // for each of StartVal and Accum
5784 auto getExtendedExpr = [&](const SCEV *Expr,
5785 bool CreateSignExtend) -> const SCEV * {
5786 assert(isLoopInvariant(Expr, L) && "Expr is expected to be invariant");
5787 const SCEV *TruncatedExpr = getTruncateExpr(Expr, TruncTy);
5788 const SCEV *ExtendedExpr =
5789 CreateSignExtend ? getSignExtendExpr(TruncatedExpr, Expr->getType())
5790 : getZeroExtendExpr(TruncatedExpr, Expr->getType());
5791 return ExtendedExpr;
5792 };
5793
5794 // Given:
5795 // ExtendedExpr = (Ext ix (Trunc iy (Expr) to ix) to iy
5796 // = getExtendedExpr(Expr)
5797 // Determine whether the predicate P: Expr == ExtendedExpr
5798 // is known to be false at compile time
5799 auto PredIsKnownFalse = [&](const SCEV *Expr,
5800 const SCEV *ExtendedExpr) -> bool {
5801 return Expr != ExtendedExpr &&
5802 isKnownPredicate(ICmpInst::ICMP_NE, Expr, ExtendedExpr);
5803 };
5804
5805 const SCEV *StartExtended = getExtendedExpr(StartVal, Signed);
5806 if (PredIsKnownFalse(StartVal, StartExtended)) {
5807 LLVM_DEBUG(dbgs() << "P2 is compile-time false\n";);
5808 return std::nullopt;
5809 }
5810
5811 // The Step is always Signed (because the overflow checks are either
5812 // NSSW or NUSW)
5813 const SCEV *AccumExtended = getExtendedExpr(Accum, /*CreateSignExtend=*/true);
5814 if (PredIsKnownFalse(Accum, AccumExtended)) {
5815 LLVM_DEBUG(dbgs() << "P3 is compile-time false\n";);
5816 return std::nullopt;
5817 }
5818
5819 auto AppendPredicate = [&](const SCEV *Expr,
5820 const SCEV *ExtendedExpr) -> void {
5821 if (Expr != ExtendedExpr &&
5822 !isKnownPredicate(ICmpInst::ICMP_EQ, Expr, ExtendedExpr)) {
5823 const SCEVPredicate *Pred = getEqualPredicate(Expr, ExtendedExpr);
5824 LLVM_DEBUG(dbgs() << "Added Predicate: " << *Pred);
5825 Predicates.push_back(Pred);
5826 }
5827 };
5828
5829 AppendPredicate(StartVal, StartExtended);
5830 AppendPredicate(Accum, AccumExtended);
5831
5832 // *** Part3: Predicates are ready. Now go ahead and create the new addrec in
5833 // which the casts had been folded away. The caller can rewrite SymbolicPHI
5834 // into NewAR if it will also add the runtime overflow checks specified in
5835 // Predicates.
5836 auto *NewAR = getAddRecExpr(StartVal, Accum, L, SCEV::FlagAnyWrap);
5837
5838 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> PredRewrite =
5839 std::make_pair(NewAR, Predicates);
5840 // Remember the result of the analysis for this SCEV at this locayyytion.
5841 PredicatedSCEVRewrites[{SymbolicPHI, L}] = PredRewrite;
5842 return PredRewrite;
5843}
5844
5845std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5847 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5848 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5849 if (!L)
5850 return std::nullopt;
5851
5852 // Check to see if we already analyzed this PHI.
5853 auto I = PredicatedSCEVRewrites.find({SymbolicPHI, L});
5854 if (I != PredicatedSCEVRewrites.end()) {
5855 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> Rewrite =
5856 I->second;
5857 // Analysis was done before and failed to create an AddRec:
5858 if (Rewrite.first == SymbolicPHI)
5859 return std::nullopt;
5860 // Analysis was done before and succeeded to create an AddRec under
5861 // a predicate:
5862 assert(isa<SCEVAddRecExpr>(Rewrite.first) && "Expected an AddRec");
5863 assert(!(Rewrite.second).empty() && "Expected to find Predicates");
5864 return Rewrite;
5865 }
5866
5867 std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5868 Rewrite = createAddRecFromPHIWithCastsImpl(SymbolicPHI);
5869
5870 // Record in the cache that the analysis failed
5871 if (!Rewrite) {
5873 PredicatedSCEVRewrites[{SymbolicPHI, L}] = {SymbolicPHI, Predicates};
5874 return std::nullopt;
5875 }
5876
5877 return Rewrite;
5878}
5879
5880// FIXME: This utility is currently required because the Rewriter currently
5881// does not rewrite this expression:
5882// {0, +, (sext ix (trunc iy to ix) to iy)}
5883// into {0, +, %step},
5884// even when the following Equal predicate exists:
5885// "%step == (sext ix (trunc iy to ix) to iy)".
5887 const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const {
5888 if (AR1 == AR2)
5889 return true;
5890
5891 auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool {
5892 if (Expr1 != Expr2 &&
5893 !Preds->implies(SE.getEqualPredicate(Expr1, Expr2), SE) &&
5894 !Preds->implies(SE.getEqualPredicate(Expr2, Expr1), SE))
5895 return false;
5896 return true;
5897 };
5898
5899 if (!areExprsEqual(AR1->getStart(), AR2->getStart()) ||
5900 !areExprsEqual(AR1->getStepRecurrence(SE), AR2->getStepRecurrence(SE)))
5901 return false;
5902 return true;
5903}
5904
5905/// A helper function for createAddRecFromPHI to handle simple cases.
5906///
5907/// This function tries to find an AddRec expression for the simplest (yet most
5908/// common) cases: PN = PHI(Start, OP(Self, LoopInvariant)).
5909/// If it fails, createAddRecFromPHI will use a more general, but slow,
5910/// technique for finding the AddRec expression.
5911const SCEV *ScalarEvolution::createSimpleAffineAddRec(PHINode *PN,
5912 Value *BEValueV,
5913 Value *StartValueV) {
5914 const Loop *L = LI.getLoopFor(PN->getParent());
5915 assert(L && L->getHeader() == PN->getParent());
5916 assert(BEValueV && StartValueV);
5917
5918 auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN);
5919 if (!BO)
5920 return nullptr;
5921
5922 if (BO->Opcode != Instruction::Add)
5923 return nullptr;
5924
5925 const SCEV *Accum = nullptr;
5926 if (BO->LHS == PN && L->isLoopInvariant(BO->RHS))
5927 Accum = getSCEV(BO->RHS);
5928 else if (BO->RHS == PN && L->isLoopInvariant(BO->LHS))
5929 Accum = getSCEV(BO->LHS);
5930
5931 if (!Accum)
5932 return nullptr;
5933
5935 if (BO->IsNUW)
5936 Flags = setFlags(Flags, SCEV::FlagNUW);
5937 if (BO->IsNSW)
5938 Flags = setFlags(Flags, SCEV::FlagNSW);
5939
5940 const SCEV *StartVal = getSCEV(StartValueV);
5941 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5942 insertValueToMap(PN, PHISCEV);
5943
5944 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5945 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR),
5946 (AR->getNoWrapFlags() | proveNoWrapViaConstantRanges(AR)));
5947 }
5948
5949 // We can add Flags to the post-inc expression only if we
5950 // know that it is *undefined behavior* for BEValueV to
5951 // overflow.
5952 if (auto *BEInst = dyn_cast<Instruction>(BEValueV)) {
5953 assert(isLoopInvariant(Accum, L) &&
5954 "Accum is defined outside L, but is not invariant?");
5955 if (isAddRecNeverPoison(BEInst, L))
5956 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5957 }
5958
5959 return PHISCEV;
5960}
5961
5962const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
5963 const Loop *L = LI.getLoopFor(PN->getParent());
5964 if (!L || L->getHeader() != PN->getParent())
5965 return nullptr;
5966
5967 // The loop may have multiple entrances or multiple exits; we can analyze
5968 // this phi as an addrec if it has a unique entry value and a unique
5969 // backedge value.
5970 Value *BEValueV = nullptr, *StartValueV = nullptr;
5971 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5972 Value *V = PN->getIncomingValue(i);
5973 if (L->contains(PN->getIncomingBlock(i))) {
5974 if (!BEValueV) {
5975 BEValueV = V;
5976 } else if (BEValueV != V) {
5977 BEValueV = nullptr;
5978 break;
5979 }
5980 } else if (!StartValueV) {
5981 StartValueV = V;
5982 } else if (StartValueV != V) {
5983 StartValueV = nullptr;
5984 break;
5985 }
5986 }
5987 if (!BEValueV || !StartValueV)
5988 return nullptr;
5989
5990 assert(ValueExprMap.find_as(PN) == ValueExprMap.end() &&
5991 "PHI node already processed?");
5992
5993 // First, try to find AddRec expression without creating a fictituos symbolic
5994 // value for PN.
5995 if (auto *S = createSimpleAffineAddRec(PN, BEValueV, StartValueV))
5996 return S;
5997
5998 // Handle PHI node value symbolically.
5999 const SCEV *SymbolicName = getUnknown(PN);
6000 insertValueToMap(PN, SymbolicName);
6001
6002 // Using this symbolic name for the PHI, analyze the value coming around
6003 // the back-edge.
6004 const SCEV *BEValue = getSCEV(BEValueV);
6005
6006 // NOTE: If BEValue is loop invariant, we know that the PHI node just
6007 // has a special value for the first iteration of the loop.
6008
6009 // If the value coming around the backedge is an add with the symbolic
6010 // value we just inserted, then we found a simple induction variable!
6011 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) {
6012 // If there is a single occurrence of the symbolic value, replace it
6013 // with a recurrence.
6014 unsigned FoundIndex = Add->getNumOperands();
6015 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
6016 if (Add->getOperand(i) == SymbolicName)
6017 if (FoundIndex == e) {
6018 FoundIndex = i;
6019 break;
6020 }
6021
6022 if (FoundIndex != Add->getNumOperands()) {
6023 // Create an add with everything but the specified operand.
6025 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
6026 if (i != FoundIndex)
6027 Ops.push_back(SCEVBackedgeConditionFolder::rewrite(Add->getOperand(i),
6028 L, *this));
6029 const SCEV *Accum = getAddExpr(Ops);
6030
6031 // This is not a valid addrec if the step amount is varying each
6032 // loop iteration, but is not itself an addrec in this loop.
6033 if (isLoopInvariant(Accum, L) ||
6034 (isa<SCEVAddRecExpr>(Accum) &&
6035 cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) {
6037
6038 if (auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN)) {
6039 if (BO->Opcode == Instruction::Add && BO->LHS == PN) {
6040 if (BO->IsNUW)
6041 Flags = setFlags(Flags, SCEV::FlagNUW);
6042 if (BO->IsNSW)
6043 Flags = setFlags(Flags, SCEV::FlagNSW);
6044 }
6045 } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(BEValueV)) {
6046 if (GEP->getOperand(0) == PN) {
6047 GEPNoWrapFlags NW = GEP->getNoWrapFlags();
6048 // If the increment has any nowrap flags, then we know the address
6049 // space cannot be wrapped around.
6050 if (NW != GEPNoWrapFlags::none())
6051 Flags = setFlags(Flags, SCEV::FlagNW);
6052 // If the GEP is nuw or nusw with non-negative offset, we know that
6053 // no unsigned wrap occurs. We cannot set the nsw flag as only the
6054 // offset is treated as signed, while the base is unsigned.
6055 if (NW.hasNoUnsignedWrap() ||
6057 Flags = setFlags(Flags, SCEV::FlagNUW);
6058 }
6059
6060 // We cannot transfer nuw and nsw flags from subtraction
6061 // operations -- sub nuw X, Y is not the same as add nuw X, -Y
6062 // for instance.
6063 }
6064
6065 const SCEV *StartVal = getSCEV(StartValueV);
6066 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
6067
6068 // Okay, for the entire analysis of this edge we assumed the PHI
6069 // to be symbolic. We now need to go back and purge all of the
6070 // entries for the scalars that use the symbolic expression.
6071 forgetMemoizedResults({SymbolicName});
6072 insertValueToMap(PN, PHISCEV);
6073
6074 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
6076 const_cast<SCEVAddRecExpr *>(AR),
6077 (AR->getNoWrapFlags() | proveNoWrapViaConstantRanges(AR)));
6078 }
6079
6080 // We can add Flags to the post-inc expression only if we
6081 // know that it is *undefined behavior* for BEValueV to
6082 // overflow.
6083 if (auto *BEInst = dyn_cast<Instruction>(BEValueV))
6084 if (isLoopInvariant(Accum, L) && isAddRecNeverPoison(BEInst, L))
6085 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
6086
6087 return PHISCEV;
6088 }
6089 }
6090 } else {
6091 // Otherwise, this could be a loop like this:
6092 // i = 0; for (j = 1; ..; ++j) { .... i = j; }
6093 // In this case, j = {1,+,1} and BEValue is j.
6094 // Because the other in-value of i (0) fits the evolution of BEValue
6095 // i really is an addrec evolution.
6096 //
6097 // We can generalize this saying that i is the shifted value of BEValue
6098 // by one iteration:
6099 // PHI(f(0), f({1,+,1})) --> f({0,+,1})
6100
6101 // Do not allow refinement in rewriting of BEValue.
6102 const SCEV *Shifted = SCEVShiftRewriter::rewrite(BEValue, L, *this);
6103 const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this, false);
6104 if (Shifted != getCouldNotCompute() && Start != getCouldNotCompute() &&
6105 isGuaranteedNotToCauseUB(Shifted) && ::impliesPoison(Shifted, Start)) {
6106 const SCEV *StartVal = getSCEV(StartValueV);
6107 if (Start == StartVal) {
6108 // Okay, for the entire analysis of this edge we assumed the PHI
6109 // to be symbolic. We now need to go back and purge all of the
6110 // entries for the scalars that use the symbolic expression.
6111 forgetMemoizedResults({SymbolicName});
6112 insertValueToMap(PN, Shifted);
6113 return Shifted;
6114 }
6115 }
6116 }
6117
6118 // Remove the temporary PHI node SCEV that has been inserted while intending
6119 // to create an AddRecExpr for this PHI node. We can not keep this temporary
6120 // as it will prevent later (possibly simpler) SCEV expressions to be added
6121 // to the ValueExprMap.
6122 eraseValueFromMap(PN);
6123
6124 return nullptr;
6125}
6126
6127// Try to match a control flow sequence that branches out at BI and merges back
6128// at Merge into a "C ? LHS : RHS" select pattern. Return true on a successful
6129// match.
6131 Value *&C, Value *&LHS, Value *&RHS) {
6132 C = BI->getCondition();
6133
6134 BasicBlockEdge LeftEdge(BI->getParent(), BI->getSuccessor(0));
6135 BasicBlockEdge RightEdge(BI->getParent(), BI->getSuccessor(1));
6136
6137 Use &LeftUse = Merge->getOperandUse(0);
6138 Use &RightUse = Merge->getOperandUse(1);
6139
6140 if (DT.dominates(LeftEdge, LeftUse) && DT.dominates(RightEdge, RightUse)) {
6141 LHS = LeftUse;
6142 RHS = RightUse;
6143 return true;
6144 }
6145
6146 if (DT.dominates(LeftEdge, RightUse) && DT.dominates(RightEdge, LeftUse)) {
6147 LHS = RightUse;
6148 RHS = LeftUse;
6149 return true;
6150 }
6151
6152 return false;
6153}
6154
6156 Value *&Cond, Value *&LHS,
6157 Value *&RHS) {
6158 auto IsReachable =
6159 [&](BasicBlock *BB) { return DT.isReachableFromEntry(BB); };
6160 if (PN->getNumIncomingValues() == 2 && all_of(PN->blocks(), IsReachable)) {
6161 // Try to match
6162 //
6163 // br %cond, label %left, label %right
6164 // left:
6165 // br label %merge
6166 // right:
6167 // br label %merge
6168 // merge:
6169 // V = phi [ %x, %left ], [ %y, %right ]
6170 //
6171 // as "select %cond, %x, %y"
6172
6173 BasicBlock *IDom = DT[PN->getParent()]->getIDom()->getBlock();
6174 assert(IDom && "At least the entry block should dominate PN");
6175
6176 auto *BI = dyn_cast<CondBrInst>(IDom->getTerminator());
6177 return BI && BrPHIToSelect(DT, BI, PN, Cond, LHS, RHS);
6178 }
6179 return false;
6180}
6181
6182const SCEV *ScalarEvolution::createNodeFromSelectLikePHI(PHINode *PN) {
6183 Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr;
6184 if (getOperandsForSelectLikePHI(DT, PN, Cond, LHS, RHS) &&
6187 return createNodeForSelectOrPHI(PN, Cond, LHS, RHS);
6188
6189 return nullptr;
6190}
6191
6193 BinaryOperator *CommonInst = nullptr;
6194 // Check if instructions are identical.
6195 for (Value *Incoming : PN->incoming_values()) {
6196 auto *IncomingInst = dyn_cast<BinaryOperator>(Incoming);
6197 if (!IncomingInst)
6198 return nullptr;
6199 if (CommonInst) {
6200 if (!CommonInst->isIdenticalToWhenDefined(IncomingInst))
6201 return nullptr; // Not identical, give up
6202 } else {
6203 // Remember binary operator
6204 CommonInst = IncomingInst;
6205 }
6206 }
6207 return CommonInst;
6208}
6209
6210/// Returns SCEV for the first operand of a phi if all phi operands have
6211/// identical opcodes and operands
6212/// eg.
6213/// a: %add = %a + %b
6214/// br %c
6215/// b: %add1 = %a + %b
6216/// br %c
6217/// c: %phi = phi [%add, a], [%add1, b]
6218/// scev(%phi) => scev(%add)
6219const SCEV *
6220ScalarEvolution::createNodeForPHIWithIdenticalOperands(PHINode *PN) {
6221 BinaryOperator *CommonInst = getCommonInstForPHI(PN);
6222 if (!CommonInst)
6223 return nullptr;
6224
6225 // Check if SCEV exprs for instructions are identical.
6226 const SCEV *CommonSCEV = getSCEV(CommonInst);
6227 bool SCEVExprsIdentical =
6229 [this, CommonSCEV](Value *V) { return CommonSCEV == getSCEV(V); });
6230 return SCEVExprsIdentical ? CommonSCEV : nullptr;
6231}
6232
6233const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
6234 if (const SCEV *S = createAddRecFromPHI(PN))
6235 return S;
6236
6237 // We do not allow simplifying phi (undef, X) to X here, to avoid reusing the
6238 // phi node for X.
6239 if (Value *V = simplifyInstruction(
6240 PN, {getDataLayout(), &TLI, &DT, &AC, /*CtxI=*/nullptr,
6241 /*UseInstrInfo=*/true, /*CanUseUndef=*/false}))
6242 return getSCEV(V);
6243
6244 if (const SCEV *S = createNodeForPHIWithIdenticalOperands(PN))
6245 return S;
6246
6247 if (const SCEV *S = createNodeFromSelectLikePHI(PN))
6248 return S;
6249
6250 // If it's not a loop phi, we can't handle it yet.
6251 return getUnknown(PN);
6252}
6253
6254bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind,
6255 SCEVTypes RootKind) {
6256 struct FindClosure {
6257 const SCEV *OperandToFind;
6258 const SCEVTypes RootKind; // Must be a sequential min/max expression.
6259 const SCEVTypes NonSequentialRootKind; // Non-seq variant of RootKind.
6260
6261 bool Found = false;
6262
6263 bool canRecurseInto(SCEVTypes Kind) const {
6264 // We can only recurse into the SCEV expression of the same effective type
6265 // as the type of our root SCEV expression, and into zero-extensions.
6266 return RootKind == Kind || NonSequentialRootKind == Kind ||
6267 scZeroExtend == Kind;
6268 };
6269
6270 FindClosure(const SCEV *OperandToFind, SCEVTypes RootKind)
6271 : OperandToFind(OperandToFind), RootKind(RootKind),
6272 NonSequentialRootKind(
6274 RootKind)) {}
6275
6276 bool follow(const SCEV *S) {
6277 Found = S == OperandToFind;
6278
6279 return !isDone() && canRecurseInto(S->getSCEVType());
6280 }
6281
6282 bool isDone() const { return Found; }
6283 };
6284
6285 FindClosure FC(OperandToFind, RootKind);
6286 visitAll(Root, FC);
6287 return FC.Found;
6288}
6289
6290std::optional<const SCEV *>
6291ScalarEvolution::createNodeForSelectOrPHIInstWithICmpInstCond(Type *Ty,
6292 ICmpInst *Cond,
6293 Value *TrueVal,
6294 Value *FalseVal) {
6295 // Try to match some simple smax or umax patterns.
6296 auto *ICI = Cond;
6297
6298 Value *LHS = ICI->getOperand(0);
6299 Value *RHS = ICI->getOperand(1);
6300
6301 switch (ICI->getPredicate()) {
6302 case ICmpInst::ICMP_SLT:
6303 case ICmpInst::ICMP_SLE:
6304 case ICmpInst::ICMP_ULT:
6305 case ICmpInst::ICMP_ULE:
6306 std::swap(LHS, RHS);
6307 [[fallthrough]];
6308 case ICmpInst::ICMP_SGT:
6309 case ICmpInst::ICMP_SGE:
6310 case ICmpInst::ICMP_UGT:
6311 case ICmpInst::ICMP_UGE:
6312 // a > b ? a+x : b+x -> max(a, b)+x
6313 // a > b ? b+x : a+x -> min(a, b)+x
6315 bool Signed = ICI->isSigned();
6316 const SCEV *LA = getSCEV(TrueVal);
6317 const SCEV *RA = getSCEV(FalseVal);
6318 const SCEV *LS = getSCEV(LHS);
6319 const SCEV *RS = getSCEV(RHS);
6320 if (LA->getType()->isPointerTy()) {
6321 // FIXME: Handle cases where LS/RS are pointers not equal to LA/RA.
6322 // Need to make sure we can't produce weird expressions involving
6323 // negated pointers.
6324 if (LA == LS && RA == RS)
6325 return Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS);
6326 if (LA == RS && RA == LS)
6327 return Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS);
6328 }
6329 auto CoerceOperand = [&](const SCEV *Op) -> const SCEV * {
6330 if (Op->getType()->isPointerTy()) {
6333 return Op;
6334 }
6335 if (Signed)
6336 Op = getNoopOrSignExtend(Op, Ty);
6337 else
6338 Op = getNoopOrZeroExtend(Op, Ty);
6339 return Op;
6340 };
6341 LS = CoerceOperand(LS);
6342 RS = CoerceOperand(RS);
6344 break;
6345 const SCEV *LDiff = getMinusSCEV(LA, LS);
6346 const SCEV *RDiff = getMinusSCEV(RA, RS);
6347 if (LDiff == RDiff)
6348 return getAddExpr(Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS),
6349 LDiff);
6350 LDiff = getMinusSCEV(LA, RS);
6351 RDiff = getMinusSCEV(RA, LS);
6352 if (LDiff == RDiff)
6353 return getAddExpr(Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS),
6354 LDiff);
6355 }
6356 break;
6357 case ICmpInst::ICMP_NE:
6358 // x != 0 ? x+y : C+y -> x == 0 ? C+y : x+y
6359 std::swap(TrueVal, FalseVal);
6360 [[fallthrough]];
6361 case ICmpInst::ICMP_EQ:
6362 // x == 0 ? C+y : x+y -> umax(x, C)+y iff C u<= 1
6365 const SCEV *X = getNoopOrZeroExtend(getSCEV(LHS), Ty);
6366 const SCEV *TrueValExpr = getSCEV(TrueVal); // C+y
6367 const SCEV *FalseValExpr = getSCEV(FalseVal); // x+y
6368 const SCEV *Y = getMinusSCEV(FalseValExpr, X); // y = (x+y)-x
6369 const SCEV *C = getMinusSCEV(TrueValExpr, Y); // C = (C+y)-y
6370 if (isa<SCEVConstant>(C) && cast<SCEVConstant>(C)->getAPInt().ule(1))
6371 return getAddExpr(getUMaxExpr(X, C), Y);
6372 }
6373 // x == 0 ? 0 : umin (..., x, ...) -> umin_seq(x, umin (...))
6374 // x == 0 ? 0 : umin_seq(..., x, ...) -> umin_seq(x, umin_seq(...))
6375 // x == 0 ? 0 : umin (..., umin_seq(..., x, ...), ...)
6376 // -> umin_seq(x, umin (..., umin_seq(...), ...))
6378 isa<ConstantInt>(TrueVal) && cast<ConstantInt>(TrueVal)->isZero()) {
6379 const SCEV *X = getSCEV(LHS);
6380 while (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(X))
6381 X = ZExt->getOperand();
6382 if (getTypeSizeInBits(X->getType()) <= getTypeSizeInBits(Ty)) {
6383 const SCEV *FalseValExpr = getSCEV(FalseVal);
6384 if (SCEVMinMaxExprContains(FalseValExpr, X, scSequentialUMinExpr))
6385 return getUMinExpr(getNoopOrZeroExtend(X, Ty), FalseValExpr,
6386 /*Sequential=*/true);
6387 }
6388 }
6389 break;
6390 default:
6391 break;
6392 }
6393
6394 return std::nullopt;
6395}
6396
6397static std::optional<const SCEV *>
6399 const SCEV *TrueExpr, const SCEV *FalseExpr) {
6400 assert(CondExpr->getType()->isIntegerTy(1) &&
6401 TrueExpr->getType() == FalseExpr->getType() &&
6402 TrueExpr->getType()->isIntegerTy(1) &&
6403 "Unexpected operands of a select.");
6404
6405 // i1 cond ? i1 x : i1 C --> C + (i1 cond ? (i1 x - i1 C) : i1 0)
6406 // --> C + (umin_seq cond, x - C)
6407 //
6408 // i1 cond ? i1 C : i1 x --> C + (i1 cond ? i1 0 : (i1 x - i1 C))
6409 // --> C + (i1 ~cond ? (i1 x - i1 C) : i1 0)
6410 // --> C + (umin_seq ~cond, x - C)
6411
6412 // FIXME: while we can't legally model the case where both of the hands
6413 // are fully variable, we only require that the *difference* is constant.
6414 if (!isa<SCEVConstant>(TrueExpr) && !isa<SCEVConstant>(FalseExpr))
6415 return std::nullopt;
6416
6417 const SCEV *X, *C;
6418 if (isa<SCEVConstant>(TrueExpr)) {
6419 CondExpr = SE->getNotSCEV(CondExpr);
6420 X = FalseExpr;
6421 C = TrueExpr;
6422 } else {
6423 X = TrueExpr;
6424 C = FalseExpr;
6425 }
6426 return SE->getAddExpr(C, SE->getUMinExpr(CondExpr, SE->getMinusSCEV(X, C),
6427 /*Sequential=*/true));
6428}
6429
6430static std::optional<const SCEV *>
6432 Value *FalseVal) {
6433 if (!isa<ConstantInt>(TrueVal) && !isa<ConstantInt>(FalseVal))
6434 return std::nullopt;
6435
6436 const auto *SECond = SE->getSCEV(Cond);
6437 const auto *SETrue = SE->getSCEV(TrueVal);
6438 const auto *SEFalse = SE->getSCEV(FalseVal);
6439 return createNodeForSelectViaUMinSeq(SE, SECond, SETrue, SEFalse);
6440}
6441
6442const SCEV *ScalarEvolution::createNodeForSelectOrPHIViaUMinSeq(
6443 Value *V, Value *Cond, Value *TrueVal, Value *FalseVal) {
6444 assert(Cond->getType()->isIntegerTy(1) && "Select condition is not an i1?");
6445 assert(TrueVal->getType() == FalseVal->getType() &&
6446 V->getType() == TrueVal->getType() &&
6447 "Types of select hands and of the result must match.");
6448
6449 // For now, only deal with i1-typed `select`s.
6450 if (!V->getType()->isIntegerTy(1))
6451 return getUnknown(V);
6452
6453 if (std::optional<const SCEV *> S =
6454 createNodeForSelectViaUMinSeq(this, Cond, TrueVal, FalseVal))
6455 return *S;
6456
6457 return getUnknown(V);
6458}
6459
6460const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Value *V, Value *Cond,
6461 Value *TrueVal,
6462 Value *FalseVal) {
6463 // Handle "constant" branch or select. This can occur for instance when a
6464 // loop pass transforms an inner loop and moves on to process the outer loop.
6465 if (auto *CI = dyn_cast<ConstantInt>(Cond))
6466 return getSCEV(CI->isOne() ? TrueVal : FalseVal);
6467
6468 if (auto *I = dyn_cast<Instruction>(V)) {
6469 if (auto *ICI = dyn_cast<ICmpInst>(Cond)) {
6470 if (std::optional<const SCEV *> S =
6471 createNodeForSelectOrPHIInstWithICmpInstCond(I->getType(), ICI,
6472 TrueVal, FalseVal))
6473 return *S;
6474 }
6475 }
6476
6477 return createNodeForSelectOrPHIViaUMinSeq(V, Cond, TrueVal, FalseVal);
6478}
6479
6480/// Expand GEP instructions into add and multiply operations. This allows them
6481/// to be analyzed by regular SCEV code.
6482const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
6483 assert(GEP->getSourceElementType()->isSized() &&
6484 "GEP source element type must be sized");
6485
6486 SmallVector<SCEVUse, 4> IndexExprs;
6487 for (Value *Index : GEP->indices())
6488 IndexExprs.push_back(getSCEV(Index));
6489 return getGEPExpr(GEP, IndexExprs);
6490}
6491
6492APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S,
6493 const Instruction *CtxI) {
6494 uint64_t BitWidth = getTypeSizeInBits(S->getType());
6495 auto GetShiftedByZeros = [BitWidth](uint32_t TrailingZeros) {
6496 return TrailingZeros >= BitWidth
6498 : APInt::getOneBitSet(BitWidth, TrailingZeros);
6499 };
6500 auto GetGCDMultiple = [this, CtxI](const SCEVNAryExpr *N) {
6501 // The result is GCD of all operands results.
6502 APInt Res = getConstantMultiple(N->getOperand(0), CtxI);
6503 for (unsigned I = 1, E = N->getNumOperands(); I < E && Res != 1; ++I)
6505 Res, getConstantMultiple(N->getOperand(I), CtxI));
6506 return Res;
6507 };
6508
6509 switch (S->getSCEVType()) {
6510 case scConstant:
6511 return cast<SCEVConstant>(S)->getAPInt();
6512 case scPtrToAddr:
6513 case scPtrToInt:
6514 return getConstantMultiple(cast<SCEVCastExpr>(S)->getOperand());
6515 case scUDivExpr:
6516 case scVScale:
6517 return APInt(BitWidth, 1);
6518 case scTruncate: {
6519 // Only multiples that are a power of 2 will hold after truncation.
6520 const SCEVTruncateExpr *T = cast<SCEVTruncateExpr>(S);
6521 uint32_t TZ = getMinTrailingZeros(T->getOperand(), CtxI);
6522 return GetShiftedByZeros(TZ);
6523 }
6524 case scZeroExtend: {
6525 const SCEVZeroExtendExpr *Z = cast<SCEVZeroExtendExpr>(S);
6526 return getConstantMultiple(Z->getOperand(), CtxI).zext(BitWidth);
6527 }
6528 case scSignExtend: {
6529 // Only multiples that are a power of 2 will hold after sext.
6530 const SCEVSignExtendExpr *E = cast<SCEVSignExtendExpr>(S);
6531 uint32_t TZ = getMinTrailingZeros(E->getOperand(), CtxI);
6532 return GetShiftedByZeros(TZ);
6533 }
6534 case scMulExpr: {
6535 const SCEVMulExpr *M = cast<SCEVMulExpr>(S);
6536 if (M->hasNoUnsignedWrap()) {
6537 // The result is the product of all operand results.
6538 APInt Res = getConstantMultiple(M->getOperand(0), CtxI);
6539 for (const SCEV *Operand : M->operands().drop_front())
6540 Res = Res * getConstantMultiple(Operand, CtxI);
6541 return Res;
6542 }
6543
6544 // If there are no wrap guarentees, find the trailing zeros, which is the
6545 // sum of trailing zeros for all its operands.
6546 uint32_t TZ = 0;
6547 for (const SCEV *Operand : M->operands())
6548 TZ += getMinTrailingZeros(Operand, CtxI);
6549 return GetShiftedByZeros(TZ);
6550 }
6551 case scAddExpr:
6552 case scAddRecExpr: {
6553 const SCEVNAryExpr *N = cast<SCEVNAryExpr>(S);
6554 if (N->hasNoUnsignedWrap())
6555 return GetGCDMultiple(N);
6556 // Find the trailing bits, which is the minimum of its operands.
6557 uint32_t TZ = getMinTrailingZeros(N->getOperand(0), CtxI);
6558 for (const SCEV *Operand : N->operands().drop_front())
6559 TZ = std::min(TZ, getMinTrailingZeros(Operand, CtxI));
6560 return GetShiftedByZeros(TZ);
6561 }
6562 case scUMaxExpr:
6563 case scSMaxExpr:
6564 case scUMinExpr:
6565 case scSMinExpr:
6567 return GetGCDMultiple(cast<SCEVNAryExpr>(S));
6568 case scUnknown: {
6569 // Ask ValueTracking for known bits. SCEVUnknown only become available at
6570 // the point their underlying IR instruction has been defined. If CtxI was
6571 // not provided, use:
6572 // * the first instruction in the entry block if it is an argument
6573 // * the instruction itself otherwise.
6574 const SCEVUnknown *U = cast<SCEVUnknown>(S);
6575 if (!CtxI) {
6576 if (isa<Argument>(U->getValue()))
6577 CtxI = &*F.getEntryBlock().begin();
6578 else if (auto *I = dyn_cast<Instruction>(U->getValue()))
6579 CtxI = I;
6580 }
6581 unsigned Known =
6582 computeKnownBits(U->getValue(),
6583 SimplifyQuery(getDataLayout(), &DT, &AC, CtxI)
6584 .allowEphemerals(true))
6585 .countMinTrailingZeros();
6586 return GetShiftedByZeros(Known);
6587 }
6588 case scCouldNotCompute:
6589 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6590 }
6591 llvm_unreachable("Unknown SCEV kind!");
6592}
6593
6595 const Instruction *CtxI) {
6596 // Skip looking up and updating the cache if there is a context instruction,
6597 // as the result will only be valid in the specified context.
6598 if (CtxI)
6599 return getConstantMultipleImpl(S, CtxI);
6600
6601 auto I = ConstantMultipleCache.find(S);
6602 if (I != ConstantMultipleCache.end())
6603 return I->second;
6604
6605 APInt Result = getConstantMultipleImpl(S, CtxI);
6606 auto InsertPair = ConstantMultipleCache.insert({S, Result});
6607 assert(InsertPair.second && "Should insert a new key");
6608 return InsertPair.first->second;
6609}
6610
6612 APInt Multiple = getConstantMultiple(S);
6613 return Multiple == 0 ? APInt(Multiple.getBitWidth(), 1) : Multiple;
6614}
6615
6617 const Instruction *CtxI) {
6618 return std::min(getConstantMultiple(S, CtxI).countTrailingZeros(),
6619 (unsigned)getTypeSizeInBits(S->getType()));
6620}
6621
6622/// Helper method to assign a range to V from metadata present in the IR.
6623static std::optional<ConstantRange> GetRangeFromMetadata(Value *V) {
6625 if (MDNode *MD = I->getMetadata(LLVMContext::MD_range))
6626 return getConstantRangeFromMetadata(*MD);
6627 if (const auto *CB = dyn_cast<CallBase>(V))
6628 if (std::optional<ConstantRange> Range = CB->getRange())
6629 return Range;
6630 }
6631 if (auto *A = dyn_cast<Argument>(V))
6632 if (std::optional<ConstantRange> Range = A->getRange())
6633 return Range;
6634
6635 return std::nullopt;
6636}
6637
6639 SCEV::NoWrapFlags Flags) {
6640 if (AddRec->getNoWrapFlags(Flags) != Flags) {
6641 AddRec->setNoWrapFlags(Flags);
6642 UnsignedRanges.erase(AddRec);
6643 SignedRanges.erase(AddRec);
6644 ConstantMultipleCache.erase(AddRec);
6645 }
6646}
6647
6648ConstantRange ScalarEvolution::
6649getRangeForUnknownRecurrence(const SCEVUnknown *U) {
6650 const DataLayout &DL = getDataLayout();
6651
6652 unsigned BitWidth = getTypeSizeInBits(U->getType());
6653 const ConstantRange FullSet(BitWidth, /*isFullSet=*/true);
6654
6655 // Match a simple recurrence of the form: <start, ShiftOp, Step>, and then
6656 // use information about the trip count to improve our available range. Note
6657 // that the trip count independent cases are already handled by known bits.
6658 // WARNING: The definition of recurrence used here is subtly different than
6659 // the one used by AddRec (and thus most of this file). Step is allowed to
6660 // be arbitrarily loop varying here, where AddRec allows only loop invariant
6661 // and other addrecs in the same loop (for non-affine addrecs). The code
6662 // below intentionally handles the case where step is not loop invariant.
6663 auto *P = dyn_cast<PHINode>(U->getValue());
6664 if (!P)
6665 return FullSet;
6666
6667 // Make sure that no Phi input comes from an unreachable block. Otherwise,
6668 // even the values that are not available in these blocks may come from them,
6669 // and this leads to false-positive recurrence test.
6670 for (auto *Pred : predecessors(P->getParent()))
6671 if (!DT.isReachableFromEntry(Pred))
6672 return FullSet;
6673
6674 BinaryOperator *BO;
6675 Value *Start, *Step;
6676 if (!matchSimpleRecurrence(P, BO, Start, Step))
6677 return FullSet;
6678
6679 // If we found a recurrence in reachable code, we must be in a loop. Note
6680 // that BO might be in some subloop of L, and that's completely okay.
6681 auto *L = LI.getLoopFor(P->getParent());
6682 assert(L && L->getHeader() == P->getParent());
6683 if (!L->contains(BO->getParent()))
6684 // NOTE: This bailout should be an assert instead. However, asserting
6685 // the condition here exposes a case where LoopFusion is querying SCEV
6686 // with malformed loop information during the midst of the transform.
6687 // There doesn't appear to be an obvious fix, so for the moment bailout
6688 // until the caller issue can be fixed. PR49566 tracks the bug.
6689 return FullSet;
6690
6691 // TODO: Extend to other opcodes such as mul, and div
6692 switch (BO->getOpcode()) {
6693 default:
6694 return FullSet;
6695 case Instruction::AShr:
6696 case Instruction::LShr:
6697 case Instruction::Shl:
6698 break;
6699 };
6700
6701 if (BO->getOperand(0) != P)
6702 // TODO: Handle the power function forms some day.
6703 return FullSet;
6704
6705 unsigned TC = getSmallConstantMaxTripCount(L);
6706 if (!TC || TC >= BitWidth)
6707 return FullSet;
6708
6709 auto KnownStart = computeKnownBits(Start, DL, &AC, nullptr, &DT);
6710 auto KnownStep = computeKnownBits(Step, DL, &AC, nullptr, &DT);
6711 assert(KnownStart.getBitWidth() == BitWidth &&
6712 KnownStep.getBitWidth() == BitWidth);
6713
6714 // Compute total shift amount, being careful of overflow and bitwidths.
6715 auto MaxShiftAmt = KnownStep.getMaxValue();
6716 APInt TCAP(BitWidth, TC-1);
6717 bool Overflow = false;
6718 auto TotalShift = MaxShiftAmt.umul_ov(TCAP, Overflow);
6719 if (Overflow)
6720 return FullSet;
6721
6722 switch (BO->getOpcode()) {
6723 default:
6724 llvm_unreachable("filtered out above");
6725 case Instruction::AShr: {
6726 // For each ashr, three cases:
6727 // shift = 0 => unchanged value
6728 // saturation => 0 or -1
6729 // other => a value closer to zero (of the same sign)
6730 // Thus, the end value is closer to zero than the start.
6731 auto KnownEnd = KnownBits::ashr(KnownStart,
6732 KnownBits::makeConstant(TotalShift));
6733 if (KnownStart.isNonNegative())
6734 // Analogous to lshr (simply not yet canonicalized)
6735 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6736 KnownStart.getMaxValue() + 1);
6737 if (KnownStart.isNegative())
6738 // End >=u Start && End <=s Start
6739 return ConstantRange::getNonEmpty(KnownStart.getMinValue(),
6740 KnownEnd.getMaxValue() + 1);
6741 break;
6742 }
6743 case Instruction::LShr: {
6744 // For each lshr, three cases:
6745 // shift = 0 => unchanged value
6746 // saturation => 0
6747 // other => a smaller positive number
6748 // Thus, the low end of the unsigned range is the last value produced.
6749 auto KnownEnd = KnownBits::lshr(KnownStart,
6750 KnownBits::makeConstant(TotalShift));
6751 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6752 KnownStart.getMaxValue() + 1);
6753 }
6754 case Instruction::Shl: {
6755 // Iff no bits are shifted out, value increases on every shift.
6756 auto KnownEnd = KnownBits::shl(KnownStart,
6757 KnownBits::makeConstant(TotalShift));
6758 if (TotalShift.ult(KnownStart.countMinLeadingZeros()))
6759 return ConstantRange(KnownStart.getMinValue(),
6760 KnownEnd.getMaxValue() + 1);
6761 break;
6762 }
6763 };
6764 return FullSet;
6765}
6766
6767// The goal of this function is to check if recursively visiting the operands
6768// of this PHI might lead to an infinite loop. If we do see such a loop,
6769// there's no good way to break it, so we avoid analyzing such cases.
6770//
6771// getRangeRef previously used a visited set to avoid infinite loops, but this
6772// caused other issues: the result was dependent on the order of getRangeRef
6773// calls, and the interaction with createSCEVIter could cause a stack overflow
6774// in some cases (see issue #148253).
6775//
6776// FIXME: The way this is implemented is overly conservative; this checks
6777// for a few obviously safe patterns, but anything that doesn't lead to
6778// recursion is fine.
6780 Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr;
6782 return true;
6783
6784 if (all_of(PHI->operands(),
6785 [&](Value *Operand) { return DT.dominates(Operand, PHI); }))
6786 return true;
6787
6788 return false;
6789}
6790
6791const ConstantRange &
6792ScalarEvolution::getRangeRefIter(const SCEV *S,
6793 ScalarEvolution::RangeSignHint SignHint) {
6794 DenseMap<const SCEV *, ConstantRange> &Cache =
6795 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6796 : SignedRanges;
6797 SmallVector<SCEVUse> WorkList;
6798 SmallPtrSet<const SCEV *, 8> Seen;
6799
6800 // Add Expr to the worklist, if Expr is either an N-ary expression or a
6801 // SCEVUnknown PHI node.
6802 auto AddToWorklist = [&WorkList, &Seen, &Cache](const SCEV *Expr) {
6803 if (!Seen.insert(Expr).second)
6804 return;
6805 if (Cache.contains(Expr))
6806 return;
6807 switch (Expr->getSCEVType()) {
6808 case scUnknown:
6810 break;
6811 [[fallthrough]];
6812 case scConstant:
6813 case scVScale:
6814 case scTruncate:
6815 case scZeroExtend:
6816 case scSignExtend:
6817 case scPtrToAddr:
6818 case scPtrToInt:
6819 case scAddExpr:
6820 case scMulExpr:
6821 case scUDivExpr:
6822 case scAddRecExpr:
6823 case scUMaxExpr:
6824 case scSMaxExpr:
6825 case scUMinExpr:
6826 case scSMinExpr:
6828 WorkList.push_back(Expr);
6829 break;
6830 case scCouldNotCompute:
6831 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6832 }
6833 };
6834 AddToWorklist(S);
6835
6836 // Build worklist by queuing operands of N-ary expressions and phi nodes.
6837 for (unsigned I = 0; I != WorkList.size(); ++I) {
6838 const SCEV *P = WorkList[I];
6839 auto *UnknownS = dyn_cast<SCEVUnknown>(P);
6840 // If it is not a `SCEVUnknown`, just recurse into operands.
6841 if (!UnknownS) {
6842 for (const SCEV *Op : P->operands())
6843 AddToWorklist(Op);
6844 continue;
6845 }
6846 // `SCEVUnknown`'s require special treatment.
6847 if (PHINode *P = dyn_cast<PHINode>(UnknownS->getValue())) {
6848 if (!RangeRefPHIAllowedOperands(DT, P))
6849 continue;
6850 for (auto &Op : reverse(P->operands()))
6851 AddToWorklist(getSCEV(Op));
6852 }
6853 }
6854
6855 if (!WorkList.empty()) {
6856 // Use getRangeRef to compute ranges for items in the worklist in reverse
6857 // order. This will force ranges for earlier operands to be computed before
6858 // their users in most cases.
6859 for (const SCEV *P : reverse(drop_begin(WorkList))) {
6860 getRangeRef(P, SignHint);
6861 }
6862 }
6863
6864 return getRangeRef(S, SignHint, 0);
6865}
6866
6867/// Determine the range for a particular SCEV. If SignHint is
6868/// HINT_RANGE_UNSIGNED (resp. HINT_RANGE_SIGNED) then getRange prefers ranges
6869/// with a "cleaner" unsigned (resp. signed) representation.
6870const ConstantRange &ScalarEvolution::getRangeRef(
6871 const SCEV *S, ScalarEvolution::RangeSignHint SignHint, unsigned Depth) {
6872 DenseMap<const SCEV *, ConstantRange> &Cache =
6873 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6874 : SignedRanges;
6876 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? ConstantRange::Unsigned
6878
6879 // See if we've computed this range already.
6880 auto I = Cache.find(S);
6881 if (I != Cache.end())
6882 return I->second;
6883
6884 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
6885 return setRange(C, SignHint, ConstantRange(C->getAPInt()));
6886
6887 // Switch to iteratively computing the range for S, if it is part of a deeply
6888 // nested expression.
6890 return getRangeRefIter(S, SignHint);
6891
6892 unsigned BitWidth = getTypeSizeInBits(S->getType());
6893 ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true);
6894 using OBO = OverflowingBinaryOperator;
6895
6896 // If the value has known zeros, the maximum value will have those known zeros
6897 // as well.
6898 if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) {
6899 APInt Multiple = getNonZeroConstantMultiple(S);
6900 APInt Remainder = APInt::getMaxValue(BitWidth).urem(Multiple);
6901 if (!Remainder.isZero())
6902 ConservativeResult =
6903 ConstantRange(APInt::getMinValue(BitWidth),
6904 APInt::getMaxValue(BitWidth) - Remainder + 1);
6905 }
6906 else {
6907 uint32_t TZ = getMinTrailingZeros(S);
6908 if (TZ != 0) {
6909 ConservativeResult = ConstantRange(
6911 APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1);
6912 }
6913 }
6914
6915 switch (S->getSCEVType()) {
6916 case scConstant:
6917 llvm_unreachable("Already handled above.");
6918 case scVScale:
6919 return setRange(S, SignHint, getVScaleRange(&F, BitWidth));
6920 case scTruncate: {
6921 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(S);
6922 ConstantRange X = getRangeRef(Trunc->getOperand(), SignHint, Depth + 1);
6923 return setRange(
6924 Trunc, SignHint,
6925 ConservativeResult.intersectWith(X.truncate(BitWidth), RangeType));
6926 }
6927 case scZeroExtend: {
6928 const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(S);
6929 ConstantRange X = getRangeRef(ZExt->getOperand(), SignHint, Depth + 1);
6930 return setRange(
6931 ZExt, SignHint,
6932 ConservativeResult.intersectWith(X.zeroExtend(BitWidth), RangeType));
6933 }
6934 case scSignExtend: {
6935 const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(S);
6936 ConstantRange X = getRangeRef(SExt->getOperand(), SignHint, Depth + 1);
6937 return setRange(
6938 SExt, SignHint,
6939 ConservativeResult.intersectWith(X.signExtend(BitWidth), RangeType));
6940 }
6941 case scPtrToAddr:
6942 case scPtrToInt: {
6943 const SCEVCastExpr *Cast = cast<SCEVCastExpr>(S);
6944 ConstantRange X = getRangeRef(Cast->getOperand(), SignHint, Depth + 1);
6945 return setRange(Cast, SignHint, X);
6946 }
6947 case scAddExpr: {
6948 const SCEVAddExpr *Add = cast<SCEVAddExpr>(S);
6949 // Check if this is a URem pattern: A - (A / B) * B, which is always < B.
6950 const SCEV *URemLHS = nullptr, *URemRHS = nullptr;
6951 if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED &&
6952 match(S, m_scev_URem(m_SCEV(URemLHS), m_SCEV(URemRHS), *this))) {
6953 ConstantRange LHSRange = getRangeRef(URemLHS, SignHint, Depth + 1);
6954 ConstantRange RHSRange = getRangeRef(URemRHS, SignHint, Depth + 1);
6955 ConservativeResult =
6956 ConservativeResult.intersectWith(LHSRange.urem(RHSRange), RangeType);
6957 }
6958 ConstantRange X = getRangeRef(Add->getOperand(0), SignHint, Depth + 1);
6959 unsigned WrapType = OBO::AnyWrap;
6960 if (Add->hasNoSignedWrap())
6961 WrapType |= OBO::NoSignedWrap;
6962 if (Add->hasNoUnsignedWrap())
6963 WrapType |= OBO::NoUnsignedWrap;
6964 for (const SCEV *Op : drop_begin(Add->operands()))
6965 X = X.addWithNoWrap(getRangeRef(Op, SignHint, Depth + 1), WrapType,
6966 RangeType);
6967 return setRange(Add, SignHint,
6968 ConservativeResult.intersectWith(X, RangeType));
6969 }
6970 case scMulExpr: {
6971 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(S);
6972 ConstantRange X = getRangeRef(Mul->getOperand(0), SignHint, Depth + 1);
6973 for (const SCEV *Op : drop_begin(Mul->operands()))
6974 X = X.multiply(getRangeRef(Op, SignHint, Depth + 1));
6975 return setRange(Mul, SignHint,
6976 ConservativeResult.intersectWith(X, RangeType));
6977 }
6978 case scUDivExpr: {
6979 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
6980 ConstantRange X = getRangeRef(UDiv->getLHS(), SignHint, Depth + 1);
6981 ConstantRange Y = getRangeRef(UDiv->getRHS(), SignHint, Depth + 1);
6982 return setRange(UDiv, SignHint,
6983 ConservativeResult.intersectWith(X.udiv(Y), RangeType));
6984 }
6985 case scAddRecExpr: {
6986 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(S);
6987 // If there's no unsigned wrap, the value will never be less than its
6988 // initial value.
6989 if (AddRec->hasNoUnsignedWrap()) {
6990 APInt UnsignedMinValue = getUnsignedRangeMin(AddRec->getStart());
6991 if (!UnsignedMinValue.isZero())
6992 ConservativeResult = ConservativeResult.intersectWith(
6993 ConstantRange(UnsignedMinValue, APInt(BitWidth, 0)), RangeType);
6994 }
6995
6996 // If there's no signed wrap, and all the operands except initial value have
6997 // the same sign or zero, the value won't ever be:
6998 // 1: smaller than initial value if operands are non negative,
6999 // 2: bigger than initial value if operands are non positive.
7000 // For both cases, value can not cross signed min/max boundary.
7001 if (AddRec->hasNoSignedWrap()) {
7002 bool AllNonNeg = true;
7003 bool AllNonPos = true;
7004 for (unsigned i = 1, e = AddRec->getNumOperands(); i != e; ++i) {
7005 if (!isKnownNonNegative(AddRec->getOperand(i)))
7006 AllNonNeg = false;
7007 if (!isKnownNonPositive(AddRec->getOperand(i)))
7008 AllNonPos = false;
7009 }
7010 if (AllNonNeg)
7011 ConservativeResult = ConservativeResult.intersectWith(
7014 RangeType);
7015 else if (AllNonPos)
7016 ConservativeResult = ConservativeResult.intersectWith(
7018 getSignedRangeMax(AddRec->getStart()) +
7019 1),
7020 RangeType);
7021 }
7022
7023 // TODO: non-affine addrec
7024 if (AddRec->isAffine()) {
7025 const SCEV *MaxBEScev =
7027 if (!isa<SCEVCouldNotCompute>(MaxBEScev)) {
7028 APInt MaxBECount = cast<SCEVConstant>(MaxBEScev)->getAPInt();
7029
7030 // Adjust MaxBECount to the same bitwidth as AddRec. We can truncate if
7031 // MaxBECount's active bits are all <= AddRec's bit width.
7032 if (MaxBECount.getBitWidth() > BitWidth &&
7033 MaxBECount.getActiveBits() <= BitWidth)
7034 MaxBECount = MaxBECount.trunc(BitWidth);
7035 else if (MaxBECount.getBitWidth() < BitWidth)
7036 MaxBECount = MaxBECount.zext(BitWidth);
7037
7038 if (MaxBECount.getBitWidth() == BitWidth) {
7039 auto RangeFromAffine = getRangeForAffineAR(
7040 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
7041 ConservativeResult =
7042 ConservativeResult.intersectWith(RangeFromAffine, RangeType);
7043
7044 auto RangeFromFactoring = getRangeViaFactoring(
7045 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
7046 ConservativeResult =
7047 ConservativeResult.intersectWith(RangeFromFactoring, RangeType);
7048 }
7049 }
7050
7051 // Now try symbolic BE count and more powerful methods.
7053 const SCEV *SymbolicMaxBECount =
7055 if (!isa<SCEVCouldNotCompute>(SymbolicMaxBECount) &&
7056 getTypeSizeInBits(MaxBEScev->getType()) <= BitWidth &&
7057 AddRec->hasNoSelfWrap()) {
7058 auto RangeFromAffineNew = getRangeForAffineNoSelfWrappingAR(
7059 AddRec, SymbolicMaxBECount, BitWidth, SignHint);
7060 ConservativeResult =
7061 ConservativeResult.intersectWith(RangeFromAffineNew, RangeType);
7062 }
7063 }
7064 }
7065
7066 return setRange(AddRec, SignHint, std::move(ConservativeResult));
7067 }
7068 case scUMaxExpr:
7069 case scSMaxExpr:
7070 case scUMinExpr:
7071 case scSMinExpr:
7072 case scSequentialUMinExpr: {
7074 switch (S->getSCEVType()) {
7075 case scUMaxExpr:
7076 ID = Intrinsic::umax;
7077 break;
7078 case scSMaxExpr:
7079 ID = Intrinsic::smax;
7080 break;
7081 case scUMinExpr:
7083 ID = Intrinsic::umin;
7084 break;
7085 case scSMinExpr:
7086 ID = Intrinsic::smin;
7087 break;
7088 default:
7089 llvm_unreachable("Unknown SCEVMinMaxExpr/SCEVSequentialMinMaxExpr.");
7090 }
7091
7092 const auto *NAry = cast<SCEVNAryExpr>(S);
7093 ConstantRange X = getRangeRef(NAry->getOperand(0), SignHint, Depth + 1);
7094 for (unsigned i = 1, e = NAry->getNumOperands(); i != e; ++i)
7095 X = X.intrinsic(
7096 ID, {X, getRangeRef(NAry->getOperand(i), SignHint, Depth + 1)});
7097 return setRange(S, SignHint,
7098 ConservativeResult.intersectWith(X, RangeType));
7099 }
7100 case scUnknown: {
7101 const SCEVUnknown *U = cast<SCEVUnknown>(S);
7102 Value *V = U->getValue();
7103
7104 // Check if the IR explicitly contains !range metadata.
7105 std::optional<ConstantRange> MDRange = GetRangeFromMetadata(V);
7106 if (MDRange)
7107 ConservativeResult =
7108 ConservativeResult.intersectWith(*MDRange, RangeType);
7109
7110 // Use facts about recurrences in the underlying IR. Note that add
7111 // recurrences are AddRecExprs and thus don't hit this path. This
7112 // primarily handles shift recurrences.
7113 auto CR = getRangeForUnknownRecurrence(U);
7114 ConservativeResult = ConservativeResult.intersectWith(CR);
7115
7116 // See if ValueTracking can give us a useful range.
7117 const DataLayout &DL = getDataLayout();
7118 KnownBits Known = computeKnownBits(V, DL, &AC, nullptr, &DT);
7119 if (Known.getBitWidth() != BitWidth)
7120 Known = Known.zextOrTrunc(BitWidth);
7121
7122 // ValueTracking may be able to compute a tighter result for the number of
7123 // sign bits than for the value of those sign bits.
7124 unsigned NS = ComputeNumSignBits(V, DL, &AC, nullptr, &DT);
7125 if (U->getType()->isPointerTy()) {
7126 // If the pointer size is larger than the index size type, this can cause
7127 // NS to be larger than BitWidth. So compensate for this.
7128 unsigned ptrSize = DL.getPointerTypeSizeInBits(U->getType());
7129 int ptrIdxDiff = ptrSize - BitWidth;
7130 if (ptrIdxDiff > 0 && ptrSize > BitWidth && NS > (unsigned)ptrIdxDiff)
7131 NS -= ptrIdxDiff;
7132 }
7133
7134 if (NS > 1) {
7135 // If we know any of the sign bits, we know all of the sign bits.
7136 if (!Known.Zero.getHiBits(NS).isZero())
7137 Known.Zero.setHighBits(NS);
7138 if (!Known.One.getHiBits(NS).isZero())
7139 Known.One.setHighBits(NS);
7140 }
7141
7142 if (Known.getMinValue() != Known.getMaxValue() + 1)
7143 ConservativeResult = ConservativeResult.intersectWith(
7144 ConstantRange(Known.getMinValue(), Known.getMaxValue() + 1),
7145 RangeType);
7146 if (NS > 1)
7147 ConservativeResult = ConservativeResult.intersectWith(
7148 ConstantRange(APInt::getSignedMinValue(BitWidth).ashr(NS - 1),
7149 APInt::getSignedMaxValue(BitWidth).ashr(NS - 1) + 1),
7150 RangeType);
7151
7152 if (U->getType()->isPointerTy() && SignHint == HINT_RANGE_UNSIGNED) {
7153 // Strengthen the range if the underlying IR value is a
7154 // global/alloca/heap allocation using the size of the object.
7155 bool CanBeNull, CanBeFreed;
7156 uint64_t DerefBytes =
7157 V->getPointerDereferenceableBytes(DL, CanBeNull, CanBeFreed);
7158 if (DerefBytes > 1 && isUIntN(BitWidth, DerefBytes)) {
7159 // The highest address the object can start is DerefBytes bytes before
7160 // the end (unsigned max value). If this value is not a multiple of the
7161 // alignment, the last possible start value is the next lowest multiple
7162 // of the alignment. Note: The computations below cannot overflow,
7163 // because if they would there's no possible start address for the
7164 // object.
7165 APInt MaxVal =
7166 APInt::getMaxValue(BitWidth) - APInt(BitWidth, DerefBytes);
7167 uint64_t Align = U->getValue()->getPointerAlignment(DL).value();
7168 uint64_t Rem = MaxVal.urem(Align);
7169 MaxVal -= APInt(BitWidth, Rem);
7170 APInt MinVal = APInt::getZero(BitWidth);
7171 if (llvm::isKnownNonZero(V, DL))
7172 MinVal = Align;
7173 ConservativeResult = ConservativeResult.intersectWith(
7174 ConstantRange::getNonEmpty(MinVal, MaxVal + 1), RangeType);
7175 }
7176 }
7177
7178 // A range of Phi is a subset of union of all ranges of its input.
7179 if (PHINode *Phi = dyn_cast<PHINode>(V)) {
7180 // SCEVExpander sometimes creates SCEVUnknowns that are secretly
7181 // AddRecs; return the range for the corresponding AddRec.
7182 if (auto *AR = dyn_cast<SCEVAddRecExpr>(getSCEV(V)))
7183 return getRangeRef(AR, SignHint, Depth + 1);
7184
7185 // Make sure that we do not run over cycled Phis.
7186 if (RangeRefPHIAllowedOperands(DT, Phi)) {
7187 ConstantRange RangeFromOps(BitWidth, /*isFullSet=*/false);
7188
7189 for (const auto &Op : Phi->operands()) {
7190 auto OpRange = getRangeRef(getSCEV(Op), SignHint, Depth + 1);
7191 RangeFromOps = RangeFromOps.unionWith(OpRange);
7192 // No point to continue if we already have a full set.
7193 if (RangeFromOps.isFullSet())
7194 break;
7195 }
7196 ConservativeResult =
7197 ConservativeResult.intersectWith(RangeFromOps, RangeType);
7198 }
7199 }
7200
7201 // vscale can't be equal to zero
7202 if (const auto *II = dyn_cast<IntrinsicInst>(V))
7203 if (II->getIntrinsicID() == Intrinsic::vscale) {
7204 ConstantRange Disallowed = APInt::getZero(BitWidth);
7205 ConservativeResult = ConservativeResult.difference(Disallowed);
7206 }
7207
7208 return setRange(U, SignHint, std::move(ConservativeResult));
7209 }
7210 case scCouldNotCompute:
7211 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
7212 }
7213
7214 return setRange(S, SignHint, std::move(ConservativeResult));
7215}
7216
7217// Given a StartRange, Step and MaxBECount for an expression compute a range of
7218// values that the expression can take. Initially, the expression has a value
7219// from StartRange and then is changed by Step up to MaxBECount times. Signed
7220// argument defines if we treat Step as signed or unsigned.
7222 const ConstantRange &StartRange,
7223 const APInt &MaxBECount,
7224 bool Signed) {
7225 unsigned BitWidth = Step.getBitWidth();
7226 assert(BitWidth == StartRange.getBitWidth() &&
7227 BitWidth == MaxBECount.getBitWidth() && "mismatched bit widths");
7228 // If either Step or MaxBECount is 0, then the expression won't change, and we
7229 // just need to return the initial range.
7230 if (Step == 0 || MaxBECount == 0)
7231 return StartRange;
7232
7233 // If we don't know anything about the initial value (i.e. StartRange is
7234 // FullRange), then we don't know anything about the final range either.
7235 // Return FullRange.
7236 if (StartRange.isFullSet())
7237 return ConstantRange::getFull(BitWidth);
7238
7239 // If Step is signed and negative, then we use its absolute value, but we also
7240 // note that we're moving in the opposite direction.
7241 bool Descending = Signed && Step.isNegative();
7242
7243 if (Signed)
7244 // This is correct even for INT_SMIN. Let's look at i8 to illustrate this:
7245 // abs(INT_SMIN) = abs(-128) = abs(0x80) = -0x80 = 0x80 = 128.
7246 // This equations hold true due to the well-defined wrap-around behavior of
7247 // APInt.
7248 Step = Step.abs();
7249
7250 // Check if Offset is more than full span of BitWidth. If it is, the
7251 // expression is guaranteed to overflow.
7252 if (APInt::getMaxValue(StartRange.getBitWidth()).udiv(Step).ult(MaxBECount))
7253 return ConstantRange::getFull(BitWidth);
7254
7255 // Offset is by how much the expression can change. Checks above guarantee no
7256 // overflow here.
7257 APInt Offset = Step * MaxBECount;
7258
7259 // Minimum value of the final range will match the minimal value of StartRange
7260 // if the expression is increasing and will be decreased by Offset otherwise.
7261 // Maximum value of the final range will match the maximal value of StartRange
7262 // if the expression is decreasing and will be increased by Offset otherwise.
7263 APInt StartLower = StartRange.getLower();
7264 APInt StartUpper = StartRange.getUpper() - 1;
7265 APInt MovedBoundary = Descending ? (StartLower - std::move(Offset))
7266 : (StartUpper + std::move(Offset));
7267
7268 // It's possible that the new minimum/maximum value will fall into the initial
7269 // range (due to wrap around). This means that the expression can take any
7270 // value in this bitwidth, and we have to return full range.
7271 if (StartRange.contains(MovedBoundary))
7272 return ConstantRange::getFull(BitWidth);
7273
7274 APInt NewLower =
7275 Descending ? std::move(MovedBoundary) : std::move(StartLower);
7276 APInt NewUpper =
7277 Descending ? std::move(StartUpper) : std::move(MovedBoundary);
7278 NewUpper += 1;
7279
7280 // No overflow detected, return [StartLower, StartUpper + Offset + 1) range.
7281 return ConstantRange::getNonEmpty(std::move(NewLower), std::move(NewUpper));
7282}
7283
7284ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start,
7285 const SCEV *Step,
7286 const APInt &MaxBECount) {
7287 assert(getTypeSizeInBits(Start->getType()) ==
7288 getTypeSizeInBits(Step->getType()) &&
7289 getTypeSizeInBits(Start->getType()) == MaxBECount.getBitWidth() &&
7290 "mismatched bit widths");
7291
7292 // First, consider step signed.
7293 ConstantRange StartSRange = getSignedRange(Start);
7294 ConstantRange StepSRange = getSignedRange(Step);
7295
7296 // If Step can be both positive and negative, we need to find ranges for the
7297 // maximum absolute step values in both directions and union them.
7298 ConstantRange SR = getRangeForAffineARHelper(
7299 StepSRange.getSignedMin(), StartSRange, MaxBECount, /* Signed = */ true);
7301 StartSRange, MaxBECount,
7302 /* Signed = */ true));
7303
7304 // Next, consider step unsigned.
7305 ConstantRange UR = getRangeForAffineARHelper(
7306 getUnsignedRangeMax(Step), getUnsignedRange(Start), MaxBECount,
7307 /* Signed = */ false);
7308
7309 // Finally, intersect signed and unsigned ranges.
7311}
7312
7313ConstantRange ScalarEvolution::getRangeForAffineNoSelfWrappingAR(
7314 const SCEVAddRecExpr *AddRec, const SCEV *MaxBECount, unsigned BitWidth,
7315 ScalarEvolution::RangeSignHint SignHint) {
7316 assert(AddRec->isAffine() && "Non-affine AddRecs are not suppored!\n");
7317 assert(AddRec->hasNoSelfWrap() &&
7318 "This only works for non-self-wrapping AddRecs!");
7319 const bool IsSigned = SignHint == HINT_RANGE_SIGNED;
7320 const SCEV *Step = AddRec->getStepRecurrence(*this);
7321 // Only deal with constant step to save compile time.
7322 if (!isa<SCEVConstant>(Step))
7323 return ConstantRange::getFull(BitWidth);
7324 // Let's make sure that we can prove that we do not self-wrap during
7325 // MaxBECount iterations. We need this because MaxBECount is a maximum
7326 // iteration count estimate, and we might infer nw from some exit for which we
7327 // do not know max exit count (or any other side reasoning).
7328 // TODO: Turn into assert at some point.
7329 if (getTypeSizeInBits(MaxBECount->getType()) >
7330 getTypeSizeInBits(AddRec->getType()))
7331 return ConstantRange::getFull(BitWidth);
7332 MaxBECount = getNoopOrZeroExtend(MaxBECount, AddRec->getType());
7333 const SCEV *RangeWidth = getMinusOne(AddRec->getType());
7334 const SCEV *StepAbs = getUMinExpr(Step, getNegativeSCEV(Step));
7335 const SCEV *MaxItersWithoutWrap = getUDivExpr(RangeWidth, StepAbs);
7336 if (!isKnownPredicateViaConstantRanges(ICmpInst::ICMP_ULE, MaxBECount,
7337 MaxItersWithoutWrap))
7338 return ConstantRange::getFull(BitWidth);
7339
7340 ICmpInst::Predicate LEPred =
7342 ICmpInst::Predicate GEPred =
7344 const SCEV *End = AddRec->evaluateAtIteration(MaxBECount, *this);
7345
7346 // We know that there is no self-wrap. Let's take Start and End values and
7347 // look at all intermediate values V1, V2, ..., Vn that IndVar takes during
7348 // the iteration. They either lie inside the range [Min(Start, End),
7349 // Max(Start, End)] or outside it:
7350 //
7351 // Case 1: RangeMin ... Start V1 ... VN End ... RangeMax;
7352 // Case 2: RangeMin Vk ... V1 Start ... End Vn ... Vk + 1 RangeMax;
7353 //
7354 // No self wrap flag guarantees that the intermediate values cannot be BOTH
7355 // outside and inside the range [Min(Start, End), Max(Start, End)]. Using that
7356 // knowledge, let's try to prove that we are dealing with Case 1. It is so if
7357 // Start <= End and step is positive, or Start >= End and step is negative.
7358 const SCEV *Start = applyLoopGuards(AddRec->getStart(), AddRec->getLoop());
7359 ConstantRange StartRange = getRangeRef(Start, SignHint);
7360 ConstantRange EndRange = getRangeRef(End, SignHint);
7361 ConstantRange RangeBetween = StartRange.unionWith(EndRange);
7362 // If they already cover full iteration space, we will know nothing useful
7363 // even if we prove what we want to prove.
7364 if (RangeBetween.isFullSet())
7365 return RangeBetween;
7366 // Only deal with ranges that do not wrap (i.e. RangeMin < RangeMax).
7367 bool IsWrappedSet = IsSigned ? RangeBetween.isSignWrappedSet()
7368 : RangeBetween.isWrappedSet();
7369 if (IsWrappedSet)
7370 return ConstantRange::getFull(BitWidth);
7371
7372 if (isKnownPositive(Step) &&
7373 isKnownPredicateViaConstantRanges(LEPred, Start, End))
7374 return RangeBetween;
7375 if (isKnownNegative(Step) &&
7376 isKnownPredicateViaConstantRanges(GEPred, Start, End))
7377 return RangeBetween;
7378 return ConstantRange::getFull(BitWidth);
7379}
7380
7381ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start,
7382 const SCEV *Step,
7383 const APInt &MaxBECount) {
7384 // RangeOf({C?A:B,+,C?P:Q}) == RangeOf(C?{A,+,P}:{B,+,Q})
7385 // == RangeOf({A,+,P}) union RangeOf({B,+,Q})
7386
7387 unsigned BitWidth = MaxBECount.getBitWidth();
7388 assert(getTypeSizeInBits(Start->getType()) == BitWidth &&
7389 getTypeSizeInBits(Step->getType()) == BitWidth &&
7390 "mismatched bit widths");
7391
7392 struct SelectPattern {
7393 Value *Condition = nullptr;
7394 APInt TrueValue;
7395 APInt FalseValue;
7396
7397 explicit SelectPattern(ScalarEvolution &SE, unsigned BitWidth,
7398 const SCEV *S) {
7399 std::optional<unsigned> CastOp;
7400 APInt Offset(BitWidth, 0);
7401
7403 "Should be!");
7404
7405 // Peel off a constant offset. In the future we could consider being
7406 // smarter here and handle {Start+Step,+,Step} too.
7407 const APInt *Off;
7408 if (match(S, m_scev_Add(m_scev_APInt(Off), m_SCEV(S))))
7409 Offset = *Off;
7410
7411 // Peel off a cast operation
7412 if (auto *SCast = dyn_cast<SCEVIntegralCastExpr>(S)) {
7413 CastOp = SCast->getSCEVType();
7414 S = SCast->getOperand();
7415 }
7416
7417 using namespace llvm::PatternMatch;
7418
7419 auto *SU = dyn_cast<SCEVUnknown>(S);
7420 const APInt *TrueVal, *FalseVal;
7421 if (!SU ||
7422 !match(SU->getValue(), m_Select(m_Value(Condition), m_APInt(TrueVal),
7423 m_APInt(FalseVal)))) {
7424 Condition = nullptr;
7425 return;
7426 }
7427
7428 TrueValue = *TrueVal;
7429 FalseValue = *FalseVal;
7430
7431 // Re-apply the cast we peeled off earlier
7432 if (CastOp)
7433 switch (*CastOp) {
7434 default:
7435 llvm_unreachable("Unknown SCEV cast type!");
7436
7437 case scTruncate:
7438 TrueValue = TrueValue.trunc(BitWidth);
7439 FalseValue = FalseValue.trunc(BitWidth);
7440 break;
7441 case scZeroExtend:
7442 TrueValue = TrueValue.zext(BitWidth);
7443 FalseValue = FalseValue.zext(BitWidth);
7444 break;
7445 case scSignExtend:
7446 TrueValue = TrueValue.sext(BitWidth);
7447 FalseValue = FalseValue.sext(BitWidth);
7448 break;
7449 }
7450
7451 // Re-apply the constant offset we peeled off earlier
7452 TrueValue += Offset;
7453 FalseValue += Offset;
7454 }
7455
7456 bool isRecognized() { return Condition != nullptr; }
7457 };
7458
7459 SelectPattern StartPattern(*this, BitWidth, Start);
7460 if (!StartPattern.isRecognized())
7461 return ConstantRange::getFull(BitWidth);
7462
7463 SelectPattern StepPattern(*this, BitWidth, Step);
7464 if (!StepPattern.isRecognized())
7465 return ConstantRange::getFull(BitWidth);
7466
7467 if (StartPattern.Condition != StepPattern.Condition) {
7468 // We don't handle this case today; but we could, by considering four
7469 // possibilities below instead of two. I'm not sure if there are cases where
7470 // that will help over what getRange already does, though.
7471 return ConstantRange::getFull(BitWidth);
7472 }
7473
7474 // NB! Calling ScalarEvolution::getConstant is fine, but we should not try to
7475 // construct arbitrary general SCEV expressions here. This function is called
7476 // from deep in the call stack, and calling getSCEV (on a sext instruction,
7477 // say) can end up caching a suboptimal value.
7478
7479 // FIXME: without the explicit `this` receiver below, MSVC errors out with
7480 // C2352 and C2512 (otherwise it isn't needed).
7481
7482 const SCEV *TrueStart = this->getConstant(StartPattern.TrueValue);
7483 const SCEV *TrueStep = this->getConstant(StepPattern.TrueValue);
7484 const SCEV *FalseStart = this->getConstant(StartPattern.FalseValue);
7485 const SCEV *FalseStep = this->getConstant(StepPattern.FalseValue);
7486
7487 ConstantRange TrueRange =
7488 this->getRangeForAffineAR(TrueStart, TrueStep, MaxBECount);
7489 ConstantRange FalseRange =
7490 this->getRangeForAffineAR(FalseStart, FalseStep, MaxBECount);
7491
7492 return TrueRange.unionWith(FalseRange);
7493}
7494
7495SCEV::NoWrapFlags ScalarEvolution::getNoWrapFlagsFromUB(const Value *V) {
7496 if (isa<ConstantExpr>(V)) return SCEV::FlagAnyWrap;
7497 const BinaryOperator *BinOp = cast<BinaryOperator>(V);
7498
7499 // Return early if there are no flags to propagate to the SCEV.
7501 if (auto *PDI = dyn_cast<PossiblyDisjointInst>(BinOp);
7502 PDI && PDI->isDisjoint()) {
7504 } else {
7505 if (BinOp->hasNoUnsignedWrap())
7507 if (BinOp->hasNoSignedWrap())
7509 }
7510 if (Flags == SCEV::FlagAnyWrap)
7511 return SCEV::FlagAnyWrap;
7512
7513 return isSCEVExprNeverPoison(BinOp) ? Flags : SCEV::FlagAnyWrap;
7514}
7515
7516const Instruction *
7517ScalarEvolution::getNonTrivialDefiningScopeBound(const SCEV *S) {
7518 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S))
7519 return &*AddRec->getLoop()->getHeader()->begin();
7520 if (auto *U = dyn_cast<SCEVUnknown>(S))
7521 if (auto *I = dyn_cast<Instruction>(U->getValue()))
7522 return I;
7523 return nullptr;
7524}
7525
7526const Instruction *ScalarEvolution::getDefiningScopeBound(ArrayRef<SCEVUse> Ops,
7527 bool &Precise) {
7528 Precise = true;
7529 // Do a bounded search of the def relation of the requested SCEVs.
7530 SmallPtrSet<const SCEV *, 16> Visited;
7531 SmallVector<SCEVUse> Worklist;
7532 auto pushOp = [&](const SCEV *S) {
7533 if (!Visited.insert(S).second)
7534 return;
7535 // Threshold of 30 here is arbitrary.
7536 if (Visited.size() > 30) {
7537 Precise = false;
7538 return;
7539 }
7540 Worklist.push_back(S);
7541 };
7542
7543 for (SCEVUse S : Ops)
7544 pushOp(S);
7545
7546 const Instruction *Bound = nullptr;
7547 while (!Worklist.empty()) {
7548 SCEVUse S = Worklist.pop_back_val();
7549 if (auto *DefI = getNonTrivialDefiningScopeBound(S)) {
7550 if (!Bound || DT.dominates(Bound, DefI))
7551 Bound = DefI;
7552 } else {
7553 for (SCEVUse Op : S->operands())
7554 pushOp(Op);
7555 }
7556 }
7557 return Bound ? Bound : &*F.getEntryBlock().begin();
7558}
7559
7560const Instruction *
7561ScalarEvolution::getDefiningScopeBound(ArrayRef<SCEVUse> Ops) {
7562 bool Discard;
7563 return getDefiningScopeBound(Ops, Discard);
7564}
7565
7566bool ScalarEvolution::isGuaranteedToTransferExecutionTo(const Instruction *A,
7567 const Instruction *B) {
7568 if (A->getParent() == B->getParent() &&
7570 B->getIterator()))
7571 return true;
7572
7573 auto *BLoop = LI.getLoopFor(B->getParent());
7574 if (BLoop && BLoop->getHeader() == B->getParent() &&
7575 BLoop->getLoopPreheader() == A->getParent() &&
7577 A->getParent()->end()) &&
7578 isGuaranteedToTransferExecutionToSuccessor(B->getParent()->begin(),
7579 B->getIterator()))
7580 return true;
7581 return false;
7582}
7583
7584bool ScalarEvolution::isGuaranteedNotToBePoison(const SCEV *Op) {
7585 SCEVPoisonCollector PC(/* LookThroughMaybePoisonBlocking */ true);
7586 visitAll(Op, PC);
7587 return PC.MaybePoison.empty();
7588}
7589
7590bool ScalarEvolution::isGuaranteedNotToCauseUB(const SCEV *Op) {
7591 return !SCEVExprContains(Op, [this](const SCEV *S) {
7592 const SCEV *Op1;
7593 bool M = match(S, m_scev_UDiv(m_SCEV(), m_SCEV(Op1)));
7594 // The UDiv may be UB if the divisor is poison or zero. Unless the divisor
7595 // is a non-zero constant, we have to assume the UDiv may be UB.
7596 return M && (!isKnownNonZero(Op1) || !isGuaranteedNotToBePoison(Op1));
7597 });
7598}
7599
7600bool ScalarEvolution::isSCEVExprNeverPoison(const Instruction *I) {
7601 // Only proceed if we can prove that I does not yield poison.
7603 return false;
7604
7605 // At this point we know that if I is executed, then it does not wrap
7606 // according to at least one of NSW or NUW. If I is not executed, then we do
7607 // not know if the calculation that I represents would wrap. Multiple
7608 // instructions can map to the same SCEV. If we apply NSW or NUW from I to
7609 // the SCEV, we must guarantee no wrapping for that SCEV also when it is
7610 // derived from other instructions that map to the same SCEV. We cannot make
7611 // that guarantee for cases where I is not executed. So we need to find a
7612 // upper bound on the defining scope for the SCEV, and prove that I is
7613 // executed every time we enter that scope. When the bounding scope is a
7614 // loop (the common case), this is equivalent to proving I executes on every
7615 // iteration of that loop.
7616 SmallVector<SCEVUse> SCEVOps;
7617 for (const Use &Op : I->operands()) {
7618 // I could be an extractvalue from a call to an overflow intrinsic.
7619 // TODO: We can do better here in some cases.
7620 if (isSCEVable(Op->getType()))
7621 SCEVOps.push_back(getSCEV(Op));
7622 }
7623 auto *DefI = getDefiningScopeBound(SCEVOps);
7624 return isGuaranteedToTransferExecutionTo(DefI, I);
7625}
7626
7627bool ScalarEvolution::isAddRecNeverPoison(const Instruction *I, const Loop *L) {
7628 // If we know that \c I can never be poison period, then that's enough.
7629 if (isSCEVExprNeverPoison(I))
7630 return true;
7631
7632 // If the loop only has one exit, then we know that, if the loop is entered,
7633 // any instruction dominating that exit will be executed. If any such
7634 // instruction would result in UB, the addrec cannot be poison.
7635 //
7636 // This is basically the same reasoning as in isSCEVExprNeverPoison(), but
7637 // also handles uses outside the loop header (they just need to dominate the
7638 // single exit).
7639
7640 auto *ExitingBB = L->getExitingBlock();
7641 if (!ExitingBB || !loopHasNoAbnormalExits(L))
7642 return false;
7643
7644 SmallPtrSet<const Value *, 16> KnownPoison;
7646
7647 // We start by assuming \c I, the post-inc add recurrence, is poison. Only
7648 // things that are known to be poison under that assumption go on the
7649 // Worklist.
7650 KnownPoison.insert(I);
7651 Worklist.push_back(I);
7652
7653 while (!Worklist.empty()) {
7654 const Instruction *Poison = Worklist.pop_back_val();
7655
7656 for (const Use &U : Poison->uses()) {
7657 const Instruction *PoisonUser = cast<Instruction>(U.getUser());
7658 if (mustTriggerUB(PoisonUser, KnownPoison) &&
7659 DT.dominates(PoisonUser->getParent(), ExitingBB))
7660 return true;
7661
7662 if (propagatesPoison(U) && L->contains(PoisonUser))
7663 if (KnownPoison.insert(PoisonUser).second)
7664 Worklist.push_back(PoisonUser);
7665 }
7666 }
7667
7668 return false;
7669}
7670
7671ScalarEvolution::LoopProperties
7672ScalarEvolution::getLoopProperties(const Loop *L) {
7673 using LoopProperties = ScalarEvolution::LoopProperties;
7674
7675 auto Itr = LoopPropertiesCache.find(L);
7676 if (Itr == LoopPropertiesCache.end()) {
7677 auto HasSideEffects = [](Instruction *I) {
7678 if (auto *SI = dyn_cast<StoreInst>(I))
7679 return !SI->isSimple();
7680
7681 if (I->mayThrow())
7682 return true;
7683
7684 // Non-volatile memset / memcpy do not count as side-effect for forward
7685 // progress.
7686 if (isa<MemIntrinsic>(I) && !I->isVolatile())
7687 return false;
7688
7689 return I->mayWriteToMemory();
7690 };
7691
7692 LoopProperties LP = {/* HasNoAbnormalExits */ true,
7693 /*HasNoSideEffects*/ true};
7694
7695 for (auto *BB : L->getBlocks())
7696 for (auto &I : *BB) {
7698 LP.HasNoAbnormalExits = false;
7699 if (HasSideEffects(&I))
7700 LP.HasNoSideEffects = false;
7701 if (!LP.HasNoAbnormalExits && !LP.HasNoSideEffects)
7702 break; // We're already as pessimistic as we can get.
7703 }
7704
7705 auto InsertPair = LoopPropertiesCache.insert({L, LP});
7706 assert(InsertPair.second && "We just checked!");
7707 Itr = InsertPair.first;
7708 }
7709
7710 return Itr->second;
7711}
7712
7714 // A mustprogress loop without side effects must be finite.
7715 // TODO: The check used here is very conservative. It's only *specific*
7716 // side effects which are well defined in infinite loops.
7717 return isFinite(L) || (isMustProgress(L) && loopHasNoSideEffects(L));
7718}
7719
7720const SCEV *ScalarEvolution::createSCEVIter(Value *V) {
7721 // Worklist item with a Value and a bool indicating whether all operands have
7722 // been visited already.
7725
7726 Stack.emplace_back(V, true);
7727 Stack.emplace_back(V, false);
7728 while (!Stack.empty()) {
7729 auto E = Stack.pop_back_val();
7730 Value *CurV = E.getPointer();
7731
7732 if (getExistingSCEV(CurV))
7733 continue;
7734
7736 const SCEV *CreatedSCEV = nullptr;
7737 // If all operands have been visited already, create the SCEV.
7738 if (E.getInt()) {
7739 CreatedSCEV = createSCEV(CurV);
7740 } else {
7741 // Otherwise get the operands we need to create SCEV's for before creating
7742 // the SCEV for CurV. If the SCEV for CurV can be constructed trivially,
7743 // just use it.
7744 CreatedSCEV = getOperandsToCreate(CurV, Ops);
7745 }
7746
7747 if (CreatedSCEV) {
7748 insertValueToMap(CurV, CreatedSCEV);
7749 } else {
7750 // Queue CurV for SCEV creation, followed by its's operands which need to
7751 // be constructed first.
7752 Stack.emplace_back(CurV, true);
7753 for (Value *Op : Ops)
7754 Stack.emplace_back(Op, false);
7755 }
7756 }
7757
7758 return getExistingSCEV(V);
7759}
7760
7761const SCEV *
7762ScalarEvolution::getOperandsToCreate(Value *V, SmallVectorImpl<Value *> &Ops) {
7763 if (!isSCEVable(V->getType()))
7764 return getUnknown(V);
7765
7766 if (Instruction *I = dyn_cast<Instruction>(V)) {
7767 // Don't attempt to analyze instructions in blocks that aren't
7768 // reachable. Such instructions don't matter, and they aren't required
7769 // to obey basic rules for definitions dominating uses which this
7770 // analysis depends on.
7771 if (!DT.isReachableFromEntry(I->getParent()))
7772 return getUnknown(PoisonValue::get(V->getType()));
7773 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7774 return getConstant(CI);
7775 else if (isa<GlobalAlias>(V))
7776 return getUnknown(V);
7777 else if (!isa<ConstantExpr>(V))
7778 return getUnknown(V);
7779
7781 if (auto BO =
7783 bool IsConstArg = isa<ConstantInt>(BO->RHS);
7784 switch (BO->Opcode) {
7785 case Instruction::Add:
7786 case Instruction::Mul: {
7787 // For additions and multiplications, traverse add/mul chains for which we
7788 // can potentially create a single SCEV, to reduce the number of
7789 // get{Add,Mul}Expr calls.
7790 do {
7791 if (BO->Op) {
7792 if (BO->Op != V && getExistingSCEV(BO->Op)) {
7793 Ops.push_back(BO->Op);
7794 break;
7795 }
7796 }
7797 Ops.push_back(BO->RHS);
7798 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7800 if (!NewBO ||
7801 (BO->Opcode == Instruction::Add &&
7802 (NewBO->Opcode != Instruction::Add &&
7803 NewBO->Opcode != Instruction::Sub)) ||
7804 (BO->Opcode == Instruction::Mul &&
7805 NewBO->Opcode != Instruction::Mul)) {
7806 Ops.push_back(BO->LHS);
7807 break;
7808 }
7809 // CreateSCEV calls getNoWrapFlagsFromUB, which under certain conditions
7810 // requires a SCEV for the LHS.
7811 if (BO->Op && (BO->IsNSW || BO->IsNUW)) {
7812 auto *I = dyn_cast<Instruction>(BO->Op);
7813 if (I && programUndefinedIfPoison(I)) {
7814 Ops.push_back(BO->LHS);
7815 break;
7816 }
7817 }
7818 BO = NewBO;
7819 } while (true);
7820 return nullptr;
7821 }
7822 case Instruction::Sub:
7823 case Instruction::UDiv:
7824 case Instruction::URem:
7825 break;
7826 case Instruction::AShr:
7827 case Instruction::Shl:
7828 case Instruction::Xor:
7829 if (!IsConstArg)
7830 return nullptr;
7831 break;
7832 case Instruction::And:
7833 case Instruction::Or:
7834 if (!IsConstArg && !BO->LHS->getType()->isIntegerTy(1))
7835 return nullptr;
7836 break;
7837 case Instruction::LShr:
7838 return getUnknown(V);
7839 default:
7840 llvm_unreachable("Unhandled binop");
7841 break;
7842 }
7843
7844 Ops.push_back(BO->LHS);
7845 Ops.push_back(BO->RHS);
7846 return nullptr;
7847 }
7848
7849 switch (U->getOpcode()) {
7850 case Instruction::Trunc:
7851 case Instruction::ZExt:
7852 case Instruction::SExt:
7853 case Instruction::PtrToAddr:
7854 case Instruction::PtrToInt:
7855 Ops.push_back(U->getOperand(0));
7856 return nullptr;
7857
7858 case Instruction::BitCast:
7859 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType())) {
7860 Ops.push_back(U->getOperand(0));
7861 return nullptr;
7862 }
7863 return getUnknown(V);
7864
7865 case Instruction::SDiv:
7866 case Instruction::SRem:
7867 Ops.push_back(U->getOperand(0));
7868 Ops.push_back(U->getOperand(1));
7869 return nullptr;
7870
7871 case Instruction::GetElementPtr:
7872 assert(cast<GEPOperator>(U)->getSourceElementType()->isSized() &&
7873 "GEP source element type must be sized");
7874 llvm::append_range(Ops, U->operands());
7875 return nullptr;
7876
7877 case Instruction::IntToPtr:
7878 return getUnknown(V);
7879
7880 case Instruction::PHI:
7881 // getNodeForPHI has four ways to turn a PHI into a SCEV; retrieve the
7882 // relevant nodes for each of them.
7883 //
7884 // The first is just to call simplifyInstruction, and get something back
7885 // that isn't a PHI.
7886 if (Value *V = simplifyInstruction(
7887 cast<PHINode>(U),
7888 {getDataLayout(), &TLI, &DT, &AC, /*CtxI=*/nullptr,
7889 /*UseInstrInfo=*/true, /*CanUseUndef=*/false})) {
7890 assert(V);
7891 Ops.push_back(V);
7892 return nullptr;
7893 }
7894 // The second is createNodeForPHIWithIdenticalOperands: this looks for
7895 // operands which all perform the same operation, but haven't been
7896 // CSE'ed for whatever reason.
7897 if (BinaryOperator *BO = getCommonInstForPHI(cast<PHINode>(U))) {
7898 assert(BO);
7899 Ops.push_back(BO);
7900 return nullptr;
7901 }
7902 // The third is createNodeFromSelectLikePHI; this takes a PHI which
7903 // is equivalent to a select, and analyzes it like a select.
7904 {
7905 Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr;
7907 assert(Cond);
7908 assert(LHS);
7909 assert(RHS);
7910 if (auto *CondICmp = dyn_cast<ICmpInst>(Cond)) {
7911 Ops.push_back(CondICmp->getOperand(0));
7912 Ops.push_back(CondICmp->getOperand(1));
7913 }
7914 Ops.push_back(Cond);
7915 Ops.push_back(LHS);
7916 Ops.push_back(RHS);
7917 return nullptr;
7918 }
7919 }
7920 // The fourth way is createAddRecFromPHI. It's complicated to handle here,
7921 // so just construct it recursively.
7922 //
7923 // In addition to getNodeForPHI, also construct nodes which might be needed
7924 // by getRangeRef.
7926 for (Value *V : cast<PHINode>(U)->operands())
7927 Ops.push_back(V);
7928 return nullptr;
7929 }
7930 return nullptr;
7931
7932 case Instruction::Select: {
7933 // Check if U is a select that can be simplified to a SCEVUnknown.
7934 auto CanSimplifyToUnknown = [this, U]() {
7935 if (U->getType()->isIntegerTy(1) || isa<ConstantInt>(U->getOperand(0)))
7936 return false;
7937
7938 auto *ICI = dyn_cast<ICmpInst>(U->getOperand(0));
7939 if (!ICI)
7940 return false;
7941 Value *LHS = ICI->getOperand(0);
7942 Value *RHS = ICI->getOperand(1);
7943 if (ICI->getPredicate() == CmpInst::ICMP_EQ ||
7944 ICI->getPredicate() == CmpInst::ICMP_NE) {
7946 return true;
7947 } else if (getTypeSizeInBits(LHS->getType()) >
7948 getTypeSizeInBits(U->getType()))
7949 return true;
7950 return false;
7951 };
7952 if (CanSimplifyToUnknown())
7953 return getUnknown(U);
7954
7955 llvm::append_range(Ops, U->operands());
7956 return nullptr;
7957 break;
7958 }
7959 case Instruction::Call:
7960 case Instruction::Invoke:
7961 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand()) {
7962 Ops.push_back(RV);
7963 return nullptr;
7964 }
7965
7966 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
7967 switch (II->getIntrinsicID()) {
7968 case Intrinsic::abs:
7969 Ops.push_back(II->getArgOperand(0));
7970 return nullptr;
7971 case Intrinsic::umax:
7972 case Intrinsic::umin:
7973 case Intrinsic::smax:
7974 case Intrinsic::smin:
7975 case Intrinsic::usub_sat:
7976 case Intrinsic::uadd_sat:
7977 Ops.push_back(II->getArgOperand(0));
7978 Ops.push_back(II->getArgOperand(1));
7979 return nullptr;
7980 case Intrinsic::start_loop_iterations:
7981 case Intrinsic::annotation:
7982 case Intrinsic::ptr_annotation:
7983 Ops.push_back(II->getArgOperand(0));
7984 return nullptr;
7985 default:
7986 break;
7987 }
7988 }
7989 break;
7990 }
7991
7992 return nullptr;
7993}
7994
7995const SCEV *ScalarEvolution::createSCEV(Value *V) {
7996 if (!isSCEVable(V->getType()))
7997 return getUnknown(V);
7998
7999 if (Instruction *I = dyn_cast<Instruction>(V)) {
8000 // Don't attempt to analyze instructions in blocks that aren't
8001 // reachable. Such instructions don't matter, and they aren't required
8002 // to obey basic rules for definitions dominating uses which this
8003 // analysis depends on.
8004 if (!DT.isReachableFromEntry(I->getParent()))
8005 return getUnknown(PoisonValue::get(V->getType()));
8006 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
8007 return getConstant(CI);
8008 else if (isa<GlobalAlias>(V))
8009 return getUnknown(V);
8010 else if (!isa<ConstantExpr>(V))
8011 return getUnknown(V);
8012
8013 const SCEV *LHS;
8014 const SCEV *RHS;
8015
8017 if (auto BO =
8019 switch (BO->Opcode) {
8020 case Instruction::Add: {
8021 // The simple thing to do would be to just call getSCEV on both operands
8022 // and call getAddExpr with the result. However if we're looking at a
8023 // bunch of things all added together, this can be quite inefficient,
8024 // because it leads to N-1 getAddExpr calls for N ultimate operands.
8025 // Instead, gather up all the operands and make a single getAddExpr call.
8026 // LLVM IR canonical form means we need only traverse the left operands.
8028 do {
8029 if (BO->Op) {
8030 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
8031 AddOps.push_back(OpSCEV);
8032 break;
8033 }
8034
8035 // If a NUW or NSW flag can be applied to the SCEV for this
8036 // addition, then compute the SCEV for this addition by itself
8037 // with a separate call to getAddExpr. We need to do that
8038 // instead of pushing the operands of the addition onto AddOps,
8039 // since the flags are only known to apply to this particular
8040 // addition - they may not apply to other additions that can be
8041 // formed with operands from AddOps.
8042 const SCEV *RHS = getSCEV(BO->RHS);
8043 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
8044 if (Flags != SCEV::FlagAnyWrap) {
8045 const SCEV *LHS = getSCEV(BO->LHS);
8046 if (BO->Opcode == Instruction::Sub)
8047 AddOps.push_back(getMinusSCEV(LHS, RHS, Flags));
8048 else
8049 AddOps.push_back(getAddExpr(LHS, RHS, Flags));
8050 break;
8051 }
8052 }
8053
8054 if (BO->Opcode == Instruction::Sub)
8055 AddOps.push_back(getNegativeSCEV(getSCEV(BO->RHS)));
8056 else
8057 AddOps.push_back(getSCEV(BO->RHS));
8058
8059 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
8061 if (!NewBO || (NewBO->Opcode != Instruction::Add &&
8062 NewBO->Opcode != Instruction::Sub)) {
8063 AddOps.push_back(getSCEV(BO->LHS));
8064 break;
8065 }
8066 BO = NewBO;
8067 } while (true);
8068
8069 return getAddExpr(AddOps);
8070 }
8071
8072 case Instruction::Mul: {
8074 do {
8075 if (BO->Op) {
8076 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
8077 MulOps.push_back(OpSCEV);
8078 break;
8079 }
8080
8081 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
8082 if (Flags != SCEV::FlagAnyWrap) {
8083 LHS = getSCEV(BO->LHS);
8084 RHS = getSCEV(BO->RHS);
8085 MulOps.push_back(getMulExpr(LHS, RHS, Flags));
8086 break;
8087 }
8088 }
8089
8090 MulOps.push_back(getSCEV(BO->RHS));
8091 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
8093 if (!NewBO || NewBO->Opcode != Instruction::Mul) {
8094 MulOps.push_back(getSCEV(BO->LHS));
8095 break;
8096 }
8097 BO = NewBO;
8098 } while (true);
8099
8100 return getMulExpr(MulOps);
8101 }
8102 case Instruction::UDiv:
8103 LHS = getSCEV(BO->LHS);
8104 RHS = getSCEV(BO->RHS);
8105 return getUDivExpr(LHS, RHS);
8106 case Instruction::URem:
8107 LHS = getSCEV(BO->LHS);
8108 RHS = getSCEV(BO->RHS);
8109 return getURemExpr(LHS, RHS);
8110 case Instruction::Sub: {
8112 if (BO->Op)
8113 Flags = getNoWrapFlagsFromUB(BO->Op);
8114 LHS = getSCEV(BO->LHS);
8115 RHS = getSCEV(BO->RHS);
8116 return getMinusSCEV(LHS, RHS, Flags);
8117 }
8118 case Instruction::And:
8119 // For an expression like x&255 that merely masks off the high bits,
8120 // use zext(trunc(x)) as the SCEV expression.
8121 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
8122 if (CI->isZero())
8123 return getSCEV(BO->RHS);
8124 if (CI->isMinusOne())
8125 return getSCEV(BO->LHS);
8126 const APInt &A = CI->getValue();
8127
8128 // Instcombine's ShrinkDemandedConstant may strip bits out of
8129 // constants, obscuring what would otherwise be a low-bits mask.
8130 // Use computeKnownBits to compute what ShrinkDemandedConstant
8131 // knew about to reconstruct a low-bits mask value.
8132 unsigned LZ = A.countl_zero();
8133 unsigned TZ = A.countr_zero();
8134 unsigned BitWidth = A.getBitWidth();
8135 KnownBits Known(BitWidth);
8136 computeKnownBits(BO->LHS, Known, getDataLayout(), &AC, nullptr, &DT);
8137
8138 APInt EffectiveMask =
8139 APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ);
8140 if ((LZ != 0 || TZ != 0) && !((~A & ~Known.Zero) & EffectiveMask)) {
8141 const SCEV *MulCount = getConstant(APInt::getOneBitSet(BitWidth, TZ));
8142 const SCEV *LHS = getSCEV(BO->LHS);
8143 const SCEV *ShiftedLHS = nullptr;
8144 if (auto *LHSMul = dyn_cast<SCEVMulExpr>(LHS)) {
8145 if (auto *OpC = dyn_cast<SCEVConstant>(LHSMul->getOperand(0))) {
8146 // For an expression like (x * 8) & 8, simplify the multiply.
8147 unsigned MulZeros = OpC->getAPInt().countr_zero();
8148 unsigned GCD = std::min(MulZeros, TZ);
8149 APInt DivAmt = APInt::getOneBitSet(BitWidth, TZ - GCD);
8151 MulOps.push_back(getConstant(OpC->getAPInt().ashr(GCD)));
8152 append_range(MulOps, LHSMul->operands().drop_front());
8153 auto *NewMul = getMulExpr(MulOps, LHSMul->getNoWrapFlags());
8154 ShiftedLHS = getUDivExpr(NewMul, getConstant(DivAmt));
8155 }
8156 }
8157 if (!ShiftedLHS)
8158 ShiftedLHS = getUDivExpr(LHS, MulCount);
8159 return getMulExpr(
8161 getTruncateExpr(ShiftedLHS,
8162 IntegerType::get(getContext(), BitWidth - LZ - TZ)),
8163 BO->LHS->getType()),
8164 MulCount);
8165 }
8166 }
8167 // Binary `and` is a bit-wise `umin`.
8168 if (BO->LHS->getType()->isIntegerTy(1)) {
8169 LHS = getSCEV(BO->LHS);
8170 RHS = getSCEV(BO->RHS);
8171 return getUMinExpr(LHS, RHS);
8172 }
8173 break;
8174
8175 case Instruction::Or:
8176 // Binary `or` is a bit-wise `umax`.
8177 if (BO->LHS->getType()->isIntegerTy(1)) {
8178 LHS = getSCEV(BO->LHS);
8179 RHS = getSCEV(BO->RHS);
8180 return getUMaxExpr(LHS, RHS);
8181 }
8182 break;
8183
8184 case Instruction::Xor:
8185 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
8186 // If the RHS of xor is -1, then this is a not operation.
8187 if (CI->isMinusOne())
8188 return getNotSCEV(getSCEV(BO->LHS));
8189
8190 // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask.
8191 // This is a variant of the check for xor with -1, and it handles
8192 // the case where instcombine has trimmed non-demanded bits out
8193 // of an xor with -1.
8194 if (auto *LBO = dyn_cast<BinaryOperator>(BO->LHS))
8195 if (ConstantInt *LCI = dyn_cast<ConstantInt>(LBO->getOperand(1)))
8196 if (LBO->getOpcode() == Instruction::And &&
8197 LCI->getValue() == CI->getValue())
8198 if (const SCEVZeroExtendExpr *Z =
8200 Type *UTy = BO->LHS->getType();
8201 const SCEV *Z0 = Z->getOperand();
8202 Type *Z0Ty = Z0->getType();
8203 unsigned Z0TySize = getTypeSizeInBits(Z0Ty);
8204
8205 // If C is a low-bits mask, the zero extend is serving to
8206 // mask off the high bits. Complement the operand and
8207 // re-apply the zext.
8208 if (CI->getValue().isMask(Z0TySize))
8209 return getZeroExtendExpr(getNotSCEV(Z0), UTy);
8210
8211 // If C is a single bit, it may be in the sign-bit position
8212 // before the zero-extend. In this case, represent the xor
8213 // using an add, which is equivalent, and re-apply the zext.
8214 APInt Trunc = CI->getValue().trunc(Z0TySize);
8215 if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() &&
8216 Trunc.isSignMask())
8217 return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)),
8218 UTy);
8219 }
8220 }
8221 break;
8222
8223 case Instruction::Shl:
8224 // Turn shift left of a constant amount into a multiply.
8225 if (ConstantInt *SA = dyn_cast<ConstantInt>(BO->RHS)) {
8226 uint32_t BitWidth = cast<IntegerType>(SA->getType())->getBitWidth();
8227
8228 // If the shift count is not less than the bitwidth, the result of
8229 // the shift is undefined. Don't try to analyze it, because the
8230 // resolution chosen here may differ from the resolution chosen in
8231 // other parts of the compiler.
8232 if (SA->getValue().uge(BitWidth))
8233 break;
8234
8235 // We can safely preserve the nuw flag in all cases. It's also safe to
8236 // turn a nuw nsw shl into a nuw nsw mul. However, nsw in isolation
8237 // requires special handling. It can be preserved as long as we're not
8238 // left shifting by bitwidth - 1.
8239 auto Flags = SCEV::FlagAnyWrap;
8240 if (BO->Op) {
8241 auto MulFlags = getNoWrapFlagsFromUB(BO->Op);
8242 if (any(MulFlags & SCEV::FlagNSW) &&
8243 (any(MulFlags & SCEV::FlagNUW) ||
8244 SA->getValue().ult(BitWidth - 1)))
8246 if (any(MulFlags & SCEV::FlagNUW))
8248 }
8249
8250 ConstantInt *X = ConstantInt::get(
8251 getContext(), APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
8252 return getMulExpr(getSCEV(BO->LHS), getConstant(X), Flags);
8253 }
8254 break;
8255
8256 case Instruction::AShr:
8257 // AShr X, C, where C is a constant.
8258 ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS);
8259 if (!CI)
8260 break;
8261
8262 Type *OuterTy = BO->LHS->getType();
8263 uint64_t BitWidth = getTypeSizeInBits(OuterTy);
8264 // If the shift count is not less than the bitwidth, the result of
8265 // the shift is undefined. Don't try to analyze it, because the
8266 // resolution chosen here may differ from the resolution chosen in
8267 // other parts of the compiler.
8268 if (CI->getValue().uge(BitWidth))
8269 break;
8270
8271 if (CI->isZero())
8272 return getSCEV(BO->LHS); // shift by zero --> noop
8273
8274 uint64_t AShrAmt = CI->getZExtValue();
8275 Type *TruncTy = IntegerType::get(getContext(), BitWidth - AShrAmt);
8276
8277 Operator *L = dyn_cast<Operator>(BO->LHS);
8278 const SCEV *AddTruncateExpr = nullptr;
8279 ConstantInt *ShlAmtCI = nullptr;
8280 const SCEV *AddConstant = nullptr;
8281
8282 if (L && L->getOpcode() == Instruction::Add) {
8283 // X = Shl A, n
8284 // Y = Add X, c
8285 // Z = AShr Y, m
8286 // n, c and m are constants.
8287
8288 Operator *LShift = dyn_cast<Operator>(L->getOperand(0));
8289 ConstantInt *AddOperandCI = dyn_cast<ConstantInt>(L->getOperand(1));
8290 if (LShift && LShift->getOpcode() == Instruction::Shl) {
8291 if (AddOperandCI) {
8292 const SCEV *ShlOp0SCEV = getSCEV(LShift->getOperand(0));
8293 ShlAmtCI = dyn_cast<ConstantInt>(LShift->getOperand(1));
8294 // since we truncate to TruncTy, the AddConstant should be of the
8295 // same type, so create a new Constant with type same as TruncTy.
8296 // Also, the Add constant should be shifted right by AShr amount.
8297 APInt AddOperand = AddOperandCI->getValue().ashr(AShrAmt);
8298 AddConstant = getConstant(AddOperand.trunc(BitWidth - AShrAmt));
8299 // we model the expression as sext(add(trunc(A), c << n)), since the
8300 // sext(trunc) part is already handled below, we create a
8301 // AddExpr(TruncExp) which will be used later.
8302 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
8303 }
8304 }
8305 } else if (L && L->getOpcode() == Instruction::Shl) {
8306 // X = Shl A, n
8307 // Y = AShr X, m
8308 // Both n and m are constant.
8309
8310 const SCEV *ShlOp0SCEV = getSCEV(L->getOperand(0));
8311 ShlAmtCI = dyn_cast<ConstantInt>(L->getOperand(1));
8312 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
8313 }
8314
8315 if (AddTruncateExpr && ShlAmtCI) {
8316 // We can merge the two given cases into a single SCEV statement,
8317 // incase n = m, the mul expression will be 2^0, so it gets resolved to
8318 // a simpler case. The following code handles the two cases:
8319 //
8320 // 1) For a two-shift sext-inreg, i.e. n = m,
8321 // use sext(trunc(x)) as the SCEV expression.
8322 //
8323 // 2) When n > m, use sext(mul(trunc(x), 2^(n-m)))) as the SCEV
8324 // expression. We already checked that ShlAmt < BitWidth, so
8325 // the multiplier, 1 << (ShlAmt - AShrAmt), fits into TruncTy as
8326 // ShlAmt - AShrAmt < Amt.
8327 const APInt &ShlAmt = ShlAmtCI->getValue();
8328 if (ShlAmt.ult(BitWidth) && ShlAmt.uge(AShrAmt)) {
8329 APInt Mul = APInt::getOneBitSet(BitWidth - AShrAmt,
8330 ShlAmtCI->getZExtValue() - AShrAmt);
8331 const SCEV *CompositeExpr =
8332 getMulExpr(AddTruncateExpr, getConstant(Mul));
8333 if (L->getOpcode() != Instruction::Shl)
8334 CompositeExpr = getAddExpr(CompositeExpr, AddConstant);
8335
8336 return getSignExtendExpr(CompositeExpr, OuterTy);
8337 }
8338 }
8339 break;
8340 }
8341 }
8342
8343 switch (U->getOpcode()) {
8344 case Instruction::Trunc:
8345 return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType());
8346
8347 case Instruction::ZExt:
8348 return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType());
8349
8350 case Instruction::SExt:
8351 if (auto BO = MatchBinaryOp(U->getOperand(0), getDataLayout(), AC, DT,
8353 // The NSW flag of a subtract does not always survive the conversion to
8354 // A + (-1)*B. By pushing sign extension onto its operands we are much
8355 // more likely to preserve NSW and allow later AddRec optimisations.
8356 //
8357 // NOTE: This is effectively duplicating this logic from getSignExtend:
8358 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
8359 // but by that point the NSW information has potentially been lost.
8360 if (BO->Opcode == Instruction::Sub && BO->IsNSW) {
8361 Type *Ty = U->getType();
8362 auto *V1 = getSignExtendExpr(getSCEV(BO->LHS), Ty);
8363 auto *V2 = getSignExtendExpr(getSCEV(BO->RHS), Ty);
8364 return getMinusSCEV(V1, V2, SCEV::FlagNSW);
8365 }
8366 }
8367 return getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType());
8368
8369 case Instruction::BitCast:
8370 // BitCasts are no-op casts so we just eliminate the cast.
8371 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType()))
8372 return getSCEV(U->getOperand(0));
8373 break;
8374
8375 case Instruction::PtrToAddr: {
8376 const SCEV *IntOp = getPtrToAddrExpr(getSCEV(U->getOperand(0)));
8377 if (isa<SCEVCouldNotCompute>(IntOp))
8378 return getUnknown(V);
8379 return IntOp;
8380 }
8381
8382 case Instruction::PtrToInt: {
8383 // Pointer to integer cast is straight-forward, so do model it.
8384 const SCEV *Op = getSCEV(U->getOperand(0));
8385 Type *DstIntTy = U->getType();
8386 // But only if effective SCEV (integer) type is wide enough to represent
8387 // all possible pointer values.
8388 const SCEV *IntOp = getPtrToIntExpr(Op, DstIntTy);
8389 if (isa<SCEVCouldNotCompute>(IntOp))
8390 return getUnknown(V);
8391 return IntOp;
8392 }
8393 case Instruction::IntToPtr:
8394 // Just don't deal with inttoptr casts.
8395 return getUnknown(V);
8396
8397 case Instruction::SDiv:
8398 // If both operands are non-negative, this is just an udiv.
8399 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8400 isKnownNonNegative(getSCEV(U->getOperand(1))))
8401 return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8402 break;
8403
8404 case Instruction::SRem:
8405 // If both operands are non-negative, this is just an urem.
8406 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8407 isKnownNonNegative(getSCEV(U->getOperand(1))))
8408 return getURemExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8409 break;
8410
8411 case Instruction::GetElementPtr:
8412 return createNodeForGEP(cast<GEPOperator>(U));
8413
8414 case Instruction::PHI:
8415 return createNodeForPHI(cast<PHINode>(U));
8416
8417 case Instruction::Select:
8418 return createNodeForSelectOrPHI(U, U->getOperand(0), U->getOperand(1),
8419 U->getOperand(2));
8420
8421 case Instruction::Call:
8422 case Instruction::Invoke:
8423 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand())
8424 return getSCEV(RV);
8425
8426 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
8427 switch (II->getIntrinsicID()) {
8428 case Intrinsic::abs:
8429 return getAbsExpr(
8430 getSCEV(II->getArgOperand(0)),
8431 /*IsNSW=*/cast<ConstantInt>(II->getArgOperand(1))->isOne());
8432 case Intrinsic::umax:
8433 LHS = getSCEV(II->getArgOperand(0));
8434 RHS = getSCEV(II->getArgOperand(1));
8435 return getUMaxExpr(LHS, RHS);
8436 case Intrinsic::umin:
8437 LHS = getSCEV(II->getArgOperand(0));
8438 RHS = getSCEV(II->getArgOperand(1));
8439 return getUMinExpr(LHS, RHS);
8440 case Intrinsic::smax:
8441 LHS = getSCEV(II->getArgOperand(0));
8442 RHS = getSCEV(II->getArgOperand(1));
8443 return getSMaxExpr(LHS, RHS);
8444 case Intrinsic::smin:
8445 LHS = getSCEV(II->getArgOperand(0));
8446 RHS = getSCEV(II->getArgOperand(1));
8447 return getSMinExpr(LHS, RHS);
8448 case Intrinsic::usub_sat: {
8449 const SCEV *X = getSCEV(II->getArgOperand(0));
8450 const SCEV *Y = getSCEV(II->getArgOperand(1));
8451 const SCEV *ClampedY = getUMinExpr(X, Y);
8452 return getMinusSCEV(X, ClampedY, SCEV::FlagNUW);
8453 }
8454 case Intrinsic::uadd_sat: {
8455 const SCEV *X = getSCEV(II->getArgOperand(0));
8456 const SCEV *Y = getSCEV(II->getArgOperand(1));
8457 const SCEV *ClampedX = getUMinExpr(X, getNotSCEV(Y));
8458 return getAddExpr(ClampedX, Y, SCEV::FlagNUW);
8459 }
8460 case Intrinsic::start_loop_iterations:
8461 case Intrinsic::annotation:
8462 case Intrinsic::ptr_annotation:
8463 // A start_loop_iterations or llvm.annotation or llvm.prt.annotation is
8464 // just eqivalent to the first operand for SCEV purposes.
8465 return getSCEV(II->getArgOperand(0));
8466 case Intrinsic::vscale:
8467 return getVScale(II->getType());
8468 default:
8469 break;
8470 }
8471 }
8472 break;
8473 }
8474
8475 return getUnknown(V);
8476}
8477
8478//===----------------------------------------------------------------------===//
8479// Iteration Count Computation Code
8480//
8481
8483 if (isa<SCEVCouldNotCompute>(ExitCount))
8484 return getCouldNotCompute();
8485
8486 auto *ExitCountType = ExitCount->getType();
8487 assert(ExitCountType->isIntegerTy());
8488 auto *EvalTy = Type::getIntNTy(ExitCountType->getContext(),
8489 1 + ExitCountType->getScalarSizeInBits());
8490 return getTripCountFromExitCount(ExitCount, EvalTy, nullptr);
8491}
8492
8494 Type *EvalTy,
8495 const Loop *L) {
8496 if (isa<SCEVCouldNotCompute>(ExitCount))
8497 return getCouldNotCompute();
8498
8499 unsigned ExitCountSize = getTypeSizeInBits(ExitCount->getType());
8500 unsigned EvalSize = EvalTy->getPrimitiveSizeInBits();
8501
8502 auto CanAddOneWithoutOverflow = [&]() {
8503 ConstantRange ExitCountRange =
8504 getRangeRef(ExitCount, RangeSignHint::HINT_RANGE_UNSIGNED);
8505 if (!ExitCountRange.contains(APInt::getMaxValue(ExitCountSize)))
8506 return true;
8507
8508 return L && isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, ExitCount,
8509 getMinusOne(ExitCount->getType()));
8510 };
8511
8512 // If we need to zero extend the backedge count, check if we can add one to
8513 // it prior to zero extending without overflow. Provided this is safe, it
8514 // allows better simplification of the +1.
8515 if (EvalSize > ExitCountSize && CanAddOneWithoutOverflow())
8516 return getZeroExtendExpr(
8517 getAddExpr(ExitCount, getOne(ExitCount->getType())), EvalTy);
8518
8519 // Get the total trip count from the count by adding 1. This may wrap.
8520 return getAddExpr(getTruncateOrZeroExtend(ExitCount, EvalTy), getOne(EvalTy));
8521}
8522
8523static unsigned getConstantTripCount(const SCEVConstant *ExitCount) {
8524 if (!ExitCount)
8525 return 0;
8526
8527 ConstantInt *ExitConst = ExitCount->getValue();
8528
8529 // Guard against huge trip counts.
8530 if (ExitConst->getValue().getActiveBits() > 32)
8531 return 0;
8532
8533 // In case of integer overflow, this returns 0, which is correct.
8534 return ((unsigned)ExitConst->getZExtValue()) + 1;
8535}
8536
8538 auto *ExitCount = dyn_cast<SCEVConstant>(getBackedgeTakenCount(L, Exact));
8539 return getConstantTripCount(ExitCount);
8540}
8541
8542unsigned
8544 const BasicBlock *ExitingBlock) {
8545 assert(ExitingBlock && "Must pass a non-null exiting block!");
8546 assert(L->isLoopExiting(ExitingBlock) &&
8547 "Exiting block must actually branch out of the loop!");
8548 const SCEVConstant *ExitCount =
8549 dyn_cast<SCEVConstant>(getExitCount(L, ExitingBlock));
8550 return getConstantTripCount(ExitCount);
8551}
8552
8554 const Loop *L, SmallVectorImpl<const SCEVPredicate *> *Predicates) {
8555
8556 const auto *MaxExitCount =
8557 Predicates ? getPredicatedConstantMaxBackedgeTakenCount(L, *Predicates)
8559 return getConstantTripCount(dyn_cast<SCEVConstant>(MaxExitCount));
8560}
8561
8563 SmallVector<BasicBlock *, 8> ExitingBlocks;
8564 L->getExitingBlocks(ExitingBlocks);
8565
8566 std::optional<unsigned> Res;
8567 for (auto *ExitingBB : ExitingBlocks) {
8568 unsigned Multiple = getSmallConstantTripMultiple(L, ExitingBB);
8569 if (!Res)
8570 Res = Multiple;
8571 Res = std::gcd(*Res, Multiple);
8572 }
8573 return Res.value_or(1);
8574}
8575
8577 const SCEV *ExitCount) {
8578 if (isa<SCEVCouldNotCompute>(ExitCount))
8579 return 1;
8580
8581 // Get the trip count
8582 const SCEV *TCExpr = getTripCountFromExitCount(applyLoopGuards(ExitCount, L));
8583
8584 APInt Multiple = getNonZeroConstantMultiple(TCExpr);
8585 // If a trip multiple is huge (>=2^32), the trip count is still divisible by
8586 // the greatest power of 2 divisor less than 2^32.
8587 return Multiple.getActiveBits() > 32
8588 ? 1U << std::min(31U, Multiple.countTrailingZeros())
8589 : (unsigned)Multiple.getZExtValue();
8590}
8591
8592/// Returns the largest constant divisor of the trip count of this loop as a
8593/// normal unsigned value, if possible. This means that the actual trip count is
8594/// always a multiple of the returned value (don't forget the trip count could
8595/// very well be zero as well!).
8596///
8597/// Returns 1 if the trip count is unknown or not guaranteed to be the
8598/// multiple of a constant (which is also the case if the trip count is simply
8599/// constant, use getSmallConstantTripCount for that case), Will also return 1
8600/// if the trip count is very large (>= 2^32).
8601///
8602/// As explained in the comments for getSmallConstantTripCount, this assumes
8603/// that control exits the loop via ExitingBlock.
8604unsigned
8606 const BasicBlock *ExitingBlock) {
8607 assert(ExitingBlock && "Must pass a non-null exiting block!");
8608 assert(L->isLoopExiting(ExitingBlock) &&
8609 "Exiting block must actually branch out of the loop!");
8610 const SCEV *ExitCount = getExitCount(L, ExitingBlock);
8611 return getSmallConstantTripMultiple(L, ExitCount);
8612}
8613
8615 const BasicBlock *ExitingBlock,
8616 ExitCountKind Kind) {
8617 switch (Kind) {
8618 case Exact:
8619 return getBackedgeTakenInfo(L).getExact(ExitingBlock, this);
8620 case SymbolicMaximum:
8621 return getBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this);
8622 case ConstantMaximum:
8623 return getBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this);
8624 };
8625 llvm_unreachable("Invalid ExitCountKind!");
8626}
8627
8629 const Loop *L, const BasicBlock *ExitingBlock,
8631 switch (Kind) {
8632 case Exact:
8633 return getPredicatedBackedgeTakenInfo(L).getExact(ExitingBlock, this,
8634 Predicates);
8635 case SymbolicMaximum:
8636 return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this,
8637 Predicates);
8638 case ConstantMaximum:
8639 return getPredicatedBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this,
8640 Predicates);
8641 };
8642 llvm_unreachable("Invalid ExitCountKind!");
8643}
8644
8647 return getPredicatedBackedgeTakenInfo(L).getExact(L, this, &Preds);
8648}
8649
8651 ExitCountKind Kind) {
8652 switch (Kind) {
8653 case Exact:
8654 return getBackedgeTakenInfo(L).getExact(L, this);
8655 case ConstantMaximum:
8656 return getBackedgeTakenInfo(L).getConstantMax(this);
8657 case SymbolicMaximum:
8658 return getBackedgeTakenInfo(L).getSymbolicMax(L, this);
8659 };
8660 llvm_unreachable("Invalid ExitCountKind!");
8661}
8662
8665 return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(L, this, &Preds);
8666}
8667
8670 return getPredicatedBackedgeTakenInfo(L).getConstantMax(this, &Preds);
8671}
8672
8674 return getBackedgeTakenInfo(L).isConstantMaxOrZero(this);
8675}
8676
8677/// Push PHI nodes in the header of the given loop onto the given Worklist.
8678static void PushLoopPHIs(const Loop *L,
8681 BasicBlock *Header = L->getHeader();
8682
8683 // Push all Loop-header PHIs onto the Worklist stack.
8684 for (PHINode &PN : Header->phis())
8685 if (Visited.insert(&PN).second)
8686 Worklist.push_back(&PN);
8687}
8688
8689ScalarEvolution::BackedgeTakenInfo &
8690ScalarEvolution::getPredicatedBackedgeTakenInfo(const Loop *L) {
8691 auto &BTI = getBackedgeTakenInfo(L);
8692 if (BTI.hasFullInfo())
8693 return BTI;
8694
8695 auto Pair = PredicatedBackedgeTakenCounts.try_emplace(L);
8696
8697 if (!Pair.second)
8698 return Pair.first->second;
8699
8700 BackedgeTakenInfo Result =
8701 computeBackedgeTakenCount(L, /*AllowPredicates=*/true);
8702
8703 return PredicatedBackedgeTakenCounts.find(L)->second = std::move(Result);
8704}
8705
8706ScalarEvolution::BackedgeTakenInfo &
8707ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
8708 // Initially insert an invalid entry for this loop. If the insertion
8709 // succeeds, proceed to actually compute a backedge-taken count and
8710 // update the value. The temporary CouldNotCompute value tells SCEV
8711 // code elsewhere that it shouldn't attempt to request a new
8712 // backedge-taken count, which could result in infinite recursion.
8713 std::pair<DenseMap<const Loop *, BackedgeTakenInfo>::iterator, bool> Pair =
8714 BackedgeTakenCounts.try_emplace(L);
8715 if (!Pair.second)
8716 return Pair.first->second;
8717
8718 // computeBackedgeTakenCount may allocate memory for its result. Inserting it
8719 // into the BackedgeTakenCounts map transfers ownership. Otherwise, the result
8720 // must be cleared in this scope.
8721 BackedgeTakenInfo Result = computeBackedgeTakenCount(L);
8722
8723 // Now that we know more about the trip count for this loop, forget any
8724 // existing SCEV values for PHI nodes in this loop since they are only
8725 // conservative estimates made without the benefit of trip count
8726 // information. This invalidation is not necessary for correctness, and is
8727 // only done to produce more precise results.
8728 if (Result.hasAnyInfo()) {
8729 // Invalidate any expression using an addrec in this loop.
8730 SmallVector<SCEVUse, 8> ToForget;
8731 auto LoopUsersIt = LoopUsers.find(L);
8732 if (LoopUsersIt != LoopUsers.end())
8733 append_range(ToForget, LoopUsersIt->second);
8734 forgetMemoizedResults(ToForget);
8735
8736 // Invalidate constant-evolved loop header phis.
8737 for (PHINode &PN : L->getHeader()->phis())
8738 ConstantEvolutionLoopExitValue.erase(&PN);
8739 }
8740
8741 // Re-lookup the insert position, since the call to
8742 // computeBackedgeTakenCount above could result in a
8743 // recusive call to getBackedgeTakenInfo (on a different
8744 // loop), which would invalidate the iterator computed
8745 // earlier.
8746 return BackedgeTakenCounts.find(L)->second = std::move(Result);
8747}
8748
8750 // This method is intended to forget all info about loops. It should
8751 // invalidate caches as if the following happened:
8752 // - The trip counts of all loops have changed arbitrarily
8753 // - Every llvm::Value has been updated in place to produce a different
8754 // result.
8755 BackedgeTakenCounts.clear();
8756 PredicatedBackedgeTakenCounts.clear();
8757 BECountUsers.clear();
8758 LoopPropertiesCache.clear();
8759 ConstantEvolutionLoopExitValue.clear();
8760 ValueExprMap.clear();
8761 ValuesAtScopes.clear();
8762 ValuesAtScopesUsers.clear();
8763 LoopDispositions.clear();
8764 BlockDispositions.clear();
8765 UnsignedRanges.clear();
8766 SignedRanges.clear();
8767 ExprValueMap.clear();
8768 HasRecMap.clear();
8769 ConstantMultipleCache.clear();
8770 PredicatedSCEVRewrites.clear();
8771 FoldCache.clear();
8772 FoldCacheUser.clear();
8773}
8774void ScalarEvolution::visitAndClearUsers(
8777 SmallVectorImpl<SCEVUse> &ToForget) {
8778 while (!Worklist.empty()) {
8779 Instruction *I = Worklist.pop_back_val();
8780 if (!isSCEVable(I->getType()) && !isa<WithOverflowInst>(I))
8781 continue;
8782
8784 ValueExprMap.find_as(static_cast<Value *>(I));
8785 if (It != ValueExprMap.end()) {
8786 ToForget.push_back(It->second);
8787 eraseValueFromMap(It->first);
8788 if (PHINode *PN = dyn_cast<PHINode>(I))
8789 ConstantEvolutionLoopExitValue.erase(PN);
8790 }
8791
8792 PushDefUseChildren(I, Worklist, Visited);
8793 }
8794}
8795
8797 SmallVector<const Loop *, 16> LoopWorklist(1, L);
8800 SmallVector<SCEVUse, 16> ToForget;
8801
8802 // Iterate over all the loops and sub-loops to drop SCEV information.
8803 while (!LoopWorklist.empty()) {
8804 auto *CurrL = LoopWorklist.pop_back_val();
8805
8806 // Drop any stored trip count value.
8807 forgetBackedgeTakenCounts(CurrL, /* Predicated */ false);
8808 forgetBackedgeTakenCounts(CurrL, /* Predicated */ true);
8809
8810 // Drop information about predicated SCEV rewrites for this loop.
8811 PredicatedSCEVRewrites.remove_if(
8812 [&](const auto &Entry) { return Entry.first.second == CurrL; });
8813
8814 auto LoopUsersItr = LoopUsers.find(CurrL);
8815 if (LoopUsersItr != LoopUsers.end())
8816 llvm::append_range(ToForget, LoopUsersItr->second);
8817
8818 // Drop information about expressions based on loop-header PHIs.
8819 PushLoopPHIs(CurrL, Worklist, Visited);
8820 visitAndClearUsers(Worklist, Visited, ToForget);
8821
8822 LoopPropertiesCache.erase(CurrL);
8823 // Forget all contained loops too, to avoid dangling entries in the
8824 // ValuesAtScopes map.
8825 LoopWorklist.append(CurrL->begin(), CurrL->end());
8826 }
8827 forgetMemoizedResults(ToForget);
8828}
8829
8831 forgetLoop(L->getOutermostLoop());
8832}
8833
8836 if (!I) return;
8837
8838 // Drop information about expressions based on loop-header PHIs.
8841 SmallVector<SCEVUse, 8> ToForget;
8842 Worklist.push_back(I);
8843 Visited.insert(I);
8844 visitAndClearUsers(Worklist, Visited, ToForget);
8845
8846 forgetMemoizedResults(ToForget);
8847}
8848
8850 if (!isSCEVable(V->getType()))
8851 return;
8852
8853 // If SCEV looked through a trivial LCSSA phi node, we might have SCEV's
8854 // directly using a SCEVUnknown/SCEVAddRec defined in the loop. After an
8855 // extra predecessor is added, this is no longer valid. Find all Unknowns and
8856 // AddRecs defined in the loop and invalidate any SCEV's making use of them.
8857 if (const SCEV *S = getExistingSCEV(V)) {
8858 struct InvalidationRootCollector {
8859 Loop *L;
8861
8862 InvalidationRootCollector(Loop *L) : L(L) {}
8863
8864 bool follow(const SCEV *S) {
8865 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
8866 if (auto *I = dyn_cast<Instruction>(SU->getValue()))
8867 if (L->contains(I))
8868 Roots.push_back(S);
8869 } else if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
8870 if (L->contains(AddRec->getLoop()))
8871 Roots.push_back(S);
8872 }
8873 return true;
8874 }
8875 bool isDone() const { return false; }
8876 };
8877
8878 InvalidationRootCollector C(L);
8879 visitAll(S, C);
8880 forgetMemoizedResults(C.Roots);
8881 }
8882
8883 // Also perform the normal invalidation.
8884 forgetValue(V);
8885}
8886
8887void ScalarEvolution::forgetLoopDispositions() { LoopDispositions.clear(); }
8888
8890 // Unless a specific value is passed to invalidation, completely clear both
8891 // caches.
8892 if (!V) {
8893 BlockDispositions.clear();
8894 LoopDispositions.clear();
8895 return;
8896 }
8897
8898 if (!isSCEVable(V->getType()))
8899 return;
8900
8901 const SCEV *S = getExistingSCEV(V);
8902 if (!S)
8903 return;
8904
8905 // Invalidate the block and loop dispositions cached for S. Dispositions of
8906 // S's users may change if S's disposition changes (i.e. a user may change to
8907 // loop-invariant, if S changes to loop invariant), so also invalidate
8908 // dispositions of S's users recursively.
8909 SmallVector<SCEVUse, 8> Worklist = {S};
8911 while (!Worklist.empty()) {
8912 const SCEV *Curr = Worklist.pop_back_val();
8913 bool LoopDispoRemoved = LoopDispositions.erase(Curr);
8914 bool BlockDispoRemoved = BlockDispositions.erase(Curr);
8915 if (!LoopDispoRemoved && !BlockDispoRemoved)
8916 continue;
8917 auto Users = SCEVUsers.find(Curr);
8918 if (Users != SCEVUsers.end())
8919 for (const auto *User : Users->second)
8920 if (Seen.insert(User).second)
8921 Worklist.push_back(User);
8922 }
8923}
8924
8925/// Get the exact loop backedge taken count considering all loop exits. A
8926/// computable result can only be returned for loops with all exiting blocks
8927/// dominating the latch. howFarToZero assumes that the limit of each loop test
8928/// is never skipped. This is a valid assumption as long as the loop exits via
8929/// that test. For precise results, it is the caller's responsibility to specify
8930/// the relevant loop exiting block using getExact(ExitingBlock, SE).
8931const SCEV *ScalarEvolution::BackedgeTakenInfo::getExact(
8932 const Loop *L, ScalarEvolution *SE,
8934 // If any exits were not computable, the loop is not computable.
8935 if (!isComplete() || ExitNotTaken.empty())
8936 return SE->getCouldNotCompute();
8937
8938 const BasicBlock *Latch = L->getLoopLatch();
8939 // All exiting blocks we have collected must dominate the only backedge.
8940 if (!Latch)
8941 return SE->getCouldNotCompute();
8942
8943 // All exiting blocks we have gathered dominate loop's latch, so exact trip
8944 // count is simply a minimum out of all these calculated exit counts.
8946 for (const auto &ENT : ExitNotTaken) {
8947 const SCEV *BECount = ENT.ExactNotTaken;
8948 assert(BECount != SE->getCouldNotCompute() && "Bad exit SCEV!");
8949 assert(SE->DT.dominates(ENT.ExitingBlock, Latch) &&
8950 "We should only have known counts for exiting blocks that dominate "
8951 "latch!");
8952
8953 Ops.push_back(BECount);
8954
8955 if (Preds)
8956 append_range(*Preds, ENT.Predicates);
8957
8958 assert((Preds || ENT.hasAlwaysTruePredicate()) &&
8959 "Predicate should be always true!");
8960 }
8961
8962 // If an earlier exit exits on the first iteration (exit count zero), then
8963 // a later poison exit count should not propagate into the result. This are
8964 // exactly the semantics provided by umin_seq.
8965 return SE->getUMinFromMismatchedTypes(Ops, /* Sequential */ true);
8966}
8967
8968const ScalarEvolution::ExitNotTakenInfo *
8969ScalarEvolution::BackedgeTakenInfo::getExitNotTaken(
8970 const BasicBlock *ExitingBlock,
8971 SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
8972 for (const auto &ENT : ExitNotTaken)
8973 if (ENT.ExitingBlock == ExitingBlock) {
8974 if (ENT.hasAlwaysTruePredicate())
8975 return &ENT;
8976 else if (Predicates) {
8977 append_range(*Predicates, ENT.Predicates);
8978 return &ENT;
8979 }
8980 }
8981
8982 return nullptr;
8983}
8984
8985/// getConstantMax - Get the constant max backedge taken count for the loop.
8986const SCEV *ScalarEvolution::BackedgeTakenInfo::getConstantMax(
8987 ScalarEvolution *SE,
8988 SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
8989 if (!getConstantMax())
8990 return SE->getCouldNotCompute();
8991
8992 for (const auto &ENT : ExitNotTaken)
8993 if (!ENT.hasAlwaysTruePredicate()) {
8994 if (!Predicates)
8995 return SE->getCouldNotCompute();
8996 append_range(*Predicates, ENT.Predicates);
8997 }
8998
8999 assert((isa<SCEVCouldNotCompute>(getConstantMax()) ||
9000 isa<SCEVConstant>(getConstantMax())) &&
9001 "No point in having a non-constant max backedge taken count!");
9002 return getConstantMax();
9003}
9004
9005const SCEV *ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(
9006 const Loop *L, ScalarEvolution *SE,
9007 SmallVectorImpl<const SCEVPredicate *> *Predicates) {
9008 if (!SymbolicMax) {
9009 // Form an expression for the maximum exit count possible for this loop. We
9010 // merge the max and exact information to approximate a version of
9011 // getConstantMaxBackedgeTakenCount which isn't restricted to just
9012 // constants.
9013 SmallVector<SCEVUse, 4> ExitCounts;
9014
9015 for (const auto &ENT : ExitNotTaken) {
9016 const SCEV *ExitCount = ENT.SymbolicMaxNotTaken;
9017 if (!isa<SCEVCouldNotCompute>(ExitCount)) {
9018 assert(SE->DT.dominates(ENT.ExitingBlock, L->getLoopLatch()) &&
9019 "We should only have known counts for exiting blocks that "
9020 "dominate latch!");
9021 ExitCounts.push_back(ExitCount);
9022 if (Predicates)
9023 append_range(*Predicates, ENT.Predicates);
9024
9025 assert((Predicates || ENT.hasAlwaysTruePredicate()) &&
9026 "Predicate should be always true!");
9027 }
9028 }
9029 if (ExitCounts.empty())
9030 SymbolicMax = SE->getCouldNotCompute();
9031 else
9032 SymbolicMax =
9033 SE->getUMinFromMismatchedTypes(ExitCounts, /*Sequential*/ true);
9034 }
9035 return SymbolicMax;
9036}
9037
9038bool ScalarEvolution::BackedgeTakenInfo::isConstantMaxOrZero(
9039 ScalarEvolution *SE) const {
9040 auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) {
9041 return !ENT.hasAlwaysTruePredicate();
9042 };
9043 return MaxOrZero && !any_of(ExitNotTaken, PredicateNotAlwaysTrue);
9044}
9045
9048
9050 const SCEV *E, const SCEV *ConstantMaxNotTaken,
9051 const SCEV *SymbolicMaxNotTaken, bool MaxOrZero,
9055 // If we prove the max count is zero, so is the symbolic bound. This happens
9056 // in practice due to differences in a) how context sensitive we've chosen
9057 // to be and b) how we reason about bounds implied by UB.
9058 if (ConstantMaxNotTaken->isZero()) {
9059 this->ExactNotTaken = E = ConstantMaxNotTaken;
9060 this->SymbolicMaxNotTaken = SymbolicMaxNotTaken = ConstantMaxNotTaken;
9061 }
9062
9065 "Exact is not allowed to be less precise than Constant Max");
9068 "Exact is not allowed to be less precise than Symbolic Max");
9071 "Symbolic Max is not allowed to be less precise than Constant Max");
9074 "No point in having a non-constant max backedge taken count!");
9076 for (const auto PredList : PredLists)
9077 for (const auto *P : PredList) {
9078 if (SeenPreds.contains(P))
9079 continue;
9080 assert(!isa<SCEVUnionPredicate>(P) && "Only add leaf predicates here!");
9081 SeenPreds.insert(P);
9082 Predicates.push_back(P);
9083 }
9084 assert((isa<SCEVCouldNotCompute>(E) || !E->getType()->isPointerTy()) &&
9085 "Backedge count should be int");
9087 !ConstantMaxNotTaken->getType()->isPointerTy()) &&
9088 "Max backedge count should be int");
9089}
9090
9098
9099/// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each
9100/// computable exit into a persistent ExitNotTakenInfo array.
9101ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(
9103 bool IsComplete, const SCEV *ConstantMax, bool MaxOrZero)
9104 : ConstantMax(ConstantMax), IsComplete(IsComplete), MaxOrZero(MaxOrZero) {
9105 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
9106
9107 ExitNotTaken.reserve(ExitCounts.size());
9108 std::transform(ExitCounts.begin(), ExitCounts.end(),
9109 std::back_inserter(ExitNotTaken),
9110 [&](const EdgeExitInfo &EEI) {
9111 BasicBlock *ExitBB = EEI.first;
9112 const ExitLimit &EL = EEI.second;
9113 return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken,
9114 EL.ConstantMaxNotTaken, EL.SymbolicMaxNotTaken,
9115 EL.Predicates);
9116 });
9117 assert((isa<SCEVCouldNotCompute>(ConstantMax) ||
9118 isa<SCEVConstant>(ConstantMax)) &&
9119 "No point in having a non-constant max backedge taken count!");
9120}
9121
9122/// Compute the number of times the backedge of the specified loop will execute.
9123ScalarEvolution::BackedgeTakenInfo
9124ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
9125 bool AllowPredicates) {
9126 SmallVector<BasicBlock *, 8> ExitingBlocks;
9127 L->getExitingBlocks(ExitingBlocks);
9128
9129 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
9130
9132 bool CouldComputeBECount = true;
9133 BasicBlock *Latch = L->getLoopLatch(); // may be NULL.
9134 const SCEV *MustExitMaxBECount = nullptr;
9135 const SCEV *MayExitMaxBECount = nullptr;
9136 bool MustExitMaxOrZero = false;
9137 bool IsOnlyExit = ExitingBlocks.size() == 1;
9138
9139 // Compute the ExitLimit for each loop exit. Use this to populate ExitCounts
9140 // and compute maxBECount.
9141 // Do a union of all the predicates here.
9142 for (BasicBlock *ExitBB : ExitingBlocks) {
9143 // We canonicalize untaken exits to br (constant), ignore them so that
9144 // proving an exit untaken doesn't negatively impact our ability to reason
9145 // about the loop as whole.
9146 if (auto *BI = dyn_cast<CondBrInst>(ExitBB->getTerminator()))
9147 if (auto *CI = dyn_cast<ConstantInt>(BI->getCondition())) {
9148 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
9149 if (ExitIfTrue == CI->isZero())
9150 continue;
9151 }
9152
9153 ExitLimit EL = computeExitLimit(L, ExitBB, IsOnlyExit, AllowPredicates);
9154
9155 assert((AllowPredicates || EL.Predicates.empty()) &&
9156 "Predicated exit limit when predicates are not allowed!");
9157
9158 // 1. For each exit that can be computed, add an entry to ExitCounts.
9159 // CouldComputeBECount is true only if all exits can be computed.
9160 if (EL.ExactNotTaken != getCouldNotCompute())
9161 ++NumExitCountsComputed;
9162 else
9163 // We couldn't compute an exact value for this exit, so
9164 // we won't be able to compute an exact value for the loop.
9165 CouldComputeBECount = false;
9166 // Remember exit count if either exact or symbolic is known. Because
9167 // Exact always implies symbolic, only check symbolic.
9168 if (EL.SymbolicMaxNotTaken != getCouldNotCompute())
9169 ExitCounts.emplace_back(ExitBB, EL);
9170 else {
9171 assert(EL.ExactNotTaken == getCouldNotCompute() &&
9172 "Exact is known but symbolic isn't?");
9173 ++NumExitCountsNotComputed;
9174 }
9175
9176 // 2. Derive the loop's MaxBECount from each exit's max number of
9177 // non-exiting iterations. Partition the loop exits into two kinds:
9178 // LoopMustExits and LoopMayExits.
9179 //
9180 // If the exit dominates the loop latch, it is a LoopMustExit otherwise it
9181 // is a LoopMayExit. If any computable LoopMustExit is found, then
9182 // MaxBECount is the minimum EL.ConstantMaxNotTaken of computable
9183 // LoopMustExits. Otherwise, MaxBECount is conservatively the maximum
9184 // EL.ConstantMaxNotTaken, where CouldNotCompute is considered greater than
9185 // any
9186 // computable EL.ConstantMaxNotTaken.
9187 if (EL.ConstantMaxNotTaken != getCouldNotCompute() && Latch &&
9188 DT.dominates(ExitBB, Latch)) {
9189 if (!MustExitMaxBECount) {
9190 MustExitMaxBECount = EL.ConstantMaxNotTaken;
9191 MustExitMaxOrZero = EL.MaxOrZero;
9192 } else {
9193 MustExitMaxBECount = getUMinFromMismatchedTypes(MustExitMaxBECount,
9194 EL.ConstantMaxNotTaken);
9195 }
9196 } else if (MayExitMaxBECount != getCouldNotCompute()) {
9197 if (!MayExitMaxBECount || EL.ConstantMaxNotTaken == getCouldNotCompute())
9198 MayExitMaxBECount = EL.ConstantMaxNotTaken;
9199 else {
9200 MayExitMaxBECount = getUMaxFromMismatchedTypes(MayExitMaxBECount,
9201 EL.ConstantMaxNotTaken);
9202 }
9203 }
9204 }
9205 const SCEV *MaxBECount = MustExitMaxBECount ? MustExitMaxBECount :
9206 (MayExitMaxBECount ? MayExitMaxBECount : getCouldNotCompute());
9207 // The loop backedge will be taken the maximum or zero times if there's
9208 // a single exit that must be taken the maximum or zero times.
9209 bool MaxOrZero = (MustExitMaxOrZero && ExitingBlocks.size() == 1);
9210
9211 // Remember which SCEVs are used in exit limits for invalidation purposes.
9212 // We only care about non-constant SCEVs here, so we can ignore
9213 // EL.ConstantMaxNotTaken
9214 // and MaxBECount, which must be SCEVConstant.
9215 for (const auto &Pair : ExitCounts) {
9216 if (!isa<SCEVConstant>(Pair.second.ExactNotTaken))
9217 BECountUsers[Pair.second.ExactNotTaken].insert({L, AllowPredicates});
9218 if (!isa<SCEVConstant>(Pair.second.SymbolicMaxNotTaken))
9219 BECountUsers[Pair.second.SymbolicMaxNotTaken].insert(
9220 {L, AllowPredicates});
9221 }
9222 return BackedgeTakenInfo(std::move(ExitCounts), CouldComputeBECount,
9223 MaxBECount, MaxOrZero);
9224}
9225
9226ScalarEvolution::ExitLimit
9227ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
9228 bool IsOnlyExit, bool AllowPredicates) {
9229 assert(L->contains(ExitingBlock) && "Exit count for non-loop block?");
9230 // If our exiting block does not dominate the latch, then its connection with
9231 // loop's exit limit may be far from trivial.
9232 const BasicBlock *Latch = L->getLoopLatch();
9233 if (!Latch || !DT.dominates(ExitingBlock, Latch))
9234 return getCouldNotCompute();
9235
9236 Instruction *Term = ExitingBlock->getTerminator();
9237 if (CondBrInst *BI = dyn_cast<CondBrInst>(Term)) {
9238 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
9239 assert(ExitIfTrue == L->contains(BI->getSuccessor(1)) &&
9240 "It should have one successor in loop and one exit block!");
9241 // Proceed to the next level to examine the exit condition expression.
9242 return computeExitLimitFromCond(L, BI->getCondition(), ExitIfTrue,
9243 /*ControlsOnlyExit=*/IsOnlyExit,
9244 AllowPredicates);
9245 }
9246
9247 if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) {
9248 // For switch, make sure that there is a single exit from the loop.
9249 BasicBlock *Exit = nullptr;
9250 for (auto *SBB : successors(ExitingBlock))
9251 if (!L->contains(SBB)) {
9252 if (Exit) // Multiple exit successors.
9253 return getCouldNotCompute();
9254 Exit = SBB;
9255 }
9256 assert(Exit && "Exiting block must have at least one exit");
9257 return computeExitLimitFromSingleExitSwitch(
9258 L, SI, Exit, /*ControlsOnlyExit=*/IsOnlyExit);
9259 }
9260
9261 return getCouldNotCompute();
9262}
9263
9265 const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
9266 bool AllowPredicates) {
9267 ScalarEvolution::ExitLimitCacheTy Cache(L, ExitIfTrue, AllowPredicates);
9268 return computeExitLimitFromCondCached(Cache, L, ExitCond, ExitIfTrue,
9269 ControlsOnlyExit, AllowPredicates);
9270}
9271
9272std::optional<ScalarEvolution::ExitLimit>
9273ScalarEvolution::ExitLimitCache::find(const Loop *L, Value *ExitCond,
9274 bool ExitIfTrue, bool ControlsOnlyExit,
9275 bool AllowPredicates) {
9276 (void)this->L;
9277 (void)this->ExitIfTrue;
9278 (void)this->AllowPredicates;
9279
9280 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
9281 this->AllowPredicates == AllowPredicates &&
9282 "Variance in assumed invariant key components!");
9283 auto Itr = TripCountMap.find({ExitCond, ControlsOnlyExit});
9284 if (Itr == TripCountMap.end())
9285 return std::nullopt;
9286 return Itr->second;
9287}
9288
9289void ScalarEvolution::ExitLimitCache::insert(const Loop *L, Value *ExitCond,
9290 bool ExitIfTrue,
9291 bool ControlsOnlyExit,
9292 bool AllowPredicates,
9293 const ExitLimit &EL) {
9294 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
9295 this->AllowPredicates == AllowPredicates &&
9296 "Variance in assumed invariant key components!");
9297
9298 auto InsertResult = TripCountMap.insert({{ExitCond, ControlsOnlyExit}, EL});
9299 assert(InsertResult.second && "Expected successful insertion!");
9300 (void)InsertResult;
9301 (void)ExitIfTrue;
9302}
9303
9304ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached(
9305 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9306 bool ControlsOnlyExit, bool AllowPredicates) {
9307
9308 if (auto MaybeEL = Cache.find(L, ExitCond, ExitIfTrue, ControlsOnlyExit,
9309 AllowPredicates))
9310 return *MaybeEL;
9311
9312 ExitLimit EL = computeExitLimitFromCondImpl(
9313 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates);
9314 Cache.insert(L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates, EL);
9315 return EL;
9316}
9317
9318ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
9319 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9320 bool ControlsOnlyExit, bool AllowPredicates) {
9321 // Handle BinOp conditions (And, Or).
9322 if (auto LimitFromBinOp = computeExitLimitFromCondFromBinOp(
9323 Cache, L, ExitCond, ExitIfTrue, AllowPredicates))
9324 return *LimitFromBinOp;
9325
9326 // With an icmp, it may be feasible to compute an exact backedge-taken count.
9327 // Proceed to the next level to examine the icmp.
9328 if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond)) {
9329 ExitLimit EL =
9330 computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsOnlyExit);
9331 if (EL.hasFullInfo() || !AllowPredicates)
9332 return EL;
9333
9334 // Try again, but use SCEV predicates this time.
9335 return computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue,
9336 ControlsOnlyExit,
9337 /*AllowPredicates=*/true);
9338 }
9339
9340 // Check for a constant condition. These are normally stripped out by
9341 // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to
9342 // preserve the CFG and is temporarily leaving constant conditions
9343 // in place.
9344 if (ConstantInt *CI = dyn_cast<ConstantInt>(ExitCond)) {
9345 if (ExitIfTrue == !CI->getZExtValue())
9346 // The backedge is always taken.
9347 return getCouldNotCompute();
9348 // The backedge is never taken.
9349 return getZero(CI->getType());
9350 }
9351
9352 // If we're exiting based on the overflow flag of an x.with.overflow intrinsic
9353 // with a constant step, we can form an equivalent icmp predicate and figure
9354 // out how many iterations will be taken before we exit.
9355 const WithOverflowInst *WO;
9356 const APInt *C;
9357 if (match(ExitCond, m_ExtractValue<1>(m_WithOverflowInst(WO))) &&
9358 match(WO->getRHS(), m_APInt(C))) {
9359 ConstantRange NWR =
9361 WO->getNoWrapKind());
9362 CmpInst::Predicate Pred;
9363 APInt NewRHSC, Offset;
9364 NWR.getEquivalentICmp(Pred, NewRHSC, Offset);
9365 if (!ExitIfTrue)
9366 Pred = ICmpInst::getInversePredicate(Pred);
9367 auto *LHS = getSCEV(WO->getLHS());
9368 if (Offset != 0)
9370 auto EL = computeExitLimitFromICmp(L, Pred, LHS, getConstant(NewRHSC),
9371 ControlsOnlyExit, AllowPredicates);
9372 if (EL.hasAnyInfo())
9373 return EL;
9374 }
9375
9376 // If it's not an integer or pointer comparison then compute it the hard way.
9377 return computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
9378}
9379
9380std::optional<ScalarEvolution::ExitLimit>
9381ScalarEvolution::computeExitLimitFromCondFromBinOp(ExitLimitCacheTy &Cache,
9382 const Loop *L,
9383 Value *ExitCond,
9384 bool ExitIfTrue,
9385 bool AllowPredicates) {
9386 // Check if the controlling expression for this loop is an And or Or.
9387 Value *Op0, *Op1;
9388 bool IsAnd;
9389 if (match(ExitCond, m_LogicalAnd(m_Value(Op0), m_Value(Op1))))
9390 IsAnd = true;
9391 else if (match(ExitCond, m_LogicalOr(m_Value(Op0), m_Value(Op1))))
9392 IsAnd = false;
9393 else
9394 return std::nullopt;
9395
9396 // A sub-condition of a non-trivial binop never solely controls the exit,
9397 // whether we exit always depends on both conditions.
9398 ExitLimit EL0 = computeExitLimitFromCondCached(
9399 Cache, L, Op0, ExitIfTrue, /*ControlsOnlyExit=*/false, AllowPredicates);
9400 ExitLimit EL1 = computeExitLimitFromCondCached(
9401 Cache, L, Op1, ExitIfTrue, /*ControlsOnlyExit=*/false, AllowPredicates);
9402
9403 // EitherMayExit is true in these two cases:
9404 // br (and Op0 Op1), loop, exit
9405 // br (or Op0 Op1), exit, loop
9406 bool EitherMayExit = IsAnd ^ ExitIfTrue;
9407
9408 const SCEV *BECount = getCouldNotCompute();
9409 const SCEV *ConstantMaxBECount = getCouldNotCompute();
9410 const SCEV *SymbolicMaxBECount = getCouldNotCompute();
9411 if (EitherMayExit) {
9412 bool UseSequentialUMin = !isa<BinaryOperator>(ExitCond);
9413 // Both conditions must be same for the loop to continue executing.
9414 // Choose the less conservative count.
9415 if (EL0.ExactNotTaken != getCouldNotCompute() &&
9416 EL1.ExactNotTaken != getCouldNotCompute()) {
9417 BECount = getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken,
9418 UseSequentialUMin);
9419 }
9420 if (EL0.ConstantMaxNotTaken == getCouldNotCompute())
9421 ConstantMaxBECount = EL1.ConstantMaxNotTaken;
9422 else if (EL1.ConstantMaxNotTaken == getCouldNotCompute())
9423 ConstantMaxBECount = EL0.ConstantMaxNotTaken;
9424 else
9425 ConstantMaxBECount = getUMinFromMismatchedTypes(EL0.ConstantMaxNotTaken,
9426 EL1.ConstantMaxNotTaken);
9427 if (EL0.SymbolicMaxNotTaken == getCouldNotCompute())
9428 SymbolicMaxBECount = EL1.SymbolicMaxNotTaken;
9429 else if (EL1.SymbolicMaxNotTaken == getCouldNotCompute())
9430 SymbolicMaxBECount = EL0.SymbolicMaxNotTaken;
9431 else
9432 SymbolicMaxBECount = getUMinFromMismatchedTypes(
9433 EL0.SymbolicMaxNotTaken, EL1.SymbolicMaxNotTaken, UseSequentialUMin);
9434 } else {
9435 // Both conditions must be same at the same time for the loop to exit.
9436 // For now, be conservative.
9437 if (EL0.ExactNotTaken == EL1.ExactNotTaken)
9438 BECount = EL0.ExactNotTaken;
9439 }
9440
9441 // There are cases (e.g. PR26207) where computeExitLimitFromCond is able
9442 // to be more aggressive when computing BECount than when computing
9443 // ConstantMaxBECount. In these cases it is possible for EL0.ExactNotTaken
9444 // and
9445 // EL1.ExactNotTaken to match, but for EL0.ConstantMaxNotTaken and
9446 // EL1.ConstantMaxNotTaken to not.
9447 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
9448 !isa<SCEVCouldNotCompute>(BECount))
9449 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
9450 if (isa<SCEVCouldNotCompute>(SymbolicMaxBECount))
9451 SymbolicMaxBECount =
9452 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
9453 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
9454 {ArrayRef(EL0.Predicates), ArrayRef(EL1.Predicates)});
9455}
9456
9457ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9458 const Loop *L, ICmpInst *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
9459 bool AllowPredicates) {
9460 // If the condition was exit on true, convert the condition to exit on false
9461 CmpPredicate Pred;
9462 if (!ExitIfTrue)
9463 Pred = ExitCond->getCmpPredicate();
9464 else
9465 Pred = ExitCond->getInverseCmpPredicate();
9466 const ICmpInst::Predicate OriginalPred = Pred;
9467
9468 const SCEV *LHS = getSCEV(ExitCond->getOperand(0));
9469 const SCEV *RHS = getSCEV(ExitCond->getOperand(1));
9470
9471 ExitLimit EL = computeExitLimitFromICmp(L, Pred, LHS, RHS, ControlsOnlyExit,
9472 AllowPredicates);
9473 if (EL.hasAnyInfo())
9474 return EL;
9475
9476 auto *ExhaustiveCount =
9477 computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
9478
9479 if (!isa<SCEVCouldNotCompute>(ExhaustiveCount))
9480 return ExhaustiveCount;
9481
9482 return computeShiftCompareExitLimit(ExitCond->getOperand(0),
9483 ExitCond->getOperand(1), L, OriginalPred);
9484}
9485ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9486 const Loop *L, CmpPredicate Pred, SCEVUse LHS, SCEVUse RHS,
9487 bool ControlsOnlyExit, bool AllowPredicates) {
9488
9489 // Try to evaluate any dependencies out of the loop.
9490 LHS = getSCEVAtScope(LHS, L);
9491 RHS = getSCEVAtScope(RHS, L);
9492
9493 // At this point, we would like to compute how many iterations of the
9494 // loop the predicate will return true for these inputs.
9495 if (isLoopInvariant(LHS, L) && !isLoopInvariant(RHS, L)) {
9496 // If there is a loop-invariant, force it into the RHS.
9497 std::swap(LHS, RHS);
9499 }
9500
9501 bool ControllingFiniteLoop = ControlsOnlyExit && loopHasNoAbnormalExits(L) &&
9503 // Simplify the operands before analyzing them.
9504 (void)SimplifyICmpOperands(Pred, LHS, RHS, /*Depth=*/0);
9505
9506 // If we have a comparison of a chrec against a constant, try to use value
9507 // ranges to answer this query.
9508 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS))
9509 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS))
9510 if (AddRec->getLoop() == L) {
9511 // Form the constant range.
9512 ConstantRange CompRange =
9513 ConstantRange::makeExactICmpRegion(Pred, RHSC->getAPInt());
9514
9515 const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this);
9516 if (!isa<SCEVCouldNotCompute>(Ret)) return Ret;
9517 }
9518
9519 // If this loop must exit based on this condition (or execute undefined
9520 // behaviour), see if we can improve wrap flags. This is essentially
9521 // a must execute style proof.
9522 if (ControllingFiniteLoop && isLoopInvariant(RHS, L)) {
9523 // If we can prove the test sequence produced must repeat the same values
9524 // on self-wrap of the IV, then we can infer that IV doesn't self wrap
9525 // because if it did, we'd have an infinite (undefined) loop.
9526 // TODO: We can peel off any functions which are invertible *in L*. Loop
9527 // invariant terms are effectively constants for our purposes here.
9528 SCEVUse InnerLHS = LHS;
9529 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS))
9530 InnerLHS = ZExt->getOperand();
9531 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(InnerLHS);
9532 AR && !AR->hasNoSelfWrap() && AR->getLoop() == L && AR->isAffine() &&
9533 isKnownToBeAPowerOfTwo(AR->getStepRecurrence(*this), /*OrZero=*/true,
9534 /*OrNegative=*/true)) {
9535 auto Flags = AR->getNoWrapFlags();
9536 Flags = setFlags(Flags, SCEV::FlagNW);
9537 SmallVector<SCEVUse> Operands{AR->operands()};
9538 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
9539 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
9540 }
9541
9542 // For a slt/ult condition with a positive step, can we prove nsw/nuw?
9543 // From no-self-wrap, this follows trivially from the fact that every
9544 // (un)signed-wrapped, but not self-wrapped value must be LT than the
9545 // last value before (un)signed wrap. Since we know that last value
9546 // didn't exit, nor will any smaller one.
9547 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT) {
9548 auto WrapType = Pred == ICmpInst::ICMP_SLT ? SCEV::FlagNSW : SCEV::FlagNUW;
9549 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS);
9550 AR && AR->getLoop() == L && AR->isAffine() &&
9551 !AR->getNoWrapFlags(WrapType) && AR->hasNoSelfWrap() &&
9552 isKnownPositive(AR->getStepRecurrence(*this))) {
9553 auto Flags = AR->getNoWrapFlags();
9554 Flags = setFlags(Flags, WrapType);
9555 SmallVector<SCEVUse> Operands{AR->operands()};
9556 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
9557 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
9558 }
9559 }
9560 }
9561
9562 switch (Pred) {
9563 case ICmpInst::ICMP_NE: { // while (X != Y)
9564 // Convert to: while (X-Y != 0)
9565 if (LHS->getType()->isPointerTy()) {
9568 return LHS;
9569 }
9570 if (RHS->getType()->isPointerTy()) {
9573 return RHS;
9574 }
9575 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit,
9576 AllowPredicates);
9577 if (EL.hasAnyInfo())
9578 return EL;
9579 break;
9580 }
9581 case ICmpInst::ICMP_EQ: { // while (X == Y)
9582 // Convert to: while (X-Y == 0)
9583 if (LHS->getType()->isPointerTy()) {
9586 return LHS;
9587 }
9588 if (RHS->getType()->isPointerTy()) {
9591 return RHS;
9592 }
9593 ExitLimit EL = howFarToNonZero(getMinusSCEV(LHS, RHS), L);
9594 if (EL.hasAnyInfo()) return EL;
9595 break;
9596 }
9597 case ICmpInst::ICMP_SLE:
9598 case ICmpInst::ICMP_ULE:
9599 // Since the loop is finite, an invariant RHS cannot include the boundary
9600 // value, otherwise it would loop forever.
9601 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9602 !isLoopInvariant(RHS, L)) {
9603 // Otherwise, perform the addition in a wider type, to avoid overflow.
9604 // If the LHS is an addrec with the appropriate nowrap flag, the
9605 // extension will be sunk into it and the exit count can be analyzed.
9606 auto *OldType = dyn_cast<IntegerType>(LHS->getType());
9607 if (!OldType)
9608 break;
9609 // Prefer doubling the bitwidth over adding a single bit to make it more
9610 // likely that we use a legal type.
9611 auto *NewType =
9612 Type::getIntNTy(OldType->getContext(), OldType->getBitWidth() * 2);
9613 if (ICmpInst::isSigned(Pred)) {
9614 LHS = getSignExtendExpr(LHS, NewType);
9615 RHS = getSignExtendExpr(RHS, NewType);
9616 } else {
9617 LHS = getZeroExtendExpr(LHS, NewType);
9618 RHS = getZeroExtendExpr(RHS, NewType);
9619 }
9620 }
9622 [[fallthrough]];
9623 case ICmpInst::ICMP_SLT:
9624 case ICmpInst::ICMP_ULT: { // while (X < Y)
9625 bool IsSigned = ICmpInst::isSigned(Pred);
9626 ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9627 AllowPredicates);
9628 if (EL.hasAnyInfo())
9629 return EL;
9630 break;
9631 }
9632 case ICmpInst::ICMP_SGE:
9633 case ICmpInst::ICMP_UGE:
9634 // Since the loop is finite, an invariant RHS cannot include the boundary
9635 // value, otherwise it would loop forever.
9636 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9637 !isLoopInvariant(RHS, L))
9638 break;
9640 [[fallthrough]];
9641 case ICmpInst::ICMP_SGT:
9642 case ICmpInst::ICMP_UGT: { // while (X > Y)
9643 bool IsSigned = ICmpInst::isSigned(Pred);
9644 ExitLimit EL = howManyGreaterThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9645 AllowPredicates);
9646 if (EL.hasAnyInfo())
9647 return EL;
9648 break;
9649 }
9650 default:
9651 break;
9652 }
9653
9654 return getCouldNotCompute();
9655}
9656
9657ScalarEvolution::ExitLimit
9658ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L,
9659 SwitchInst *Switch,
9660 BasicBlock *ExitingBlock,
9661 bool ControlsOnlyExit) {
9662 assert(!L->contains(ExitingBlock) && "Not an exiting block!");
9663
9664 // Give up if the exit is the default dest of a switch.
9665 if (Switch->getDefaultDest() == ExitingBlock)
9666 return getCouldNotCompute();
9667
9668 assert(L->contains(Switch->getDefaultDest()) &&
9669 "Default case must not exit the loop!");
9670 const SCEV *LHS = getSCEVAtScope(Switch->getCondition(), L);
9671 const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock));
9672
9673 // while (X != Y) --> while (X-Y != 0)
9674 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit);
9675 if (EL.hasAnyInfo())
9676 return EL;
9677
9678 return getCouldNotCompute();
9679}
9680
9681static ConstantInt *
9683 ScalarEvolution &SE) {
9684 const SCEV *InVal = SE.getConstant(C);
9685 const SCEV *Val = AddRec->evaluateAtIteration(InVal, SE);
9687 "Evaluation of SCEV at constant didn't fold correctly?");
9688 return cast<SCEVConstant>(Val)->getValue();
9689}
9690
9691ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit(
9692 Value *LHS, Value *RHSV, const Loop *L, ICmpInst::Predicate Pred) {
9693 ConstantInt *RHS = dyn_cast<ConstantInt>(RHSV);
9694 if (!RHS)
9695 return getCouldNotCompute();
9696
9697 const BasicBlock *Latch = L->getLoopLatch();
9698 if (!Latch)
9699 return getCouldNotCompute();
9700
9701 const BasicBlock *Predecessor = L->getLoopPredecessor();
9702 if (!Predecessor)
9703 return getCouldNotCompute();
9704
9705 // Return true if V is of the form "LHS `shift_op` <positive constant>".
9706 // Return LHS in OutLHS and shift_opt in OutOpCode.
9707 auto MatchPositiveShift =
9708 [](Value *V, Value *&OutLHS, Instruction::BinaryOps &OutOpCode) {
9709
9710 using namespace PatternMatch;
9711
9712 ConstantInt *ShiftAmt;
9713 if (match(V, m_LShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9714 OutOpCode = Instruction::LShr;
9715 else if (match(V, m_AShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9716 OutOpCode = Instruction::AShr;
9717 else if (match(V, m_Shl(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9718 OutOpCode = Instruction::Shl;
9719 else
9720 return false;
9721
9722 return ShiftAmt->getValue().isStrictlyPositive();
9723 };
9724
9725 // Recognize a "shift recurrence" either of the form %iv or of %iv.shifted in
9726 //
9727 // loop:
9728 // %iv = phi i32 [ %iv.shifted, %loop ], [ %val, %preheader ]
9729 // %iv.shifted = lshr i32 %iv, <positive constant>
9730 //
9731 // Return true on a successful match. Return the corresponding PHI node (%iv
9732 // above) in PNOut and the opcode of the shift operation in OpCodeOut.
9733 auto MatchShiftRecurrence =
9734 [&](Value *V, PHINode *&PNOut, Instruction::BinaryOps &OpCodeOut) {
9735 std::optional<Instruction::BinaryOps> PostShiftOpCode;
9736
9737 {
9739 Value *V;
9740
9741 // If we encounter a shift instruction, "peel off" the shift operation,
9742 // and remember that we did so. Later when we inspect %iv's backedge
9743 // value, we will make sure that the backedge value uses the same
9744 // operation.
9745 //
9746 // Note: the peeled shift operation does not have to be the same
9747 // instruction as the one feeding into the PHI's backedge value. We only
9748 // really care about it being the same *kind* of shift instruction --
9749 // that's all that is required for our later inferences to hold.
9750 if (MatchPositiveShift(LHS, V, OpC)) {
9751 PostShiftOpCode = OpC;
9752 LHS = V;
9753 }
9754 }
9755
9756 PNOut = dyn_cast<PHINode>(LHS);
9757 if (!PNOut || PNOut->getParent() != L->getHeader())
9758 return false;
9759
9760 Value *BEValue = PNOut->getIncomingValueForBlock(Latch);
9761 Value *OpLHS;
9762
9763 return
9764 // The backedge value for the PHI node must be a shift by a positive
9765 // amount
9766 MatchPositiveShift(BEValue, OpLHS, OpCodeOut) &&
9767
9768 // of the PHI node itself
9769 OpLHS == PNOut &&
9770
9771 // and the kind of shift should be match the kind of shift we peeled
9772 // off, if any.
9773 (!PostShiftOpCode || *PostShiftOpCode == OpCodeOut);
9774 };
9775
9776 PHINode *PN;
9778 if (!MatchShiftRecurrence(LHS, PN, OpCode))
9779 return getCouldNotCompute();
9780
9781 const DataLayout &DL = getDataLayout();
9782
9783 // The key rationale for this optimization is that for some kinds of shift
9784 // recurrences, the value of the recurrence "stabilizes" to either 0 or -1
9785 // within a finite number of iterations. If the condition guarding the
9786 // backedge (in the sense that the backedge is taken if the condition is true)
9787 // is false for the value the shift recurrence stabilizes to, then we know
9788 // that the backedge is taken only a finite number of times.
9789
9790 ConstantInt *StableValue = nullptr;
9791 switch (OpCode) {
9792 default:
9793 llvm_unreachable("Impossible case!");
9794
9795 case Instruction::AShr: {
9796 // {K,ashr,<positive-constant>} stabilizes to signum(K) in at most
9797 // bitwidth(K) iterations.
9798 Value *FirstValue = PN->getIncomingValueForBlock(Predecessor);
9799 KnownBits Known = computeKnownBits(FirstValue, DL, &AC,
9800 Predecessor->getTerminator(), &DT);
9801 auto *Ty = cast<IntegerType>(RHS->getType());
9802 if (Known.isNonNegative())
9803 StableValue = ConstantInt::get(Ty, 0);
9804 else if (Known.isNegative())
9805 StableValue = ConstantInt::get(Ty, -1, true);
9806 else
9807 return getCouldNotCompute();
9808
9809 break;
9810 }
9811 case Instruction::LShr:
9812 case Instruction::Shl:
9813 // Both {K,lshr,<positive-constant>} and {K,shl,<positive-constant>}
9814 // stabilize to 0 in at most bitwidth(K) iterations.
9815 StableValue = ConstantInt::get(cast<IntegerType>(RHS->getType()), 0);
9816 break;
9817 }
9818
9819 auto *Result =
9820 ConstantFoldCompareInstOperands(Pred, StableValue, RHS, DL, &TLI);
9821 assert(Result->getType()->isIntegerTy(1) &&
9822 "Otherwise cannot be an operand to a branch instruction");
9823
9824 if (Result->isNullValue()) {
9825 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
9826 const SCEV *UpperBound =
9828 return ExitLimit(getCouldNotCompute(), UpperBound, UpperBound, false);
9829 }
9830
9831 return getCouldNotCompute();
9832}
9833
9834/// Return true if we can constant fold an instruction of the specified type,
9835/// assuming that all operands were constants.
9836static bool CanConstantFold(const Instruction *I) {
9840 return true;
9841
9842 if (const CallInst *CI = dyn_cast<CallInst>(I))
9843 if (const Function *F = CI->getCalledFunction())
9844 return canConstantFoldCallTo(CI, F);
9845 return false;
9846}
9847
9848/// Determine whether this instruction can constant evolve within this loop
9849/// assuming its operands can all constant evolve.
9850static bool canConstantEvolve(Instruction *I, const Loop *L) {
9851 // An instruction outside of the loop can't be derived from a loop PHI.
9852 if (!L->contains(I)) return false;
9853
9854 if (isa<PHINode>(I)) {
9855 // We don't currently keep track of the control flow needed to evaluate
9856 // PHIs, so we cannot handle PHIs inside of loops.
9857 return L->getHeader() == I->getParent();
9858 }
9859
9860 // If we won't be able to constant fold this expression even if the operands
9861 // are constants, bail early.
9862 return CanConstantFold(I);
9863}
9864
9865/// getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by
9866/// recursing through each instruction operand until reaching a loop header phi.
9867static PHINode *
9870 unsigned Depth) {
9872 return nullptr;
9873
9874 // Otherwise, we can evaluate this instruction if all of its operands are
9875 // constant or derived from a PHI node themselves.
9876 PHINode *PHI = nullptr;
9877 for (Value *Op : UseInst->operands()) {
9878 if (isa<Constant>(Op)) continue;
9879
9881 if (!OpInst || !canConstantEvolve(OpInst, L)) return nullptr;
9882
9883 PHINode *P = dyn_cast<PHINode>(OpInst);
9884 if (!P)
9885 // If this operand is already visited, reuse the prior result.
9886 // We may have P != PHI if this is the deepest point at which the
9887 // inconsistent paths meet.
9888 P = PHIMap.lookup(OpInst);
9889 if (!P) {
9890 // Recurse and memoize the results, whether a phi is found or not.
9891 // This recursive call invalidates pointers into PHIMap.
9892 P = getConstantEvolvingPHIOperands(OpInst, L, PHIMap, Depth + 1);
9893 PHIMap[OpInst] = P;
9894 }
9895 if (!P)
9896 return nullptr; // Not evolving from PHI
9897 if (PHI && PHI != P)
9898 return nullptr; // Evolving from multiple different PHIs.
9899 PHI = P;
9900 }
9901 // This is a expression evolving from a constant PHI!
9902 return PHI;
9903}
9904
9905/// getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node
9906/// in the loop that V is derived from. We allow arbitrary operations along the
9907/// way, but the operands of an operation must either be constants or a value
9908/// derived from a constant PHI. If this expression does not fit with these
9909/// constraints, return null.
9912 if (!I || !canConstantEvolve(I, L)) return nullptr;
9913
9914 if (PHINode *PN = dyn_cast<PHINode>(I))
9915 return PN;
9916
9917 // Record non-constant instructions contained by the loop.
9919 return getConstantEvolvingPHIOperands(I, L, PHIMap, 0);
9920}
9921
9922/// EvaluateExpression - Given an expression that passes the
9923/// getConstantEvolvingPHI predicate, evaluate its value assuming the PHI node
9924/// in the loop has the value PHIVal. If we can't fold this expression for some
9925/// reason, return null.
9928 const DataLayout &DL,
9929 const TargetLibraryInfo *TLI) {
9930 // Convenient constant check, but redundant for recursive calls.
9931 if (Constant *C = dyn_cast<Constant>(V)) return C;
9933 if (!I) return nullptr;
9934
9935 if (Constant *C = Vals.lookup(I)) return C;
9936
9937 // An instruction inside the loop depends on a value outside the loop that we
9938 // weren't given a mapping for, or a value such as a call inside the loop.
9939 if (!canConstantEvolve(I, L)) return nullptr;
9940
9941 // An unmapped PHI can be due to a branch or another loop inside this loop,
9942 // or due to this not being the initial iteration through a loop where we
9943 // couldn't compute the evolution of this particular PHI last time.
9944 if (isa<PHINode>(I)) return nullptr;
9945
9946 std::vector<Constant*> Operands(I->getNumOperands());
9947
9948 for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
9949 Instruction *Operand = dyn_cast<Instruction>(I->getOperand(i));
9950 if (!Operand) {
9951 Operands[i] = dyn_cast<Constant>(I->getOperand(i));
9952 if (!Operands[i]) return nullptr;
9953 continue;
9954 }
9955 Constant *C = EvaluateExpression(Operand, L, Vals, DL, TLI);
9956 Vals[Operand] = C;
9957 if (!C) return nullptr;
9958 Operands[i] = C;
9959 }
9960
9961 return ConstantFoldInstOperands(I, Operands, DL, TLI,
9962 /*AllowNonDeterministic=*/false);
9963}
9964
9965
9966// If every incoming value to PN except the one for BB is a specific Constant,
9967// return that, else return nullptr.
9969 Constant *IncomingVal = nullptr;
9970
9971 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
9972 if (PN->getIncomingBlock(i) == BB)
9973 continue;
9974
9975 auto *CurrentVal = dyn_cast<Constant>(PN->getIncomingValue(i));
9976 if (!CurrentVal)
9977 return nullptr;
9978
9979 if (IncomingVal != CurrentVal) {
9980 if (IncomingVal)
9981 return nullptr;
9982 IncomingVal = CurrentVal;
9983 }
9984 }
9985
9986 return IncomingVal;
9987}
9988
9989/// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
9990/// in the header of its containing loop, we know the loop executes a
9991/// constant number of times, and the PHI node is just a recurrence
9992/// involving constants, fold it.
9993Constant *
9994ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN,
9995 const APInt &BEs,
9996 const Loop *L) {
9997 auto [I, Inserted] = ConstantEvolutionLoopExitValue.try_emplace(PN);
9998 if (!Inserted)
9999 return I->second;
10000
10002 return nullptr; // Not going to evaluate it.
10003
10004 Constant *&RetVal = I->second;
10005
10006 DenseMap<Instruction *, Constant *> CurrentIterVals;
10007 BasicBlock *Header = L->getHeader();
10008 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
10009
10010 BasicBlock *Latch = L->getLoopLatch();
10011 if (!Latch)
10012 return nullptr;
10013
10014 for (PHINode &PHI : Header->phis()) {
10015 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
10016 CurrentIterVals[&PHI] = StartCST;
10017 }
10018 if (!CurrentIterVals.count(PN))
10019 return RetVal = nullptr;
10020
10021 Value *BEValue = PN->getIncomingValueForBlock(Latch);
10022
10023 // Execute the loop symbolically to determine the exit value.
10024 assert(BEs.getActiveBits() < CHAR_BIT * sizeof(unsigned) &&
10025 "BEs is <= MaxBruteForceIterations which is an 'unsigned'!");
10026
10027 unsigned NumIterations = BEs.getZExtValue(); // must be in range
10028 unsigned IterationNum = 0;
10029 const DataLayout &DL = getDataLayout();
10030 for (; ; ++IterationNum) {
10031 if (IterationNum == NumIterations)
10032 return RetVal = CurrentIterVals[PN]; // Got exit value!
10033
10034 // Compute the value of the PHIs for the next iteration.
10035 // EvaluateExpression adds non-phi values to the CurrentIterVals map.
10036 DenseMap<Instruction *, Constant *> NextIterVals;
10037 Constant *NextPHI =
10038 EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
10039 if (!NextPHI)
10040 return nullptr; // Couldn't evaluate!
10041 NextIterVals[PN] = NextPHI;
10042
10043 bool StoppedEvolving = NextPHI == CurrentIterVals[PN];
10044
10045 // Also evaluate the other PHI nodes. However, we don't get to stop if we
10046 // cease to be able to evaluate one of them or if they stop evolving,
10047 // because that doesn't necessarily prevent us from computing PN.
10049 for (const auto &I : CurrentIterVals) {
10050 PHINode *PHI = dyn_cast<PHINode>(I.first);
10051 if (!PHI || PHI == PN || PHI->getParent() != Header) continue;
10052 PHIsToCompute.emplace_back(PHI, I.second);
10053 }
10054 // We use two distinct loops because EvaluateExpression may invalidate any
10055 // iterators into CurrentIterVals.
10056 for (const auto &I : PHIsToCompute) {
10057 PHINode *PHI = I.first;
10058 Constant *&NextPHI = NextIterVals[PHI];
10059 if (!NextPHI) { // Not already computed.
10060 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
10061 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
10062 }
10063 if (NextPHI != I.second)
10064 StoppedEvolving = false;
10065 }
10066
10067 // If all entries in CurrentIterVals == NextIterVals then we can stop
10068 // iterating, the loop can't continue to change.
10069 if (StoppedEvolving)
10070 return RetVal = CurrentIterVals[PN];
10071
10072 CurrentIterVals.swap(NextIterVals);
10073 }
10074}
10075
10076const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L,
10077 Value *Cond,
10078 bool ExitWhen) {
10079 PHINode *PN = getConstantEvolvingPHI(Cond, L);
10080 if (!PN) return getCouldNotCompute();
10081
10082 // If the loop is canonicalized, the PHI will have exactly two entries.
10083 // That's the only form we support here.
10084 if (PN->getNumIncomingValues() != 2) return getCouldNotCompute();
10085
10086 DenseMap<Instruction *, Constant *> CurrentIterVals;
10087 BasicBlock *Header = L->getHeader();
10088 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
10089
10090 BasicBlock *Latch = L->getLoopLatch();
10091 assert(Latch && "Should follow from NumIncomingValues == 2!");
10092
10093 for (PHINode &PHI : Header->phis()) {
10094 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
10095 CurrentIterVals[&PHI] = StartCST;
10096 }
10097 if (!CurrentIterVals.count(PN))
10098 return getCouldNotCompute();
10099
10100 // Okay, we find a PHI node that defines the trip count of this loop. Execute
10101 // the loop symbolically to determine when the condition gets a value of
10102 // "ExitWhen".
10103 unsigned MaxIterations = MaxBruteForceIterations; // Limit analysis.
10104 const DataLayout &DL = getDataLayout();
10105 for (unsigned IterationNum = 0; IterationNum != MaxIterations;++IterationNum){
10106 auto *CondVal = dyn_cast_or_null<ConstantInt>(
10107 EvaluateExpression(Cond, L, CurrentIterVals, DL, &TLI));
10108
10109 // Couldn't symbolically evaluate.
10110 if (!CondVal) return getCouldNotCompute();
10111
10112 if (CondVal->getValue() == uint64_t(ExitWhen)) {
10113 ++NumBruteForceTripCountsComputed;
10114 return getConstant(Type::getInt32Ty(getContext()), IterationNum);
10115 }
10116
10117 // Update all the PHI nodes for the next iteration.
10118 DenseMap<Instruction *, Constant *> NextIterVals;
10119
10120 // Create a list of which PHIs we need to compute. We want to do this before
10121 // calling EvaluateExpression on them because that may invalidate iterators
10122 // into CurrentIterVals.
10123 SmallVector<PHINode *, 8> PHIsToCompute;
10124 for (const auto &I : CurrentIterVals) {
10125 PHINode *PHI = dyn_cast<PHINode>(I.first);
10126 if (!PHI || PHI->getParent() != Header) continue;
10127 PHIsToCompute.push_back(PHI);
10128 }
10129 for (PHINode *PHI : PHIsToCompute) {
10130 Constant *&NextPHI = NextIterVals[PHI];
10131 if (NextPHI) continue; // Already computed!
10132
10133 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
10134 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
10135 }
10136 CurrentIterVals.swap(NextIterVals);
10137 }
10138
10139 // Too many iterations were needed to evaluate.
10140 return getCouldNotCompute();
10141}
10142
10143const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
10145 ValuesAtScopes[V];
10146 // Check to see if we've folded this expression at this loop before.
10147 for (auto &LS : Values)
10148 if (LS.first == L)
10149 return LS.second ? LS.second : V;
10150
10151 Values.emplace_back(L, nullptr);
10152
10153 // Otherwise compute it.
10154 const SCEV *C = computeSCEVAtScope(V, L);
10155 for (auto &LS : reverse(ValuesAtScopes[V]))
10156 if (LS.first == L) {
10157 LS.second = C;
10158 if (!isa<SCEVConstant>(C))
10159 ValuesAtScopesUsers[C].push_back({L, V});
10160 break;
10161 }
10162 return C;
10163}
10164
10165/// This builds up a Constant using the ConstantExpr interface. That way, we
10166/// will return Constants for objects which aren't represented by a
10167/// SCEVConstant, because SCEVConstant is restricted to ConstantInt.
10168/// Returns NULL if the SCEV isn't representable as a Constant.
10170 switch (V->getSCEVType()) {
10171 case scCouldNotCompute:
10172 case scAddRecExpr:
10173 case scVScale:
10174 return nullptr;
10175 case scConstant:
10176 return cast<SCEVConstant>(V)->getValue();
10177 case scUnknown:
10179 case scPtrToAddr: {
10181 if (Constant *CastOp = BuildConstantFromSCEV(P2I->getOperand()))
10182 return ConstantExpr::getPtrToAddr(CastOp, P2I->getType());
10183
10184 return nullptr;
10185 }
10186 case scPtrToInt: {
10188 if (Constant *CastOp = BuildConstantFromSCEV(P2I->getOperand()))
10189 return ConstantExpr::getPtrToInt(CastOp, P2I->getType());
10190
10191 return nullptr;
10192 }
10193 case scTruncate: {
10195 if (Constant *CastOp = BuildConstantFromSCEV(ST->getOperand()))
10196 return ConstantExpr::getTrunc(CastOp, ST->getType());
10197 return nullptr;
10198 }
10199 case scAddExpr: {
10200 const SCEVAddExpr *SA = cast<SCEVAddExpr>(V);
10201 Constant *C = nullptr;
10202 for (const SCEV *Op : SA->operands()) {
10204 if (!OpC)
10205 return nullptr;
10206 if (!C) {
10207 C = OpC;
10208 continue;
10209 }
10210 assert(!C->getType()->isPointerTy() &&
10211 "Can only have one pointer, and it must be last");
10212 if (OpC->getType()->isPointerTy()) {
10213 // The offsets have been converted to bytes. We can add bytes using
10214 // an i8 GEP.
10215 C = ConstantExpr::getPtrAdd(OpC, C);
10216 } else {
10217 C = ConstantExpr::getAdd(C, OpC);
10218 }
10219 }
10220 return C;
10221 }
10222 case scMulExpr:
10223 case scSignExtend:
10224 case scZeroExtend:
10225 case scUDivExpr:
10226 case scSMaxExpr:
10227 case scUMaxExpr:
10228 case scSMinExpr:
10229 case scUMinExpr:
10231 return nullptr;
10232 }
10233 llvm_unreachable("Unknown SCEV kind!");
10234}
10235
10236const SCEV *ScalarEvolution::getWithOperands(const SCEV *S,
10237 SmallVectorImpl<SCEVUse> &NewOps) {
10238 switch (S->getSCEVType()) {
10239 case scTruncate:
10240 case scZeroExtend:
10241 case scSignExtend:
10242 case scPtrToAddr:
10243 case scPtrToInt:
10244 return getCastExpr(S->getSCEVType(), NewOps[0], S->getType());
10245 case scAddRecExpr: {
10246 auto *AddRec = cast<SCEVAddRecExpr>(S);
10247 return getAddRecExpr(NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags());
10248 }
10249 case scAddExpr:
10250 return getAddExpr(NewOps, cast<SCEVAddExpr>(S)->getNoWrapFlags());
10251 case scMulExpr:
10252 return getMulExpr(NewOps, cast<SCEVMulExpr>(S)->getNoWrapFlags());
10253 case scUDivExpr:
10254 return getUDivExpr(NewOps[0], NewOps[1]);
10255 case scUMaxExpr:
10256 case scSMaxExpr:
10257 case scUMinExpr:
10258 case scSMinExpr:
10259 return getMinMaxExpr(S->getSCEVType(), NewOps);
10261 return getSequentialMinMaxExpr(S->getSCEVType(), NewOps);
10262 case scConstant:
10263 case scVScale:
10264 case scUnknown:
10265 return S;
10266 case scCouldNotCompute:
10267 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
10268 }
10269 llvm_unreachable("Unknown SCEV kind!");
10270}
10271
10272const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
10273 switch (V->getSCEVType()) {
10274 case scConstant:
10275 case scVScale:
10276 return V;
10277 case scAddRecExpr: {
10278 // If this is a loop recurrence for a loop that does not contain L, then we
10279 // are dealing with the final value computed by the loop.
10280 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(V);
10281 // First, attempt to evaluate each operand.
10282 // Avoid performing the look-up in the common case where the specified
10283 // expression has no loop-variant portions.
10284 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
10285 const SCEV *OpAtScope = getSCEVAtScope(AddRec->getOperand(i), L);
10286 if (OpAtScope == AddRec->getOperand(i))
10287 continue;
10288
10289 // Okay, at least one of these operands is loop variant but might be
10290 // foldable. Build a new instance of the folded commutative expression.
10292 NewOps.reserve(AddRec->getNumOperands());
10293 append_range(NewOps, AddRec->operands().take_front(i));
10294 NewOps.push_back(OpAtScope);
10295 for (++i; i != e; ++i)
10296 NewOps.push_back(getSCEVAtScope(AddRec->getOperand(i), L));
10297
10298 const SCEV *FoldedRec = getAddRecExpr(
10299 NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags(SCEV::FlagNW));
10300 AddRec = dyn_cast<SCEVAddRecExpr>(FoldedRec);
10301 // The addrec may be folded to a nonrecurrence, for example, if the
10302 // induction variable is multiplied by zero after constant folding. Go
10303 // ahead and return the folded value.
10304 if (!AddRec)
10305 return FoldedRec;
10306 break;
10307 }
10308
10309 // If the scope is outside the addrec's loop, evaluate it by using the
10310 // loop exit value of the addrec.
10311 if (!AddRec->getLoop()->contains(L)) {
10312 // To evaluate this recurrence, we need to know how many times the AddRec
10313 // loop iterates. Compute this now.
10314 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop());
10315 if (BackedgeTakenCount == getCouldNotCompute())
10316 return AddRec;
10317
10318 // Then, evaluate the AddRec.
10319 return AddRec->evaluateAtIteration(BackedgeTakenCount, *this);
10320 }
10321
10322 return AddRec;
10323 }
10324 case scTruncate:
10325 case scZeroExtend:
10326 case scSignExtend:
10327 case scPtrToAddr:
10328 case scPtrToInt:
10329 case scAddExpr:
10330 case scMulExpr:
10331 case scUDivExpr:
10332 case scUMaxExpr:
10333 case scSMaxExpr:
10334 case scUMinExpr:
10335 case scSMinExpr:
10336 case scSequentialUMinExpr: {
10337 ArrayRef<SCEVUse> Ops = V->operands();
10338 // Avoid performing the look-up in the common case where the specified
10339 // expression has no loop-variant portions.
10340 for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
10341 const SCEV *OpAtScope = getSCEVAtScope(Ops[i].getPointer(), L);
10342 if (OpAtScope != Ops[i].getPointer()) {
10343 // Okay, at least one of these operands is loop variant but might be
10344 // foldable. Build a new instance of the folded commutative expression.
10346 NewOps.reserve(Ops.size());
10347 append_range(NewOps, Ops.take_front(i));
10348 NewOps.push_back(OpAtScope);
10349
10350 for (++i; i != e; ++i) {
10351 OpAtScope = getSCEVAtScope(Ops[i].getPointer(), L);
10352 NewOps.push_back(OpAtScope);
10353 }
10354
10355 return getWithOperands(V, NewOps);
10356 }
10357 }
10358 // If we got here, all operands are loop invariant.
10359 return V;
10360 }
10361 case scUnknown: {
10362 // If this instruction is evolved from a constant-evolving PHI, compute the
10363 // exit value from the loop without using SCEVs.
10364 const SCEVUnknown *SU = cast<SCEVUnknown>(V);
10366 if (!I)
10367 return V; // This is some other type of SCEVUnknown, just return it.
10368
10369 if (PHINode *PN = dyn_cast<PHINode>(I)) {
10370 const Loop *CurrLoop = this->LI[I->getParent()];
10371 // Looking for loop exit value.
10372 if (CurrLoop && CurrLoop->getParentLoop() == L &&
10373 PN->getParent() == CurrLoop->getHeader()) {
10374 // Okay, there is no closed form solution for the PHI node. Check
10375 // to see if the loop that contains it has a known backedge-taken
10376 // count. If so, we may be able to force computation of the exit
10377 // value.
10378 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(CurrLoop);
10379 // This trivial case can show up in some degenerate cases where
10380 // the incoming IR has not yet been fully simplified.
10381 if (BackedgeTakenCount->isZero()) {
10382 Value *InitValue = nullptr;
10383 bool MultipleInitValues = false;
10384 for (unsigned i = 0; i < PN->getNumIncomingValues(); i++) {
10385 if (!CurrLoop->contains(PN->getIncomingBlock(i))) {
10386 if (!InitValue)
10387 InitValue = PN->getIncomingValue(i);
10388 else if (InitValue != PN->getIncomingValue(i)) {
10389 MultipleInitValues = true;
10390 break;
10391 }
10392 }
10393 }
10394 if (!MultipleInitValues && InitValue)
10395 return getSCEV(InitValue);
10396 }
10397 // Do we have a loop invariant value flowing around the backedge
10398 // for a loop which must execute the backedge?
10399 if (!isa<SCEVCouldNotCompute>(BackedgeTakenCount) &&
10400 isKnownNonZero(BackedgeTakenCount) &&
10401 PN->getNumIncomingValues() == 2) {
10402
10403 unsigned InLoopPred =
10404 CurrLoop->contains(PN->getIncomingBlock(0)) ? 0 : 1;
10405 Value *BackedgeVal = PN->getIncomingValue(InLoopPred);
10406 if (CurrLoop->isLoopInvariant(BackedgeVal))
10407 return getSCEV(BackedgeVal);
10408 }
10409 if (auto *BTCC = dyn_cast<SCEVConstant>(BackedgeTakenCount)) {
10410 // Okay, we know how many times the containing loop executes. If
10411 // this is a constant evolving PHI node, get the final value at
10412 // the specified iteration number.
10413 Constant *RV =
10414 getConstantEvolutionLoopExitValue(PN, BTCC->getAPInt(), CurrLoop);
10415 if (RV)
10416 return getSCEV(RV);
10417 }
10418 }
10419 }
10420
10421 // Okay, this is an expression that we cannot symbolically evaluate
10422 // into a SCEV. Check to see if it's possible to symbolically evaluate
10423 // the arguments into constants, and if so, try to constant propagate the
10424 // result. This is particularly useful for computing loop exit values.
10425 if (!CanConstantFold(I))
10426 return V; // This is some other type of SCEVUnknown, just return it.
10427
10428 SmallVector<Constant *, 4> Operands;
10429 Operands.reserve(I->getNumOperands());
10430 bool MadeImprovement = false;
10431 for (Value *Op : I->operands()) {
10432 if (Constant *C = dyn_cast<Constant>(Op)) {
10433 Operands.push_back(C);
10434 continue;
10435 }
10436
10437 // If any of the operands is non-constant and if they are
10438 // non-integer and non-pointer, don't even try to analyze them
10439 // with scev techniques.
10440 if (!isSCEVable(Op->getType()))
10441 return V;
10442
10443 const SCEV *OrigV = getSCEV(Op);
10444 const SCEV *OpV = getSCEVAtScope(OrigV, L);
10445 MadeImprovement |= OrigV != OpV;
10446
10448 if (!C)
10449 return V;
10450 assert(C->getType() == Op->getType() && "Type mismatch");
10451 Operands.push_back(C);
10452 }
10453
10454 // Check to see if getSCEVAtScope actually made an improvement.
10455 if (!MadeImprovement)
10456 return V; // This is some other type of SCEVUnknown, just return it.
10457
10458 Constant *C = nullptr;
10459 const DataLayout &DL = getDataLayout();
10460 C = ConstantFoldInstOperands(I, Operands, DL, &TLI,
10461 /*AllowNonDeterministic=*/false);
10462 if (!C)
10463 return V;
10464 return getSCEV(C);
10465 }
10466 case scCouldNotCompute:
10467 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
10468 }
10469 llvm_unreachable("Unknown SCEV type!");
10470}
10471
10473 return getSCEVAtScope(getSCEV(V), L);
10474}
10475
10476const SCEV *ScalarEvolution::stripInjectiveFunctions(const SCEV *S) const {
10478 return stripInjectiveFunctions(ZExt->getOperand());
10480 return stripInjectiveFunctions(SExt->getOperand());
10481 return S;
10482}
10483
10484/// Finds the minimum unsigned root of the following equation:
10485///
10486/// A * X = B (mod N)
10487///
10488/// where N = 2^BW and BW is the common bit width of A and B. The signedness of
10489/// A and B isn't important.
10490///
10491/// If the equation does not have a solution, SCEVCouldNotCompute is returned.
10492static const SCEV *
10495 ScalarEvolution &SE, const Loop *L) {
10496 uint32_t BW = A.getBitWidth();
10497 assert(BW == SE.getTypeSizeInBits(B->getType()));
10498 assert(A != 0 && "A must be non-zero.");
10499
10500 // 1. D = gcd(A, N)
10501 //
10502 // The gcd of A and N may have only one prime factor: 2. The number of
10503 // trailing zeros in A is its multiplicity
10504 uint32_t Mult2 = A.countr_zero();
10505 // D = 2^Mult2
10506
10507 // 2. Check if B is divisible by D.
10508 //
10509 // B is divisible by D if and only if the multiplicity of prime factor 2 for B
10510 // is not less than multiplicity of this prime factor for D.
10511 unsigned MinTZ = SE.getMinTrailingZeros(B);
10512 // Try again with the terminator of the loop predecessor for context-specific
10513 // result, if MinTZ s too small.
10514 if (MinTZ < Mult2 && L->getLoopPredecessor())
10515 MinTZ = SE.getMinTrailingZeros(B, L->getLoopPredecessor()->getTerminator());
10516 if (MinTZ < Mult2) {
10517 // Check if we can prove there's no remainder using URem.
10518 const SCEV *URem =
10519 SE.getURemExpr(B, SE.getConstant(APInt::getOneBitSet(BW, Mult2)));
10520 const SCEV *Zero = SE.getZero(B->getType());
10521 if (!SE.isKnownPredicate(CmpInst::ICMP_EQ, URem, Zero)) {
10522 // Try to add a predicate ensuring B is a multiple of 1 << Mult2.
10523 if (!Predicates)
10524 return SE.getCouldNotCompute();
10525
10526 // Avoid adding a predicate that is known to be false.
10527 if (SE.isKnownPredicate(CmpInst::ICMP_NE, URem, Zero))
10528 return SE.getCouldNotCompute();
10529 Predicates->push_back(SE.getEqualPredicate(URem, Zero));
10530 }
10531 }
10532
10533 // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic
10534 // modulo (N / D).
10535 //
10536 // If D == 1, (N / D) == N == 2^BW, so we need one extra bit to represent
10537 // (N / D) in general. The inverse itself always fits into BW bits, though,
10538 // so we immediately truncate it.
10539 APInt AD = A.lshr(Mult2).trunc(BW - Mult2); // AD = A / D
10540 APInt I = AD.multiplicativeInverse().zext(BW);
10541
10542 // 4. Compute the minimum unsigned root of the equation:
10543 // I * (B / D) mod (N / D)
10544 // To simplify the computation, we factor out the divide by D:
10545 // (I * B mod N) / D
10546 const SCEV *D = SE.getConstant(APInt::getOneBitSet(BW, Mult2));
10547 return SE.getUDivExactExpr(SE.getMulExpr(B, SE.getConstant(I)), D);
10548}
10549
10550/// For a given quadratic addrec, generate coefficients of the corresponding
10551/// quadratic equation, multiplied by a common value to ensure that they are
10552/// integers.
10553/// The returned value is a tuple { A, B, C, M, BitWidth }, where
10554/// Ax^2 + Bx + C is the quadratic function, M is the value that A, B and C
10555/// were multiplied by, and BitWidth is the bit width of the original addrec
10556/// coefficients.
10557/// This function returns std::nullopt if the addrec coefficients are not
10558/// compile- time constants.
10559static std::optional<std::tuple<APInt, APInt, APInt, APInt, unsigned>>
10561 assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!");
10562 const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0));
10563 const SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1));
10564 const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2));
10565 LLVM_DEBUG(dbgs() << __func__ << ": analyzing quadratic addrec: "
10566 << *AddRec << '\n');
10567
10568 // We currently can only solve this if the coefficients are constants.
10569 if (!LC || !MC || !NC) {
10570 LLVM_DEBUG(dbgs() << __func__ << ": coefficients are not constant\n");
10571 return std::nullopt;
10572 }
10573
10574 APInt L = LC->getAPInt();
10575 APInt M = MC->getAPInt();
10576 APInt N = NC->getAPInt();
10577 assert(!N.isZero() && "This is not a quadratic addrec");
10578
10579 unsigned BitWidth = LC->getAPInt().getBitWidth();
10580 unsigned NewWidth = BitWidth + 1;
10581 LLVM_DEBUG(dbgs() << __func__ << ": addrec coeff bw: "
10582 << BitWidth << '\n');
10583 // The sign-extension (as opposed to a zero-extension) here matches the
10584 // extension used in SolveQuadraticEquationWrap (with the same motivation).
10585 N = N.sext(NewWidth);
10586 M = M.sext(NewWidth);
10587 L = L.sext(NewWidth);
10588
10589 // The increments are M, M+N, M+2N, ..., so the accumulated values are
10590 // L+M, (L+M)+(M+N), (L+M)+(M+N)+(M+2N), ..., that is,
10591 // L+M, L+2M+N, L+3M+3N, ...
10592 // After n iterations the accumulated value Acc is L + nM + n(n-1)/2 N.
10593 //
10594 // The equation Acc = 0 is then
10595 // L + nM + n(n-1)/2 N = 0, or 2L + 2M n + n(n-1) N = 0.
10596 // In a quadratic form it becomes:
10597 // N n^2 + (2M-N) n + 2L = 0.
10598
10599 APInt A = N;
10600 APInt B = 2 * M - A;
10601 APInt C = 2 * L;
10602 APInt T = APInt(NewWidth, 2);
10603 LLVM_DEBUG(dbgs() << __func__ << ": equation " << A << "x^2 + " << B
10604 << "x + " << C << ", coeff bw: " << NewWidth
10605 << ", multiplied by " << T << '\n');
10606 return std::make_tuple(A, B, C, T, BitWidth);
10607}
10608
10609/// Helper function to compare optional APInts:
10610/// (a) if X and Y both exist, return min(X, Y),
10611/// (b) if neither X nor Y exist, return std::nullopt,
10612/// (c) if exactly one of X and Y exists, return that value.
10613static std::optional<APInt> MinOptional(std::optional<APInt> X,
10614 std::optional<APInt> Y) {
10615 if (X && Y) {
10616 unsigned W = std::max(X->getBitWidth(), Y->getBitWidth());
10617 APInt XW = X->sext(W);
10618 APInt YW = Y->sext(W);
10619 return XW.slt(YW) ? *X : *Y;
10620 }
10621 if (!X && !Y)
10622 return std::nullopt;
10623 return X ? *X : *Y;
10624}
10625
10626/// Helper function to truncate an optional APInt to a given BitWidth.
10627/// When solving addrec-related equations, it is preferable to return a value
10628/// that has the same bit width as the original addrec's coefficients. If the
10629/// solution fits in the original bit width, truncate it (except for i1).
10630/// Returning a value of a different bit width may inhibit some optimizations.
10631///
10632/// In general, a solution to a quadratic equation generated from an addrec
10633/// may require BW+1 bits, where BW is the bit width of the addrec's
10634/// coefficients. The reason is that the coefficients of the quadratic
10635/// equation are BW+1 bits wide (to avoid truncation when converting from
10636/// the addrec to the equation).
10637static std::optional<APInt> TruncIfPossible(std::optional<APInt> X,
10638 unsigned BitWidth) {
10639 if (!X)
10640 return std::nullopt;
10641 unsigned W = X->getBitWidth();
10643 return X->trunc(BitWidth);
10644 return X;
10645}
10646
10647/// Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n
10648/// iterations. The values L, M, N are assumed to be signed, and they
10649/// should all have the same bit widths.
10650/// Find the least n >= 0 such that c(n) = 0 in the arithmetic modulo 2^BW,
10651/// where BW is the bit width of the addrec's coefficients.
10652/// If the calculated value is a BW-bit integer (for BW > 1), it will be
10653/// returned as such, otherwise the bit width of the returned value may
10654/// be greater than BW.
10655///
10656/// This function returns std::nullopt if
10657/// (a) the addrec coefficients are not constant, or
10658/// (b) SolveQuadraticEquationWrap was unable to find a solution. For cases
10659/// like x^2 = 5, no integer solutions exist, in other cases an integer
10660/// solution may exist, but SolveQuadraticEquationWrap may fail to find it.
10661static std::optional<APInt>
10663 APInt A, B, C, M;
10664 unsigned BitWidth;
10665 auto T = GetQuadraticEquation(AddRec);
10666 if (!T)
10667 return std::nullopt;
10668
10669 std::tie(A, B, C, M, BitWidth) = *T;
10670 LLVM_DEBUG(dbgs() << __func__ << ": solving for unsigned overflow\n");
10671 std::optional<APInt> X =
10673 if (!X)
10674 return std::nullopt;
10675
10676 ConstantInt *CX = ConstantInt::get(SE.getContext(), *X);
10677 ConstantInt *V = EvaluateConstantChrecAtConstant(AddRec, CX, SE);
10678 if (!V->isZero())
10679 return std::nullopt;
10680
10681 return TruncIfPossible(X, BitWidth);
10682}
10683
10684/// Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n
10685/// iterations. The values M, N are assumed to be signed, and they
10686/// should all have the same bit widths.
10687/// Find the least n such that c(n) does not belong to the given range,
10688/// while c(n-1) does.
10689///
10690/// This function returns std::nullopt if
10691/// (a) the addrec coefficients are not constant, or
10692/// (b) SolveQuadraticEquationWrap was unable to find a solution for the
10693/// bounds of the range.
10694static std::optional<APInt>
10696 const ConstantRange &Range, ScalarEvolution &SE) {
10697 assert(AddRec->getOperand(0)->isZero() &&
10698 "Starting value of addrec should be 0");
10699 LLVM_DEBUG(dbgs() << __func__ << ": solving boundary crossing for range "
10700 << Range << ", addrec " << *AddRec << '\n');
10701 // This case is handled in getNumIterationsInRange. Here we can assume that
10702 // we start in the range.
10703 assert(Range.contains(APInt(SE.getTypeSizeInBits(AddRec->getType()), 0)) &&
10704 "Addrec's initial value should be in range");
10705
10706 APInt A, B, C, M;
10707 unsigned BitWidth;
10708 auto T = GetQuadraticEquation(AddRec);
10709 if (!T)
10710 return std::nullopt;
10711
10712 // Be careful about the return value: there can be two reasons for not
10713 // returning an actual number. First, if no solutions to the equations
10714 // were found, and second, if the solutions don't leave the given range.
10715 // The first case means that the actual solution is "unknown", the second
10716 // means that it's known, but not valid. If the solution is unknown, we
10717 // cannot make any conclusions.
10718 // Return a pair: the optional solution and a flag indicating if the
10719 // solution was found.
10720 auto SolveForBoundary =
10721 [&](APInt Bound) -> std::pair<std::optional<APInt>, bool> {
10722 // Solve for signed overflow and unsigned overflow, pick the lower
10723 // solution.
10724 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: checking boundary "
10725 << Bound << " (before multiplying by " << M << ")\n");
10726 Bound *= M; // The quadratic equation multiplier.
10727
10728 std::optional<APInt> SO;
10729 if (BitWidth > 1) {
10730 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10731 "signed overflow\n");
10733 }
10734 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10735 "unsigned overflow\n");
10736 std::optional<APInt> UO =
10738
10739 auto LeavesRange = [&] (const APInt &X) {
10740 ConstantInt *C0 = ConstantInt::get(SE.getContext(), X);
10741 ConstantInt *V0 = EvaluateConstantChrecAtConstant(AddRec, C0, SE);
10742 if (Range.contains(V0->getValue()))
10743 return false;
10744 // X should be at least 1, so X-1 is non-negative.
10745 ConstantInt *C1 = ConstantInt::get(SE.getContext(), X-1);
10746 ConstantInt *V1 = EvaluateConstantChrecAtConstant(AddRec, C1, SE);
10747 if (Range.contains(V1->getValue()))
10748 return true;
10749 return false;
10750 };
10751
10752 // If SolveQuadraticEquationWrap returns std::nullopt, it means that there
10753 // can be a solution, but the function failed to find it. We cannot treat it
10754 // as "no solution".
10755 if (!SO || !UO)
10756 return {std::nullopt, false};
10757
10758 // Check the smaller value first to see if it leaves the range.
10759 // At this point, both SO and UO must have values.
10760 std::optional<APInt> Min = MinOptional(SO, UO);
10761 if (LeavesRange(*Min))
10762 return { Min, true };
10763 std::optional<APInt> Max = Min == SO ? UO : SO;
10764 if (LeavesRange(*Max))
10765 return { Max, true };
10766
10767 // Solutions were found, but were eliminated, hence the "true".
10768 return {std::nullopt, true};
10769 };
10770
10771 std::tie(A, B, C, M, BitWidth) = *T;
10772 // Lower bound is inclusive, subtract 1 to represent the exiting value.
10773 APInt Lower = Range.getLower().sext(A.getBitWidth()) - 1;
10774 APInt Upper = Range.getUpper().sext(A.getBitWidth());
10775 auto SL = SolveForBoundary(Lower);
10776 auto SU = SolveForBoundary(Upper);
10777 // If any of the solutions was unknown, no meaninigful conclusions can
10778 // be made.
10779 if (!SL.second || !SU.second)
10780 return std::nullopt;
10781
10782 // Claim: The correct solution is not some value between Min and Max.
10783 //
10784 // Justification: Assuming that Min and Max are different values, one of
10785 // them is when the first signed overflow happens, the other is when the
10786 // first unsigned overflow happens. Crossing the range boundary is only
10787 // possible via an overflow (treating 0 as a special case of it, modeling
10788 // an overflow as crossing k*2^W for some k).
10789 //
10790 // The interesting case here is when Min was eliminated as an invalid
10791 // solution, but Max was not. The argument is that if there was another
10792 // overflow between Min and Max, it would also have been eliminated if
10793 // it was considered.
10794 //
10795 // For a given boundary, it is possible to have two overflows of the same
10796 // type (signed/unsigned) without having the other type in between: this
10797 // can happen when the vertex of the parabola is between the iterations
10798 // corresponding to the overflows. This is only possible when the two
10799 // overflows cross k*2^W for the same k. In such case, if the second one
10800 // left the range (and was the first one to do so), the first overflow
10801 // would have to enter the range, which would mean that either we had left
10802 // the range before or that we started outside of it. Both of these cases
10803 // are contradictions.
10804 //
10805 // Claim: In the case where SolveForBoundary returns std::nullopt, the correct
10806 // solution is not some value between the Max for this boundary and the
10807 // Min of the other boundary.
10808 //
10809 // Justification: Assume that we had such Max_A and Min_B corresponding
10810 // to range boundaries A and B and such that Max_A < Min_B. If there was
10811 // a solution between Max_A and Min_B, it would have to be caused by an
10812 // overflow corresponding to either A or B. It cannot correspond to B,
10813 // since Min_B is the first occurrence of such an overflow. If it
10814 // corresponded to A, it would have to be either a signed or an unsigned
10815 // overflow that is larger than both eliminated overflows for A. But
10816 // between the eliminated overflows and this overflow, the values would
10817 // cover the entire value space, thus crossing the other boundary, which
10818 // is a contradiction.
10819
10820 return TruncIfPossible(MinOptional(SL.first, SU.first), BitWidth);
10821}
10822
10823ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
10824 const Loop *L,
10825 bool ControlsOnlyExit,
10826 bool AllowPredicates) {
10827
10828 // This is only used for loops with a "x != y" exit test. The exit condition
10829 // is now expressed as a single expression, V = x-y. So the exit test is
10830 // effectively V != 0. We know and take advantage of the fact that this
10831 // expression only being used in a comparison by zero context.
10832
10834 // If the value is a constant
10835 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10836 // If the value is already zero, the branch will execute zero times.
10837 if (C->getValue()->isZero()) return C;
10838 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10839 }
10840
10841 const SCEVAddRecExpr *AddRec =
10842 dyn_cast<SCEVAddRecExpr>(stripInjectiveFunctions(V));
10843
10844 if (!AddRec && AllowPredicates)
10845 // Try to make this an AddRec using runtime tests, in the first X
10846 // iterations of this loop, where X is the SCEV expression found by the
10847 // algorithm below.
10848 AddRec = convertSCEVToAddRecWithPredicates(V, L, Predicates);
10849
10850 if (!AddRec || AddRec->getLoop() != L)
10851 return getCouldNotCompute();
10852
10853 // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of
10854 // the quadratic equation to solve it.
10855 if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) {
10856 // We can only use this value if the chrec ends up with an exact zero
10857 // value at this index. When solving for "X*X != 5", for example, we
10858 // should not accept a root of 2.
10859 if (auto S = SolveQuadraticAddRecExact(AddRec, *this)) {
10860 const auto *R = cast<SCEVConstant>(getConstant(*S));
10861 return ExitLimit(R, R, R, false, Predicates);
10862 }
10863 return getCouldNotCompute();
10864 }
10865
10866 // Otherwise we can only handle this if it is affine.
10867 if (!AddRec->isAffine())
10868 return getCouldNotCompute();
10869
10870 // If this is an affine expression, the execution count of this branch is
10871 // the minimum unsigned root of the following equation:
10872 //
10873 // Start + Step*N = 0 (mod 2^BW)
10874 //
10875 // equivalent to:
10876 //
10877 // Step*N = -Start (mod 2^BW)
10878 //
10879 // where BW is the common bit width of Start and Step.
10880
10881 // Get the initial value for the loop.
10882 const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop());
10883 const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop());
10884
10885 if (!isLoopInvariant(Step, L))
10886 return getCouldNotCompute();
10887
10888 LoopGuards Guards = LoopGuards::collect(L, *this);
10889 // Specialize step for this loop so we get context sensitive facts below.
10890 const SCEV *StepWLG = applyLoopGuards(Step, Guards);
10891
10892 // For positive steps (counting up until unsigned overflow):
10893 // N = -Start/Step (as unsigned)
10894 // For negative steps (counting down to zero):
10895 // N = Start/-Step
10896 // First compute the unsigned distance from zero in the direction of Step.
10897 bool CountDown = isKnownNegative(StepWLG);
10898 if (!CountDown && !isKnownNonNegative(StepWLG))
10899 return getCouldNotCompute();
10900
10901 const SCEV *Distance = CountDown ? Start : getNegativeSCEV(Start);
10902 // Handle unitary steps, which cannot wraparound.
10903 // 1*N = -Start; -1*N = Start (mod 2^BW), so:
10904 // N = Distance (as unsigned)
10905
10906 if (match(Step, m_CombineOr(m_scev_One(), m_scev_AllOnes()))) {
10907 APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, Guards));
10908 MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance));
10909
10910 // When a loop like "for (int i = 0; i != n; ++i) { /* body */ }" is rotated,
10911 // we end up with a loop whose backedge-taken count is n - 1. Detect this
10912 // case, and see if we can improve the bound.
10913 //
10914 // Explicitly handling this here is necessary because getUnsignedRange
10915 // isn't context-sensitive; it doesn't know that we only care about the
10916 // range inside the loop.
10917 const SCEV *Zero = getZero(Distance->getType());
10918 const SCEV *One = getOne(Distance->getType());
10919 const SCEV *DistancePlusOne = getAddExpr(Distance, One);
10920 if (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, DistancePlusOne, Zero)) {
10921 // If Distance + 1 doesn't overflow, we can compute the maximum distance
10922 // as "unsigned_max(Distance + 1) - 1".
10923 ConstantRange CR = getUnsignedRange(DistancePlusOne);
10924 MaxBECount = APIntOps::umin(MaxBECount, CR.getUnsignedMax() - 1);
10925 }
10926 return ExitLimit(Distance, getConstant(MaxBECount), Distance, false,
10927 Predicates);
10928 }
10929
10930 // If the condition controls loop exit (the loop exits only if the expression
10931 // is true) and the addition is no-wrap we can use unsigned divide to
10932 // compute the backedge count. In this case, the step may not divide the
10933 // distance, but we don't care because if the condition is "missed" the loop
10934 // will have undefined behavior due to wrapping.
10935 if (ControlsOnlyExit && AddRec->hasNoSelfWrap() &&
10936 loopHasNoAbnormalExits(AddRec->getLoop())) {
10937
10938 // If the stride is zero and the start is non-zero, the loop must be
10939 // infinite. In C++, most loops are finite by assumption, in which case the
10940 // step being zero implies UB must execute if the loop is entered.
10941 if (!(loopIsFiniteByAssumption(L) && isKnownNonZero(Start)) &&
10942 !isKnownNonZero(StepWLG))
10943 return getCouldNotCompute();
10944
10945 const SCEV *Exact =
10946 getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
10947 const SCEV *ConstantMax = getCouldNotCompute();
10948 if (Exact != getCouldNotCompute()) {
10949 APInt MaxInt = getUnsignedRangeMax(applyLoopGuards(Exact, Guards));
10950 ConstantMax =
10952 }
10953 const SCEV *SymbolicMax =
10954 isa<SCEVCouldNotCompute>(Exact) ? ConstantMax : Exact;
10955 return ExitLimit(Exact, ConstantMax, SymbolicMax, false, Predicates);
10956 }
10957
10958 // Solve the general equation.
10959 const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
10960 if (!StepC || StepC->getValue()->isZero())
10961 return getCouldNotCompute();
10962 const SCEV *E = SolveLinEquationWithOverflow(
10963 StepC->getAPInt(), getNegativeSCEV(Start),
10964 AllowPredicates ? &Predicates : nullptr, *this, L);
10965
10966 const SCEV *M = E;
10967 if (E != getCouldNotCompute()) {
10968 APInt MaxWithGuards = getUnsignedRangeMax(applyLoopGuards(E, Guards));
10969 M = getConstant(APIntOps::umin(MaxWithGuards, getUnsignedRangeMax(E)));
10970 }
10971 auto *S = isa<SCEVCouldNotCompute>(E) ? M : E;
10972 return ExitLimit(E, M, S, false, Predicates);
10973}
10974
10975ScalarEvolution::ExitLimit
10976ScalarEvolution::howFarToNonZero(const SCEV *V, const Loop *L) {
10977 // Loops that look like: while (X == 0) are very strange indeed. We don't
10978 // handle them yet except for the trivial case. This could be expanded in the
10979 // future as needed.
10980
10981 // If the value is a constant, check to see if it is known to be non-zero
10982 // already. If so, the backedge will execute zero times.
10983 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10984 if (!C->getValue()->isZero())
10985 return getZero(C->getType());
10986 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10987 }
10988
10989 // We could implement others, but I really doubt anyone writes loops like
10990 // this, and if they did, they would already be constant folded.
10991 return getCouldNotCompute();
10992}
10993
10994std::pair<const BasicBlock *, const BasicBlock *>
10995ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(const BasicBlock *BB)
10996 const {
10997 // If the block has a unique predecessor, then there is no path from the
10998 // predecessor to the block that does not go through the direct edge
10999 // from the predecessor to the block.
11000 if (const BasicBlock *Pred = BB->getSinglePredecessor())
11001 return {Pred, BB};
11002
11003 // A loop's header is defined to be a block that dominates the loop.
11004 // If the header has a unique predecessor outside the loop, it must be
11005 // a block that has exactly one successor that can reach the loop.
11006 if (const Loop *L = LI.getLoopFor(BB))
11007 return {L->getLoopPredecessor(), L->getHeader()};
11008
11009 return {nullptr, BB};
11010}
11011
11012/// SCEV structural equivalence is usually sufficient for testing whether two
11013/// expressions are equal, however for the purposes of looking for a condition
11014/// guarding a loop, it can be useful to be a little more general, since a
11015/// front-end may have replicated the controlling expression.
11016static bool HasSameValue(const SCEV *A, const SCEV *B) {
11017 // Quick check to see if they are the same SCEV.
11018 if (A == B) return true;
11019
11020 auto ComputesEqualValues = [](const Instruction *A, const Instruction *B) {
11021 // Not all instructions that are "identical" compute the same value. For
11022 // instance, two distinct alloca instructions allocating the same type are
11023 // identical and do not read memory; but compute distinct values.
11024 return A->isIdenticalTo(B) && (isa<BinaryOperator>(A) || isa<GetElementPtrInst>(A));
11025 };
11026
11027 // Otherwise, if they're both SCEVUnknown, it's possible that they hold
11028 // two different instructions with the same value. Check for this case.
11029 if (const SCEVUnknown *AU = dyn_cast<SCEVUnknown>(A))
11030 if (const SCEVUnknown *BU = dyn_cast<SCEVUnknown>(B))
11031 if (const Instruction *AI = dyn_cast<Instruction>(AU->getValue()))
11032 if (const Instruction *BI = dyn_cast<Instruction>(BU->getValue()))
11033 if (ComputesEqualValues(AI, BI))
11034 return true;
11035
11036 // Otherwise assume they may have a different value.
11037 return false;
11038}
11039
11040static bool MatchBinarySub(const SCEV *S, SCEVUse &LHS, SCEVUse &RHS) {
11041 const SCEV *Op0, *Op1;
11042 if (!match(S, m_scev_Add(m_SCEV(Op0), m_SCEV(Op1))))
11043 return false;
11044 if (match(Op0, m_scev_Mul(m_scev_AllOnes(), m_SCEV(RHS)))) {
11045 LHS = Op1;
11046 return true;
11047 }
11048 if (match(Op1, m_scev_Mul(m_scev_AllOnes(), m_SCEV(RHS)))) {
11049 LHS = Op0;
11050 return true;
11051 }
11052 return false;
11053}
11054
11056 SCEVUse &RHS, unsigned Depth) {
11057 bool Changed = false;
11058 // Simplifies ICMP to trivial true or false by turning it into '0 == 0' or
11059 // '0 != 0'.
11060 auto TrivialCase = [&](bool TriviallyTrue) {
11062 Pred = TriviallyTrue ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE;
11063 return true;
11064 };
11065 // If we hit the max recursion limit bail out.
11066 if (Depth >= 3)
11067 return false;
11068
11069 const SCEV *NewLHS, *NewRHS;
11070 if (match(LHS, m_scev_c_Mul(m_SCEV(NewLHS), m_SCEVVScale())) &&
11071 match(RHS, m_scev_c_Mul(m_SCEV(NewRHS), m_SCEVVScale()))) {
11072 const SCEVMulExpr *LMul = cast<SCEVMulExpr>(LHS);
11073 const SCEVMulExpr *RMul = cast<SCEVMulExpr>(RHS);
11074
11075 // (X * vscale) pred (Y * vscale) ==> X pred Y
11076 // when both multiples are NSW.
11077 // (X * vscale) uicmp/eq/ne (Y * vscale) ==> X uicmp/eq/ne Y
11078 // when both multiples are NUW.
11079 if ((LMul->hasNoSignedWrap() && RMul->hasNoSignedWrap()) ||
11080 (LMul->hasNoUnsignedWrap() && RMul->hasNoUnsignedWrap() &&
11081 !ICmpInst::isSigned(Pred))) {
11082 LHS = NewLHS;
11083 RHS = NewRHS;
11084 Changed = true;
11085 }
11086 }
11087
11088 // Canonicalize a constant to the right side.
11089 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
11090 // Check for both operands constant.
11091 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
11092 if (!ICmpInst::compare(LHSC->getAPInt(), RHSC->getAPInt(), Pred))
11093 return TrivialCase(false);
11094 return TrivialCase(true);
11095 }
11096 // Otherwise swap the operands to put the constant on the right.
11097 std::swap(LHS, RHS);
11099 Changed = true;
11100 }
11101
11102 // If we're comparing an addrec with a value which is loop-invariant in the
11103 // addrec's loop, put the addrec on the left. Also make a dominance check,
11104 // as both operands could be addrecs loop-invariant in each other's loop.
11105 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(RHS)) {
11106 const Loop *L = AR->getLoop();
11107 if (isLoopInvariant(LHS, L) && properlyDominates(LHS, L->getHeader())) {
11108 std::swap(LHS, RHS);
11110 Changed = true;
11111 }
11112 }
11113
11114 // If there's a constant operand, canonicalize comparisons with boundary
11115 // cases, and canonicalize *-or-equal comparisons to regular comparisons.
11116 if (const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS)) {
11117 const APInt &RA = RC->getAPInt();
11118
11119 bool SimplifiedByConstantRange = false;
11120
11121 if (!ICmpInst::isEquality(Pred)) {
11123 if (ExactCR.isFullSet())
11124 return TrivialCase(true);
11125 if (ExactCR.isEmptySet())
11126 return TrivialCase(false);
11127
11128 APInt NewRHS;
11129 CmpInst::Predicate NewPred;
11130 if (ExactCR.getEquivalentICmp(NewPred, NewRHS) &&
11131 ICmpInst::isEquality(NewPred)) {
11132 // We were able to convert an inequality to an equality.
11133 Pred = NewPred;
11134 RHS = getConstant(NewRHS);
11135 Changed = SimplifiedByConstantRange = true;
11136 }
11137 }
11138
11139 if (!SimplifiedByConstantRange) {
11140 switch (Pred) {
11141 default:
11142 break;
11143 case ICmpInst::ICMP_EQ:
11144 case ICmpInst::ICMP_NE:
11145 // Fold ((-1) * %a) + %b == 0 (equivalent to %b-%a == 0) into %a == %b.
11146 if (RA.isZero() && MatchBinarySub(LHS, LHS, RHS))
11147 Changed = true;
11148 break;
11149
11150 // The "Should have been caught earlier!" messages refer to the fact
11151 // that the ExactCR.isFullSet() or ExactCR.isEmptySet() check above
11152 // should have fired on the corresponding cases, and canonicalized the
11153 // check to trivial case.
11154
11155 case ICmpInst::ICMP_UGE:
11156 assert(!RA.isMinValue() && "Should have been caught earlier!");
11157 Pred = ICmpInst::ICMP_UGT;
11158 RHS = getConstant(RA - 1);
11159 Changed = true;
11160 break;
11161 case ICmpInst::ICMP_ULE:
11162 assert(!RA.isMaxValue() && "Should have been caught earlier!");
11163 Pred = ICmpInst::ICMP_ULT;
11164 RHS = getConstant(RA + 1);
11165 Changed = true;
11166 break;
11167 case ICmpInst::ICMP_SGE:
11168 assert(!RA.isMinSignedValue() && "Should have been caught earlier!");
11169 Pred = ICmpInst::ICMP_SGT;
11170 RHS = getConstant(RA - 1);
11171 Changed = true;
11172 break;
11173 case ICmpInst::ICMP_SLE:
11174 assert(!RA.isMaxSignedValue() && "Should have been caught earlier!");
11175 Pred = ICmpInst::ICMP_SLT;
11176 RHS = getConstant(RA + 1);
11177 Changed = true;
11178 break;
11179 }
11180 }
11181 }
11182
11183 // Check for obvious equality.
11184 if (HasSameValue(LHS, RHS)) {
11185 if (ICmpInst::isTrueWhenEqual(Pred))
11186 return TrivialCase(true);
11188 return TrivialCase(false);
11189 }
11190
11191 // If possible, canonicalize GE/LE comparisons to GT/LT comparisons, by
11192 // adding or subtracting 1 from one of the operands.
11193 switch (Pred) {
11194 case ICmpInst::ICMP_SLE:
11195 if (!getSignedRangeMax(RHS).isMaxSignedValue()) {
11196 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
11198 Pred = ICmpInst::ICMP_SLT;
11199 Changed = true;
11200 } else if (!getSignedRangeMin(LHS).isMinSignedValue()) {
11201 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS,
11203 Pred = ICmpInst::ICMP_SLT;
11204 Changed = true;
11205 }
11206 break;
11207 case ICmpInst::ICMP_SGE:
11208 if (!getSignedRangeMin(RHS).isMinSignedValue()) {
11209 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS,
11211 Pred = ICmpInst::ICMP_SGT;
11212 Changed = true;
11213 } else if (!getSignedRangeMax(LHS).isMaxSignedValue()) {
11214 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
11216 Pred = ICmpInst::ICMP_SGT;
11217 Changed = true;
11218 }
11219 break;
11220 case ICmpInst::ICMP_ULE:
11221 if (!getUnsignedRangeMax(RHS).isMaxValue()) {
11222 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
11224 Pred = ICmpInst::ICMP_ULT;
11225 Changed = true;
11226 } else if (!getUnsignedRangeMin(LHS).isMinValue()) {
11227 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS);
11228 Pred = ICmpInst::ICMP_ULT;
11229 Changed = true;
11230 }
11231 break;
11232 case ICmpInst::ICMP_UGE:
11233 // If RHS is an op we can fold the -1, try that first.
11234 // Otherwise prefer LHS to preserve the nuw flag.
11235 if ((isa<SCEVConstant>(RHS) ||
11237 isa<SCEVConstant>(cast<SCEVNAryExpr>(RHS)->getOperand(0)))) &&
11238 !getUnsignedRangeMin(RHS).isMinValue()) {
11239 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
11240 Pred = ICmpInst::ICMP_UGT;
11241 Changed = true;
11242 } else if (!getUnsignedRangeMax(LHS).isMaxValue()) {
11243 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
11245 Pred = ICmpInst::ICMP_UGT;
11246 Changed = true;
11247 } else if (!getUnsignedRangeMin(RHS).isMinValue()) {
11248 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
11249 Pred = ICmpInst::ICMP_UGT;
11250 Changed = true;
11251 }
11252 break;
11253 default:
11254 break;
11255 }
11256
11257 // TODO: More simplifications are possible here.
11258
11259 // Recursively simplify until we either hit a recursion limit or nothing
11260 // changes.
11261 if (Changed)
11262 (void)SimplifyICmpOperands(Pred, LHS, RHS, Depth + 1);
11263
11264 return Changed;
11265}
11266
11268 return getSignedRangeMax(S).isNegative();
11269}
11270
11274
11276 return !getSignedRangeMin(S).isNegative();
11277}
11278
11282
11284 // Query push down for cases where the unsigned range is
11285 // less than sufficient.
11286 if (const auto *SExt = dyn_cast<SCEVSignExtendExpr>(S))
11287 return isKnownNonZero(SExt->getOperand(0));
11288 return getUnsignedRangeMin(S) != 0;
11289}
11290
11292 bool OrNegative) {
11293 auto NonRecursive = [OrNegative](const SCEV *S) {
11294 if (auto *C = dyn_cast<SCEVConstant>(S))
11295 return C->getAPInt().isPowerOf2() ||
11296 (OrNegative && C->getAPInt().isNegatedPowerOf2());
11297
11298 // vscale is a power-of-two.
11299 return isa<SCEVVScale>(S);
11300 };
11301
11302 if (NonRecursive(S))
11303 return true;
11304
11305 auto *Mul = dyn_cast<SCEVMulExpr>(S);
11306 if (!Mul)
11307 return false;
11308 return all_of(Mul->operands(), NonRecursive) && (OrZero || isKnownNonZero(S));
11309}
11310
11312 const SCEV *S, uint64_t M,
11314 if (M == 0)
11315 return false;
11316 if (M == 1)
11317 return true;
11318
11319 // Recursively check AddRec operands. An AddRecExpr S is a multiple of M if S
11320 // starts with a multiple of M and at every iteration step S only adds
11321 // multiples of M.
11322 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S))
11323 return isKnownMultipleOf(AddRec->getStart(), M, Assumptions) &&
11324 isKnownMultipleOf(AddRec->getStepRecurrence(*this), M, Assumptions);
11325
11326 // For a constant, check that "S % M == 0".
11327 if (auto *Cst = dyn_cast<SCEVConstant>(S)) {
11328 APInt C = Cst->getAPInt();
11329 return C.urem(M) == 0;
11330 }
11331
11332 // TODO: Also check other SCEV expressions, i.e., SCEVAddRecExpr, etc.
11333
11334 // Basic tests have failed.
11335 // Check "S % M == 0" at compile time and record runtime Assumptions.
11336 auto *STy = dyn_cast<IntegerType>(S->getType());
11337 const SCEV *SmodM =
11338 getURemExpr(S, getConstant(ConstantInt::get(STy, M, false)));
11339 const SCEV *Zero = getZero(STy);
11340
11341 // Check whether "S % M == 0" is known at compile time.
11342 if (isKnownPredicate(ICmpInst::ICMP_EQ, SmodM, Zero))
11343 return true;
11344
11345 // Check whether "S % M != 0" is known at compile time.
11346 if (isKnownPredicate(ICmpInst::ICMP_NE, SmodM, Zero))
11347 return false;
11348
11350
11351 // Detect redundant predicates.
11352 for (auto *A : Assumptions)
11353 if (A->implies(P, *this))
11354 return true;
11355
11356 // Only record non-redundant predicates.
11357 Assumptions.push_back(P);
11358 return true;
11359}
11360
11362 return ((isKnownNonNegative(S1) && isKnownNonNegative(S2)) ||
11364}
11365
11366std::pair<const SCEV *, const SCEV *>
11368 // Compute SCEV on entry of loop L.
11369 const SCEV *Start = SCEVInitRewriter::rewrite(S, L, *this);
11370 if (Start == getCouldNotCompute())
11371 return { Start, Start };
11372 // Compute post increment SCEV for loop L.
11373 const SCEV *PostInc = SCEVPostIncRewriter::rewrite(S, L, *this);
11374 assert(PostInc != getCouldNotCompute() && "Unexpected could not compute");
11375 return { Start, PostInc };
11376}
11377
11379 SCEVUse RHS) {
11380 // First collect all loops.
11382 getUsedLoops(LHS, LoopsUsed);
11383 getUsedLoops(RHS, LoopsUsed);
11384
11385 if (LoopsUsed.empty())
11386 return false;
11387
11388 // Domination relationship must be a linear order on collected loops.
11389#ifndef NDEBUG
11390 for (const auto *L1 : LoopsUsed)
11391 for (const auto *L2 : LoopsUsed)
11392 assert((DT.dominates(L1->getHeader(), L2->getHeader()) ||
11393 DT.dominates(L2->getHeader(), L1->getHeader())) &&
11394 "Domination relationship is not a linear order");
11395#endif
11396
11397 const Loop *MDL =
11398 *llvm::max_element(LoopsUsed, [&](const Loop *L1, const Loop *L2) {
11399 return DT.properlyDominates(L1->getHeader(), L2->getHeader());
11400 });
11401
11402 // Get init and post increment value for LHS.
11403 auto SplitLHS = SplitIntoInitAndPostInc(MDL, LHS);
11404 // if LHS contains unknown non-invariant SCEV then bail out.
11405 if (SplitLHS.first == getCouldNotCompute())
11406 return false;
11407 assert (SplitLHS.second != getCouldNotCompute() && "Unexpected CNC");
11408 // Get init and post increment value for RHS.
11409 auto SplitRHS = SplitIntoInitAndPostInc(MDL, RHS);
11410 // if RHS contains unknown non-invariant SCEV then bail out.
11411 if (SplitRHS.first == getCouldNotCompute())
11412 return false;
11413 assert (SplitRHS.second != getCouldNotCompute() && "Unexpected CNC");
11414 // It is possible that init SCEV contains an invariant load but it does
11415 // not dominate MDL and is not available at MDL loop entry, so we should
11416 // check it here.
11417 if (!isAvailableAtLoopEntry(SplitLHS.first, MDL) ||
11418 !isAvailableAtLoopEntry(SplitRHS.first, MDL))
11419 return false;
11420
11421 // It seems backedge guard check is faster than entry one so in some cases
11422 // it can speed up whole estimation by short circuit
11423 return isLoopBackedgeGuardedByCond(MDL, Pred, SplitLHS.second,
11424 SplitRHS.second) &&
11425 isLoopEntryGuardedByCond(MDL, Pred, SplitLHS.first, SplitRHS.first);
11426}
11427
11429 SCEVUse RHS) {
11430 // Canonicalize the inputs first.
11431 (void)SimplifyICmpOperands(Pred, LHS, RHS);
11432
11433 if (isKnownViaInduction(Pred, LHS, RHS))
11434 return true;
11435
11436 if (isKnownPredicateViaSplitting(Pred, LHS, RHS))
11437 return true;
11438
11439 // Otherwise see what can be done with some simple reasoning.
11440 return isKnownViaNonRecursiveReasoning(Pred, LHS, RHS);
11441}
11442
11444 const SCEV *LHS,
11445 const SCEV *RHS) {
11446 if (isKnownPredicate(Pred, LHS, RHS))
11447 return true;
11449 return false;
11450 return std::nullopt;
11451}
11452
11454 const SCEV *RHS,
11455 const Instruction *CtxI) {
11456 // TODO: Analyze guards and assumes from Context's block.
11457 return isKnownPredicate(Pred, LHS, RHS) ||
11458 isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS);
11459}
11460
11461std::optional<bool>
11463 const SCEV *RHS, const Instruction *CtxI) {
11464 std::optional<bool> KnownWithoutContext = evaluatePredicate(Pred, LHS, RHS);
11465 if (KnownWithoutContext)
11466 return KnownWithoutContext;
11467
11468 if (isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS))
11469 return true;
11471 CtxI->getParent(), ICmpInst::getInverseCmpPredicate(Pred), LHS, RHS))
11472 return false;
11473 return std::nullopt;
11474}
11475
11477 const SCEVAddRecExpr *LHS,
11478 const SCEV *RHS) {
11479 const Loop *L = LHS->getLoop();
11480 return isLoopEntryGuardedByCond(L, Pred, LHS->getStart(), RHS) &&
11481 isLoopBackedgeGuardedByCond(L, Pred, LHS->getPostIncExpr(*this), RHS);
11482}
11483
11484std::optional<ScalarEvolution::MonotonicPredicateType>
11486 ICmpInst::Predicate Pred) {
11487 auto Result = getMonotonicPredicateTypeImpl(LHS, Pred);
11488
11489#ifndef NDEBUG
11490 // Verify an invariant: inverting the predicate should turn a monotonically
11491 // increasing change to a monotonically decreasing one, and vice versa.
11492 if (Result) {
11493 auto ResultSwapped =
11494 getMonotonicPredicateTypeImpl(LHS, ICmpInst::getSwappedPredicate(Pred));
11495
11496 assert(*ResultSwapped != *Result &&
11497 "monotonicity should flip as we flip the predicate");
11498 }
11499#endif
11500
11501 return Result;
11502}
11503
11504std::optional<ScalarEvolution::MonotonicPredicateType>
11505ScalarEvolution::getMonotonicPredicateTypeImpl(const SCEVAddRecExpr *LHS,
11506 ICmpInst::Predicate Pred) {
11507 // A zero step value for LHS means the induction variable is essentially a
11508 // loop invariant value. We don't really depend on the predicate actually
11509 // flipping from false to true (for increasing predicates, and the other way
11510 // around for decreasing predicates), all we care about is that *if* the
11511 // predicate changes then it only changes from false to true.
11512 //
11513 // A zero step value in itself is not very useful, but there may be places
11514 // where SCEV can prove X >= 0 but not prove X > 0, so it is helpful to be
11515 // as general as possible.
11516
11517 // Only handle LE/LT/GE/GT predicates.
11518 if (!ICmpInst::isRelational(Pred))
11519 return std::nullopt;
11520
11521 bool IsGreater = ICmpInst::isGE(Pred) || ICmpInst::isGT(Pred);
11522 assert((IsGreater || ICmpInst::isLE(Pred) || ICmpInst::isLT(Pred)) &&
11523 "Should be greater or less!");
11524
11525 // Check that AR does not wrap.
11526 if (ICmpInst::isUnsigned(Pred)) {
11527 if (!LHS->hasNoUnsignedWrap())
11528 return std::nullopt;
11530 }
11531 assert(ICmpInst::isSigned(Pred) &&
11532 "Relational predicate is either signed or unsigned!");
11533 if (!LHS->hasNoSignedWrap())
11534 return std::nullopt;
11535
11536 const SCEV *Step = LHS->getStepRecurrence(*this);
11537
11538 if (isKnownNonNegative(Step))
11540
11541 if (isKnownNonPositive(Step))
11543
11544 return std::nullopt;
11545}
11546
11547std::optional<ScalarEvolution::LoopInvariantPredicate>
11549 const SCEV *RHS, const Loop *L,
11550 const Instruction *CtxI) {
11551 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
11552 if (!isLoopInvariant(RHS, L)) {
11553 if (!isLoopInvariant(LHS, L))
11554 return std::nullopt;
11555
11556 std::swap(LHS, RHS);
11558 }
11559
11560 const SCEVAddRecExpr *ArLHS = dyn_cast<SCEVAddRecExpr>(LHS);
11561 if (!ArLHS || ArLHS->getLoop() != L)
11562 return std::nullopt;
11563
11564 auto MonotonicType = getMonotonicPredicateType(ArLHS, Pred);
11565 if (!MonotonicType)
11566 return std::nullopt;
11567 // If the predicate "ArLHS `Pred` RHS" monotonically increases from false to
11568 // true as the loop iterates, and the backedge is control dependent on
11569 // "ArLHS `Pred` RHS" == true then we can reason as follows:
11570 //
11571 // * if the predicate was false in the first iteration then the predicate
11572 // is never evaluated again, since the loop exits without taking the
11573 // backedge.
11574 // * if the predicate was true in the first iteration then it will
11575 // continue to be true for all future iterations since it is
11576 // monotonically increasing.
11577 //
11578 // For both the above possibilities, we can replace the loop varying
11579 // predicate with its value on the first iteration of the loop (which is
11580 // loop invariant).
11581 //
11582 // A similar reasoning applies for a monotonically decreasing predicate, by
11583 // replacing true with false and false with true in the above two bullets.
11585 auto P = Increasing ? Pred : ICmpInst::getInverseCmpPredicate(Pred);
11586
11587 if (isLoopBackedgeGuardedByCond(L, P, LHS, RHS))
11589 RHS);
11590
11591 if (!CtxI)
11592 return std::nullopt;
11593 // Try to prove via context.
11594 // TODO: Support other cases.
11595 switch (Pred) {
11596 default:
11597 break;
11598 case ICmpInst::ICMP_ULE:
11599 case ICmpInst::ICMP_ULT: {
11600 assert(ArLHS->hasNoUnsignedWrap() && "Is a requirement of monotonicity!");
11601 // Given preconditions
11602 // (1) ArLHS does not cross the border of positive and negative parts of
11603 // range because of:
11604 // - Positive step; (TODO: lift this limitation)
11605 // - nuw - does not cross zero boundary;
11606 // - nsw - does not cross SINT_MAX boundary;
11607 // (2) ArLHS <s RHS
11608 // (3) RHS >=s 0
11609 // we can replace the loop variant ArLHS <u RHS condition with loop
11610 // invariant Start(ArLHS) <u RHS.
11611 //
11612 // Because of (1) there are two options:
11613 // - ArLHS is always negative. It means that ArLHS <u RHS is always false;
11614 // - ArLHS is always non-negative. Because of (3) RHS is also non-negative.
11615 // It means that ArLHS <s RHS <=> ArLHS <u RHS.
11616 // Because of (2) ArLHS <u RHS is trivially true.
11617 // All together it means that ArLHS <u RHS <=> Start(ArLHS) >=s 0.
11618 // We can strengthen this to Start(ArLHS) <u RHS.
11619 auto SignFlippedPred = ICmpInst::getFlippedSignednessPredicate(Pred);
11620 if (ArLHS->hasNoSignedWrap() && ArLHS->isAffine() &&
11621 isKnownPositive(ArLHS->getStepRecurrence(*this)) &&
11622 isKnownNonNegative(RHS) &&
11623 isKnownPredicateAt(SignFlippedPred, ArLHS, RHS, CtxI))
11625 RHS);
11626 }
11627 }
11628
11629 return std::nullopt;
11630}
11631
11632std::optional<ScalarEvolution::LoopInvariantPredicate>
11634 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11635 const Instruction *CtxI, const SCEV *MaxIter) {
11637 Pred, LHS, RHS, L, CtxI, MaxIter))
11638 return LIP;
11639 if (auto *UMin = dyn_cast<SCEVUMinExpr>(MaxIter))
11640 // Number of iterations expressed as UMIN isn't always great for expressing
11641 // the value on the last iteration. If the straightforward approach didn't
11642 // work, try the following trick: if the a predicate is invariant for X, it
11643 // is also invariant for umin(X, ...). So try to find something that works
11644 // among subexpressions of MaxIter expressed as umin.
11645 for (SCEVUse Op : UMin->operands())
11647 Pred, LHS, RHS, L, CtxI, Op))
11648 return LIP;
11649 return std::nullopt;
11650}
11651
11652std::optional<ScalarEvolution::LoopInvariantPredicate>
11654 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11655 const Instruction *CtxI, const SCEV *MaxIter) {
11656 // Try to prove the following set of facts:
11657 // - The predicate is monotonic in the iteration space.
11658 // - If the check does not fail on the 1st iteration:
11659 // - No overflow will happen during first MaxIter iterations;
11660 // - It will not fail on the MaxIter'th iteration.
11661 // If the check does fail on the 1st iteration, we leave the loop and no
11662 // other checks matter.
11663
11664 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
11665 if (!isLoopInvariant(RHS, L)) {
11666 if (!isLoopInvariant(LHS, L))
11667 return std::nullopt;
11668
11669 std::swap(LHS, RHS);
11671 }
11672
11673 auto *AR = dyn_cast<SCEVAddRecExpr>(LHS);
11674 if (!AR || AR->getLoop() != L)
11675 return std::nullopt;
11676
11677 // Even if both are valid, we need to consistently chose the unsigned or the
11678 // signed predicate below, not mixtures of both. For now, prefer the unsigned
11679 // predicate.
11680 Pred = Pred.dropSameSign();
11681
11682 // The predicate must be relational (i.e. <, <=, >=, >).
11683 if (!ICmpInst::isRelational(Pred))
11684 return std::nullopt;
11685
11686 // TODO: Support steps other than +/- 1.
11687 const SCEV *Step = AR->getStepRecurrence(*this);
11688 auto *One = getOne(Step->getType());
11689 auto *MinusOne = getNegativeSCEV(One);
11690 if (Step != One && Step != MinusOne)
11691 return std::nullopt;
11692
11693 // Type mismatch here means that MaxIter is potentially larger than max
11694 // unsigned value in start type, which mean we cannot prove no wrap for the
11695 // indvar.
11696 if (AR->getType() != MaxIter->getType())
11697 return std::nullopt;
11698
11699 // Value of IV on suggested last iteration.
11700 const SCEV *Last = AR->evaluateAtIteration(MaxIter, *this);
11701 // Does it still meet the requirement?
11702 if (!isLoopBackedgeGuardedByCond(L, Pred, Last, RHS))
11703 return std::nullopt;
11704 // Because step is +/- 1 and MaxIter has same type as Start (i.e. it does
11705 // not exceed max unsigned value of this type), this effectively proves
11706 // that there is no wrap during the iteration. To prove that there is no
11707 // signed/unsigned wrap, we need to check that
11708 // Start <= Last for step = 1 or Start >= Last for step = -1.
11709 ICmpInst::Predicate NoOverflowPred =
11711 if (Step == MinusOne)
11712 NoOverflowPred = ICmpInst::getSwappedPredicate(NoOverflowPred);
11713 const SCEV *Start = AR->getStart();
11714 if (!isKnownPredicateAt(NoOverflowPred, Start, Last, CtxI))
11715 return std::nullopt;
11716
11717 // Everything is fine.
11718 return ScalarEvolution::LoopInvariantPredicate(Pred, Start, RHS);
11719}
11720
11721bool ScalarEvolution::isKnownPredicateViaConstantRanges(CmpPredicate Pred,
11722 SCEVUse LHS,
11723 SCEVUse RHS) {
11724 if (HasSameValue(LHS, RHS))
11725 return ICmpInst::isTrueWhenEqual(Pred);
11726
11727 auto CheckRange = [&](bool IsSigned) {
11728 auto RangeLHS = IsSigned ? getSignedRange(LHS) : getUnsignedRange(LHS);
11729 auto RangeRHS = IsSigned ? getSignedRange(RHS) : getUnsignedRange(RHS);
11730 return RangeLHS.icmp(Pred, RangeRHS);
11731 };
11732
11733 // The check at the top of the function catches the case where the values are
11734 // known to be equal.
11735 if (Pred == CmpInst::ICMP_EQ)
11736 return false;
11737
11738 if (Pred == CmpInst::ICMP_NE) {
11739 if (CheckRange(true) || CheckRange(false))
11740 return true;
11741 auto *Diff = getMinusSCEV(LHS, RHS);
11742 return !isa<SCEVCouldNotCompute>(Diff) && isKnownNonZero(Diff);
11743 }
11744
11745 return CheckRange(CmpInst::isSigned(Pred));
11746}
11747
11748bool ScalarEvolution::isKnownPredicateViaNoOverflow(CmpPredicate Pred,
11750 // Match X to (A + C1)<ExpectedFlags> and Y to (A + C2)<ExpectedFlags>, where
11751 // C1 and C2 are constant integers. If either X or Y are not add expressions,
11752 // consider them as X + 0 and Y + 0 respectively. C1 and C2 are returned via
11753 // OutC1 and OutC2.
11754 auto MatchBinaryAddToConst = [this](SCEVUse X, SCEVUse Y, APInt &OutC1,
11755 APInt &OutC2,
11756 SCEV::NoWrapFlags ExpectedFlags) {
11757 SCEVUse XNonConstOp, XConstOp;
11758 SCEVUse YNonConstOp, YConstOp;
11759 SCEV::NoWrapFlags XFlagsPresent;
11760 SCEV::NoWrapFlags YFlagsPresent;
11761
11762 if (!splitBinaryAdd(X, XConstOp, XNonConstOp, XFlagsPresent)) {
11763 XConstOp = getZero(X->getType());
11764 XNonConstOp = X;
11765 XFlagsPresent = ExpectedFlags;
11766 }
11767 if (!isa<SCEVConstant>(XConstOp))
11768 return false;
11769
11770 if (!splitBinaryAdd(Y, YConstOp, YNonConstOp, YFlagsPresent)) {
11771 YConstOp = getZero(Y->getType());
11772 YNonConstOp = Y;
11773 YFlagsPresent = ExpectedFlags;
11774 }
11775
11776 if (YNonConstOp != XNonConstOp)
11777 return false;
11778
11779 if (!isa<SCEVConstant>(YConstOp))
11780 return false;
11781
11782 // When matching ADDs with NUW flags (and unsigned predicates), only the
11783 // second ADD (with the larger constant) requires NUW.
11784 if ((YFlagsPresent & ExpectedFlags) != ExpectedFlags)
11785 return false;
11786 if (ExpectedFlags != SCEV::FlagNUW &&
11787 (XFlagsPresent & ExpectedFlags) != ExpectedFlags) {
11788 return false;
11789 }
11790
11791 OutC1 = cast<SCEVConstant>(XConstOp)->getAPInt();
11792 OutC2 = cast<SCEVConstant>(YConstOp)->getAPInt();
11793
11794 return true;
11795 };
11796
11797 APInt C1;
11798 APInt C2;
11799
11800 switch (Pred) {
11801 default:
11802 break;
11803
11804 case ICmpInst::ICMP_SGE:
11805 std::swap(LHS, RHS);
11806 [[fallthrough]];
11807 case ICmpInst::ICMP_SLE:
11808 // (X + C1)<nsw> s<= (X + C2)<nsw> if C1 s<= C2.
11809 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.sle(C2))
11810 return true;
11811
11812 break;
11813
11814 case ICmpInst::ICMP_SGT:
11815 std::swap(LHS, RHS);
11816 [[fallthrough]];
11817 case ICmpInst::ICMP_SLT:
11818 // (X + C1)<nsw> s< (X + C2)<nsw> if C1 s< C2.
11819 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.slt(C2))
11820 return true;
11821
11822 break;
11823
11824 case ICmpInst::ICMP_UGE:
11825 std::swap(LHS, RHS);
11826 [[fallthrough]];
11827 case ICmpInst::ICMP_ULE:
11828 // (X + C1) u<= (X + C2)<nuw> for C1 u<= C2.
11829 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNUW) && C1.ule(C2))
11830 return true;
11831
11832 break;
11833
11834 case ICmpInst::ICMP_UGT:
11835 std::swap(LHS, RHS);
11836 [[fallthrough]];
11837 case ICmpInst::ICMP_ULT:
11838 // (X + C1) u< (X + C2)<nuw> if C1 u< C2.
11839 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNUW) && C1.ult(C2))
11840 return true;
11841 break;
11842 }
11843
11844 return false;
11845}
11846
11847bool ScalarEvolution::isKnownPredicateViaSplitting(CmpPredicate Pred,
11849 if (Pred != ICmpInst::ICMP_ULT || ProvingSplitPredicate)
11850 return false;
11851
11852 // Allowing arbitrary number of activations of isKnownPredicateViaSplitting on
11853 // the stack can result in exponential time complexity.
11854 SaveAndRestore Restore(ProvingSplitPredicate, true);
11855
11856 // If L >= 0 then I `ult` L <=> I >= 0 && I `slt` L
11857 //
11858 // To prove L >= 0 we use isKnownNonNegative whereas to prove I >= 0 we use
11859 // isKnownPredicate. isKnownPredicate is more powerful, but also more
11860 // expensive; and using isKnownNonNegative(RHS) is sufficient for most of the
11861 // interesting cases seen in practice. We can consider "upgrading" L >= 0 to
11862 // use isKnownPredicate later if needed.
11863 return isKnownNonNegative(RHS) &&
11866}
11867
11868bool ScalarEvolution::isImpliedViaGuard(const BasicBlock *BB, CmpPredicate Pred,
11869 const SCEV *LHS, const SCEV *RHS) {
11870 // No need to even try if we know the module has no guards.
11871 if (!HasGuards)
11872 return false;
11873
11874 return any_of(*BB, [&](const Instruction &I) {
11875 using namespace llvm::PatternMatch;
11876
11877 Value *Condition;
11879 m_Value(Condition))) &&
11880 isImpliedCond(Pred, LHS, RHS, Condition, false);
11881 });
11882}
11883
11884/// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is
11885/// protected by a conditional between LHS and RHS. This is used to
11886/// to eliminate casts.
11888 CmpPredicate Pred,
11889 const SCEV *LHS,
11890 const SCEV *RHS) {
11891 // Interpret a null as meaning no loop, where there is obviously no guard
11892 // (interprocedural conditions notwithstanding). Do not bother about
11893 // unreachable loops.
11894 if (!L || !DT.isReachableFromEntry(L->getHeader()))
11895 return true;
11896
11897 if (VerifyIR)
11898 assert(!verifyFunction(*L->getHeader()->getParent(), &dbgs()) &&
11899 "This cannot be done on broken IR!");
11900
11901
11902 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
11903 return true;
11904
11905 BasicBlock *Latch = L->getLoopLatch();
11906 if (!Latch)
11907 return false;
11908
11909 CondBrInst *LoopContinuePredicate =
11911 if (LoopContinuePredicate &&
11912 isImpliedCond(Pred, LHS, RHS, LoopContinuePredicate->getCondition(),
11913 LoopContinuePredicate->getSuccessor(0) != L->getHeader()))
11914 return true;
11915
11916 // We don't want more than one activation of the following loops on the stack
11917 // -- that can lead to O(n!) time complexity.
11918 if (WalkingBEDominatingConds)
11919 return false;
11920
11921 SaveAndRestore ClearOnExit(WalkingBEDominatingConds, true);
11922
11923 // See if we can exploit a trip count to prove the predicate.
11924 const auto &BETakenInfo = getBackedgeTakenInfo(L);
11925 const SCEV *LatchBECount = BETakenInfo.getExact(Latch, this);
11926 if (LatchBECount != getCouldNotCompute()) {
11927 // We know that Latch branches back to the loop header exactly
11928 // LatchBECount times. This means the backdege condition at Latch is
11929 // equivalent to "{0,+,1} u< LatchBECount".
11930 Type *Ty = LatchBECount->getType();
11931 auto NoWrapFlags = SCEV::NoWrapFlags(SCEV::FlagNUW | SCEV::FlagNW);
11932 const SCEV *LoopCounter =
11933 getAddRecExpr(getZero(Ty), getOne(Ty), L, NoWrapFlags);
11934 if (isImpliedCond(Pred, LHS, RHS, ICmpInst::ICMP_ULT, LoopCounter,
11935 LatchBECount))
11936 return true;
11937 }
11938
11939 // Check conditions due to any @llvm.assume intrinsics.
11940 for (auto &AssumeVH : AC.assumptions()) {
11941 if (!AssumeVH)
11942 continue;
11943 auto *CI = cast<CallInst>(AssumeVH);
11944 if (!DT.dominates(CI, Latch->getTerminator()))
11945 continue;
11946
11947 if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false))
11948 return true;
11949 }
11950
11951 if (isImpliedViaGuard(Latch, Pred, LHS, RHS))
11952 return true;
11953
11954 for (DomTreeNode *DTN = DT[Latch], *HeaderDTN = DT[L->getHeader()];
11955 DTN != HeaderDTN; DTN = DTN->getIDom()) {
11956 assert(DTN && "should reach the loop header before reaching the root!");
11957
11958 BasicBlock *BB = DTN->getBlock();
11959 if (isImpliedViaGuard(BB, Pred, LHS, RHS))
11960 return true;
11961
11962 BasicBlock *PBB = BB->getSinglePredecessor();
11963 if (!PBB)
11964 continue;
11965
11967 if (!ContBr || ContBr->getSuccessor(0) == ContBr->getSuccessor(1))
11968 continue;
11969
11970 // If we have an edge `E` within the loop body that dominates the only
11971 // latch, the condition guarding `E` also guards the backedge. This
11972 // reasoning works only for loops with a single latch.
11973 // We're constructively (and conservatively) enumerating edges within the
11974 // loop body that dominate the latch. The dominator tree better agree
11975 // with us on this:
11976 assert(DT.dominates(BasicBlockEdge(PBB, BB), Latch) && "should be!");
11977 if (isImpliedCond(Pred, LHS, RHS, ContBr->getCondition(),
11978 BB != ContBr->getSuccessor(0)))
11979 return true;
11980 }
11981
11982 return false;
11983}
11984
11986 CmpPredicate Pred,
11987 const SCEV *LHS,
11988 const SCEV *RHS) {
11989 // Do not bother proving facts for unreachable code.
11990 if (!DT.isReachableFromEntry(BB))
11991 return true;
11992 if (VerifyIR)
11993 assert(!verifyFunction(*BB->getParent(), &dbgs()) &&
11994 "This cannot be done on broken IR!");
11995
11996 // If we cannot prove strict comparison (e.g. a > b), maybe we can prove
11997 // the facts (a >= b && a != b) separately. A typical situation is when the
11998 // non-strict comparison is known from ranges and non-equality is known from
11999 // dominating predicates. If we are proving strict comparison, we always try
12000 // to prove non-equality and non-strict comparison separately.
12001 CmpPredicate NonStrictPredicate = ICmpInst::getNonStrictCmpPredicate(Pred);
12002 const bool ProvingStrictComparison =
12003 Pred != NonStrictPredicate.dropSameSign();
12004 bool ProvedNonStrictComparison = false;
12005 bool ProvedNonEquality = false;
12006
12007 auto SplitAndProve = [&](std::function<bool(CmpPredicate)> Fn) -> bool {
12008 if (!ProvedNonStrictComparison)
12009 ProvedNonStrictComparison = Fn(NonStrictPredicate);
12010 if (!ProvedNonEquality)
12011 ProvedNonEquality = Fn(ICmpInst::ICMP_NE);
12012 if (ProvedNonStrictComparison && ProvedNonEquality)
12013 return true;
12014 return false;
12015 };
12016
12017 if (ProvingStrictComparison) {
12018 auto ProofFn = [&](CmpPredicate P) {
12019 return isKnownViaNonRecursiveReasoning(P, LHS, RHS);
12020 };
12021 if (SplitAndProve(ProofFn))
12022 return true;
12023 }
12024
12025 // Try to prove (Pred, LHS, RHS) using isImpliedCond.
12026 auto ProveViaCond = [&](const Value *Condition, bool Inverse) {
12027 const Instruction *CtxI = &BB->front();
12028 if (isImpliedCond(Pred, LHS, RHS, Condition, Inverse, CtxI))
12029 return true;
12030 if (ProvingStrictComparison) {
12031 auto ProofFn = [&](CmpPredicate P) {
12032 return isImpliedCond(P, LHS, RHS, Condition, Inverse, CtxI);
12033 };
12034 if (SplitAndProve(ProofFn))
12035 return true;
12036 }
12037 return false;
12038 };
12039
12040 // Starting at the block's predecessor, climb up the predecessor chain, as long
12041 // as there are predecessors that can be found that have unique successors
12042 // leading to the original block.
12043 const Loop *ContainingLoop = LI.getLoopFor(BB);
12044 const BasicBlock *PredBB;
12045 if (ContainingLoop && ContainingLoop->getHeader() == BB)
12046 PredBB = ContainingLoop->getLoopPredecessor();
12047 else
12048 PredBB = BB->getSinglePredecessor();
12049 for (std::pair<const BasicBlock *, const BasicBlock *> Pair(PredBB, BB);
12050 Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
12051 const CondBrInst *BlockEntryPredicate =
12052 dyn_cast<CondBrInst>(Pair.first->getTerminator());
12053 if (!BlockEntryPredicate)
12054 continue;
12055
12056 if (ProveViaCond(BlockEntryPredicate->getCondition(),
12057 BlockEntryPredicate->getSuccessor(0) != Pair.second))
12058 return true;
12059 }
12060
12061 // Check conditions due to any @llvm.assume intrinsics.
12062 for (auto &AssumeVH : AC.assumptions()) {
12063 if (!AssumeVH)
12064 continue;
12065 auto *CI = cast<CallInst>(AssumeVH);
12066 if (!DT.dominates(CI, BB))
12067 continue;
12068
12069 if (ProveViaCond(CI->getArgOperand(0), false))
12070 return true;
12071 }
12072
12073 // Check conditions due to any @llvm.experimental.guard intrinsics.
12074 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
12075 F.getParent(), Intrinsic::experimental_guard);
12076 if (GuardDecl)
12077 for (const auto *GU : GuardDecl->users())
12078 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
12079 if (Guard->getFunction() == BB->getParent() && DT.dominates(Guard, BB))
12080 if (ProveViaCond(Guard->getArgOperand(0), false))
12081 return true;
12082 return false;
12083}
12084
12086 const SCEV *LHS,
12087 const SCEV *RHS) {
12088 // Interpret a null as meaning no loop, where there is obviously no guard
12089 // (interprocedural conditions notwithstanding).
12090 if (!L)
12091 return false;
12092
12093 // Both LHS and RHS must be available at loop entry.
12095 "LHS is not available at Loop Entry");
12097 "RHS is not available at Loop Entry");
12098
12099 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
12100 return true;
12101
12102 return isBasicBlockEntryGuardedByCond(L->getHeader(), Pred, LHS, RHS);
12103}
12104
12105bool ScalarEvolution::isImpliedCond(CmpPredicate Pred, const SCEV *LHS,
12106 const SCEV *RHS,
12107 const Value *FoundCondValue, bool Inverse,
12108 const Instruction *CtxI) {
12109 // False conditions implies anything. Do not bother analyzing it further.
12110 if (FoundCondValue ==
12111 ConstantInt::getBool(FoundCondValue->getContext(), Inverse))
12112 return true;
12113
12114 if (!PendingLoopPredicates.insert(FoundCondValue).second)
12115 return false;
12116
12117 llvm::scope_exit ClearOnExit(
12118 [&]() { PendingLoopPredicates.erase(FoundCondValue); });
12119
12120 // Recursively handle And and Or conditions.
12121 const Value *Op0, *Op1;
12122 if (match(FoundCondValue, m_LogicalAnd(m_Value(Op0), m_Value(Op1)))) {
12123 if (!Inverse)
12124 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
12125 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
12126 } else if (match(FoundCondValue, m_LogicalOr(m_Value(Op0), m_Value(Op1)))) {
12127 if (Inverse)
12128 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
12129 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
12130 }
12131
12132 const ICmpInst *ICI = dyn_cast<ICmpInst>(FoundCondValue);
12133 if (!ICI) return false;
12134
12135 // Now that we found a conditional branch that dominates the loop or controls
12136 // the loop latch. Check to see if it is the comparison we are looking for.
12137 CmpPredicate FoundPred;
12138 if (Inverse)
12139 FoundPred = ICI->getInverseCmpPredicate();
12140 else
12141 FoundPred = ICI->getCmpPredicate();
12142
12143 const SCEV *FoundLHS = getSCEV(ICI->getOperand(0));
12144 const SCEV *FoundRHS = getSCEV(ICI->getOperand(1));
12145
12146 return isImpliedCond(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS, CtxI);
12147}
12148
12149bool ScalarEvolution::isImpliedCond(CmpPredicate Pred, const SCEV *LHS,
12150 const SCEV *RHS, CmpPredicate FoundPred,
12151 const SCEV *FoundLHS, const SCEV *FoundRHS,
12152 const Instruction *CtxI) {
12153 // Balance the types.
12154 if (getTypeSizeInBits(LHS->getType()) <
12155 getTypeSizeInBits(FoundLHS->getType())) {
12156 // For unsigned and equality predicates, try to prove that both found
12157 // operands fit into narrow unsigned range. If so, try to prove facts in
12158 // narrow types.
12159 if (!CmpInst::isSigned(FoundPred) && !FoundLHS->getType()->isPointerTy() &&
12160 !FoundRHS->getType()->isPointerTy()) {
12161 auto *NarrowType = LHS->getType();
12162 auto *WideType = FoundLHS->getType();
12163 auto BitWidth = getTypeSizeInBits(NarrowType);
12164 const SCEV *MaxValue = getZeroExtendExpr(
12166 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundLHS,
12167 MaxValue) &&
12168 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundRHS,
12169 MaxValue)) {
12170 const SCEV *TruncFoundLHS = getTruncateExpr(FoundLHS, NarrowType);
12171 const SCEV *TruncFoundRHS = getTruncateExpr(FoundRHS, NarrowType);
12172 // We cannot preserve samesign after truncation.
12173 if (isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred.dropSameSign(),
12174 TruncFoundLHS, TruncFoundRHS, CtxI))
12175 return true;
12176 }
12177 }
12178
12179 if (LHS->getType()->isPointerTy() || RHS->getType()->isPointerTy())
12180 return false;
12181 if (CmpInst::isSigned(Pred)) {
12182 LHS = getSignExtendExpr(LHS, FoundLHS->getType());
12183 RHS = getSignExtendExpr(RHS, FoundLHS->getType());
12184 } else {
12185 LHS = getZeroExtendExpr(LHS, FoundLHS->getType());
12186 RHS = getZeroExtendExpr(RHS, FoundLHS->getType());
12187 }
12188 } else if (getTypeSizeInBits(LHS->getType()) >
12189 getTypeSizeInBits(FoundLHS->getType())) {
12190 if (FoundLHS->getType()->isPointerTy() || FoundRHS->getType()->isPointerTy())
12191 return false;
12192 if (CmpInst::isSigned(FoundPred)) {
12193 FoundLHS = getSignExtendExpr(FoundLHS, LHS->getType());
12194 FoundRHS = getSignExtendExpr(FoundRHS, LHS->getType());
12195 } else {
12196 FoundLHS = getZeroExtendExpr(FoundLHS, LHS->getType());
12197 FoundRHS = getZeroExtendExpr(FoundRHS, LHS->getType());
12198 }
12199 }
12200 return isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, FoundLHS,
12201 FoundRHS, CtxI);
12202}
12203
12204bool ScalarEvolution::isImpliedCondBalancedTypes(
12205 CmpPredicate Pred, SCEVUse LHS, SCEVUse RHS, CmpPredicate FoundPred,
12206 SCEVUse FoundLHS, SCEVUse FoundRHS, const Instruction *CtxI) {
12208 getTypeSizeInBits(FoundLHS->getType()) &&
12209 "Types should be balanced!");
12210 // Canonicalize the query to match the way instcombine will have
12211 // canonicalized the comparison.
12212 if (SimplifyICmpOperands(Pred, LHS, RHS))
12213 if (LHS == RHS)
12214 return CmpInst::isTrueWhenEqual(Pred);
12215 if (SimplifyICmpOperands(FoundPred, FoundLHS, FoundRHS))
12216 if (FoundLHS == FoundRHS)
12217 return CmpInst::isFalseWhenEqual(FoundPred);
12218
12219 // Check to see if we can make the LHS or RHS match.
12220 if (LHS == FoundRHS || RHS == FoundLHS) {
12221 if (isa<SCEVConstant>(RHS)) {
12222 std::swap(FoundLHS, FoundRHS);
12223 FoundPred = ICmpInst::getSwappedCmpPredicate(FoundPred);
12224 } else {
12225 std::swap(LHS, RHS);
12227 }
12228 }
12229
12230 // Check whether the found predicate is the same as the desired predicate.
12231 if (auto P = CmpPredicate::getMatching(FoundPred, Pred))
12232 return isImpliedCondOperands(*P, LHS, RHS, FoundLHS, FoundRHS, CtxI);
12233
12234 // Check whether swapping the found predicate makes it the same as the
12235 // desired predicate.
12236 if (auto P = CmpPredicate::getMatching(
12237 ICmpInst::getSwappedCmpPredicate(FoundPred), Pred)) {
12238 // We can write the implication
12239 // 0. LHS Pred RHS <- FoundLHS SwapPred FoundRHS
12240 // using one of the following ways:
12241 // 1. LHS Pred RHS <- FoundRHS Pred FoundLHS
12242 // 2. RHS SwapPred LHS <- FoundLHS SwapPred FoundRHS
12243 // 3. LHS Pred RHS <- ~FoundLHS Pred ~FoundRHS
12244 // 4. ~LHS SwapPred ~RHS <- FoundLHS SwapPred FoundRHS
12245 // Forms 1. and 2. require swapping the operands of one condition. Don't
12246 // do this if it would break canonical constant/addrec ordering.
12248 return isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(*P), RHS,
12249 LHS, FoundLHS, FoundRHS, CtxI);
12250 if (!isa<SCEVConstant>(FoundRHS) && !isa<SCEVAddRecExpr>(FoundLHS))
12251 return isImpliedCondOperands(*P, LHS, RHS, FoundRHS, FoundLHS, CtxI);
12252
12253 // There's no clear preference between forms 3. and 4., try both. Avoid
12254 // forming getNotSCEV of pointer values as the resulting subtract is
12255 // not legal.
12256 if (!LHS->getType()->isPointerTy() && !RHS->getType()->isPointerTy() &&
12257 isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(*P),
12258 getNotSCEV(LHS), getNotSCEV(RHS), FoundLHS,
12259 FoundRHS, CtxI))
12260 return true;
12261
12262 if (!FoundLHS->getType()->isPointerTy() &&
12263 !FoundRHS->getType()->isPointerTy() &&
12264 isImpliedCondOperands(*P, LHS, RHS, getNotSCEV(FoundLHS),
12265 getNotSCEV(FoundRHS), CtxI))
12266 return true;
12267
12268 return false;
12269 }
12270
12271 auto IsSignFlippedPredicate = [](CmpInst::Predicate P1,
12273 assert(P1 != P2 && "Handled earlier!");
12274 return CmpInst::isRelational(P2) &&
12276 };
12277 if (IsSignFlippedPredicate(Pred, FoundPred)) {
12278 // Unsigned comparison is the same as signed comparison when both the
12279 // operands are non-negative or negative.
12280 if (haveSameSign(FoundLHS, FoundRHS))
12281 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI);
12282 // Create local copies that we can freely swap and canonicalize our
12283 // conditions to "le/lt".
12284 CmpPredicate CanonicalPred = Pred, CanonicalFoundPred = FoundPred;
12285 const SCEV *CanonicalLHS = LHS, *CanonicalRHS = RHS,
12286 *CanonicalFoundLHS = FoundLHS, *CanonicalFoundRHS = FoundRHS;
12287 if (ICmpInst::isGT(CanonicalPred) || ICmpInst::isGE(CanonicalPred)) {
12288 CanonicalPred = ICmpInst::getSwappedCmpPredicate(CanonicalPred);
12289 CanonicalFoundPred = ICmpInst::getSwappedCmpPredicate(CanonicalFoundPred);
12290 std::swap(CanonicalLHS, CanonicalRHS);
12291 std::swap(CanonicalFoundLHS, CanonicalFoundRHS);
12292 }
12293 assert((ICmpInst::isLT(CanonicalPred) || ICmpInst::isLE(CanonicalPred)) &&
12294 "Must be!");
12295 assert((ICmpInst::isLT(CanonicalFoundPred) ||
12296 ICmpInst::isLE(CanonicalFoundPred)) &&
12297 "Must be!");
12298 if (ICmpInst::isSigned(CanonicalPred) && isKnownNonNegative(CanonicalRHS))
12299 // Use implication:
12300 // x <u y && y >=s 0 --> x <s y.
12301 // If we can prove the left part, the right part is also proven.
12302 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
12303 CanonicalRHS, CanonicalFoundLHS,
12304 CanonicalFoundRHS);
12305 if (ICmpInst::isUnsigned(CanonicalPred) && isKnownNegative(CanonicalRHS))
12306 // Use implication:
12307 // x <s y && y <s 0 --> x <u y.
12308 // If we can prove the left part, the right part is also proven.
12309 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
12310 CanonicalRHS, CanonicalFoundLHS,
12311 CanonicalFoundRHS);
12312 }
12313
12314 // Check if we can make progress by sharpening ranges.
12315 if (FoundPred == ICmpInst::ICMP_NE &&
12316 (isa<SCEVConstant>(FoundLHS) || isa<SCEVConstant>(FoundRHS))) {
12317
12318 const SCEVConstant *C = nullptr;
12319 const SCEV *V = nullptr;
12320
12321 if (isa<SCEVConstant>(FoundLHS)) {
12322 C = cast<SCEVConstant>(FoundLHS);
12323 V = FoundRHS;
12324 } else {
12325 C = cast<SCEVConstant>(FoundRHS);
12326 V = FoundLHS;
12327 }
12328
12329 // The guarding predicate tells us that C != V. If the known range
12330 // of V is [C, t), we can sharpen the range to [C + 1, t). The
12331 // range we consider has to correspond to same signedness as the
12332 // predicate we're interested in folding.
12333
12334 APInt Min = ICmpInst::isSigned(Pred) ?
12336
12337 if (Min == C->getAPInt()) {
12338 // Given (V >= Min && V != Min) we conclude V >= (Min + 1).
12339 // This is true even if (Min + 1) wraps around -- in case of
12340 // wraparound, (Min + 1) < Min, so (V >= Min => V >= (Min + 1)).
12341
12342 APInt SharperMin = Min + 1;
12343
12344 switch (Pred) {
12345 case ICmpInst::ICMP_SGE:
12346 case ICmpInst::ICMP_UGE:
12347 // We know V `Pred` SharperMin. If this implies LHS `Pred`
12348 // RHS, we're done.
12349 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(SharperMin),
12350 CtxI))
12351 return true;
12352 [[fallthrough]];
12353
12354 case ICmpInst::ICMP_SGT:
12355 case ICmpInst::ICMP_UGT:
12356 // We know from the range information that (V `Pred` Min ||
12357 // V == Min). We know from the guarding condition that !(V
12358 // == Min). This gives us
12359 //
12360 // V `Pred` Min || V == Min && !(V == Min)
12361 // => V `Pred` Min
12362 //
12363 // If V `Pred` Min implies LHS `Pred` RHS, we're done.
12364
12365 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min), CtxI))
12366 return true;
12367 break;
12368
12369 // `LHS < RHS` and `LHS <= RHS` are handled in the same way as `RHS > LHS` and `RHS >= LHS` respectively.
12370 case ICmpInst::ICMP_SLE:
12371 case ICmpInst::ICMP_ULE:
12372 if (isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(Pred), RHS,
12373 LHS, V, getConstant(SharperMin), CtxI))
12374 return true;
12375 [[fallthrough]];
12376
12377 case ICmpInst::ICMP_SLT:
12378 case ICmpInst::ICMP_ULT:
12379 if (isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(Pred), RHS,
12380 LHS, V, getConstant(Min), CtxI))
12381 return true;
12382 break;
12383
12384 default:
12385 // No change
12386 break;
12387 }
12388 }
12389 }
12390
12391 // Check whether the actual condition is beyond sufficient.
12392 if (FoundPred == ICmpInst::ICMP_EQ)
12393 if (ICmpInst::isTrueWhenEqual(Pred))
12394 if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
12395 return true;
12396 if (Pred == ICmpInst::ICMP_NE)
12397 if (!ICmpInst::isTrueWhenEqual(FoundPred))
12398 if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
12399 return true;
12400
12401 if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS))
12402 return true;
12403
12404 // Otherwise assume the worst.
12405 return false;
12406}
12407
12408bool ScalarEvolution::splitBinaryAdd(SCEVUse Expr, SCEVUse &L, SCEVUse &R,
12409 SCEV::NoWrapFlags &Flags) {
12410 if (!match(Expr, m_scev_Add(m_SCEV(L), m_SCEV(R))))
12411 return false;
12412
12413 Flags = cast<SCEVAddExpr>(Expr)->getNoWrapFlags();
12414 return true;
12415}
12416
12417std::optional<APInt>
12419 // We avoid subtracting expressions here because this function is usually
12420 // fairly deep in the call stack (i.e. is called many times).
12421
12422 unsigned BW = getTypeSizeInBits(More->getType());
12423 APInt Diff(BW, 0);
12424 APInt DiffMul(BW, 1);
12425 // Try various simplifications to reduce the difference to a constant. Limit
12426 // the number of allowed simplifications to keep compile-time low.
12427 for (unsigned I = 0; I < 8; ++I) {
12428 if (More == Less)
12429 return Diff;
12430
12431 // Reduce addrecs with identical steps to their start value.
12433 const auto *LAR = cast<SCEVAddRecExpr>(Less);
12434 const auto *MAR = cast<SCEVAddRecExpr>(More);
12435
12436 if (LAR->getLoop() != MAR->getLoop())
12437 return std::nullopt;
12438
12439 // We look at affine expressions only; not for correctness but to keep
12440 // getStepRecurrence cheap.
12441 if (!LAR->isAffine() || !MAR->isAffine())
12442 return std::nullopt;
12443
12444 if (LAR->getStepRecurrence(*this) != MAR->getStepRecurrence(*this))
12445 return std::nullopt;
12446
12447 Less = LAR->getStart();
12448 More = MAR->getStart();
12449 continue;
12450 }
12451
12452 // Try to match a common constant multiply.
12453 auto MatchConstMul =
12454 [](const SCEV *S) -> std::optional<std::pair<const SCEV *, APInt>> {
12455 const APInt *C;
12456 const SCEV *Op;
12457 if (match(S, m_scev_Mul(m_scev_APInt(C), m_SCEV(Op))))
12458 return {{Op, *C}};
12459 return std::nullopt;
12460 };
12461 if (auto MatchedMore = MatchConstMul(More)) {
12462 if (auto MatchedLess = MatchConstMul(Less)) {
12463 if (MatchedMore->second == MatchedLess->second) {
12464 More = MatchedMore->first;
12465 Less = MatchedLess->first;
12466 DiffMul *= MatchedMore->second;
12467 continue;
12468 }
12469 }
12470 }
12471
12472 // Try to cancel out common factors in two add expressions.
12474 auto Add = [&](const SCEV *S, int Mul) {
12475 if (auto *C = dyn_cast<SCEVConstant>(S)) {
12476 if (Mul == 1) {
12477 Diff += C->getAPInt() * DiffMul;
12478 } else {
12479 assert(Mul == -1);
12480 Diff -= C->getAPInt() * DiffMul;
12481 }
12482 } else
12483 Multiplicity[S] += Mul;
12484 };
12485 auto Decompose = [&](const SCEV *S, int Mul) {
12486 if (isa<SCEVAddExpr>(S)) {
12487 for (const SCEV *Op : S->operands())
12488 Add(Op, Mul);
12489 } else
12490 Add(S, Mul);
12491 };
12492 Decompose(More, 1);
12493 Decompose(Less, -1);
12494
12495 // Check whether all the non-constants cancel out, or reduce to new
12496 // More/Less values.
12497 const SCEV *NewMore = nullptr, *NewLess = nullptr;
12498 for (const auto &[S, Mul] : Multiplicity) {
12499 if (Mul == 0)
12500 continue;
12501 if (Mul == 1) {
12502 if (NewMore)
12503 return std::nullopt;
12504 NewMore = S;
12505 } else if (Mul == -1) {
12506 if (NewLess)
12507 return std::nullopt;
12508 NewLess = S;
12509 } else
12510 return std::nullopt;
12511 }
12512
12513 // Values stayed the same, no point in trying further.
12514 if (NewMore == More || NewLess == Less)
12515 return std::nullopt;
12516
12517 More = NewMore;
12518 Less = NewLess;
12519
12520 // Reduced to constant.
12521 if (!More && !Less)
12522 return Diff;
12523
12524 // Left with variable on only one side, bail out.
12525 if (!More || !Less)
12526 return std::nullopt;
12527 }
12528
12529 // Did not reduce to constant.
12530 return std::nullopt;
12531}
12532
12533bool ScalarEvolution::isImpliedCondOperandsViaAddRecStart(
12534 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS,
12535 const SCEV *FoundRHS, const Instruction *CtxI) {
12536 // Try to recognize the following pattern:
12537 //
12538 // FoundRHS = ...
12539 // ...
12540 // loop:
12541 // FoundLHS = {Start,+,W}
12542 // context_bb: // Basic block from the same loop
12543 // known(Pred, FoundLHS, FoundRHS)
12544 //
12545 // If some predicate is known in the context of a loop, it is also known on
12546 // each iteration of this loop, including the first iteration. Therefore, in
12547 // this case, `FoundLHS Pred FoundRHS` implies `Start Pred FoundRHS`. Try to
12548 // prove the original pred using this fact.
12549 if (!CtxI)
12550 return false;
12551 const BasicBlock *ContextBB = CtxI->getParent();
12552 // Make sure AR varies in the context block.
12553 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundLHS)) {
12554 const Loop *L = AR->getLoop();
12555 const auto *Latch = L->getLoopLatch();
12556 // Make sure that context belongs to the loop and executes on 1st iteration
12557 // (if it ever executes at all).
12558 if (!L->contains(ContextBB) || !Latch || !DT.dominates(ContextBB, Latch))
12559 return false;
12560 if (!isAvailableAtLoopEntry(FoundRHS, AR->getLoop()))
12561 return false;
12562 return isImpliedCondOperands(Pred, LHS, RHS, AR->getStart(), FoundRHS);
12563 }
12564
12565 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundRHS)) {
12566 const Loop *L = AR->getLoop();
12567 const auto *Latch = L->getLoopLatch();
12568 // Make sure that context belongs to the loop and executes on 1st iteration
12569 // (if it ever executes at all).
12570 if (!L->contains(ContextBB) || !Latch || !DT.dominates(ContextBB, Latch))
12571 return false;
12572 if (!isAvailableAtLoopEntry(FoundLHS, AR->getLoop()))
12573 return false;
12574 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, AR->getStart());
12575 }
12576
12577 return false;
12578}
12579
12580bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow(CmpPredicate Pred,
12581 const SCEV *LHS,
12582 const SCEV *RHS,
12583 const SCEV *FoundLHS,
12584 const SCEV *FoundRHS) {
12585 if (Pred != CmpInst::ICMP_SLT && Pred != CmpInst::ICMP_ULT)
12586 return false;
12587
12588 const auto *AddRecLHS = dyn_cast<SCEVAddRecExpr>(LHS);
12589 if (!AddRecLHS)
12590 return false;
12591
12592 const auto *AddRecFoundLHS = dyn_cast<SCEVAddRecExpr>(FoundLHS);
12593 if (!AddRecFoundLHS)
12594 return false;
12595
12596 // We'd like to let SCEV reason about control dependencies, so we constrain
12597 // both the inequalities to be about add recurrences on the same loop. This
12598 // way we can use isLoopEntryGuardedByCond later.
12599
12600 const Loop *L = AddRecFoundLHS->getLoop();
12601 if (L != AddRecLHS->getLoop())
12602 return false;
12603
12604 // FoundLHS u< FoundRHS u< -C => (FoundLHS + C) u< (FoundRHS + C) ... (1)
12605 //
12606 // FoundLHS s< FoundRHS s< INT_MIN - C => (FoundLHS + C) s< (FoundRHS + C)
12607 // ... (2)
12608 //
12609 // Informal proof for (2), assuming (1) [*]:
12610 //
12611 // We'll also assume (A s< B) <=> ((A + INT_MIN) u< (B + INT_MIN)) ... (3)[**]
12612 //
12613 // Then
12614 //
12615 // FoundLHS s< FoundRHS s< INT_MIN - C
12616 // <=> (FoundLHS + INT_MIN) u< (FoundRHS + INT_MIN) u< -C [ using (3) ]
12617 // <=> (FoundLHS + INT_MIN + C) u< (FoundRHS + INT_MIN + C) [ using (1) ]
12618 // <=> (FoundLHS + INT_MIN + C + INT_MIN) s<
12619 // (FoundRHS + INT_MIN + C + INT_MIN) [ using (3) ]
12620 // <=> FoundLHS + C s< FoundRHS + C
12621 //
12622 // [*]: (1) can be proved by ruling out overflow.
12623 //
12624 // [**]: This can be proved by analyzing all the four possibilities:
12625 // (A s< 0, B s< 0), (A s< 0, B s>= 0), (A s>= 0, B s< 0) and
12626 // (A s>= 0, B s>= 0).
12627 //
12628 // Note:
12629 // Despite (2), "FoundRHS s< INT_MIN - C" does not mean that "FoundRHS + C"
12630 // will not sign underflow. For instance, say FoundLHS = (i8 -128), FoundRHS
12631 // = (i8 -127) and C = (i8 -100). Then INT_MIN - C = (i8 -28), and FoundRHS
12632 // s< (INT_MIN - C). Lack of sign overflow / underflow in "FoundRHS + C" is
12633 // neither necessary nor sufficient to prove "(FoundLHS + C) s< (FoundRHS +
12634 // C)".
12635
12636 std::optional<APInt> LDiff = computeConstantDifference(LHS, FoundLHS);
12637 if (!LDiff)
12638 return false;
12639 std::optional<APInt> RDiff = computeConstantDifference(RHS, FoundRHS);
12640 if (!RDiff || *LDiff != *RDiff)
12641 return false;
12642
12643 if (LDiff->isMinValue())
12644 return true;
12645
12646 APInt FoundRHSLimit;
12647
12648 if (Pred == CmpInst::ICMP_ULT) {
12649 FoundRHSLimit = -(*RDiff);
12650 } else {
12651 assert(Pred == CmpInst::ICMP_SLT && "Checked above!");
12652 FoundRHSLimit = APInt::getSignedMinValue(getTypeSizeInBits(RHS->getType())) - *RDiff;
12653 }
12654
12655 // Try to prove (1) or (2), as needed.
12656 return isAvailableAtLoopEntry(FoundRHS, L) &&
12657 isLoopEntryGuardedByCond(L, Pred, FoundRHS,
12658 getConstant(FoundRHSLimit));
12659}
12660
12661bool ScalarEvolution::isImpliedViaMerge(CmpPredicate Pred, const SCEV *LHS,
12662 const SCEV *RHS, const SCEV *FoundLHS,
12663 const SCEV *FoundRHS, unsigned Depth) {
12664 const PHINode *LPhi = nullptr, *RPhi = nullptr;
12665
12666 llvm::scope_exit ClearOnExit([&]() {
12667 if (LPhi) {
12668 bool Erased = PendingMerges.erase(LPhi);
12669 assert(Erased && "Failed to erase LPhi!");
12670 (void)Erased;
12671 }
12672 if (RPhi) {
12673 bool Erased = PendingMerges.erase(RPhi);
12674 assert(Erased && "Failed to erase RPhi!");
12675 (void)Erased;
12676 }
12677 });
12678
12679 // Find respective Phis and check that they are not being pending.
12680 if (const SCEVUnknown *LU = dyn_cast<SCEVUnknown>(LHS))
12681 if (auto *Phi = dyn_cast<PHINode>(LU->getValue())) {
12682 if (!PendingMerges.insert(Phi).second)
12683 return false;
12684 LPhi = Phi;
12685 }
12686 if (const SCEVUnknown *RU = dyn_cast<SCEVUnknown>(RHS))
12687 if (auto *Phi = dyn_cast<PHINode>(RU->getValue())) {
12688 // If we detect a loop of Phi nodes being processed by this method, for
12689 // example:
12690 //
12691 // %a = phi i32 [ %some1, %preheader ], [ %b, %latch ]
12692 // %b = phi i32 [ %some2, %preheader ], [ %a, %latch ]
12693 //
12694 // we don't want to deal with a case that complex, so return conservative
12695 // answer false.
12696 if (!PendingMerges.insert(Phi).second)
12697 return false;
12698 RPhi = Phi;
12699 }
12700
12701 // If none of LHS, RHS is a Phi, nothing to do here.
12702 if (!LPhi && !RPhi)
12703 return false;
12704
12705 // If there is a SCEVUnknown Phi we are interested in, make it left.
12706 if (!LPhi) {
12707 std::swap(LHS, RHS);
12708 std::swap(FoundLHS, FoundRHS);
12709 std::swap(LPhi, RPhi);
12711 }
12712
12713 assert(LPhi && "LPhi should definitely be a SCEVUnknown Phi!");
12714 const BasicBlock *LBB = LPhi->getParent();
12715 const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
12716
12717 auto ProvedEasily = [&](const SCEV *S1, const SCEV *S2) {
12718 return isKnownViaNonRecursiveReasoning(Pred, S1, S2) ||
12719 isImpliedCondOperandsViaRanges(Pred, S1, S2, Pred, FoundLHS, FoundRHS) ||
12720 isImpliedViaOperations(Pred, S1, S2, FoundLHS, FoundRHS, Depth);
12721 };
12722
12723 if (RPhi && RPhi->getParent() == LBB) {
12724 // Case one: RHS is also a SCEVUnknown Phi from the same basic block.
12725 // If we compare two Phis from the same block, and for each entry block
12726 // the predicate is true for incoming values from this block, then the
12727 // predicate is also true for the Phis.
12728 for (const BasicBlock *IncBB : predecessors(LBB)) {
12729 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12730 const SCEV *R = getSCEV(RPhi->getIncomingValueForBlock(IncBB));
12731 if (!ProvedEasily(L, R))
12732 return false;
12733 }
12734 } else if (RAR && RAR->getLoop()->getHeader() == LBB) {
12735 // Case two: RHS is also a Phi from the same basic block, and it is an
12736 // AddRec. It means that there is a loop which has both AddRec and Unknown
12737 // PHIs, for it we can compare incoming values of AddRec from above the loop
12738 // and latch with their respective incoming values of LPhi.
12739 // TODO: Generalize to handle loops with many inputs in a header.
12740 if (LPhi->getNumIncomingValues() != 2) return false;
12741
12742 auto *RLoop = RAR->getLoop();
12743 auto *Predecessor = RLoop->getLoopPredecessor();
12744 assert(Predecessor && "Loop with AddRec with no predecessor?");
12745 const SCEV *L1 = getSCEV(LPhi->getIncomingValueForBlock(Predecessor));
12746 if (!ProvedEasily(L1, RAR->getStart()))
12747 return false;
12748 auto *Latch = RLoop->getLoopLatch();
12749 assert(Latch && "Loop with AddRec with no latch?");
12750 const SCEV *L2 = getSCEV(LPhi->getIncomingValueForBlock(Latch));
12751 if (!ProvedEasily(L2, RAR->getPostIncExpr(*this)))
12752 return false;
12753 } else {
12754 // In all other cases go over inputs of LHS and compare each of them to RHS,
12755 // the predicate is true for (LHS, RHS) if it is true for all such pairs.
12756 // At this point RHS is either a non-Phi, or it is a Phi from some block
12757 // different from LBB.
12758 for (const BasicBlock *IncBB : predecessors(LBB)) {
12759 // Check that RHS is available in this block.
12760 if (!dominates(RHS, IncBB))
12761 return false;
12762 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12763 // Make sure L does not refer to a value from a potentially previous
12764 // iteration of a loop.
12765 if (!properlyDominates(L, LBB))
12766 return false;
12767 // Addrecs are considered to properly dominate their loop, so are missed
12768 // by the previous check. Discard any values that have computable
12769 // evolution in this loop.
12770 if (auto *Loop = LI.getLoopFor(LBB))
12771 if (hasComputableLoopEvolution(L, Loop))
12772 return false;
12773 if (!ProvedEasily(L, RHS))
12774 return false;
12775 }
12776 }
12777 return true;
12778}
12779
12780bool ScalarEvolution::isImpliedCondOperandsViaShift(CmpPredicate Pred,
12781 const SCEV *LHS,
12782 const SCEV *RHS,
12783 const SCEV *FoundLHS,
12784 const SCEV *FoundRHS) {
12785 // We want to imply LHS < RHS from LHS < (RHS >> shiftvalue). First, make
12786 // sure that we are dealing with same LHS.
12787 if (RHS == FoundRHS) {
12788 std::swap(LHS, RHS);
12789 std::swap(FoundLHS, FoundRHS);
12791 }
12792 if (LHS != FoundLHS)
12793 return false;
12794
12795 auto *SUFoundRHS = dyn_cast<SCEVUnknown>(FoundRHS);
12796 if (!SUFoundRHS)
12797 return false;
12798
12799 Value *Shiftee, *ShiftValue;
12800
12801 using namespace PatternMatch;
12802 if (match(SUFoundRHS->getValue(),
12803 m_LShr(m_Value(Shiftee), m_Value(ShiftValue)))) {
12804 auto *ShifteeS = getSCEV(Shiftee);
12805 // Prove one of the following:
12806 // LHS <u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <u RHS
12807 // LHS <=u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <=u RHS
12808 // LHS <s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12809 // ---> LHS <s RHS
12810 // LHS <=s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12811 // ---> LHS <=s RHS
12812 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE)
12813 return isKnownPredicate(ICmpInst::ICMP_ULE, ShifteeS, RHS);
12814 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE)
12815 if (isKnownNonNegative(ShifteeS))
12816 return isKnownPredicate(ICmpInst::ICMP_SLE, ShifteeS, RHS);
12817 }
12818
12819 return false;
12820}
12821
12822bool ScalarEvolution::isImpliedCondOperands(CmpPredicate Pred, const SCEV *LHS,
12823 const SCEV *RHS,
12824 const SCEV *FoundLHS,
12825 const SCEV *FoundRHS,
12826 const Instruction *CtxI) {
12827 return isImpliedCondOperandsViaRanges(Pred, LHS, RHS, Pred, FoundLHS,
12828 FoundRHS) ||
12829 isImpliedCondOperandsViaNoOverflow(Pred, LHS, RHS, FoundLHS,
12830 FoundRHS) ||
12831 isImpliedCondOperandsViaShift(Pred, LHS, RHS, FoundLHS, FoundRHS) ||
12832 isImpliedCondOperandsViaAddRecStart(Pred, LHS, RHS, FoundLHS, FoundRHS,
12833 CtxI) ||
12834 isImpliedCondOperandsHelper(Pred, LHS, RHS, FoundLHS, FoundRHS);
12835}
12836
12837/// Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
12838template <typename MinMaxExprType>
12839static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr,
12840 const SCEV *Candidate) {
12841 const MinMaxExprType *MinMaxExpr = dyn_cast<MinMaxExprType>(MaybeMinMaxExpr);
12842 if (!MinMaxExpr)
12843 return false;
12844
12845 return is_contained(MinMaxExpr->operands(), Candidate);
12846}
12847
12849 CmpPredicate Pred, const SCEV *LHS,
12850 const SCEV *RHS) {
12851 // If both sides are affine addrecs for the same loop, with equal
12852 // steps, and we know the recurrences don't wrap, then we only
12853 // need to check the predicate on the starting values.
12854
12855 if (!ICmpInst::isRelational(Pred))
12856 return false;
12857
12858 const SCEV *LStart, *RStart, *Step;
12859 const Loop *L;
12860 if (!match(LHS,
12861 m_scev_AffineAddRec(m_SCEV(LStart), m_SCEV(Step), m_Loop(L))) ||
12863 m_SpecificLoop(L))))
12864 return false;
12869 if (!LAR->getNoWrapFlags(NW) || !RAR->getNoWrapFlags(NW))
12870 return false;
12871
12872 return SE.isKnownPredicate(Pred, LStart, RStart);
12873}
12874
12875/// Is LHS `Pred` RHS true on the virtue of LHS or RHS being a Min or Max
12876/// expression?
12878 const SCEV *LHS, const SCEV *RHS) {
12879 switch (Pred) {
12880 default:
12881 return false;
12882
12883 case ICmpInst::ICMP_SGE:
12884 std::swap(LHS, RHS);
12885 [[fallthrough]];
12886 case ICmpInst::ICMP_SLE:
12887 return
12888 // min(A, ...) <= A
12890 // A <= max(A, ...)
12892
12893 case ICmpInst::ICMP_UGE:
12894 std::swap(LHS, RHS);
12895 [[fallthrough]];
12896 case ICmpInst::ICMP_ULE:
12897 return
12898 // min(A, ...) <= A
12899 // FIXME: what about umin_seq?
12901 // A <= max(A, ...)
12903 }
12904
12905 llvm_unreachable("covered switch fell through?!");
12906}
12907
12908bool ScalarEvolution::isImpliedViaOperations(CmpPredicate Pred, const SCEV *LHS,
12909 const SCEV *RHS,
12910 const SCEV *FoundLHS,
12911 const SCEV *FoundRHS,
12912 unsigned Depth) {
12915 "LHS and RHS have different sizes?");
12916 assert(getTypeSizeInBits(FoundLHS->getType()) ==
12917 getTypeSizeInBits(FoundRHS->getType()) &&
12918 "FoundLHS and FoundRHS have different sizes?");
12919 // We want to avoid hurting the compile time with analysis of too big trees.
12921 return false;
12922
12923 // We only want to work with GT comparison so far.
12924 if (ICmpInst::isLT(Pred)) {
12926 std::swap(LHS, RHS);
12927 std::swap(FoundLHS, FoundRHS);
12928 }
12929
12931
12932 // For unsigned, try to reduce it to corresponding signed comparison.
12933 if (P == ICmpInst::ICMP_UGT)
12934 // We can replace unsigned predicate with its signed counterpart if all
12935 // involved values are non-negative.
12936 // TODO: We could have better support for unsigned.
12937 if (isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) {
12938 // Knowing that both FoundLHS and FoundRHS are non-negative, and knowing
12939 // FoundLHS >u FoundRHS, we also know that FoundLHS >s FoundRHS. Let us
12940 // use this fact to prove that LHS and RHS are non-negative.
12941 const SCEV *MinusOne = getMinusOne(LHS->getType());
12942 if (isImpliedCondOperands(ICmpInst::ICMP_SGT, LHS, MinusOne, FoundLHS,
12943 FoundRHS) &&
12944 isImpliedCondOperands(ICmpInst::ICMP_SGT, RHS, MinusOne, FoundLHS,
12945 FoundRHS))
12947 }
12948
12949 if (P != ICmpInst::ICMP_SGT)
12950 return false;
12951
12952 auto GetOpFromSExt = [&](const SCEV *S) -> const SCEV * {
12953 if (auto *Ext = dyn_cast<SCEVSignExtendExpr>(S))
12954 return Ext->getOperand();
12955 // TODO: If S is a SCEVConstant then you can cheaply "strip" the sext off
12956 // the constant in some cases.
12957 return S;
12958 };
12959
12960 // Acquire values from extensions.
12961 auto *OrigLHS = LHS;
12962 auto *OrigFoundLHS = FoundLHS;
12963 LHS = GetOpFromSExt(LHS);
12964 FoundLHS = GetOpFromSExt(FoundLHS);
12965
12966 // Is the SGT predicate can be proved trivially or using the found context.
12967 auto IsSGTViaContext = [&](const SCEV *S1, const SCEV *S2) {
12968 return isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGT, S1, S2) ||
12969 isImpliedViaOperations(ICmpInst::ICMP_SGT, S1, S2, OrigFoundLHS,
12970 FoundRHS, Depth + 1);
12971 };
12972
12973 if (auto *LHSAddExpr = dyn_cast<SCEVAddExpr>(LHS)) {
12974 // We want to avoid creation of any new non-constant SCEV. Since we are
12975 // going to compare the operands to RHS, we should be certain that we don't
12976 // need any size extensions for this. So let's decline all cases when the
12977 // sizes of types of LHS and RHS do not match.
12978 // TODO: Maybe try to get RHS from sext to catch more cases?
12980 return false;
12981
12982 // Should not overflow.
12983 if (!LHSAddExpr->hasNoSignedWrap())
12984 return false;
12985
12986 SCEVUse LL = LHSAddExpr->getOperand(0);
12987 SCEVUse LR = LHSAddExpr->getOperand(1);
12988 auto *MinusOne = getMinusOne(RHS->getType());
12989
12990 // Checks that S1 >= 0 && S2 > RHS, trivially or using the found context.
12991 auto IsSumGreaterThanRHS = [&](const SCEV *S1, const SCEV *S2) {
12992 return IsSGTViaContext(S1, MinusOne) && IsSGTViaContext(S2, RHS);
12993 };
12994 // Try to prove the following rule:
12995 // (LHS = LL + LR) && (LL >= 0) && (LR > RHS) => (LHS > RHS).
12996 // (LHS = LL + LR) && (LR >= 0) && (LL > RHS) => (LHS > RHS).
12997 if (IsSumGreaterThanRHS(LL, LR) || IsSumGreaterThanRHS(LR, LL))
12998 return true;
12999 } else if (auto *LHSUnknownExpr = dyn_cast<SCEVUnknown>(LHS)) {
13000 Value *LL, *LR;
13001 // FIXME: Once we have SDiv implemented, we can get rid of this matching.
13002
13003 using namespace llvm::PatternMatch;
13004
13005 if (match(LHSUnknownExpr->getValue(), m_SDiv(m_Value(LL), m_Value(LR)))) {
13006 // Rules for division.
13007 // We are going to perform some comparisons with Denominator and its
13008 // derivative expressions. In general case, creating a SCEV for it may
13009 // lead to a complex analysis of the entire graph, and in particular it
13010 // can request trip count recalculation for the same loop. This would
13011 // cache as SCEVCouldNotCompute to avoid the infinite recursion. To avoid
13012 // this, we only want to create SCEVs that are constants in this section.
13013 // So we bail if Denominator is not a constant.
13014 if (!isa<ConstantInt>(LR))
13015 return false;
13016
13017 auto *Denominator = cast<SCEVConstant>(getSCEV(LR));
13018
13019 // We want to make sure that LHS = FoundLHS / Denominator. If it is so,
13020 // then a SCEV for the numerator already exists and matches with FoundLHS.
13021 auto *Numerator = getExistingSCEV(LL);
13022 if (!Numerator || Numerator->getType() != FoundLHS->getType())
13023 return false;
13024
13025 // Make sure that the numerator matches with FoundLHS and the denominator
13026 // is positive.
13027 if (!HasSameValue(Numerator, FoundLHS) || !isKnownPositive(Denominator))
13028 return false;
13029
13030 auto *DTy = Denominator->getType();
13031 auto *FRHSTy = FoundRHS->getType();
13032 if (DTy->isPointerTy() != FRHSTy->isPointerTy())
13033 // One of types is a pointer and another one is not. We cannot extend
13034 // them properly to a wider type, so let us just reject this case.
13035 // TODO: Usage of getEffectiveSCEVType for DTy, FRHSTy etc should help
13036 // to avoid this check.
13037 return false;
13038
13039 // Given that:
13040 // FoundLHS > FoundRHS, LHS = FoundLHS / Denominator, Denominator > 0.
13041 auto *WTy = getWiderType(DTy, FRHSTy);
13042 auto *DenominatorExt = getNoopOrSignExtend(Denominator, WTy);
13043 auto *FoundRHSExt = getNoopOrSignExtend(FoundRHS, WTy);
13044
13045 // Try to prove the following rule:
13046 // (FoundRHS > Denominator - 2) && (RHS <= 0) => (LHS > RHS).
13047 // For example, given that FoundLHS > 2. It means that FoundLHS is at
13048 // least 3. If we divide it by Denominator < 4, we will have at least 1.
13049 auto *DenomMinusTwo = getMinusSCEV(DenominatorExt, getConstant(WTy, 2));
13050 if (isKnownNonPositive(RHS) &&
13051 IsSGTViaContext(FoundRHSExt, DenomMinusTwo))
13052 return true;
13053
13054 // Try to prove the following rule:
13055 // (FoundRHS > -1 - Denominator) && (RHS < 0) => (LHS > RHS).
13056 // For example, given that FoundLHS > -3. Then FoundLHS is at least -2.
13057 // If we divide it by Denominator > 2, then:
13058 // 1. If FoundLHS is negative, then the result is 0.
13059 // 2. If FoundLHS is non-negative, then the result is non-negative.
13060 // Anyways, the result is non-negative.
13061 auto *MinusOne = getMinusOne(WTy);
13062 auto *NegDenomMinusOne = getMinusSCEV(MinusOne, DenominatorExt);
13063 if (isKnownNegative(RHS) &&
13064 IsSGTViaContext(FoundRHSExt, NegDenomMinusOne))
13065 return true;
13066 }
13067 }
13068
13069 // If our expression contained SCEVUnknown Phis, and we split it down and now
13070 // need to prove something for them, try to prove the predicate for every
13071 // possible incoming values of those Phis.
13072 if (isImpliedViaMerge(Pred, OrigLHS, RHS, OrigFoundLHS, FoundRHS, Depth + 1))
13073 return true;
13074
13075 return false;
13076}
13077
13079 const SCEV *RHS) {
13080 // zext x u<= sext x, sext x s<= zext x
13081 const SCEV *Op;
13082 switch (Pred) {
13083 case ICmpInst::ICMP_SGE:
13084 std::swap(LHS, RHS);
13085 [[fallthrough]];
13086 case ICmpInst::ICMP_SLE: {
13087 // If operand >=s 0 then ZExt == SExt. If operand <s 0 then SExt <s ZExt.
13088 return match(LHS, m_scev_SExt(m_SCEV(Op))) &&
13090 }
13091 case ICmpInst::ICMP_UGE:
13092 std::swap(LHS, RHS);
13093 [[fallthrough]];
13094 case ICmpInst::ICMP_ULE: {
13095 // If operand >=u 0 then ZExt == SExt. If operand <u 0 then ZExt <u SExt.
13096 return match(LHS, m_scev_ZExt(m_SCEV(Op))) &&
13098 }
13099 default:
13100 return false;
13101 };
13102 llvm_unreachable("unhandled case");
13103}
13104
13105bool ScalarEvolution::isKnownViaNonRecursiveReasoning(CmpPredicate Pred,
13106 SCEVUse LHS,
13107 SCEVUse RHS) {
13108 return isKnownPredicateExtendIdiom(Pred, LHS, RHS) ||
13109 isKnownPredicateViaConstantRanges(Pred, LHS, RHS) ||
13110 IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) ||
13111 IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) ||
13112 isKnownPredicateViaNoOverflow(Pred, LHS, RHS);
13113}
13114
13115bool ScalarEvolution::isImpliedCondOperandsHelper(CmpPredicate Pred,
13116 const SCEV *LHS,
13117 const SCEV *RHS,
13118 const SCEV *FoundLHS,
13119 const SCEV *FoundRHS) {
13120 switch (Pred) {
13121 default:
13122 llvm_unreachable("Unexpected CmpPredicate value!");
13123 case ICmpInst::ICMP_EQ:
13124 case ICmpInst::ICMP_NE:
13125 if (HasSameValue(LHS, FoundLHS) && HasSameValue(RHS, FoundRHS))
13126 return true;
13127 break;
13128 case ICmpInst::ICMP_SLT:
13129 case ICmpInst::ICMP_SLE:
13130 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, LHS, FoundLHS) &&
13131 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, RHS, FoundRHS))
13132 return true;
13133 break;
13134 case ICmpInst::ICMP_SGT:
13135 case ICmpInst::ICMP_SGE:
13136 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, LHS, FoundLHS) &&
13137 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, RHS, FoundRHS))
13138 return true;
13139 break;
13140 case ICmpInst::ICMP_ULT:
13141 case ICmpInst::ICMP_ULE:
13142 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, LHS, FoundLHS) &&
13143 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, RHS, FoundRHS))
13144 return true;
13145 break;
13146 case ICmpInst::ICMP_UGT:
13147 case ICmpInst::ICMP_UGE:
13148 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, LHS, FoundLHS) &&
13149 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, RHS, FoundRHS))
13150 return true;
13151 break;
13152 }
13153
13154 // Maybe it can be proved via operations?
13155 if (isImpliedViaOperations(Pred, LHS, RHS, FoundLHS, FoundRHS))
13156 return true;
13157
13158 return false;
13159}
13160
13161bool ScalarEvolution::isImpliedCondOperandsViaRanges(
13162 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, CmpPredicate FoundPred,
13163 const SCEV *FoundLHS, const SCEV *FoundRHS) {
13164 if (!isa<SCEVConstant>(RHS) || !isa<SCEVConstant>(FoundRHS))
13165 // The restriction on `FoundRHS` be lifted easily -- it exists only to
13166 // reduce the compile time impact of this optimization.
13167 return false;
13168
13169 std::optional<APInt> Addend = computeConstantDifference(LHS, FoundLHS);
13170 if (!Addend)
13171 return false;
13172
13173 const APInt &ConstFoundRHS = cast<SCEVConstant>(FoundRHS)->getAPInt();
13174
13175 // `FoundLHSRange` is the range we know `FoundLHS` to be in by virtue of the
13176 // antecedent "`FoundLHS` `FoundPred` `FoundRHS`".
13177 ConstantRange FoundLHSRange =
13178 ConstantRange::makeExactICmpRegion(FoundPred, ConstFoundRHS);
13179
13180 // Since `LHS` is `FoundLHS` + `Addend`, we can compute a range for `LHS`:
13181 ConstantRange LHSRange = FoundLHSRange.add(ConstantRange(*Addend));
13182
13183 // We can also compute the range of values for `LHS` that satisfy the
13184 // consequent, "`LHS` `Pred` `RHS`":
13185 const APInt &ConstRHS = cast<SCEVConstant>(RHS)->getAPInt();
13186 // The antecedent implies the consequent if every value of `LHS` that
13187 // satisfies the antecedent also satisfies the consequent.
13188 return LHSRange.icmp(Pred, ConstRHS);
13189}
13190
13191bool ScalarEvolution::canIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride,
13192 bool IsSigned) {
13193 assert(isKnownPositive(Stride) && "Positive stride expected!");
13194
13195 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
13196 const SCEV *One = getOne(Stride->getType());
13197
13198 if (IsSigned) {
13199 APInt MaxRHS = getSignedRangeMax(RHS);
13200 APInt MaxValue = APInt::getSignedMaxValue(BitWidth);
13201 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
13202
13203 // SMaxRHS + SMaxStrideMinusOne > SMaxValue => overflow!
13204 return (std::move(MaxValue) - MaxStrideMinusOne).slt(MaxRHS);
13205 }
13206
13207 APInt MaxRHS = getUnsignedRangeMax(RHS);
13208 APInt MaxValue = APInt::getMaxValue(BitWidth);
13209 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
13210
13211 // UMaxRHS + UMaxStrideMinusOne > UMaxValue => overflow!
13212 return (std::move(MaxValue) - MaxStrideMinusOne).ult(MaxRHS);
13213}
13214
13215bool ScalarEvolution::canIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride,
13216 bool IsSigned) {
13217
13218 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
13219 const SCEV *One = getOne(Stride->getType());
13220
13221 if (IsSigned) {
13222 APInt MinRHS = getSignedRangeMin(RHS);
13223 APInt MinValue = APInt::getSignedMinValue(BitWidth);
13224 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
13225
13226 // SMinRHS - SMaxStrideMinusOne < SMinValue => overflow!
13227 return (std::move(MinValue) + MaxStrideMinusOne).sgt(MinRHS);
13228 }
13229
13230 APInt MinRHS = getUnsignedRangeMin(RHS);
13231 APInt MinValue = APInt::getMinValue(BitWidth);
13232 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
13233
13234 // UMinRHS - UMaxStrideMinusOne < UMinValue => overflow!
13235 return (std::move(MinValue) + MaxStrideMinusOne).ugt(MinRHS);
13236}
13237
13239 // umin(N, 1) + floor((N - umin(N, 1)) / D)
13240 // This is equivalent to "1 + floor((N - 1) / D)" for N != 0. The umin
13241 // expression fixes the case of N=0.
13242 const SCEV *MinNOne = getUMinExpr(N, getOne(N->getType()));
13243 const SCEV *NMinusOne = getMinusSCEV(N, MinNOne);
13244 return getAddExpr(MinNOne, getUDivExpr(NMinusOne, D));
13245}
13246
13247const SCEV *ScalarEvolution::computeMaxBECountForLT(const SCEV *Start,
13248 const SCEV *Stride,
13249 const SCEV *End,
13250 unsigned BitWidth,
13251 bool IsSigned) {
13252 // The logic in this function assumes we can represent a positive stride.
13253 // If we can't, the backedge-taken count must be zero.
13254 if (IsSigned && BitWidth == 1)
13255 return getZero(Stride->getType());
13256
13257 // This code below only been closely audited for negative strides in the
13258 // unsigned comparison case, it may be correct for signed comparison, but
13259 // that needs to be established.
13260 if (IsSigned && isKnownNegative(Stride))
13261 return getCouldNotCompute();
13262
13263 // Calculate the maximum backedge count based on the range of values
13264 // permitted by Start, End, and Stride.
13265 APInt MinStart =
13266 IsSigned ? getSignedRangeMin(Start) : getUnsignedRangeMin(Start);
13267
13268 APInt MinStride =
13269 IsSigned ? getSignedRangeMin(Stride) : getUnsignedRangeMin(Stride);
13270
13271 // We assume either the stride is positive, or the backedge-taken count
13272 // is zero. So force StrideForMaxBECount to be at least one.
13273 APInt One(BitWidth, 1);
13274 APInt StrideForMaxBECount = IsSigned ? APIntOps::smax(One, MinStride)
13275 : APIntOps::umax(One, MinStride);
13276
13277 APInt MaxValue = IsSigned ? APInt::getSignedMaxValue(BitWidth)
13278 : APInt::getMaxValue(BitWidth);
13279 APInt Limit = MaxValue - (StrideForMaxBECount - 1);
13280
13281 // Although End can be a MAX expression we estimate MaxEnd considering only
13282 // the case End = RHS of the loop termination condition. This is safe because
13283 // in the other case (End - Start) is zero, leading to a zero maximum backedge
13284 // taken count.
13285 APInt MaxEnd = IsSigned ? APIntOps::smin(getSignedRangeMax(End), Limit)
13286 : APIntOps::umin(getUnsignedRangeMax(End), Limit);
13287
13288 // MaxBECount = ceil((max(MaxEnd, MinStart) - MinStart) / Stride)
13289 MaxEnd = IsSigned ? APIntOps::smax(MaxEnd, MinStart)
13290 : APIntOps::umax(MaxEnd, MinStart);
13291
13292 return getUDivCeilSCEV(getConstant(MaxEnd - MinStart) /* Delta */,
13293 getConstant(StrideForMaxBECount) /* Step */);
13294}
13295
13297ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
13298 const Loop *L, bool IsSigned,
13299 bool ControlsOnlyExit, bool AllowPredicates) {
13301
13303 bool PredicatedIV = false;
13304 if (!IV) {
13305 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS)) {
13306 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(ZExt->getOperand());
13307 if (AR && AR->getLoop() == L && AR->isAffine()) {
13308 auto canProveNUW = [&]() {
13309 // We can use the comparison to infer no-wrap flags only if it fully
13310 // controls the loop exit.
13311 if (!ControlsOnlyExit)
13312 return false;
13313
13314 if (!isLoopInvariant(RHS, L))
13315 return false;
13316
13317 if (!isKnownNonZero(AR->getStepRecurrence(*this)))
13318 // We need the sequence defined by AR to strictly increase in the
13319 // unsigned integer domain for the logic below to hold.
13320 return false;
13321
13322 const unsigned InnerBitWidth = getTypeSizeInBits(AR->getType());
13323 const unsigned OuterBitWidth = getTypeSizeInBits(RHS->getType());
13324 // If RHS <=u Limit, then there must exist a value V in the sequence
13325 // defined by AR (e.g. {Start,+,Step}) such that V >u RHS, and
13326 // V <=u UINT_MAX. Thus, we must exit the loop before unsigned
13327 // overflow occurs. This limit also implies that a signed comparison
13328 // (in the wide bitwidth) is equivalent to an unsigned comparison as
13329 // the high bits on both sides must be zero.
13330 APInt StrideMax = getUnsignedRangeMax(AR->getStepRecurrence(*this));
13331 APInt Limit = APInt::getMaxValue(InnerBitWidth) - (StrideMax - 1);
13332 Limit = Limit.zext(OuterBitWidth);
13333 return getUnsignedRangeMax(applyLoopGuards(RHS, L)).ule(Limit);
13334 };
13335 auto Flags = AR->getNoWrapFlags();
13336 if (!hasFlags(Flags, SCEV::FlagNUW) && canProveNUW())
13337 Flags = setFlags(Flags, SCEV::FlagNUW);
13338
13339 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
13340 if (AR->hasNoUnsignedWrap()) {
13341 // Emulate what getZeroExtendExpr would have done during construction
13342 // if we'd been able to infer the fact just above at that time.
13343 const SCEV *Step = AR->getStepRecurrence(*this);
13344 Type *Ty = ZExt->getType();
13345 auto *S = getAddRecExpr(
13347 getZeroExtendExpr(Step, Ty, 0), L, AR->getNoWrapFlags());
13349 }
13350 }
13351 }
13352 }
13353
13354
13355 if (!IV && AllowPredicates) {
13356 // Try to make this an AddRec using runtime tests, in the first X
13357 // iterations of this loop, where X is the SCEV expression found by the
13358 // algorithm below.
13359 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
13360 PredicatedIV = true;
13361 }
13362
13363 // Avoid weird loops
13364 if (!IV || IV->getLoop() != L || !IV->isAffine())
13365 return getCouldNotCompute();
13366
13367 // A precondition of this method is that the condition being analyzed
13368 // reaches an exiting branch which dominates the latch. Given that, we can
13369 // assume that an increment which violates the nowrap specification and
13370 // produces poison must cause undefined behavior when the resulting poison
13371 // value is branched upon and thus we can conclude that the backedge is
13372 // taken no more often than would be required to produce that poison value.
13373 // Note that a well defined loop can exit on the iteration which violates
13374 // the nowrap specification if there is another exit (either explicit or
13375 // implicit/exceptional) which causes the loop to execute before the
13376 // exiting instruction we're analyzing would trigger UB.
13377 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
13378 bool NoWrap = ControlsOnlyExit && any(IV->getNoWrapFlags(WrapType));
13380
13381 const SCEV *Stride = IV->getStepRecurrence(*this);
13382
13383 bool PositiveStride = isKnownPositive(Stride);
13384
13385 // Avoid negative or zero stride values.
13386 if (!PositiveStride) {
13387 // We can compute the correct backedge taken count for loops with unknown
13388 // strides if we can prove that the loop is not an infinite loop with side
13389 // effects. Here's the loop structure we are trying to handle -
13390 //
13391 // i = start
13392 // do {
13393 // A[i] = i;
13394 // i += s;
13395 // } while (i < end);
13396 //
13397 // The backedge taken count for such loops is evaluated as -
13398 // (max(end, start + stride) - start - 1) /u stride
13399 //
13400 // The additional preconditions that we need to check to prove correctness
13401 // of the above formula is as follows -
13402 //
13403 // a) IV is either nuw or nsw depending upon signedness (indicated by the
13404 // NoWrap flag).
13405 // b) the loop is guaranteed to be finite (e.g. is mustprogress and has
13406 // no side effects within the loop)
13407 // c) loop has a single static exit (with no abnormal exits)
13408 //
13409 // Precondition a) implies that if the stride is negative, this is a single
13410 // trip loop. The backedge taken count formula reduces to zero in this case.
13411 //
13412 // Precondition b) and c) combine to imply that if rhs is invariant in L,
13413 // then a zero stride means the backedge can't be taken without executing
13414 // undefined behavior.
13415 //
13416 // The positive stride case is the same as isKnownPositive(Stride) returning
13417 // true (original behavior of the function).
13418 //
13419 if (PredicatedIV || !NoWrap || !loopIsFiniteByAssumption(L) ||
13421 return getCouldNotCompute();
13422
13423 if (!isKnownNonZero(Stride)) {
13424 // If we have a step of zero, and RHS isn't invariant in L, we don't know
13425 // if it might eventually be greater than start and if so, on which
13426 // iteration. We can't even produce a useful upper bound.
13427 if (!isLoopInvariant(RHS, L))
13428 return getCouldNotCompute();
13429
13430 // We allow a potentially zero stride, but we need to divide by stride
13431 // below. Since the loop can't be infinite and this check must control
13432 // the sole exit, we can infer the exit must be taken on the first
13433 // iteration (e.g. backedge count = 0) if the stride is zero. Given that,
13434 // we know the numerator in the divides below must be zero, so we can
13435 // pick an arbitrary non-zero value for the denominator (e.g. stride)
13436 // and produce the right result.
13437 // FIXME: Handle the case where Stride is poison?
13438 auto wouldZeroStrideBeUB = [&]() {
13439 // Proof by contradiction. Suppose the stride were zero. If we can
13440 // prove that the backedge *is* taken on the first iteration, then since
13441 // we know this condition controls the sole exit, we must have an
13442 // infinite loop. We can't have a (well defined) infinite loop per
13443 // check just above.
13444 // Note: The (Start - Stride) term is used to get the start' term from
13445 // (start' + stride,+,stride). Remember that we only care about the
13446 // result of this expression when stride == 0 at runtime.
13447 auto *StartIfZero = getMinusSCEV(IV->getStart(), Stride);
13448 return isLoopEntryGuardedByCond(L, Cond, StartIfZero, RHS);
13449 };
13450 if (!wouldZeroStrideBeUB()) {
13451 Stride = getUMaxExpr(Stride, getOne(Stride->getType()));
13452 }
13453 }
13454 } else if (!NoWrap) {
13455 // Avoid proven overflow cases: this will ensure that the backedge taken
13456 // count will not generate any unsigned overflow.
13457 if (canIVOverflowOnLT(RHS, Stride, IsSigned))
13458 return getCouldNotCompute();
13459 }
13460
13461 // On all paths just preceeding, we established the following invariant:
13462 // IV can be assumed not to overflow up to and including the exiting
13463 // iteration. We proved this in one of two ways:
13464 // 1) We can show overflow doesn't occur before the exiting iteration
13465 // 1a) canIVOverflowOnLT, and b) step of one
13466 // 2) We can show that if overflow occurs, the loop must execute UB
13467 // before any possible exit.
13468 // Note that we have not yet proved RHS invariant (in general).
13469
13470 const SCEV *Start = IV->getStart();
13471
13472 // Preserve pointer-typed Start/RHS to pass to isLoopEntryGuardedByCond.
13473 // If we convert to integers, isLoopEntryGuardedByCond will miss some cases.
13474 // Use integer-typed versions for actual computation; we can't subtract
13475 // pointers in general.
13476 const SCEV *OrigStart = Start;
13477 const SCEV *OrigRHS = RHS;
13478 if (Start->getType()->isPointerTy()) {
13480 if (isa<SCEVCouldNotCompute>(Start))
13481 return Start;
13482 }
13483 if (RHS->getType()->isPointerTy()) {
13486 return RHS;
13487 }
13488
13489 const SCEV *End = nullptr, *BECount = nullptr,
13490 *BECountIfBackedgeTaken = nullptr;
13491 if (!isLoopInvariant(RHS, L)) {
13492 const auto *RHSAddRec = dyn_cast<SCEVAddRecExpr>(RHS);
13493 if (PositiveStride && RHSAddRec != nullptr && RHSAddRec->getLoop() == L &&
13494 any(RHSAddRec->getNoWrapFlags())) {
13495 // The structure of loop we are trying to calculate backedge count of:
13496 //
13497 // left = left_start
13498 // right = right_start
13499 //
13500 // while(left < right){
13501 // ... do something here ...
13502 // left += s1; // stride of left is s1 (s1 > 0)
13503 // right += s2; // stride of right is s2 (s2 < 0)
13504 // }
13505 //
13506
13507 const SCEV *RHSStart = RHSAddRec->getStart();
13508 const SCEV *RHSStride = RHSAddRec->getStepRecurrence(*this);
13509
13510 // If Stride - RHSStride is positive and does not overflow, we can write
13511 // backedge count as ->
13512 // ceil((End - Start) /u (Stride - RHSStride))
13513 // Where, End = max(RHSStart, Start)
13514
13515 // Check if RHSStride < 0 and Stride - RHSStride will not overflow.
13516 if (isKnownNegative(RHSStride) &&
13517 willNotOverflow(Instruction::Sub, /*Signed=*/true, Stride,
13518 RHSStride)) {
13519
13520 const SCEV *Denominator = getMinusSCEV(Stride, RHSStride);
13521 if (isKnownPositive(Denominator)) {
13522 End = IsSigned ? getSMaxExpr(RHSStart, Start)
13523 : getUMaxExpr(RHSStart, Start);
13524
13525 // We can do this because End >= Start, as End = max(RHSStart, Start)
13526 const SCEV *Delta = getMinusSCEV(End, Start);
13527
13528 BECount = getUDivCeilSCEV(Delta, Denominator);
13529 BECountIfBackedgeTaken =
13530 getUDivCeilSCEV(getMinusSCEV(RHSStart, Start), Denominator);
13531 }
13532 }
13533 }
13534 if (BECount == nullptr) {
13535 // If we cannot calculate ExactBECount, we can calculate the MaxBECount,
13536 // given the start, stride and max value for the end bound of the
13537 // loop (RHS), and the fact that IV does not overflow (which is
13538 // checked above).
13539 const SCEV *MaxBECount = computeMaxBECountForLT(
13540 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13541 return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,
13542 MaxBECount, false /*MaxOrZero*/, Predicates);
13543 }
13544 } else {
13545 // We use the expression (max(End,Start)-Start)/Stride to describe the
13546 // backedge count, as if the backedge is taken at least once
13547 // max(End,Start) is End and so the result is as above, and if not
13548 // max(End,Start) is Start so we get a backedge count of zero.
13549 auto *OrigStartMinusStride = getMinusSCEV(OrigStart, Stride);
13550 assert(isAvailableAtLoopEntry(OrigStartMinusStride, L) && "Must be!");
13551 assert(isAvailableAtLoopEntry(OrigStart, L) && "Must be!");
13552 assert(isAvailableAtLoopEntry(OrigRHS, L) && "Must be!");
13553 // Can we prove (max(RHS,Start) > Start - Stride?
13554 if (isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigStart) &&
13555 isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigRHS)) {
13556 // In this case, we can use a refined formula for computing backedge
13557 // taken count. The general formula remains:
13558 // "End-Start /uceiling Stride" where "End = max(RHS,Start)"
13559 // We want to use the alternate formula:
13560 // "((End - 1) - (Start - Stride)) /u Stride"
13561 // Let's do a quick case analysis to show these are equivalent under
13562 // our precondition that max(RHS,Start) > Start - Stride.
13563 // * For RHS <= Start, the backedge-taken count must be zero.
13564 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
13565 // "((Start - 1) - (Start - Stride)) /u Stride" which simplies to
13566 // "Stride - 1 /u Stride" which is indeed zero for all non-zero values
13567 // of Stride. For 0 stride, we've use umin(1,Stride) above,
13568 // reducing this to the stride of 1 case.
13569 // * For RHS >= Start, the backedge count must be "RHS-Start /uceil
13570 // Stride".
13571 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
13572 // "((RHS - 1) - (Start - Stride)) /u Stride" reassociates to
13573 // "((RHS - (Start - Stride) - 1) /u Stride".
13574 // Our preconditions trivially imply no overflow in that form.
13575 const SCEV *MinusOne = getMinusOne(Stride->getType());
13576 const SCEV *Numerator =
13577 getMinusSCEV(getAddExpr(RHS, MinusOne), getMinusSCEV(Start, Stride));
13578 BECount = getUDivExpr(Numerator, Stride);
13579 }
13580
13581 if (!BECount) {
13582 auto canProveRHSGreaterThanEqualStart = [&]() {
13583 auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
13584 const SCEV *GuardedRHS = applyLoopGuards(OrigRHS, L);
13585 const SCEV *GuardedStart = applyLoopGuards(OrigStart, L);
13586
13587 if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart) ||
13588 isKnownPredicate(CondGE, GuardedRHS, GuardedStart))
13589 return true;
13590
13591 // (RHS > Start - 1) implies RHS >= Start.
13592 // * "RHS >= Start" is trivially equivalent to "RHS > Start - 1" if
13593 // "Start - 1" doesn't overflow.
13594 // * For signed comparison, if Start - 1 does overflow, it's equal
13595 // to INT_MAX, and "RHS >s INT_MAX" is trivially false.
13596 // * For unsigned comparison, if Start - 1 does overflow, it's equal
13597 // to UINT_MAX, and "RHS >u UINT_MAX" is trivially false.
13598 //
13599 // FIXME: Should isLoopEntryGuardedByCond do this for us?
13600 auto CondGT = IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT;
13601 auto *StartMinusOne =
13602 getAddExpr(OrigStart, getMinusOne(OrigStart->getType()));
13603 return isLoopEntryGuardedByCond(L, CondGT, OrigRHS, StartMinusOne);
13604 };
13605
13606 // If we know that RHS >= Start in the context of loop, then we know
13607 // that max(RHS, Start) = RHS at this point.
13608 if (canProveRHSGreaterThanEqualStart()) {
13609 End = RHS;
13610 } else {
13611 // If RHS < Start, the backedge will be taken zero times. So in
13612 // general, we can write the backedge-taken count as:
13613 //
13614 // RHS >= Start ? ceil(RHS - Start) / Stride : 0
13615 //
13616 // We convert it to the following to make it more convenient for SCEV:
13617 //
13618 // ceil(max(RHS, Start) - Start) / Stride
13619 End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start);
13620
13621 // See what would happen if we assume the backedge is taken. This is
13622 // used to compute MaxBECount.
13623 BECountIfBackedgeTaken =
13624 getUDivCeilSCEV(getMinusSCEV(RHS, Start), Stride);
13625 }
13626
13627 // At this point, we know:
13628 //
13629 // 1. If IsSigned, Start <=s End; otherwise, Start <=u End
13630 // 2. The index variable doesn't overflow.
13631 //
13632 // Therefore, we know N exists such that
13633 // (Start + Stride * N) >= End, and computing "(Start + Stride * N)"
13634 // doesn't overflow.
13635 //
13636 // Using this information, try to prove whether the addition in
13637 // "(Start - End) + (Stride - 1)" has unsigned overflow.
13638 const SCEV *One = getOne(Stride->getType());
13639 bool MayAddOverflow = [&] {
13640 if (isKnownToBeAPowerOfTwo(Stride)) {
13641 // Suppose Stride is a power of two, and Start/End are unsigned
13642 // integers. Let UMAX be the largest representable unsigned
13643 // integer.
13644 //
13645 // By the preconditions of this function, we know
13646 // "(Start + Stride * N) >= End", and this doesn't overflow.
13647 // As a formula:
13648 //
13649 // End <= (Start + Stride * N) <= UMAX
13650 //
13651 // Subtracting Start from all the terms:
13652 //
13653 // End - Start <= Stride * N <= UMAX - Start
13654 //
13655 // Since Start is unsigned, UMAX - Start <= UMAX. Therefore:
13656 //
13657 // End - Start <= Stride * N <= UMAX
13658 //
13659 // Stride * N is a multiple of Stride. Therefore,
13660 //
13661 // End - Start <= Stride * N <= UMAX - (UMAX mod Stride)
13662 //
13663 // Since Stride is a power of two, UMAX + 1 is divisible by
13664 // Stride. Therefore, UMAX mod Stride == Stride - 1. So we can
13665 // write:
13666 //
13667 // End - Start <= Stride * N <= UMAX - Stride - 1
13668 //
13669 // Dropping the middle term:
13670 //
13671 // End - Start <= UMAX - Stride - 1
13672 //
13673 // Adding Stride - 1 to both sides:
13674 //
13675 // (End - Start) + (Stride - 1) <= UMAX
13676 //
13677 // In other words, the addition doesn't have unsigned overflow.
13678 //
13679 // A similar proof works if we treat Start/End as signed values.
13680 // Just rewrite steps before "End - Start <= Stride * N <= UMAX"
13681 // to use signed max instead of unsigned max. Note that we're
13682 // trying to prove a lack of unsigned overflow in either case.
13683 return false;
13684 }
13685 if (Start == Stride || Start == getMinusSCEV(Stride, One)) {
13686 // If Start is equal to Stride, (End - Start) + (Stride - 1) == End
13687 // - 1. If !IsSigned, 0 <u Stride == Start <=u End; so 0 <u End - 1
13688 // <u End. If IsSigned, 0 <s Stride == Start <=s End; so 0 <s End -
13689 // 1 <s End.
13690 //
13691 // If Start is equal to Stride - 1, (End - Start) + Stride - 1 ==
13692 // End.
13693 return false;
13694 }
13695 return true;
13696 }();
13697
13698 const SCEV *Delta = getMinusSCEV(End, Start);
13699 if (!MayAddOverflow) {
13700 // floor((D + (S - 1)) / S)
13701 // We prefer this formulation if it's legal because it's fewer
13702 // operations.
13703 BECount =
13704 getUDivExpr(getAddExpr(Delta, getMinusSCEV(Stride, One)), Stride);
13705 } else {
13706 BECount = getUDivCeilSCEV(Delta, Stride);
13707 }
13708 }
13709 }
13710
13711 const SCEV *ConstantMaxBECount;
13712 bool MaxOrZero = false;
13713 if (isa<SCEVConstant>(BECount)) {
13714 ConstantMaxBECount = BECount;
13715 } else if (BECountIfBackedgeTaken &&
13716 isa<SCEVConstant>(BECountIfBackedgeTaken)) {
13717 // If we know exactly how many times the backedge will be taken if it's
13718 // taken at least once, then the backedge count will either be that or
13719 // zero.
13720 ConstantMaxBECount = BECountIfBackedgeTaken;
13721 MaxOrZero = true;
13722 } else {
13723 ConstantMaxBECount = computeMaxBECountForLT(
13724 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13725 }
13726
13727 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
13728 !isa<SCEVCouldNotCompute>(BECount))
13729 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
13730
13731 const SCEV *SymbolicMaxBECount =
13732 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13733 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, MaxOrZero,
13734 Predicates);
13735}
13736
13737ScalarEvolution::ExitLimit ScalarEvolution::howManyGreaterThans(
13738 const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned,
13739 bool ControlsOnlyExit, bool AllowPredicates) {
13741 // We handle only IV > Invariant
13742 if (!isLoopInvariant(RHS, L))
13743 return getCouldNotCompute();
13744
13745 const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
13746 if (!IV && AllowPredicates)
13747 // Try to make this an AddRec using runtime tests, in the first X
13748 // iterations of this loop, where X is the SCEV expression found by the
13749 // algorithm below.
13750 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
13751
13752 // Avoid weird loops
13753 if (!IV || IV->getLoop() != L || !IV->isAffine())
13754 return getCouldNotCompute();
13755
13756 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
13757 bool NoWrap = ControlsOnlyExit && any(IV->getNoWrapFlags(WrapType));
13759
13760 const SCEV *Stride = getNegativeSCEV(IV->getStepRecurrence(*this));
13761
13762 // Avoid negative or zero stride values
13763 if (!isKnownPositive(Stride))
13764 return getCouldNotCompute();
13765
13766 // Avoid proven overflow cases: this will ensure that the backedge taken count
13767 // will not generate any unsigned overflow. Relaxed no-overflow conditions
13768 // exploit NoWrapFlags, allowing to optimize in presence of undefined
13769 // behaviors like the case of C language.
13770 if (!Stride->isOne() && !NoWrap)
13771 if (canIVOverflowOnGT(RHS, Stride, IsSigned))
13772 return getCouldNotCompute();
13773
13774 const SCEV *Start = IV->getStart();
13775 const SCEV *End = RHS;
13776 if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) {
13777 // If we know that Start >= RHS in the context of loop, then we know that
13778 // min(RHS, Start) = RHS at this point.
13780 L, IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE, Start, RHS))
13781 End = RHS;
13782 else
13783 End = IsSigned ? getSMinExpr(RHS, Start) : getUMinExpr(RHS, Start);
13784 }
13785
13786 if (Start->getType()->isPointerTy()) {
13788 if (isa<SCEVCouldNotCompute>(Start))
13789 return Start;
13790 }
13791 if (End->getType()->isPointerTy()) {
13792 End = getLosslessPtrToIntExpr(End);
13793 if (isa<SCEVCouldNotCompute>(End))
13794 return End;
13795 }
13796
13797 // Compute ((Start - End) + (Stride - 1)) / Stride.
13798 // FIXME: This can overflow. Holding off on fixing this for now;
13799 // howManyGreaterThans will hopefully be gone soon.
13800 const SCEV *One = getOne(Stride->getType());
13801 const SCEV *BECount = getUDivExpr(
13802 getAddExpr(getMinusSCEV(Start, End), getMinusSCEV(Stride, One)), Stride);
13803
13804 APInt MaxStart = IsSigned ? getSignedRangeMax(Start)
13806
13807 APInt MinStride = IsSigned ? getSignedRangeMin(Stride)
13808 : getUnsignedRangeMin(Stride);
13809
13810 unsigned BitWidth = getTypeSizeInBits(LHS->getType());
13811 APInt Limit = IsSigned ? APInt::getSignedMinValue(BitWidth) + (MinStride - 1)
13812 : APInt::getMinValue(BitWidth) + (MinStride - 1);
13813
13814 // Although End can be a MIN expression we estimate MinEnd considering only
13815 // the case End = RHS. This is safe because in the other case (Start - End)
13816 // is zero, leading to a zero maximum backedge taken count.
13817 APInt MinEnd =
13818 IsSigned ? APIntOps::smax(getSignedRangeMin(RHS), Limit)
13819 : APIntOps::umax(getUnsignedRangeMin(RHS), Limit);
13820
13821 const SCEV *ConstantMaxBECount =
13822 isa<SCEVConstant>(BECount)
13823 ? BECount
13824 : getUDivCeilSCEV(getConstant(MaxStart - MinEnd),
13825 getConstant(MinStride));
13826
13827 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount))
13828 ConstantMaxBECount = BECount;
13829 const SCEV *SymbolicMaxBECount =
13830 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13831
13832 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
13833 Predicates);
13834}
13835
13837 ScalarEvolution &SE) const {
13838 if (Range.isFullSet()) // Infinite loop.
13839 return SE.getCouldNotCompute();
13840
13841 // If the start is a non-zero constant, shift the range to simplify things.
13842 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart()))
13843 if (!SC->getValue()->isZero()) {
13845 Operands[0] = SE.getZero(SC->getType());
13846 const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop(),
13848 if (const auto *ShiftedAddRec = dyn_cast<SCEVAddRecExpr>(Shifted))
13849 return ShiftedAddRec->getNumIterationsInRange(
13850 Range.subtract(SC->getAPInt()), SE);
13851 // This is strange and shouldn't happen.
13852 return SE.getCouldNotCompute();
13853 }
13854
13855 // The only time we can solve this is when we have all constant indices.
13856 // Otherwise, we cannot determine the overflow conditions.
13857 if (any_of(operands(), [](const SCEV *Op) { return !isa<SCEVConstant>(Op); }))
13858 return SE.getCouldNotCompute();
13859
13860 // Okay at this point we know that all elements of the chrec are constants and
13861 // that the start element is zero.
13862
13863 // First check to see if the range contains zero. If not, the first
13864 // iteration exits.
13865 unsigned BitWidth = SE.getTypeSizeInBits(getType());
13866 if (!Range.contains(APInt(BitWidth, 0)))
13867 return SE.getZero(getType());
13868
13869 if (isAffine()) {
13870 // If this is an affine expression then we have this situation:
13871 // Solve {0,+,A} in Range === Ax in Range
13872
13873 // We know that zero is in the range. If A is positive then we know that
13874 // the upper value of the range must be the first possible exit value.
13875 // If A is negative then the lower of the range is the last possible loop
13876 // value. Also note that we already checked for a full range.
13877 APInt A = cast<SCEVConstant>(getOperand(1))->getAPInt();
13878 APInt End = A.sge(1) ? (Range.getUpper() - 1) : Range.getLower();
13879
13880 // The exit value should be (End+A)/A.
13881 APInt ExitVal = (End + A).udiv(A);
13882 ConstantInt *ExitValue = ConstantInt::get(SE.getContext(), ExitVal);
13883
13884 // Evaluate at the exit value. If we really did fall out of the valid
13885 // range, then we computed our trip count, otherwise wrap around or other
13886 // things must have happened.
13887 ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue, SE);
13888 if (Range.contains(Val->getValue()))
13889 return SE.getCouldNotCompute(); // Something strange happened
13890
13891 // Ensure that the previous value is in the range.
13892 assert(Range.contains(
13894 ConstantInt::get(SE.getContext(), ExitVal - 1), SE)->getValue()) &&
13895 "Linear scev computation is off in a bad way!");
13896 return SE.getConstant(ExitValue);
13897 }
13898
13899 if (isQuadratic()) {
13900 if (auto S = SolveQuadraticAddRecRange(this, Range, SE))
13901 return SE.getConstant(*S);
13902 }
13903
13904 return SE.getCouldNotCompute();
13905}
13906
13907const SCEVAddRecExpr *
13909 assert(getNumOperands() > 1 && "AddRec with zero step?");
13910 // There is a temptation to just call getAddExpr(this, getStepRecurrence(SE)),
13911 // but in this case we cannot guarantee that the value returned will be an
13912 // AddRec because SCEV does not have a fixed point where it stops
13913 // simplification: it is legal to return ({rec1} + {rec2}). For example, it
13914 // may happen if we reach arithmetic depth limit while simplifying. So we
13915 // construct the returned value explicitly.
13917 // If this is {A,+,B,+,C,...,+,N}, then its step is {B,+,C,+,...,+,N}, and
13918 // (this + Step) is {A+B,+,B+C,+...,+,N}.
13919 for (unsigned i = 0, e = getNumOperands() - 1; i < e; ++i)
13920 Ops.push_back(SE.getAddExpr(getOperand(i), getOperand(i + 1)));
13921 // We know that the last operand is not a constant zero (otherwise it would
13922 // have been popped out earlier). This guarantees us that if the result has
13923 // the same last operand, then it will also not be popped out, meaning that
13924 // the returned value will be an AddRec.
13925 const SCEV *Last = getOperand(getNumOperands() - 1);
13926 assert(!Last->isZero() && "Recurrency with zero step?");
13927 Ops.push_back(Last);
13930}
13931
13932// Return true when S contains at least an undef value.
13934 return SCEVExprContains(
13935 S, [](const SCEV *S) { return match(S, m_scev_UndefOrPoison()); });
13936}
13937
13938// Return true when S contains a value that is a nullptr.
13940 return SCEVExprContains(S, [](const SCEV *S) {
13941 if (const auto *SU = dyn_cast<SCEVUnknown>(S))
13942 return SU->getValue() == nullptr;
13943 return false;
13944 });
13945}
13946
13947/// Return the size of an element read or written by Inst.
13949 Type *Ty;
13950 if (StoreInst *Store = dyn_cast<StoreInst>(Inst))
13951 Ty = Store->getValueOperand()->getType();
13952 else if (LoadInst *Load = dyn_cast<LoadInst>(Inst))
13953 Ty = Load->getType();
13954 else
13955 return nullptr;
13956
13958 return getSizeOfExpr(ETy, Ty);
13959}
13960
13961//===----------------------------------------------------------------------===//
13962// SCEVCallbackVH Class Implementation
13963//===----------------------------------------------------------------------===//
13964
13966 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13967 if (PHINode *PN = dyn_cast<PHINode>(getValPtr()))
13968 SE->ConstantEvolutionLoopExitValue.erase(PN);
13969 SE->eraseValueFromMap(getValPtr());
13970 // this now dangles!
13971}
13972
13973void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) {
13974 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13975
13976 // Forget all the expressions associated with users of the old value,
13977 // so that future queries will recompute the expressions using the new
13978 // value.
13979 SE->forgetValue(getValPtr());
13980 // this now dangles!
13981}
13982
13983ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se)
13984 : CallbackVH(V), SE(se) {}
13985
13986//===----------------------------------------------------------------------===//
13987// ScalarEvolution Class Implementation
13988//===----------------------------------------------------------------------===//
13989
13992 LoopInfo &LI)
13993 : F(F), DL(F.getDataLayout()), TLI(TLI), AC(AC), DT(DT), LI(LI),
13994 CouldNotCompute(new SCEVCouldNotCompute()), ValuesAtScopes(64),
13995 LoopDispositions(64), BlockDispositions(64) {
13996 // To use guards for proving predicates, we need to scan every instruction in
13997 // relevant basic blocks, and not just terminators. Doing this is a waste of
13998 // time if the IR does not actually contain any calls to
13999 // @llvm.experimental.guard, so do a quick check and remember this beforehand.
14000 //
14001 // This pessimizes the case where a pass that preserves ScalarEvolution wants
14002 // to _add_ guards to the module when there weren't any before, and wants
14003 // ScalarEvolution to optimize based on those guards. For now we prefer to be
14004 // efficient in lieu of being smart in that rather obscure case.
14005
14006 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
14007 F.getParent(), Intrinsic::experimental_guard);
14008 HasGuards = GuardDecl && !GuardDecl->use_empty();
14009}
14010
14012 : F(Arg.F), DL(Arg.DL), HasGuards(Arg.HasGuards), TLI(Arg.TLI), AC(Arg.AC),
14013 DT(Arg.DT), LI(Arg.LI), CouldNotCompute(std::move(Arg.CouldNotCompute)),
14014 ValueExprMap(std::move(Arg.ValueExprMap)),
14015 PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)),
14016 PendingMerges(std::move(Arg.PendingMerges)),
14017 ConstantMultipleCache(std::move(Arg.ConstantMultipleCache)),
14018 BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)),
14019 PredicatedBackedgeTakenCounts(
14020 std::move(Arg.PredicatedBackedgeTakenCounts)),
14021 BECountUsers(std::move(Arg.BECountUsers)),
14022 ConstantEvolutionLoopExitValue(
14023 std::move(Arg.ConstantEvolutionLoopExitValue)),
14024 ValuesAtScopes(std::move(Arg.ValuesAtScopes)),
14025 ValuesAtScopesUsers(std::move(Arg.ValuesAtScopesUsers)),
14026 LoopDispositions(std::move(Arg.LoopDispositions)),
14027 LoopPropertiesCache(std::move(Arg.LoopPropertiesCache)),
14028 BlockDispositions(std::move(Arg.BlockDispositions)),
14029 SCEVUsers(std::move(Arg.SCEVUsers)),
14030 UnsignedRanges(std::move(Arg.UnsignedRanges)),
14031 SignedRanges(std::move(Arg.SignedRanges)),
14032 UniqueSCEVs(std::move(Arg.UniqueSCEVs)),
14033 UniquePreds(std::move(Arg.UniquePreds)),
14034 SCEVAllocator(std::move(Arg.SCEVAllocator)),
14035 LoopUsers(std::move(Arg.LoopUsers)),
14036 PredicatedSCEVRewrites(std::move(Arg.PredicatedSCEVRewrites)),
14037 FirstUnknown(Arg.FirstUnknown) {
14038 Arg.FirstUnknown = nullptr;
14039}
14040
14042 // Iterate through all the SCEVUnknown instances and call their
14043 // destructors, so that they release their references to their values.
14044 for (SCEVUnknown *U = FirstUnknown; U;) {
14045 SCEVUnknown *Tmp = U;
14046 U = U->Next;
14047 Tmp->~SCEVUnknown();
14048 }
14049 FirstUnknown = nullptr;
14050
14051 ExprValueMap.clear();
14052 ValueExprMap.clear();
14053 HasRecMap.clear();
14054 BackedgeTakenCounts.clear();
14055 PredicatedBackedgeTakenCounts.clear();
14056
14057 assert(PendingLoopPredicates.empty() && "isImpliedCond garbage");
14058 assert(PendingMerges.empty() && "isImpliedViaMerge garbage");
14059 assert(!WalkingBEDominatingConds && "isLoopBackedgeGuardedByCond garbage!");
14060 assert(!ProvingSplitPredicate && "ProvingSplitPredicate garbage!");
14061}
14062
14066
14067/// When printing a top-level SCEV for trip counts, it's helpful to include
14068/// a type for constants which are otherwise hard to disambiguate.
14069static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV* S) {
14070 if (isa<SCEVConstant>(S))
14071 OS << *S->getType() << " ";
14072 OS << *S;
14073}
14074
14076 const Loop *L) {
14077 // Print all inner loops first
14078 for (Loop *I : *L)
14079 PrintLoopInfo(OS, SE, I);
14080
14081 OS << "Loop ";
14082 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14083 OS << ": ";
14084
14085 SmallVector<BasicBlock *, 8> ExitingBlocks;
14086 L->getExitingBlocks(ExitingBlocks);
14087 if (ExitingBlocks.size() != 1)
14088 OS << "<multiple exits> ";
14089
14090 auto *BTC = SE->getBackedgeTakenCount(L);
14091 if (!isa<SCEVCouldNotCompute>(BTC)) {
14092 OS << "backedge-taken count is ";
14093 PrintSCEVWithTypeHint(OS, BTC);
14094 } else
14095 OS << "Unpredictable backedge-taken count.";
14096 OS << "\n";
14097
14098 if (ExitingBlocks.size() > 1)
14099 for (BasicBlock *ExitingBlock : ExitingBlocks) {
14100 OS << " exit count for " << ExitingBlock->getName() << ": ";
14101 const SCEV *EC = SE->getExitCount(L, ExitingBlock);
14102 PrintSCEVWithTypeHint(OS, EC);
14103 if (isa<SCEVCouldNotCompute>(EC)) {
14104 // Retry with predicates.
14106 EC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates);
14107 if (!isa<SCEVCouldNotCompute>(EC)) {
14108 OS << "\n predicated exit count for " << ExitingBlock->getName()
14109 << ": ";
14110 PrintSCEVWithTypeHint(OS, EC);
14111 OS << "\n Predicates:\n";
14112 for (const auto *P : Predicates)
14113 P->print(OS, 4);
14114 }
14115 }
14116 OS << "\n";
14117 }
14118
14119 OS << "Loop ";
14120 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14121 OS << ": ";
14122
14123 auto *ConstantBTC = SE->getConstantMaxBackedgeTakenCount(L);
14124 if (!isa<SCEVCouldNotCompute>(ConstantBTC)) {
14125 OS << "constant max backedge-taken count is ";
14126 PrintSCEVWithTypeHint(OS, ConstantBTC);
14128 OS << ", actual taken count either this or zero.";
14129 } else {
14130 OS << "Unpredictable constant max backedge-taken count. ";
14131 }
14132
14133 OS << "\n"
14134 "Loop ";
14135 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14136 OS << ": ";
14137
14138 auto *SymbolicBTC = SE->getSymbolicMaxBackedgeTakenCount(L);
14139 if (!isa<SCEVCouldNotCompute>(SymbolicBTC)) {
14140 OS << "symbolic max backedge-taken count is ";
14141 PrintSCEVWithTypeHint(OS, SymbolicBTC);
14143 OS << ", actual taken count either this or zero.";
14144 } else {
14145 OS << "Unpredictable symbolic max backedge-taken count. ";
14146 }
14147 OS << "\n";
14148
14149 if (ExitingBlocks.size() > 1)
14150 for (BasicBlock *ExitingBlock : ExitingBlocks) {
14151 OS << " symbolic max exit count for " << ExitingBlock->getName() << ": ";
14152 auto *ExitBTC = SE->getExitCount(L, ExitingBlock,
14154 PrintSCEVWithTypeHint(OS, ExitBTC);
14155 if (isa<SCEVCouldNotCompute>(ExitBTC)) {
14156 // Retry with predicates.
14158 ExitBTC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates,
14160 if (!isa<SCEVCouldNotCompute>(ExitBTC)) {
14161 OS << "\n predicated symbolic max exit count for "
14162 << ExitingBlock->getName() << ": ";
14163 PrintSCEVWithTypeHint(OS, ExitBTC);
14164 OS << "\n Predicates:\n";
14165 for (const auto *P : Predicates)
14166 P->print(OS, 4);
14167 }
14168 }
14169 OS << "\n";
14170 }
14171
14173 auto *PBT = SE->getPredicatedBackedgeTakenCount(L, Preds);
14174 if (PBT != BTC) {
14175 assert(!Preds.empty() && "Different predicated BTC, but no predicates");
14176 OS << "Loop ";
14177 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14178 OS << ": ";
14179 if (!isa<SCEVCouldNotCompute>(PBT)) {
14180 OS << "Predicated backedge-taken count is ";
14181 PrintSCEVWithTypeHint(OS, PBT);
14182 } else
14183 OS << "Unpredictable predicated backedge-taken count.";
14184 OS << "\n";
14185 OS << " Predicates:\n";
14186 for (const auto *P : Preds)
14187 P->print(OS, 4);
14188 }
14189 Preds.clear();
14190
14191 auto *PredConstantMax =
14193 if (PredConstantMax != ConstantBTC) {
14194 assert(!Preds.empty() &&
14195 "different predicated constant max BTC but no predicates");
14196 OS << "Loop ";
14197 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14198 OS << ": ";
14199 if (!isa<SCEVCouldNotCompute>(PredConstantMax)) {
14200 OS << "Predicated constant max backedge-taken count is ";
14201 PrintSCEVWithTypeHint(OS, PredConstantMax);
14202 } else
14203 OS << "Unpredictable predicated constant max backedge-taken count.";
14204 OS << "\n";
14205 OS << " Predicates:\n";
14206 for (const auto *P : Preds)
14207 P->print(OS, 4);
14208 }
14209 Preds.clear();
14210
14211 auto *PredSymbolicMax =
14213 if (SymbolicBTC != PredSymbolicMax) {
14214 assert(!Preds.empty() &&
14215 "Different predicated symbolic max BTC, but no predicates");
14216 OS << "Loop ";
14217 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14218 OS << ": ";
14219 if (!isa<SCEVCouldNotCompute>(PredSymbolicMax)) {
14220 OS << "Predicated symbolic max backedge-taken count is ";
14221 PrintSCEVWithTypeHint(OS, PredSymbolicMax);
14222 } else
14223 OS << "Unpredictable predicated symbolic max backedge-taken count.";
14224 OS << "\n";
14225 OS << " Predicates:\n";
14226 for (const auto *P : Preds)
14227 P->print(OS, 4);
14228 }
14229
14231 OS << "Loop ";
14232 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14233 OS << ": ";
14234 OS << "Trip multiple is " << SE->getSmallConstantTripMultiple(L) << "\n";
14235 }
14236}
14237
14238namespace llvm {
14239// Note: these overloaded operators need to be in the llvm namespace for them
14240// to be resolved correctly. If we put them outside the llvm namespace, the
14241//
14242// OS << ": " << SE.getLoopDisposition(SV, InnerL);
14243//
14244// code below "breaks" and start printing raw enum values as opposed to the
14245// string values.
14248 switch (LD) {
14250 OS << "Variant";
14251 break;
14253 OS << "Invariant";
14254 break;
14256 OS << "Uniform";
14257 break;
14259 OS << "Computable";
14260 break;
14261 }
14262 return OS;
14263}
14264
14267 switch (BD) {
14269 OS << "DoesNotDominate";
14270 break;
14272 OS << "Dominates";
14273 break;
14275 OS << "ProperlyDominates";
14276 break;
14277 }
14278 return OS;
14279}
14280} // namespace llvm
14281
14283 // ScalarEvolution's implementation of the print method is to print
14284 // out SCEV values of all instructions that are interesting. Doing
14285 // this potentially causes it to create new SCEV objects though,
14286 // which technically conflicts with the const qualifier. This isn't
14287 // observable from outside the class though, so casting away the
14288 // const isn't dangerous.
14289 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
14290
14291 if (ClassifyExpressions) {
14292 OS << "Classifying expressions for: ";
14293 F.printAsOperand(OS, /*PrintType=*/false);
14294 OS << "\n";
14295 for (Instruction &I : instructions(F))
14296 if (isSCEVable(I.getType()) && !isa<CmpInst>(I)) {
14297 OS << I << '\n';
14298 OS << " --> ";
14299 const SCEV *SV = SE.getSCEV(&I);
14300 SV->print(OS);
14301 if (!isa<SCEVCouldNotCompute>(SV)) {
14302 OS << " U: ";
14303 SE.getUnsignedRange(SV).print(OS);
14304 OS << " S: ";
14305 SE.getSignedRange(SV).print(OS);
14306 }
14307
14308 const Loop *L = LI.getLoopFor(I.getParent());
14309
14310 const SCEV *AtUse = SE.getSCEVAtScope(SV, L);
14311 if (AtUse != SV) {
14312 OS << " --> ";
14313 AtUse->print(OS);
14314 if (!isa<SCEVCouldNotCompute>(AtUse)) {
14315 OS << " U: ";
14316 SE.getUnsignedRange(AtUse).print(OS);
14317 OS << " S: ";
14318 SE.getSignedRange(AtUse).print(OS);
14319 }
14320 }
14321
14322 if (L) {
14323 OS << "\t\t" "Exits: ";
14324 const SCEV *ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop());
14325 if (!SE.isLoopInvariant(ExitValue, L)) {
14326 OS << "<<Unknown>>";
14327 } else {
14328 OS << *ExitValue;
14329 }
14330
14331 ListSeparator LS(", ", "\t\tLoopDispositions: { ");
14332 for (const auto *Iter = L; Iter; Iter = Iter->getParentLoop()) {
14333 OS << LS;
14334 Iter->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14335 OS << ": " << SE.getLoopDisposition(SV, Iter);
14336 }
14337
14338 for (const auto *InnerL : depth_first(L)) {
14339 if (InnerL == L)
14340 continue;
14341 OS << LS;
14342 InnerL->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14343 OS << ": " << SE.getLoopDisposition(SV, InnerL);
14344 }
14345
14346 OS << " }";
14347 }
14348
14349 OS << "\n";
14350 }
14351 }
14352
14353 OS << "Determining loop execution counts for: ";
14354 F.printAsOperand(OS, /*PrintType=*/false);
14355 OS << "\n";
14356 for (Loop *I : LI)
14357 PrintLoopInfo(OS, &SE, I);
14358}
14359
14362 auto &Values = LoopDispositions[S];
14363 for (auto &V : Values) {
14364 if (V.getPointer() == L)
14365 return V.getInt();
14366 }
14367 Values.emplace_back(L, LoopVariant);
14368 LoopDisposition D = computeLoopDisposition(S, L);
14369 auto &Values2 = LoopDispositions[S];
14370 for (auto &V : llvm::reverse(Values2)) {
14371 if (V.getPointer() == L) {
14372 V.setInt(D);
14373 break;
14374 }
14375 }
14376 return D;
14377}
14378
14380ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) {
14381 switch (S->getSCEVType()) {
14382 case scConstant:
14383 case scVScale:
14384 return LoopInvariant;
14385 case scAddRecExpr: {
14386 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
14387
14388 // If L is the addrec's loop, it's computable.
14389 if (AR->getLoop() == L)
14390 return LoopComputable;
14391
14392 // Add recurrences are never invariant in the function-body (null loop).
14393 if (!L)
14394 return LoopVariant;
14395
14396 // Everything that is not defined at loop entry is variant.
14397 if (DT.dominates(L->getHeader(), AR->getLoop()->getHeader())) {
14398 if (L->contains(AR->getLoop()) &&
14399 llvm::all_of(AR->operands(),
14400 [&](const SCEV *Op) { return isLoopUniform(Op, L); }))
14401 return LoopUniform;
14402
14403 return LoopVariant;
14404 }
14405 assert(!L->contains(AR->getLoop()) && "Containing loop's header does not"
14406 " dominate the contained loop's header?");
14407
14408 // This recurrence is invariant w.r.t. L if AR's loop contains L.
14409 if (AR->getLoop()->contains(L))
14410 return LoopInvariant;
14411
14412 // This recurrence is variant w.r.t. L if any of its operands
14413 // are variant.
14414 for (SCEVUse Op : AR->operands())
14415 if (!isLoopInvariant(Op, L))
14416 return LoopVariant;
14417
14418 // Otherwise it's loop-invariant.
14419 return LoopInvariant;
14420 }
14421 case scTruncate:
14422 case scZeroExtend:
14423 case scSignExtend:
14424 case scPtrToAddr:
14425 case scPtrToInt:
14426 case scAddExpr:
14427 case scMulExpr:
14428 case scUDivExpr:
14429 case scUMaxExpr:
14430 case scSMaxExpr:
14431 case scUMinExpr:
14432 case scSMinExpr:
14433 case scSequentialUMinExpr: {
14434 bool HasVarying = false;
14435 bool HasUniform = false;
14436 for (SCEVUse Op : S->operands()) {
14438 if (D == LoopVariant)
14439 return LoopVariant;
14440 if (D == LoopComputable)
14441 HasVarying = true;
14442 if (D == LoopUniform)
14443 HasUniform = true;
14444 }
14445 return HasVarying ? (HasUniform ? LoopVariant : LoopComputable)
14446 : (HasUniform ? LoopUniform : LoopInvariant);
14447 }
14448 case scUnknown:
14449 // All non-instruction values are loop invariant. All instructions are loop
14450 // invariant if they are not contained in the specified loop.
14451 // Instructions are never considered invariant in the function body
14452 // (null loop) because they are defined within the "loop".
14454 return (L && !L->contains(I)) ? LoopInvariant : LoopVariant;
14455 return LoopInvariant;
14456 case scCouldNotCompute:
14457 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
14458 }
14459 llvm_unreachable("Unknown SCEV kind!");
14460}
14461
14462bool ScalarEvolution::isLoopUniform(const SCEV *S, const Loop *L) {
14464 return D == LoopUniform || D == LoopInvariant;
14465}
14466
14468 return getLoopDisposition(S, L) == LoopInvariant;
14469}
14470
14472 return getLoopDisposition(S, L) == LoopComputable;
14473}
14474
14477 auto &Values = BlockDispositions[S];
14478 for (auto &V : Values) {
14479 if (V.getPointer() == BB)
14480 return V.getInt();
14481 }
14482 Values.emplace_back(BB, DoesNotDominateBlock);
14483 BlockDisposition D = computeBlockDisposition(S, BB);
14484 auto &Values2 = BlockDispositions[S];
14485 for (auto &V : llvm::reverse(Values2)) {
14486 if (V.getPointer() == BB) {
14487 V.setInt(D);
14488 break;
14489 }
14490 }
14491 return D;
14492}
14493
14495ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) {
14496 switch (S->getSCEVType()) {
14497 case scConstant:
14498 case scVScale:
14500 case scAddRecExpr: {
14501 // This uses a "dominates" query instead of "properly dominates" query
14502 // to test for proper dominance too, because the instruction which
14503 // produces the addrec's value is a PHI, and a PHI effectively properly
14504 // dominates its entire containing block.
14505 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
14506 if (!DT.dominates(AR->getLoop()->getHeader(), BB))
14507 return DoesNotDominateBlock;
14508
14509 // Fall through into SCEVNAryExpr handling.
14510 [[fallthrough]];
14511 }
14512 case scTruncate:
14513 case scZeroExtend:
14514 case scSignExtend:
14515 case scPtrToAddr:
14516 case scPtrToInt:
14517 case scAddExpr:
14518 case scMulExpr:
14519 case scUDivExpr:
14520 case scUMaxExpr:
14521 case scSMaxExpr:
14522 case scUMinExpr:
14523 case scSMinExpr:
14524 case scSequentialUMinExpr: {
14525 bool Proper = true;
14526 for (const SCEV *NAryOp : S->operands()) {
14528 if (D == DoesNotDominateBlock)
14529 return DoesNotDominateBlock;
14530 if (D == DominatesBlock)
14531 Proper = false;
14532 }
14533 return Proper ? ProperlyDominatesBlock : DominatesBlock;
14534 }
14535 case scUnknown:
14536 if (Instruction *I =
14538 if (I->getParent() == BB)
14539 return DominatesBlock;
14540 if (DT.properlyDominates(I->getParent(), BB))
14542 return DoesNotDominateBlock;
14543 }
14545 case scCouldNotCompute:
14546 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
14547 }
14548 llvm_unreachable("Unknown SCEV kind!");
14549}
14550
14551bool ScalarEvolution::dominates(const SCEV *S, const BasicBlock *BB) {
14552 return getBlockDisposition(S, BB) >= DominatesBlock;
14553}
14554
14557}
14558
14559bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const {
14560 return SCEVExprContains(S, [&](const SCEV *Expr) { return Expr == Op; });
14561}
14562
14563void ScalarEvolution::forgetBackedgeTakenCounts(const Loop *L,
14564 bool Predicated) {
14565 auto &BECounts =
14566 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
14567 auto It = BECounts.find(L);
14568 if (It != BECounts.end()) {
14569 for (const ExitNotTakenInfo &ENT : It->second.ExitNotTaken) {
14570 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
14571 if (!isa<SCEVConstant>(S)) {
14572 auto UserIt = BECountUsers.find(S);
14573 assert(UserIt != BECountUsers.end());
14574 UserIt->second.erase({L, Predicated});
14575 }
14576 }
14577 }
14578 BECounts.erase(It);
14579 }
14580}
14581
14582void ScalarEvolution::forgetMemoizedResults(ArrayRef<SCEVUse> SCEVs) {
14583 SmallPtrSet<const SCEV *, 8> ToForget(llvm::from_range, SCEVs);
14584 SmallVector<SCEVUse, 8> Worklist(ToForget.begin(), ToForget.end());
14585
14586 while (!Worklist.empty()) {
14587 const SCEV *Curr = Worklist.pop_back_val();
14588 auto Users = SCEVUsers.find(Curr);
14589 if (Users != SCEVUsers.end())
14590 for (const auto *User : Users->second)
14591 if (ToForget.insert(User).second)
14592 Worklist.push_back(User);
14593 }
14594
14595 for (const auto *S : ToForget)
14596 forgetMemoizedResultsImpl(S);
14597
14598 PredicatedSCEVRewrites.remove_if(
14599 [&](const auto &Entry) { return ToForget.count(Entry.first.first); });
14600}
14601
14602void ScalarEvolution::forgetMemoizedResultsImpl(const SCEV *S) {
14603 LoopDispositions.erase(S);
14604 BlockDispositions.erase(S);
14605 UnsignedRanges.erase(S);
14606 SignedRanges.erase(S);
14607 HasRecMap.erase(S);
14608 ConstantMultipleCache.erase(S);
14609
14610 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S)) {
14611 UnsignedWrapViaInductionTried.erase(AR);
14612 SignedWrapViaInductionTried.erase(AR);
14613 }
14614
14615 auto ExprIt = ExprValueMap.find(S);
14616 if (ExprIt != ExprValueMap.end()) {
14617 for (Value *V : ExprIt->second) {
14618 auto ValueIt = ValueExprMap.find_as(V);
14619 if (ValueIt != ValueExprMap.end())
14620 ValueExprMap.erase(ValueIt);
14621 }
14622 ExprValueMap.erase(ExprIt);
14623 }
14624
14625 auto ScopeIt = ValuesAtScopes.find(S);
14626 if (ScopeIt != ValuesAtScopes.end()) {
14627 for (const auto &Pair : ScopeIt->second)
14628 if (!isa_and_nonnull<SCEVConstant>(Pair.second))
14629 llvm::erase(ValuesAtScopesUsers[Pair.second],
14630 std::make_pair(Pair.first, S));
14631 ValuesAtScopes.erase(ScopeIt);
14632 }
14633
14634 auto ScopeUserIt = ValuesAtScopesUsers.find(S);
14635 if (ScopeUserIt != ValuesAtScopesUsers.end()) {
14636 for (const auto &Pair : ScopeUserIt->second)
14637 llvm::erase(ValuesAtScopes[Pair.second], std::make_pair(Pair.first, S));
14638 ValuesAtScopesUsers.erase(ScopeUserIt);
14639 }
14640
14641 auto BEUsersIt = BECountUsers.find(S);
14642 if (BEUsersIt != BECountUsers.end()) {
14643 // Work on a copy, as forgetBackedgeTakenCounts() will modify the original.
14644 auto Copy = BEUsersIt->second;
14645 for (const auto &Pair : Copy)
14646 forgetBackedgeTakenCounts(Pair.getPointer(), Pair.getInt());
14647 BECountUsers.erase(BEUsersIt);
14648 }
14649
14650 auto FoldUser = FoldCacheUser.find(S);
14651 if (FoldUser != FoldCacheUser.end())
14652 for (auto &KV : FoldUser->second)
14653 FoldCache.erase(KV);
14654 FoldCacheUser.erase(S);
14655}
14656
14657void
14658ScalarEvolution::getUsedLoops(const SCEV *S,
14659 SmallPtrSetImpl<const Loop *> &LoopsUsed) {
14660 struct FindUsedLoops {
14661 FindUsedLoops(SmallPtrSetImpl<const Loop *> &LoopsUsed)
14662 : LoopsUsed(LoopsUsed) {}
14663 SmallPtrSetImpl<const Loop *> &LoopsUsed;
14664 bool follow(const SCEV *S) {
14665 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S))
14666 LoopsUsed.insert(AR->getLoop());
14667 return true;
14668 }
14669
14670 bool isDone() const { return false; }
14671 };
14672
14673 FindUsedLoops F(LoopsUsed);
14674 SCEVTraversal<FindUsedLoops>(F).visitAll(S);
14675}
14676
14677void ScalarEvolution::getReachableBlocks(
14680 Worklist.push_back(&F.getEntryBlock());
14681 while (!Worklist.empty()) {
14682 BasicBlock *BB = Worklist.pop_back_val();
14683 if (!Reachable.insert(BB).second)
14684 continue;
14685
14686 Value *Cond;
14687 BasicBlock *TrueBB, *FalseBB;
14688 if (match(BB->getTerminator(), m_Br(m_Value(Cond), m_BasicBlock(TrueBB),
14689 m_BasicBlock(FalseBB)))) {
14690 if (auto *C = dyn_cast<ConstantInt>(Cond)) {
14691 Worklist.push_back(C->isOne() ? TrueBB : FalseBB);
14692 continue;
14693 }
14694
14695 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
14696 const SCEV *L = getSCEV(Cmp->getOperand(0));
14697 const SCEV *R = getSCEV(Cmp->getOperand(1));
14698 if (isKnownPredicateViaConstantRanges(Cmp->getCmpPredicate(), L, R)) {
14699 Worklist.push_back(TrueBB);
14700 continue;
14701 }
14702 if (isKnownPredicateViaConstantRanges(Cmp->getInverseCmpPredicate(), L,
14703 R)) {
14704 Worklist.push_back(FalseBB);
14705 continue;
14706 }
14707 }
14708 }
14709
14710 append_range(Worklist, successors(BB));
14711 }
14712}
14713
14715 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
14716 ScalarEvolution SE2(F, TLI, AC, DT, LI);
14717
14718 SmallVector<Loop *, 8> LoopStack(LI.begin(), LI.end());
14719
14720 // Map's SCEV expressions from one ScalarEvolution "universe" to another.
14721 struct SCEVMapper : public SCEVRewriteVisitor<SCEVMapper> {
14722 SCEVMapper(ScalarEvolution &SE) : SCEVRewriteVisitor<SCEVMapper>(SE) {}
14723
14724 const SCEV *visitConstant(const SCEVConstant *Constant) {
14725 return SE.getConstant(Constant->getAPInt());
14726 }
14727
14728 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14729 return SE.getUnknown(Expr->getValue());
14730 }
14731
14732 const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
14733 return SE.getCouldNotCompute();
14734 }
14735 };
14736
14737 SCEVMapper SCM(SE2);
14738 SmallPtrSet<BasicBlock *, 16> ReachableBlocks;
14739 SE2.getReachableBlocks(ReachableBlocks, F);
14740
14741 auto GetDelta = [&](const SCEV *Old, const SCEV *New) -> const SCEV * {
14742 if (containsUndefs(Old) || containsUndefs(New)) {
14743 // SCEV treats "undef" as an unknown but consistent value (i.e. it does
14744 // not propagate undef aggressively). This means we can (and do) fail
14745 // verification in cases where a transform makes a value go from "undef"
14746 // to "undef+1" (say). The transform is fine, since in both cases the
14747 // result is "undef", but SCEV thinks the value increased by 1.
14748 return nullptr;
14749 }
14750
14751 // Unless VerifySCEVStrict is set, we only compare constant deltas.
14752 const SCEV *Delta = SE2.getMinusSCEV(Old, New);
14753 if (!VerifySCEVStrict && !isa<SCEVConstant>(Delta))
14754 return nullptr;
14755
14756 return Delta;
14757 };
14758
14759 while (!LoopStack.empty()) {
14760 auto *L = LoopStack.pop_back_val();
14761 llvm::append_range(LoopStack, *L);
14762
14763 // Only verify BECounts in reachable loops. For an unreachable loop,
14764 // any BECount is legal.
14765 if (!ReachableBlocks.contains(L->getHeader()))
14766 continue;
14767
14768 // Only verify cached BECounts. Computing new BECounts may change the
14769 // results of subsequent SCEV uses.
14770 auto It = BackedgeTakenCounts.find(L);
14771 if (It == BackedgeTakenCounts.end())
14772 continue;
14773
14774 auto *CurBECount =
14775 SCM.visit(It->second.getExact(L, const_cast<ScalarEvolution *>(this)));
14776 auto *NewBECount = SE2.getBackedgeTakenCount(L);
14777
14778 if (CurBECount == SE2.getCouldNotCompute() ||
14779 NewBECount == SE2.getCouldNotCompute()) {
14780 // NB! This situation is legal, but is very suspicious -- whatever pass
14781 // change the loop to make a trip count go from could not compute to
14782 // computable or vice-versa *should have* invalidated SCEV. However, we
14783 // choose not to assert here (for now) since we don't want false
14784 // positives.
14785 continue;
14786 }
14787
14788 if (SE.getTypeSizeInBits(CurBECount->getType()) >
14789 SE.getTypeSizeInBits(NewBECount->getType()))
14790 NewBECount = SE2.getZeroExtendExpr(NewBECount, CurBECount->getType());
14791 else if (SE.getTypeSizeInBits(CurBECount->getType()) <
14792 SE.getTypeSizeInBits(NewBECount->getType()))
14793 CurBECount = SE2.getZeroExtendExpr(CurBECount, NewBECount->getType());
14794
14795 const SCEV *Delta = GetDelta(CurBECount, NewBECount);
14796 if (Delta && !Delta->isZero()) {
14797 dbgs() << "Trip Count for " << *L << " Changed!\n";
14798 dbgs() << "Old: " << *CurBECount << "\n";
14799 dbgs() << "New: " << *NewBECount << "\n";
14800 dbgs() << "Delta: " << *Delta << "\n";
14801 std::abort();
14802 }
14803 }
14804
14805 // Collect all valid loops currently in LoopInfo.
14806 SmallPtrSet<Loop *, 32> ValidLoops;
14807 SmallVector<Loop *, 32> Worklist(LI.begin(), LI.end());
14808 while (!Worklist.empty()) {
14809 Loop *L = Worklist.pop_back_val();
14810 if (ValidLoops.insert(L).second)
14811 Worklist.append(L->begin(), L->end());
14812 }
14813 for (const auto &KV : ValueExprMap) {
14814#ifndef NDEBUG
14815 // Check for SCEV expressions referencing invalid/deleted loops.
14816 if (auto *AR = dyn_cast<SCEVAddRecExpr>(KV.second)) {
14817 assert(ValidLoops.contains(AR->getLoop()) &&
14818 "AddRec references invalid loop");
14819 }
14820#endif
14821
14822 // Check that the value is also part of the reverse map.
14823 auto It = ExprValueMap.find(KV.second);
14824 if (It == ExprValueMap.end() || !It->second.contains(KV.first)) {
14825 dbgs() << "Value " << *KV.first
14826 << " is in ValueExprMap but not in ExprValueMap\n";
14827 std::abort();
14828 }
14829
14830 if (auto *I = dyn_cast<Instruction>(&*KV.first)) {
14831 if (!ReachableBlocks.contains(I->getParent()))
14832 continue;
14833 const SCEV *OldSCEV = SCM.visit(KV.second);
14834 const SCEV *NewSCEV = SE2.getSCEV(I);
14835 const SCEV *Delta = GetDelta(OldSCEV, NewSCEV);
14836 if (Delta && !Delta->isZero()) {
14837 dbgs() << "SCEV for value " << *I << " changed!\n"
14838 << "Old: " << *OldSCEV << "\n"
14839 << "New: " << *NewSCEV << "\n"
14840 << "Delta: " << *Delta << "\n";
14841 std::abort();
14842 }
14843 }
14844 }
14845
14846 for (const auto &KV : ExprValueMap) {
14847 for (Value *V : KV.second) {
14848 const SCEV *S = ValueExprMap.lookup(V);
14849 if (!S) {
14850 dbgs() << "Value " << *V
14851 << " is in ExprValueMap but not in ValueExprMap\n";
14852 std::abort();
14853 }
14854 if (S != KV.first) {
14855 dbgs() << "Value " << *V << " mapped to " << *S << " rather than "
14856 << *KV.first << "\n";
14857 std::abort();
14858 }
14859 }
14860 }
14861
14862 // Verify integrity of SCEV users.
14863 for (const auto &S : UniqueSCEVs) {
14864 for (SCEVUse Op : S.operands()) {
14865 // We do not store dependencies of constants.
14866 if (isa<SCEVConstant>(Op))
14867 continue;
14868 auto It = SCEVUsers.find(Op);
14869 if (It != SCEVUsers.end() && It->second.count(&S))
14870 continue;
14871 dbgs() << "Use of operand " << *Op << " by user " << S
14872 << " is not being tracked!\n";
14873 std::abort();
14874 }
14875 }
14876
14877 // Verify integrity of ValuesAtScopes users.
14878 for (const auto &ValueAndVec : ValuesAtScopes) {
14879 const SCEV *Value = ValueAndVec.first;
14880 for (const auto &LoopAndValueAtScope : ValueAndVec.second) {
14881 const Loop *L = LoopAndValueAtScope.first;
14882 const SCEV *ValueAtScope = LoopAndValueAtScope.second;
14883 if (!isa<SCEVConstant>(ValueAtScope)) {
14884 auto It = ValuesAtScopesUsers.find(ValueAtScope);
14885 if (It != ValuesAtScopesUsers.end() &&
14886 is_contained(It->second, std::make_pair(L, Value)))
14887 continue;
14888 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14889 << *ValueAtScope << " missing in ValuesAtScopesUsers\n";
14890 std::abort();
14891 }
14892 }
14893 }
14894
14895 for (const auto &ValueAtScopeAndVec : ValuesAtScopesUsers) {
14896 const SCEV *ValueAtScope = ValueAtScopeAndVec.first;
14897 for (const auto &LoopAndValue : ValueAtScopeAndVec.second) {
14898 const Loop *L = LoopAndValue.first;
14899 const SCEV *Value = LoopAndValue.second;
14901 auto It = ValuesAtScopes.find(Value);
14902 if (It != ValuesAtScopes.end() &&
14903 is_contained(It->second, std::make_pair(L, ValueAtScope)))
14904 continue;
14905 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14906 << *ValueAtScope << " missing in ValuesAtScopes\n";
14907 std::abort();
14908 }
14909 }
14910
14911 // Verify integrity of BECountUsers.
14912 auto VerifyBECountUsers = [&](bool Predicated) {
14913 auto &BECounts =
14914 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
14915 for (const auto &LoopAndBEInfo : BECounts) {
14916 for (const ExitNotTakenInfo &ENT : LoopAndBEInfo.second.ExitNotTaken) {
14917 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
14918 if (!isa<SCEVConstant>(S)) {
14919 auto UserIt = BECountUsers.find(S);
14920 if (UserIt != BECountUsers.end() &&
14921 UserIt->second.contains({ LoopAndBEInfo.first, Predicated }))
14922 continue;
14923 dbgs() << "Value " << *S << " for loop " << *LoopAndBEInfo.first
14924 << " missing from BECountUsers\n";
14925 std::abort();
14926 }
14927 }
14928 }
14929 }
14930 };
14931 VerifyBECountUsers(/* Predicated */ false);
14932 VerifyBECountUsers(/* Predicated */ true);
14933
14934 // Verify intergity of loop disposition cache.
14935 for (auto &[S, Values] : LoopDispositions) {
14936 for (auto [Loop, CachedDisposition] : Values) {
14937 const auto RecomputedDisposition = SE2.getLoopDisposition(S, Loop);
14938 if (CachedDisposition != RecomputedDisposition) {
14939 dbgs() << "Cached disposition of " << *S << " for loop " << *Loop
14940 << " is incorrect: cached " << CachedDisposition << ", actual "
14941 << RecomputedDisposition << "\n";
14942 std::abort();
14943 }
14944 }
14945 }
14946
14947 // Verify integrity of the block disposition cache.
14948 for (auto &[S, Values] : BlockDispositions) {
14949 for (auto [BB, CachedDisposition] : Values) {
14950 const auto RecomputedDisposition = SE2.getBlockDisposition(S, BB);
14951 if (CachedDisposition != RecomputedDisposition) {
14952 dbgs() << "Cached disposition of " << *S << " for block %"
14953 << BB->getName() << " is incorrect: cached " << CachedDisposition
14954 << ", actual " << RecomputedDisposition << "\n";
14955 std::abort();
14956 }
14957 }
14958 }
14959
14960 // Verify FoldCache/FoldCacheUser caches.
14961 for (auto [FoldID, Expr] : FoldCache) {
14962 auto I = FoldCacheUser.find(Expr);
14963 if (I == FoldCacheUser.end()) {
14964 dbgs() << "Missing entry in FoldCacheUser for cached expression " << *Expr
14965 << "!\n";
14966 std::abort();
14967 }
14968 if (!is_contained(I->second, FoldID)) {
14969 dbgs() << "Missing FoldID in cached users of " << *Expr << "!\n";
14970 std::abort();
14971 }
14972 }
14973 for (auto [Expr, IDs] : FoldCacheUser) {
14974 for (auto &FoldID : IDs) {
14975 const SCEV *S = FoldCache.lookup(FoldID);
14976 if (!S) {
14977 dbgs() << "Missing entry in FoldCache for expression " << *Expr
14978 << "!\n";
14979 std::abort();
14980 }
14981 if (S != Expr) {
14982 dbgs() << "Entry in FoldCache doesn't match FoldCacheUser: " << *S
14983 << " != " << *Expr << "!\n";
14984 std::abort();
14985 }
14986 }
14987 }
14988
14989 // Verify that ConstantMultipleCache computations are correct. We check that
14990 // cached multiples and recomputed multiples are multiples of each other to
14991 // verify correctness. It is possible that a recomputed multiple is different
14992 // from the cached multiple due to strengthened no wrap flags or changes in
14993 // KnownBits computations.
14994 for (auto [S, Multiple] : ConstantMultipleCache) {
14995 APInt RecomputedMultiple = SE2.getConstantMultiple(S);
14996 if ((Multiple != 0 && RecomputedMultiple != 0 &&
14997 Multiple.urem(RecomputedMultiple) != 0 &&
14998 RecomputedMultiple.urem(Multiple) != 0)) {
14999 dbgs() << "Incorrect cached computation in ConstantMultipleCache for "
15000 << *S << " : Computed " << RecomputedMultiple
15001 << " but cache contains " << Multiple << "!\n";
15002 std::abort();
15003 }
15004 }
15005}
15006
15008 Function &F, const PreservedAnalyses &PA,
15009 FunctionAnalysisManager::Invalidator &Inv) {
15010 // Invalidate the ScalarEvolution object whenever it isn't preserved or one
15011 // of its dependencies is invalidated.
15012 auto PAC = PA.getChecker<ScalarEvolutionAnalysis>();
15013 return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>()) ||
15014 Inv.invalidate<AssumptionAnalysis>(F, PA) ||
15015 Inv.invalidate<DominatorTreeAnalysis>(F, PA) ||
15016 Inv.invalidate<LoopAnalysis>(F, PA);
15017}
15018
15019AnalysisKey ScalarEvolutionAnalysis::Key;
15020
15023 auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
15024 auto &AC = AM.getResult<AssumptionAnalysis>(F);
15025 auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
15026 auto &LI = AM.getResult<LoopAnalysis>(F);
15027 return ScalarEvolution(F, TLI, AC, DT, LI);
15028}
15029
15035
15038 // For compatibility with opt's -analyze feature under legacy pass manager
15039 // which was not ported to NPM. This keeps tests using
15040 // update_analyze_test_checks.py working.
15041 OS << "Printing analysis 'Scalar Evolution Analysis' for function '"
15042 << F.getName() << "':\n";
15044 return PreservedAnalyses::all();
15045}
15046
15048 "Scalar Evolution Analysis", false, true)
15054 "Scalar Evolution Analysis", false, true)
15055
15057
15059
15061 SE.reset(new ScalarEvolution(
15063 getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F),
15065 getAnalysis<LoopInfoWrapperPass>().getLoopInfo()));
15066 return false;
15067}
15068
15070
15072 SE->print(OS);
15073}
15074
15076 if (!VerifySCEV)
15077 return;
15078
15079 SE->verify();
15080}
15081
15089
15091 const SCEV *RHS) {
15092 return getComparePredicate(ICmpInst::ICMP_EQ, LHS, RHS);
15093}
15094
15095const SCEVPredicate *
15097 const SCEV *LHS, const SCEV *RHS) {
15099 assert(LHS->getType() == RHS->getType() &&
15100 "Type mismatch between LHS and RHS");
15101 // Unique this node based on the arguments
15102 ID.AddInteger(SCEVPredicate::P_Compare);
15103 ID.AddInteger(Pred);
15104 ID.AddPointer(LHS);
15105 ID.AddPointer(RHS);
15106 void *IP = nullptr;
15107 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
15108 return S;
15109 SCEVComparePredicate *Eq = new (SCEVAllocator)
15110 SCEVComparePredicate(ID.Intern(SCEVAllocator), Pred, LHS, RHS);
15111 UniquePreds.InsertNode(Eq, IP);
15112 return Eq;
15113}
15114
15116 const SCEVAddRecExpr *AR,
15119 // Unique this node based on the arguments
15120 ID.AddInteger(SCEVPredicate::P_Wrap);
15121 ID.AddPointer(AR);
15122 ID.AddInteger(AddedFlags);
15123 void *IP = nullptr;
15124 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
15125 return S;
15126 auto *OF = new (SCEVAllocator)
15127 SCEVWrapPredicate(ID.Intern(SCEVAllocator), AR, AddedFlags);
15128 UniquePreds.InsertNode(OF, IP);
15129 return OF;
15130}
15131
15132namespace {
15133
15134class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
15135public:
15136
15137 /// Rewrites \p S in the context of a loop L and the SCEV predication
15138 /// infrastructure.
15139 ///
15140 /// If \p Pred is non-null, the SCEV expression is rewritten to respect the
15141 /// equivalences present in \p Pred.
15142 ///
15143 /// If \p NewPreds is non-null, rewrite is free to add further predicates to
15144 /// \p NewPreds such that the result will be an AddRecExpr.
15145 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
15147 const SCEVPredicate *Pred) {
15148 SCEVPredicateRewriter Rewriter(L, SE, NewPreds, Pred);
15149 return Rewriter.visit(S);
15150 }
15151
15152 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
15153 if (Pred) {
15154 if (auto *U = dyn_cast<SCEVUnionPredicate>(Pred)) {
15155 for (const auto *Pred : U->getPredicates())
15156 if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred))
15157 if (IPred->getLHS() == Expr &&
15158 IPred->getPredicate() == ICmpInst::ICMP_EQ)
15159 return IPred->getRHS();
15160 } else if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred)) {
15161 if (IPred->getLHS() == Expr &&
15162 IPred->getPredicate() == ICmpInst::ICMP_EQ)
15163 return IPred->getRHS();
15164 }
15165 }
15166 return convertToAddRecWithPreds(Expr);
15167 }
15168
15169 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
15170 const SCEV *Operand = visit(Expr->getOperand());
15171 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
15172 if (AR && AR->getLoop() == L && AR->isAffine()) {
15173 // This couldn't be folded because the operand didn't have the nuw
15174 // flag. Add the nusw flag as an assumption that we could make.
15175 const SCEV *Step = AR->getStepRecurrence(SE);
15176 Type *Ty = Expr->getType();
15177 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNUSW))
15178 return SE.getAddRecExpr(SE.getZeroExtendExpr(AR->getStart(), Ty),
15179 SE.getSignExtendExpr(Step, Ty), L,
15180 AR->getNoWrapFlags());
15181 }
15182 return SE.getZeroExtendExpr(Operand, Expr->getType());
15183 }
15184
15185 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
15186 const SCEV *Operand = visit(Expr->getOperand());
15187 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
15188 if (AR && AR->getLoop() == L && AR->isAffine()) {
15189 // This couldn't be folded because the operand didn't have the nsw
15190 // flag. Add the nssw flag as an assumption that we could make.
15191 const SCEV *Step = AR->getStepRecurrence(SE);
15192 Type *Ty = Expr->getType();
15193 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNSSW))
15194 return SE.getAddRecExpr(SE.getSignExtendExpr(AR->getStart(), Ty),
15195 SE.getSignExtendExpr(Step, Ty), L,
15196 AR->getNoWrapFlags());
15197 }
15198 return SE.getSignExtendExpr(Operand, Expr->getType());
15199 }
15200
15201private:
15202 explicit SCEVPredicateRewriter(
15203 const Loop *L, ScalarEvolution &SE,
15204 SmallVectorImpl<const SCEVPredicate *> *NewPreds,
15205 const SCEVPredicate *Pred)
15206 : SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {}
15207
15208 bool addOverflowAssumption(const SCEVPredicate *P) {
15209 if (!NewPreds) {
15210 // Check if we've already made this assumption.
15211 return Pred && Pred->implies(P, SE);
15212 }
15213 NewPreds->push_back(P);
15214 return true;
15215 }
15216
15217 bool addOverflowAssumption(const SCEVAddRecExpr *AR,
15219 auto *A = SE.getWrapPredicate(AR, AddedFlags);
15220 return addOverflowAssumption(A);
15221 }
15222
15223 // If \p Expr represents a PHINode, we try to see if it can be represented
15224 // as an AddRec, possibly under a predicate (PHISCEVPred). If it is possible
15225 // to add this predicate as a runtime overflow check, we return the AddRec.
15226 // If \p Expr does not meet these conditions (is not a PHI node, or we
15227 // couldn't create an AddRec for it, or couldn't add the predicate), we just
15228 // return \p Expr.
15229 const SCEV *convertToAddRecWithPreds(const SCEVUnknown *Expr) {
15230 if (!isa<PHINode>(Expr->getValue()))
15231 return Expr;
15232 std::optional<
15233 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
15234 PredicatedRewrite = SE.createAddRecFromPHIWithCasts(Expr);
15235 if (!PredicatedRewrite)
15236 return Expr;
15237 for (const auto *P : PredicatedRewrite->second){
15238 // Wrap predicates from outer loops are not supported.
15239 if (auto *WP = dyn_cast<const SCEVWrapPredicate>(P)) {
15240 if (L != WP->getExpr()->getLoop())
15241 return Expr;
15242 }
15243 if (!addOverflowAssumption(P))
15244 return Expr;
15245 }
15246 return PredicatedRewrite->first;
15247 }
15248
15249 SmallVectorImpl<const SCEVPredicate *> *NewPreds;
15250 const SCEVPredicate *Pred;
15251 const Loop *L;
15252};
15253
15254} // end anonymous namespace
15255
15256const SCEV *
15258 const SCEVPredicate &Preds) {
15259 return SCEVPredicateRewriter::rewrite(S, L, *this, nullptr, &Preds);
15260}
15261
15263 const SCEV *S, const Loop *L,
15266 S = SCEVPredicateRewriter::rewrite(S, L, *this, &TransformPreds, nullptr);
15267 auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
15268
15269 if (!AddRec)
15270 return nullptr;
15271
15272 // Check if any of the transformed predicates is known to be false. In that
15273 // case, it doesn't make sense to convert to a predicated AddRec, as the
15274 // versioned loop will never execute.
15275 for (const SCEVPredicate *Pred : TransformPreds) {
15276 auto *WrapPred = dyn_cast<SCEVWrapPredicate>(Pred);
15277 if (!WrapPred || WrapPred->getFlags() != SCEVWrapPredicate::IncrementNSSW)
15278 continue;
15279
15280 const SCEVAddRecExpr *AddRecToCheck = WrapPred->getExpr();
15281 const SCEV *ExitCount = getBackedgeTakenCount(AddRecToCheck->getLoop());
15282 if (isa<SCEVCouldNotCompute>(ExitCount))
15283 continue;
15284
15285 const SCEV *Step = AddRecToCheck->getStepRecurrence(*this);
15286 if (!Step->isOne())
15287 continue;
15288
15289 ExitCount = getTruncateOrSignExtend(ExitCount, Step->getType());
15290 const SCEV *Add = getAddExpr(AddRecToCheck->getStart(), ExitCount);
15291 if (isKnownPredicate(CmpInst::ICMP_SLT, Add, AddRecToCheck->getStart()))
15292 return nullptr;
15293 }
15294
15295 // Since the transformation was successful, we can now transfer the SCEV
15296 // predicates.
15297 Preds.append(TransformPreds.begin(), TransformPreds.end());
15298
15299 return AddRec;
15300}
15301
15302/// SCEV predicates
15306
15308 const ICmpInst::Predicate Pred,
15309 const SCEV *LHS, const SCEV *RHS)
15310 : SCEVPredicate(ID, P_Compare), Pred(Pred), LHS(LHS), RHS(RHS) {
15311 assert(LHS->getType() == RHS->getType() && "LHS and RHS types don't match");
15312 assert(LHS != RHS && "LHS and RHS are the same SCEV");
15313}
15314
15316 ScalarEvolution &SE) const {
15317 const auto *Op = dyn_cast<SCEVComparePredicate>(N);
15318
15319 if (!Op)
15320 return false;
15321
15322 if (Pred != ICmpInst::ICMP_EQ)
15323 return false;
15324
15325 return Op->LHS == LHS && Op->RHS == RHS;
15326}
15327
15328bool SCEVComparePredicate::isAlwaysTrue() const { return false; }
15329
15331 if (Pred == ICmpInst::ICMP_EQ)
15332 OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n";
15333 else
15334 OS.indent(Depth) << "Compare predicate: " << *LHS << " " << Pred << ") "
15335 << *RHS << "\n";
15336
15337}
15338
15340 const SCEVAddRecExpr *AR,
15341 IncrementWrapFlags Flags)
15342 : SCEVPredicate(ID, P_Wrap), AR(AR), Flags(Flags) {}
15343
15344const SCEVAddRecExpr *SCEVWrapPredicate::getExpr() const { return AR; }
15345
15347 ScalarEvolution &SE) const {
15348 const auto *Op = dyn_cast<SCEVWrapPredicate>(N);
15349 if (!Op || setFlags(Flags, Op->Flags) != Flags)
15350 return false;
15351
15352 if (Op->AR == AR)
15353 return true;
15354
15355 if (Flags != SCEVWrapPredicate::IncrementNSSW &&
15357 return false;
15358
15359 const SCEV *Start = AR->getStart();
15360 const SCEV *OpStart = Op->AR->getStart();
15361 if (Start->getType()->isPointerTy() != OpStart->getType()->isPointerTy())
15362 return false;
15363
15364 // Reject pointers to different address spaces.
15365 if (Start->getType()->isPointerTy() && Start->getType() != OpStart->getType())
15366 return false;
15367
15368 // NUSW/NSSW on a wider-type AddRec does not imply the same on a
15369 // narrower-type AddRec.
15370 if (SE.getTypeSizeInBits(AR->getType()) >
15371 SE.getTypeSizeInBits(Op->AR->getType()))
15372 return false;
15373
15374 const SCEV *Step = AR->getStepRecurrence(SE);
15375 const SCEV *OpStep = Op->AR->getStepRecurrence(SE);
15376 if (!SE.isKnownPositive(Step) || !SE.isKnownPositive(OpStep))
15377 return false;
15378
15379 // If both steps are positive, this implies N, if N's start and step are
15380 // ULE/SLE (for NSUW/NSSW) than this'.
15381 Type *WiderTy = SE.getWiderType(Step->getType(), OpStep->getType());
15382 Step = SE.getNoopOrZeroExtend(Step, WiderTy);
15383 OpStep = SE.getNoopOrZeroExtend(OpStep, WiderTy);
15384
15385 bool IsNUW = Flags == SCEVWrapPredicate::IncrementNUSW;
15386 OpStart = IsNUW ? SE.getNoopOrZeroExtend(OpStart, WiderTy)
15387 : SE.getNoopOrSignExtend(OpStart, WiderTy);
15388 Start = IsNUW ? SE.getNoopOrZeroExtend(Start, WiderTy)
15389 : SE.getNoopOrSignExtend(Start, WiderTy);
15391 return SE.isKnownPredicate(Pred, OpStep, Step) &&
15392 SE.isKnownPredicate(Pred, OpStart, Start);
15393}
15394
15396 SCEV::NoWrapFlags ScevFlags = AR->getNoWrapFlags();
15397 IncrementWrapFlags IFlags = Flags;
15398
15399 if (ScalarEvolution::setFlags(ScevFlags, SCEV::FlagNSW) == ScevFlags)
15400 IFlags = clearFlags(IFlags, IncrementNSSW);
15401
15402 return IFlags == IncrementAnyWrap;
15403}
15404
15405void SCEVWrapPredicate::print(raw_ostream &OS, unsigned Depth) const {
15406 OS.indent(Depth) << *getExpr() << " Added Flags: ";
15408 OS << "<nusw>";
15410 OS << "<nssw>";
15411 OS << "\n";
15412}
15413
15416 ScalarEvolution &SE) {
15417 IncrementWrapFlags ImpliedFlags = IncrementAnyWrap;
15418 SCEV::NoWrapFlags StaticFlags = AR->getNoWrapFlags();
15419
15420 // We can safely transfer the NSW flag as NSSW.
15421 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNSW) == StaticFlags)
15422 ImpliedFlags = IncrementNSSW;
15423
15424 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNUW) == StaticFlags) {
15425 // If the increment is positive, the SCEV NUW flag will also imply the
15426 // WrapPredicate NUSW flag.
15427 if (const auto *Step = dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE)))
15428 if (Step->getValue()->getValue().isNonNegative())
15429 ImpliedFlags = setFlags(ImpliedFlags, IncrementNUSW);
15430 }
15431
15432 return ImpliedFlags;
15433}
15434
15435/// Union predicates don't get cached so create a dummy set ID for it.
15437 ScalarEvolution &SE)
15438 : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
15439 for (const auto *P : Preds)
15440 add(P, SE);
15441}
15442
15444 return all_of(Preds,
15445 [](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
15446}
15447
15449 ScalarEvolution &SE) const {
15450 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N))
15451 return all_of(Set->Preds, [this, &SE](const SCEVPredicate *I) {
15452 return this->implies(I, SE);
15453 });
15454
15455 return any_of(Preds,
15456 [N, &SE](const SCEVPredicate *I) { return I->implies(N, SE); });
15457}
15458
15460 for (const auto *Pred : Preds)
15461 Pred->print(OS, Depth);
15462}
15463
15464void SCEVUnionPredicate::add(const SCEVPredicate *N, ScalarEvolution &SE) {
15465 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) {
15466 for (const auto *Pred : Set->Preds)
15467 add(Pred, SE);
15468 return;
15469 }
15470
15471 // Implication checks are quadratic in the number of predicates. Stop doing
15472 // them if there are many predicates, as they should be too expensive to use
15473 // anyway at that point.
15474 bool CheckImplies = Preds.size() < 16;
15475
15476 // Only add predicate if it is not already implied by this union predicate.
15477 if (CheckImplies && implies(N, SE))
15478 return;
15479
15480 // Build a new vector containing the current predicates, except the ones that
15481 // are implied by the new predicate N.
15483 for (auto *P : Preds) {
15484 if (CheckImplies && N->implies(P, SE))
15485 continue;
15486 PrunedPreds.push_back(P);
15487 }
15488 Preds = std::move(PrunedPreds);
15489 Preds.push_back(N);
15490}
15491
15493 Loop &L)
15494 : SE(SE), L(L) {
15496 Preds = std::make_unique<SCEVUnionPredicate>(Empty, SE);
15497}
15498
15501 for (const auto *Op : Ops)
15502 // We do not expect that forgetting cached data for SCEVConstants will ever
15503 // open any prospects for sharpening or introduce any correctness issues,
15504 // so we don't bother storing their dependencies.
15505 if (!isa<SCEVConstant>(Op))
15506 SCEVUsers[Op].insert(User);
15507}
15508
15510 for (const SCEV *Op : Ops)
15511 // We do not expect that forgetting cached data for SCEVConstants will ever
15512 // open any prospects for sharpening or introduce any correctness issues,
15513 // so we don't bother storing their dependencies.
15514 if (!isa<SCEVConstant>(Op))
15515 SCEVUsers[Op].insert(User);
15516}
15517
15519 const SCEV *Expr = SE.getSCEV(V);
15520 return getPredicatedSCEV(Expr);
15521}
15522
15524 RewriteEntry &Entry = RewriteMap[Expr];
15525
15526 // If we already have an entry and the version matches, return it.
15527 if (Entry.second && Generation == Entry.first)
15528 return Entry.second;
15529
15530 // We found an entry but it's stale. Rewrite the stale entry
15531 // according to the current predicate.
15532 if (Entry.second)
15533 Expr = Entry.second;
15534
15535 const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, &L, *Preds);
15536 Entry = {Generation, NewSCEV};
15537
15538 return NewSCEV;
15539}
15540
15542 if (!BackedgeCount) {
15544 BackedgeCount = SE.getPredicatedBackedgeTakenCount(&L, Preds);
15545 for (const auto *P : Preds)
15546 addPredicate(*P);
15547 }
15548 return BackedgeCount;
15549}
15550
15552 if (!SymbolicMaxBackedgeCount) {
15554 SymbolicMaxBackedgeCount =
15555 SE.getPredicatedSymbolicMaxBackedgeTakenCount(&L, Preds);
15556 for (const auto *P : Preds)
15557 addPredicate(*P);
15558 }
15559 return SymbolicMaxBackedgeCount;
15560}
15561
15563 if (!SmallConstantMaxTripCount) {
15565 SmallConstantMaxTripCount = SE.getSmallConstantMaxTripCount(&L, &Preds);
15566 for (const auto *P : Preds)
15567 addPredicate(*P);
15568 }
15569 return *SmallConstantMaxTripCount;
15570}
15571
15573 if (Preds->implies(&Pred, SE))
15574 return;
15575
15576 SmallVector<const SCEVPredicate *, 4> NewPreds(Preds->getPredicates());
15577 NewPreds.push_back(&Pred);
15578 Preds = std::make_unique<SCEVUnionPredicate>(NewPreds, SE);
15579 updateGeneration();
15580}
15581
15583 return *Preds;
15584}
15585
15586void PredicatedScalarEvolution::updateGeneration() {
15587 // If the generation number wrapped recompute everything.
15588 if (++Generation == 0) {
15589 for (auto &II : RewriteMap) {
15590 const SCEV *Rewritten = II.second.second;
15591 II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, &L, *Preds)};
15592 }
15593 }
15594}
15595
15598 const SCEV *Expr = getSCEV(V);
15599 const auto *AR = cast<SCEVAddRecExpr>(Expr);
15600
15601 auto ImpliedFlags = SCEVWrapPredicate::getImpliedFlags(AR, SE);
15602
15603 // Clear the statically implied flags.
15604 Flags = SCEVWrapPredicate::clearFlags(Flags, ImpliedFlags);
15605 addPredicate(*SE.getWrapPredicate(AR, Flags));
15606
15607 auto II = FlagsMap.insert({V, Flags});
15608 if (!II.second)
15609 II.first->second = SCEVWrapPredicate::setFlags(Flags, II.first->second);
15610}
15611
15614 const SCEV *Expr = getSCEV(V);
15615 const auto *AR = cast<SCEVAddRecExpr>(Expr);
15616
15618 Flags, SCEVWrapPredicate::getImpliedFlags(AR, SE));
15619
15620 auto II = FlagsMap.find(V);
15621
15622 if (II != FlagsMap.end())
15623 Flags = SCEVWrapPredicate::clearFlags(Flags, II->second);
15624
15626}
15627
15629 const SCEV *Expr = this->getSCEV(V);
15631 auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, NewPreds);
15632
15633 if (!New)
15634 return nullptr;
15635
15636 for (const auto *P : NewPreds)
15637 addPredicate(*P);
15638
15639 RewriteMap[SE.getSCEV(V)] = {Generation, New};
15640 return New;
15641}
15642
15645 : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
15646 Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates(),
15647 SE)),
15648 Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
15649 for (auto I : Init.FlagsMap)
15650 FlagsMap.insert(I);
15651}
15652
15654 // For each block.
15655 for (auto *BB : L.getBlocks())
15656 for (auto &I : *BB) {
15657 if (!SE.isSCEVable(I.getType()))
15658 continue;
15659
15660 auto *Expr = SE.getSCEV(&I);
15661 auto II = RewriteMap.find(Expr);
15662
15663 if (II == RewriteMap.end())
15664 continue;
15665
15666 // Don't print things that are not interesting.
15667 if (II->second.second == Expr)
15668 continue;
15669
15670 OS.indent(Depth) << "[PSE]" << I << ":\n";
15671 OS.indent(Depth + 2) << *Expr << "\n";
15672 OS.indent(Depth + 2) << "--> " << *II->second.second << "\n";
15673 }
15674}
15675
15678 BasicBlock *Header = L->getHeader();
15679 BasicBlock *Pred = L->getLoopPredecessor();
15680 LoopGuards Guards(SE);
15681 if (!Pred)
15682 return Guards;
15684 collectFromBlock(SE, Guards, Header, Pred, VisitedBlocks);
15685 return Guards;
15686}
15687
15688void ScalarEvolution::LoopGuards::collectFromPHI(
15692 unsigned Depth) {
15693 if (!SE.isSCEVable(Phi.getType()))
15694 return;
15695
15696 using MinMaxPattern = std::pair<const SCEVConstant *, SCEVTypes>;
15697 auto GetMinMaxConst = [&](unsigned IncomingIdx) -> MinMaxPattern {
15698 const BasicBlock *InBlock = Phi.getIncomingBlock(IncomingIdx);
15699 if (!VisitedBlocks.insert(InBlock).second)
15700 return {nullptr, scCouldNotCompute};
15701
15702 // Avoid analyzing unreachable blocks so that we don't get trapped
15703 // traversing cycles with ill-formed dominance or infinite cycles
15704 if (!SE.DT.isReachableFromEntry(InBlock))
15705 return {nullptr, scCouldNotCompute};
15706
15707 auto [G, Inserted] = IncomingGuards.try_emplace(InBlock, LoopGuards(SE));
15708 if (Inserted)
15709 collectFromBlock(SE, G->second, Phi.getParent(), InBlock, VisitedBlocks,
15710 Depth + 1);
15711 auto &RewriteMap = G->second.RewriteMap;
15712 if (RewriteMap.empty())
15713 return {nullptr, scCouldNotCompute};
15714 auto S = RewriteMap.find(SE.getSCEV(Phi.getIncomingValue(IncomingIdx)));
15715 if (S == RewriteMap.end())
15716 return {nullptr, scCouldNotCompute};
15717 auto *SM = dyn_cast_if_present<SCEVMinMaxExpr>(S->second);
15718 if (!SM)
15719 return {nullptr, scCouldNotCompute};
15720 if (const SCEVConstant *C0 = dyn_cast<SCEVConstant>(SM->getOperand(0)))
15721 return {C0, SM->getSCEVType()};
15722 return {nullptr, scCouldNotCompute};
15723 };
15724 auto MergeMinMaxConst = [](MinMaxPattern P1,
15725 MinMaxPattern P2) -> MinMaxPattern {
15726 auto [C1, T1] = P1;
15727 auto [C2, T2] = P2;
15728 if (!C1 || !C2 || T1 != T2)
15729 return {nullptr, scCouldNotCompute};
15730 switch (T1) {
15731 case scUMaxExpr:
15732 return {C1->getAPInt().ult(C2->getAPInt()) ? C1 : C2, T1};
15733 case scSMaxExpr:
15734 return {C1->getAPInt().slt(C2->getAPInt()) ? C1 : C2, T1};
15735 case scUMinExpr:
15736 return {C1->getAPInt().ugt(C2->getAPInt()) ? C1 : C2, T1};
15737 case scSMinExpr:
15738 return {C1->getAPInt().sgt(C2->getAPInt()) ? C1 : C2, T1};
15739 default:
15740 llvm_unreachable("Trying to merge non-MinMaxExpr SCEVs.");
15741 }
15742 };
15743 auto P = GetMinMaxConst(0);
15744 for (unsigned int In = 1; In < Phi.getNumIncomingValues(); In++) {
15745 if (!P.first)
15746 break;
15747 P = MergeMinMaxConst(P, GetMinMaxConst(In));
15748 }
15749 if (P.first) {
15750 const SCEV *LHS = SE.getSCEV(const_cast<PHINode *>(&Phi));
15751 SmallVector<SCEVUse, 2> Ops({P.first, LHS});
15752 const SCEV *RHS = SE.getMinMaxExpr(P.second, Ops);
15753 Guards.RewriteMap.insert({LHS, RHS});
15754 }
15755}
15756
15757// Return a new SCEV that modifies \p Expr to the closest number divides by
15758// \p Divisor and less or equal than Expr. For now, only handle constant
15759// Expr.
15761 const APInt &DivisorVal,
15762 ScalarEvolution &SE) {
15763 const APInt *ExprVal;
15764 if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() ||
15765 DivisorVal.isNonPositive())
15766 return Expr;
15767 APInt Rem = ExprVal->urem(DivisorVal);
15768 // return the SCEV: Expr - Expr % Divisor
15769 return SE.getConstant(*ExprVal - Rem);
15770}
15771
15772// Return a new SCEV that modifies \p Expr to the closest number divides by
15773// \p Divisor and greater or equal than Expr. For now, only handle constant
15774// Expr.
15775static const SCEV *getNextSCEVDivisibleByDivisor(const SCEV *Expr,
15776 const APInt &DivisorVal,
15777 ScalarEvolution &SE) {
15778 const APInt *ExprVal;
15779 if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() ||
15780 DivisorVal.isNonPositive())
15781 return Expr;
15782 APInt Rem = ExprVal->urem(DivisorVal);
15783 if (Rem.isZero())
15784 return Expr;
15785 // return the SCEV: Expr + Divisor - Expr % Divisor
15786 return SE.getConstant(*ExprVal + DivisorVal - Rem);
15787}
15788
15790 ICmpInst::Predicate Predicate, const SCEV *LHS, const SCEV *RHS,
15793 // If we have LHS == 0, check if LHS is computing a property of some unknown
15794 // SCEV %v which we can rewrite %v to express explicitly.
15796 return false;
15797 // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
15798 // explicitly express that.
15799 const SCEVUnknown *URemLHS = nullptr;
15800 const SCEV *URemRHS = nullptr;
15801 if (!match(LHS, m_scev_URem(m_SCEVUnknown(URemLHS), m_SCEV(URemRHS), SE)))
15802 return false;
15803
15804 const SCEV *Multiple =
15805 SE.getMulExpr(SE.getUDivExpr(URemLHS, URemRHS), URemRHS);
15806 DivInfo[URemLHS] = Multiple;
15807 if (auto *C = dyn_cast<SCEVConstant>(URemRHS))
15808 Multiples[URemLHS] = C->getAPInt();
15809 return true;
15810}
15811
15812// Check if the condition is a divisibility guard (A % B == 0).
15813static bool isDivisibilityGuard(const SCEV *LHS, const SCEV *RHS,
15814 ScalarEvolution &SE) {
15815 const SCEV *X, *Y;
15816 return match(LHS, m_scev_URem(m_SCEV(X), m_SCEV(Y), SE)) && RHS->isZero();
15817}
15818
15819// Apply divisibility by \p Divisor on MinMaxExpr with constant values,
15820// recursively. This is done by aligning up/down the constant value to the
15821// Divisor.
15822static const SCEV *applyDivisibilityOnMinMaxExpr(const SCEV *MinMaxExpr,
15823 APInt Divisor,
15824 ScalarEvolution &SE) {
15825 // Return true if \p Expr is a MinMax SCEV expression with a non-negative
15826 // constant operand. If so, return in \p SCTy the SCEV type and in \p RHS
15827 // the non-constant operand and in \p LHS the constant operand.
15828 auto IsMinMaxSCEVWithNonNegativeConstant =
15829 [&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS,
15830 const SCEV *&RHS) {
15831 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
15832 if (MinMax->getNumOperands() != 2)
15833 return false;
15834 if (auto *C = dyn_cast<SCEVConstant>(MinMax->getOperand(0))) {
15835 if (C->getAPInt().isNegative())
15836 return false;
15837 SCTy = MinMax->getSCEVType();
15838 LHS = MinMax->getOperand(0);
15839 RHS = MinMax->getOperand(1);
15840 return true;
15841 }
15842 }
15843 return false;
15844 };
15845
15846 const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr;
15847 SCEVTypes SCTy;
15848 if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS,
15849 MinMaxRHS))
15850 return MinMaxExpr;
15851 auto IsMin = isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr);
15852 assert(SE.isKnownNonNegative(MinMaxLHS) && "Expected non-negative operand!");
15853 auto *DivisibleExpr =
15854 IsMin ? getPreviousSCEVDivisibleByDivisor(MinMaxLHS, Divisor, SE)
15855 : getNextSCEVDivisibleByDivisor(MinMaxLHS, Divisor, SE);
15857 applyDivisibilityOnMinMaxExpr(MinMaxRHS, Divisor, SE), DivisibleExpr};
15858 return SE.getMinMaxExpr(SCTy, Ops);
15859}
15860
15861void ScalarEvolution::LoopGuards::collectFromBlock(
15862 ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
15863 const BasicBlock *Block, const BasicBlock *Pred,
15864 SmallPtrSetImpl<const BasicBlock *> &VisitedBlocks, unsigned Depth) {
15865
15867
15868 SmallVector<SCEVUse> ExprsToRewrite;
15869 auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
15870 const SCEV *RHS,
15871 DenseMap<const SCEV *, const SCEV *> &RewriteMap,
15872 const LoopGuards &DivGuards) {
15873 // WARNING: It is generally unsound to apply any wrap flags to the proposed
15874 // replacement SCEV which isn't directly implied by the structure of that
15875 // SCEV. In particular, using contextual facts to imply flags is *NOT*
15876 // legal. See the scoping rules for flags in the header to understand why.
15877
15878 // Check for a condition of the form (-C1 + X < C2). InstCombine will
15879 // create this form when combining two checks of the form (X u< C2 + C1) and
15880 // (X >=u C1).
15881 auto MatchRangeCheckIdiom = [&SE, Predicate, LHS, RHS, &RewriteMap,
15882 &ExprsToRewrite]() {
15883 const SCEVConstant *C1;
15884 const SCEVUnknown *LHSUnknown;
15885 auto *C2 = dyn_cast<SCEVConstant>(RHS);
15886 if (!match(LHS,
15887 m_scev_Add(m_SCEVConstant(C1), m_SCEVUnknown(LHSUnknown))) ||
15888 !C2)
15889 return false;
15890
15891 auto ExactRegion =
15892 ConstantRange::makeExactICmpRegion(Predicate, C2->getAPInt())
15893 .sub(C1->getAPInt());
15894
15895 // Bail out, unless we have a non-wrapping, monotonic range.
15896 if (ExactRegion.isWrappedSet() || ExactRegion.isFullSet())
15897 return false;
15898 auto [I, Inserted] = RewriteMap.try_emplace(LHSUnknown);
15899 const SCEV *RewrittenLHS = Inserted ? LHSUnknown : I->second;
15900 I->second = SE.getUMaxExpr(
15901 SE.getConstant(ExactRegion.getUnsignedMin()),
15902 SE.getUMinExpr(RewrittenLHS,
15903 SE.getConstant(ExactRegion.getUnsignedMax())));
15904 ExprsToRewrite.push_back(LHSUnknown);
15905 return true;
15906 };
15907 if (MatchRangeCheckIdiom())
15908 return;
15909
15910 // Do not apply information for constants or if RHS contains an AddRec.
15912 return;
15913
15914 // If RHS is SCEVUnknown, make sure the information is applied to it.
15916 std::swap(LHS, RHS);
15918 }
15919
15920 // Puts rewrite rule \p From -> \p To into the rewrite map. Also if \p From
15921 // and \p FromRewritten are the same (i.e. there has been no rewrite
15922 // registered for \p From), then puts this value in the list of rewritten
15923 // expressions.
15924 auto AddRewrite = [&](const SCEV *From, const SCEV *FromRewritten,
15925 const SCEV *To) {
15926 if (From == FromRewritten)
15927 ExprsToRewrite.push_back(From);
15928 RewriteMap[From] = To;
15929 };
15930
15931 // Checks whether \p S has already been rewritten. In that case returns the
15932 // existing rewrite because we want to chain further rewrites onto the
15933 // already rewritten value. Otherwise returns \p S.
15934 auto GetMaybeRewritten = [&](const SCEV *S) {
15935 return RewriteMap.lookup_or(S, S);
15936 };
15937
15938 const SCEV *RewrittenLHS = GetMaybeRewritten(LHS);
15939 // Apply divisibility information when computing the constant multiple.
15940 const APInt &DividesBy =
15941 SE.getConstantMultiple(DivGuards.rewrite(RewrittenLHS));
15942
15943 // Collect rewrites for LHS and its transitive operands based on the
15944 // condition.
15945 // For min/max expressions, also apply the guard to its operands:
15946 // 'min(a, b) >= c' -> '(a >= c) and (b >= c)',
15947 // 'min(a, b) > c' -> '(a > c) and (b > c)',
15948 // 'max(a, b) <= c' -> '(a <= c) and (b <= c)',
15949 // 'max(a, b) < c' -> '(a < c) and (b < c)'.
15950
15951 // We cannot express strict predicates in SCEV, so instead we replace them
15952 // with non-strict ones against plus or minus one of RHS depending on the
15953 // predicate.
15954 const SCEV *One = SE.getOne(RHS->getType());
15955 switch (Predicate) {
15956 case CmpInst::ICMP_ULT:
15957 if (RHS->getType()->isPointerTy())
15958 return;
15959 RHS = SE.getUMaxExpr(RHS, One);
15960 [[fallthrough]];
15961 case CmpInst::ICMP_SLT: {
15962 RHS = SE.getMinusSCEV(RHS, One);
15963 RHS = getPreviousSCEVDivisibleByDivisor(RHS, DividesBy, SE);
15964 break;
15965 }
15966 case CmpInst::ICMP_UGT:
15967 case CmpInst::ICMP_SGT:
15968 RHS = SE.getAddExpr(RHS, One);
15969 RHS = getNextSCEVDivisibleByDivisor(RHS, DividesBy, SE);
15970 break;
15971 case CmpInst::ICMP_ULE:
15972 case CmpInst::ICMP_SLE:
15973 RHS = getPreviousSCEVDivisibleByDivisor(RHS, DividesBy, SE);
15974 break;
15975 case CmpInst::ICMP_UGE:
15976 case CmpInst::ICMP_SGE:
15977 RHS = getNextSCEVDivisibleByDivisor(RHS, DividesBy, SE);
15978 break;
15979 default:
15980 break;
15981 }
15982
15983 SmallVector<SCEVUse, 16> Worklist(1, LHS);
15984 SmallPtrSet<const SCEV *, 16> Visited;
15985
15986 auto EnqueueOperands = [&Worklist](const SCEVNAryExpr *S) {
15987 append_range(Worklist, S->operands());
15988 };
15989
15990 while (!Worklist.empty()) {
15991 const SCEV *From = Worklist.pop_back_val();
15992 if (isa<SCEVConstant>(From))
15993 continue;
15994 if (!Visited.insert(From).second)
15995 continue;
15996 const SCEV *FromRewritten = GetMaybeRewritten(From);
15997 const SCEV *To = nullptr;
15998
15999 switch (Predicate) {
16000 case CmpInst::ICMP_ULT:
16001 case CmpInst::ICMP_ULE:
16002 To = SE.getUMinExpr(FromRewritten, RHS);
16003 if (auto *UMax = dyn_cast<SCEVUMaxExpr>(FromRewritten))
16004 EnqueueOperands(UMax);
16005 break;
16006 case CmpInst::ICMP_SLT:
16007 case CmpInst::ICMP_SLE:
16008 To = SE.getSMinExpr(FromRewritten, RHS);
16009 if (auto *SMax = dyn_cast<SCEVSMaxExpr>(FromRewritten))
16010 EnqueueOperands(SMax);
16011 break;
16012 case CmpInst::ICMP_UGT:
16013 case CmpInst::ICMP_UGE:
16014 To = SE.getUMaxExpr(FromRewritten, RHS);
16015 if (auto *UMin = dyn_cast<SCEVUMinExpr>(FromRewritten))
16016 EnqueueOperands(UMin);
16017 break;
16018 case CmpInst::ICMP_SGT:
16019 case CmpInst::ICMP_SGE:
16020 To = SE.getSMaxExpr(FromRewritten, RHS);
16021 if (auto *SMin = dyn_cast<SCEVSMinExpr>(FromRewritten))
16022 EnqueueOperands(SMin);
16023 break;
16024 case CmpInst::ICMP_EQ:
16026 To = RHS;
16027 break;
16028 case CmpInst::ICMP_NE:
16029 if (match(RHS, m_scev_Zero())) {
16030 const SCEV *OneAlignedUp =
16031 getNextSCEVDivisibleByDivisor(One, DividesBy, SE);
16032 To = SE.getUMaxExpr(FromRewritten, OneAlignedUp);
16033 } else {
16034 // LHS != RHS can be rewritten as (LHS - RHS) = UMax(1, LHS - RHS),
16035 // but creating the subtraction eagerly is expensive. Track the
16036 // inequalities in a separate map, and materialize the rewrite lazily
16037 // when encountering a suitable subtraction while re-writing.
16038 if (LHS->getType()->isPointerTy()) {
16042 break;
16043 }
16044 const SCEVConstant *C;
16045 const SCEV *A, *B;
16048 RHS = A;
16049 LHS = B;
16050 }
16051 if (LHS > RHS)
16052 std::swap(LHS, RHS);
16053 Guards.NotEqual.insert({LHS, RHS});
16054 continue;
16055 }
16056 break;
16057 default:
16058 break;
16059 }
16060
16061 if (To)
16062 AddRewrite(From, FromRewritten, To);
16063 }
16064 };
16065
16067 // First, collect information from assumptions dominating the loop.
16068 for (auto &AssumeVH : SE.AC.assumptions()) {
16069 if (!AssumeVH)
16070 continue;
16071 auto *AssumeI = cast<CallInst>(AssumeVH);
16072 if (!SE.DT.dominates(AssumeI, Block))
16073 continue;
16074 Terms.emplace_back(AssumeI->getOperand(0), true);
16075 }
16076
16077 // Second, collect information from llvm.experimental.guards dominating the loop.
16078 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
16079 SE.F.getParent(), Intrinsic::experimental_guard);
16080 if (GuardDecl)
16081 for (const auto *GU : GuardDecl->users())
16082 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
16083 if (Guard->getFunction() == Block->getParent() &&
16084 SE.DT.dominates(Guard, Block))
16085 Terms.emplace_back(Guard->getArgOperand(0), true);
16086
16087 // Third, collect conditions from dominating branches. Starting at the loop
16088 // predecessor, climb up the predecessor chain, as long as there are
16089 // predecessors that can be found that have unique successors leading to the
16090 // original header.
16091 // TODO: share this logic with isLoopEntryGuardedByCond.
16092 unsigned NumCollectedConditions = 0;
16094 std::pair<const BasicBlock *, const BasicBlock *> Pair(Pred, Block);
16095 for (; Pair.first;
16096 Pair = SE.getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
16097 VisitedBlocks.insert(Pair.second);
16098 const CondBrInst *LoopEntryPredicate =
16099 dyn_cast<CondBrInst>(Pair.first->getTerminator());
16100 if (!LoopEntryPredicate)
16101 continue;
16102
16103 Terms.emplace_back(LoopEntryPredicate->getCondition(),
16104 LoopEntryPredicate->getSuccessor(0) == Pair.second);
16105 NumCollectedConditions++;
16106
16107 // If we are recursively collecting guards stop after 2
16108 // conditions to limit compile-time impact for now.
16109 if (Depth > 0 && NumCollectedConditions == 2)
16110 break;
16111 }
16112 // Finally, if we stopped climbing the predecessor chain because
16113 // there wasn't a unique one to continue, try to collect conditions
16114 // for PHINodes by recursively following all of their incoming
16115 // blocks and try to merge the found conditions to build a new one
16116 // for the Phi.
16117 if (Pair.second->hasNPredecessorsOrMore(2) &&
16119 SmallDenseMap<const BasicBlock *, LoopGuards> IncomingGuards;
16120 for (auto &Phi : Pair.second->phis())
16121 collectFromPHI(SE, Guards, Phi, VisitedBlocks, IncomingGuards, Depth);
16122 }
16123
16124 // Now apply the information from the collected conditions to
16125 // Guards.RewriteMap. Conditions are processed in reverse order, so the
16126 // earliest conditions is processed first, except guards with divisibility
16127 // information, which are moved to the back. This ensures the SCEVs with the
16128 // shortest dependency chains are constructed first.
16130 GuardsToProcess;
16131 for (auto [Term, EnterIfTrue] : reverse(Terms)) {
16132 SmallVector<Value *, 8> Worklist;
16133 SmallPtrSet<Value *, 8> Visited;
16134 Worklist.push_back(Term);
16135 while (!Worklist.empty()) {
16136 Value *Cond = Worklist.pop_back_val();
16137 if (!Visited.insert(Cond).second)
16138 continue;
16139
16140 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
16141 auto Predicate =
16142 EnterIfTrue ? Cmp->getPredicate() : Cmp->getInversePredicate();
16143 const auto *LHS = SE.getSCEV(Cmp->getOperand(0));
16144 const auto *RHS = SE.getSCEV(Cmp->getOperand(1));
16145 // If LHS is a constant, apply information to the other expression.
16146 // TODO: If LHS is not a constant, check if using CompareSCEVComplexity
16147 // can improve results.
16148 if (isa<SCEVConstant>(LHS)) {
16149 std::swap(LHS, RHS);
16151 }
16152 GuardsToProcess.emplace_back(Predicate, LHS, RHS);
16153 continue;
16154 }
16155
16156 Value *L, *R;
16157 if (EnterIfTrue ? match(Cond, m_LogicalAnd(m_Value(L), m_Value(R)))
16158 : match(Cond, m_LogicalOr(m_Value(L), m_Value(R)))) {
16159 Worklist.push_back(L);
16160 Worklist.push_back(R);
16161 }
16162 }
16163 }
16164
16165 // Process divisibility guards in reverse order to populate DivGuards early.
16166 DenseMap<const SCEV *, APInt> Multiples;
16167 LoopGuards DivGuards(SE);
16168 for (const auto &[Predicate, LHS, RHS] : GuardsToProcess) {
16169 if (!isDivisibilityGuard(LHS, RHS, SE))
16170 continue;
16171 collectDivisibilityInformation(Predicate, LHS, RHS, DivGuards.RewriteMap,
16172 Multiples, SE);
16173 }
16174
16175 for (const auto &[Predicate, LHS, RHS] : GuardsToProcess)
16176 CollectCondition(Predicate, LHS, RHS, Guards.RewriteMap, DivGuards);
16177
16178 // Apply divisibility information last. This ensures it is applied to the
16179 // outermost expression after other rewrites for the given value.
16180 for (const auto &[K, Divisor] : Multiples) {
16181 const SCEV *DivisorSCEV = SE.getConstant(Divisor);
16182 Guards.RewriteMap[K] =
16184 Guards.rewrite(K), Divisor, SE),
16185 DivisorSCEV),
16186 DivisorSCEV);
16187 ExprsToRewrite.push_back(K);
16188 }
16189
16190 // Let the rewriter preserve NUW/NSW flags if the unsigned/signed ranges of
16191 // the replacement expressions are contained in the ranges of the replaced
16192 // expressions.
16193 Guards.PreserveNUW = true;
16194 Guards.PreserveNSW = true;
16195 for (const SCEV *Expr : ExprsToRewrite) {
16196 const SCEV *RewriteTo = Guards.RewriteMap[Expr];
16197 Guards.PreserveNUW &=
16198 SE.getUnsignedRange(Expr).contains(SE.getUnsignedRange(RewriteTo));
16199 Guards.PreserveNSW &=
16200 SE.getSignedRange(Expr).contains(SE.getSignedRange(RewriteTo));
16201 }
16202
16203 // Now that all rewrite information is collect, rewrite the collected
16204 // expressions with the information in the map. This applies information to
16205 // sub-expressions.
16206 if (ExprsToRewrite.size() > 1) {
16207 for (const SCEV *Expr : ExprsToRewrite) {
16208 const SCEV *RewriteTo = Guards.RewriteMap[Expr];
16209 Guards.RewriteMap.erase(Expr);
16210 Guards.RewriteMap.insert({Expr, Guards.rewrite(RewriteTo)});
16211 }
16212 }
16213}
16214
16216 /// A rewriter to replace SCEV expressions in Map with the corresponding entry
16217 /// in the map. It skips AddRecExpr because we cannot guarantee that the
16218 /// replacement is loop invariant in the loop of the AddRec.
16219 class SCEVLoopGuardRewriter
16220 : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
16223
16225
16226 public:
16227 SCEVLoopGuardRewriter(ScalarEvolution &SE,
16228 const ScalarEvolution::LoopGuards &Guards)
16229 : SCEVRewriteVisitor(SE), Map(Guards.RewriteMap),
16230 NotEqual(Guards.NotEqual) {
16231 if (Guards.PreserveNUW)
16232 FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNUW);
16233 if (Guards.PreserveNSW)
16234 FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNSW);
16235 }
16236
16237 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
16238
16239 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
16240 return Map.lookup_or(Expr, Expr);
16241 }
16242
16243 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
16244 if (const SCEV *S = Map.lookup(Expr))
16245 return S;
16246
16247 // If we didn't find the extact ZExt expr in the map, check if there's
16248 // an entry for a smaller ZExt we can use instead.
16249 Type *Ty = Expr->getType();
16250 const SCEV *Op = Expr->getOperand(0);
16251 unsigned Bitwidth = Ty->getScalarSizeInBits() / 2;
16252 while (Bitwidth % 8 == 0 && Bitwidth >= 8 &&
16253 Bitwidth > Op->getType()->getScalarSizeInBits()) {
16254 Type *NarrowTy = IntegerType::get(SE.getContext(), Bitwidth);
16255 auto *NarrowExt = SE.getZeroExtendExpr(Op, NarrowTy);
16256 if (const SCEV *S = Map.lookup(NarrowExt))
16257 return SE.getZeroExtendExpr(S, Ty);
16258 Bitwidth = Bitwidth / 2;
16259 }
16260
16262 Expr);
16263 }
16264
16265 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
16266 if (const SCEV *S = Map.lookup(Expr))
16267 return S;
16269 Expr);
16270 }
16271
16272 const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) {
16273 if (const SCEV *S = Map.lookup(Expr))
16274 return S;
16276 }
16277
16278 const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) {
16279 if (const SCEV *S = Map.lookup(Expr))
16280 return S;
16282 }
16283
16284 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
16285 // Helper to check if S is a subtraction (A - B) where A != B, and if so,
16286 // return UMax(S, 1).
16287 auto RewriteSubtraction = [&](const SCEV *S) -> const SCEV * {
16288 SCEVUse LHS, RHS;
16289 if (MatchBinarySub(S, LHS, RHS)) {
16290 if (LHS > RHS)
16291 std::swap(LHS, RHS);
16292 if (NotEqual.contains({LHS, RHS})) {
16293 const SCEV *OneAlignedUp = getNextSCEVDivisibleByDivisor(
16294 SE.getOne(S->getType()), SE.getConstantMultiple(S), SE);
16295 return SE.getUMaxExpr(OneAlignedUp, S);
16296 }
16297 }
16298 return nullptr;
16299 };
16300
16301 // Check if Expr itself is a subtraction pattern with guard info.
16302 if (const SCEV *Rewritten = RewriteSubtraction(Expr))
16303 return Rewritten;
16304
16305 // Trip count expressions sometimes consist of adding 3 operands, i.e.
16306 // (Const + A + B). There may be guard info for A + B, and if so, apply
16307 // it.
16308 // TODO: Could more generally apply guards to Add sub-expressions.
16309 if (isa<SCEVConstant>(Expr->getOperand(0)) &&
16310 Expr->getNumOperands() == 3) {
16311 const SCEV *Add =
16312 SE.getAddExpr(Expr->getOperand(1), Expr->getOperand(2));
16313 if (const SCEV *Rewritten = RewriteSubtraction(Add))
16314 return SE.getAddExpr(
16315 Expr->getOperand(0), Rewritten,
16316 ScalarEvolution::maskFlags(Expr->getNoWrapFlags(), FlagMask));
16317 if (const SCEV *S = Map.lookup(Add))
16318 return SE.getAddExpr(Expr->getOperand(0), S);
16319 }
16320 SmallVector<SCEVUse, 2> Operands;
16321 bool Changed = false;
16322 for (SCEVUse Op : Expr->operands()) {
16323 Operands.push_back(
16325 Changed |= Op != Operands.back();
16326 }
16327 // We are only replacing operands with equivalent values, so transfer the
16328 // flags from the original expression.
16329 return !Changed ? Expr
16330 : SE.getAddExpr(Operands,
16332 Expr->getNoWrapFlags(), FlagMask));
16333 }
16334
16335 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
16336 SmallVector<SCEVUse, 2> Operands;
16337 bool Changed = false;
16338 for (SCEVUse Op : Expr->operands()) {
16339 Operands.push_back(
16341 Changed |= Op != Operands.back();
16342 }
16343 // We are only replacing operands with equivalent values, so transfer the
16344 // flags from the original expression.
16345 return !Changed ? Expr
16346 : SE.getMulExpr(Operands,
16348 Expr->getNoWrapFlags(), FlagMask));
16349 }
16350 };
16351
16352 if (RewriteMap.empty() && NotEqual.empty())
16353 return Expr;
16354
16355 SCEVLoopGuardRewriter Rewriter(SE, *this);
16356 return Rewriter.visit(Expr);
16357}
16358
16359const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
16360 return applyLoopGuards(Expr, LoopGuards::collect(L, *this));
16361}
16362
16364 const LoopGuards &Guards) {
16365 return Guards.rewrite(Expr);
16366}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
constexpr LLT S1
Rewrite undef for PHI
This file implements a class to represent arbitrary precision integral constant values and operations...
@ PostInc
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Expand Atomic instructions
#define X(NUM, ENUM, NAME)
Definition ELF.h:853
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
#define LLVM_DUMP_METHOD
Mark debug helper function definitions like dump() that should not be stripped from debug builds.
Definition Compiler.h:661
This file contains the declarations for the subclasses of Constant, which represent the different fla...
SmallPtrSet< const BasicBlock *, 8 > VisitedBlocks
This file defines the DenseMap class.
This file builds on the ADT/GraphTraits.h file to build generic depth first graph iterator.
static bool isSigned(unsigned Opcode)
This file defines a hash set that can be used to remove duplication of nodes in a graph.
#define op(i)
Hexagon Common GEP
Value * getPointer(Value *Ptr)
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
This defines the Use class.
iv Induction Variable Users
Definition IVUsers.cpp:48
static constexpr Value * getValue(Ty &ValueOrUse)
static Value * getOpcode(Value &V, Type &Ty, InstrumentationConfig &IConf, InstrumentorIRBuilderTy &IIRB)
const AbstractManglingParser< Derived, Alloc >::OperatorInfo AbstractManglingParser< Derived, Alloc >::Ops[]
static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, AssumptionCache *AC)
Definition Lint.cpp:539
#define F(x, y, z)
Definition MD5.cpp:54
#define I(x, y, z)
Definition MD5.cpp:57
#define G(x, y, z)
Definition MD5.cpp:55
#define T
#define T1
static constexpr unsigned SM(unsigned Version)
ConstantRange Range(APInt(BitWidth, Low), APInt(BitWidth, High))
uint64_t IntrinsicInst * II
#define P(N)
ppc ctr loops verify
PowerPC Reduce CR logical Operation
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition PassSupport.h:42
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition PassSupport.h:44
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition PassSupport.h:39
R600 Clause Merge
const SmallVectorImpl< MachineOperand > & Cond
static DominatorTree getDomTree(Function &F)
static bool isValid(const char C)
Returns true if C is a valid mangled character: <0-9a-zA-Z_>.
SI optimize exec mask operations pre RA
static void visit(BasicBlock &Start, std::function< bool(BasicBlock *)> op)
This file contains some templates that are useful if you are working with the STL at all.
This file provides utility classes that use RAII to save and restore values.
bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind, SCEVTypes RootKind)
static cl::opt< unsigned > MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden, cl::desc("Max coefficients in AddRec during evolving"), cl::init(8))
static cl::opt< unsigned > RangeIterThreshold("scev-range-iter-threshold", cl::Hidden, cl::desc("Threshold for switching to iteratively computing SCEV ranges"), cl::init(32))
static const Loop * isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI)
static unsigned getConstantTripCount(const SCEVConstant *ExitCount)
static int CompareValueComplexity(const LoopInfo *const LI, Value *LV, Value *RV, unsigned Depth)
Compare the two values LV and RV in terms of their "complexity" where "complexity" is a partial (and ...
static const SCEV * getNextSCEVDivisibleByDivisor(const SCEV *Expr, const APInt &DivisorVal, ScalarEvolution &SE)
static void PushLoopPHIs(const Loop *L, SmallVectorImpl< Instruction * > &Worklist, SmallPtrSetImpl< Instruction * > &Visited)
Push PHI nodes in the header of the given loop onto the given Worklist.
static void insertFoldCacheEntry(const ScalarEvolution::FoldID &ID, const SCEV *S, DenseMap< ScalarEvolution::FoldID, const SCEV * > &FoldCache, DenseMap< const SCEV *, SmallVector< ScalarEvolution::FoldID, 2 > > &FoldCacheUser)
static cl::opt< bool > ClassifyExpressions("scalar-evolution-classify-expressions", cl::Hidden, cl::init(true), cl::desc("When printing analysis, include information on every instruction"))
static bool hasHugeExpression(ArrayRef< SCEVUse > Ops)
Returns true if Ops contains a huge SCEV (the subtree of S contains at least HugeExprThreshold nodes)...
static bool CanConstantFold(const Instruction *I)
Return true if we can constant fold an instruction of the specified type, assuming that all operands ...
static cl::opt< unsigned > AddOpsInlineThreshold("scev-addops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining addition operands into a SCEV"), cl::init(500))
static cl::opt< unsigned > MaxLoopGuardCollectionDepth("scalar-evolution-max-loop-guard-collection-depth", cl::Hidden, cl::desc("Maximum depth for recursive loop guard collection"), cl::init(1))
static cl::opt< bool > VerifyIR("scev-verify-ir", cl::Hidden, cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"), cl::init(false))
static bool RangeRefPHIAllowedOperands(DominatorTree &DT, PHINode *PHI)
static const SCEV * getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static std::optional< APInt > MinOptional(std::optional< APInt > X, std::optional< APInt > Y)
Helper function to compare optional APInts: (a) if X and Y both exist, return min(X,...
static cl::opt< unsigned > MulOpsInlineThreshold("scev-mulops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining multiplication operands into a SCEV"), cl::init(32))
static BinaryOperator * getCommonInstForPHI(PHINode *PN)
static bool isDivisibilityGuard(const SCEV *LHS, const SCEV *RHS, ScalarEvolution &SE)
static std::optional< const SCEV * > createNodeForSelectViaUMinSeq(ScalarEvolution *SE, const SCEV *CondExpr, const SCEV *TrueExpr, const SCEV *FalseExpr)
static Constant * BuildConstantFromSCEV(const SCEV *V)
This builds up a Constant using the ConstantExpr interface.
static ConstantInt * EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C, ScalarEvolution &SE)
static const SCEV * BinomialCoefficient(const SCEV *It, unsigned K, ScalarEvolution &SE, Type *ResultTy)
Compute BC(It, K). The result has width W. Assume, K > 0.
static cl::opt< unsigned > MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden, cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"), cl::init(8))
static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr, const SCEV *Candidate)
Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
static PHINode * getConstantEvolvingPHI(Value *V, const Loop *L)
getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node in the loop that V is deri...
static const SCEV * SolveLinEquationWithOverflow(const APInt &A, const SCEV *B, SmallVectorImpl< const SCEVPredicate * > *Predicates, ScalarEvolution &SE, const Loop *L)
Finds the minimum unsigned root of the following equation:
static cl::opt< unsigned > MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden, cl::desc("Maximum number of iterations SCEV will " "symbolically execute a constant " "derived loop"), cl::init(100))
static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow)
static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV *S)
When printing a top-level SCEV for trip counts, it's helpful to include a type for constants which ar...
static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, const Loop *L)
static SCEV::NoWrapFlags StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, ArrayRef< SCEVUse > Ops, SCEV::NoWrapFlags Flags)
static bool containsConstantInAddMulChain(const SCEV *StartExpr)
Determine if any of the operands in this SCEV are a constant or if any of the add or multiply express...
static const SCEV * getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static bool CollectAddOperandsWithScales(SmallDenseMap< SCEVUse, APInt, 16 > &M, SmallVectorImpl< SCEVUse > &NewOps, APInt &AccumulatedConstant, ArrayRef< SCEVUse > Ops, const APInt &Scale, ScalarEvolution &SE)
Process the given Ops list, which is a list of operands to be added under the given scale,...
static const SCEV * constantFoldAndGroupOps(ScalarEvolution &SE, LoopInfo &LI, DominatorTree &DT, SmallVectorImpl< SCEVUse > &Ops, FoldT Fold, IsIdentityT IsIdentity, IsAbsorberT IsAbsorber)
Performs a number of common optimizations on the passed Ops.
static cl::opt< unsigned > MaxPhiSCCAnalysisSize("scalar-evolution-max-scc-analysis-depth", cl::Hidden, cl::desc("Maximum amount of nodes to process while searching SCEVUnknown " "Phi strongly connected components"), cl::init(8))
static bool IsKnownPredicateViaAddRecStart(ScalarEvolution &SE, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
static void GroupByComplexity(SmallVectorImpl< SCEVUse > &Ops, LoopInfo *LI, DominatorTree &DT)
Given a list of SCEV objects, order them by their complexity, and group objects of the same complexit...
static bool collectDivisibilityInformation(ICmpInst::Predicate Predicate, const SCEV *LHS, const SCEV *RHS, DenseMap< const SCEV *, const SCEV * > &DivInfo, DenseMap< const SCEV *, APInt > &Multiples, ScalarEvolution &SE)
static cl::opt< unsigned > MaxSCEVOperationsImplicationDepth("scalar-evolution-max-scev-operations-implication-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV operations implication analysis"), cl::init(2))
static void PushDefUseChildren(Instruction *I, SmallVectorImpl< Instruction * > &Worklist, SmallPtrSetImpl< Instruction * > &Visited)
Push users of the given Instruction onto the given Worklist.
static std::optional< APInt > SolveQuadraticAddRecRange(const SCEVAddRecExpr *AddRec, const ConstantRange &Range, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n iterations.
static cl::opt< bool > UseContextForNoWrapFlagInference("scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden, cl::desc("Infer nuw/nsw flags using context where suitable"), cl::init(true))
static cl::opt< bool > EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden, cl::desc("Handle <= and >= in finite loops"), cl::init(true))
static bool getOperandsForSelectLikePHI(DominatorTree &DT, PHINode *PN, Value *&Cond, Value *&LHS, Value *&RHS)
static std::optional< std::tuple< APInt, APInt, APInt, APInt, unsigned > > GetQuadraticEquation(const SCEVAddRecExpr *AddRec)
For a given quadratic addrec, generate coefficients of the corresponding quadratic equation,...
static bool isKnownPredicateExtendIdiom(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
static std::optional< BinaryOp > MatchBinaryOp(Value *V, const DataLayout &DL, AssumptionCache &AC, const DominatorTree &DT, const Instruction *CxtI)
Try to map V into a BinaryOp, and return std::nullopt on failure.
static std::optional< APInt > SolveQuadraticAddRecExact(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n iterations.
static std::optional< APInt > TruncIfPossible(std::optional< APInt > X, unsigned BitWidth)
Helper function to truncate an optional APInt to a given BitWidth.
static cl::opt< unsigned > MaxSCEVCompareDepth("scalar-evolution-max-scev-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV complexity comparisons"), cl::init(32))
static APInt extractConstantWithoutWrapping(ScalarEvolution &SE, const SCEVConstant *ConstantTerm, const SCEVAddExpr *WholeAddExpr)
static cl::opt< unsigned > MaxConstantEvolvingDepth("scalar-evolution-max-constant-evolving-depth", cl::Hidden, cl::desc("Maximum depth of recursive constant evolving"), cl::init(32))
static ConstantRange getRangeForAffineARHelper(APInt Step, const ConstantRange &StartRange, const APInt &MaxBECount, bool Signed)
static bool MatchBinarySub(const SCEV *S, SCEVUse &LHS, SCEVUse &RHS)
static std::optional< ConstantRange > GetRangeFromMetadata(Value *V)
Helper method to assign a range to V from metadata present in the IR.
static cl::opt< unsigned > HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden, cl::desc("Size of the expression which is considered huge"), cl::init(4096))
static Type * isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI, bool &Signed, ScalarEvolution &SE)
Helper function to createAddRecFromPHIWithCasts.
static Constant * EvaluateExpression(Value *V, const Loop *L, DenseMap< Instruction *, Constant * > &Vals, const DataLayout &DL, const TargetLibraryInfo *TLI)
EvaluateExpression - Given an expression that passes the getConstantEvolvingPHI predicate,...
static const SCEV * getPreviousSCEVDivisibleByDivisor(const SCEV *Expr, const APInt &DivisorVal, ScalarEvolution &SE)
static const SCEV * MatchNotExpr(const SCEV *Expr)
If Expr computes ~A, return A else return nullptr.
static cl::opt< unsigned > MaxValueCompareDepth("scalar-evolution-max-value-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive value complexity comparisons"), cl::init(2))
static const SCEV * applyDivisibilityOnMinMaxExpr(const SCEV *MinMaxExpr, APInt Divisor, ScalarEvolution &SE)
static cl::opt< bool, true > VerifySCEVOpt("verify-scev", cl::Hidden, cl::location(VerifySCEV), cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"))
static const SCEV * getSignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
static cl::opt< unsigned > MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden, cl::desc("Maximum depth of recursive arithmetics"), cl::init(32))
static bool HasSameValue(const SCEV *A, const SCEV *B)
SCEV structural equivalence is usually sufficient for testing whether two expressions are equal,...
static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow)
Compute the result of "n choose k", the binomial coefficient.
static std::optional< int > CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS, const SCEV *RHS, DominatorTree &DT, unsigned Depth=0)
static bool canConstantEvolve(Instruction *I, const Loop *L)
Determine whether this instruction can constant evolve within this loop assuming its operands can all...
static PHINode * getConstantEvolvingPHIOperands(Instruction *UseInst, const Loop *L, DenseMap< Instruction *, PHINode * > &PHIMap, unsigned Depth)
getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by recursing through each instructi...
static bool scevUnconditionallyPropagatesPoisonFromOperands(SCEVTypes Kind)
static cl::opt< bool > VerifySCEVStrict("verify-scev-strict", cl::Hidden, cl::desc("Enable stricter verification with -verify-scev is passed"))
static Constant * getOtherIncomingValue(PHINode *PN, BasicBlock *BB)
static cl::opt< bool > UseExpensiveRangeSharpening("scalar-evolution-use-expensive-range-sharpening", cl::Hidden, cl::init(false), cl::desc("Use more powerful methods of sharpening expression ranges. May " "be costly in terms of compile time"))
static const SCEV * getUnsignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Is LHS Pred RHS true on the virtue of LHS or RHS being a Min or Max expression?
static bool BrPHIToSelect(DominatorTree &DT, CondBrInst *BI, PHINode *Merge, Value *&C, Value *&LHS, Value *&RHS)
This file defines the make_scope_exit function, which executes user-defined cleanup logic at scope ex...
static bool InBlock(const Value *V, const BasicBlock *BB)
Provides some synthesis utilities to produce sequences of values.
This file defines the SmallPtrSet class.
This file defines the SmallVector class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition Statistic.h:171
This file contains some functions that are useful when dealing with strings.
#define LLVM_DEBUG(...)
Definition Debug.h:119
static TableGen::Emitter::Opt Y("gen-skeleton-entry", EmitSkeleton, "Generate example skeleton entry")
static SymbolRef::Type getType(const Symbol *Sym)
Definition TapiFile.cpp:39
LocallyHashedType DenseMapInfo< LocallyHashedType >::Empty
static std::optional< bool > isImpliedCondOperands(CmpInst::Predicate Pred, const Value *ALHS, const Value *ARHS, const Value *BLHS, const Value *BRHS)
Return true if "icmp Pred BLHS BRHS" is true whenever "icmp PredALHS ARHS" is true.
Virtual Register Rewriter
Value * RHS
Value * LHS
BinaryOperator * Mul
static const uint32_t IV[8]
Definition blake3_impl.h:83
SCEVCastSinkingRewriter(ScalarEvolution &SE, Type *TargetTy, ConversionFn CreatePtrCast)
static const SCEV * rewrite(const SCEV *Scev, ScalarEvolution &SE, Type *TargetTy, ConversionFn CreatePtrCast)
const SCEV * visitUnknown(const SCEVUnknown *Expr)
const SCEV * visitMulExpr(const SCEVMulExpr *Expr)
const SCEV * visitAddExpr(const SCEVAddExpr *Expr)
const SCEV * visit(const SCEV *S)
Class for arbitrary precision integers.
Definition APInt.h:78
LLVM_ABI APInt umul_ov(const APInt &RHS, bool &Overflow) const
Definition APInt.cpp:2023
LLVM_ABI APInt zext(unsigned width) const
Zero extend to a new width.
Definition APInt.cpp:1055
bool isMinSignedValue() const
Determine if this is the smallest signed value.
Definition APInt.h:424
uint64_t getZExtValue() const
Get zero extended value.
Definition APInt.h:1563
void setHighBits(unsigned hiBits)
Set the top hiBits bits.
Definition APInt.h:1414
LLVM_ABI APInt getHiBits(unsigned numBits) const
Compute an APInt containing numBits highbits from this APInt.
Definition APInt.cpp:640
unsigned getActiveBits() const
Compute the number of active bits in the value.
Definition APInt.h:1535
LLVM_ABI APInt trunc(unsigned width) const
Truncate to new width.
Definition APInt.cpp:968
static APInt getMaxValue(unsigned numBits)
Gets maximum unsigned value of APInt for specific bit width.
Definition APInt.h:207
APInt abs() const
Get the absolute value.
Definition APInt.h:1818
bool sgt(const APInt &RHS) const
Signed greater than comparison.
Definition APInt.h:1208
bool isAllOnes() const
Determine if all bits are set. This is true for zero-width values.
Definition APInt.h:372
bool ugt(const APInt &RHS) const
Unsigned greater than comparison.
Definition APInt.h:1189
bool isZero() const
Determine if this value is zero, i.e. all bits are clear.
Definition APInt.h:381
bool isSignMask() const
Check if the APInt's value is returned by getSignMask.
Definition APInt.h:467
LLVM_ABI APInt urem(const APInt &RHS) const
Unsigned remainder operation.
Definition APInt.cpp:1709
unsigned getBitWidth() const
Return the number of bits in the APInt.
Definition APInt.h:1511
bool ult(const APInt &RHS) const
Unsigned less than comparison.
Definition APInt.h:1118
static APInt getSignedMaxValue(unsigned numBits)
Gets maximum signed value of APInt for a specific bit width.
Definition APInt.h:210
static APInt getMinValue(unsigned numBits)
Gets minimum unsigned value of APInt for a specific bit width.
Definition APInt.h:217
bool isNegative() const
Determine sign of this APInt.
Definition APInt.h:330
bool sle(const APInt &RHS) const
Signed less or equal comparison.
Definition APInt.h:1173
static APInt getSignedMinValue(unsigned numBits)
Gets minimum signed value of APInt for a specific bit width.
Definition APInt.h:220
bool isNonPositive() const
Determine if this APInt Value is non-positive (<= 0).
Definition APInt.h:362
unsigned countTrailingZeros() const
Definition APInt.h:1670
bool isStrictlyPositive() const
Determine if this APInt Value is positive.
Definition APInt.h:357
unsigned logBase2() const
Definition APInt.h:1784
APInt ashr(unsigned ShiftAmt) const
Arithmetic right-shift function.
Definition APInt.h:834
LLVM_ABI APInt multiplicativeInverse() const
Definition APInt.cpp:1317
bool ule(const APInt &RHS) const
Unsigned less or equal comparison.
Definition APInt.h:1157
LLVM_ABI APInt sext(unsigned width) const
Sign extend to a new width.
Definition APInt.cpp:1028
APInt shl(unsigned shiftAmt) const
Left-shift function.
Definition APInt.h:880
bool isPowerOf2() const
Check if this APInt's value is a power of two greater than zero.
Definition APInt.h:441
static APInt getLowBitsSet(unsigned numBits, unsigned loBitsSet)
Constructs an APInt value that has the bottom loBitsSet bits set.
Definition APInt.h:307
bool isSignBitSet() const
Determine if sign bit of this APInt is set.
Definition APInt.h:342
bool slt(const APInt &RHS) const
Signed less than comparison.
Definition APInt.h:1137
static APInt getZero(unsigned numBits)
Get the '0' value for the specified bit-width.
Definition APInt.h:201
bool isIntN(unsigned N) const
Check if this APInt has an N-bits unsigned integer value.
Definition APInt.h:433
bool sge(const APInt &RHS) const
Signed greater or equal comparison.
Definition APInt.h:1244
static APInt getOneBitSet(unsigned numBits, unsigned BitNo)
Return an APInt with exactly one bit set in the result.
Definition APInt.h:240
bool uge(const APInt &RHS) const
Unsigned greater or equal comparison.
Definition APInt.h:1228
This templated class represents "all analyses that operate over <aparticular IR unit>" (e....
Definition Analysis.h:50
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Represent the analysis usage information of a pass.
void setPreservesAll()
Set by analyses that do not transform their input at all.
AnalysisUsage & addRequiredTransitive()
Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition ArrayRef.h:40
iterator end() const
Definition ArrayRef.h:130
size_t size() const
Get the array size.
Definition ArrayRef.h:141
iterator begin() const
Definition ArrayRef.h:129
A function analysis which provides an AssumptionCache.
An immutable pass that tracks lazily created AssumptionCache objects.
A cache of @llvm.assume calls within a function.
MutableArrayRef< WeakVH > assumptions()
Access the list of assumption handles currently tracked for this function.
LLVM Basic Block Representation.
Definition BasicBlock.h:62
iterator begin()
Instruction iterator methods.
Definition BasicBlock.h:461
const Function * getParent() const
Return the enclosing method, or null if none.
Definition BasicBlock.h:213
LLVM_ABI const BasicBlock * getSinglePredecessor() const
Return the predecessor of this block if it has a single predecessor block.
const Instruction & front() const
Definition BasicBlock.h:484
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction; assumes that the block is well-formed.
Definition BasicBlock.h:237
LLVM_ABI unsigned getNoWrapKind() const
Returns one of OBO::NoSignedWrap or OBO::NoUnsignedWrap.
LLVM_ABI Instruction::BinaryOps getBinaryOp() const
Returns the binary operation underlying the intrinsic.
BinaryOps getOpcode() const
Definition InstrTypes.h:409
This class represents a function call, abstracting a target machine's calling convention.
virtual void deleted()
Callback for Value destruction.
void setValPtr(Value *P)
bool isFalseWhenEqual() const
This is just a convenience.
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition InstrTypes.h:740
@ ICMP_SLT
signed less than
Definition InstrTypes.h:769
@ ICMP_SLE
signed less or equal
Definition InstrTypes.h:770
@ ICMP_UGE
unsigned greater or equal
Definition InstrTypes.h:764
@ ICMP_UGT
unsigned greater than
Definition InstrTypes.h:763
@ ICMP_SGT
signed greater than
Definition InstrTypes.h:767
@ ICMP_ULT
unsigned less than
Definition InstrTypes.h:765
@ ICMP_NE
not equal
Definition InstrTypes.h:762
@ ICMP_SGE
signed greater or equal
Definition InstrTypes.h:768
@ ICMP_ULE
unsigned less or equal
Definition InstrTypes.h:766
bool isSigned() const
Definition InstrTypes.h:993
Predicate getSwappedPredicate() const
For example, EQ->EQ, SLE->SGE, ULT->UGT, OEQ->OEQ, ULE->UGE, OLT->OGT, etc.
Definition InstrTypes.h:890
bool isTrueWhenEqual() const
This is just a convenience.
Predicate getInversePredicate() const
For example, EQ -> NE, UGT -> ULE, SLT -> SGE, OEQ -> UNE, UGT -> OLE, OLT -> UGE,...
Definition InstrTypes.h:852
bool isUnsigned() const
Definition InstrTypes.h:999
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
Definition InstrTypes.h:989
An abstraction over a floating-point predicate, and a pack of an integer predicate with samesign info...
static LLVM_ABI std::optional< CmpPredicate > getMatching(CmpPredicate A, CmpPredicate B)
Compares two CmpPredicates taking samesign into account and returns the canonicalized CmpPredicate if...
LLVM_ABI CmpInst::Predicate getPreferredSignedPredicate() const
Attempts to return a signed CmpInst::Predicate from the CmpPredicate.
CmpInst::Predicate dropSameSign() const
Drops samesign information.
Conditional Branch instruction.
Value * getCondition() const
BasicBlock * getSuccessor(unsigned i) const
static LLVM_ABI Constant * getNot(Constant *C)
static Constant * getPtrAdd(Constant *Ptr, Constant *Offset, GEPNoWrapFlags NW=GEPNoWrapFlags::none(), std::optional< ConstantRange > InRange=std::nullopt, Type *OnlyIfReduced=nullptr)
Create a getelementptr i8, ptr, offset constant expression.
Definition Constants.h:1491
static LLVM_ABI Constant * getPtrToInt(Constant *C, Type *Ty, bool OnlyIfReduced=false)
static LLVM_ABI Constant * getPtrToAddr(Constant *C, Type *Ty, bool OnlyIfReduced=false)
static LLVM_ABI Constant * getAdd(Constant *C1, Constant *C2, bool HasNUW=false, bool HasNSW=false)
static LLVM_ABI Constant * getNeg(Constant *C, bool HasNSW=false)
static LLVM_ABI Constant * getTrunc(Constant *C, Type *Ty, bool OnlyIfReduced=false)
This is the shared class of boolean and integer constants.
Definition Constants.h:87
bool isZero() const
This is just a convenience method to make client code smaller for a common code.
Definition Constants.h:219
static LLVM_ABI ConstantInt * getFalse(LLVMContext &Context)
uint64_t getZExtValue() const
Return the constant as a 64-bit unsigned integer value after it has been zero extended as appropriate...
Definition Constants.h:168
const APInt & getValue() const
Return the constant as an APInt value reference.
Definition Constants.h:159
static LLVM_ABI ConstantInt * getBool(LLVMContext &Context, bool V)
This class represents a range of values.
LLVM_ABI ConstantRange add(const ConstantRange &Other) const
Return a new range representing the possible values resulting from an addition of a value in this ran...
LLVM_ABI ConstantRange zextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
PreferredRangeType
If represented precisely, the result of some range operations may consist of multiple disjoint ranges...
LLVM_ABI bool getEquivalentICmp(CmpInst::Predicate &Pred, APInt &RHS) const
Set up Pred and RHS such that ConstantRange::makeExactICmpRegion(Pred, RHS) == *this.
const APInt & getLower() const
Return the lower value for this range.
LLVM_ABI ConstantRange urem(const ConstantRange &Other) const
Return a new range representing the possible values resulting from an unsigned remainder operation of...
LLVM_ABI bool isFullSet() const
Return true if this set contains all of the elements possible for this data-type.
LLVM_ABI bool icmp(CmpInst::Predicate Pred, const ConstantRange &Other) const
Does the predicate Pred hold between ranges this and Other?
LLVM_ABI bool isEmptySet() const
Return true if this set contains no members.
LLVM_ABI ConstantRange zeroExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
LLVM_ABI bool isSignWrappedSet() const
Return true if this set wraps around the signed domain.
LLVM_ABI APInt getSignedMin() const
Return the smallest signed value contained in the ConstantRange.
LLVM_ABI bool isWrappedSet() const
Return true if this set wraps around the unsigned domain.
LLVM_ABI void print(raw_ostream &OS) const
Print out the bounds to a stream.
LLVM_ABI ConstantRange truncate(uint32_t BitWidth, unsigned NoWrapKind=0) const
Return a new range in the specified integer type, which must be strictly smaller than the current typ...
LLVM_ABI ConstantRange signExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
const APInt & getUpper() const
Return the upper value for this range.
LLVM_ABI ConstantRange unionWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the union of this range with another range.
static LLVM_ABI ConstantRange makeExactICmpRegion(CmpInst::Predicate Pred, const APInt &Other)
Produce the exact range such that all values in the returned range satisfy the given predicate with a...
LLVM_ABI bool contains(const APInt &Val) const
Return true if the specified value is in the set.
LLVM_ABI APInt getUnsignedMax() const
Return the largest unsigned value contained in the ConstantRange.
LLVM_ABI ConstantRange intersectWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the intersection of this range with another range.
LLVM_ABI APInt getSignedMax() const
Return the largest signed value contained in the ConstantRange.
static ConstantRange getNonEmpty(APInt Lower, APInt Upper)
Create non-empty constant range with the given bounds.
static LLVM_ABI ConstantRange makeGuaranteedNoWrapRegion(Instruction::BinaryOps BinOp, const ConstantRange &Other, unsigned NoWrapKind)
Produce the largest range containing all X such that "X BinOp Y" is guaranteed not to wrap (overflow)...
LLVM_ABI unsigned getMinSignedBits() const
Compute the maximal number of bits needed to represent every value in this signed range.
uint32_t getBitWidth() const
Get the bit width of this ConstantRange.
LLVM_ABI ConstantRange sub(const ConstantRange &Other) const
Return a new range representing the possible values resulting from a subtraction of a value in this r...
LLVM_ABI ConstantRange sextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
static LLVM_ABI ConstantRange makeExactNoWrapRegion(Instruction::BinaryOps BinOp, const APInt &Other, unsigned NoWrapKind)
Produce the range that contains X if and only if "X BinOp Other" does not wrap.
This is an important base class in LLVM.
Definition Constant.h:43
A parsed version of the target data layout string in and methods for querying it.
Definition DataLayout.h:64
LLVM_ABI const StructLayout * getStructLayout(StructType *Ty) const
Returns a StructLayout object, indicating the alignment of the struct, its size, and the offsets of i...
LLVM_ABI IntegerType * getIntPtrType(LLVMContext &C, unsigned AddressSpace=0) const
Returns an integer type with size at least as big as that of a pointer in the given address space.
LLVM_ABI unsigned getIndexTypeSizeInBits(Type *Ty) const
The size in bits of the index used in GEP calculation for this type.
LLVM_ABI IntegerType * getIndexType(LLVMContext &C, unsigned AddressSpace) const
Returns the type of a GEP index in AddressSpace.
TypeSize getTypeSizeInBits(Type *Ty) const
Size examples:
Definition DataLayout.h:791
ValueT lookup(const_arg_type_t< KeyT > Val) const
Return the entry for the specified key, or a default constructed value if no such entry exists.
Definition DenseMap.h:205
iterator find(const_arg_type_t< KeyT > Val)
Definition DenseMap.h:178
std::pair< iterator, bool > try_emplace(KeyT &&Key, Ts &&...Args)
Definition DenseMap.h:254
DenseMapIterator< KeyT, ValueT, KeyInfoT, BucketT > iterator
Definition DenseMap.h:78
iterator find_as(const LookupKeyT &Val)
Alternate version of find() which allows a different, and possibly less expensive,...
Definition DenseMap.h:191
size_type count(const_arg_type_t< KeyT > Val) const
Return 1 if the specified key is in the map, 0 otherwise.
Definition DenseMap.h:174
iterator end()
Definition DenseMap.h:85
bool contains(const_arg_type_t< KeyT > Val) const
Return true if the specified key is in the map, false otherwise.
Definition DenseMap.h:169
void swap(DerivedT &RHS)
Definition DenseMap.h:390
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition DenseMap.h:239
Analysis pass which computes a DominatorTree.
Definition Dominators.h:278
Legacy analysis pass which computes a DominatorTree.
Definition Dominators.h:314
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition Dominators.h:159
LLVM_ABI bool isReachableFromEntry(const Use &U) const
Provide an overload for a Use.
LLVM_ABI bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
This class describes a reference to an interned FoldingSetNodeID, which can be a useful to store node...
Definition FoldingSet.h:171
This class is used to gather all the unique data bits of a node.
Definition FoldingSet.h:208
FunctionPass(char &pid)
Definition Pass.h:316
Represents flags for the getelementptr instruction/expression.
bool hasNoUnsignedSignedWrap() const
bool hasNoUnsignedWrap() const
static GEPNoWrapFlags none()
static LLVM_ABI Type * getTypeAtIndex(Type *Ty, Value *Idx)
Return the type of the element at the given index of an indexable type.
Module * getParent()
Get the module that this global value is contained inside of...
static bool isPrivateLinkage(LinkageTypes Linkage)
static bool isInternalLinkage(LinkageTypes Linkage)
This instruction compares its operands according to the predicate given to the constructor.
CmpPredicate getCmpPredicate() const
static bool isGE(Predicate P)
Return true if the predicate is SGE or UGE.
CmpPredicate getSwappedCmpPredicate() const
static LLVM_ABI bool compare(const APInt &LHS, const APInt &RHS, ICmpInst::Predicate Pred)
Return result of LHS Pred RHS comparison.
static bool isLT(Predicate P)
Return true if the predicate is SLT or ULT.
CmpPredicate getInverseCmpPredicate() const
Predicate getNonStrictCmpPredicate() const
For example, SGT -> SGE, SLT -> SLE, ULT -> ULE, UGT -> UGE.
static bool isGT(Predicate P)
Return true if the predicate is SGT or UGT.
Predicate getFlippedSignednessPredicate() const
For example, SLT->ULT, ULT->SLT, SLE->ULE, ULE->SLE, EQ->EQ.
static CmpPredicate getInverseCmpPredicate(CmpPredicate Pred)
static bool isEquality(Predicate P)
Return true if this predicate is either EQ or NE.
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
static bool isLE(Predicate P)
Return true if the predicate is SLE or ULE.
LLVM_ABI bool hasNoUnsignedWrap() const LLVM_READONLY
Determine whether the no unsigned wrap flag is set.
LLVM_ABI bool hasNoSignedWrap() const LLVM_READONLY
Determine whether the no signed wrap flag is set.
LLVM_ABI bool isIdenticalToWhenDefined(const Instruction *I, bool IntersectAttrs=false) const LLVM_READONLY
This is like isIdenticalTo, except that it ignores the SubclassOptionalData flags,...
Class to represent integer types.
static LLVM_ABI IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
Definition Type.cpp:350
A helper class to return the specified delimiter string after the first invocation of operator String...
An instruction for reading from memory.
Analysis pass that exposes the LoopInfo for a function.
Definition LoopInfo.h:587
bool contains(const LoopT *L) const
Return true if the specified loop is contained within in this loop.
BlockT * getHeader() const
unsigned getLoopDepth() const
Return the nesting level of this loop.
BlockT * getLoopPredecessor() const
If the given loop's header has exactly one unique predecessor outside the loop, return it.
LoopT * getParentLoop() const
Return the parent loop if it exists or nullptr for top level loops.
unsigned getLoopDepth(const BlockT *BB) const
Return the loop nesting level of the specified block.
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
The legacy pass manager's analysis pass to compute loop information.
Definition LoopInfo.h:612
Represents a single loop in the control flow graph.
Definition LoopInfo.h:40
bool isLoopInvariant(const Value *V) const
Return true if the specified value is loop invariant.
Definition LoopInfo.cpp:67
Metadata node.
Definition Metadata.h:1080
A Module instance is used to store all the information related to an LLVM module.
Definition Module.h:68
unsigned getOpcode() const
Return the opcode for this Instruction or ConstantExpr.
Definition Operator.h:43
Utility class for integer operators which may exhibit overflow - Add, Sub, Mul, and Shl.
Definition Operator.h:78
bool hasNoSignedWrap() const
Test whether this operation is known to never undergo signed overflow, aka the nsw property.
Definition Operator.h:113
bool hasNoUnsignedWrap() const
Test whether this operation is known to never undergo unsigned overflow, aka the nuw property.
Definition Operator.h:107
iterator_range< const_block_iterator > blocks() const
op_range incoming_values()
Value * getIncomingValueForBlock(const BasicBlock *BB) const
BasicBlock * getIncomingBlock(unsigned i) const
Return incoming basic block number i.
Value * getIncomingValue(unsigned i) const
Return incoming value number x.
unsigned getNumIncomingValues() const
Return the number of incoming edges.
AnalysisType & getAnalysis() const
getAnalysis<AnalysisType>() - This function is used by subclasses to get to the analysis information ...
PointerIntPair - This class implements a pair of a pointer and small integer.
static PointerType * getUnqual(Type *ElementType)
This constructs a pointer to an object of the specified type in the default address space (address sp...
static LLVM_ABI PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
LLVM_ABI void addPredicate(const SCEVPredicate &Pred)
Adds a new predicate.
LLVM_ABI const SCEVPredicate & getPredicate() const
LLVM_ABI const SCEV * getPredicatedSCEV(const SCEV *Expr)
Returns the rewritten SCEV for Expr in the context of the current SCEV predicate.
LLVM_ABI bool hasNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags)
Returns true if we've proved that V doesn't wrap by means of a SCEV predicate.
LLVM_ABI void setNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags)
Proves that V doesn't overflow by adding SCEV predicate.
LLVM_ABI void print(raw_ostream &OS, unsigned Depth) const
Print the SCEV mappings done by the Predicated Scalar Evolution.
LLVM_ABI bool areAddRecsEqualWithPreds(const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const
Check if AR1 and AR2 are equal, while taking into account Equal predicates in Preds.
LLVM_ABI PredicatedScalarEvolution(ScalarEvolution &SE, Loop &L)
LLVM_ABI const SCEVAddRecExpr * getAsAddRec(Value *V)
Attempts to produce an AddRecExpr for V by adding additional SCEV predicates.
LLVM_ABI unsigned getSmallConstantMaxTripCount()
Returns the upper bound of the loop trip count as a normal unsigned value, or 0 if the trip count is ...
LLVM_ABI const SCEV * getBackedgeTakenCount()
Get the (predicated) backedge count for the analyzed loop.
LLVM_ABI const SCEV * getSymbolicMaxBackedgeTakenCount()
Get the (predicated) symbolic max backedge count for the analyzed loop.
LLVM_ABI const SCEV * getSCEV(Value *V)
Returns the SCEV expression of V, in the context of the current SCEV predicate.
A set of analyses that are preserved following a run of a transformation pass.
Definition Analysis.h:112
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition Analysis.h:118
PreservedAnalysisChecker getChecker() const
Build a checker for this PreservedAnalyses and the specified analysis type.
Definition Analysis.h:275
constexpr bool isValid() const
Definition Register.h:112
This node represents an addition of some number of SCEVs.
This node represents a polynomial recurrence on the trip count of the specified loop.
LLVM_ABI const SCEV * evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const
Return the value of this chain of recurrences at the specified iteration number.
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a recurrence without clearing any previously set flags.
bool isAffine() const
Return true if this represents an expression A + B*x where A and B are loop invariant values.
bool isQuadratic() const
Return true if this represents an expression A + B*x + C*x^2 where A, B and C are loop invariant valu...
LLVM_ABI const SCEV * getNumIterationsInRange(const ConstantRange &Range, ScalarEvolution &SE) const
Return the number of iterations of this loop that produce values in the specified constant range.
LLVM_ABI const SCEVAddRecExpr * getPostIncExpr(ScalarEvolution &SE) const
Return an expression representing the value of this expression one iteration of the loop ahead.
SCEVUse getStepRecurrence(ScalarEvolution &SE) const
Constructs and returns the recurrence indicating how much this expression steps by.
This is the base class for unary cast operator classes.
LLVM_ABI SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, SCEVUse op, Type *ty)
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a non-recurrence without clearing previously set flags.
This class represents an assumption that the expression LHS Pred RHS evaluates to true,...
SCEVComparePredicate(const FoldingSetNodeIDRef ID, const ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Implementation of the SCEVPredicate interface.
This class represents a constant integer value.
ConstantInt * getValue() const
const APInt & getAPInt() const
This is the base class for unary integral cast operator classes.
LLVM_ABI SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, SCEVUse op, Type *ty)
This node is the base class min/max selections.
static enum SCEVTypes negate(enum SCEVTypes T)
This node represents multiplication of some number of SCEVs.
This node is a base class providing common functionality for n'ary operators.
ArrayRef< SCEVUse > operands() const
NoWrapFlags getNoWrapFlags(NoWrapFlags Mask=NoWrapMask) const
SCEVUse getOperand(unsigned i) const
This class represents an assumption made using SCEV expressions which can be checked at run-time.
SCEVPredicate(const SCEVPredicate &)=default
virtual bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const =0
Returns true if this predicate implies N.
SCEVPredicateKind Kind
This class represents a cast from a pointer to a pointer-sized integer value, without capturing the p...
This class represents a cast from a pointer to a pointer-sized integer value.
This visitor recursively visits a SCEV expression and re-writes it.
const SCEV * visitSignExtendExpr(const SCEVSignExtendExpr *Expr)
const SCEV * visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr)
const SCEV * visitSMinExpr(const SCEVSMinExpr *Expr)
const SCEV * visitUMinExpr(const SCEVUMinExpr *Expr)
This class represents a signed minimum selection.
This node is the base class for sequential/in-order min/max selections.
static SCEVTypes getEquivalentNonSequentialSCEVType(SCEVTypes Ty)
This class represents a sign extension of a small integer value to a larger integer value.
Visit all nodes in the expression tree using worklist traversal.
This class represents a truncation of an integer value to a smaller integer value.
This class represents a binary unsigned division operation.
This class represents an unsigned minimum selection.
This class represents a composition of other SCEV predicates, and is the class that most clients will...
void print(raw_ostream &OS, unsigned Depth) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Returns true if this predicate implies N.
SCEVUnionPredicate(ArrayRef< const SCEVPredicate * > Preds, ScalarEvolution &SE)
Union predicates don't get cached so create a dummy set ID for it.
bool isAlwaysTrue() const override
Implementation of the SCEVPredicate interface.
This means that we are dealing with an entirely unknown SCEV value, and only represent it as its LLVM...
This class represents the value of vscale, as used when defining the length of a scalable vector or r...
This class represents an assumption made on an AddRec expression.
IncrementWrapFlags
Similar to SCEV::NoWrapFlags, but with slightly different semantics for FlagNUSW.
SCEVWrapPredicate(const FoldingSetNodeIDRef ID, const SCEVAddRecExpr *AR, IncrementWrapFlags Flags)
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Returns true if this predicate implies N.
static SCEVWrapPredicate::IncrementWrapFlags setFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OnFlags)
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
const SCEVAddRecExpr * getExpr() const
Implementation of the SCEVPredicate interface.
static SCEVWrapPredicate::IncrementWrapFlags clearFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OffFlags)
Convenient IncrementWrapFlags manipulation methods.
static SCEVWrapPredicate::IncrementWrapFlags getImpliedFlags(const SCEVAddRecExpr *AR, ScalarEvolution &SE)
Returns the set of SCEVWrapPredicate no wrap flags implied by a SCEVAddRecExpr.
IncrementWrapFlags getFlags() const
Returns the set assumed no overflow flags.
This class represents a zero extension of a small integer value to a larger integer value.
This class represents an analyzed expression in the program.
unsigned short getExpressionSize() const
SCEVNoWrapFlags NoWrapFlags
LLVM_ABI bool isOne() const
Return true if the expression is a constant one.
static constexpr auto FlagNUW
LLVM_ABI void computeAndSetCanonical(ScalarEvolution &SE)
Compute and set the canonical SCEV, by constructing a SCEV with the same operands,...
LLVM_ABI bool isZero() const
Return true if the expression is a constant zero.
const SCEV * CanonicalSCEV
Pointer to the canonical version of the SCEV, i.e.
static constexpr auto FlagAnyWrap
LLVM_ABI void dump() const
This method is used for debugging.
LLVM_ABI bool isAllOnesValue() const
Return true if the expression is a constant all-ones value.
LLVM_ABI bool isNonConstantNegative() const
Return true if the specified scev is negated, but not a constant.
static constexpr auto FlagNSW
LLVM_ABI ArrayRef< SCEVUse > operands() const
Return operands of this SCEV expression.
LLVM_ABI void print(raw_ostream &OS) const
Print out the internal representation of this scalar to the specified stream.
SCEV(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, unsigned short ExpressionSize)
SCEVTypes getSCEVType() const
static constexpr auto FlagNW
LLVM_ABI Type * getType() const
Return the LLVM type of this SCEV expression.
Analysis pass that exposes the ScalarEvolution for a function.
LLVM_ABI ScalarEvolution run(Function &F, FunctionAnalysisManager &AM)
LLVM_ABI PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
LLVM_ABI PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
void print(raw_ostream &OS, const Module *=nullptr) const override
print - Print out the internal state of the pass.
bool runOnFunction(Function &F) override
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
void releaseMemory() override
releaseMemory() - This member can be implemented by a pass if it wants to be able to release its memo...
void verifyAnalysis() const override
verifyAnalysis() - This member can be implemented by a analysis pass to check state of analysis infor...
static LLVM_ABI LoopGuards collect(const Loop *L, ScalarEvolution &SE)
Collect rewrite map for loop guards for loop L, together with flags indicating if NUW and NSW can be ...
LLVM_ABI const SCEV * rewrite(const SCEV *Expr) const
Try to apply the collected loop guards to Expr.
The main scalar evolution driver.
LLVM_ABI const SCEV * getUDivExpr(SCEVUse LHS, SCEVUse RHS)
Get a canonical unsigned division expression, or something simpler if possible.
const SCEV * getConstantMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEVConstant that is greater than or equal to (i.e.
static bool hasFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags TestFlags)
const DataLayout & getDataLayout() const
Return the DataLayout associated with the module this SCEV instance is operating on.
LLVM_ABI bool isKnownNonNegative(const SCEV *S)
Test if the given expression is known to be non-negative.
LLVM_ABI bool isKnownOnEveryIteration(CmpPredicate Pred, const SCEVAddRecExpr *LHS, const SCEV *RHS)
Test if the condition described by Pred, LHS, RHS is known to be true on every iteration of the loop ...
LLVM_ABI const SCEV * getNegativeSCEV(const SCEV *V, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap)
Return the SCEV object corresponding to -V.
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterationsImpl(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
LLVM_ABI const SCEV * getUDivCeilSCEV(const SCEV *N, const SCEV *D)
Compute ceil(N / D).
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterations(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L at given Context duri...
LLVM_ABI Type * getWiderType(Type *Ty1, Type *Ty2) const
LLVM_ABI const SCEV * getAbsExpr(const SCEV *Op, bool IsNSW)
LLVM_ABI bool isKnownNonPositive(const SCEV *S)
Test if the given expression is known to be non-positive.
LLVM_ABI bool isKnownNegative(const SCEV *S)
Test if the given expression is known to be negative.
LLVM_ABI const SCEV * getPredicatedConstantMaxBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getConstantMaxBackedgeTakenCount, except it will add a set of SCEV predicates to Predicate...
LLVM_ABI const SCEV * removePointerBase(const SCEV *S)
Compute an expression equivalent to S - getPointerBase(S).
LLVM_ABI bool isLoopEntryGuardedByCond(const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the loop is protected by a conditional between LHS and RHS.
LLVM_ABI bool isKnownNonZero(const SCEV *S)
Test if the given expression is known to be non-zero.
LLVM_ABI const SCEV * getURemExpr(SCEVUse LHS, SCEVUse RHS)
Represents an unsigned remainder expression based on unsigned division.
LLVM_ABI const SCEV * getSCEVAtScope(const SCEV *S, const Loop *L)
Return a SCEV expression for the specified value at the specified scope in the program.
LLVM_ABI const SCEV * getBackedgeTakenCount(const Loop *L, ExitCountKind Kind=Exact)
If the specified loop has a predictable backedge-taken count, return it, otherwise return a SCEVCould...
LLVM_ABI const SCEV * getSMinExpr(SCEVUse LHS, SCEVUse RHS)
LLVM_ABI void setNoWrapFlags(SCEVAddRecExpr *AddRec, SCEV::NoWrapFlags Flags)
Update no-wrap flags of an AddRec.
LLVM_ABI const SCEV * getUMaxFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS)
Promote the operands to the wider of the types using zero-extension, and then perform a umax operatio...
const SCEV * getZero(Type *Ty)
Return a SCEV for the constant 0 of a specific type.
LLVM_ABI bool willNotOverflow(Instruction::BinaryOps BinOp, bool Signed, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI=nullptr)
Is operation BinOp between LHS and RHS provably does not have a signed/unsigned overflow (Signed)?
LLVM_ABI ExitLimit computeExitLimitFromCond(const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit, bool AllowPredicates=false)
Compute the number of times the backedge of the specified loop will execute if its exit condition wer...
LLVM_ABI const SCEV * getZeroExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI const SCEV * getMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< SCEVUse > &Operands)
LLVM_ABI const SCEVPredicate * getEqualPredicate(const SCEV *LHS, const SCEV *RHS)
LLVM_ABI unsigned getSmallConstantTripMultiple(const Loop *L, const SCEV *ExitCount)
Returns the largest constant divisor of the trip count as a normal unsigned value,...
LLVM_ABI uint64_t getTypeSizeInBits(Type *Ty) const
Return the size in bits of the specified type, for which isSCEVable must return true.
LLVM_ABI const SCEV * getConstant(ConstantInt *V)
LLVM_ABI const SCEV * getPredicatedBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getBackedgeTakenCount, except it will add a set of SCEV predicates to Predicates that are ...
LLVM_ABI const SCEV * getSCEV(Value *V)
Return a SCEV expression for the full generality of the specified expression.
LLVM_ABI const SCEV * getMinusSCEV(SCEVUse LHS, SCEVUse RHS, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Return LHS-RHS.
ConstantRange getSignedRange(const SCEV *S)
Determine the signed range for a particular SCEV.
LLVM_ABI const SCEV * getAddRecExpr(SCEVUse Start, SCEVUse Step, const Loop *L, SCEV::NoWrapFlags Flags)
Get an add recurrence expression for the specified loop.
LLVM_ABI const SCEV * getNoopOrSignExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
bool loopHasNoAbnormalExits(const Loop *L)
Return true if the loop has no abnormal exits.
LLVM_ABI const SCEV * getTripCountFromExitCount(const SCEV *ExitCount)
A version of getTripCountFromExitCount below which always picks an evaluation type which can not resu...
LLVM_ABI ScalarEvolution(Function &F, TargetLibraryInfo &TLI, AssumptionCache &AC, DominatorTree &DT, LoopInfo &LI)
const SCEV * getOne(Type *Ty)
Return a SCEV for the constant 1 of a specific type.
LLVM_ABI const SCEV * getTruncateOrNoop(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI const SCEV * getLosslessPtrToIntExpr(const SCEV *Op)
LLVM_ABI const SCEV * getCastExpr(SCEVTypes Kind, const SCEV *Op, Type *Ty)
LLVM_ABI const SCEV * getSequentialMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< SCEVUse > &Operands)
LLVM_ABI std::optional< bool > evaluatePredicateAt(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Check whether the condition described by Pred, LHS, and RHS is true or false in the given Context.
LLVM_ABI unsigned getSmallConstantMaxTripCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > *Predicates=nullptr)
Returns the upper bound of the loop trip count as a normal unsigned value.
LLVM_ABI const SCEV * getPtrToIntExpr(const SCEV *Op, Type *Ty)
LLVM_ABI bool isBackedgeTakenCountMaxOrZero(const Loop *L)
Return true if the backedge taken count is either the value returned by getConstantMaxBackedgeTakenCo...
LLVM_ABI void forgetLoop(const Loop *L)
This method should be called by the client when it has changed a loop in a way that may effect Scalar...
LLVM_ABI bool isLoopInvariant(const SCEV *S, const Loop *L)
Return true if the value of the given SCEV is unchanging in the specified loop.
LLVM_ABI bool isKnownPositive(const SCEV *S)
Test if the given expression is known to be positive.
LLVM_ABI bool SimplifyICmpOperands(CmpPredicate &Pred, SCEVUse &LHS, SCEVUse &RHS, unsigned Depth=0)
Simplify LHS and RHS in a comparison with predicate Pred.
APInt getUnsignedRangeMin(const SCEV *S)
Determine the min of the unsigned range for a particular SCEV.
LLVM_ABI const SCEV * getOffsetOfExpr(Type *IntTy, StructType *STy, unsigned FieldNo)
Return an expression for offsetof on the given field with type IntTy.
LLVM_ABI LoopDisposition getLoopDisposition(const SCEV *S, const Loop *L)
Return the "disposition" of the given SCEV with respect to the given loop.
LLVM_ABI bool containsAddRecurrence(const SCEV *S)
Return true if the SCEV is a scAddRecExpr or it contains scAddRecExpr.
LLVM_ABI const SCEV * getSignExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI bool hasOperand(const SCEV *S, const SCEV *Op) const
Test whether the given SCEV has Op as a direct or indirect operand.
LLVM_ABI const SCEV * getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI bool isSCEVable(Type *Ty) const
Test if values of the given type are analyzable within the SCEV framework.
LLVM_ABI Type * getEffectiveSCEVType(Type *Ty) const
Return a type with the same bitwidth as the given type and which represents how SCEV will treat the g...
LLVM_ABI const SCEVPredicate * getComparePredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
LLVM_ABI bool haveSameSign(const SCEV *S1, const SCEV *S2)
Return true if we know that S1 and S2 must have the same sign.
LLVM_ABI const SCEV * getNotSCEV(const SCEV *V)
Return the SCEV object corresponding to ~V.
LLVM_ABI const SCEV * getElementCount(Type *Ty, ElementCount EC, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap)
LLVM_ABI bool instructionCouldExistWithOperands(const SCEV *A, const SCEV *B)
Return true if there exists a point in the program at which both A and B could be operands to the sam...
ConstantRange getUnsignedRange(const SCEV *S)
Determine the unsigned range for a particular SCEV.
LLVM_ABI void print(raw_ostream &OS) const
LLVM_ABI const SCEV * getPredicatedExitCount(const Loop *L, const BasicBlock *ExitingBlock, SmallVectorImpl< const SCEVPredicate * > *Predicates, ExitCountKind Kind=Exact)
Same as above except this uses the predicated backedge taken info and may require predicates.
static SCEV::NoWrapFlags clearFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OffFlags)
LLVM_ABI void forgetTopmostLoop(const Loop *L)
LLVM_ABI void forgetValue(Value *V)
This method should be called by the client when it has changed a value in a way that may effect its v...
APInt getSignedRangeMin(const SCEV *S)
Determine the min of the signed range for a particular SCEV.
LLVM_ABI bool isLoopUniform(const SCEV *S, const Loop *L)
Returns true if the given SCEV is loop-uniform with respect to the specified loop L.
LLVM_ABI const SCEV * getNoopOrAnyExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI void forgetBlockAndLoopDispositions(Value *V=nullptr)
Called when the client has changed the disposition of values in a loop or block.
LLVM_ABI const SCEV * getTruncateExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI const SCEV * getUMaxExpr(SCEVUse LHS, SCEVUse RHS)
static SCEV::NoWrapFlags maskFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags Mask)
Convenient NoWrapFlags manipulation.
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantPredicate(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI=nullptr)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L, return a LoopInvaria...
LLVM_ABI const SCEV * getStoreSizeOfExpr(Type *IntTy, Type *StoreTy)
Return an expression for the store size of StoreTy that is type IntTy.
LLVM_ABI const SCEVPredicate * getWrapPredicate(const SCEVAddRecExpr *AR, SCEVWrapPredicate::IncrementWrapFlags AddedFlags)
LLVM_ABI bool isLoopBackedgeGuardedByCond(const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether the backedge of the loop is protected by a conditional between LHS and RHS.
LLVM_ABI APInt getNonZeroConstantMultiple(const SCEV *S)
const SCEV * getMinusOne(Type *Ty)
Return a SCEV for the constant -1 of a specific type.
static SCEV::NoWrapFlags setFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OnFlags)
LLVM_ABI bool hasLoopInvariantBackedgeTakenCount(const Loop *L)
Return true if the specified loop has an analyzable loop-invariant backedge-taken count.
LLVM_ABI BlockDisposition getBlockDisposition(const SCEV *S, const BasicBlock *BB)
Return the "disposition" of the given SCEV with respect to the given block.
LLVM_ABI const SCEV * getNoopOrZeroExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI bool invalidate(Function &F, const PreservedAnalyses &PA, FunctionAnalysisManager::Invalidator &Inv)
LLVM_ABI const SCEV * getUMinFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS, bool Sequential=false)
Promote the operands to the wider of the types using zero-extension, and then perform a umin operatio...
LLVM_ABI bool loopIsFiniteByAssumption(const Loop *L)
Return true if this loop is finite by assumption.
LLVM_ABI const SCEV * getExistingSCEV(Value *V)
Return an existing SCEV for V if there is one, otherwise return nullptr.
LLVM_ABI APInt getConstantMultiple(const SCEV *S, const Instruction *CtxI=nullptr)
Returns the max constant multiple of S.
LoopDisposition
An enum describing the relationship between a SCEV and a loop.
@ LoopComputable
The SCEV varies predictably with the loop.
@ LoopVariant
The SCEV is loop-variant (unknown).
@ LoopInvariant
The SCEV is loop-invariant.
@ LoopUniform
The SCEV is loop-uniform.
LLVM_ABI bool isKnownMultipleOf(const SCEV *S, uint64_t M, SmallVectorImpl< const SCEVPredicate * > &Assumptions)
Check that S is a multiple of M.
LLVM_ABI const SCEV * getAnyExtendExpr(const SCEV *Op, Type *Ty)
getAnyExtendExpr - Return a SCEV for the given operand extended with unspecified bits out to the give...
LLVM_ABI bool isKnownToBeAPowerOfTwo(const SCEV *S, bool OrZero=false, bool OrNegative=false)
Test if the given expression is known to be a power of 2.
LLVM_ABI std::optional< SCEV::NoWrapFlags > getStrengthenedNoWrapFlagsFromBinOp(const OverflowingBinaryOperator *OBO)
Parse NSW/NUW flags from add/sub/mul IR binary operation Op into SCEV no-wrap flags,...
LLVM_ABI void forgetLcssaPhiWithNewPredecessor(Loop *L, PHINode *V)
Forget LCSSA phi node V of loop L to which a new predecessor was added, such that it may no longer be...
LLVM_ABI bool containsUndefs(const SCEV *S) const
Return true if the SCEV expression contains an undef value.
LLVM_ABI std::optional< MonotonicPredicateType > getMonotonicPredicateType(const SCEVAddRecExpr *LHS, ICmpInst::Predicate Pred)
If, for all loop invariant X, the predicate "LHS `Pred` X" is monotonically increasing or decreasing,...
LLVM_ABI const SCEV * getCouldNotCompute()
LLVM_ABI const SCEV * getMulExpr(SmallVectorImpl< SCEVUse > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical multiply expression, or something simpler if possible.
LLVM_ABI bool isAvailableAtLoopEntry(const SCEV *S, const Loop *L)
Determine if the SCEV can be evaluated at loop's entry.
LLVM_ABI uint32_t getMinTrailingZeros(const SCEV *S, const Instruction *CtxI=nullptr)
Determine the minimum number of zero bits that S is guaranteed to end in (at every loop iteration).
BlockDisposition
An enum describing the relationship between a SCEV and a basic block.
@ DominatesBlock
The SCEV dominates the block.
@ ProperlyDominatesBlock
The SCEV properly dominates the block.
@ DoesNotDominateBlock
The SCEV does not dominate the block.
LLVM_ABI const SCEV * getExitCount(const Loop *L, const BasicBlock *ExitingBlock, ExitCountKind Kind=Exact)
Return the number of times the backedge executes before the given exit would be taken; if not exactly...
LLVM_ABI const SCEV * getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI void getPoisonGeneratingValues(SmallPtrSetImpl< const Value * > &Result, const SCEV *S)
Return the set of Values that, if poison, will definitively result in S being poison as well.
LLVM_ABI void forgetLoopDispositions()
Called when the client has changed the disposition of values in this loop.
LLVM_ABI const SCEV * getVScale(Type *Ty)
LLVM_ABI unsigned getSmallConstantTripCount(const Loop *L)
Returns the exact trip count of the loop if we can compute it, and the result is a small constant.
LLVM_ABI bool hasComputableLoopEvolution(const SCEV *S, const Loop *L)
Return true if the given SCEV changes value in a known way in the specified loop.
LLVM_ABI const SCEV * getPointerBase(const SCEV *V)
Transitively follow the chain of pointer-type operands until reaching a SCEV that does not have a sin...
LLVM_ABI void forgetAllLoops()
LLVM_ABI bool dominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV dominate the specified basic block.
APInt getUnsignedRangeMax(const SCEV *S)
Determine the max of the unsigned range for a particular SCEV.
LLVM_ABI const SCEV * getAddExpr(SmallVectorImpl< SCEVUse > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical add expression, or something simpler if possible.
ExitCountKind
The terms "backedge taken count" and "exit count" are used interchangeably to refer to the number of ...
@ SymbolicMaximum
An expression which provides an upper bound on the exact trip count.
@ ConstantMaximum
A constant which provides an upper bound on the exact trip count.
@ Exact
An expression exactly describing the number of times the backedge has executed when a loop is exited.
LLVM_ABI bool isKnownPredicate(CmpPredicate Pred, SCEVUse LHS, SCEVUse RHS)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
LLVM_ABI const SCEV * applyLoopGuards(const SCEV *Expr, const Loop *L)
Try to apply information from loop guards for L to Expr.
LLVM_ABI const SCEV * getPtrToAddrExpr(const SCEV *Op)
LLVM_ABI const SCEVAddRecExpr * convertSCEVToAddRecWithPredicates(const SCEV *S, const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Preds)
Tries to convert the S expression to an AddRec expression, adding additional predicates to Preds as r...
LLVM_ABI const SCEV * getSMaxExpr(SCEVUse LHS, SCEVUse RHS)
LLVM_ABI const SCEV * getElementSize(Instruction *Inst)
Return the size of an element read or written by Inst.
LLVM_ABI const SCEV * getSizeOfExpr(Type *IntTy, TypeSize Size)
Return an expression for a TypeSize.
LLVM_ABI std::optional< bool > evaluatePredicate(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Check whether the condition described by Pred, LHS, and RHS is true or false.
LLVM_ABI const SCEV * getUnknown(Value *V)
LLVM_ABI std::optional< std::pair< const SCEV *, SmallVector< const SCEVPredicate *, 3 > > > createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI)
Checks if SymbolicPHI can be rewritten as an AddRecExpr under some Predicates.
LLVM_ABI const SCEV * getTruncateOrZeroExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI bool isKnownViaInduction(CmpPredicate Pred, SCEVUse LHS, SCEVUse RHS)
We'd like to check the predicate on every iteration of the most dominated loop between loops used in ...
LLVM_ABI std::optional< APInt > computeConstantDifference(const SCEV *LHS, const SCEV *RHS)
Compute LHS - RHS and returns the result as an APInt if it is a constant, and std::nullopt if it isn'...
LLVM_ABI bool properlyDominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV properly dominate the specified basic block.
LLVM_ABI const SCEV * getUDivExactExpr(SCEVUse LHS, SCEVUse RHS)
Get a canonical unsigned division expression, or something simpler if possible.
LLVM_ABI const SCEV * rewriteUsingPredicate(const SCEV *S, const Loop *L, const SCEVPredicate &A)
Re-writes the SCEV according to the Predicates in A.
LLVM_ABI std::pair< const SCEV *, const SCEV * > SplitIntoInitAndPostInc(const Loop *L, const SCEV *S)
Splits SCEV expression S into two SCEVs.
LLVM_ABI bool canReuseInstruction(const SCEV *S, Instruction *I, SmallVectorImpl< Instruction * > &DropPoisonGeneratingInsts)
Check whether it is poison-safe to represent the expression S using the instruction I.
LLVM_ABI bool isKnownPredicateAt(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
LLVM_ABI const SCEV * getPredicatedSymbolicMaxBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getSymbolicMaxBackedgeTakenCount, except it will add a set of SCEV predicates to Predicate...
LLVM_ABI const SCEV * getGEPExpr(GEPOperator *GEP, ArrayRef< SCEVUse > IndexExprs)
Returns an expression for a GEP.
LLVM_ABI const SCEV * getUMinExpr(SCEVUse LHS, SCEVUse RHS, bool Sequential=false)
LLVM_ABI void registerUser(const SCEV *User, ArrayRef< const SCEV * > Ops)
Notify this ScalarEvolution that User directly uses SCEVs in Ops.
LLVM_ABI bool isBasicBlockEntryGuardedByCond(const BasicBlock *BB, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the basic block is protected by a conditional between LHS and RHS.
LLVM_ABI const SCEV * getTruncateOrSignExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI bool containsErasedValue(const SCEV *S) const
Return true if the SCEV expression contains a Value that has been optimised out and is now a nullptr.
const SCEV * getSymbolicMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEV that is greater than or equal to (i.e.
APInt getSignedRangeMax(const SCEV *S)
Determine the max of the signed range for a particular SCEV.
LLVM_ABI void verify() const
LLVMContext & getContext() const
Implements a dense probed hash-table based set with some number of buckets stored inline.
Definition DenseSet.h:301
size_type size() const
Definition SmallPtrSet.h:99
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
bool contains(ConstPtrType Ptr) const
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
reference emplace_back(ArgTypes &&... Args)
void reserve(size_type N)
iterator erase(const_iterator CI)
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
iterator insert(iterator I, T &&Elt)
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
An instruction for storing to memory.
Represent a constant reference to a string, i.e.
Definition StringRef.h:56
Used to lazily calculate structure layout information for a target machine, based on the DataLayout s...
Definition DataLayout.h:743
TypeSize getElementOffset(unsigned Idx) const
Definition DataLayout.h:774
TypeSize getSizeInBits() const
Definition DataLayout.h:754
Class to represent struct types.
Analysis pass providing the TargetLibraryInfo.
Provides information about what library functions are available for the current target.
The instances of the Type class are immutable: once they are created, they are never changed.
Definition Type.h:46
static LLVM_ABI IntegerType * getInt32Ty(LLVMContext &C)
Definition Type.cpp:309
bool isPointerTy() const
True if this is an instance of PointerType.
Definition Type.h:282
LLVM_ABI TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
Definition Type.cpp:197
static LLVM_ABI IntegerType * getInt1Ty(LLVMContext &C)
Definition Type.cpp:306
bool isIntOrPtrTy() const
Return true if this is an integer type or a pointer type.
Definition Type.h:270
bool isIntegerTy() const
True if this is an instance of IntegerType.
Definition Type.h:257
static LLVM_ABI IntegerType * getIntNTy(LLVMContext &C, unsigned N)
Definition Type.cpp:313
A Use represents the edge between a Value definition and its users.
Definition Use.h:35
op_range operands()
Definition User.h:267
Use & Op()
Definition User.h:171
Value * getOperand(unsigned i) const
Definition User.h:207
LLVM Value Representation.
Definition Value.h:75
Type * getType() const
All values are typed, get the type of this value.
Definition Value.h:255
LLVMContext & getContext() const
All values hold a context through their type.
Definition Value.h:258
unsigned getValueID() const
Return an ID for the concrete type of this object.
Definition Value.h:543
LLVM_ABI void printAsOperand(raw_ostream &O, bool PrintType=true, const Module *M=nullptr) const
Print the name of this Value out to the specified raw_ostream.
LLVM_ABI StringRef getName() const
Return a constant reference to the value's name.
Definition Value.cpp:318
constexpr bool isScalable() const
Returns whether the quantity is scaled by a runtime quantity (vscale).
Definition TypeSize.h:168
An efficient, type-erasing, non-owning reference to a callable.
const ParentTy * getParent() const
Definition ilist_node.h:34
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition raw_ostream.h:53
raw_ostream & indent(unsigned NumSpaces)
indent - Insert 'NumSpaces' spaces.
Changed
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr char Align[]
Key for Kernel::Arg::Metadata::mAlign.
const APInt & smin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be signed.
Definition APInt.h:2277
const APInt & smax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be signed.
Definition APInt.h:2282
const APInt & umin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be unsigned.
Definition APInt.h:2287
LLVM_ABI std::optional< APInt > SolveQuadraticEquationWrap(APInt A, APInt B, APInt C, unsigned RangeWidth)
Let q(n) = An^2 + Bn + C, and BW = bit width of the value range (e.g.
Definition APInt.cpp:2864
const APInt & umax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be unsigned.
Definition APInt.h:2292
LLVM_ABI APInt GreatestCommonDivisor(APInt A, APInt B)
Compute GCD of two unsigned APInt values.
Definition APInt.cpp:830
constexpr bool any(E Val)
@ Entry
Definition COFF.h:862
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition CallingConv.h:24
@ C
The default llvm calling convention, compatible with C.
Definition CallingConv.h:34
int getMinValue(MCInstrInfo const &MCII, MCInst const &MCI)
Return the minimum value of an extendable operand.
@ BasicBlock
Various leaf nodes.
Definition ISDOpcodes.h:81
LLVM_ABI Function * getDeclarationIfExists(const Module *M, ID id)
Look up the Function declaration of the intrinsic id in the Module M and return it if it exists.
Predicate
Predicate - These are "(BI << 5) | BO" for various predicates.
match_combine_or< Ty... > m_CombineOr(const Ty &...Ps)
Combine pattern matchers matching any of Ps patterns.
BinaryOp_match< LHS, RHS, Instruction::AShr > m_AShr(const LHS &L, const RHS &R)
ap_match< APInt > m_APInt(const APInt *&Res)
Match a ConstantInt or splatted ConstantVector, binding the specified pointer to the contained APInt.
bool match(Val *V, const Pattern &P)
IntrinsicID_match m_Intrinsic()
Match intrinsic calls like this: m_Intrinsic<Intrinsic::fabs>(m_Value(X))
ThreeOps_match< Cond, LHS, RHS, Instruction::Select > m_Select(const Cond &C, const LHS &L, const RHS &R)
Matches SelectInst.
auto m_BasicBlock()
Match an arbitrary basic block value and ignore it.
ExtractValue_match< Ind, Val_t > m_ExtractValue(const Val_t &V)
Match a single index ExtractValue instruction.
auto m_Value()
Match an arbitrary value and ignore it.
auto m_LogicalOr()
Matches L || R where L and R are arbitrary values.
match_bind< WithOverflowInst > m_WithOverflowInst(WithOverflowInst *&I)
Match a with overflow intrinsic, capturing it if we match.
BinaryOp_match< LHS, RHS, Instruction::SDiv > m_SDiv(const LHS &L, const RHS &R)
BinaryOp_match< LHS, RHS, Instruction::LShr > m_LShr(const LHS &L, const RHS &R)
BinaryOp_match< LHS, RHS, Instruction::Shl > m_Shl(const LHS &L, const RHS &R)
auto m_LogicalAnd()
Matches L && R where L and R are arbitrary values.
brc_match< Cond_t, match_bind< BasicBlock >, match_bind< BasicBlock > > m_Br(const Cond_t &C, BasicBlock *&T, BasicBlock *&F)
auto m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
bind_cst_ty m_scev_APInt(const APInt *&C)
Match an SCEV constant and bind it to an APInt.
cst_pred_ty< is_all_ones > m_scev_AllOnes()
Match an integer with all bits set.
SCEVUnaryExpr_match< SCEVZeroExtendExpr, Op0_t > m_scev_ZExt(const Op0_t &Op0)
is_undef_or_poison m_scev_UndefOrPoison()
Match an SCEVUnknown wrapping undef or poison.
cst_pred_ty< is_one > m_scev_One()
Match an integer 1.
specificloop_ty m_SpecificLoop(const Loop *L)
SCEVUnaryExpr_match< SCEVSignExtendExpr, Op0_t > m_scev_SExt(const Op0_t &Op0)
match_bind< const SCEVMulExpr > m_scev_Mul(const SCEVMulExpr *&V)
cst_pred_ty< is_zero > m_scev_Zero()
Match an integer 0.
SCEVUnaryExpr_match< SCEVTruncateExpr, Op0_t > m_scev_Trunc(const Op0_t &Op0)
bool match(const SCEV *S, const Pattern &P)
SCEVBinaryExpr_match< SCEVUDivExpr, Op0_t, Op1_t > m_scev_UDiv(const Op0_t &Op0, const Op1_t &Op1)
specificscev_ty m_scev_Specific(const SCEV *S)
Match if we have a specific specified SCEV.
SCEVAffineAddRec_match< Op0_t, Op1_t, match_isa< const Loop > > m_scev_AffineAddRec(const Op0_t &Op0, const Op1_t &Op1)
match_bind< const SCEVUnknown > m_SCEVUnknown(const SCEVUnknown *&V)
SCEVBinaryExpr_match< SCEVMulExpr, Op0_t, Op1_t, SCEV::FlagNUW, true > m_scev_c_NUWMul(const Op0_t &Op0, const Op1_t &Op1)
match_bind< const SCEVAddExpr > m_scev_Add(const SCEVAddExpr *&V)
SCEVBinaryExpr_match< SCEVMulExpr, Op0_t, Op1_t, SCEV::FlagAnyWrap, true > m_scev_c_Mul(const Op0_t &Op0, const Op1_t &Op1)
SCEVBinaryExpr_match< SCEVSMaxExpr, Op0_t, Op1_t > m_scev_SMax(const Op0_t &Op0, const Op1_t &Op1)
SCEVURem_match< Op0_t, Op1_t > m_scev_URem(Op0_t LHS, Op1_t RHS, ScalarEvolution &SE)
Match the mathematical pattern A - (A / B) * B, where A and B can be arbitrary expressions.
@ Valid
The data is already valid.
initializer< Ty > init(const Ty &Val)
LocationClass< Ty > location(Ty &L)
@ Switch
The "resume-switch" lowering, where there are separate resume and destroy functions that are shared b...
Definition CoroShape.h:31
constexpr double e
NodeAddr< PhiNode * > Phi
Definition RDFGraph.h:390
friend class Instruction
Iterator for Instructions in a `BasicBlock.
Definition BasicBlock.h:73
This is an optimization pass for GlobalISel generic memory operations.
void visitAll(const SCEV *Root, SV &Visitor)
Use SCEVTraversal to visit all nodes in the given expression tree.
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
Definition STLExtras.h:315
@ Offset
Definition DWP.cpp:558
FunctionAddr VTableAddr Value
Definition InstrProf.h:137
LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt gcd(const DynamicAPInt &A, const DynamicAPInt &B)
void stable_sort(R &&Range)
Definition STLExtras.h:2115
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1738
SaveAndRestore(T &) -> SaveAndRestore< T >
Printable print(const GCNRegPressure &RP, const GCNSubtarget *ST=nullptr, unsigned DynamicVGPRBlockSize=0)
LLVM_ABI bool canCreatePoison(const Operator *Op, bool ConsiderFlagsAndMetadata=true)
LLVM_ABI bool mustTriggerUB(const Instruction *I, const SmallPtrSetImpl< const Value * > &KnownPoison)
Return true if the given instruction must trigger undefined behavior when I is executed with any oper...
LLVM_ABI bool canConstantFoldCallTo(const CallBase *Call, const Function *F)
canConstantFoldCallTo - Return true if its even possible to fold a call to the specified function.
InterleavedRange< Range > interleaved(const Range &R, StringRef Separator=", ", StringRef Prefix="", StringRef Suffix="")
Output range R as a sequence of interleaved elements.
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:643
LLVM_ABI bool verifyFunction(const Function &F, raw_ostream *OS=nullptr)
Check a function for errors, useful for use when debugging a pass.
auto successors(const MachineBasicBlock *BB)
scope_exit(Callable) -> scope_exit< Callable >
constexpr from_range_t from_range
auto dyn_cast_if_present(const Y &Val)
dyn_cast_if_present<X> - Functionally identical to dyn_cast, except that a null (or none in the case ...
Definition Casting.h:732
bool set_is_subset(const S1Ty &S1, const S2Ty &S2)
set_is_subset(A, B) - Return true iff A in B
void append_range(Container &C, Range &&R)
Wrapper function to append range R to container C.
Definition STLExtras.h:2207
constexpr bool isUIntN(unsigned N, uint64_t x)
Checks if an unsigned integer fits into the given (dynamic) bit width.
Definition MathExtras.h:243
LLVM_ABI Constant * ConstantFoldCompareInstOperands(unsigned Predicate, Constant *LHS, Constant *RHS, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr, const Instruction *I=nullptr)
Attempt to constant fold a compare instruction (icmp/fcmp) with the specified operands.
void * PointerTy
LLVM_ABI bool VerifySCEV
auto uninitialized_copy(R &&Src, IterTy Dst)
Definition STLExtras.h:2110
bool isa_and_nonnull(const Y &Val)
Definition Casting.h:676
LLVM_ABI ConstantRange getConstantRangeFromMetadata(const MDNode &RangeMD)
Parse out a conservative ConstantRange from !range metadata.
int countr_zero(T Val)
Count number of 0's from the least significant bit to the most stopping at the first 1.
Definition bit.h:204
LLVM_ABI Value * simplifyInstruction(Instruction *I, const SimplifyQuery &Q)
See if we can compute a simplified version of this instruction.
LLVM_ABI bool isOverflowIntrinsicNoWrap(const WithOverflowInst *WO, const DominatorTree &DT)
Returns true if the arithmetic part of the WO 's result is used only along the paths control dependen...
DomTreeNodeBase< BasicBlock > DomTreeNode
Definition Dominators.h:94
LLVM_ABI bool matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO, Value *&Start, Value *&Step)
Attempt to match a simple first order recurrence cycle of the form: iv = phi Ty [Start,...
auto dyn_cast_or_null(const Y &Val)
Definition Casting.h:753
void erase(Container &C, ValueType V)
Wrapper function to remove a value from a container:
Definition STLExtras.h:2199
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1745
iterator_range< pointee_iterator< WrappedIteratorT > > make_pointee_range(RangeT &&Range)
Definition iterator.h:341
auto reverse(ContainerTy &&C)
Definition STLExtras.h:407
LLVM_ABI bool isMustProgress(const Loop *L)
Return true if this loop can be assumed to make progress.
LLVM_ABI bool impliesPoison(const Value *ValAssumedPoison, const Value *V)
Return true if V is poison given that ValAssumedPoison is already poison.
LLVM_ABI bool isFinite(const Loop *L)
Return true if this loop can be assumed to run for a finite number of iterations.
LLVM_ABI void computeKnownBits(const Value *V, KnownBits &Known, const DataLayout &DL, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true, unsigned Depth=0)
Determine which bits of V are known to be either zero or one and return them in the KnownZero/KnownOn...
unsigned short computeExpressionSize(ArrayRef< SCEVUse > Args)
LLVM_ABI bool programUndefinedIfPoison(const Instruction *Inst)
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:209
bool isPointerTy(const Type *T)
Definition SPIRVUtils.h:371
FunctionAddr VTableAddr Count
Definition InstrProf.h:139
LLVM_ABI ConstantRange getVScaleRange(const Function *F, unsigned BitWidth)
Determine the possible constant range of vscale with the given bit width, based on the vscale_range f...
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
Definition Casting.h:547
LLVM_ATTRIBUTE_VISIBILITY_DEFAULT AnalysisKey InnerAnalysisManagerProxy< AnalysisManagerT, IRUnitT, ExtraArgTs... >::Key
LLVM_ABI bool isKnownNonZero(const Value *V, const SimplifyQuery &Q, unsigned Depth=0)
Return true if the given value is known to be non-zero when defined.
LLVM_ABI bool propagatesPoison(const Use &PoisonOp)
Return true if PoisonOp's user yields poison or raises UB if its operand PoisonOp is poison.
@ UMin
Unsigned integer min implemented in terms of select(cmp()).
@ Mul
Product of integers.
@ SMax
Signed integer max implemented in terms of select(cmp()).
@ SMin
Signed integer min implemented in terms of select(cmp()).
@ Add
Sum of integers.
@ UMax
Unsigned integer max implemented in terms of select(cmp()).
auto count(R &&Range, const E &Element)
Wrapper function around std::count to count the number of times an element Element occurs in the give...
Definition STLExtras.h:2011
DWARFExpression::Operation Op
auto max_element(R &&Range)
Provide wrappers to std::max_element which take ranges instead of having to pass begin/end explicitly...
Definition STLExtras.h:2087
raw_ostream & operator<<(raw_ostream &OS, const APFixedPoint &FX)
ArrayRef(const T &OneElt) -> ArrayRef< T >
LLVM_ABI unsigned ComputeNumSignBits(const Value *Op, const DataLayout &DL, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true, unsigned Depth=0)
Return the number of times the sign bit of the register is replicated into the other bits.
constexpr unsigned BitWidth
OutputIt move(R &&Range, OutputIt Out)
Provide wrappers to std::move which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1916
LLVM_ABI bool isGuaranteedToTransferExecutionToSuccessor(const Instruction *I)
Return true if this function can prove that the instruction I will always transfer execution to one o...
auto count_if(R &&Range, UnaryPredicate P)
Wrapper function around std::count_if to count the number of times an element satisfying a given pred...
Definition STLExtras.h:2018
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:559
constexpr bool isIntN(unsigned N, int64_t x)
Checks if an signed integer fits into the given (dynamic) bit width.
Definition MathExtras.h:248
auto predecessors(const MachineBasicBlock *BB)
bool is_contained(R &&Range, const E &Element)
Returns true if Element is found in Range.
Definition STLExtras.h:1946
iterator_range< df_iterator< T > > depth_first(const T &G)
auto seq(T Begin, T End)
Iterate over an integral type from Begin up to - but not including - End.
Definition Sequence.h:305
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.
LLVM_ABI bool isGuaranteedNotToBePoison(const Value *V, AssumptionCache *AC=nullptr, const Instruction *CtxI=nullptr, const DominatorTree *DT=nullptr, unsigned Depth=0)
Returns true if V cannot be poison, but may be undef.
LLVM_ABI Constant * ConstantFoldInstOperands(const Instruction *I, ArrayRef< Constant * > Ops, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr, bool AllowNonDeterministic=true)
ConstantFoldInstOperands - Attempt to constant fold an instruction with the specified operands.
SCEVUseT< const SCEV * > SCEVUse
bool SCEVExprContains(const SCEV *Root, PredTy Pred)
Return true if any node in Root satisfies the predicate Pred.
Implement std::hash so that hash_code can be used in STL containers.
Definition BitVector.h:874
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition BitVector.h:876
#define N
#define NC
Definition regutils.h:42
A special type used by analysis passes to provide an address that identifies that particular analysis...
Definition Analysis.h:29
static KnownBits makeConstant(const APInt &C)
Create known bits from a known constant.
Definition KnownBits.h:315
bool isNonNegative() const
Returns true if this value is known to be non-negative.
Definition KnownBits.h:106
static LLVM_ABI KnownBits ashr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for ashr(LHS, RHS).
unsigned getBitWidth() const
Get the bit width of this value.
Definition KnownBits.h:44
static LLVM_ABI KnownBits lshr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for lshr(LHS, RHS).
KnownBits zextOrTrunc(unsigned BitWidth) const
Return known bits for a zero extension or truncation of the value we're tracking.
Definition KnownBits.h:200
APInt getMaxValue() const
Return the maximal unsigned value possible given these KnownBits.
Definition KnownBits.h:146
APInt getMinValue() const
Return the minimal unsigned value possible given these KnownBits.
Definition KnownBits.h:130
bool isNegative() const
Returns true if this value is known to be negative.
Definition KnownBits.h:103
static LLVM_ABI KnownBits shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW=false, bool NSW=false, bool ShAmtNonZero=false)
Compute known bits for shl(LHS, RHS).
An object of this class is returned by queries that could not be answered.
static LLVM_ABI bool classof(const SCEV *S)
Methods for support type inquiry through isa, cast, and dyn_cast:
This class defines a simple visitor class that may be used for various SCEV analysis purposes.
A utility class that uses RAII to save and restore the value of a variable.
Information about the number of loop iterations for which a loop exit's branch condition evaluates to...
LLVM_ABI ExitLimit(const SCEV *E)
Construct either an exact exit limit from a constant, or an unknown one from a SCEVCouldNotCompute.
SmallVector< const SCEVPredicate *, 4 > Predicates
A vector of predicate guards for this ExitLimit.