LLVM 23.0.0git
ScalarEvolution.cpp
Go to the documentation of this file.
1//===- ScalarEvolution.cpp - Scalar Evolution Analysis --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains the implementation of the scalar evolution analysis
10// engine, which is used primarily to analyze expressions involving induction
11// variables in loops.
12//
13// There are several aspects to this library. First is the representation of
14// scalar expressions, which are represented as subclasses of the SCEV class.
15// These classes are used to represent certain types of subexpressions that we
16// can handle. We only create one SCEV of a particular shape, so
17// pointer-comparisons for equality are legal.
18//
19// One important aspect of the SCEV objects is that they are never cyclic, even
20// if there is a cycle in the dataflow for an expression (ie, a PHI node). If
21// the PHI node is one of the idioms that we can represent (e.g., a polynomial
22// recurrence) then we represent it directly as a recurrence node, otherwise we
23// represent it as a SCEVUnknown node.
24//
25// In addition to being able to represent expressions of various types, we also
26// have folders that are used to build the *canonical* representation for a
27// particular expression. These folders are capable of using a variety of
28// rewrite rules to simplify the expressions.
29//
30// Once the folders are defined, we can implement the more interesting
31// higher-level code, such as the code that recognizes PHI nodes of various
32// types, computes the execution count of a loop, etc.
33//
34// TODO: We should use these routines and value representations to implement
35// dependence analysis!
36//
37//===----------------------------------------------------------------------===//
38//
39// There are several good references for the techniques used in this analysis.
40//
41// Chains of recurrences -- a method to expedite the evaluation
42// of closed-form functions
43// Olaf Bachmann, Paul S. Wang, Eugene V. Zima
44//
45// On computational properties of chains of recurrences
46// Eugene V. Zima
47//
48// Symbolic Evaluation of Chains of Recurrences for Loop Optimization
49// Robert A. van Engelen
50//
51// Efficient Symbolic Analysis for Optimizing Compilers
52// Robert A. van Engelen
53//
54// Using the chains of recurrences algebra for data dependence testing and
55// induction variable substitution
56// MS Thesis, Johnie Birch
57//
58//===----------------------------------------------------------------------===//
59
61#include "llvm/ADT/APInt.h"
62#include "llvm/ADT/ArrayRef.h"
63#include "llvm/ADT/DenseMap.h"
65#include "llvm/ADT/FoldingSet.h"
66#include "llvm/ADT/STLExtras.h"
67#include "llvm/ADT/ScopeExit.h"
68#include "llvm/ADT/Sequence.h"
71#include "llvm/ADT/Statistic.h"
73#include "llvm/ADT/StringRef.h"
83#include "llvm/Config/llvm-config.h"
84#include "llvm/IR/Argument.h"
85#include "llvm/IR/BasicBlock.h"
86#include "llvm/IR/CFG.h"
87#include "llvm/IR/Constant.h"
89#include "llvm/IR/Constants.h"
90#include "llvm/IR/DataLayout.h"
92#include "llvm/IR/Dominators.h"
93#include "llvm/IR/Function.h"
94#include "llvm/IR/GlobalAlias.h"
95#include "llvm/IR/GlobalValue.h"
97#include "llvm/IR/InstrTypes.h"
98#include "llvm/IR/Instruction.h"
101#include "llvm/IR/Intrinsics.h"
102#include "llvm/IR/LLVMContext.h"
103#include "llvm/IR/Operator.h"
104#include "llvm/IR/PatternMatch.h"
105#include "llvm/IR/Type.h"
106#include "llvm/IR/Use.h"
107#include "llvm/IR/User.h"
108#include "llvm/IR/Value.h"
109#include "llvm/IR/Verifier.h"
111#include "llvm/Pass.h"
112#include "llvm/Support/Casting.h"
115#include "llvm/Support/Debug.h"
121#include <algorithm>
122#include <cassert>
123#include <climits>
124#include <cstdint>
125#include <cstdlib>
126#include <map>
127#include <memory>
128#include <numeric>
129#include <optional>
130#include <tuple>
131#include <utility>
132#include <vector>
133
134using namespace llvm;
135using namespace PatternMatch;
136using namespace SCEVPatternMatch;
137
138#define DEBUG_TYPE "scalar-evolution"
139
140STATISTIC(NumExitCountsComputed,
141 "Number of loop exits with predictable exit counts");
142STATISTIC(NumExitCountsNotComputed,
143 "Number of loop exits without predictable exit counts");
144STATISTIC(NumBruteForceTripCountsComputed,
145 "Number of loops with trip counts computed by force");
146
147#ifdef EXPENSIVE_CHECKS
148bool llvm::VerifySCEV = true;
149#else
150bool llvm::VerifySCEV = false;
151#endif
152
154 MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
155 cl::desc("Maximum number of iterations SCEV will "
156 "symbolically execute a constant "
157 "derived loop"),
158 cl::init(100));
159
161 "verify-scev", cl::Hidden, cl::location(VerifySCEV),
162 cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"));
164 "verify-scev-strict", cl::Hidden,
165 cl::desc("Enable stricter verification with -verify-scev is passed"));
166
168 "scev-verify-ir", cl::Hidden,
169 cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"),
170 cl::init(false));
171
173 "scev-mulops-inline-threshold", cl::Hidden,
174 cl::desc("Threshold for inlining multiplication operands into a SCEV"),
175 cl::init(32));
176
178 "scev-addops-inline-threshold", cl::Hidden,
179 cl::desc("Threshold for inlining addition operands into a SCEV"),
180 cl::init(500));
181
183 "scalar-evolution-max-scev-compare-depth", cl::Hidden,
184 cl::desc("Maximum depth of recursive SCEV complexity comparisons"),
185 cl::init(32));
186
188 "scalar-evolution-max-scev-operations-implication-depth", cl::Hidden,
189 cl::desc("Maximum depth of recursive SCEV operations implication analysis"),
190 cl::init(2));
191
193 "scalar-evolution-max-value-compare-depth", cl::Hidden,
194 cl::desc("Maximum depth of recursive value complexity comparisons"),
195 cl::init(2));
196
198 MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden,
199 cl::desc("Maximum depth of recursive arithmetics"),
200 cl::init(32));
201
203 "scalar-evolution-max-constant-evolving-depth", cl::Hidden,
204 cl::desc("Maximum depth of recursive constant evolving"), cl::init(32));
205
207 MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden,
208 cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"),
209 cl::init(8));
210
212 MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden,
213 cl::desc("Max coefficients in AddRec during evolving"),
214 cl::init(8));
215
217 HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden,
218 cl::desc("Size of the expression which is considered huge"),
219 cl::init(4096));
220
222 "scev-range-iter-threshold", cl::Hidden,
223 cl::desc("Threshold for switching to iteratively computing SCEV ranges"),
224 cl::init(32));
225
227 "scalar-evolution-max-loop-guard-collection-depth", cl::Hidden,
228 cl::desc("Maximum depth for recursive loop guard collection"), cl::init(1));
229
230static cl::opt<bool>
231ClassifyExpressions("scalar-evolution-classify-expressions",
232 cl::Hidden, cl::init(true),
233 cl::desc("When printing analysis, include information on every instruction"));
234
236 "scalar-evolution-use-expensive-range-sharpening", cl::Hidden,
237 cl::init(false),
238 cl::desc("Use more powerful methods of sharpening expression ranges. May "
239 "be costly in terms of compile time"));
240
242 "scalar-evolution-max-scc-analysis-depth", cl::Hidden,
243 cl::desc("Maximum amount of nodes to process while searching SCEVUnknown "
244 "Phi strongly connected components"),
245 cl::init(8));
246
247static cl::opt<bool>
248 EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden,
249 cl::desc("Handle <= and >= in finite loops"),
250 cl::init(true));
251
253 "scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden,
254 cl::desc("Infer nuw/nsw flags using context where suitable"),
255 cl::init(true));
256
257//===----------------------------------------------------------------------===//
258// SCEV class definitions
259//===----------------------------------------------------------------------===//
260
262 // Leaf nodes are always their own canonical.
263 switch (getSCEVType()) {
264 case scConstant:
265 case scVScale:
266 case scUnknown:
267 CanonicalSCEV = this;
268 return;
269 default:
270 break;
271 }
272
273 // For all other expressions, check whether any immediate operand has a
274 // different canonical. Since operands are always created before their parent,
275 // their canonical pointers are already set — no recursion needed.
276 bool Changed = false;
278 for (SCEVUse Op : operands()) {
279 CanonOps.push_back(Op->getCanonical());
280 Changed |= CanonOps.back() != Op.getPointer();
281 }
282
283 if (!Changed) {
284 CanonicalSCEV = this;
285 return;
286 }
287
288 auto *NAry = dyn_cast<SCEVNAryExpr>(this);
289 SCEV::NoWrapFlags Flags = NAry ? NAry->getNoWrapFlags() : SCEV::FlagAnyWrap;
290 switch (getSCEVType()) {
291 case scPtrToAddr:
292 CanonicalSCEV = SE.getPtrToAddrExpr(CanonOps[0]);
293 return;
294 case scPtrToInt:
295 CanonicalSCEV = SE.getPtrToIntExpr(CanonOps[0], getType());
296 return;
297 case scTruncate:
298 CanonicalSCEV = SE.getTruncateExpr(CanonOps[0], getType());
299 return;
300 case scZeroExtend:
301 CanonicalSCEV = SE.getZeroExtendExpr(CanonOps[0], getType());
302 return;
303 case scSignExtend:
304 CanonicalSCEV = SE.getSignExtendExpr(CanonOps[0], getType());
305 return;
306 case scUDivExpr:
307 CanonicalSCEV = SE.getUDivExpr(CanonOps[0], CanonOps[1]);
308 return;
309 case scAddExpr:
310 CanonicalSCEV = SE.getAddExpr(CanonOps, Flags);
311 return;
312 case scMulExpr:
313 CanonicalSCEV = SE.getMulExpr(CanonOps, Flags);
314 return;
315 case scAddRecExpr:
317 CanonOps, cast<SCEVAddRecExpr>(this)->getLoop(), Flags);
318 return;
319 case scSMaxExpr:
320 CanonicalSCEV = SE.getSMaxExpr(CanonOps);
321 return;
322 case scUMaxExpr:
323 CanonicalSCEV = SE.getUMaxExpr(CanonOps);
324 return;
325 case scSMinExpr:
326 CanonicalSCEV = SE.getSMinExpr(CanonOps);
327 return;
328 case scUMinExpr:
329 CanonicalSCEV = SE.getUMinExpr(CanonOps);
330 return;
332 CanonicalSCEV = SE.getUMinExpr(CanonOps, /*Sequential=*/true);
333 return;
334 default:
335 llvm_unreachable("Unknown SCEV type");
336 }
337}
338
339//===----------------------------------------------------------------------===//
340// Implementation of the SCEV class.
341//
342
343#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
345 print(dbgs());
346 dbgs() << '\n';
347}
348#endif
349
350void SCEV::print(raw_ostream &OS) const {
351 switch (getSCEVType()) {
352 case scConstant:
353 cast<SCEVConstant>(this)->getValue()->printAsOperand(OS, false);
354 return;
355 case scVScale:
356 OS << "vscale";
357 return;
358 case scPtrToAddr:
359 case scPtrToInt: {
360 const SCEVCastExpr *PtrCast = cast<SCEVCastExpr>(this);
361 const SCEV *Op = PtrCast->getOperand();
362 StringRef OpS = getSCEVType() == scPtrToAddr ? "addr" : "int";
363 OS << "(ptrto" << OpS << " " << *Op->getType() << " " << *Op << " to "
364 << *PtrCast->getType() << ")";
365 return;
366 }
367 case scTruncate: {
368 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(this);
369 const SCEV *Op = Trunc->getOperand();
370 OS << "(trunc " << *Op->getType() << " " << *Op << " to "
371 << *Trunc->getType() << ")";
372 return;
373 }
374 case scZeroExtend: {
376 const SCEV *Op = ZExt->getOperand();
377 OS << "(zext " << *Op->getType() << " " << *Op << " to "
378 << *ZExt->getType() << ")";
379 return;
380 }
381 case scSignExtend: {
383 const SCEV *Op = SExt->getOperand();
384 OS << "(sext " << *Op->getType() << " " << *Op << " to "
385 << *SExt->getType() << ")";
386 return;
387 }
388 case scAddRecExpr: {
389 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(this);
390 OS << "{" << *AR->getOperand(0);
391 for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i)
392 OS << ",+," << *AR->getOperand(i);
393 OS << "}<";
394 if (AR->hasNoUnsignedWrap())
395 OS << "nuw><";
396 if (AR->hasNoSignedWrap())
397 OS << "nsw><";
398 if (AR->hasNoSelfWrap() && !AR->hasNoUnsignedWrap() &&
399 !AR->hasNoSignedWrap())
400 OS << "nw><";
401 AR->getLoop()->getHeader()->printAsOperand(OS, /*PrintType=*/false);
402 OS << ">";
403 return;
404 }
405 case scAddExpr:
406 case scMulExpr:
407 case scUMaxExpr:
408 case scSMaxExpr:
409 case scUMinExpr:
410 case scSMinExpr:
412 const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(this);
413 const char *OpStr = nullptr;
414 switch (NAry->getSCEVType()) {
415 case scAddExpr: OpStr = " + "; break;
416 case scMulExpr: OpStr = " * "; break;
417 case scUMaxExpr: OpStr = " umax "; break;
418 case scSMaxExpr: OpStr = " smax "; break;
419 case scUMinExpr:
420 OpStr = " umin ";
421 break;
422 case scSMinExpr:
423 OpStr = " smin ";
424 break;
426 OpStr = " umin_seq ";
427 break;
428 default:
429 llvm_unreachable("There are no other nary expression types.");
430 }
431 OS << "("
433 << ")";
434 switch (NAry->getSCEVType()) {
435 case scAddExpr:
436 case scMulExpr:
437 if (NAry->hasNoUnsignedWrap())
438 OS << "<nuw>";
439 if (NAry->hasNoSignedWrap())
440 OS << "<nsw>";
441 break;
442 default:
443 // Nothing to print for other nary expressions.
444 break;
445 }
446 return;
447 }
448 case scUDivExpr: {
449 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(this);
450 OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")";
451 return;
452 }
453 case scUnknown:
454 cast<SCEVUnknown>(this)->getValue()->printAsOperand(OS, false);
455 return;
457 OS << "***COULDNOTCOMPUTE***";
458 return;
459 }
460 llvm_unreachable("Unknown SCEV kind!");
461}
462
464 switch (getSCEVType()) {
465 case scConstant:
466 return cast<SCEVConstant>(this)->getType();
467 case scVScale:
468 return cast<SCEVVScale>(this)->getType();
469 case scPtrToAddr:
470 case scPtrToInt:
471 case scTruncate:
472 case scZeroExtend:
473 case scSignExtend:
474 return cast<SCEVCastExpr>(this)->getType();
475 case scAddRecExpr:
476 return cast<SCEVAddRecExpr>(this)->getType();
477 case scMulExpr:
478 return cast<SCEVMulExpr>(this)->getType();
479 case scUMaxExpr:
480 case scSMaxExpr:
481 case scUMinExpr:
482 case scSMinExpr:
483 return cast<SCEVMinMaxExpr>(this)->getType();
485 return cast<SCEVSequentialMinMaxExpr>(this)->getType();
486 case scAddExpr:
487 return cast<SCEVAddExpr>(this)->getType();
488 case scUDivExpr:
489 return cast<SCEVUDivExpr>(this)->getType();
490 case scUnknown:
491 return cast<SCEVUnknown>(this)->getType();
493 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
494 }
495 llvm_unreachable("Unknown SCEV kind!");
496}
497
499 switch (getSCEVType()) {
500 case scConstant:
501 case scVScale:
502 case scUnknown:
503 return {};
504 case scPtrToAddr:
505 case scPtrToInt:
506 case scTruncate:
507 case scZeroExtend:
508 case scSignExtend:
509 return cast<SCEVCastExpr>(this)->operands();
510 case scAddRecExpr:
511 case scAddExpr:
512 case scMulExpr:
513 case scUMaxExpr:
514 case scSMaxExpr:
515 case scUMinExpr:
516 case scSMinExpr:
518 return cast<SCEVNAryExpr>(this)->operands();
519 case scUDivExpr:
520 return cast<SCEVUDivExpr>(this)->operands();
522 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
523 }
524 llvm_unreachable("Unknown SCEV kind!");
525}
526
527bool SCEV::isZero() const { return match(this, m_scev_Zero()); }
528
529bool SCEV::isOne() const { return match(this, m_scev_One()); }
530
531bool SCEV::isAllOnesValue() const { return match(this, m_scev_AllOnes()); }
532
535 if (!Mul) return false;
536
537 // If there is a constant factor, it will be first.
538 const SCEVConstant *SC = dyn_cast<SCEVConstant>(Mul->getOperand(0));
539 if (!SC) return false;
540
541 // Return true if the value is negative, this matches things like (-42 * V).
542 return SC->getAPInt().isNegative();
543}
544
547
549 return S->getSCEVType() == scCouldNotCompute;
550}
551
554 ID.AddInteger(scConstant);
555 ID.AddPointer(V);
556 void *IP = nullptr;
557 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
558 SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V);
559 UniqueSCEVs.InsertNode(S, IP);
560 S->computeAndSetCanonical(*this);
561 return S;
562}
563
565 return getConstant(ConstantInt::get(getContext(), Val));
566}
567
568const SCEV *
571 // TODO: Avoid implicit trunc?
572 // See https://github.com/llvm/llvm-project/issues/112510.
573 return getConstant(
574 ConstantInt::get(ITy, V, isSigned, /*ImplicitTrunc=*/true));
575}
576
579 ID.AddInteger(scVScale);
580 ID.AddPointer(Ty);
581 void *IP = nullptr;
582 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
583 return S;
584 SCEV *S = new (SCEVAllocator) SCEVVScale(ID.Intern(SCEVAllocator), Ty);
585 UniqueSCEVs.InsertNode(S, IP);
586 S->computeAndSetCanonical(*this);
587 return S;
588}
589
591 SCEV::NoWrapFlags Flags) {
592 const SCEV *Res = getConstant(Ty, EC.getKnownMinValue());
593 if (EC.isScalable())
594 Res = getMulExpr(Res, getVScale(Ty), Flags);
595 return Res;
596}
597
601
602SCEVPtrToAddrExpr::SCEVPtrToAddrExpr(const FoldingSetNodeIDRef ID,
603 const SCEV *Op, Type *ITy)
604 : SCEVCastExpr(ID, scPtrToAddr, Op, ITy) {
605 assert(getOperand()->getType()->isPointerTy() && Ty->isIntegerTy() &&
606 "Must be a non-bit-width-changing pointer-to-integer cast!");
607}
608
609SCEVPtrToIntExpr::SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, SCEVUse Op,
610 Type *ITy)
611 : SCEVCastExpr(ID, scPtrToInt, Op, ITy) {
612 assert(getOperand()->getType()->isPointerTy() && Ty->isIntegerTy() &&
613 "Must be a non-bit-width-changing pointer-to-integer cast!");
614}
615
620
621SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID, SCEVUse op,
622 Type *ty)
624 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
625 "Cannot truncate non-integer value!");
626}
627
628SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID, SCEVUse op,
629 Type *ty)
631 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
632 "Cannot zero extend non-integer value!");
633}
634
635SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID, SCEVUse op,
636 Type *ty)
638 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
639 "Cannot sign extend non-integer value!");
640}
641
643 // Clear this SCEVUnknown from various maps.
644 SE->forgetMemoizedResults({this});
645
646 // Remove this SCEVUnknown from the uniquing map.
647 SE->UniqueSCEVs.RemoveNode(this);
648
649 // Release the value.
650 setValPtr(nullptr);
651}
652
653void SCEVUnknown::allUsesReplacedWith(Value *New) {
654 // Clear this SCEVUnknown from various maps.
655 SE->forgetMemoizedResults({this});
656
657 // Remove this SCEVUnknown from the uniquing map.
658 SE->UniqueSCEVs.RemoveNode(this);
659
660 // Replace the value pointer in case someone is still using this SCEVUnknown.
661 setValPtr(New);
662}
663
664//===----------------------------------------------------------------------===//
665// SCEV Utilities
666//===----------------------------------------------------------------------===//
667
668/// Compare the two values \p LV and \p RV in terms of their "complexity" where
669/// "complexity" is a partial (and somewhat ad-hoc) relation used to order
670/// operands in SCEV expressions.
671static int CompareValueComplexity(const LoopInfo *const LI, Value *LV,
672 Value *RV, unsigned Depth) {
674 return 0;
675
676 // Order pointer values after integer values. This helps SCEVExpander form
677 // GEPs.
678 bool LIsPointer = LV->getType()->isPointerTy(),
679 RIsPointer = RV->getType()->isPointerTy();
680 if (LIsPointer != RIsPointer)
681 return (int)LIsPointer - (int)RIsPointer;
682
683 // Compare getValueID values.
684 unsigned LID = LV->getValueID(), RID = RV->getValueID();
685 if (LID != RID)
686 return (int)LID - (int)RID;
687
688 // Sort arguments by their position.
689 if (const auto *LA = dyn_cast<Argument>(LV)) {
690 const auto *RA = cast<Argument>(RV);
691 unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo();
692 return (int)LArgNo - (int)RArgNo;
693 }
694
695 if (const auto *LGV = dyn_cast<GlobalValue>(LV)) {
696 const auto *RGV = cast<GlobalValue>(RV);
697
698 if (auto L = LGV->getLinkage() - RGV->getLinkage())
699 return L;
700
701 const auto IsGVNameSemantic = [&](const GlobalValue *GV) {
702 auto LT = GV->getLinkage();
703 return !(GlobalValue::isPrivateLinkage(LT) ||
705 };
706
707 // Use the names to distinguish the two values, but only if the
708 // names are semantically important.
709 if (IsGVNameSemantic(LGV) && IsGVNameSemantic(RGV))
710 return LGV->getName().compare(RGV->getName());
711 }
712
713 // For instructions, compare their loop depth, and their operand count. This
714 // is pretty loose.
715 if (const auto *LInst = dyn_cast<Instruction>(LV)) {
716 const auto *RInst = cast<Instruction>(RV);
717
718 // Compare loop depths.
719 const BasicBlock *LParent = LInst->getParent(),
720 *RParent = RInst->getParent();
721 if (LParent != RParent) {
722 unsigned LDepth = LI->getLoopDepth(LParent),
723 RDepth = LI->getLoopDepth(RParent);
724 if (LDepth != RDepth)
725 return (int)LDepth - (int)RDepth;
726 }
727
728 // Compare the number of operands.
729 unsigned LNumOps = LInst->getNumOperands(),
730 RNumOps = RInst->getNumOperands();
731 if (LNumOps != RNumOps)
732 return (int)LNumOps - (int)RNumOps;
733
734 for (unsigned Idx : seq(LNumOps)) {
735 int Result = CompareValueComplexity(LI, LInst->getOperand(Idx),
736 RInst->getOperand(Idx), Depth + 1);
737 if (Result != 0)
738 return Result;
739 }
740 }
741
742 return 0;
743}
744
745// Return negative, zero, or positive, if LHS is less than, equal to, or greater
746// than RHS, respectively. A three-way result allows recursive comparisons to be
747// more efficient.
748// If the max analysis depth was reached, return std::nullopt, assuming we do
749// not know if they are equivalent for sure.
750static std::optional<int>
751CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS,
752 const SCEV *RHS, DominatorTree &DT, unsigned Depth = 0) {
753 // Fast-path: SCEVs are uniqued so we can do a quick equality check.
754 if (LHS == RHS)
755 return 0;
756
757 // Primarily, sort the SCEVs by their getSCEVType().
758 SCEVTypes LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
759 if (LType != RType)
760 return (int)LType - (int)RType;
761
763 return std::nullopt;
764
765 // Aside from the getSCEVType() ordering, the particular ordering
766 // isn't very important except that it's beneficial to be consistent,
767 // so that (a + b) and (b + a) don't end up as different expressions.
768 switch (LType) {
769 case scUnknown: {
770 const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
771 const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
772
773 int X =
774 CompareValueComplexity(LI, LU->getValue(), RU->getValue(), Depth + 1);
775 return X;
776 }
777
778 case scConstant: {
781
782 // Compare constant values.
783 const APInt &LA = LC->getAPInt();
784 const APInt &RA = RC->getAPInt();
785 unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth();
786 if (LBitWidth != RBitWidth)
787 return (int)LBitWidth - (int)RBitWidth;
788 return LA.ult(RA) ? -1 : 1;
789 }
790
791 case scVScale: {
792 const auto *LTy = cast<IntegerType>(cast<SCEVVScale>(LHS)->getType());
793 const auto *RTy = cast<IntegerType>(cast<SCEVVScale>(RHS)->getType());
794 return LTy->getBitWidth() - RTy->getBitWidth();
795 }
796
797 case scAddRecExpr: {
800
801 // There is always a dominance between two recs that are used by one SCEV,
802 // so we can safely sort recs by loop header dominance. We require such
803 // order in getAddExpr.
804 const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop();
805 if (LLoop != RLoop) {
806 const BasicBlock *LHead = LLoop->getHeader(), *RHead = RLoop->getHeader();
807 assert(LHead != RHead && "Two loops share the same header?");
808 if (DT.dominates(LHead, RHead))
809 return 1;
810 assert(DT.dominates(RHead, LHead) &&
811 "No dominance between recurrences used by one SCEV?");
812 return -1;
813 }
814
815 [[fallthrough]];
816 }
817
818 case scTruncate:
819 case scZeroExtend:
820 case scSignExtend:
821 case scPtrToAddr:
822 case scPtrToInt:
823 case scAddExpr:
824 case scMulExpr:
825 case scUDivExpr:
826 case scSMaxExpr:
827 case scUMaxExpr:
828 case scSMinExpr:
829 case scUMinExpr:
831 ArrayRef<SCEVUse> LOps = LHS->operands();
832 ArrayRef<SCEVUse> ROps = RHS->operands();
833
834 // Lexicographically compare n-ary-like expressions.
835 unsigned LNumOps = LOps.size(), RNumOps = ROps.size();
836 if (LNumOps != RNumOps)
837 return (int)LNumOps - (int)RNumOps;
838
839 for (unsigned i = 0; i != LNumOps; ++i) {
840 auto X = CompareSCEVComplexity(LI, LOps[i].getPointer(),
841 ROps[i].getPointer(), DT, Depth + 1);
842 if (X != 0)
843 return X;
844 }
845 return 0;
846 }
847
849 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
850 }
851 llvm_unreachable("Unknown SCEV kind!");
852}
853
854/// Given a list of SCEV objects, order them by their complexity, and group
855/// objects of the same complexity together by value. When this routine is
856/// finished, we know that any duplicates in the vector are consecutive and that
857/// complexity is monotonically increasing.
858///
859/// Note that we go take special precautions to ensure that we get deterministic
860/// results from this routine. In other words, we don't want the results of
861/// this to depend on where the addresses of various SCEV objects happened to
862/// land in memory.
864 DominatorTree &DT) {
865 if (Ops.size() < 2) return; // Noop
866
867 // Whether LHS has provably less complexity than RHS.
868 auto IsLessComplex = [&](SCEVUse LHS, SCEVUse RHS) {
869 auto Complexity = CompareSCEVComplexity(LI, LHS, RHS, DT);
870 return Complexity && *Complexity < 0;
871 };
872 if (Ops.size() == 2) {
873 // This is the common case, which also happens to be trivially simple.
874 // Special case it.
875 SCEVUse &LHS = Ops[0], &RHS = Ops[1];
876 if (IsLessComplex(RHS, LHS))
877 std::swap(LHS, RHS);
878 return;
879 }
880
881 // Do the rough sort by complexity.
883 Ops, [&](SCEVUse LHS, SCEVUse RHS) { return IsLessComplex(LHS, RHS); });
884
885 // Now that we are sorted by complexity, group elements of the same
886 // complexity. Note that this is, at worst, N^2, but the vector is likely to
887 // be extremely short in practice. Note that we take this approach because we
888 // do not want to depend on the addresses of the objects we are grouping.
889 for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) {
890 const SCEV *S = Ops[i];
891 unsigned Complexity = S->getSCEVType();
892
893 // If there are any objects of the same complexity and same value as this
894 // one, group them.
895 for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) {
896 if (Ops[j] == S) { // Found a duplicate.
897 // Move it to immediately after i'th element.
898 std::swap(Ops[i+1], Ops[j]);
899 ++i; // no need to rescan it.
900 if (i == e-2) return; // Done!
901 }
902 }
903 }
904}
905
906/// Returns true if \p Ops contains a huge SCEV (the subtree of S contains at
907/// least HugeExprThreshold nodes).
909 return any_of(Ops, [](const SCEV *S) {
911 });
912}
913
914/// Performs a number of common optimizations on the passed \p Ops. If the
915/// whole expression reduces down to a single operand, it will be returned.
916///
917/// The following optimizations are performed:
918/// * Fold constants using the \p Fold function.
919/// * Remove identity constants satisfying \p IsIdentity.
920/// * If a constant satisfies \p IsAbsorber, return it.
921/// * Sort operands by complexity.
922template <typename FoldT, typename IsIdentityT, typename IsAbsorberT>
923static const SCEV *
925 SmallVectorImpl<SCEVUse> &Ops, FoldT Fold,
926 IsIdentityT IsIdentity, IsAbsorberT IsAbsorber) {
927 const SCEVConstant *Folded = nullptr;
928 for (unsigned Idx = 0; Idx < Ops.size();) {
929 const SCEV *Op = Ops[Idx];
930 if (const auto *C = dyn_cast<SCEVConstant>(Op)) {
931 if (!Folded)
932 Folded = C;
933 else
934 Folded = cast<SCEVConstant>(
935 SE.getConstant(Fold(Folded->getAPInt(), C->getAPInt())));
936 Ops.erase(Ops.begin() + Idx);
937 continue;
938 }
939 ++Idx;
940 }
941
942 if (Ops.empty()) {
943 assert(Folded && "Must have folded value");
944 return Folded;
945 }
946
947 if (Folded && IsAbsorber(Folded->getAPInt()))
948 return Folded;
949
950 GroupByComplexity(Ops, &LI, DT);
951 if (Folded && !IsIdentity(Folded->getAPInt()))
952 Ops.insert(Ops.begin(), Folded);
953
954 return Ops.size() == 1 ? Ops[0] : nullptr;
955}
956
957//===----------------------------------------------------------------------===//
958// Simple SCEV method implementations
959//===----------------------------------------------------------------------===//
960
961/// Compute BC(It, K). The result has width W. Assume, K > 0.
962static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
963 ScalarEvolution &SE,
964 Type *ResultTy) {
965 // Handle the simplest case efficiently.
966 if (K == 1)
967 return SE.getTruncateOrZeroExtend(It, ResultTy);
968
969 // We are using the following formula for BC(It, K):
970 //
971 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
972 //
973 // Suppose, W is the bitwidth of the return value. We must be prepared for
974 // overflow. Hence, we must assure that the result of our computation is
975 // equal to the accurate one modulo 2^W. Unfortunately, division isn't
976 // safe in modular arithmetic.
977 //
978 // However, this code doesn't use exactly that formula; the formula it uses
979 // is something like the following, where T is the number of factors of 2 in
980 // K! (i.e. trailing zeros in the binary representation of K!), and ^ is
981 // exponentiation:
982 //
983 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T)
984 //
985 // This formula is trivially equivalent to the previous formula. However,
986 // this formula can be implemented much more efficiently. The trick is that
987 // K! / 2^T is odd, and exact division by an odd number *is* safe in modular
988 // arithmetic. To do exact division in modular arithmetic, all we have
989 // to do is multiply by the inverse. Therefore, this step can be done at
990 // width W.
991 //
992 // The next issue is how to safely do the division by 2^T. The way this
993 // is done is by doing the multiplication step at a width of at least W + T
994 // bits. This way, the bottom W+T bits of the product are accurate. Then,
995 // when we perform the division by 2^T (which is equivalent to a right shift
996 // by T), the bottom W bits are accurate. Extra bits are okay; they'll get
997 // truncated out after the division by 2^T.
998 //
999 // In comparison to just directly using the first formula, this technique
1000 // is much more efficient; using the first formula requires W * K bits,
1001 // but this formula less than W + K bits. Also, the first formula requires
1002 // a division step, whereas this formula only requires multiplies and shifts.
1003 //
1004 // It doesn't matter whether the subtraction step is done in the calculation
1005 // width or the input iteration count's width; if the subtraction overflows,
1006 // the result must be zero anyway. We prefer here to do it in the width of
1007 // the induction variable because it helps a lot for certain cases; CodeGen
1008 // isn't smart enough to ignore the overflow, which leads to much less
1009 // efficient code if the width of the subtraction is wider than the native
1010 // register width.
1011 //
1012 // (It's possible to not widen at all by pulling out factors of 2 before
1013 // the multiplication; for example, K=2 can be calculated as
1014 // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires
1015 // extra arithmetic, so it's not an obvious win, and it gets
1016 // much more complicated for K > 3.)
1017
1018 // Protection from insane SCEVs; this bound is conservative,
1019 // but it probably doesn't matter.
1020 if (K > 1000)
1021 return SE.getCouldNotCompute();
1022
1023 unsigned W = SE.getTypeSizeInBits(ResultTy);
1024
1025 // Calculate K! / 2^T and T; we divide out the factors of two before
1026 // multiplying for calculating K! / 2^T to avoid overflow.
1027 // Other overflow doesn't matter because we only care about the bottom
1028 // W bits of the result.
1029 APInt OddFactorial(W, 1);
1030 unsigned T = 1;
1031 for (unsigned i = 3; i <= K; ++i) {
1032 unsigned TwoFactors = countr_zero(i);
1033 T += TwoFactors;
1034 OddFactorial *= (i >> TwoFactors);
1035 }
1036
1037 // We need at least W + T bits for the multiplication step
1038 unsigned CalculationBits = W + T;
1039
1040 // Calculate 2^T, at width T+W.
1041 APInt DivFactor = APInt::getOneBitSet(CalculationBits, T);
1042
1043 // Calculate the multiplicative inverse of K! / 2^T;
1044 // this multiplication factor will perform the exact division by
1045 // K! / 2^T.
1046 APInt MultiplyFactor = OddFactorial.multiplicativeInverse();
1047
1048 // Calculate the product, at width T+W
1049 IntegerType *CalculationTy = IntegerType::get(SE.getContext(),
1050 CalculationBits);
1051 const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
1052 for (unsigned i = 1; i != K; ++i) {
1053 const SCEV *S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i));
1054 Dividend = SE.getMulExpr(Dividend,
1055 SE.getTruncateOrZeroExtend(S, CalculationTy));
1056 }
1057
1058 // Divide by 2^T
1059 const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
1060
1061 // Truncate the result, and divide by K! / 2^T.
1062
1063 return SE.getMulExpr(SE.getConstant(MultiplyFactor),
1064 SE.getTruncateOrZeroExtend(DivResult, ResultTy));
1065}
1066
1067/// Return the value of this chain of recurrences at the specified iteration
1068/// number. We can evaluate this recurrence by multiplying each element in the
1069/// chain by the binomial coefficient corresponding to it. In other words, we
1070/// can evaluate {A,+,B,+,C,+,D} as:
1071///
1072/// A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
1073///
1074/// where BC(It, k) stands for binomial coefficient.
1076 ScalarEvolution &SE) const {
1077 return evaluateAtIteration(operands(), It, SE);
1078}
1079
1081 const SCEV *It,
1082 ScalarEvolution &SE) {
1083 assert(Operands.size() > 0);
1084 const SCEV *Result = Operands[0].getPointer();
1085 for (unsigned i = 1, e = Operands.size(); i != e; ++i) {
1086 // The computation is correct in the face of overflow provided that the
1087 // multiplication is performed _after_ the evaluation of the binomial
1088 // coefficient.
1089 const SCEV *Coeff = BinomialCoefficient(It, i, SE, Result->getType());
1090 if (isa<SCEVCouldNotCompute>(Coeff))
1091 return Coeff;
1092
1093 Result =
1094 SE.getAddExpr(Result, SE.getMulExpr(Operands[i].getPointer(), Coeff));
1095 }
1096 return Result;
1097}
1098
1099//===----------------------------------------------------------------------===//
1100// SCEV Expression folder implementations
1101//===----------------------------------------------------------------------===//
1102
1103/// The SCEVCastSinkingRewriter takes a scalar evolution expression,
1104/// which computes a pointer-typed value, and rewrites the whole expression
1105/// tree so that *all* the computations are done on integers, and the only
1106/// pointer-typed operands in the expression are SCEVUnknown.
1107/// The CreatePtrCast callback is invoked to create the actual conversion
1108/// (ptrtoint or ptrtoaddr) at the SCEVUnknown leaves.
1110 : public SCEVRewriteVisitor<SCEVCastSinkingRewriter> {
1112 using ConversionFn = function_ref<const SCEV *(const SCEVUnknown *)>;
1113 Type *TargetTy;
1114 ConversionFn CreatePtrCast;
1115
1116public:
1118 ConversionFn CreatePtrCast)
1119 : Base(SE), TargetTy(TargetTy), CreatePtrCast(std::move(CreatePtrCast)) {}
1120
1121 static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE,
1122 Type *TargetTy, ConversionFn CreatePtrCast) {
1123 SCEVCastSinkingRewriter Rewriter(SE, TargetTy, std::move(CreatePtrCast));
1124 return Rewriter.visit(Scev);
1125 }
1126
1127 const SCEV *visit(const SCEV *S) {
1128 Type *STy = S->getType();
1129 // If the expression is not pointer-typed, just keep it as-is.
1130 if (!STy->isPointerTy())
1131 return S;
1132 // Else, recursively sink the cast down into it.
1133 return Base::visit(S);
1134 }
1135
1136 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
1137 // Preserve wrap flags on rewritten SCEVAddExpr, which the default
1138 // implementation drops.
1139 SmallVector<SCEVUse, 2> Operands;
1140 bool Changed = false;
1141 for (SCEVUse Op : Expr->operands()) {
1142 Operands.push_back(visit(Op.getPointer()));
1143 Changed |= Op.getPointer() != Operands.back();
1144 }
1145 return !Changed ? Expr : SE.getAddExpr(Operands, Expr->getNoWrapFlags());
1146 }
1147
1148 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
1149 SmallVector<SCEVUse, 2> Operands;
1150 bool Changed = false;
1151 for (SCEVUse Op : Expr->operands()) {
1152 Operands.push_back(visit(Op.getPointer()));
1153 Changed |= Op.getPointer() != Operands.back();
1154 }
1155 return !Changed ? Expr : SE.getMulExpr(Operands, Expr->getNoWrapFlags());
1156 }
1157
1158 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
1159 assert(Expr->getType()->isPointerTy() &&
1160 "Should only reach pointer-typed SCEVUnknown's.");
1161 // Perform some basic constant folding. If the operand of the cast is a
1162 // null pointer, don't create a cast SCEV expression (that will be left
1163 // as-is), but produce a zero constant.
1165 return SE.getZero(TargetTy);
1166 return CreatePtrCast(Expr);
1167 }
1168};
1169
1171 assert(Op->getType()->isPointerTy() && "Op must be a pointer");
1172
1173 // It isn't legal for optimizations to construct new ptrtoint expressions
1174 // for non-integral pointers.
1175 if (getDataLayout().isNonIntegralPointerType(Op->getType()))
1176 return getCouldNotCompute();
1177
1178 Type *IntPtrTy = getDataLayout().getIntPtrType(Op->getType());
1179
1180 // We can only trivially model ptrtoint if SCEV's effective (integer) type
1181 // is sufficiently wide to represent all possible pointer values.
1182 // We could theoretically teach SCEV to truncate wider pointers, but
1183 // that isn't implemented for now.
1185 getDataLayout().getTypeSizeInBits(IntPtrTy))
1186 return getCouldNotCompute();
1187
1188 // Use the rewriter to sink the cast down to SCEVUnknown leaves.
1190 Op, *this, IntPtrTy, [this, IntPtrTy](const SCEVUnknown *U) {
1192 ID.AddInteger(scPtrToInt);
1193 ID.AddPointer(U);
1194 void *IP = nullptr;
1195 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1196 return S;
1197 SCEV *S = new (SCEVAllocator)
1198 SCEVPtrToIntExpr(ID.Intern(SCEVAllocator), U, IntPtrTy);
1199 UniqueSCEVs.InsertNode(S, IP);
1200 S->computeAndSetCanonical(*this);
1201 registerUser(S, U);
1202 return static_cast<const SCEV *>(S);
1203 });
1204 assert(IntOp->getType()->isIntegerTy() &&
1205 "We must have succeeded in sinking the cast, "
1206 "and ending up with an integer-typed expression!");
1207 return IntOp;
1208}
1209
1211 assert(Op->getType()->isPointerTy() && "Op must be a pointer");
1212
1213 // Treat pointers with unstable representation conservatively, since the
1214 // address bits may change.
1215 if (DL.hasUnstableRepresentation(Op->getType()))
1216 return getCouldNotCompute();
1217
1218 Type *Ty = DL.getAddressType(Op->getType());
1219
1220 // Use the rewriter to sink the cast down to SCEVUnknown leaves.
1221 // The rewriter handles null pointer constant folding.
1223 Op, *this, Ty, [this, Ty](const SCEVUnknown *U) {
1225 ID.AddInteger(scPtrToAddr);
1226 ID.AddPointer(U);
1227 ID.AddPointer(Ty);
1228 void *IP = nullptr;
1229 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1230 return S;
1231 SCEV *S = new (SCEVAllocator)
1232 SCEVPtrToAddrExpr(ID.Intern(SCEVAllocator), U, Ty);
1233 UniqueSCEVs.InsertNode(S, IP);
1234 S->computeAndSetCanonical(*this);
1235 registerUser(S, U);
1236 return static_cast<const SCEV *>(S);
1237 });
1238 assert(IntOp->getType()->isIntegerTy() &&
1239 "We must have succeeded in sinking the cast, "
1240 "and ending up with an integer-typed expression!");
1241 return IntOp;
1242}
1243
1245 assert(Ty->isIntegerTy() && "Target type must be an integer type!");
1246
1247 const SCEV *IntOp = getLosslessPtrToIntExpr(Op);
1248 if (isa<SCEVCouldNotCompute>(IntOp))
1249 return IntOp;
1250
1251 return getTruncateOrZeroExtend(IntOp, Ty);
1252}
1253
1255 unsigned Depth) {
1256 assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) &&
1257 "This is not a truncating conversion!");
1258 assert(isSCEVable(Ty) &&
1259 "This is not a conversion to a SCEVable type!");
1260 assert(!Op->getType()->isPointerTy() && "Can't truncate pointer!");
1261 Ty = getEffectiveSCEVType(Ty);
1262
1264 ID.AddInteger(scTruncate);
1265 ID.AddPointer(Op);
1266 ID.AddPointer(Ty);
1267 void *IP = nullptr;
1268 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1269
1270 // Fold if the operand is constant.
1271 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1272 return getConstant(
1273 cast<ConstantInt>(ConstantExpr::getTrunc(SC->getValue(), Ty)));
1274
1275 // trunc(trunc(x)) --> trunc(x)
1277 return getTruncateExpr(ST->getOperand(), Ty, Depth + 1);
1278
1279 // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing
1281 return getTruncateOrSignExtend(SS->getOperand(), Ty, Depth + 1);
1282
1283 // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing
1285 return getTruncateOrZeroExtend(SZ->getOperand(), Ty, Depth + 1);
1286
1287 if (Depth > MaxCastDepth) {
1288 SCEV *S =
1289 new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), Op, Ty);
1290 UniqueSCEVs.InsertNode(S, IP);
1291 S->computeAndSetCanonical(*this);
1292 registerUser(S, Op);
1293 return S;
1294 }
1295
1296 // trunc(x1 + ... + xN) --> trunc(x1) + ... + trunc(xN) and
1297 // trunc(x1 * ... * xN) --> trunc(x1) * ... * trunc(xN),
1298 // if after transforming we have at most one truncate, not counting truncates
1299 // that replace other casts.
1301 auto *CommOp = cast<SCEVCommutativeExpr>(Op);
1302 SmallVector<SCEVUse, 4> Operands;
1303 unsigned numTruncs = 0;
1304 for (unsigned i = 0, e = CommOp->getNumOperands(); i != e && numTruncs < 2;
1305 ++i) {
1306 const SCEV *S = getTruncateExpr(CommOp->getOperand(i), Ty, Depth + 1);
1307 if (!isa<SCEVIntegralCastExpr>(CommOp->getOperand(i)) &&
1309 numTruncs++;
1310 Operands.push_back(S);
1311 }
1312 if (numTruncs < 2) {
1313 if (isa<SCEVAddExpr>(Op))
1314 return getAddExpr(Operands);
1315 if (isa<SCEVMulExpr>(Op))
1316 return getMulExpr(Operands);
1317 llvm_unreachable("Unexpected SCEV type for Op.");
1318 }
1319 // Although we checked in the beginning that ID is not in the cache, it is
1320 // possible that during recursion and different modification ID was inserted
1321 // into the cache. So if we find it, just return it.
1322 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1323 return S;
1324 }
1325
1326 // If the input value is a chrec scev, truncate the chrec's operands.
1327 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
1328 SmallVector<SCEVUse, 4> Operands;
1329 for (const SCEV *Op : AddRec->operands())
1330 Operands.push_back(getTruncateExpr(Op, Ty, Depth + 1));
1331 return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap);
1332 }
1333
1334 // Return zero if truncating to known zeros.
1335 uint32_t MinTrailingZeros = getMinTrailingZeros(Op);
1336 if (MinTrailingZeros >= getTypeSizeInBits(Ty))
1337 return getZero(Ty);
1338
1339 // The cast wasn't folded; create an explicit cast node. We can reuse
1340 // the existing insert position since if we get here, we won't have
1341 // made any changes which would invalidate it.
1342 SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator),
1343 Op, Ty);
1344 UniqueSCEVs.InsertNode(S, IP);
1345 S->computeAndSetCanonical(*this);
1346 registerUser(S, Op);
1347 return S;
1348}
1349
1350// Get the limit of a recurrence such that incrementing by Step cannot cause
1351// signed overflow as long as the value of the recurrence within the
1352// loop does not exceed this limit before incrementing.
1353static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step,
1354 ICmpInst::Predicate *Pred,
1355 ScalarEvolution *SE) {
1356 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1357 if (SE->isKnownPositive(Step)) {
1358 *Pred = ICmpInst::ICMP_SLT;
1360 SE->getSignedRangeMax(Step));
1361 }
1362 if (SE->isKnownNegative(Step)) {
1363 *Pred = ICmpInst::ICMP_SGT;
1365 SE->getSignedRangeMin(Step));
1366 }
1367 return nullptr;
1368}
1369
1370// Get the limit of a recurrence such that incrementing by Step cannot cause
1371// unsigned overflow as long as the value of the recurrence within the loop does
1372// not exceed this limit before incrementing.
1374 ICmpInst::Predicate *Pred,
1375 ScalarEvolution *SE) {
1376 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1377 *Pred = ICmpInst::ICMP_ULT;
1378
1380 SE->getUnsignedRangeMax(Step));
1381}
1382
1383namespace {
1384
1385struct ExtendOpTraitsBase {
1386 typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *,
1387 unsigned);
1388};
1389
1390// Used to make code generic over signed and unsigned overflow.
1391template <typename ExtendOp> struct ExtendOpTraits {
1392 // Members present:
1393 //
1394 // static const SCEV::NoWrapFlags WrapType;
1395 //
1396 // static const ExtendOpTraitsBase::GetExtendExprTy GetExtendExpr;
1397 //
1398 // static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1399 // ICmpInst::Predicate *Pred,
1400 // ScalarEvolution *SE);
1401};
1402
1403template <>
1404struct ExtendOpTraits<SCEVSignExtendExpr> : public ExtendOpTraitsBase {
1405 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNSW;
1406
1407 static const GetExtendExprTy GetExtendExpr;
1408
1409 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1410 ICmpInst::Predicate *Pred,
1411 ScalarEvolution *SE) {
1412 return getSignedOverflowLimitForStep(Step, Pred, SE);
1413 }
1414};
1415
1416const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1418
1419template <>
1420struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase {
1421 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNUW;
1422
1423 static const GetExtendExprTy GetExtendExpr;
1424
1425 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1426 ICmpInst::Predicate *Pred,
1427 ScalarEvolution *SE) {
1428 return getUnsignedOverflowLimitForStep(Step, Pred, SE);
1429 }
1430};
1431
1432const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1434
1435} // end anonymous namespace
1436
1437// The recurrence AR has been shown to have no signed/unsigned wrap or something
1438// close to it. Typically, if we can prove NSW/NUW for AR, then we can just as
1439// easily prove NSW/NUW for its preincrement or postincrement sibling. This
1440// allows normalizing a sign/zero extended AddRec as such: {sext/zext(Step +
1441// Start),+,Step} => {(Step + sext/zext(Start),+,Step} As a result, the
1442// expression "Step + sext/zext(PreIncAR)" is congruent with
1443// "sext/zext(PostIncAR)"
1444template <typename ExtendOpTy>
1445static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty,
1446 ScalarEvolution *SE, unsigned Depth) {
1447 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1448 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1449
1450 const Loop *L = AR->getLoop();
1451 const SCEV *Start = AR->getStart();
1452 const SCEV *Step = AR->getStepRecurrence(*SE);
1453
1454 // Check for a simple looking step prior to loop entry.
1455 const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Start);
1456 if (!SA)
1457 return nullptr;
1458
1459 // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV
1460 // subtraction is expensive. For this purpose, perform a quick and dirty
1461 // difference, by checking for Step in the operand list. Note, that
1462 // SA might have repeated ops, like %a + %a + ..., so only remove one.
1463 SmallVector<SCEVUse, 4> DiffOps(SA->operands());
1464 for (auto It = DiffOps.begin(); It != DiffOps.end(); ++It)
1465 if (*It == Step) {
1466 DiffOps.erase(It);
1467 break;
1468 }
1469
1470 if (DiffOps.size() == SA->getNumOperands())
1471 return nullptr;
1472
1473 // Try to prove `WrapType` (SCEV::FlagNSW or SCEV::FlagNUW) on `PreStart` +
1474 // `Step`:
1475
1476 // 1. NSW/NUW flags on the step increment.
1477 auto PreStartFlags =
1479 const SCEV *PreStart = SE->getAddExpr(DiffOps, PreStartFlags);
1481 SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap));
1482
1483 // "{S,+,X} is <nsw>/<nuw>" and "the backedge is taken at least once" implies
1484 // "S+X does not sign/unsign-overflow".
1485 //
1486
1487 const SCEV *BECount = SE->getBackedgeTakenCount(L);
1488 if (PreAR && any(PreAR->getNoWrapFlags(WrapType)) &&
1489 !isa<SCEVCouldNotCompute>(BECount) && SE->isKnownPositive(BECount))
1490 return PreStart;
1491
1492 // 2. Direct overflow check on the step operation's expression.
1493 unsigned BitWidth = SE->getTypeSizeInBits(AR->getType());
1494 Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2);
1495 const SCEV *OperandExtendedStart =
1496 SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy, Depth),
1497 (SE->*GetExtendExpr)(Step, WideTy, Depth));
1498 if ((SE->*GetExtendExpr)(Start, WideTy, Depth) == OperandExtendedStart) {
1499 if (PreAR && any(AR->getNoWrapFlags(WrapType))) {
1500 // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW
1501 // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then
1502 // `PreAR` == {`PreStart`,+,`Step`} is also `WrapType`. Cache this fact.
1503 SE->setNoWrapFlags(const_cast<SCEVAddRecExpr *>(PreAR), WrapType);
1504 }
1505 return PreStart;
1506 }
1507
1508 // 3. Loop precondition.
1510 const SCEV *OverflowLimit =
1511 ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(Step, &Pred, SE);
1512
1513 if (OverflowLimit &&
1514 SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit))
1515 return PreStart;
1516
1517 return nullptr;
1518}
1519
1520// Get the normalized zero or sign extended expression for this AddRec's Start.
1521template <typename ExtendOpTy>
1522static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty,
1523 ScalarEvolution *SE,
1524 unsigned Depth) {
1525 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1526
1527 const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE, Depth);
1528 if (!PreStart)
1529 return (SE->*GetExtendExpr)(AR->getStart(), Ty, Depth);
1530
1531 return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty,
1532 Depth),
1533 (SE->*GetExtendExpr)(PreStart, Ty, Depth));
1534}
1535
1536// Try to prove away overflow by looking at "nearby" add recurrences. A
1537// motivating example for this rule: if we know `{0,+,4}` is `ult` `-1` and it
1538// does not itself wrap then we can conclude that `{1,+,4}` is `nuw`.
1539//
1540// Formally:
1541//
1542// {S,+,X} == {S-T,+,X} + T
1543// => Ext({S,+,X}) == Ext({S-T,+,X} + T)
1544//
1545// If ({S-T,+,X} + T) does not overflow ... (1)
1546//
1547// RHS == Ext({S-T,+,X} + T) == Ext({S-T,+,X}) + Ext(T)
1548//
1549// If {S-T,+,X} does not overflow ... (2)
1550//
1551// RHS == Ext({S-T,+,X}) + Ext(T) == {Ext(S-T),+,Ext(X)} + Ext(T)
1552// == {Ext(S-T)+Ext(T),+,Ext(X)}
1553//
1554// If (S-T)+T does not overflow ... (3)
1555//
1556// RHS == {Ext(S-T)+Ext(T),+,Ext(X)} == {Ext(S-T+T),+,Ext(X)}
1557// == {Ext(S),+,Ext(X)} == LHS
1558//
1559// Thus, if (1), (2) and (3) are true for some T, then
1560// Ext({S,+,X}) == {Ext(S),+,Ext(X)}
1561//
1562// (3) is implied by (1) -- "(S-T)+T does not overflow" is simply "({S-T,+,X}+T)
1563// does not overflow" restricted to the 0th iteration. Therefore we only need
1564// to check for (1) and (2).
1565//
1566// In the current context, S is `Start`, X is `Step`, Ext is `ExtendOpTy` and T
1567// is `Delta` (defined below).
1568template <typename ExtendOpTy>
1569bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start,
1570 const SCEV *Step,
1571 const Loop *L) {
1572 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1573
1574 // We restrict `Start` to a constant to prevent SCEV from spending too much
1575 // time here. It is correct (but more expensive) to continue with a
1576 // non-constant `Start` and do a general SCEV subtraction to compute
1577 // `PreStart` below.
1578 const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start);
1579 if (!StartC)
1580 return false;
1581
1582 APInt StartAI = StartC->getAPInt();
1583
1584 for (unsigned Delta : {-2, -1, 1, 2}) {
1585 const SCEV *PreStart = getConstant(StartAI - Delta);
1586
1587 FoldingSetNodeID ID;
1588 ID.AddInteger(scAddRecExpr);
1589 ID.AddPointer(PreStart);
1590 ID.AddPointer(Step);
1591 ID.AddPointer(L);
1592 void *IP = nullptr;
1593 const auto *PreAR =
1594 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
1595
1596 // Give up if we don't already have the add recurrence we need because
1597 // actually constructing an add recurrence is relatively expensive.
1598 if (PreAR && any(PreAR->getNoWrapFlags(WrapType))) { // proves (2)
1599 const SCEV *DeltaS = getConstant(StartC->getType(), Delta);
1601 const SCEV *Limit = ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(
1602 DeltaS, &Pred, this);
1603 if (Limit && isKnownPredicate(Pred, PreAR, Limit)) // proves (1)
1604 return true;
1605 }
1606 }
1607
1608 return false;
1609}
1610
1611// Finds an integer D for an expression (C + x + y + ...) such that the top
1612// level addition in (D + (C - D + x + y + ...)) would not wrap (signed or
1613// unsigned) and the number of trailing zeros of (C - D + x + y + ...) is
1614// maximized, where C is the \p ConstantTerm, x, y, ... are arbitrary SCEVs, and
1615// the (C + x + y + ...) expression is \p WholeAddExpr.
1617 const SCEVConstant *ConstantTerm,
1618 const SCEVAddExpr *WholeAddExpr) {
1619 const APInt &C = ConstantTerm->getAPInt();
1620 const unsigned BitWidth = C.getBitWidth();
1621 // Find number of trailing zeros of (x + y + ...) w/o the C first:
1622 uint32_t TZ = BitWidth;
1623 for (unsigned I = 1, E = WholeAddExpr->getNumOperands(); I < E && TZ; ++I)
1624 TZ = std::min(TZ, SE.getMinTrailingZeros(WholeAddExpr->getOperand(I)));
1625 if (TZ) {
1626 // Set D to be as many least significant bits of C as possible while still
1627 // guaranteeing that adding D to (C - D + x + y + ...) won't cause a wrap:
1628 return TZ < BitWidth ? C.trunc(TZ).zext(BitWidth) : C;
1629 }
1630 return APInt(BitWidth, 0);
1631}
1632
1633// Finds an integer D for an affine AddRec expression {C,+,x} such that the top
1634// level addition in (D + {C-D,+,x}) would not wrap (signed or unsigned) and the
1635// number of trailing zeros of (C - D + x * n) is maximized, where C is the \p
1636// ConstantStart, x is an arbitrary \p Step, and n is the loop trip count.
1638 const APInt &ConstantStart,
1639 const SCEV *Step) {
1640 const unsigned BitWidth = ConstantStart.getBitWidth();
1641 const uint32_t TZ = SE.getMinTrailingZeros(Step);
1642 if (TZ)
1643 return TZ < BitWidth ? ConstantStart.trunc(TZ).zext(BitWidth)
1644 : ConstantStart;
1645 return APInt(BitWidth, 0);
1646}
1647
1649 const ScalarEvolution::FoldID &ID, const SCEV *S,
1652 &FoldCacheUser) {
1653 auto I = FoldCache.insert({ID, S});
1654 if (!I.second) {
1655 // Remove FoldCacheUser entry for ID when replacing an existing FoldCache
1656 // entry.
1657 auto &UserIDs = FoldCacheUser[I.first->second];
1658 assert(count(UserIDs, ID) == 1 && "unexpected duplicates in UserIDs");
1659 for (unsigned I = 0; I != UserIDs.size(); ++I)
1660 if (UserIDs[I] == ID) {
1661 std::swap(UserIDs[I], UserIDs.back());
1662 break;
1663 }
1664 UserIDs.pop_back();
1665 I.first->second = S;
1666 }
1667 FoldCacheUser[S].push_back(ID);
1668}
1669
1670const SCEV *
1672 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1673 "This is not an extending conversion!");
1674 assert(isSCEVable(Ty) &&
1675 "This is not a conversion to a SCEVable type!");
1676 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1677 Ty = getEffectiveSCEVType(Ty);
1678
1679 FoldID ID(scZeroExtend, Op, Ty);
1680 if (const SCEV *S = FoldCache.lookup(ID))
1681 return S;
1682
1683 const SCEV *S = getZeroExtendExprImpl(Op, Ty, Depth);
1685 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
1686 return S;
1687}
1688
1690 unsigned Depth) {
1691 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1692 "This is not an extending conversion!");
1693 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
1694 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1695
1696 // Fold if the operand is constant.
1697 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1698 return getConstant(SC->getAPInt().zext(getTypeSizeInBits(Ty)));
1699
1700 // zext(zext(x)) --> zext(x)
1702 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1703
1704 // Before doing any expensive analysis, check to see if we've already
1705 // computed a SCEV for this Op and Ty.
1707 ID.AddInteger(scZeroExtend);
1708 ID.AddPointer(Op);
1709 ID.AddPointer(Ty);
1710 void *IP = nullptr;
1711 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1712 if (Depth > MaxCastDepth) {
1713 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1714 Op, Ty);
1715 UniqueSCEVs.InsertNode(S, IP);
1716 S->computeAndSetCanonical(*this);
1717 registerUser(S, Op);
1718 return S;
1719 }
1720
1721 // zext(trunc(x)) --> zext(x) or x or trunc(x)
1723 // It's possible the bits taken off by the truncate were all zero bits. If
1724 // so, we should be able to simplify this further.
1725 const SCEV *X = ST->getOperand();
1727 unsigned TruncBits = getTypeSizeInBits(ST->getType());
1728 unsigned NewBits = getTypeSizeInBits(Ty);
1729 if (CR.truncate(TruncBits).zeroExtend(NewBits).contains(
1730 CR.zextOrTrunc(NewBits)))
1731 return getTruncateOrZeroExtend(X, Ty, Depth);
1732 }
1733
1734 // If the input value is a chrec scev, and we can prove that the value
1735 // did not overflow the old, smaller, value, we can zero extend all of the
1736 // operands (often constants). This allows analysis of something like
1737 // this: for (unsigned char X = 0; X < 100; ++X) { int Y = X; }
1739 if (AR->isAffine()) {
1740 const SCEV *Start = AR->getStart();
1741 const SCEV *Step = AR->getStepRecurrence(*this);
1742 unsigned BitWidth = getTypeSizeInBits(AR->getType());
1743 const Loop *L = AR->getLoop();
1744
1745 // If we have special knowledge that this addrec won't overflow,
1746 // we don't need to do any further analysis.
1747 if (AR->hasNoUnsignedWrap()) {
1748 Start =
1750 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1751 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1752 }
1753
1754 // Check whether the backedge-taken count is SCEVCouldNotCompute.
1755 // Note that this serves two purposes: It filters out loops that are
1756 // simply not analyzable, and it covers the case where this code is
1757 // being called from within backedge-taken count analysis, such that
1758 // attempting to ask for the backedge-taken count would likely result
1759 // in infinite recursion. In the later case, the analysis code will
1760 // cope with a conservative value, and it will take care to purge
1761 // that value once it has finished.
1762 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
1763 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
1764 // Manually compute the final value for AR, checking for overflow.
1765
1766 // Check whether the backedge-taken count can be losslessly casted to
1767 // the addrec's type. The count is always unsigned.
1768 const SCEV *CastedMaxBECount =
1769 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
1770 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
1771 CastedMaxBECount, MaxBECount->getType(), Depth);
1772 if (MaxBECount == RecastedMaxBECount) {
1773 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
1774 // Check whether Start+Step*MaxBECount has no unsigned overflow.
1775 const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step,
1777 const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul,
1779 Depth + 1),
1780 WideTy, Depth + 1);
1781 const SCEV *WideStart = getZeroExtendExpr(Start, WideTy, Depth + 1);
1782 const SCEV *WideMaxBECount =
1783 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
1784 const SCEV *OperandExtendedAdd =
1785 getAddExpr(WideStart,
1786 getMulExpr(WideMaxBECount,
1787 getZeroExtendExpr(Step, WideTy, Depth + 1),
1790 if (ZAdd == OperandExtendedAdd) {
1791 // Cache knowledge of AR NUW, which is propagated to this AddRec.
1792 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1793 // Return the expression with the addrec on the outside.
1794 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1795 Depth + 1);
1796 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1797 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1798 }
1799 // Similar to above, only this time treat the step value as signed.
1800 // This covers loops that count down.
1801 OperandExtendedAdd =
1802 getAddExpr(WideStart,
1803 getMulExpr(WideMaxBECount,
1804 getSignExtendExpr(Step, WideTy, Depth + 1),
1807 if (ZAdd == OperandExtendedAdd) {
1808 // Cache knowledge of AR NW, which is propagated to this AddRec.
1809 // Negative step causes unsigned wrap, but it still can't self-wrap.
1810 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1811 // Return the expression with the addrec on the outside.
1812 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1813 Depth + 1);
1814 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1815 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1816 }
1817 }
1818 }
1819
1820 // Normally, in the cases we can prove no-overflow via a
1821 // backedge guarding condition, we can also compute a backedge
1822 // taken count for the loop. The exceptions are assumptions and
1823 // guards present in the loop -- SCEV is not great at exploiting
1824 // these to compute max backedge taken counts, but can still use
1825 // these to prove lack of overflow. Use this fact to avoid
1826 // doing extra work that may not pay off.
1827 if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards ||
1828 !AC.assumptions().empty()) {
1829
1830 auto NewFlags = proveNoUnsignedWrapViaInduction(AR);
1831 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
1832 if (AR->hasNoUnsignedWrap()) {
1833 // Same as nuw case above - duplicated here to avoid a compile time
1834 // issue. It's not clear that the order of checks does matter, but
1835 // it's one of two issue possible causes for a change which was
1836 // reverted. Be conservative for the moment.
1837 Start =
1839 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1840 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1841 }
1842
1843 // For a negative step, we can extend the operands iff doing so only
1844 // traverses values in the range zext([0,UINT_MAX]).
1845 if (isKnownNegative(Step)) {
1847 getSignedRangeMin(Step));
1850 // Cache knowledge of AR NW, which is propagated to this
1851 // AddRec. Negative step causes unsigned wrap, but it
1852 // still can't self-wrap.
1853 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1854 // Return the expression with the addrec on the outside.
1855 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1856 Depth + 1);
1857 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1858 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1859 }
1860 }
1861 }
1862
1863 // zext({C,+,Step}) --> (zext(D) + zext({C-D,+,Step}))<nuw><nsw>
1864 // if D + (C - D + Step * n) could be proven to not unsigned wrap
1865 // where D maximizes the number of trailing zeros of (C - D + Step * n)
1866 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
1867 const APInt &C = SC->getAPInt();
1868 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
1869 if (D != 0) {
1870 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1871 const SCEV *SResidual =
1872 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
1873 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1874 return getAddExpr(SZExtD, SZExtR, SCEV::FlagNSW | SCEV::FlagNUW,
1875 Depth + 1);
1876 }
1877 }
1878
1879 if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) {
1880 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1881 Start =
1883 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1884 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1885 }
1886 }
1887
1888 // zext(A % B) --> zext(A) % zext(B)
1889 {
1890 const SCEV *LHS;
1891 const SCEV *RHS;
1892 if (match(Op, m_scev_URem(m_SCEV(LHS), m_SCEV(RHS), *this)))
1893 return getURemExpr(getZeroExtendExpr(LHS, Ty, Depth + 1),
1894 getZeroExtendExpr(RHS, Ty, Depth + 1));
1895 }
1896
1897 // zext(A / B) --> zext(A) / zext(B).
1898 if (auto *Div = dyn_cast<SCEVUDivExpr>(Op))
1899 return getUDivExpr(getZeroExtendExpr(Div->getLHS(), Ty, Depth + 1),
1900 getZeroExtendExpr(Div->getRHS(), Ty, Depth + 1));
1901
1902 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1903 // zext((A + B + ...)<nuw>) --> (zext(A) + zext(B) + ...)<nuw>
1904 if (SA->hasNoUnsignedWrap()) {
1905 // If the addition does not unsign overflow then we can, by definition,
1906 // commute the zero extension with the addition operation.
1908 for (SCEVUse Op : SA->operands())
1909 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1910 return getAddExpr(Ops, SCEV::FlagNUW, Depth + 1);
1911 }
1912
1913 const APInt *C, *C2;
1914 // zext (C + A)<nsw> -> (sext(C) + sext(A))<nsw> if zext (C + A)<nsw> >=s 0.
1915 // Currently the non-negative check is done manually, as isKnownNonNegative
1916 // is too expensive.
1917 if (SA->hasNoSignedWrap() &&
1919 m_scev_SMax(m_scev_APInt(C2), m_SCEV()))) &&
1920 C->isNegative() && !C->isMinSignedValue() && C2->sge(C->abs())) {
1921 assert(isKnownNonNegative(SA) && "incorrectly determined non-negative");
1922 return getAddExpr(getSignExtendExpr(SA->getOperand(0), Ty, Depth + 1),
1923 getSignExtendExpr(SA->getOperand(1), Ty, Depth + 1),
1924 SCEV::FlagNSW, Depth + 1);
1925 }
1926
1927 // zext(C + x + y + ...) --> (zext(D) + zext((C - D) + x + y + ...))
1928 // if D + (C - D + x + y + ...) could be proven to not unsigned wrap
1929 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1930 //
1931 // Often address arithmetics contain expressions like
1932 // (zext (add (shl X, C1), C2)), for instance, (zext (5 + (4 * X))).
1933 // This transformation is useful while proving that such expressions are
1934 // equal or differ by a small constant amount, see LoadStoreVectorizer pass.
1935 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1936 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1937 if (D != 0) {
1938 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1939 const SCEV *SResidual =
1941 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1942 return getAddExpr(SZExtD, SZExtR, (SCEV::FlagNSW | SCEV::FlagNUW),
1943 Depth + 1);
1944 }
1945 }
1946 }
1947
1948 if (auto *SM = dyn_cast<SCEVMulExpr>(Op)) {
1949 // zext((A * B * ...)<nuw>) --> (zext(A) * zext(B) * ...)<nuw>
1950 if (SM->hasNoUnsignedWrap()) {
1951 // If the multiply does not unsign overflow then we can, by definition,
1952 // commute the zero extension with the multiply operation.
1954 for (SCEVUse Op : SM->operands())
1955 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1956 return getMulExpr(Ops, SCEV::FlagNUW, Depth + 1);
1957 }
1958
1959 // zext(2^K * (trunc X to iN)) to iM ->
1960 // 2^K * (zext(trunc X to i{N-K}) to iM)<nuw>
1961 //
1962 // Proof:
1963 //
1964 // zext(2^K * (trunc X to iN)) to iM
1965 // = zext((trunc X to iN) << K) to iM
1966 // = zext((trunc X to i{N-K}) << K)<nuw> to iM
1967 // (because shl removes the top K bits)
1968 // = zext((2^K * (trunc X to i{N-K}))<nuw>) to iM
1969 // = (2^K * (zext(trunc X to i{N-K}) to iM))<nuw>.
1970 //
1971 const APInt *C;
1972 const SCEV *TruncRHS;
1973 if (match(SM,
1974 m_scev_Mul(m_scev_APInt(C), m_scev_Trunc(m_SCEV(TruncRHS)))) &&
1975 C->isPowerOf2()) {
1976 int NewTruncBits =
1977 getTypeSizeInBits(SM->getOperand(1)->getType()) - C->logBase2();
1978 Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits);
1979 return getMulExpr(
1980 getZeroExtendExpr(SM->getOperand(0), Ty),
1981 getZeroExtendExpr(getTruncateExpr(TruncRHS, NewTruncTy), Ty),
1982 SCEV::FlagNUW, Depth + 1);
1983 }
1984 }
1985
1986 // zext(umin(x, y)) -> umin(zext(x), zext(y))
1987 // zext(umax(x, y)) -> umax(zext(x), zext(y))
1990 SmallVector<SCEVUse, 4> Operands;
1991 for (SCEVUse Operand : MinMax->operands())
1992 Operands.push_back(getZeroExtendExpr(Operand, Ty));
1994 return getUMinExpr(Operands);
1995 return getUMaxExpr(Operands);
1996 }
1997
1998 // zext(umin_seq(x, y)) -> umin_seq(zext(x), zext(y))
2000 assert(isa<SCEVSequentialUMinExpr>(MinMax) && "Not supported!");
2001 SmallVector<SCEVUse, 4> Operands;
2002 for (SCEVUse Operand : MinMax->operands())
2003 Operands.push_back(getZeroExtendExpr(Operand, Ty));
2004 return getUMinExpr(Operands, /*Sequential*/ true);
2005 }
2006
2007 // The cast wasn't folded; create an explicit cast node.
2008 // Recompute the insert position, as it may have been invalidated.
2009 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2010 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
2011 Op, Ty);
2012 UniqueSCEVs.InsertNode(S, IP);
2013 S->computeAndSetCanonical(*this);
2014 registerUser(S, Op);
2015 return S;
2016}
2017
2018const SCEV *
2020 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
2021 "This is not an extending conversion!");
2022 assert(isSCEVable(Ty) &&
2023 "This is not a conversion to a SCEVable type!");
2024 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
2025 Ty = getEffectiveSCEVType(Ty);
2026
2027 FoldID ID(scSignExtend, Op, Ty);
2028 if (const SCEV *S = FoldCache.lookup(ID))
2029 return S;
2030
2031 const SCEV *S = getSignExtendExprImpl(Op, Ty, Depth);
2033 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
2034 return S;
2035}
2036
2038 unsigned Depth) {
2039 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
2040 "This is not an extending conversion!");
2041 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
2042 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
2043 Ty = getEffectiveSCEVType(Ty);
2044
2045 // Fold if the operand is constant.
2046 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
2047 return getConstant(SC->getAPInt().sext(getTypeSizeInBits(Ty)));
2048
2049 // sext(sext(x)) --> sext(x)
2051 return getSignExtendExpr(SS->getOperand(), Ty, Depth + 1);
2052
2053 // sext(zext(x)) --> zext(x)
2055 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
2056
2057 // Before doing any expensive analysis, check to see if we've already
2058 // computed a SCEV for this Op and Ty.
2060 ID.AddInteger(scSignExtend);
2061 ID.AddPointer(Op);
2062 ID.AddPointer(Ty);
2063 void *IP = nullptr;
2064 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2065 // Limit recursion depth.
2066 if (Depth > MaxCastDepth) {
2067 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
2068 Op, Ty);
2069 UniqueSCEVs.InsertNode(S, IP);
2070 S->computeAndSetCanonical(*this);
2071 registerUser(S, Op);
2072 return S;
2073 }
2074
2075 // sext(trunc(x)) --> sext(x) or x or trunc(x)
2077 // It's possible the bits taken off by the truncate were all sign bits. If
2078 // so, we should be able to simplify this further.
2079 const SCEV *X = ST->getOperand();
2081 unsigned TruncBits = getTypeSizeInBits(ST->getType());
2082 unsigned NewBits = getTypeSizeInBits(Ty);
2083 if (CR.truncate(TruncBits).signExtend(NewBits).contains(
2084 CR.sextOrTrunc(NewBits)))
2085 return getTruncateOrSignExtend(X, Ty, Depth);
2086 }
2087
2088 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
2089 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
2090 if (SA->hasNoSignedWrap()) {
2091 // If the addition does not sign overflow then we can, by definition,
2092 // commute the sign extension with the addition operation.
2094 for (SCEVUse Op : SA->operands())
2095 Ops.push_back(getSignExtendExpr(Op, Ty, Depth + 1));
2096 return getAddExpr(Ops, SCEV::FlagNSW, Depth + 1);
2097 }
2098
2099 // sext(C + x + y + ...) --> (sext(D) + sext((C - D) + x + y + ...))
2100 // if D + (C - D + x + y + ...) could be proven to not signed wrap
2101 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
2102 //
2103 // For instance, this will bring two seemingly different expressions:
2104 // 1 + sext(5 + 20 * %x + 24 * %y) and
2105 // sext(6 + 20 * %x + 24 * %y)
2106 // to the same form:
2107 // 2 + sext(4 + 20 * %x + 24 * %y)
2108 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
2109 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
2110 if (D != 0) {
2111 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
2112 const SCEV *SResidual =
2114 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
2115 return getAddExpr(SSExtD, SSExtR, (SCEV::FlagNSW | SCEV::FlagNUW),
2116 Depth + 1);
2117 }
2118 }
2119 }
2120 // If the input value is a chrec scev, and we can prove that the value
2121 // did not overflow the old, smaller, value, we can sign extend all of the
2122 // operands (often constants). This allows analysis of something like
2123 // this: for (signed char X = 0; X < 100; ++X) { int Y = X; }
2125 if (AR->isAffine()) {
2126 const SCEV *Start = AR->getStart();
2127 const SCEV *Step = AR->getStepRecurrence(*this);
2128 unsigned BitWidth = getTypeSizeInBits(AR->getType());
2129 const Loop *L = AR->getLoop();
2130
2131 // If we have special knowledge that this addrec won't overflow,
2132 // we don't need to do any further analysis.
2133 if (AR->hasNoSignedWrap()) {
2134 Start =
2136 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2137 return getAddRecExpr(Start, Step, L, SCEV::FlagNSW);
2138 }
2139
2140 // Check whether the backedge-taken count is SCEVCouldNotCompute.
2141 // Note that this serves two purposes: It filters out loops that are
2142 // simply not analyzable, and it covers the case where this code is
2143 // being called from within backedge-taken count analysis, such that
2144 // attempting to ask for the backedge-taken count would likely result
2145 // in infinite recursion. In the later case, the analysis code will
2146 // cope with a conservative value, and it will take care to purge
2147 // that value once it has finished.
2148 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
2149 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
2150 // Manually compute the final value for AR, checking for
2151 // overflow.
2152
2153 // Check whether the backedge-taken count can be losslessly casted to
2154 // the addrec's type. The count is always unsigned.
2155 const SCEV *CastedMaxBECount =
2156 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
2157 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
2158 CastedMaxBECount, MaxBECount->getType(), Depth);
2159 if (MaxBECount == RecastedMaxBECount) {
2160 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
2161 // Check whether Start+Step*MaxBECount has no signed overflow.
2162 const SCEV *SMul = getMulExpr(CastedMaxBECount, Step,
2164 const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul,
2166 Depth + 1),
2167 WideTy, Depth + 1);
2168 const SCEV *WideStart = getSignExtendExpr(Start, WideTy, Depth + 1);
2169 const SCEV *WideMaxBECount =
2170 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
2171 const SCEV *OperandExtendedAdd =
2172 getAddExpr(WideStart,
2173 getMulExpr(WideMaxBECount,
2174 getSignExtendExpr(Step, WideTy, Depth + 1),
2177 if (SAdd == OperandExtendedAdd) {
2178 // Cache knowledge of AR NSW, which is propagated to this AddRec.
2179 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2180 // Return the expression with the addrec on the outside.
2181 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2182 Depth + 1);
2183 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2184 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2185 }
2186 // Similar to above, only this time treat the step value as unsigned.
2187 // This covers loops that count up with an unsigned step.
2188 OperandExtendedAdd =
2189 getAddExpr(WideStart,
2190 getMulExpr(WideMaxBECount,
2191 getZeroExtendExpr(Step, WideTy, Depth + 1),
2194 if (SAdd == OperandExtendedAdd) {
2195 // If AR wraps around then
2196 //
2197 // abs(Step) * MaxBECount > unsigned-max(AR->getType())
2198 // => SAdd != OperandExtendedAdd
2199 //
2200 // Thus (AR is not NW => SAdd != OperandExtendedAdd) <=>
2201 // (SAdd == OperandExtendedAdd => AR is NW)
2202
2203 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
2204
2205 // Return the expression with the addrec on the outside.
2206 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2207 Depth + 1);
2208 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
2209 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2210 }
2211 }
2212 }
2213
2214 auto NewFlags = proveNoSignedWrapViaInduction(AR);
2215 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
2216 if (AR->hasNoSignedWrap()) {
2217 // Same as nsw case above - duplicated here to avoid a compile time
2218 // issue. It's not clear that the order of checks does matter, but
2219 // it's one of two issue possible causes for a change which was
2220 // reverted. Be conservative for the moment.
2221 Start =
2223 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2224 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2225 }
2226
2227 // sext({C,+,Step}) --> (sext(D) + sext({C-D,+,Step}))<nuw><nsw>
2228 // if D + (C - D + Step * n) could be proven to not signed wrap
2229 // where D maximizes the number of trailing zeros of (C - D + Step * n)
2230 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
2231 const APInt &C = SC->getAPInt();
2232 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
2233 if (D != 0) {
2234 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
2235 const SCEV *SResidual =
2236 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
2237 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
2238 return getAddExpr(SSExtD, SSExtR, (SCEV::FlagNSW | SCEV::FlagNUW),
2239 Depth + 1);
2240 }
2241 }
2242
2243 if (proveNoWrapByVaryingStart<SCEVSignExtendExpr>(Start, Step, L)) {
2244 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2245 Start =
2247 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2248 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2249 }
2250 }
2251
2252 // If the input value is provably positive and we could not simplify
2253 // away the sext build a zext instead.
2255 return getZeroExtendExpr(Op, Ty, Depth + 1);
2256
2257 // sext(smin(x, y)) -> smin(sext(x), sext(y))
2258 // sext(smax(x, y)) -> smax(sext(x), sext(y))
2261 SmallVector<SCEVUse, 4> Operands;
2262 for (SCEVUse Operand : MinMax->operands())
2263 Operands.push_back(getSignExtendExpr(Operand, Ty));
2265 return getSMinExpr(Operands);
2266 return getSMaxExpr(Operands);
2267 }
2268
2269 // The cast wasn't folded; create an explicit cast node.
2270 // Recompute the insert position, as it may have been invalidated.
2271 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2272 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
2273 Op, Ty);
2274 UniqueSCEVs.InsertNode(S, IP);
2275 S->computeAndSetCanonical(*this);
2276 registerUser(S, Op);
2277 return S;
2278}
2279
2281 Type *Ty) {
2282 switch (Kind) {
2283 case scTruncate:
2284 return getTruncateExpr(Op, Ty);
2285 case scZeroExtend:
2286 return getZeroExtendExpr(Op, Ty);
2287 case scSignExtend:
2288 return getSignExtendExpr(Op, Ty);
2289 case scPtrToInt:
2290 return getPtrToIntExpr(Op, Ty);
2291 default:
2292 llvm_unreachable("Not a SCEV cast expression!");
2293 }
2294}
2295
2296/// getAnyExtendExpr - Return a SCEV for the given operand extended with
2297/// unspecified bits out to the given type.
2299 Type *Ty) {
2300 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
2301 "This is not an extending conversion!");
2302 assert(isSCEVable(Ty) &&
2303 "This is not a conversion to a SCEVable type!");
2304 Ty = getEffectiveSCEVType(Ty);
2305
2306 // Sign-extend negative constants.
2307 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
2308 if (SC->getAPInt().isNegative())
2309 return getSignExtendExpr(Op, Ty);
2310
2311 // Peel off a truncate cast.
2313 const SCEV *NewOp = T->getOperand();
2314 if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty))
2315 return getAnyExtendExpr(NewOp, Ty);
2316 return getTruncateOrNoop(NewOp, Ty);
2317 }
2318
2319 // Next try a zext cast. If the cast is folded, use it.
2320 const SCEV *ZExt = getZeroExtendExpr(Op, Ty);
2321 if (!isa<SCEVZeroExtendExpr>(ZExt))
2322 return ZExt;
2323
2324 // Next try a sext cast. If the cast is folded, use it.
2325 const SCEV *SExt = getSignExtendExpr(Op, Ty);
2326 if (!isa<SCEVSignExtendExpr>(SExt))
2327 return SExt;
2328
2329 // Force the cast to be folded into the operands of an addrec.
2330 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) {
2332 for (const SCEV *Op : AR->operands())
2333 Ops.push_back(getAnyExtendExpr(Op, Ty));
2334 return getAddRecExpr(Ops, AR->getLoop(), SCEV::FlagNW);
2335 }
2336
2337 // If the expression is obviously signed, use the sext cast value.
2338 if (isa<SCEVSMaxExpr>(Op))
2339 return SExt;
2340
2341 // Absent any other information, use the zext cast value.
2342 return ZExt;
2343}
2344
2345/// Process the given Ops list, which is a list of operands to be added under
2346/// the given scale, update the given map. This is a helper function for
2347/// getAddRecExpr. As an example of what it does, given a sequence of operands
2348/// that would form an add expression like this:
2349///
2350/// m + n + 13 + (A * (o + p + (B * (q + m + 29)))) + r + (-1 * r)
2351///
2352/// where A and B are constants, update the map with these values:
2353///
2354/// (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0)
2355///
2356/// and add 13 + A*B*29 to AccumulatedConstant.
2357/// This will allow getAddRecExpr to produce this:
2358///
2359/// 13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B)
2360///
2361/// This form often exposes folding opportunities that are hidden in
2362/// the original operand list.
2363///
2364/// Return true iff it appears that any interesting folding opportunities
2365/// may be exposed. This helps getAddRecExpr short-circuit extra work in
2366/// the common case where no interesting opportunities are present, and
2367/// is also used as a check to avoid infinite recursion.
2370 APInt &AccumulatedConstant,
2372 const APInt &Scale,
2373 ScalarEvolution &SE) {
2374 bool Interesting = false;
2375
2376 // Iterate over the add operands. They are sorted, with constants first.
2377 unsigned i = 0;
2378 while (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
2379 ++i;
2380 // Pull a buried constant out to the outside.
2381 if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero())
2382 Interesting = true;
2383 AccumulatedConstant += Scale * C->getAPInt();
2384 }
2385
2386 // Next comes everything else. We're especially interested in multiplies
2387 // here, but they're in the middle, so just visit the rest with one loop.
2388 for (; i != Ops.size(); ++i) {
2390 if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) {
2391 APInt NewScale =
2392 Scale * cast<SCEVConstant>(Mul->getOperand(0))->getAPInt();
2393 if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) {
2394 // A multiplication of a constant with another add; recurse.
2395 const SCEVAddExpr *Add = cast<SCEVAddExpr>(Mul->getOperand(1));
2396 Interesting |= CollectAddOperandsWithScales(
2397 M, NewOps, AccumulatedConstant, Add->operands(), NewScale, SE);
2398 } else {
2399 // A multiplication of a constant with some other value. Update
2400 // the map.
2401 SmallVector<SCEVUse, 4> MulOps(drop_begin(Mul->operands()));
2402 const SCEV *Key = SE.getMulExpr(MulOps);
2403 auto Pair = M.insert({Key, NewScale});
2404 if (Pair.second) {
2405 NewOps.push_back(Pair.first->first);
2406 } else {
2407 Pair.first->second += NewScale;
2408 // The map already had an entry for this value, which may indicate
2409 // a folding opportunity.
2410 Interesting = true;
2411 }
2412 }
2413 } else {
2414 // An ordinary operand. Update the map.
2415 auto Pair = M.insert({Ops[i], Scale});
2416 if (Pair.second) {
2417 NewOps.push_back(Pair.first->first);
2418 } else {
2419 Pair.first->second += Scale;
2420 // The map already had an entry for this value, which may indicate
2421 // a folding opportunity.
2422 Interesting = true;
2423 }
2424 }
2425 }
2426
2427 return Interesting;
2428}
2429
2431 const SCEV *LHS, const SCEV *RHS,
2432 const Instruction *CtxI) {
2434 unsigned);
2435 switch (BinOp) {
2436 default:
2437 llvm_unreachable("Unsupported binary op");
2438 case Instruction::Add:
2440 break;
2441 case Instruction::Sub:
2443 break;
2444 case Instruction::Mul:
2446 break;
2447 }
2448
2449 const SCEV *(ScalarEvolution::*Extension)(const SCEV *, Type *, unsigned) =
2452
2453 // Check ext(LHS op RHS) == ext(LHS) op ext(RHS)
2454 auto *NarrowTy = cast<IntegerType>(LHS->getType());
2455 auto *WideTy =
2456 IntegerType::get(NarrowTy->getContext(), NarrowTy->getBitWidth() * 2);
2457
2458 const SCEV *A = (this->*Extension)(
2459 (this->*Operation)(LHS, RHS, SCEV::FlagAnyWrap, 0), WideTy, 0);
2460 const SCEV *LHSB = (this->*Extension)(LHS, WideTy, 0);
2461 const SCEV *RHSB = (this->*Extension)(RHS, WideTy, 0);
2462 const SCEV *B = (this->*Operation)(LHSB, RHSB, SCEV::FlagAnyWrap, 0);
2463 if (A == B)
2464 return true;
2465 // Can we use context to prove the fact we need?
2466 if (!CtxI)
2467 return false;
2468 // TODO: Support mul.
2469 if (BinOp == Instruction::Mul)
2470 return false;
2471 auto *RHSC = dyn_cast<SCEVConstant>(RHS);
2472 // TODO: Lift this limitation.
2473 if (!RHSC)
2474 return false;
2475 APInt C = RHSC->getAPInt();
2476 unsigned NumBits = C.getBitWidth();
2477 bool IsSub = (BinOp == Instruction::Sub);
2478 bool IsNegativeConst = (Signed && C.isNegative());
2479 // Compute the direction and magnitude by which we need to check overflow.
2480 bool OverflowDown = IsSub ^ IsNegativeConst;
2481 APInt Magnitude = C;
2482 if (IsNegativeConst) {
2483 if (C == APInt::getSignedMinValue(NumBits))
2484 // TODO: SINT_MIN on inversion gives the same negative value, we don't
2485 // want to deal with that.
2486 return false;
2487 Magnitude = -C;
2488 }
2489
2491 if (OverflowDown) {
2492 // To avoid overflow down, we need to make sure that MIN + Magnitude <= LHS.
2493 APInt Min = Signed ? APInt::getSignedMinValue(NumBits)
2494 : APInt::getMinValue(NumBits);
2495 APInt Limit = Min + Magnitude;
2496 return isKnownPredicateAt(Pred, getConstant(Limit), LHS, CtxI);
2497 } else {
2498 // To avoid overflow up, we need to make sure that LHS <= MAX - Magnitude.
2499 APInt Max = Signed ? APInt::getSignedMaxValue(NumBits)
2500 : APInt::getMaxValue(NumBits);
2501 APInt Limit = Max - Magnitude;
2502 return isKnownPredicateAt(Pred, LHS, getConstant(Limit), CtxI);
2503 }
2504}
2505
2506std::optional<SCEV::NoWrapFlags>
2508 const OverflowingBinaryOperator *OBO) {
2509 // It cannot be done any better.
2510 if (OBO->hasNoUnsignedWrap() && OBO->hasNoSignedWrap())
2511 return std::nullopt;
2512
2513 SCEV::NoWrapFlags Flags = SCEV::NoWrapFlags::FlagAnyWrap;
2514
2515 if (OBO->hasNoUnsignedWrap())
2517 if (OBO->hasNoSignedWrap())
2519
2520 bool Deduced = false;
2521
2522 if (OBO->getOpcode() != Instruction::Add &&
2523 OBO->getOpcode() != Instruction::Sub &&
2524 OBO->getOpcode() != Instruction::Mul)
2525 return std::nullopt;
2526
2527 const SCEV *LHS = getSCEV(OBO->getOperand(0));
2528 const SCEV *RHS = getSCEV(OBO->getOperand(1));
2529
2530 const Instruction *CtxI =
2532 if (!OBO->hasNoUnsignedWrap() &&
2534 /* Signed */ false, LHS, RHS, CtxI)) {
2536 Deduced = true;
2537 }
2538
2539 if (!OBO->hasNoSignedWrap() &&
2541 /* Signed */ true, LHS, RHS, CtxI)) {
2543 Deduced = true;
2544 }
2545
2546 if (Deduced)
2547 return Flags;
2548 return std::nullopt;
2549}
2550
2551// We're trying to construct a SCEV of type `Type' with `Ops' as operands and
2552// `OldFlags' as can't-wrap behavior. Infer a more aggressive set of
2553// can't-overflow flags for the operation if possible.
2557 SCEV::NoWrapFlags Flags) {
2558 using namespace std::placeholders;
2559
2560 using OBO = OverflowingBinaryOperator;
2561
2562 bool CanAnalyze =
2564 (void)CanAnalyze;
2565 assert(CanAnalyze && "don't call from other places!");
2566
2567 SCEV::NoWrapFlags SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
2568 SCEV::NoWrapFlags SignOrUnsignWrap =
2569 ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2570
2571 // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
2572 auto IsKnownNonNegative = [&](SCEVUse U) {
2573 return SE->isKnownNonNegative(U);
2574 };
2575
2576 if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative))
2577 Flags = ScalarEvolution::setFlags(Flags, SignOrUnsignMask);
2578
2579 SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2580
2581 if (SignOrUnsignWrap != SignOrUnsignMask &&
2582 (Type == scAddExpr || Type == scMulExpr) && Ops.size() == 2 &&
2583 isa<SCEVConstant>(Ops[0])) {
2584
2585 auto Opcode = [&] {
2586 switch (Type) {
2587 case scAddExpr:
2588 return Instruction::Add;
2589 case scMulExpr:
2590 return Instruction::Mul;
2591 default:
2592 llvm_unreachable("Unexpected SCEV op.");
2593 }
2594 }();
2595
2596 const APInt &C = cast<SCEVConstant>(Ops[0])->getAPInt();
2597
2598 // (A <opcode> C) --> (A <opcode> C)<nsw> if the op doesn't sign overflow.
2599 if (!(SignOrUnsignWrap & SCEV::FlagNSW)) {
2601 Opcode, C, OBO::NoSignedWrap);
2602 if (NSWRegion.contains(SE->getSignedRange(Ops[1])))
2604 }
2605
2606 // (A <opcode> C) --> (A <opcode> C)<nuw> if the op doesn't unsign overflow.
2607 if (!(SignOrUnsignWrap & SCEV::FlagNUW)) {
2609 Opcode, C, OBO::NoUnsignedWrap);
2610 if (NUWRegion.contains(SE->getUnsignedRange(Ops[1])))
2612 }
2613 }
2614
2615 // <0,+,nonnegative><nw> is also nuw
2616 // TODO: Add corresponding nsw case
2618 !ScalarEvolution::hasFlags(Flags, SCEV::FlagNUW) && Ops.size() == 2 &&
2619 Ops[0]->isZero() && IsKnownNonNegative(Ops[1]))
2621
2622 // both (udiv X, Y) * Y and Y * (udiv X, Y) are always NUW
2624 Ops.size() == 2) {
2625 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[0]))
2626 if (UDiv->getOperand(1) == Ops[1])
2628 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[1]))
2629 if (UDiv->getOperand(1) == Ops[0])
2631 }
2632
2633 return Flags;
2634}
2635
2637 return isLoopInvariant(S, L) && properlyDominates(S, L->getHeader());
2638}
2639
2640/// Get a canonical add expression, or something simpler if possible.
2642 SCEV::NoWrapFlags OrigFlags,
2643 unsigned Depth) {
2644 assert(!(OrigFlags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) &&
2645 "only nuw or nsw allowed");
2646 assert(!Ops.empty() && "Cannot get empty add!");
2647 if (Ops.size() == 1) return Ops[0];
2648#ifndef NDEBUG
2649 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
2650 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2651 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
2652 "SCEVAddExpr operand types don't match!");
2653 unsigned NumPtrs = count_if(
2654 Ops, [](const SCEV *Op) { return Op->getType()->isPointerTy(); });
2655 assert(NumPtrs <= 1 && "add has at most one pointer operand");
2656#endif
2657
2658 const SCEV *Folded = constantFoldAndGroupOps(
2659 *this, LI, DT, Ops,
2660 [](const APInt &C1, const APInt &C2) { return C1 + C2; },
2661 [](const APInt &C) { return C.isZero(); }, // identity
2662 [](const APInt &C) { return false; }); // absorber
2663 if (Folded)
2664 return Folded;
2665
2666 unsigned Idx = isa<SCEVConstant>(Ops[0]) ? 1 : 0;
2667
2668 // Delay expensive flag strengthening until necessary.
2669 auto ComputeFlags = [this, OrigFlags](ArrayRef<SCEVUse> Ops) {
2670 return StrengthenNoWrapFlags(this, scAddExpr, Ops, OrigFlags);
2671 };
2672
2673 // Limit recursion calls depth.
2675 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2676
2677 if (SCEV *S = findExistingSCEVInCache(scAddExpr, Ops)) {
2678 // Don't strengthen flags if we have no new information.
2679 SCEVAddExpr *Add = static_cast<SCEVAddExpr *>(S);
2680 if (Add->getNoWrapFlags(OrigFlags) != OrigFlags)
2681 Add->setNoWrapFlags(ComputeFlags(Ops));
2682 return S;
2683 }
2684
2685 // Okay, check to see if the same value occurs in the operand list more than
2686 // once. If so, merge them together into an multiply expression. Since we
2687 // sorted the list, these values are required to be adjacent.
2688 Type *Ty = Ops[0]->getType();
2689 bool FoundMatch = false;
2690 for (unsigned i = 0, e = Ops.size(); i != e-1; ++i)
2691 if (Ops[i] == Ops[i+1]) { // X + Y + Y --> X + Y*2
2692 // Scan ahead to count how many equal operands there are.
2693 unsigned Count = 2;
2694 while (i+Count != e && Ops[i+Count] == Ops[i])
2695 ++Count;
2696 // Merge the values into a multiply.
2697 SCEVUse Scale = getConstant(Ty, Count);
2698 const SCEV *Mul = getMulExpr(Scale, Ops[i], SCEV::FlagAnyWrap, Depth + 1);
2699 if (Ops.size() == Count)
2700 return Mul;
2701 Ops[i] = Mul;
2702 Ops.erase(Ops.begin()+i+1, Ops.begin()+i+Count);
2703 --i; e -= Count - 1;
2704 FoundMatch = true;
2705 }
2706 if (FoundMatch)
2707 return getAddExpr(Ops, OrigFlags, Depth + 1);
2708
2709 // Check for truncates. If all the operands are truncated from the same
2710 // type, see if factoring out the truncate would permit the result to be
2711 // folded. eg., n*trunc(x) + m*trunc(y) --> trunc(trunc(m)*x + trunc(n)*y)
2712 // if the contents of the resulting outer trunc fold to something simple.
2713 auto FindTruncSrcType = [&]() -> Type * {
2714 // We're ultimately looking to fold an addrec of truncs and muls of only
2715 // constants and truncs, so if we find any other types of SCEV
2716 // as operands of the addrec then we bail and return nullptr here.
2717 // Otherwise, we return the type of the operand of a trunc that we find.
2718 if (auto *T = dyn_cast<SCEVTruncateExpr>(Ops[Idx]))
2719 return T->getOperand()->getType();
2720 if (const auto *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
2721 SCEVUse LastOp = Mul->getOperand(Mul->getNumOperands() - 1);
2722 if (const auto *T = dyn_cast<SCEVTruncateExpr>(LastOp))
2723 return T->getOperand()->getType();
2724 }
2725 return nullptr;
2726 };
2727 if (auto *SrcType = FindTruncSrcType()) {
2728 SmallVector<SCEVUse, 8> LargeOps;
2729 bool Ok = true;
2730 // Check all the operands to see if they can be represented in the
2731 // source type of the truncate.
2732 for (const SCEV *Op : Ops) {
2734 if (T->getOperand()->getType() != SrcType) {
2735 Ok = false;
2736 break;
2737 }
2738 LargeOps.push_back(T->getOperand());
2739 } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Op)) {
2740 LargeOps.push_back(getAnyExtendExpr(C, SrcType));
2741 } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Op)) {
2742 SmallVector<SCEVUse, 8> LargeMulOps;
2743 for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) {
2744 if (const SCEVTruncateExpr *T =
2745 dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) {
2746 if (T->getOperand()->getType() != SrcType) {
2747 Ok = false;
2748 break;
2749 }
2750 LargeMulOps.push_back(T->getOperand());
2751 } else if (const auto *C = dyn_cast<SCEVConstant>(M->getOperand(j))) {
2752 LargeMulOps.push_back(getAnyExtendExpr(C, SrcType));
2753 } else {
2754 Ok = false;
2755 break;
2756 }
2757 }
2758 if (Ok)
2759 LargeOps.push_back(getMulExpr(LargeMulOps, SCEV::FlagAnyWrap, Depth + 1));
2760 } else {
2761 Ok = false;
2762 break;
2763 }
2764 }
2765 if (Ok) {
2766 // Evaluate the expression in the larger type.
2767 const SCEV *Fold = getAddExpr(LargeOps, SCEV::FlagAnyWrap, Depth + 1);
2768 // If it folds to something simple, use it. Otherwise, don't.
2769 if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
2770 return getTruncateExpr(Fold, Ty);
2771 }
2772 }
2773
2774 if (Ops.size() == 2) {
2775 // Check if we have an expression of the form ((X + C1) - C2), where C1 and
2776 // C2 can be folded in a way that allows retaining wrapping flags of (X +
2777 // C1).
2778 const SCEV *A = Ops[0];
2779 const SCEV *B = Ops[1];
2780 auto *AddExpr = dyn_cast<SCEVAddExpr>(B);
2781 auto *C = dyn_cast<SCEVConstant>(A);
2782 if (AddExpr && C && isa<SCEVConstant>(AddExpr->getOperand(0))) {
2783 auto C1 = cast<SCEVConstant>(AddExpr->getOperand(0))->getAPInt();
2784 auto C2 = C->getAPInt();
2785 SCEV::NoWrapFlags PreservedFlags = SCEV::FlagAnyWrap;
2786
2787 APInt ConstAdd = C1 + C2;
2788 auto AddFlags = AddExpr->getNoWrapFlags();
2789 // Adding a smaller constant is NUW if the original AddExpr was NUW.
2791 ConstAdd.ule(C1)) {
2792 PreservedFlags =
2794 }
2795
2796 // Adding a constant with the same sign and small magnitude is NSW, if the
2797 // original AddExpr was NSW.
2799 C1.isSignBitSet() == ConstAdd.isSignBitSet() &&
2800 ConstAdd.abs().ule(C1.abs())) {
2801 PreservedFlags =
2803 }
2804
2805 if (PreservedFlags != SCEV::FlagAnyWrap) {
2806 SmallVector<SCEVUse, 4> NewOps(AddExpr->operands());
2807 NewOps[0] = getConstant(ConstAdd);
2808 return getAddExpr(NewOps, PreservedFlags);
2809 }
2810 }
2811
2812 // Try to push the constant operand into a ZExt: A + zext (-A + B) -> zext
2813 // (B), if trunc (A) + -A + B does not unsigned-wrap.
2814 const SCEVAddExpr *InnerAdd;
2815 if (match(B, m_scev_ZExt(m_scev_Add(InnerAdd)))) {
2816 const SCEV *NarrowA = getTruncateExpr(A, InnerAdd->getType());
2817 if (NarrowA == getNegativeSCEV(InnerAdd->getOperand(0)) &&
2818 getZeroExtendExpr(NarrowA, B->getType()) == A &&
2819 hasFlags(StrengthenNoWrapFlags(this, scAddExpr, {NarrowA, InnerAdd},
2821 SCEV::FlagNUW)) {
2822 return getZeroExtendExpr(getAddExpr(NarrowA, InnerAdd), B->getType());
2823 }
2824 }
2825 }
2826
2827 // Canonicalize (-1 * urem X, Y) + X --> (Y * X/Y)
2828 const SCEV *Y;
2829 if (Ops.size() == 2 &&
2830 match(Ops[0],
2832 m_scev_URem(m_scev_Specific(Ops[1]), m_SCEV(Y), *this))))
2833 return getMulExpr(Y, getUDivExpr(Ops[1], Y));
2834
2835 // Skip past any other cast SCEVs.
2836 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
2837 ++Idx;
2838
2839 // If there are add operands they would be next.
2840 if (Idx < Ops.size()) {
2841 bool DeletedAdd = false;
2842 // If the original flags and all inlined SCEVAddExprs are NUW, use the
2843 // common NUW flag for expression after inlining. Other flags cannot be
2844 // preserved, because they may depend on the original order of operations.
2845 SCEV::NoWrapFlags CommonFlags = maskFlags(OrigFlags, SCEV::FlagNUW);
2846 while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
2847 if (Ops.size() > AddOpsInlineThreshold ||
2848 Add->getNumOperands() > AddOpsInlineThreshold)
2849 break;
2850 // If we have an add, expand the add operands onto the end of the operands
2851 // list.
2852 Ops.erase(Ops.begin()+Idx);
2853 append_range(Ops, Add->operands());
2854 DeletedAdd = true;
2855 CommonFlags = maskFlags(CommonFlags, Add->getNoWrapFlags());
2856 }
2857
2858 // If we deleted at least one add, we added operands to the end of the list,
2859 // and they are not necessarily sorted. Recurse to resort and resimplify
2860 // any operands we just acquired.
2861 if (DeletedAdd)
2862 return getAddExpr(Ops, CommonFlags, Depth + 1);
2863 }
2864
2865 // Skip over the add expression until we get to a multiply.
2866 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
2867 ++Idx;
2868
2869 // Check to see if there are any folding opportunities present with
2870 // operands multiplied by constant values.
2871 if (Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx])) {
2875 APInt AccumulatedConstant(BitWidth, 0);
2876 if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2877 Ops, APInt(BitWidth, 1), *this)) {
2878 struct APIntCompare {
2879 bool operator()(const APInt &LHS, const APInt &RHS) const {
2880 return LHS.ult(RHS);
2881 }
2882 };
2883
2884 // Some interesting folding opportunity is present, so its worthwhile to
2885 // re-generate the operands list. Group the operands by constant scale,
2886 // to avoid multiplying by the same constant scale multiple times.
2887 std::map<APInt, SmallVector<SCEVUse, 4>, APIntCompare> MulOpLists;
2888 for (const SCEV *NewOp : NewOps)
2889 MulOpLists[M.find(NewOp)->second].push_back(NewOp);
2890 // Re-generate the operands list.
2891 Ops.clear();
2892 if (AccumulatedConstant != 0)
2893 Ops.push_back(getConstant(AccumulatedConstant));
2894 for (auto &MulOp : MulOpLists) {
2895 if (MulOp.first == 1) {
2896 Ops.push_back(getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1));
2897 } else if (MulOp.first != 0) {
2898 Ops.push_back(getMulExpr(
2899 getConstant(MulOp.first),
2900 getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1),
2901 SCEV::FlagAnyWrap, Depth + 1));
2902 }
2903 }
2904 if (Ops.empty())
2905 return getZero(Ty);
2906 if (Ops.size() == 1)
2907 return Ops[0];
2908 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2909 }
2910 }
2911
2912 // If we are adding something to a multiply expression, make sure the
2913 // something is not already an operand of the multiply. If so, merge it into
2914 // the multiply.
2915 for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) {
2916 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]);
2917 for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) {
2918 const SCEV *MulOpSCEV = Mul->getOperand(MulOp);
2919 if (isa<SCEVConstant>(MulOpSCEV))
2920 continue;
2921 for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
2922 if (MulOpSCEV == Ops[AddOp]) {
2923 // Fold W + X + (X * Y * Z) --> W + (X * ((Y*Z)+1))
2924 const SCEV *InnerMul = Mul->getOperand(MulOp == 0);
2925 if (Mul->getNumOperands() != 2) {
2926 // If the multiply has more than two operands, we must get the
2927 // Y*Z term.
2928 SmallVector<SCEVUse, 4> MulOps(Mul->operands().take_front(MulOp));
2929 append_range(MulOps, Mul->operands().drop_front(MulOp + 1));
2930 InnerMul = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2931 }
2932 const SCEV *AddOne =
2933 getAddExpr(getOne(Ty), InnerMul, SCEV::FlagAnyWrap, Depth + 1);
2934 const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV,
2936 if (Ops.size() == 2) return OuterMul;
2937 if (AddOp < Idx) {
2938 Ops.erase(Ops.begin()+AddOp);
2939 Ops.erase(Ops.begin()+Idx-1);
2940 } else {
2941 Ops.erase(Ops.begin()+Idx);
2942 Ops.erase(Ops.begin()+AddOp-1);
2943 }
2944 Ops.push_back(OuterMul);
2945 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2946 }
2947
2948 // Check this multiply against other multiplies being added together.
2949 for (unsigned OtherMulIdx = Idx+1;
2950 OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]);
2951 ++OtherMulIdx) {
2952 const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]);
2953 // If MulOp occurs in OtherMul, we can fold the two multiplies
2954 // together.
2955 for (unsigned OMulOp = 0, e = OtherMul->getNumOperands();
2956 OMulOp != e; ++OMulOp)
2957 if (OtherMul->getOperand(OMulOp) == MulOpSCEV) {
2958 // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E))
2959 const SCEV *InnerMul1 = Mul->getOperand(MulOp == 0);
2960 if (Mul->getNumOperands() != 2) {
2961 SmallVector<SCEVUse, 4> MulOps(Mul->operands().take_front(MulOp));
2962 append_range(MulOps, Mul->operands().drop_front(MulOp+1));
2963 InnerMul1 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2964 }
2965 const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0);
2966 if (OtherMul->getNumOperands() != 2) {
2968 OtherMul->operands().take_front(OMulOp));
2969 append_range(MulOps, OtherMul->operands().drop_front(OMulOp+1));
2970 InnerMul2 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2971 }
2972 const SCEV *InnerMulSum =
2973 getAddExpr(InnerMul1, InnerMul2, SCEV::FlagAnyWrap, Depth + 1);
2974 const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum,
2976 if (Ops.size() == 2) return OuterMul;
2977 Ops.erase(Ops.begin()+Idx);
2978 Ops.erase(Ops.begin()+OtherMulIdx-1);
2979 Ops.push_back(OuterMul);
2980 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2981 }
2982 }
2983 }
2984 }
2985
2986 // If there are any add recurrences in the operands list, see if any other
2987 // added values are loop invariant. If so, we can fold them into the
2988 // recurrence.
2989 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
2990 ++Idx;
2991
2992 // Scan over all recurrences, trying to fold loop invariants into them.
2993 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
2994 // Scan all of the other operands to this add and add them to the vector if
2995 // they are loop invariant w.r.t. the recurrence.
2997 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
2998 const Loop *AddRecLoop = AddRec->getLoop();
2999 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3000 if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) {
3001 LIOps.push_back(Ops[i]);
3002 Ops.erase(Ops.begin()+i);
3003 --i; --e;
3004 }
3005
3006 // If we found some loop invariants, fold them into the recurrence.
3007 if (!LIOps.empty()) {
3008 // Compute nowrap flags for the addition of the loop-invariant ops and
3009 // the addrec. Temporarily push it as an operand for that purpose. These
3010 // flags are valid in the scope of the addrec only.
3011 LIOps.push_back(AddRec);
3012 SCEV::NoWrapFlags Flags = ComputeFlags(LIOps);
3013 LIOps.pop_back();
3014
3015 // NLI + LI + {Start,+,Step} --> NLI + {LI+Start,+,Step}
3016 LIOps.push_back(AddRec->getStart());
3017
3018 SmallVector<SCEVUse, 4> AddRecOps(AddRec->operands());
3019
3020 // It is not in general safe to propagate flags valid on an add within
3021 // the addrec scope to one outside it. We must prove that the inner
3022 // scope is guaranteed to execute if the outer one does to be able to
3023 // safely propagate. We know the program is undefined if poison is
3024 // produced on the inner scoped addrec. We also know that *for this use*
3025 // the outer scoped add can't overflow (because of the flags we just
3026 // computed for the inner scoped add) without the program being undefined.
3027 // Proving that entry to the outer scope neccesitates entry to the inner
3028 // scope, thus proves the program undefined if the flags would be violated
3029 // in the outer scope.
3030 SCEV::NoWrapFlags AddFlags = Flags;
3031 if (AddFlags != SCEV::FlagAnyWrap) {
3032 auto *DefI = getDefiningScopeBound(LIOps);
3033 auto *ReachI = &*AddRecLoop->getHeader()->begin();
3034 if (!isGuaranteedToTransferExecutionTo(DefI, ReachI))
3035 AddFlags = SCEV::FlagAnyWrap;
3036 }
3037 AddRecOps[0] = getAddExpr(LIOps, AddFlags, Depth + 1);
3038
3039 // Build the new addrec. Propagate the NUW and NSW flags if both the
3040 // outer add and the inner addrec are guaranteed to have no overflow.
3041 // Always propagate NW.
3042 Flags = AddRec->getNoWrapFlags(setFlags(Flags, SCEV::FlagNW));
3043 const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags);
3044
3045 // If all of the other operands were loop invariant, we are done.
3046 if (Ops.size() == 1) return NewRec;
3047
3048 // Otherwise, add the folded AddRec by the non-invariant parts.
3049 for (unsigned i = 0;; ++i)
3050 if (Ops[i] == AddRec) {
3051 Ops[i] = NewRec;
3052 break;
3053 }
3054 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3055 }
3056
3057 // Okay, if there weren't any loop invariants to be folded, check to see if
3058 // there are multiple AddRec's with the same loop induction variable being
3059 // added together. If so, we can fold them.
3060 for (unsigned OtherIdx = Idx+1;
3061 OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
3062 ++OtherIdx) {
3063 // We expect the AddRecExpr's to be sorted in reverse dominance order,
3064 // so that the 1st found AddRecExpr is dominated by all others.
3065 assert(DT.dominates(
3066 cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()->getHeader(),
3067 AddRec->getLoop()->getHeader()) &&
3068 "AddRecExprs are not sorted in reverse dominance order?");
3069 if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) {
3070 // Other + {A,+,B}<L> + {C,+,D}<L> --> Other + {A+C,+,B+D}<L>
3071 SmallVector<SCEVUse, 4> AddRecOps(AddRec->operands());
3072 for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
3073 ++OtherIdx) {
3074 const auto *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
3075 if (OtherAddRec->getLoop() == AddRecLoop) {
3076 for (unsigned i = 0, e = OtherAddRec->getNumOperands();
3077 i != e; ++i) {
3078 if (i >= AddRecOps.size()) {
3079 append_range(AddRecOps, OtherAddRec->operands().drop_front(i));
3080 break;
3081 }
3082 AddRecOps[i] =
3083 getAddExpr(AddRecOps[i], OtherAddRec->getOperand(i),
3085 }
3086 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
3087 }
3088 }
3089 // Step size has changed, so we cannot guarantee no self-wraparound.
3090 Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap);
3091 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3092 }
3093 }
3094
3095 // Otherwise couldn't fold anything into this recurrence. Move onto the
3096 // next one.
3097 }
3098
3099 // Okay, it looks like we really DO need an add expr. Check to see if we
3100 // already have one, otherwise create a new one.
3101 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
3102}
3103
3104const SCEV *ScalarEvolution::getOrCreateAddExpr(ArrayRef<SCEVUse> Ops,
3105 SCEV::NoWrapFlags Flags) {
3107 ID.AddInteger(scAddExpr);
3108 for (const SCEV *Op : Ops)
3109 ID.AddPointer(Op);
3110 void *IP = nullptr;
3111 SCEVAddExpr *S =
3112 static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3113 if (!S) {
3114 SCEVUse *O = SCEVAllocator.Allocate<SCEVUse>(Ops.size());
3116 S = new (SCEVAllocator)
3117 SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size());
3118 UniqueSCEVs.InsertNode(S, IP);
3119 S->computeAndSetCanonical(*this);
3120 registerUser(S, Ops);
3121 }
3122 S->setNoWrapFlags(Flags);
3123 return S;
3124}
3125
3126const SCEV *ScalarEvolution::getOrCreateAddRecExpr(ArrayRef<SCEVUse> Ops,
3127 const Loop *L,
3128 SCEV::NoWrapFlags Flags) {
3129 FoldingSetNodeID ID;
3130 ID.AddInteger(scAddRecExpr);
3131 for (const SCEV *Op : Ops)
3132 ID.AddPointer(Op);
3133 ID.AddPointer(L);
3134 void *IP = nullptr;
3135 SCEVAddRecExpr *S =
3136 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3137 if (!S) {
3138 SCEVUse *O = SCEVAllocator.Allocate<SCEVUse>(Ops.size());
3140 S = new (SCEVAllocator)
3141 SCEVAddRecExpr(ID.Intern(SCEVAllocator), O, Ops.size(), L);
3142 UniqueSCEVs.InsertNode(S, IP);
3143 S->computeAndSetCanonical(*this);
3144 LoopUsers[L].push_back(S);
3145 registerUser(S, Ops);
3146 }
3147 setNoWrapFlags(S, Flags);
3148 return S;
3149}
3150
3151const SCEV *ScalarEvolution::getOrCreateMulExpr(ArrayRef<SCEVUse> Ops,
3152 SCEV::NoWrapFlags Flags) {
3153 FoldingSetNodeID ID;
3154 ID.AddInteger(scMulExpr);
3155 for (const SCEV *Op : Ops)
3156 ID.AddPointer(Op);
3157 void *IP = nullptr;
3158 SCEVMulExpr *S =
3159 static_cast<SCEVMulExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3160 if (!S) {
3161 SCEVUse *O = SCEVAllocator.Allocate<SCEVUse>(Ops.size());
3163 S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator),
3164 O, Ops.size());
3165 UniqueSCEVs.InsertNode(S, IP);
3166 S->computeAndSetCanonical(*this);
3167 registerUser(S, Ops);
3168 }
3169 S->setNoWrapFlags(Flags);
3170 return S;
3171}
3172
3173static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) {
3174 uint64_t k = i*j;
3175 if (j > 1 && k / j != i) Overflow = true;
3176 return k;
3177}
3178
3179/// Compute the result of "n choose k", the binomial coefficient. If an
3180/// intermediate computation overflows, Overflow will be set and the return will
3181/// be garbage. Overflow is not cleared on absence of overflow.
3182static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) {
3183 // We use the multiplicative formula:
3184 // n(n-1)(n-2)...(n-(k-1)) / k(k-1)(k-2)...1 .
3185 // At each iteration, we take the n-th term of the numeral and divide by the
3186 // (k-n)th term of the denominator. This division will always produce an
3187 // integral result, and helps reduce the chance of overflow in the
3188 // intermediate computations. However, we can still overflow even when the
3189 // final result would fit.
3190
3191 if (n == 0 || n == k) return 1;
3192 if (k > n) return 0;
3193
3194 if (k > n/2)
3195 k = n-k;
3196
3197 uint64_t r = 1;
3198 for (uint64_t i = 1; i <= k; ++i) {
3199 r = umul_ov(r, n-(i-1), Overflow);
3200 r /= i;
3201 }
3202 return r;
3203}
3204
3205/// Determine if any of the operands in this SCEV are a constant or if
3206/// any of the add or multiply expressions in this SCEV contain a constant.
3207static bool containsConstantInAddMulChain(const SCEV *StartExpr) {
3208 struct FindConstantInAddMulChain {
3209 bool FoundConstant = false;
3210
3211 bool follow(const SCEV *S) {
3212 FoundConstant |= isa<SCEVConstant>(S);
3213 return isa<SCEVAddExpr>(S) || isa<SCEVMulExpr>(S);
3214 }
3215
3216 bool isDone() const {
3217 return FoundConstant;
3218 }
3219 };
3220
3221 FindConstantInAddMulChain F;
3223 ST.visitAll(StartExpr);
3224 return F.FoundConstant;
3225}
3226
3227/// Get a canonical multiply expression, or something simpler if possible.
3229 SCEV::NoWrapFlags OrigFlags,
3230 unsigned Depth) {
3231 assert(OrigFlags == maskFlags(OrigFlags, SCEV::FlagNUW | SCEV::FlagNSW) &&
3232 "only nuw or nsw allowed");
3233 assert(!Ops.empty() && "Cannot get empty mul!");
3234 if (Ops.size() == 1) return Ops[0];
3235#ifndef NDEBUG
3236 Type *ETy = Ops[0]->getType();
3237 assert(!ETy->isPointerTy());
3238 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
3239 assert(Ops[i]->getType() == ETy &&
3240 "SCEVMulExpr operand types don't match!");
3241#endif
3242
3243 const SCEV *Folded = constantFoldAndGroupOps(
3244 *this, LI, DT, Ops,
3245 [](const APInt &C1, const APInt &C2) { return C1 * C2; },
3246 [](const APInt &C) { return C.isOne(); }, // identity
3247 [](const APInt &C) { return C.isZero(); }); // absorber
3248 if (Folded)
3249 return Folded;
3250
3251 // Delay expensive flag strengthening until necessary.
3252 auto ComputeFlags = [this, OrigFlags](const ArrayRef<SCEVUse> Ops) {
3253 return StrengthenNoWrapFlags(this, scMulExpr, Ops, OrigFlags);
3254 };
3255
3256 // Limit recursion calls depth.
3258 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3259
3260 if (SCEV *S = findExistingSCEVInCache(scMulExpr, Ops)) {
3261 // Don't strengthen flags if we have no new information.
3262 SCEVMulExpr *Mul = static_cast<SCEVMulExpr *>(S);
3263 if (Mul->getNoWrapFlags(OrigFlags) != OrigFlags)
3264 Mul->setNoWrapFlags(ComputeFlags(Ops));
3265 return S;
3266 }
3267
3268 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3269 if (Ops.size() == 2) {
3270 // C1*(C2+V) -> C1*C2 + C1*V
3271 // If any of Add's ops are Adds or Muls with a constant, apply this
3272 // transformation as well.
3273 //
3274 // TODO: There are some cases where this transformation is not
3275 // profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of
3276 // this transformation should be narrowed down.
3277 const SCEV *Op0, *Op1;
3278 if (match(Ops[1], m_scev_Add(m_SCEV(Op0), m_SCEV(Op1))) &&
3280 const SCEV *LHS = getMulExpr(LHSC, Op0, SCEV::FlagAnyWrap, Depth + 1);
3281 const SCEV *RHS = getMulExpr(LHSC, Op1, SCEV::FlagAnyWrap, Depth + 1);
3282 return getAddExpr(LHS, RHS, SCEV::FlagAnyWrap, Depth + 1);
3283 }
3284
3285 if (Ops[0]->isAllOnesValue()) {
3286 // If we have a mul by -1 of an add, try distributing the -1 among the
3287 // add operands.
3288 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) {
3290 bool AnyFolded = false;
3291 for (const SCEV *AddOp : Add->operands()) {
3292 const SCEV *Mul = getMulExpr(Ops[0], SCEVUse(AddOp),
3294 if (!isa<SCEVMulExpr>(Mul)) AnyFolded = true;
3295 NewOps.push_back(Mul);
3296 }
3297 if (AnyFolded)
3298 return getAddExpr(NewOps, SCEV::FlagAnyWrap, Depth + 1);
3299 } else if (const auto *AddRec = dyn_cast<SCEVAddRecExpr>(Ops[1])) {
3300 // Negation preserves a recurrence's no self-wrap property.
3301 SmallVector<SCEVUse, 4> Operands;
3302 for (const SCEV *AddRecOp : AddRec->operands())
3303 Operands.push_back(getMulExpr(Ops[0], SCEVUse(AddRecOp),
3304 SCEV::FlagAnyWrap, Depth + 1));
3305 // Let M be the minimum representable signed value. AddRec with nsw
3306 // multiplied by -1 can have signed overflow if and only if it takes a
3307 // value of M: M * (-1) would stay M and (M + 1) * (-1) would be the
3308 // maximum signed value. In all other cases signed overflow is
3309 // impossible.
3310 auto FlagsMask = SCEV::FlagNW;
3311 if (AddRec->hasNoSignedWrap()) {
3312 auto MinInt =
3313 APInt::getSignedMinValue(getTypeSizeInBits(AddRec->getType()));
3314 if (getSignedRangeMin(AddRec) != MinInt)
3315 FlagsMask = setFlags(FlagsMask, SCEV::FlagNSW);
3316 }
3317 return getAddRecExpr(Operands, AddRec->getLoop(),
3318 AddRec->getNoWrapFlags(FlagsMask));
3319 }
3320 }
3321
3322 // Try to push the constant operand into a ZExt: C * zext (A + B) ->
3323 // zext (C*A + C*B) if trunc (C) * (A + B) does not unsigned-wrap.
3324 const SCEVAddExpr *InnerAdd;
3325 if (match(Ops[1], m_scev_ZExt(m_scev_Add(InnerAdd)))) {
3326 const SCEV *NarrowC = getTruncateExpr(LHSC, InnerAdd->getType());
3327 if (isa<SCEVConstant>(InnerAdd->getOperand(0)) &&
3328 getZeroExtendExpr(NarrowC, Ops[1]->getType()) == LHSC &&
3329 hasFlags(StrengthenNoWrapFlags(this, scMulExpr, {NarrowC, InnerAdd},
3331 SCEV::FlagNUW)) {
3332 auto *Res = getMulExpr(NarrowC, InnerAdd, SCEV::FlagNUW, Depth + 1);
3333 return getZeroExtendExpr(Res, Ops[1]->getType(), Depth + 1);
3334 };
3335 }
3336
3337 // Try to fold (C1 * D /u C2) -> C1/C2 * D, if C1 and C2 are powers-of-2,
3338 // D is a multiple of C2, and C1 is a multiple of C2. If C2 is a multiple
3339 // of C1, fold to (D /u (C2 /u C1)).
3340 const SCEV *D;
3341 APInt C1V = LHSC->getAPInt();
3342 // (C1 * D /u C2) == -1 * -C1 * D /u C2 when C1 != INT_MIN. Don't treat -1
3343 // as -1 * 1, as it won't enable additional folds.
3344 if (C1V.isNegative() && !C1V.isMinSignedValue() && !C1V.isAllOnes())
3345 C1V = C1V.abs();
3346 const SCEVConstant *C2;
3347 if (C1V.isPowerOf2() &&
3349 C2->getAPInt().isPowerOf2() &&
3350 C1V.logBase2() <= getMinTrailingZeros(D)) {
3351 const SCEV *NewMul = nullptr;
3352 if (C1V.uge(C2->getAPInt())) {
3353 NewMul = getMulExpr(getUDivExpr(getConstant(C1V), C2), D);
3354 } else if (C2->getAPInt().logBase2() <= getMinTrailingZeros(D)) {
3355 assert(C1V.ugt(1) && "C1 <= 1 should have been folded earlier");
3356 NewMul = getUDivExpr(D, getUDivExpr(C2, getConstant(C1V)));
3357 }
3358 if (NewMul)
3359 return C1V == LHSC->getAPInt() ? NewMul : getNegativeSCEV(NewMul);
3360 }
3361 }
3362 }
3363
3364 // Skip over the add expression until we get to a multiply.
3365 unsigned Idx = 0;
3366 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
3367 ++Idx;
3368
3369 // If there are mul operands inline them all into this expression.
3370 if (Idx < Ops.size()) {
3371 bool DeletedMul = false;
3372 while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
3373 if (Ops.size() > MulOpsInlineThreshold)
3374 break;
3375 // If we have an mul, expand the mul operands onto the end of the
3376 // operands list.
3377 Ops.erase(Ops.begin()+Idx);
3378 append_range(Ops, Mul->operands());
3379 DeletedMul = true;
3380 }
3381
3382 // If we deleted at least one mul, we added operands to the end of the
3383 // list, and they are not necessarily sorted. Recurse to resort and
3384 // resimplify any operands we just acquired.
3385 if (DeletedMul)
3386 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3387 }
3388
3389 // If there are any add recurrences in the operands list, see if any other
3390 // added values are loop invariant. If so, we can fold them into the
3391 // recurrence.
3392 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
3393 ++Idx;
3394
3395 // Scan over all recurrences, trying to fold loop invariants into them.
3396 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
3397 // Scan all of the other operands to this mul and add them to the vector
3398 // if they are loop invariant w.r.t. the recurrence.
3400 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
3401 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3402 if (isAvailableAtLoopEntry(Ops[i], AddRec->getLoop())) {
3403 LIOps.push_back(Ops[i]);
3404 Ops.erase(Ops.begin()+i);
3405 --i; --e;
3406 }
3407
3408 // If we found some loop invariants, fold them into the recurrence.
3409 if (!LIOps.empty()) {
3410 // NLI * LI * {Start,+,Step} --> NLI * {LI*Start,+,LI*Step}
3412 NewOps.reserve(AddRec->getNumOperands());
3413 const SCEV *Scale = getMulExpr(LIOps, SCEV::FlagAnyWrap, Depth + 1);
3414
3415 // If both the mul and addrec are nuw, we can preserve nuw.
3416 // If both the mul and addrec are nsw, we can only preserve nsw if either
3417 // a) they are also nuw, or
3418 // b) all multiplications of addrec operands with scale are nsw.
3419 SCEV::NoWrapFlags Flags =
3420 AddRec->getNoWrapFlags(ComputeFlags({Scale, AddRec}));
3421
3422 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
3423 NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i),
3424 SCEV::FlagAnyWrap, Depth + 1));
3425
3426 if (hasFlags(Flags, SCEV::FlagNSW) && !hasFlags(Flags, SCEV::FlagNUW)) {
3428 Instruction::Mul, getSignedRange(Scale),
3430 if (!NSWRegion.contains(getSignedRange(AddRec->getOperand(i))))
3431 Flags = clearFlags(Flags, SCEV::FlagNSW);
3432 }
3433 }
3434
3435 const SCEV *NewRec = getAddRecExpr(NewOps, AddRec->getLoop(), Flags);
3436
3437 // If all of the other operands were loop invariant, we are done.
3438 if (Ops.size() == 1) return NewRec;
3439
3440 // Otherwise, multiply the folded AddRec by the non-invariant parts.
3441 for (unsigned i = 0;; ++i)
3442 if (Ops[i] == AddRec) {
3443 Ops[i] = NewRec;
3444 break;
3445 }
3446 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3447 }
3448
3449 // Okay, if there weren't any loop invariants to be folded, check to see
3450 // if there are multiple AddRec's with the same loop induction variable
3451 // being multiplied together. If so, we can fold them.
3452
3453 // {A1,+,A2,+,...,+,An}<L> * {B1,+,B2,+,...,+,Bn}<L>
3454 // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [
3455 // choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z
3456 // ]]],+,...up to x=2n}.
3457 // Note that the arguments to choose() are always integers with values
3458 // known at compile time, never SCEV objects.
3459 //
3460 // The implementation avoids pointless extra computations when the two
3461 // addrec's are of different length (mathematically, it's equivalent to
3462 // an infinite stream of zeros on the right).
3463 bool OpsModified = false;
3464 for (unsigned OtherIdx = Idx+1;
3465 OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
3466 ++OtherIdx) {
3467 const SCEVAddRecExpr *OtherAddRec =
3468 dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]);
3469 if (!OtherAddRec || OtherAddRec->getLoop() != AddRec->getLoop())
3470 continue;
3471
3472 // Limit max number of arguments to avoid creation of unreasonably big
3473 // SCEVAddRecs with very complex operands.
3474 if (AddRec->getNumOperands() + OtherAddRec->getNumOperands() - 1 >
3475 MaxAddRecSize || hasHugeExpression({AddRec, OtherAddRec}))
3476 continue;
3477
3478 bool Overflow = false;
3479 Type *Ty = AddRec->getType();
3480 bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64;
3481 SmallVector<SCEVUse, 7> AddRecOps;
3482 for (int x = 0, xe = AddRec->getNumOperands() +
3483 OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) {
3485 for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) {
3486 uint64_t Coeff1 = Choose(x, 2*x - y, Overflow);
3487 for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1),
3488 ze = std::min(x+1, (int)OtherAddRec->getNumOperands());
3489 z < ze && !Overflow; ++z) {
3490 uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow);
3491 uint64_t Coeff;
3492 if (LargerThan64Bits)
3493 Coeff = umul_ov(Coeff1, Coeff2, Overflow);
3494 else
3495 Coeff = Coeff1*Coeff2;
3496 const SCEV *CoeffTerm = getConstant(Ty, Coeff);
3497 const SCEV *Term1 = AddRec->getOperand(y-z);
3498 const SCEV *Term2 = OtherAddRec->getOperand(z);
3499 SumOps.push_back(getMulExpr(CoeffTerm, Term1, Term2,
3500 SCEV::FlagAnyWrap, Depth + 1));
3501 }
3502 }
3503 if (SumOps.empty())
3504 SumOps.push_back(getZero(Ty));
3505 AddRecOps.push_back(getAddExpr(SumOps, SCEV::FlagAnyWrap, Depth + 1));
3506 }
3507 if (!Overflow) {
3508 const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRec->getLoop(),
3510 if (Ops.size() == 2) return NewAddRec;
3511 Ops[Idx] = NewAddRec;
3512 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
3513 OpsModified = true;
3514 AddRec = dyn_cast<SCEVAddRecExpr>(NewAddRec);
3515 if (!AddRec)
3516 break;
3517 }
3518 }
3519 if (OpsModified)
3520 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3521
3522 // Otherwise couldn't fold anything into this recurrence. Move onto the
3523 // next one.
3524 }
3525
3526 // Okay, it looks like we really DO need an mul expr. Check to see if we
3527 // already have one, otherwise create a new one.
3528 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3529}
3530
3531/// Represents an unsigned remainder expression based on unsigned division.
3533 assert(getEffectiveSCEVType(LHS->getType()) ==
3534 getEffectiveSCEVType(RHS->getType()) &&
3535 "SCEVURemExpr operand types don't match!");
3536
3537 // Short-circuit easy cases
3538 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3539 // If constant is one, the result is trivial
3540 if (RHSC->getValue()->isOne())
3541 return getZero(LHS->getType()); // X urem 1 --> 0
3542
3543 // If constant is a power of two, fold into a zext(trunc(LHS)).
3544 if (RHSC->getAPInt().isPowerOf2()) {
3545 Type *FullTy = LHS->getType();
3546 Type *TruncTy =
3547 IntegerType::get(getContext(), RHSC->getAPInt().logBase2());
3548 return getZeroExtendExpr(getTruncateExpr(LHS, TruncTy), FullTy);
3549 }
3550 }
3551
3552 // Fallback to %a == %x urem %y == %x -<nuw> ((%x udiv %y) *<nuw> %y)
3553 const SCEV *UDiv = getUDivExpr(LHS, RHS);
3554 const SCEV *Mult = getMulExpr(UDiv, RHS, SCEV::FlagNUW);
3555 return getMinusSCEV(LHS, Mult, SCEV::FlagNUW);
3556}
3557
3558/// Get a canonical unsigned division expression, or something simpler if
3559/// possible.
3561 assert(!LHS->getType()->isPointerTy() &&
3562 "SCEVUDivExpr operand can't be pointer!");
3563 assert(LHS->getType() == RHS->getType() &&
3564 "SCEVUDivExpr operand types don't match!");
3565
3567 ID.AddInteger(scUDivExpr);
3568 ID.AddPointer(LHS);
3569 ID.AddPointer(RHS);
3570 void *IP = nullptr;
3571 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3572 return S;
3573
3574 // 0 udiv Y == 0
3575 if (match(LHS, m_scev_Zero()))
3576 return LHS;
3577
3578 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3579 if (RHSC->getValue()->isOne())
3580 return LHS; // X udiv 1 --> x
3581 // If the denominator is zero, the result of the udiv is undefined. Don't
3582 // try to analyze it, because the resolution chosen here may differ from
3583 // the resolution chosen in other parts of the compiler.
3584 if (!RHSC->getValue()->isZero()) {
3585 // Determine if the division can be folded into the operands of
3586 // its operands.
3587 // TODO: Generalize this to non-constants by using known-bits information.
3588 Type *Ty = LHS->getType();
3589 unsigned LZ = RHSC->getAPInt().countl_zero();
3590 unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1;
3591 // For non-power-of-two values, effectively round the value up to the
3592 // nearest power of two.
3593 if (!RHSC->getAPInt().isPowerOf2())
3594 ++MaxShiftAmt;
3595 IntegerType *ExtTy =
3596 IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt);
3597 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
3598 if (const SCEVConstant *Step =
3599 dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this))) {
3600 // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded.
3601 const APInt &StepInt = Step->getAPInt();
3602 const APInt &DivInt = RHSC->getAPInt();
3603 if (!StepInt.urem(DivInt) &&
3604 getZeroExtendExpr(AR, ExtTy) ==
3605 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3606 getZeroExtendExpr(Step, ExtTy),
3607 AR->getLoop(), SCEV::FlagAnyWrap)) {
3608 SmallVector<SCEVUse, 4> Operands;
3609 for (const SCEV *Op : AR->operands())
3610 Operands.push_back(getUDivExpr(Op, RHS));
3611 return getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagNW);
3612 }
3613 /// Get a canonical UDivExpr for a recurrence.
3614 /// {X,+,N}/C => {Y,+,N}/C where Y=X-(X%N). Safe when C%N=0.
3615 const APInt *StartRem;
3616 if (!DivInt.urem(StepInt) && match(getURemExpr(AR->getStart(), Step),
3617 m_scev_APInt(StartRem))) {
3618 bool NoWrap =
3619 getZeroExtendExpr(AR, ExtTy) ==
3620 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3621 getZeroExtendExpr(Step, ExtTy), AR->getLoop(),
3623
3624 // With N <= C and both N, C as powers-of-2, the transformation
3625 // {X,+,N}/C => {(X - X%N),+,N}/C preserves division results even
3626 // if wrapping occurs, as the division results remain equivalent for
3627 // all offsets in [[(X - X%N), X).
3628 bool CanFoldWithWrap = StepInt.ule(DivInt) && // N <= C
3629 StepInt.isPowerOf2() && DivInt.isPowerOf2();
3630 // Only fold if the subtraction can be folded in the start
3631 // expression.
3632 const SCEV *NewStart =
3633 getMinusSCEV(AR->getStart(), getConstant(*StartRem));
3634 if (*StartRem != 0 && (NoWrap || CanFoldWithWrap) &&
3635 !isa<SCEVAddExpr>(NewStart)) {
3636 const SCEV *NewLHS =
3637 getAddRecExpr(NewStart, Step, AR->getLoop(),
3638 NoWrap ? SCEV::FlagNW : SCEV::FlagAnyWrap);
3639 if (LHS != NewLHS) {
3640 LHS = NewLHS;
3641
3642 // Reset the ID to include the new LHS, and check if it is
3643 // already cached.
3644 ID.clear();
3645 ID.AddInteger(scUDivExpr);
3646 ID.AddPointer(LHS);
3647 ID.AddPointer(RHS);
3648 IP = nullptr;
3649 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3650 return S;
3651 }
3652 }
3653 }
3654 }
3655 // (A*B)/C --> A*(B/C) if safe and B/C can be folded.
3656 if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) {
3657 SmallVector<SCEVUse, 4> Operands;
3658 for (const SCEV *Op : M->operands())
3659 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3660 if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands)) {
3661 // Find an operand that's safely divisible.
3662 for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) {
3663 const SCEV *Op = M->getOperand(i);
3664 const SCEV *Div = getUDivExpr(Op, RHSC);
3665 if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) {
3666 Operands = SmallVector<SCEVUse, 4>(M->operands());
3667 Operands[i] = Div;
3668 return getMulExpr(Operands);
3669 }
3670 }
3671
3672 // Even if it's not divisible, try to remove a common factor.
3673 if (const auto *LHSC = dyn_cast<SCEVConstant>(M->getOperand(0))) {
3674 APInt Factor = APIntOps::GreatestCommonDivisor(LHSC->getAPInt(),
3675 RHSC->getAPInt());
3676 if (!Factor.isIntN(1)) {
3677 SmallVector<SCEVUse, 2> NewOperands;
3678 NewOperands.push_back(getConstant(LHSC->getAPInt().udiv(Factor)));
3679 append_range(NewOperands, M->operands().drop_front());
3680 const SCEV *NewMul = getMulExpr(NewOperands);
3681 return getUDivExpr(NewMul,
3682 getConstant(RHSC->getAPInt().udiv(Factor)));
3683 }
3684 }
3685 }
3686 }
3687
3688 // (A/B)/C --> A/(B*C) if safe and B*C can be folded.
3689 if (const SCEVUDivExpr *OtherDiv = dyn_cast<SCEVUDivExpr>(LHS)) {
3690 if (auto *DivisorConstant =
3691 dyn_cast<SCEVConstant>(OtherDiv->getRHS())) {
3692 bool Overflow = false;
3693 APInt NewRHS =
3694 DivisorConstant->getAPInt().umul_ov(RHSC->getAPInt(), Overflow);
3695 if (Overflow) {
3696 return getConstant(RHSC->getType(), 0, false);
3697 }
3698 return getUDivExpr(OtherDiv->getLHS(), getConstant(NewRHS));
3699 }
3700 }
3701
3702 // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded.
3703 if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(LHS)) {
3704 SmallVector<SCEVUse, 4> Operands;
3705 for (const SCEV *Op : A->operands())
3706 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3707 if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) {
3708 Operands.clear();
3709 for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) {
3710 const SCEV *Op = getUDivExpr(A->getOperand(i), RHS);
3711 if (isa<SCEVUDivExpr>(Op) ||
3712 getMulExpr(Op, RHS) != A->getOperand(i))
3713 break;
3714 Operands.push_back(Op);
3715 }
3716 if (Operands.size() == A->getNumOperands())
3717 return getAddExpr(Operands);
3718 }
3719 }
3720
3721 // Fold if both operands are constant.
3722 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS))
3723 return getConstant(LHSC->getAPInt().udiv(RHSC->getAPInt()));
3724 }
3725 }
3726
3727 // ((-C + (C smax %x)) /u %x) evaluates to zero, for any positive constant C.
3728 const APInt *NegC, *C;
3729 if (match(LHS,
3732 NegC->isNegative() && !NegC->isMinSignedValue() && *C == -*NegC)
3733 return getZero(LHS->getType());
3734
3735 // (%a * %b)<nuw> / %b -> %a
3736 const auto *Mul = dyn_cast<SCEVMulExpr>(LHS);
3737 if (Mul && Mul->hasNoUnsignedWrap()) {
3738 for (int i = 0, e = Mul->getNumOperands(); i != e; ++i) {
3739 if (Mul->getOperand(i) == RHS) {
3740 SmallVector<SCEVUse, 2> Operands;
3741 append_range(Operands, Mul->operands().take_front(i));
3742 append_range(Operands, Mul->operands().drop_front(i + 1));
3743 return getMulExpr(Operands);
3744 }
3745 }
3746 }
3747
3748 // TODO: Generalize to handle any common factors.
3749 // udiv (mul nuw a, vscale), (mul nuw b, vscale) --> udiv a, b
3750 const SCEV *NewLHS, *NewRHS;
3751 if (match(LHS, m_scev_c_NUWMul(m_SCEV(NewLHS), m_SCEVVScale())) &&
3752 match(RHS, m_scev_c_NUWMul(m_SCEV(NewRHS), m_SCEVVScale())))
3753 return getUDivExpr(NewLHS, NewRHS);
3754
3755 // The Insertion Point (IP) might be invalid by now (due to UniqueSCEVs
3756 // changes). Make sure we get a new one.
3757 IP = nullptr;
3758 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
3759 SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator),
3760 LHS, RHS);
3761 UniqueSCEVs.InsertNode(S, IP);
3762 S->computeAndSetCanonical(*this);
3763 registerUser(S, ArrayRef<SCEVUse>({LHS, RHS}));
3764 return S;
3765}
3766
3767APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) {
3768 APInt A = C1->getAPInt().abs();
3769 APInt B = C2->getAPInt().abs();
3770 uint32_t ABW = A.getBitWidth();
3771 uint32_t BBW = B.getBitWidth();
3772
3773 if (ABW > BBW)
3774 B = B.zext(ABW);
3775 else if (ABW < BBW)
3776 A = A.zext(BBW);
3777
3778 return APIntOps::GreatestCommonDivisor(std::move(A), std::move(B));
3779}
3780
3781/// Get a canonical unsigned division expression, or something simpler if
3782/// possible. There is no representation for an exact udiv in SCEV IR, but we
3783/// can attempt to optimize it prior to construction.
3785 // Currently there is no exact specific logic.
3786
3787 return getUDivExpr(LHS, RHS);
3788}
3789
3790/// Get an add recurrence expression for the specified loop. Simplify the
3791/// expression as much as possible.
3793 const Loop *L,
3794 SCEV::NoWrapFlags Flags) {
3795 SmallVector<SCEVUse, 4> Operands;
3796 Operands.push_back(Start);
3797 if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step))
3798 if (StepChrec->getLoop() == L) {
3799 append_range(Operands, StepChrec->operands());
3800 return getAddRecExpr(Operands, L, maskFlags(Flags, SCEV::FlagNW));
3801 }
3802
3803 Operands.push_back(Step);
3804 return getAddRecExpr(Operands, L, Flags);
3805}
3806
3807/// Get an add recurrence expression for the specified loop. Simplify the
3808/// expression as much as possible.
3810 const Loop *L,
3811 SCEV::NoWrapFlags Flags) {
3812 if (Operands.size() == 1) return Operands[0];
3813#ifndef NDEBUG
3814 Type *ETy = getEffectiveSCEVType(Operands[0]->getType());
3815 for (const SCEV *Op : llvm::drop_begin(Operands)) {
3816 assert(getEffectiveSCEVType(Op->getType()) == ETy &&
3817 "SCEVAddRecExpr operand types don't match!");
3818 assert(!Op->getType()->isPointerTy() && "Step must be integer");
3819 }
3820 for (const SCEV *Op : Operands)
3822 "SCEVAddRecExpr operand is not available at loop entry!");
3823#endif
3824
3825 if (Operands.back()->isZero()) {
3826 Operands.pop_back();
3827 return getAddRecExpr(Operands, L, SCEV::FlagAnyWrap); // {X,+,0} --> X
3828 }
3829
3830 // It's tempting to want to call getConstantMaxBackedgeTakenCount count here and
3831 // use that information to infer NUW and NSW flags. However, computing a
3832 // BE count requires calling getAddRecExpr, so we may not yet have a
3833 // meaningful BE count at this point (and if we don't, we'd be stuck
3834 // with a SCEVCouldNotCompute as the cached BE count).
3835
3836 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
3837
3838 // Canonicalize nested AddRecs in by nesting them in order of loop depth.
3839 if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
3840 const Loop *NestedLoop = NestedAR->getLoop();
3841 if (L->contains(NestedLoop)
3842 ? (L->getLoopDepth() < NestedLoop->getLoopDepth())
3843 : (!NestedLoop->contains(L) &&
3844 DT.dominates(L->getHeader(), NestedLoop->getHeader()))) {
3845 SmallVector<SCEVUse, 4> NestedOperands(NestedAR->operands());
3846 Operands[0] = NestedAR->getStart();
3847 // AddRecs require their operands be loop-invariant with respect to their
3848 // loops. Don't perform this transformation if it would break this
3849 // requirement.
3850 bool AllInvariant = all_of(
3851 Operands, [&](const SCEV *Op) { return isLoopInvariant(Op, L); });
3852
3853 if (AllInvariant) {
3854 // Create a recurrence for the outer loop with the same step size.
3855 //
3856 // The outer recurrence keeps its NW flag but only keeps NUW/NSW if the
3857 // inner recurrence has the same property.
3858 SCEV::NoWrapFlags OuterFlags =
3859 maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags());
3860
3861 NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags);
3862 AllInvariant = all_of(NestedOperands, [&](const SCEV *Op) {
3863 return isLoopInvariant(Op, NestedLoop);
3864 });
3865
3866 if (AllInvariant) {
3867 // Ok, both add recurrences are valid after the transformation.
3868 //
3869 // The inner recurrence keeps its NW flag but only keeps NUW/NSW if
3870 // the outer recurrence has the same property.
3871 SCEV::NoWrapFlags InnerFlags =
3872 maskFlags(NestedAR->getNoWrapFlags(), SCEV::FlagNW | Flags);
3873 return getAddRecExpr(NestedOperands, NestedLoop, InnerFlags);
3874 }
3875 }
3876 // Reset Operands to its original state.
3877 Operands[0] = NestedAR;
3878 }
3879 }
3880
3881 // Okay, it looks like we really DO need an addrec expr. Check to see if we
3882 // already have one, otherwise create a new one.
3883 return getOrCreateAddRecExpr(Operands, L, Flags);
3884}
3885
3887 ArrayRef<SCEVUse> IndexExprs) {
3888 const SCEV *BaseExpr = getSCEV(GEP->getPointerOperand());
3889 // getSCEV(Base)->getType() has the same address space as Base->getType()
3890 // because SCEV::getType() preserves the address space.
3891 GEPNoWrapFlags NW = GEP->getNoWrapFlags();
3892 if (NW != GEPNoWrapFlags::none()) {
3893 // We'd like to propagate flags from the IR to the corresponding SCEV nodes,
3894 // but to do that, we have to ensure that said flag is valid in the entire
3895 // defined scope of the SCEV.
3896 // TODO: non-instructions have global scope. We might be able to prove
3897 // some global scope cases
3898 auto *GEPI = dyn_cast<Instruction>(GEP);
3899 if (!GEPI || !isSCEVExprNeverPoison(GEPI))
3900 NW = GEPNoWrapFlags::none();
3901 }
3902
3903 return getGEPExpr(BaseExpr, IndexExprs, GEP->getSourceElementType(), NW);
3904}
3905
3907 ArrayRef<SCEVUse> IndexExprs,
3908 Type *SrcElementTy, GEPNoWrapFlags NW) {
3910 if (NW.hasNoUnsignedSignedWrap())
3911 OffsetWrap = setFlags(OffsetWrap, SCEV::FlagNSW);
3912 if (NW.hasNoUnsignedWrap())
3913 OffsetWrap = setFlags(OffsetWrap, SCEV::FlagNUW);
3914
3915 Type *CurTy = BaseExpr->getType();
3916 Type *IntIdxTy = getEffectiveSCEVType(BaseExpr->getType());
3917 bool FirstIter = true;
3919 for (SCEVUse IndexExpr : IndexExprs) {
3920 // Compute the (potentially symbolic) offset in bytes for this index.
3921 if (StructType *STy = dyn_cast<StructType>(CurTy)) {
3922 // For a struct, add the member offset.
3923 ConstantInt *Index = cast<SCEVConstant>(IndexExpr)->getValue();
3924 unsigned FieldNo = Index->getZExtValue();
3925 const SCEV *FieldOffset = getOffsetOfExpr(IntIdxTy, STy, FieldNo);
3926 Offsets.push_back(FieldOffset);
3927
3928 // Update CurTy to the type of the field at Index.
3929 CurTy = STy->getTypeAtIndex(Index);
3930 } else {
3931 // Update CurTy to its element type.
3932 if (FirstIter) {
3933 assert(isa<PointerType>(CurTy) &&
3934 "The first index of a GEP indexes a pointer");
3935 CurTy = SrcElementTy;
3936 FirstIter = false;
3937 } else {
3939 }
3940 // For an array, add the element offset, explicitly scaled.
3941 const SCEV *ElementSize = getSizeOfExpr(IntIdxTy, CurTy);
3942 // Getelementptr indices are signed.
3943 IndexExpr = getTruncateOrSignExtend(IndexExpr, IntIdxTy);
3944
3945 // Multiply the index by the element size to compute the element offset.
3946 const SCEV *LocalOffset = getMulExpr(IndexExpr, ElementSize, OffsetWrap);
3947 Offsets.push_back(LocalOffset);
3948 }
3949 }
3950
3951 // Handle degenerate case of GEP without offsets.
3952 if (Offsets.empty())
3953 return BaseExpr;
3954
3955 // Add the offsets together, assuming nsw if inbounds.
3956 const SCEV *Offset = getAddExpr(Offsets, OffsetWrap);
3957 // Add the base address and the offset. We cannot use the nsw flag, as the
3958 // base address is unsigned. However, if we know that the offset is
3959 // non-negative, we can use nuw.
3960 bool NUW = NW.hasNoUnsignedWrap() ||
3963 auto *GEPExpr = getAddExpr(BaseExpr, Offset, BaseWrap);
3964 assert(BaseExpr->getType() == GEPExpr->getType() &&
3965 "GEP should not change type mid-flight.");
3966 return GEPExpr;
3967}
3968
3969SCEV *ScalarEvolution::findExistingSCEVInCache(SCEVTypes SCEVType,
3972 ID.AddInteger(SCEVType);
3973 for (const SCEV *Op : Ops)
3974 ID.AddPointer(Op);
3975 void *IP = nullptr;
3976 return UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3977}
3978
3979SCEV *ScalarEvolution::findExistingSCEVInCache(SCEVTypes SCEVType,
3982 ID.AddInteger(SCEVType);
3983 for (const SCEV *Op : Ops)
3984 ID.AddPointer(Op);
3985 void *IP = nullptr;
3986 return UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3987}
3988
3989const SCEV *ScalarEvolution::getAbsExpr(const SCEV *Op, bool IsNSW) {
3991 return getSMaxExpr(Op, getNegativeSCEV(Op, Flags));
3992}
3993
3996 assert(SCEVMinMaxExpr::isMinMaxType(Kind) && "Not a SCEVMinMaxExpr!");
3997 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
3998 if (Ops.size() == 1) return Ops[0];
3999#ifndef NDEBUG
4000 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
4001 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4002 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
4003 "Operand types don't match!");
4004 assert(Ops[0]->getType()->isPointerTy() ==
4005 Ops[i]->getType()->isPointerTy() &&
4006 "min/max should be consistently pointerish");
4007 }
4008#endif
4009
4010 bool IsSigned = Kind == scSMaxExpr || Kind == scSMinExpr;
4011 bool IsMax = Kind == scSMaxExpr || Kind == scUMaxExpr;
4012
4013 const SCEV *Folded = constantFoldAndGroupOps(
4014 *this, LI, DT, Ops,
4015 [&](const APInt &C1, const APInt &C2) {
4016 switch (Kind) {
4017 case scSMaxExpr:
4018 return APIntOps::smax(C1, C2);
4019 case scSMinExpr:
4020 return APIntOps::smin(C1, C2);
4021 case scUMaxExpr:
4022 return APIntOps::umax(C1, C2);
4023 case scUMinExpr:
4024 return APIntOps::umin(C1, C2);
4025 default:
4026 llvm_unreachable("Unknown SCEV min/max opcode");
4027 }
4028 },
4029 [&](const APInt &C) {
4030 // identity
4031 if (IsMax)
4032 return IsSigned ? C.isMinSignedValue() : C.isMinValue();
4033 else
4034 return IsSigned ? C.isMaxSignedValue() : C.isMaxValue();
4035 },
4036 [&](const APInt &C) {
4037 // absorber
4038 if (IsMax)
4039 return IsSigned ? C.isMaxSignedValue() : C.isMaxValue();
4040 else
4041 return IsSigned ? C.isMinSignedValue() : C.isMinValue();
4042 });
4043 if (Folded)
4044 return Folded;
4045
4046 // Check if we have created the same expression before.
4047 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops)) {
4048 return S;
4049 }
4050
4051 // Find the first operation of the same kind
4052 unsigned Idx = 0;
4053 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < Kind)
4054 ++Idx;
4055
4056 // Check to see if one of the operands is of the same kind. If so, expand its
4057 // operands onto our operand list, and recurse to simplify.
4058 if (Idx < Ops.size()) {
4059 bool DeletedAny = false;
4060 while (Ops[Idx]->getSCEVType() == Kind) {
4061 const SCEVMinMaxExpr *SMME = cast<SCEVMinMaxExpr>(Ops[Idx]);
4062 Ops.erase(Ops.begin()+Idx);
4063 append_range(Ops, SMME->operands());
4064 DeletedAny = true;
4065 }
4066
4067 if (DeletedAny)
4068 return getMinMaxExpr(Kind, Ops);
4069 }
4070
4071 // Okay, check to see if the same value occurs in the operand list twice. If
4072 // so, delete one. Since we sorted the list, these values are required to
4073 // be adjacent.
4078 llvm::CmpInst::Predicate FirstPred = IsMax ? GEPred : LEPred;
4079 llvm::CmpInst::Predicate SecondPred = IsMax ? LEPred : GEPred;
4080 for (unsigned i = 0, e = Ops.size() - 1; i != e; ++i) {
4081 if (Ops[i] == Ops[i + 1] ||
4082 isKnownViaNonRecursiveReasoning(FirstPred, Ops[i], Ops[i + 1])) {
4083 // X op Y op Y --> X op Y
4084 // X op Y --> X, if we know X, Y are ordered appropriately
4085 Ops.erase(Ops.begin() + i + 1, Ops.begin() + i + 2);
4086 --i;
4087 --e;
4088 } else if (isKnownViaNonRecursiveReasoning(SecondPred, Ops[i],
4089 Ops[i + 1])) {
4090 // X op Y --> Y, if we know X, Y are ordered appropriately
4091 Ops.erase(Ops.begin() + i, Ops.begin() + i + 1);
4092 --i;
4093 --e;
4094 }
4095 }
4096
4097 if (Ops.size() == 1) return Ops[0];
4098
4099 assert(!Ops.empty() && "Reduced smax down to nothing!");
4100
4101 // Okay, it looks like we really DO need an expr. Check to see if we
4102 // already have one, otherwise create a new one.
4104 ID.AddInteger(Kind);
4105 for (const SCEV *Op : Ops)
4106 ID.AddPointer(Op);
4107 void *IP = nullptr;
4108 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
4109 if (ExistingSCEV)
4110 return ExistingSCEV;
4111 SCEVUse *O = SCEVAllocator.Allocate<SCEVUse>(Ops.size());
4113 SCEV *S = new (SCEVAllocator)
4114 SCEVMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
4115
4116 UniqueSCEVs.InsertNode(S, IP);
4117 S->computeAndSetCanonical(*this);
4118 registerUser(S, Ops);
4119 return S;
4120}
4121
4122namespace {
4123
4124class SCEVSequentialMinMaxDeduplicatingVisitor final
4125 : public SCEVVisitor<SCEVSequentialMinMaxDeduplicatingVisitor,
4126 std::optional<const SCEV *>> {
4127 using RetVal = std::optional<const SCEV *>;
4129
4130 ScalarEvolution &SE;
4131 const SCEVTypes RootKind; // Must be a sequential min/max expression.
4132 const SCEVTypes NonSequentialRootKind; // Non-sequential variant of RootKind.
4134
4135 bool canRecurseInto(SCEVTypes Kind) const {
4136 // We can only recurse into the SCEV expression of the same effective type
4137 // as the type of our root SCEV expression.
4138 return RootKind == Kind || NonSequentialRootKind == Kind;
4139 };
4140
4141 RetVal visitAnyMinMaxExpr(const SCEV *S) {
4143 "Only for min/max expressions.");
4144 SCEVTypes Kind = S->getSCEVType();
4145
4146 if (!canRecurseInto(Kind))
4147 return S;
4148
4149 auto *NAry = cast<SCEVNAryExpr>(S);
4150 SmallVector<SCEVUse> NewOps;
4151 bool Changed = visit(Kind, NAry->operands(), NewOps);
4152
4153 if (!Changed)
4154 return S;
4155 if (NewOps.empty())
4156 return std::nullopt;
4157
4159 ? SE.getSequentialMinMaxExpr(Kind, NewOps)
4160 : SE.getMinMaxExpr(Kind, NewOps);
4161 }
4162
4163 RetVal visit(const SCEV *S) {
4164 // Has the whole operand been seen already?
4165 if (!SeenOps.insert(S).second)
4166 return std::nullopt;
4167 return Base::visit(S);
4168 }
4169
4170public:
4171 SCEVSequentialMinMaxDeduplicatingVisitor(ScalarEvolution &SE,
4172 SCEVTypes RootKind)
4173 : SE(SE), RootKind(RootKind),
4174 NonSequentialRootKind(
4175 SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType(
4176 RootKind)) {}
4177
4178 bool /*Changed*/ visit(SCEVTypes Kind, ArrayRef<SCEVUse> OrigOps,
4179 SmallVectorImpl<SCEVUse> &NewOps) {
4180 bool Changed = false;
4182 Ops.reserve(OrigOps.size());
4183
4184 for (const SCEV *Op : OrigOps) {
4185 RetVal NewOp = visit(Op);
4186 if (NewOp != Op)
4187 Changed = true;
4188 if (NewOp)
4189 Ops.emplace_back(*NewOp);
4190 }
4191
4192 if (Changed)
4193 NewOps = std::move(Ops);
4194 return Changed;
4195 }
4196
4197 RetVal visitConstant(const SCEVConstant *Constant) { return Constant; }
4198
4199 RetVal visitVScale(const SCEVVScale *VScale) { return VScale; }
4200
4201 RetVal visitPtrToAddrExpr(const SCEVPtrToAddrExpr *Expr) { return Expr; }
4202
4203 RetVal visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { return Expr; }
4204
4205 RetVal visitTruncateExpr(const SCEVTruncateExpr *Expr) { return Expr; }
4206
4207 RetVal visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { return Expr; }
4208
4209 RetVal visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { return Expr; }
4210
4211 RetVal visitAddExpr(const SCEVAddExpr *Expr) { return Expr; }
4212
4213 RetVal visitMulExpr(const SCEVMulExpr *Expr) { return Expr; }
4214
4215 RetVal visitUDivExpr(const SCEVUDivExpr *Expr) { return Expr; }
4216
4217 RetVal visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
4218
4219 RetVal visitSMaxExpr(const SCEVSMaxExpr *Expr) {
4220 return visitAnyMinMaxExpr(Expr);
4221 }
4222
4223 RetVal visitUMaxExpr(const SCEVUMaxExpr *Expr) {
4224 return visitAnyMinMaxExpr(Expr);
4225 }
4226
4227 RetVal visitSMinExpr(const SCEVSMinExpr *Expr) {
4228 return visitAnyMinMaxExpr(Expr);
4229 }
4230
4231 RetVal visitUMinExpr(const SCEVUMinExpr *Expr) {
4232 return visitAnyMinMaxExpr(Expr);
4233 }
4234
4235 RetVal visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
4236 return visitAnyMinMaxExpr(Expr);
4237 }
4238
4239 RetVal visitUnknown(const SCEVUnknown *Expr) { return Expr; }
4240
4241 RetVal visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { return Expr; }
4242};
4243
4244} // namespace
4245
4247 switch (Kind) {
4248 case scConstant:
4249 case scVScale:
4250 case scTruncate:
4251 case scZeroExtend:
4252 case scSignExtend:
4253 case scPtrToAddr:
4254 case scPtrToInt:
4255 case scAddExpr:
4256 case scMulExpr:
4257 case scUDivExpr:
4258 case scAddRecExpr:
4259 case scUMaxExpr:
4260 case scSMaxExpr:
4261 case scUMinExpr:
4262 case scSMinExpr:
4263 case scUnknown:
4264 // If any operand is poison, the whole expression is poison.
4265 return true;
4267 // FIXME: if the *first* operand is poison, the whole expression is poison.
4268 return false; // Pessimistically, say that it does not propagate poison.
4269 case scCouldNotCompute:
4270 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
4271 }
4272 llvm_unreachable("Unknown SCEV kind!");
4273}
4274
4275namespace {
4276// The only way poison may be introduced in a SCEV expression is from a
4277// poison SCEVUnknown (ConstantExprs are also represented as SCEVUnknown,
4278// not SCEVConstant). Notably, nowrap flags in SCEV nodes can *not*
4279// introduce poison -- they encode guaranteed, non-speculated knowledge.
4280//
4281// Additionally, all SCEV nodes propagate poison from inputs to outputs,
4282// with the notable exception of umin_seq, where only poison from the first
4283// operand is (unconditionally) propagated.
4284struct SCEVPoisonCollector {
4285 bool LookThroughMaybePoisonBlocking;
4286 SmallPtrSet<const SCEVUnknown *, 4> MaybePoison;
4287 SCEVPoisonCollector(bool LookThroughMaybePoisonBlocking)
4288 : LookThroughMaybePoisonBlocking(LookThroughMaybePoisonBlocking) {}
4289
4290 bool follow(const SCEV *S) {
4291 if (!LookThroughMaybePoisonBlocking &&
4293 return false;
4294
4295 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
4296 if (!isGuaranteedNotToBePoison(SU->getValue()))
4297 MaybePoison.insert(SU);
4298 }
4299 return true;
4300 }
4301 bool isDone() const { return false; }
4302};
4303} // namespace
4304
4305/// Return true if V is poison given that AssumedPoison is already poison.
4306static bool impliesPoison(const SCEV *AssumedPoison, const SCEV *S) {
4307 // First collect all SCEVs that might result in AssumedPoison to be poison.
4308 // We need to look through potentially poison-blocking operations here,
4309 // because we want to find all SCEVs that *might* result in poison, not only
4310 // those that are *required* to.
4311 SCEVPoisonCollector PC1(/* LookThroughMaybePoisonBlocking */ true);
4312 visitAll(AssumedPoison, PC1);
4313
4314 // AssumedPoison is never poison. As the assumption is false, the implication
4315 // is true. Don't bother walking the other SCEV in this case.
4316 if (PC1.MaybePoison.empty())
4317 return true;
4318
4319 // Collect all SCEVs in S that, if poison, *will* result in S being poison
4320 // as well. We cannot look through potentially poison-blocking operations
4321 // here, as their arguments only *may* make the result poison.
4322 SCEVPoisonCollector PC2(/* LookThroughMaybePoisonBlocking */ false);
4323 visitAll(S, PC2);
4324
4325 // Make sure that no matter which SCEV in PC1.MaybePoison is actually poison,
4326 // it will also make S poison by being part of PC2.MaybePoison.
4327 return llvm::set_is_subset(PC1.MaybePoison, PC2.MaybePoison);
4328}
4329
4331 SmallPtrSetImpl<const Value *> &Result, const SCEV *S) {
4332 SCEVPoisonCollector PC(/* LookThroughMaybePoisonBlocking */ false);
4333 visitAll(S, PC);
4334 for (const SCEVUnknown *SU : PC.MaybePoison)
4335 Result.insert(SU->getValue());
4336}
4337
4339 const SCEV *S, Instruction *I,
4340 SmallVectorImpl<Instruction *> &DropPoisonGeneratingInsts) {
4341 // If the instruction cannot be poison, it's always safe to reuse.
4343 return true;
4344
4345 // Otherwise, it is possible that I is more poisonous that S. Collect the
4346 // poison-contributors of S, and then check whether I has any additional
4347 // poison-contributors. Poison that is contributed through poison-generating
4348 // flags is handled by dropping those flags instead.
4350 getPoisonGeneratingValues(PoisonVals, S);
4351
4352 SmallVector<Value *> Worklist;
4354 Worklist.push_back(I);
4355 while (!Worklist.empty()) {
4356 Value *V = Worklist.pop_back_val();
4357 if (!Visited.insert(V).second)
4358 continue;
4359
4360 // Avoid walking large instruction graphs.
4361 if (Visited.size() > 16)
4362 return false;
4363
4364 // Either the value can't be poison, or the S would also be poison if it
4365 // is.
4366 if (PoisonVals.contains(V) || ::isGuaranteedNotToBePoison(V))
4367 continue;
4368
4369 auto *I = dyn_cast<Instruction>(V);
4370 if (!I)
4371 return false;
4372
4373 // Disjoint or instructions are interpreted as adds by SCEV. However, we
4374 // can't replace an arbitrary add with disjoint or, even if we drop the
4375 // flag. We would need to convert the or into an add.
4376 if (auto *PDI = dyn_cast<PossiblyDisjointInst>(I))
4377 if (PDI->isDisjoint())
4378 return false;
4379
4380 // FIXME: Ignore vscale, even though it technically could be poison. Do this
4381 // because SCEV currently assumes it can't be poison. Remove this special
4382 // case once we proper model when vscale can be poison.
4383 if (auto *II = dyn_cast<IntrinsicInst>(I);
4384 II && II->getIntrinsicID() == Intrinsic::vscale)
4385 continue;
4386
4387 if (canCreatePoison(cast<Operator>(I), /*ConsiderFlagsAndMetadata*/ false))
4388 return false;
4389
4390 // If the instruction can't create poison, we can recurse to its operands.
4391 if (I->hasPoisonGeneratingAnnotations())
4392 DropPoisonGeneratingInsts.push_back(I);
4393
4394 llvm::append_range(Worklist, I->operands());
4395 }
4396 return true;
4397}
4398
4399const SCEV *
4402 assert(SCEVSequentialMinMaxExpr::isSequentialMinMaxType(Kind) &&
4403 "Not a SCEVSequentialMinMaxExpr!");
4404 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
4405 if (Ops.size() == 1)
4406 return Ops[0];
4407#ifndef NDEBUG
4408 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
4409 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4410 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
4411 "Operand types don't match!");
4412 assert(Ops[0]->getType()->isPointerTy() ==
4413 Ops[i]->getType()->isPointerTy() &&
4414 "min/max should be consistently pointerish");
4415 }
4416#endif
4417
4418 // Note that SCEVSequentialMinMaxExpr is *NOT* commutative,
4419 // so we can *NOT* do any kind of sorting of the expressions!
4420
4421 // Check if we have created the same expression before.
4422 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops))
4423 return S;
4424
4425 // FIXME: there are *some* simplifications that we can do here.
4426
4427 // Keep only the first instance of an operand.
4428 {
4429 SCEVSequentialMinMaxDeduplicatingVisitor Deduplicator(*this, Kind);
4430 bool Changed = Deduplicator.visit(Kind, Ops, Ops);
4431 if (Changed)
4432 return getSequentialMinMaxExpr(Kind, Ops);
4433 }
4434
4435 // Check to see if one of the operands is of the same kind. If so, expand its
4436 // operands onto our operand list, and recurse to simplify.
4437 {
4438 unsigned Idx = 0;
4439 bool DeletedAny = false;
4440 while (Idx < Ops.size()) {
4441 if (Ops[Idx]->getSCEVType() != Kind) {
4442 ++Idx;
4443 continue;
4444 }
4445 const auto *SMME = cast<SCEVSequentialMinMaxExpr>(Ops[Idx]);
4446 Ops.erase(Ops.begin() + Idx);
4447 Ops.insert(Ops.begin() + Idx, SMME->operands().begin(),
4448 SMME->operands().end());
4449 DeletedAny = true;
4450 }
4451
4452 if (DeletedAny)
4453 return getSequentialMinMaxExpr(Kind, Ops);
4454 }
4455
4456 const SCEV *SaturationPoint;
4458 switch (Kind) {
4460 SaturationPoint = getZero(Ops[0]->getType());
4461 Pred = ICmpInst::ICMP_ULE;
4462 break;
4463 default:
4464 llvm_unreachable("Not a sequential min/max type.");
4465 }
4466
4467 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4468 if (!isGuaranteedNotToCauseUB(Ops[i]))
4469 continue;
4470 // We can replace %x umin_seq %y with %x umin %y if either:
4471 // * %y being poison implies %x is also poison.
4472 // * %x cannot be the saturating value (e.g. zero for umin).
4473 if (::impliesPoison(Ops[i], Ops[i - 1]) ||
4474 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_NE, Ops[i - 1],
4475 SaturationPoint)) {
4476 SmallVector<SCEVUse, 2> SeqOps = {Ops[i - 1], Ops[i]};
4477 Ops[i - 1] = getMinMaxExpr(
4479 SeqOps);
4480 Ops.erase(Ops.begin() + i);
4481 return getSequentialMinMaxExpr(Kind, Ops);
4482 }
4483 // Fold %x umin_seq %y to %x if %x ule %y.
4484 // TODO: We might be able to prove the predicate for a later operand.
4485 if (isKnownViaNonRecursiveReasoning(Pred, Ops[i - 1], Ops[i])) {
4486 Ops.erase(Ops.begin() + i);
4487 return getSequentialMinMaxExpr(Kind, Ops);
4488 }
4489 }
4490
4491 // Okay, it looks like we really DO need an expr. Check to see if we
4492 // already have one, otherwise create a new one.
4494 ID.AddInteger(Kind);
4495 for (const SCEV *Op : Ops)
4496 ID.AddPointer(Op);
4497 void *IP = nullptr;
4498 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
4499 if (ExistingSCEV)
4500 return ExistingSCEV;
4501
4502 SCEVUse *O = SCEVAllocator.Allocate<SCEVUse>(Ops.size());
4504 SCEV *S = new (SCEVAllocator)
4505 SCEVSequentialMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
4506
4507 UniqueSCEVs.InsertNode(S, IP);
4508 S->computeAndSetCanonical(*this);
4509 registerUser(S, Ops);
4510 return S;
4511}
4512
4517
4521
4526
4530
4535
4539
4541 bool Sequential) {
4542 SmallVector<SCEVUse, 2> Ops = {LHS, RHS};
4543 return getUMinExpr(Ops, Sequential);
4544}
4545
4551
4552const SCEV *
4554 const SCEV *Res = getConstant(IntTy, Size.getKnownMinValue());
4555 if (Size.isScalable())
4556 Res = getMulExpr(Res, getVScale(IntTy));
4557 return Res;
4558}
4559
4561 return getSizeOfExpr(IntTy, getDataLayout().getTypeAllocSize(AllocTy));
4562}
4563
4565 return getSizeOfExpr(IntTy, getDataLayout().getTypeStoreSize(StoreTy));
4566}
4567
4569 StructType *STy,
4570 unsigned FieldNo) {
4571 // We can bypass creating a target-independent constant expression and then
4572 // folding it back into a ConstantInt. This is just a compile-time
4573 // optimization.
4574 const StructLayout *SL = getDataLayout().getStructLayout(STy);
4575 assert(!SL->getSizeInBits().isScalable() &&
4576 "Cannot get offset for structure containing scalable vector types");
4577 return getConstant(IntTy, SL->getElementOffset(FieldNo));
4578}
4579
4581 // Don't attempt to do anything other than create a SCEVUnknown object
4582 // here. createSCEV only calls getUnknown after checking for all other
4583 // interesting possibilities, and any other code that calls getUnknown
4584 // is doing so in order to hide a value from SCEV canonicalization.
4585
4587 ID.AddInteger(scUnknown);
4588 ID.AddPointer(V);
4589 void *IP = nullptr;
4590 if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) {
4591 assert(cast<SCEVUnknown>(S)->getValue() == V &&
4592 "Stale SCEVUnknown in uniquing map!");
4593 return S;
4594 }
4595 SCEV *S = new (SCEVAllocator) SCEVUnknown(ID.Intern(SCEVAllocator), V, this,
4596 FirstUnknown);
4597 FirstUnknown = cast<SCEVUnknown>(S);
4598 UniqueSCEVs.InsertNode(S, IP);
4599 S->computeAndSetCanonical(*this);
4600 return S;
4601}
4602
4603//===----------------------------------------------------------------------===//
4604// Basic SCEV Analysis and PHI Idiom Recognition Code
4605//
4606
4607/// Test if values of the given type are analyzable within the SCEV
4608/// framework. This primarily includes integer types, and it can optionally
4609/// include pointer types if the ScalarEvolution class has access to
4610/// target-specific information.
4612 // Integers and pointers are always SCEVable.
4613 return Ty->isIntOrPtrTy();
4614}
4615
4616/// Return the size in bits of the specified type, for which isSCEVable must
4617/// return true.
4619 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4620 if (Ty->isPointerTy())
4622 return getDataLayout().getTypeSizeInBits(Ty);
4623}
4624
4625/// Return a type with the same bitwidth as the given type and which represents
4626/// how SCEV will treat the given type, for which isSCEVable must return
4627/// true. For pointer types, this is the pointer index sized integer type.
4629 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4630
4631 if (Ty->isIntegerTy())
4632 return Ty;
4633
4634 // The only other support type is pointer.
4635 assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!");
4636 return getDataLayout().getIndexType(Ty);
4637}
4638
4640 return getTypeSizeInBits(T1) >= getTypeSizeInBits(T2) ? T1 : T2;
4641}
4642
4644 const SCEV *B) {
4645 /// For a valid use point to exist, the defining scope of one operand
4646 /// must dominate the other.
4647 bool PreciseA, PreciseB;
4648 auto *ScopeA = getDefiningScopeBound({A}, PreciseA);
4649 auto *ScopeB = getDefiningScopeBound({B}, PreciseB);
4650 if (!PreciseA || !PreciseB)
4651 // Can't tell.
4652 return false;
4653 return (ScopeA == ScopeB) || DT.dominates(ScopeA, ScopeB) ||
4654 DT.dominates(ScopeB, ScopeA);
4655}
4656
4658 return CouldNotCompute.get();
4659}
4660
4661bool ScalarEvolution::checkValidity(const SCEV *S) const {
4662 bool ContainsNulls = SCEVExprContains(S, [](const SCEV *S) {
4663 auto *SU = dyn_cast<SCEVUnknown>(S);
4664 return SU && SU->getValue() == nullptr;
4665 });
4666
4667 return !ContainsNulls;
4668}
4669
4671 HasRecMapType::iterator I = HasRecMap.find(S);
4672 if (I != HasRecMap.end())
4673 return I->second;
4674
4675 bool FoundAddRec =
4676 SCEVExprContains(S, [](const SCEV *S) { return isa<SCEVAddRecExpr>(S); });
4677 HasRecMap.insert({S, FoundAddRec});
4678 return FoundAddRec;
4679}
4680
4681/// Return the ValueOffsetPair set for \p S. \p S can be represented
4682/// by the value and offset from any ValueOffsetPair in the set.
4683ArrayRef<Value *> ScalarEvolution::getSCEVValues(const SCEV *S) {
4684 ExprValueMapType::iterator SI = ExprValueMap.find_as(S);
4685 if (SI == ExprValueMap.end())
4686 return {};
4687 return SI->second.getArrayRef();
4688}
4689
4690/// Erase Value from ValueExprMap and ExprValueMap. ValueExprMap.erase(V)
4691/// cannot be used separately. eraseValueFromMap should be used to remove
4692/// V from ValueExprMap and ExprValueMap at the same time.
4693void ScalarEvolution::eraseValueFromMap(Value *V) {
4694 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4695 if (I != ValueExprMap.end()) {
4696 auto EVIt = ExprValueMap.find(I->second);
4697 bool Removed = EVIt->second.remove(V);
4698 (void) Removed;
4699 assert(Removed && "Value not in ExprValueMap?");
4700 ValueExprMap.erase(I);
4701 }
4702}
4703
4704void ScalarEvolution::insertValueToMap(Value *V, const SCEV *S) {
4705 // A recursive query may have already computed the SCEV. It should be
4706 // equivalent, but may not necessarily be exactly the same, e.g. due to lazily
4707 // inferred nowrap flags.
4708 auto It = ValueExprMap.find_as(V);
4709 if (It == ValueExprMap.end()) {
4710 ValueExprMap.insert({SCEVCallbackVH(V, this), S});
4711 ExprValueMap[S].insert(V);
4712 }
4713}
4714
4715/// Return an existing SCEV if it exists, otherwise analyze the expression and
4716/// create a new one.
4718 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4719
4720 if (const SCEV *S = getExistingSCEV(V))
4721 return S;
4722 return createSCEVIter(V);
4723}
4724
4726 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4727
4728 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4729 if (I != ValueExprMap.end()) {
4730 const SCEV *S = I->second;
4731 assert(checkValidity(S) &&
4732 "existing SCEV has not been properly invalidated");
4733 return S;
4734 }
4735 return nullptr;
4736}
4737
4738/// Return a SCEV corresponding to -V = -1*V
4740 SCEV::NoWrapFlags Flags) {
4741 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4742 return getConstant(
4743 cast<ConstantInt>(ConstantExpr::getNeg(VC->getValue())));
4744
4745 Type *Ty = V->getType();
4746 Ty = getEffectiveSCEVType(Ty);
4747 return getMulExpr(V, getMinusOne(Ty), Flags);
4748}
4749
4750/// If Expr computes ~A, return A else return nullptr
4751static const SCEV *MatchNotExpr(const SCEV *Expr) {
4752 const SCEV *MulOp;
4753 if (match(Expr, m_scev_Add(m_scev_AllOnes(),
4754 m_scev_Mul(m_scev_AllOnes(), m_SCEV(MulOp)))))
4755 return MulOp;
4756 return nullptr;
4757}
4758
4759/// Return a SCEV corresponding to ~V = -1-V
4761 assert(!V->getType()->isPointerTy() && "Can't negate pointer");
4762
4763 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4764 return getConstant(
4765 cast<ConstantInt>(ConstantExpr::getNot(VC->getValue())));
4766
4767 // Fold ~(u|s)(min|max)(~x, ~y) to (u|s)(max|min)(x, y)
4768 if (const SCEVMinMaxExpr *MME = dyn_cast<SCEVMinMaxExpr>(V)) {
4769 auto MatchMinMaxNegation = [&](const SCEVMinMaxExpr *MME) {
4770 SmallVector<SCEVUse, 2> MatchedOperands;
4771 for (const SCEV *Operand : MME->operands()) {
4772 const SCEV *Matched = MatchNotExpr(Operand);
4773 if (!Matched)
4774 return (const SCEV *)nullptr;
4775 MatchedOperands.push_back(Matched);
4776 }
4777 return getMinMaxExpr(SCEVMinMaxExpr::negate(MME->getSCEVType()),
4778 MatchedOperands);
4779 };
4780 if (const SCEV *Replaced = MatchMinMaxNegation(MME))
4781 return Replaced;
4782 }
4783
4784 Type *Ty = V->getType();
4785 Ty = getEffectiveSCEVType(Ty);
4786 return getMinusSCEV(getMinusOne(Ty), V);
4787}
4788
4790 assert(P->getType()->isPointerTy());
4791
4792 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(P)) {
4793 // The base of an AddRec is the first operand.
4794 SmallVector<SCEVUse> Ops{AddRec->operands()};
4795 Ops[0] = removePointerBase(Ops[0]);
4796 // Don't try to transfer nowrap flags for now. We could in some cases
4797 // (for example, if pointer operand of the AddRec is a SCEVUnknown).
4798 return getAddRecExpr(Ops, AddRec->getLoop(), SCEV::FlagAnyWrap);
4799 }
4800 if (auto *Add = dyn_cast<SCEVAddExpr>(P)) {
4801 // The base of an Add is the pointer operand.
4802 SmallVector<SCEVUse> Ops{Add->operands()};
4803 SCEVUse *PtrOp = nullptr;
4804 for (SCEVUse &AddOp : Ops) {
4805 if (AddOp->getType()->isPointerTy()) {
4806 assert(!PtrOp && "Cannot have multiple pointer ops");
4807 PtrOp = &AddOp;
4808 }
4809 }
4810 *PtrOp = removePointerBase(*PtrOp);
4811 // Don't try to transfer nowrap flags for now. We could in some cases
4812 // (for example, if the pointer operand of the Add is a SCEVUnknown).
4813 return getAddExpr(Ops);
4814 }
4815 // Any other expression must be a pointer base.
4816 return getZero(P->getType());
4817}
4818
4820 SCEV::NoWrapFlags Flags,
4821 unsigned Depth) {
4822 // Fast path: X - X --> 0.
4823 if (LHS == RHS)
4824 return getZero(LHS->getType());
4825
4826 // If we subtract two pointers with different pointer bases, bail.
4827 // Eventually, we're going to add an assertion to getMulExpr that we
4828 // can't multiply by a pointer.
4829 if (RHS->getType()->isPointerTy()) {
4830 if (!LHS->getType()->isPointerTy() ||
4831 getPointerBase(LHS) != getPointerBase(RHS))
4832 return getCouldNotCompute();
4833 LHS = removePointerBase(LHS);
4834 RHS = removePointerBase(RHS);
4835 }
4836
4837 // We represent LHS - RHS as LHS + (-1)*RHS. This transformation
4838 // makes it so that we cannot make much use of NUW.
4839 auto AddFlags = SCEV::FlagAnyWrap;
4840 const bool RHSIsNotMinSigned =
4842 if (hasFlags(Flags, SCEV::FlagNSW)) {
4843 // Let M be the minimum representable signed value. Then (-1)*RHS
4844 // signed-wraps if and only if RHS is M. That can happen even for
4845 // a NSW subtraction because e.g. (-1)*M signed-wraps even though
4846 // -1 - M does not. So to transfer NSW from LHS - RHS to LHS +
4847 // (-1)*RHS, we need to prove that RHS != M.
4848 //
4849 // If LHS is non-negative and we know that LHS - RHS does not
4850 // signed-wrap, then RHS cannot be M. So we can rule out signed-wrap
4851 // either by proving that RHS > M or that LHS >= 0.
4852 if (RHSIsNotMinSigned || isKnownNonNegative(LHS)) {
4853 AddFlags = SCEV::FlagNSW;
4854 }
4855 }
4856
4857 // FIXME: Find a correct way to transfer NSW to (-1)*M when LHS -
4858 // RHS is NSW and LHS >= 0.
4859 //
4860 // The difficulty here is that the NSW flag may have been proven
4861 // relative to a loop that is to be found in a recurrence in LHS and
4862 // not in RHS. Applying NSW to (-1)*M may then let the NSW have a
4863 // larger scope than intended.
4864 auto NegFlags = RHSIsNotMinSigned ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
4865
4866 return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags, Depth);
4867}
4868
4870 unsigned Depth) {
4871 Type *SrcTy = V->getType();
4872 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4873 "Cannot truncate or zero extend with non-integer arguments!");
4874 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4875 return V; // No conversion
4876 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4877 return getTruncateExpr(V, Ty, Depth);
4878 return getZeroExtendExpr(V, Ty, Depth);
4879}
4880
4882 unsigned Depth) {
4883 Type *SrcTy = V->getType();
4884 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4885 "Cannot truncate or zero extend with non-integer arguments!");
4886 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4887 return V; // No conversion
4888 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4889 return getTruncateExpr(V, Ty, Depth);
4890 return getSignExtendExpr(V, Ty, Depth);
4891}
4892
4893const SCEV *
4895 Type *SrcTy = V->getType();
4896 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4897 "Cannot noop or zero extend with non-integer arguments!");
4899 "getNoopOrZeroExtend cannot truncate!");
4900 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4901 return V; // No conversion
4902 return getZeroExtendExpr(V, Ty);
4903}
4904
4905const SCEV *
4907 Type *SrcTy = V->getType();
4908 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4909 "Cannot noop or sign extend with non-integer arguments!");
4911 "getNoopOrSignExtend cannot truncate!");
4912 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4913 return V; // No conversion
4914 return getSignExtendExpr(V, Ty);
4915}
4916
4917const SCEV *
4919 Type *SrcTy = V->getType();
4920 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4921 "Cannot noop or any extend with non-integer arguments!");
4923 "getNoopOrAnyExtend cannot truncate!");
4924 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4925 return V; // No conversion
4926 return getAnyExtendExpr(V, Ty);
4927}
4928
4929const SCEV *
4931 Type *SrcTy = V->getType();
4932 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4933 "Cannot truncate or noop with non-integer arguments!");
4935 "getTruncateOrNoop cannot extend!");
4936 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4937 return V; // No conversion
4938 return getTruncateExpr(V, Ty);
4939}
4940
4942 const SCEV *RHS) {
4943 const SCEV *PromotedLHS = LHS;
4944 const SCEV *PromotedRHS = RHS;
4945
4946 if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType()))
4947 PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
4948 else
4949 PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
4950
4951 return getUMaxExpr(PromotedLHS, PromotedRHS);
4952}
4953
4955 const SCEV *RHS,
4956 bool Sequential) {
4957 SmallVector<SCEVUse, 2> Ops = {LHS, RHS};
4958 return getUMinFromMismatchedTypes(Ops, Sequential);
4959}
4960
4961const SCEV *
4963 bool Sequential) {
4964 assert(!Ops.empty() && "At least one operand must be!");
4965 // Trivial case.
4966 if (Ops.size() == 1)
4967 return Ops[0];
4968
4969 // Find the max type first.
4970 Type *MaxType = nullptr;
4971 for (SCEVUse S : Ops)
4972 if (MaxType)
4973 MaxType = getWiderType(MaxType, S->getType());
4974 else
4975 MaxType = S->getType();
4976 assert(MaxType && "Failed to find maximum type!");
4977
4978 // Extend all ops to max type.
4979 SmallVector<SCEVUse, 2> PromotedOps;
4980 for (SCEVUse S : Ops)
4981 PromotedOps.push_back(getNoopOrZeroExtend(S, MaxType));
4982
4983 // Generate umin.
4984 return getUMinExpr(PromotedOps, Sequential);
4985}
4986
4988 // A pointer operand may evaluate to a nonpointer expression, such as null.
4989 if (!V->getType()->isPointerTy())
4990 return V;
4991
4992 while (true) {
4993 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(V)) {
4994 V = AddRec->getStart();
4995 } else if (auto *Add = dyn_cast<SCEVAddExpr>(V)) {
4996 const SCEV *PtrOp = nullptr;
4997 for (const SCEV *AddOp : Add->operands()) {
4998 if (AddOp->getType()->isPointerTy()) {
4999 assert(!PtrOp && "Cannot have multiple pointer ops");
5000 PtrOp = AddOp;
5001 }
5002 }
5003 assert(PtrOp && "Must have pointer op");
5004 V = PtrOp;
5005 } else // Not something we can look further into.
5006 return V;
5007 }
5008}
5009
5010/// Push users of the given Instruction onto the given Worklist.
5014 // Push the def-use children onto the Worklist stack.
5015 for (User *U : I->users()) {
5016 auto *UserInsn = cast<Instruction>(U);
5017 if (Visited.insert(UserInsn).second)
5018 Worklist.push_back(UserInsn);
5019 }
5020}
5021
5022namespace {
5023
5024/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its start
5025/// expression in case its Loop is L. If it is not L then
5026/// if IgnoreOtherLoops is true then use AddRec itself
5027/// otherwise rewrite cannot be done.
5028/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
5029class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> {
5030public:
5031 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
5032 bool IgnoreOtherLoops = true) {
5033 SCEVInitRewriter Rewriter(L, SE);
5034 const SCEV *Result = Rewriter.visit(S);
5035 if (Rewriter.hasSeenLoopVariantSCEVUnknown())
5036 return SE.getCouldNotCompute();
5037 return Rewriter.hasSeenOtherLoops() && !IgnoreOtherLoops
5038 ? SE.getCouldNotCompute()
5039 : Result;
5040 }
5041
5042 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5043 if (!SE.isLoopInvariant(Expr, L))
5044 SeenLoopVariantSCEVUnknown = true;
5045 return Expr;
5046 }
5047
5048 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
5049 // Only re-write AddRecExprs for this loop.
5050 if (Expr->getLoop() == L)
5051 return Expr->getStart();
5052 SeenOtherLoops = true;
5053 return Expr;
5054 }
5055
5056 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
5057
5058 bool hasSeenOtherLoops() { return SeenOtherLoops; }
5059
5060private:
5061 explicit SCEVInitRewriter(const Loop *L, ScalarEvolution &SE)
5062 : SCEVRewriteVisitor(SE), L(L) {}
5063
5064 const Loop *L;
5065 bool SeenLoopVariantSCEVUnknown = false;
5066 bool SeenOtherLoops = false;
5067};
5068
5069/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its post
5070/// increment expression in case its Loop is L. If it is not L then
5071/// use AddRec itself.
5072/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
5073class SCEVPostIncRewriter : public SCEVRewriteVisitor<SCEVPostIncRewriter> {
5074public:
5075 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE) {
5076 SCEVPostIncRewriter Rewriter(L, SE);
5077 const SCEV *Result = Rewriter.visit(S);
5078 return Rewriter.hasSeenLoopVariantSCEVUnknown()
5079 ? SE.getCouldNotCompute()
5080 : Result;
5081 }
5082
5083 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5084 if (!SE.isLoopInvariant(Expr, L))
5085 SeenLoopVariantSCEVUnknown = true;
5086 return Expr;
5087 }
5088
5089 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
5090 // Only re-write AddRecExprs for this loop.
5091 if (Expr->getLoop() == L)
5092 return Expr->getPostIncExpr(SE);
5093 SeenOtherLoops = true;
5094 return Expr;
5095 }
5096
5097 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
5098
5099 bool hasSeenOtherLoops() { return SeenOtherLoops; }
5100
5101private:
5102 explicit SCEVPostIncRewriter(const Loop *L, ScalarEvolution &SE)
5103 : SCEVRewriteVisitor(SE), L(L) {}
5104
5105 const Loop *L;
5106 bool SeenLoopVariantSCEVUnknown = false;
5107 bool SeenOtherLoops = false;
5108};
5109
5110/// This class evaluates the compare condition by matching it against the
5111/// condition of loop latch. If there is a match we assume a true value
5112/// for the condition while building SCEV nodes.
5113class SCEVBackedgeConditionFolder
5114 : public SCEVRewriteVisitor<SCEVBackedgeConditionFolder> {
5115public:
5116 static const SCEV *rewrite(const SCEV *S, const Loop *L,
5117 ScalarEvolution &SE) {
5118 bool IsPosBECond = false;
5119 Value *BECond = nullptr;
5120 if (BasicBlock *Latch = L->getLoopLatch()) {
5121 if (CondBrInst *BI = dyn_cast<CondBrInst>(Latch->getTerminator())) {
5122 assert(BI->getSuccessor(0) != BI->getSuccessor(1) &&
5123 "Both outgoing branches should not target same header!");
5124 BECond = BI->getCondition();
5125 IsPosBECond = BI->getSuccessor(0) == L->getHeader();
5126 } else {
5127 return S;
5128 }
5129 }
5130 SCEVBackedgeConditionFolder Rewriter(L, BECond, IsPosBECond, SE);
5131 return Rewriter.visit(S);
5132 }
5133
5134 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5135 const SCEV *Result = Expr;
5136 bool InvariantF = SE.isLoopInvariant(Expr, L);
5137
5138 if (!InvariantF) {
5140 switch (I->getOpcode()) {
5141 case Instruction::Select: {
5142 SelectInst *SI = cast<SelectInst>(I);
5143 std::optional<const SCEV *> Res =
5144 compareWithBackedgeCondition(SI->getCondition());
5145 if (Res) {
5146 bool IsOne = cast<SCEVConstant>(*Res)->getValue()->isOne();
5147 Result = SE.getSCEV(IsOne ? SI->getTrueValue() : SI->getFalseValue());
5148 }
5149 break;
5150 }
5151 default: {
5152 std::optional<const SCEV *> Res = compareWithBackedgeCondition(I);
5153 if (Res)
5154 Result = *Res;
5155 break;
5156 }
5157 }
5158 }
5159 return Result;
5160 }
5161
5162private:
5163 explicit SCEVBackedgeConditionFolder(const Loop *L, Value *BECond,
5164 bool IsPosBECond, ScalarEvolution &SE)
5165 : SCEVRewriteVisitor(SE), L(L), BackedgeCond(BECond),
5166 IsPositiveBECond(IsPosBECond) {}
5167
5168 std::optional<const SCEV *> compareWithBackedgeCondition(Value *IC);
5169
5170 const Loop *L;
5171 /// Loop back condition.
5172 Value *BackedgeCond = nullptr;
5173 /// Set to true if loop back is on positive branch condition.
5174 bool IsPositiveBECond;
5175};
5176
5177std::optional<const SCEV *>
5178SCEVBackedgeConditionFolder::compareWithBackedgeCondition(Value *IC) {
5179
5180 // If value matches the backedge condition for loop latch,
5181 // then return a constant evolution node based on loopback
5182 // branch taken.
5183 if (BackedgeCond == IC)
5184 return IsPositiveBECond ? SE.getOne(Type::getInt1Ty(SE.getContext()))
5186 return std::nullopt;
5187}
5188
5189class SCEVShiftRewriter : public SCEVRewriteVisitor<SCEVShiftRewriter> {
5190public:
5191 static const SCEV *rewrite(const SCEV *S, const Loop *L,
5192 ScalarEvolution &SE) {
5193 SCEVShiftRewriter Rewriter(L, SE);
5194 const SCEV *Result = Rewriter.visit(S);
5195 return Rewriter.isValid() ? Result : SE.getCouldNotCompute();
5196 }
5197
5198 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5199 // Only allow AddRecExprs for this loop.
5200 if (!SE.isLoopInvariant(Expr, L))
5201 Valid = false;
5202 return Expr;
5203 }
5204
5205 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
5206 if (Expr->getLoop() == L && Expr->isAffine())
5207 return SE.getMinusSCEV(Expr, Expr->getStepRecurrence(SE));
5208 Valid = false;
5209 return Expr;
5210 }
5211
5212 bool isValid() { return Valid; }
5213
5214private:
5215 explicit SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE)
5216 : SCEVRewriteVisitor(SE), L(L) {}
5217
5218 const Loop *L;
5219 bool Valid = true;
5220};
5221
5222} // end anonymous namespace
5223
5225ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) {
5226 if (!AR->isAffine())
5227 return SCEV::FlagAnyWrap;
5228
5229 using OBO = OverflowingBinaryOperator;
5230
5232
5233 if (!AR->hasNoSelfWrap()) {
5234 const SCEV *BECount = getConstantMaxBackedgeTakenCount(AR->getLoop());
5235 if (const SCEVConstant *BECountMax = dyn_cast<SCEVConstant>(BECount)) {
5236 ConstantRange StepCR = getSignedRange(AR->getStepRecurrence(*this));
5237 const APInt &BECountAP = BECountMax->getAPInt();
5238 unsigned NoOverflowBitWidth =
5239 BECountAP.getActiveBits() + StepCR.getMinSignedBits();
5240 if (NoOverflowBitWidth <= getTypeSizeInBits(AR->getType()))
5242 }
5243 }
5244
5245 if (!AR->hasNoSignedWrap()) {
5246 ConstantRange AddRecRange = getSignedRange(AR);
5247 ConstantRange IncRange = getSignedRange(AR->getStepRecurrence(*this));
5248
5250 Instruction::Add, IncRange, OBO::NoSignedWrap);
5251 if (NSWRegion.contains(AddRecRange))
5253 }
5254
5255 if (!AR->hasNoUnsignedWrap()) {
5256 ConstantRange AddRecRange = getUnsignedRange(AR);
5257 ConstantRange IncRange = getUnsignedRange(AR->getStepRecurrence(*this));
5258
5260 Instruction::Add, IncRange, OBO::NoUnsignedWrap);
5261 if (NUWRegion.contains(AddRecRange))
5263 }
5264
5265 return Result;
5266}
5267
5269ScalarEvolution::proveNoSignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5271
5272 if (AR->hasNoSignedWrap())
5273 return Result;
5274
5275 if (!AR->isAffine())
5276 return Result;
5277
5278 // This function can be expensive, only try to prove NSW once per AddRec.
5279 if (!SignedWrapViaInductionTried.insert(AR).second)
5280 return Result;
5281
5282 const SCEV *Step = AR->getStepRecurrence(*this);
5283 const Loop *L = AR->getLoop();
5284
5285 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5286 // Note that this serves two purposes: It filters out loops that are
5287 // simply not analyzable, and it covers the case where this code is
5288 // being called from within backedge-taken count analysis, such that
5289 // attempting to ask for the backedge-taken count would likely result
5290 // in infinite recursion. In the later case, the analysis code will
5291 // cope with a conservative value, and it will take care to purge
5292 // that value once it has finished.
5293 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5294
5295 // Normally, in the cases we can prove no-overflow via a
5296 // backedge guarding condition, we can also compute a backedge
5297 // taken count for the loop. The exceptions are assumptions and
5298 // guards present in the loop -- SCEV is not great at exploiting
5299 // these to compute max backedge taken counts, but can still use
5300 // these to prove lack of overflow. Use this fact to avoid
5301 // doing extra work that may not pay off.
5302
5303 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5304 AC.assumptions().empty())
5305 return Result;
5306
5307 // If the backedge is guarded by a comparison with the pre-inc value the
5308 // addrec is safe. Also, if the entry is guarded by a comparison with the
5309 // start value and the backedge is guarded by a comparison with the post-inc
5310 // value, the addrec is safe.
5312 const SCEV *OverflowLimit =
5313 getSignedOverflowLimitForStep(Step, &Pred, this);
5314 if (OverflowLimit &&
5315 (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) ||
5316 isKnownOnEveryIteration(Pred, AR, OverflowLimit))) {
5317 Result = setFlags(Result, SCEV::FlagNSW);
5318 }
5319 return Result;
5320}
5322ScalarEvolution::proveNoUnsignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5324
5325 if (AR->hasNoUnsignedWrap())
5326 return Result;
5327
5328 if (!AR->isAffine())
5329 return Result;
5330
5331 // This function can be expensive, only try to prove NUW once per AddRec.
5332 if (!UnsignedWrapViaInductionTried.insert(AR).second)
5333 return Result;
5334
5335 const SCEV *Step = AR->getStepRecurrence(*this);
5336 unsigned BitWidth = getTypeSizeInBits(AR->getType());
5337 const Loop *L = AR->getLoop();
5338
5339 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5340 // Note that this serves two purposes: It filters out loops that are
5341 // simply not analyzable, and it covers the case where this code is
5342 // being called from within backedge-taken count analysis, such that
5343 // attempting to ask for the backedge-taken count would likely result
5344 // in infinite recursion. In the later case, the analysis code will
5345 // cope with a conservative value, and it will take care to purge
5346 // that value once it has finished.
5347 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5348
5349 // Normally, in the cases we can prove no-overflow via a
5350 // backedge guarding condition, we can also compute a backedge
5351 // taken count for the loop. The exceptions are assumptions and
5352 // guards present in the loop -- SCEV is not great at exploiting
5353 // these to compute max backedge taken counts, but can still use
5354 // these to prove lack of overflow. Use this fact to avoid
5355 // doing extra work that may not pay off.
5356
5357 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5358 AC.assumptions().empty())
5359 return Result;
5360
5361 // If the backedge is guarded by a comparison with the pre-inc value the
5362 // addrec is safe. Also, if the entry is guarded by a comparison with the
5363 // start value and the backedge is guarded by a comparison with the post-inc
5364 // value, the addrec is safe.
5365 if (isKnownPositive(Step)) {
5366 const SCEV *N = getConstant(APInt::getMinValue(BitWidth) -
5367 getUnsignedRangeMax(Step));
5370 Result = setFlags(Result, SCEV::FlagNUW);
5371 }
5372 }
5373
5374 return Result;
5375}
5376
5377namespace {
5378
5379/// Represents an abstract binary operation. This may exist as a
5380/// normal instruction or constant expression, or may have been
5381/// derived from an expression tree.
5382struct BinaryOp {
5383 unsigned Opcode;
5384 Value *LHS;
5385 Value *RHS;
5386 bool IsNSW = false;
5387 bool IsNUW = false;
5388
5389 /// Op is set if this BinaryOp corresponds to a concrete LLVM instruction or
5390 /// constant expression.
5391 Operator *Op = nullptr;
5392
5393 explicit BinaryOp(Operator *Op)
5394 : Opcode(Op->getOpcode()), LHS(Op->getOperand(0)), RHS(Op->getOperand(1)),
5395 Op(Op) {
5396 if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Op)) {
5397 IsNSW = OBO->hasNoSignedWrap();
5398 IsNUW = OBO->hasNoUnsignedWrap();
5399 }
5400 }
5401
5402 explicit BinaryOp(unsigned Opcode, Value *LHS, Value *RHS, bool IsNSW = false,
5403 bool IsNUW = false)
5404 : Opcode(Opcode), LHS(LHS), RHS(RHS), IsNSW(IsNSW), IsNUW(IsNUW) {}
5405};
5406
5407} // end anonymous namespace
5408
5409/// Try to map \p V into a BinaryOp, and return \c std::nullopt on failure.
5410static std::optional<BinaryOp> MatchBinaryOp(Value *V, const DataLayout &DL,
5411 AssumptionCache &AC,
5412 const DominatorTree &DT,
5413 const Instruction *CxtI) {
5414 auto *Op = dyn_cast<Operator>(V);
5415 if (!Op)
5416 return std::nullopt;
5417
5418 // Implementation detail: all the cleverness here should happen without
5419 // creating new SCEV expressions -- our caller knowns tricks to avoid creating
5420 // SCEV expressions when possible, and we should not break that.
5421
5422 switch (Op->getOpcode()) {
5423 case Instruction::Add:
5424 case Instruction::Sub:
5425 case Instruction::Mul:
5426 case Instruction::UDiv:
5427 case Instruction::URem:
5428 case Instruction::And:
5429 case Instruction::AShr:
5430 case Instruction::Shl:
5431 return BinaryOp(Op);
5432
5433 case Instruction::Or: {
5434 // Convert or disjoint into add nuw nsw.
5435 if (cast<PossiblyDisjointInst>(Op)->isDisjoint()) {
5436 BinaryOp BinOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1),
5437 /*IsNSW=*/true, /*IsNUW=*/true);
5438 // Keep the reference to the original instruction so that we can later
5439 // check whether it can produce poison value or not.
5440 BinOp.Op = Op;
5441 return BinOp;
5442 }
5443 return BinaryOp(Op);
5444 }
5445
5446 case Instruction::Xor:
5447 if (auto *RHSC = dyn_cast<ConstantInt>(Op->getOperand(1)))
5448 // If the RHS of the xor is a signmask, then this is just an add.
5449 // Instcombine turns add of signmask into xor as a strength reduction step.
5450 if (RHSC->getValue().isSignMask())
5451 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5452 // Binary `xor` is a bit-wise `add`.
5453 if (V->getType()->isIntegerTy(1))
5454 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5455 return BinaryOp(Op);
5456
5457 case Instruction::LShr:
5458 // Turn logical shift right of a constant into a unsigned divide.
5459 if (ConstantInt *SA = dyn_cast<ConstantInt>(Op->getOperand(1))) {
5460 uint32_t BitWidth = cast<IntegerType>(Op->getType())->getBitWidth();
5461
5462 // If the shift count is not less than the bitwidth, the result of
5463 // the shift is undefined. Don't try to analyze it, because the
5464 // resolution chosen here may differ from the resolution chosen in
5465 // other parts of the compiler.
5466 if (SA->getValue().ult(BitWidth)) {
5467 Constant *X =
5468 ConstantInt::get(SA->getContext(),
5469 APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
5470 return BinaryOp(Instruction::UDiv, Op->getOperand(0), X);
5471 }
5472 }
5473 return BinaryOp(Op);
5474
5475 case Instruction::ExtractValue: {
5476 auto *EVI = cast<ExtractValueInst>(Op);
5477 if (EVI->getNumIndices() != 1 || EVI->getIndices()[0] != 0)
5478 break;
5479
5480 auto *WO = dyn_cast<WithOverflowInst>(EVI->getAggregateOperand());
5481 if (!WO)
5482 break;
5483
5484 Instruction::BinaryOps BinOp = WO->getBinaryOp();
5485 bool Signed = WO->isSigned();
5486 // TODO: Should add nuw/nsw flags for mul as well.
5487 if (BinOp == Instruction::Mul || !isOverflowIntrinsicNoWrap(WO, DT))
5488 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS());
5489
5490 // Now that we know that all uses of the arithmetic-result component of
5491 // CI are guarded by the overflow check, we can go ahead and pretend
5492 // that the arithmetic is non-overflowing.
5493 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS(),
5494 /* IsNSW = */ Signed, /* IsNUW = */ !Signed);
5495 }
5496
5497 default:
5498 break;
5499 }
5500
5501 // Recognise intrinsic loop.decrement.reg, and as this has exactly the same
5502 // semantics as a Sub, return a binary sub expression.
5503 if (auto *II = dyn_cast<IntrinsicInst>(V))
5504 if (II->getIntrinsicID() == Intrinsic::loop_decrement_reg)
5505 return BinaryOp(Instruction::Sub, II->getOperand(0), II->getOperand(1));
5506
5507 return std::nullopt;
5508}
5509
5510/// Helper function to createAddRecFromPHIWithCasts. We have a phi
5511/// node whose symbolic (unknown) SCEV is \p SymbolicPHI, which is updated via
5512/// the loop backedge by a SCEVAddExpr, possibly also with a few casts on the
5513/// way. This function checks if \p Op, an operand of this SCEVAddExpr,
5514/// follows one of the following patterns:
5515/// Op == (SExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5516/// Op == (ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5517/// If the SCEV expression of \p Op conforms with one of the expected patterns
5518/// we return the type of the truncation operation, and indicate whether the
5519/// truncated type should be treated as signed/unsigned by setting
5520/// \p Signed to true/false, respectively.
5521static Type *isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI,
5522 bool &Signed, ScalarEvolution &SE) {
5523 // The case where Op == SymbolicPHI (that is, with no type conversions on
5524 // the way) is handled by the regular add recurrence creating logic and
5525 // would have already been triggered in createAddRecForPHI. Reaching it here
5526 // means that createAddRecFromPHI had failed for this PHI before (e.g.,
5527 // because one of the other operands of the SCEVAddExpr updating this PHI is
5528 // not invariant).
5529 //
5530 // Here we look for the case where Op = (ext(trunc(SymbolicPHI))), and in
5531 // this case predicates that allow us to prove that Op == SymbolicPHI will
5532 // be added.
5533 if (Op == SymbolicPHI)
5534 return nullptr;
5535
5536 unsigned SourceBits = SE.getTypeSizeInBits(SymbolicPHI->getType());
5537 unsigned NewBits = SE.getTypeSizeInBits(Op->getType());
5538 if (SourceBits != NewBits)
5539 return nullptr;
5540
5541 if (match(Op, m_scev_SExt(m_scev_Trunc(m_scev_Specific(SymbolicPHI))))) {
5542 Signed = true;
5543 return cast<SCEVCastExpr>(Op)->getOperand()->getType();
5544 }
5545 if (match(Op, m_scev_ZExt(m_scev_Trunc(m_scev_Specific(SymbolicPHI))))) {
5546 Signed = false;
5547 return cast<SCEVCastExpr>(Op)->getOperand()->getType();
5548 }
5549 return nullptr;
5550}
5551
5552static const Loop *isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI) {
5553 if (!PN->getType()->isIntegerTy())
5554 return nullptr;
5555 const Loop *L = LI.getLoopFor(PN->getParent());
5556 if (!L || L->getHeader() != PN->getParent())
5557 return nullptr;
5558 return L;
5559}
5560
5561// Analyze \p SymbolicPHI, a SCEV expression of a phi node, and check if the
5562// computation that updates the phi follows the following pattern:
5563// (SExt/ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy) + InvariantAccum
5564// which correspond to a phi->trunc->sext/zext->add->phi update chain.
5565// If so, try to see if it can be rewritten as an AddRecExpr under some
5566// Predicates. If successful, return them as a pair. Also cache the results
5567// of the analysis.
5568//
5569// Example usage scenario:
5570// Say the Rewriter is called for the following SCEV:
5571// 8 * ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5572// where:
5573// %X = phi i64 (%Start, %BEValue)
5574// It will visitMul->visitAdd->visitSExt->visitTrunc->visitUnknown(%X),
5575// and call this function with %SymbolicPHI = %X.
5576//
5577// The analysis will find that the value coming around the backedge has
5578// the following SCEV:
5579// BEValue = ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5580// Upon concluding that this matches the desired pattern, the function
5581// will return the pair {NewAddRec, SmallPredsVec} where:
5582// NewAddRec = {%Start,+,%Step}
5583// SmallPredsVec = {P1, P2, P3} as follows:
5584// P1(WrapPred): AR: {trunc(%Start),+,(trunc %Step)}<nsw> Flags: <nssw>
5585// P2(EqualPred): %Start == (sext i32 (trunc i64 %Start to i32) to i64)
5586// P3(EqualPred): %Step == (sext i32 (trunc i64 %Step to i32) to i64)
5587// The returned pair means that SymbolicPHI can be rewritten into NewAddRec
5588// under the predicates {P1,P2,P3}.
5589// This predicated rewrite will be cached in PredicatedSCEVRewrites:
5590// PredicatedSCEVRewrites[{%X,L}] = {NewAddRec, {P1,P2,P3)}
5591//
5592// TODO's:
5593//
5594// 1) Extend the Induction descriptor to also support inductions that involve
5595// casts: When needed (namely, when we are called in the context of the
5596// vectorizer induction analysis), a Set of cast instructions will be
5597// populated by this method, and provided back to isInductionPHI. This is
5598// needed to allow the vectorizer to properly record them to be ignored by
5599// the cost model and to avoid vectorizing them (otherwise these casts,
5600// which are redundant under the runtime overflow checks, will be
5601// vectorized, which can be costly).
5602//
5603// 2) Support additional induction/PHISCEV patterns: We also want to support
5604// inductions where the sext-trunc / zext-trunc operations (partly) occur
5605// after the induction update operation (the induction increment):
5606//
5607// (Trunc iy (SExt/ZExt ix (%SymbolicPHI + InvariantAccum) to iy) to ix)
5608// which correspond to a phi->add->trunc->sext/zext->phi update chain.
5609//
5610// (Trunc iy ((SExt/ZExt ix (%SymbolicPhi) to iy) + InvariantAccum) to ix)
5611// which correspond to a phi->trunc->add->sext/zext->phi update chain.
5612//
5613// 3) Outline common code with createAddRecFromPHI to avoid duplication.
5614std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5615ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI) {
5617
5618 // *** Part1: Analyze if we have a phi-with-cast pattern for which we can
5619 // return an AddRec expression under some predicate.
5620
5621 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5622 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5623 assert(L && "Expecting an integer loop header phi");
5624
5625 // The loop may have multiple entrances or multiple exits; we can analyze
5626 // this phi as an addrec if it has a unique entry value and a unique
5627 // backedge value.
5628 Value *BEValueV = nullptr, *StartValueV = nullptr;
5629 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5630 Value *V = PN->getIncomingValue(i);
5631 if (L->contains(PN->getIncomingBlock(i))) {
5632 if (!BEValueV) {
5633 BEValueV = V;
5634 } else if (BEValueV != V) {
5635 BEValueV = nullptr;
5636 break;
5637 }
5638 } else if (!StartValueV) {
5639 StartValueV = V;
5640 } else if (StartValueV != V) {
5641 StartValueV = nullptr;
5642 break;
5643 }
5644 }
5645 if (!BEValueV || !StartValueV)
5646 return std::nullopt;
5647
5648 const SCEV *BEValue = getSCEV(BEValueV);
5649
5650 // If the value coming around the backedge is an add with the symbolic
5651 // value we just inserted, possibly with casts that we can ignore under
5652 // an appropriate runtime guard, then we found a simple induction variable!
5653 const auto *Add = dyn_cast<SCEVAddExpr>(BEValue);
5654 if (!Add)
5655 return std::nullopt;
5656
5657 // If there is a single occurrence of the symbolic value, possibly
5658 // casted, replace it with a recurrence.
5659 unsigned FoundIndex = Add->getNumOperands();
5660 Type *TruncTy = nullptr;
5661 bool Signed;
5662 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5663 if ((TruncTy =
5664 isSimpleCastedPHI(Add->getOperand(i), SymbolicPHI, Signed, *this)))
5665 if (FoundIndex == e) {
5666 FoundIndex = i;
5667 break;
5668 }
5669
5670 if (FoundIndex == Add->getNumOperands())
5671 return std::nullopt;
5672
5673 // Create an add with everything but the specified operand.
5675 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5676 if (i != FoundIndex)
5677 Ops.push_back(Add->getOperand(i));
5678 const SCEV *Accum = getAddExpr(Ops);
5679
5680 // The runtime checks will not be valid if the step amount is
5681 // varying inside the loop.
5682 if (!isLoopInvariant(Accum, L))
5683 return std::nullopt;
5684
5685 // *** Part2: Create the predicates
5686
5687 // Analysis was successful: we have a phi-with-cast pattern for which we
5688 // can return an AddRec expression under the following predicates:
5689 //
5690 // P1: A Wrap predicate that guarantees that Trunc(Start) + i*Trunc(Accum)
5691 // fits within the truncated type (does not overflow) for i = 0 to n-1.
5692 // P2: An Equal predicate that guarantees that
5693 // Start = (Ext ix (Trunc iy (Start) to ix) to iy)
5694 // P3: An Equal predicate that guarantees that
5695 // Accum = (Ext ix (Trunc iy (Accum) to ix) to iy)
5696 //
5697 // As we next prove, the above predicates guarantee that:
5698 // Start + i*Accum = (Ext ix (Trunc iy ( Start + i*Accum ) to ix) to iy)
5699 //
5700 //
5701 // More formally, we want to prove that:
5702 // Expr(i+1) = Start + (i+1) * Accum
5703 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5704 //
5705 // Given that:
5706 // 1) Expr(0) = Start
5707 // 2) Expr(1) = Start + Accum
5708 // = (Ext ix (Trunc iy (Start) to ix) to iy) + Accum :: from P2
5709 // 3) Induction hypothesis (step i):
5710 // Expr(i) = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum
5711 //
5712 // Proof:
5713 // Expr(i+1) =
5714 // = Start + (i+1)*Accum
5715 // = (Start + i*Accum) + Accum
5716 // = Expr(i) + Accum
5717 // = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum + Accum
5718 // :: from step i
5719 //
5720 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy) + Accum + Accum
5721 //
5722 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy)
5723 // + (Ext ix (Trunc iy (Accum) to ix) to iy)
5724 // + Accum :: from P3
5725 //
5726 // = (Ext ix (Trunc iy ((Start + (i-1)*Accum) + Accum) to ix) to iy)
5727 // + Accum :: from P1: Ext(x)+Ext(y)=>Ext(x+y)
5728 //
5729 // = (Ext ix (Trunc iy (Start + i*Accum) to ix) to iy) + Accum
5730 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5731 //
5732 // By induction, the same applies to all iterations 1<=i<n:
5733 //
5734
5735 // Create a truncated addrec for which we will add a no overflow check (P1).
5736 const SCEV *StartVal = getSCEV(StartValueV);
5737 const SCEV *PHISCEV =
5738 getAddRecExpr(getTruncateExpr(StartVal, TruncTy),
5739 getTruncateExpr(Accum, TruncTy), L, SCEV::FlagAnyWrap);
5740
5741 // PHISCEV can be either a SCEVConstant or a SCEVAddRecExpr.
5742 // ex: If truncated Accum is 0 and StartVal is a constant, then PHISCEV
5743 // will be constant.
5744 //
5745 // If PHISCEV is a constant, then P1 degenerates into P2 or P3, so we don't
5746 // add P1.
5747 if (const auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5751 const SCEVPredicate *AddRecPred = getWrapPredicate(AR, AddedFlags);
5752 Predicates.push_back(AddRecPred);
5753 }
5754
5755 // Create the Equal Predicates P2,P3:
5756
5757 // It is possible that the predicates P2 and/or P3 are computable at
5758 // compile time due to StartVal and/or Accum being constants.
5759 // If either one is, then we can check that now and escape if either P2
5760 // or P3 is false.
5761
5762 // Construct the extended SCEV: (Ext ix (Trunc iy (Expr) to ix) to iy)
5763 // for each of StartVal and Accum
5764 auto getExtendedExpr = [&](const SCEV *Expr,
5765 bool CreateSignExtend) -> const SCEV * {
5766 assert(isLoopInvariant(Expr, L) && "Expr is expected to be invariant");
5767 const SCEV *TruncatedExpr = getTruncateExpr(Expr, TruncTy);
5768 const SCEV *ExtendedExpr =
5769 CreateSignExtend ? getSignExtendExpr(TruncatedExpr, Expr->getType())
5770 : getZeroExtendExpr(TruncatedExpr, Expr->getType());
5771 return ExtendedExpr;
5772 };
5773
5774 // Given:
5775 // ExtendedExpr = (Ext ix (Trunc iy (Expr) to ix) to iy
5776 // = getExtendedExpr(Expr)
5777 // Determine whether the predicate P: Expr == ExtendedExpr
5778 // is known to be false at compile time
5779 auto PredIsKnownFalse = [&](const SCEV *Expr,
5780 const SCEV *ExtendedExpr) -> bool {
5781 return Expr != ExtendedExpr &&
5782 isKnownPredicate(ICmpInst::ICMP_NE, Expr, ExtendedExpr);
5783 };
5784
5785 const SCEV *StartExtended = getExtendedExpr(StartVal, Signed);
5786 if (PredIsKnownFalse(StartVal, StartExtended)) {
5787 LLVM_DEBUG(dbgs() << "P2 is compile-time false\n";);
5788 return std::nullopt;
5789 }
5790
5791 // The Step is always Signed (because the overflow checks are either
5792 // NSSW or NUSW)
5793 const SCEV *AccumExtended = getExtendedExpr(Accum, /*CreateSignExtend=*/true);
5794 if (PredIsKnownFalse(Accum, AccumExtended)) {
5795 LLVM_DEBUG(dbgs() << "P3 is compile-time false\n";);
5796 return std::nullopt;
5797 }
5798
5799 auto AppendPredicate = [&](const SCEV *Expr,
5800 const SCEV *ExtendedExpr) -> void {
5801 if (Expr != ExtendedExpr &&
5802 !isKnownPredicate(ICmpInst::ICMP_EQ, Expr, ExtendedExpr)) {
5803 const SCEVPredicate *Pred = getEqualPredicate(Expr, ExtendedExpr);
5804 LLVM_DEBUG(dbgs() << "Added Predicate: " << *Pred);
5805 Predicates.push_back(Pred);
5806 }
5807 };
5808
5809 AppendPredicate(StartVal, StartExtended);
5810 AppendPredicate(Accum, AccumExtended);
5811
5812 // *** Part3: Predicates are ready. Now go ahead and create the new addrec in
5813 // which the casts had been folded away. The caller can rewrite SymbolicPHI
5814 // into NewAR if it will also add the runtime overflow checks specified in
5815 // Predicates.
5816 auto *NewAR = getAddRecExpr(StartVal, Accum, L, SCEV::FlagAnyWrap);
5817
5818 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> PredRewrite =
5819 std::make_pair(NewAR, Predicates);
5820 // Remember the result of the analysis for this SCEV at this locayyytion.
5821 PredicatedSCEVRewrites[{SymbolicPHI, L}] = PredRewrite;
5822 return PredRewrite;
5823}
5824
5825std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5827 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5828 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5829 if (!L)
5830 return std::nullopt;
5831
5832 // Check to see if we already analyzed this PHI.
5833 auto I = PredicatedSCEVRewrites.find({SymbolicPHI, L});
5834 if (I != PredicatedSCEVRewrites.end()) {
5835 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> Rewrite =
5836 I->second;
5837 // Analysis was done before and failed to create an AddRec:
5838 if (Rewrite.first == SymbolicPHI)
5839 return std::nullopt;
5840 // Analysis was done before and succeeded to create an AddRec under
5841 // a predicate:
5842 assert(isa<SCEVAddRecExpr>(Rewrite.first) && "Expected an AddRec");
5843 assert(!(Rewrite.second).empty() && "Expected to find Predicates");
5844 return Rewrite;
5845 }
5846
5847 std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5848 Rewrite = createAddRecFromPHIWithCastsImpl(SymbolicPHI);
5849
5850 // Record in the cache that the analysis failed
5851 if (!Rewrite) {
5853 PredicatedSCEVRewrites[{SymbolicPHI, L}] = {SymbolicPHI, Predicates};
5854 return std::nullopt;
5855 }
5856
5857 return Rewrite;
5858}
5859
5860// FIXME: This utility is currently required because the Rewriter currently
5861// does not rewrite this expression:
5862// {0, +, (sext ix (trunc iy to ix) to iy)}
5863// into {0, +, %step},
5864// even when the following Equal predicate exists:
5865// "%step == (sext ix (trunc iy to ix) to iy)".
5867 const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const {
5868 if (AR1 == AR2)
5869 return true;
5870
5871 auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool {
5872 if (Expr1 != Expr2 &&
5873 !Preds->implies(SE.getEqualPredicate(Expr1, Expr2), SE) &&
5874 !Preds->implies(SE.getEqualPredicate(Expr2, Expr1), SE))
5875 return false;
5876 return true;
5877 };
5878
5879 if (!areExprsEqual(AR1->getStart(), AR2->getStart()) ||
5880 !areExprsEqual(AR1->getStepRecurrence(SE), AR2->getStepRecurrence(SE)))
5881 return false;
5882 return true;
5883}
5884
5885/// A helper function for createAddRecFromPHI to handle simple cases.
5886///
5887/// This function tries to find an AddRec expression for the simplest (yet most
5888/// common) cases: PN = PHI(Start, OP(Self, LoopInvariant)).
5889/// If it fails, createAddRecFromPHI will use a more general, but slow,
5890/// technique for finding the AddRec expression.
5891const SCEV *ScalarEvolution::createSimpleAffineAddRec(PHINode *PN,
5892 Value *BEValueV,
5893 Value *StartValueV) {
5894 const Loop *L = LI.getLoopFor(PN->getParent());
5895 assert(L && L->getHeader() == PN->getParent());
5896 assert(BEValueV && StartValueV);
5897
5898 auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN);
5899 if (!BO)
5900 return nullptr;
5901
5902 if (BO->Opcode != Instruction::Add)
5903 return nullptr;
5904
5905 const SCEV *Accum = nullptr;
5906 if (BO->LHS == PN && L->isLoopInvariant(BO->RHS))
5907 Accum = getSCEV(BO->RHS);
5908 else if (BO->RHS == PN && L->isLoopInvariant(BO->LHS))
5909 Accum = getSCEV(BO->LHS);
5910
5911 if (!Accum)
5912 return nullptr;
5913
5915 if (BO->IsNUW)
5916 Flags = setFlags(Flags, SCEV::FlagNUW);
5917 if (BO->IsNSW)
5918 Flags = setFlags(Flags, SCEV::FlagNSW);
5919
5920 const SCEV *StartVal = getSCEV(StartValueV);
5921 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5922 insertValueToMap(PN, PHISCEV);
5923
5924 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5925 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR),
5926 (AR->getNoWrapFlags() | proveNoWrapViaConstantRanges(AR)));
5927 }
5928
5929 // We can add Flags to the post-inc expression only if we
5930 // know that it is *undefined behavior* for BEValueV to
5931 // overflow.
5932 if (auto *BEInst = dyn_cast<Instruction>(BEValueV)) {
5933 assert(isLoopInvariant(Accum, L) &&
5934 "Accum is defined outside L, but is not invariant?");
5935 if (isAddRecNeverPoison(BEInst, L))
5936 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5937 }
5938
5939 return PHISCEV;
5940}
5941
5942const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
5943 const Loop *L = LI.getLoopFor(PN->getParent());
5944 if (!L || L->getHeader() != PN->getParent())
5945 return nullptr;
5946
5947 // The loop may have multiple entrances or multiple exits; we can analyze
5948 // this phi as an addrec if it has a unique entry value and a unique
5949 // backedge value.
5950 Value *BEValueV = nullptr, *StartValueV = nullptr;
5951 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5952 Value *V = PN->getIncomingValue(i);
5953 if (L->contains(PN->getIncomingBlock(i))) {
5954 if (!BEValueV) {
5955 BEValueV = V;
5956 } else if (BEValueV != V) {
5957 BEValueV = nullptr;
5958 break;
5959 }
5960 } else if (!StartValueV) {
5961 StartValueV = V;
5962 } else if (StartValueV != V) {
5963 StartValueV = nullptr;
5964 break;
5965 }
5966 }
5967 if (!BEValueV || !StartValueV)
5968 return nullptr;
5969
5970 assert(ValueExprMap.find_as(PN) == ValueExprMap.end() &&
5971 "PHI node already processed?");
5972
5973 // First, try to find AddRec expression without creating a fictituos symbolic
5974 // value for PN.
5975 if (auto *S = createSimpleAffineAddRec(PN, BEValueV, StartValueV))
5976 return S;
5977
5978 // Handle PHI node value symbolically.
5979 const SCEV *SymbolicName = getUnknown(PN);
5980 insertValueToMap(PN, SymbolicName);
5981
5982 // Using this symbolic name for the PHI, analyze the value coming around
5983 // the back-edge.
5984 const SCEV *BEValue = getSCEV(BEValueV);
5985
5986 // NOTE: If BEValue is loop invariant, we know that the PHI node just
5987 // has a special value for the first iteration of the loop.
5988
5989 // If the value coming around the backedge is an add with the symbolic
5990 // value we just inserted, then we found a simple induction variable!
5991 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) {
5992 // If there is a single occurrence of the symbolic value, replace it
5993 // with a recurrence.
5994 unsigned FoundIndex = Add->getNumOperands();
5995 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5996 if (Add->getOperand(i) == SymbolicName)
5997 if (FoundIndex == e) {
5998 FoundIndex = i;
5999 break;
6000 }
6001
6002 if (FoundIndex != Add->getNumOperands()) {
6003 // Create an add with everything but the specified operand.
6005 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
6006 if (i != FoundIndex)
6007 Ops.push_back(SCEVBackedgeConditionFolder::rewrite(Add->getOperand(i),
6008 L, *this));
6009 const SCEV *Accum = getAddExpr(Ops);
6010
6011 // This is not a valid addrec if the step amount is varying each
6012 // loop iteration, but is not itself an addrec in this loop.
6013 if (isLoopInvariant(Accum, L) ||
6014 (isa<SCEVAddRecExpr>(Accum) &&
6015 cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) {
6017
6018 if (auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN)) {
6019 if (BO->Opcode == Instruction::Add && BO->LHS == PN) {
6020 if (BO->IsNUW)
6021 Flags = setFlags(Flags, SCEV::FlagNUW);
6022 if (BO->IsNSW)
6023 Flags = setFlags(Flags, SCEV::FlagNSW);
6024 }
6025 } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(BEValueV)) {
6026 if (GEP->getOperand(0) == PN) {
6027 GEPNoWrapFlags NW = GEP->getNoWrapFlags();
6028 // If the increment has any nowrap flags, then we know the address
6029 // space cannot be wrapped around.
6030 if (NW != GEPNoWrapFlags::none())
6031 Flags = setFlags(Flags, SCEV::FlagNW);
6032 // If the GEP is nuw or nusw with non-negative offset, we know that
6033 // no unsigned wrap occurs. We cannot set the nsw flag as only the
6034 // offset is treated as signed, while the base is unsigned.
6035 if (NW.hasNoUnsignedWrap() ||
6037 Flags = setFlags(Flags, SCEV::FlagNUW);
6038 }
6039
6040 // We cannot transfer nuw and nsw flags from subtraction
6041 // operations -- sub nuw X, Y is not the same as add nuw X, -Y
6042 // for instance.
6043 }
6044
6045 const SCEV *StartVal = getSCEV(StartValueV);
6046 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
6047
6048 // Okay, for the entire analysis of this edge we assumed the PHI
6049 // to be symbolic. We now need to go back and purge all of the
6050 // entries for the scalars that use the symbolic expression.
6051 forgetMemoizedResults({SymbolicName});
6052 insertValueToMap(PN, PHISCEV);
6053
6054 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
6056 const_cast<SCEVAddRecExpr *>(AR),
6057 (AR->getNoWrapFlags() | proveNoWrapViaConstantRanges(AR)));
6058 }
6059
6060 // We can add Flags to the post-inc expression only if we
6061 // know that it is *undefined behavior* for BEValueV to
6062 // overflow.
6063 if (auto *BEInst = dyn_cast<Instruction>(BEValueV))
6064 if (isLoopInvariant(Accum, L) && isAddRecNeverPoison(BEInst, L))
6065 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
6066
6067 return PHISCEV;
6068 }
6069 }
6070 } else {
6071 // Otherwise, this could be a loop like this:
6072 // i = 0; for (j = 1; ..; ++j) { .... i = j; }
6073 // In this case, j = {1,+,1} and BEValue is j.
6074 // Because the other in-value of i (0) fits the evolution of BEValue
6075 // i really is an addrec evolution.
6076 //
6077 // We can generalize this saying that i is the shifted value of BEValue
6078 // by one iteration:
6079 // PHI(f(0), f({1,+,1})) --> f({0,+,1})
6080
6081 // Do not allow refinement in rewriting of BEValue.
6082 const SCEV *Shifted = SCEVShiftRewriter::rewrite(BEValue, L, *this);
6083 const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this, false);
6084 if (Shifted != getCouldNotCompute() && Start != getCouldNotCompute() &&
6085 isGuaranteedNotToCauseUB(Shifted) && ::impliesPoison(Shifted, Start)) {
6086 const SCEV *StartVal = getSCEV(StartValueV);
6087 if (Start == StartVal) {
6088 // Okay, for the entire analysis of this edge we assumed the PHI
6089 // to be symbolic. We now need to go back and purge all of the
6090 // entries for the scalars that use the symbolic expression.
6091 forgetMemoizedResults({SymbolicName});
6092 insertValueToMap(PN, Shifted);
6093 return Shifted;
6094 }
6095 }
6096 }
6097
6098 // Remove the temporary PHI node SCEV that has been inserted while intending
6099 // to create an AddRecExpr for this PHI node. We can not keep this temporary
6100 // as it will prevent later (possibly simpler) SCEV expressions to be added
6101 // to the ValueExprMap.
6102 eraseValueFromMap(PN);
6103
6104 return nullptr;
6105}
6106
6107// Try to match a control flow sequence that branches out at BI and merges back
6108// at Merge into a "C ? LHS : RHS" select pattern. Return true on a successful
6109// match.
6111 Value *&C, Value *&LHS, Value *&RHS) {
6112 C = BI->getCondition();
6113
6114 BasicBlockEdge LeftEdge(BI->getParent(), BI->getSuccessor(0));
6115 BasicBlockEdge RightEdge(BI->getParent(), BI->getSuccessor(1));
6116
6117 Use &LeftUse = Merge->getOperandUse(0);
6118 Use &RightUse = Merge->getOperandUse(1);
6119
6120 if (DT.dominates(LeftEdge, LeftUse) && DT.dominates(RightEdge, RightUse)) {
6121 LHS = LeftUse;
6122 RHS = RightUse;
6123 return true;
6124 }
6125
6126 if (DT.dominates(LeftEdge, RightUse) && DT.dominates(RightEdge, LeftUse)) {
6127 LHS = RightUse;
6128 RHS = LeftUse;
6129 return true;
6130 }
6131
6132 return false;
6133}
6134
6136 Value *&Cond, Value *&LHS,
6137 Value *&RHS) {
6138 auto IsReachable =
6139 [&](BasicBlock *BB) { return DT.isReachableFromEntry(BB); };
6140 if (PN->getNumIncomingValues() == 2 && all_of(PN->blocks(), IsReachable)) {
6141 // Try to match
6142 //
6143 // br %cond, label %left, label %right
6144 // left:
6145 // br label %merge
6146 // right:
6147 // br label %merge
6148 // merge:
6149 // V = phi [ %x, %left ], [ %y, %right ]
6150 //
6151 // as "select %cond, %x, %y"
6152
6153 BasicBlock *IDom = DT[PN->getParent()]->getIDom()->getBlock();
6154 assert(IDom && "At least the entry block should dominate PN");
6155
6156 auto *BI = dyn_cast<CondBrInst>(IDom->getTerminator());
6157 return BI && BrPHIToSelect(DT, BI, PN, Cond, LHS, RHS);
6158 }
6159 return false;
6160}
6161
6162const SCEV *ScalarEvolution::createNodeFromSelectLikePHI(PHINode *PN) {
6163 Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr;
6164 if (getOperandsForSelectLikePHI(DT, PN, Cond, LHS, RHS) &&
6167 return createNodeForSelectOrPHI(PN, Cond, LHS, RHS);
6168
6169 return nullptr;
6170}
6171
6173 BinaryOperator *CommonInst = nullptr;
6174 // Check if instructions are identical.
6175 for (Value *Incoming : PN->incoming_values()) {
6176 auto *IncomingInst = dyn_cast<BinaryOperator>(Incoming);
6177 if (!IncomingInst)
6178 return nullptr;
6179 if (CommonInst) {
6180 if (!CommonInst->isIdenticalToWhenDefined(IncomingInst))
6181 return nullptr; // Not identical, give up
6182 } else {
6183 // Remember binary operator
6184 CommonInst = IncomingInst;
6185 }
6186 }
6187 return CommonInst;
6188}
6189
6190/// Returns SCEV for the first operand of a phi if all phi operands have
6191/// identical opcodes and operands
6192/// eg.
6193/// a: %add = %a + %b
6194/// br %c
6195/// b: %add1 = %a + %b
6196/// br %c
6197/// c: %phi = phi [%add, a], [%add1, b]
6198/// scev(%phi) => scev(%add)
6199const SCEV *
6200ScalarEvolution::createNodeForPHIWithIdenticalOperands(PHINode *PN) {
6201 BinaryOperator *CommonInst = getCommonInstForPHI(PN);
6202 if (!CommonInst)
6203 return nullptr;
6204
6205 // Check if SCEV exprs for instructions are identical.
6206 const SCEV *CommonSCEV = getSCEV(CommonInst);
6207 bool SCEVExprsIdentical =
6209 [this, CommonSCEV](Value *V) { return CommonSCEV == getSCEV(V); });
6210 return SCEVExprsIdentical ? CommonSCEV : nullptr;
6211}
6212
6213const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
6214 if (const SCEV *S = createAddRecFromPHI(PN))
6215 return S;
6216
6217 // We do not allow simplifying phi (undef, X) to X here, to avoid reusing the
6218 // phi node for X.
6219 if (Value *V = simplifyInstruction(
6220 PN, {getDataLayout(), &TLI, &DT, &AC, /*CtxI=*/nullptr,
6221 /*UseInstrInfo=*/true, /*CanUseUndef=*/false}))
6222 return getSCEV(V);
6223
6224 if (const SCEV *S = createNodeForPHIWithIdenticalOperands(PN))
6225 return S;
6226
6227 if (const SCEV *S = createNodeFromSelectLikePHI(PN))
6228 return S;
6229
6230 // If it's not a loop phi, we can't handle it yet.
6231 return getUnknown(PN);
6232}
6233
6234bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind,
6235 SCEVTypes RootKind) {
6236 struct FindClosure {
6237 const SCEV *OperandToFind;
6238 const SCEVTypes RootKind; // Must be a sequential min/max expression.
6239 const SCEVTypes NonSequentialRootKind; // Non-seq variant of RootKind.
6240
6241 bool Found = false;
6242
6243 bool canRecurseInto(SCEVTypes Kind) const {
6244 // We can only recurse into the SCEV expression of the same effective type
6245 // as the type of our root SCEV expression, and into zero-extensions.
6246 return RootKind == Kind || NonSequentialRootKind == Kind ||
6247 scZeroExtend == Kind;
6248 };
6249
6250 FindClosure(const SCEV *OperandToFind, SCEVTypes RootKind)
6251 : OperandToFind(OperandToFind), RootKind(RootKind),
6252 NonSequentialRootKind(
6254 RootKind)) {}
6255
6256 bool follow(const SCEV *S) {
6257 Found = S == OperandToFind;
6258
6259 return !isDone() && canRecurseInto(S->getSCEVType());
6260 }
6261
6262 bool isDone() const { return Found; }
6263 };
6264
6265 FindClosure FC(OperandToFind, RootKind);
6266 visitAll(Root, FC);
6267 return FC.Found;
6268}
6269
6270std::optional<const SCEV *>
6271ScalarEvolution::createNodeForSelectOrPHIInstWithICmpInstCond(Type *Ty,
6272 ICmpInst *Cond,
6273 Value *TrueVal,
6274 Value *FalseVal) {
6275 // Try to match some simple smax or umax patterns.
6276 auto *ICI = Cond;
6277
6278 Value *LHS = ICI->getOperand(0);
6279 Value *RHS = ICI->getOperand(1);
6280
6281 switch (ICI->getPredicate()) {
6282 case ICmpInst::ICMP_SLT:
6283 case ICmpInst::ICMP_SLE:
6284 case ICmpInst::ICMP_ULT:
6285 case ICmpInst::ICMP_ULE:
6286 std::swap(LHS, RHS);
6287 [[fallthrough]];
6288 case ICmpInst::ICMP_SGT:
6289 case ICmpInst::ICMP_SGE:
6290 case ICmpInst::ICMP_UGT:
6291 case ICmpInst::ICMP_UGE:
6292 // a > b ? a+x : b+x -> max(a, b)+x
6293 // a > b ? b+x : a+x -> min(a, b)+x
6295 bool Signed = ICI->isSigned();
6296 const SCEV *LA = getSCEV(TrueVal);
6297 const SCEV *RA = getSCEV(FalseVal);
6298 const SCEV *LS = getSCEV(LHS);
6299 const SCEV *RS = getSCEV(RHS);
6300 if (LA->getType()->isPointerTy()) {
6301 // FIXME: Handle cases where LS/RS are pointers not equal to LA/RA.
6302 // Need to make sure we can't produce weird expressions involving
6303 // negated pointers.
6304 if (LA == LS && RA == RS)
6305 return Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS);
6306 if (LA == RS && RA == LS)
6307 return Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS);
6308 }
6309 auto CoerceOperand = [&](const SCEV *Op) -> const SCEV * {
6310 if (Op->getType()->isPointerTy()) {
6313 return Op;
6314 }
6315 if (Signed)
6316 Op = getNoopOrSignExtend(Op, Ty);
6317 else
6318 Op = getNoopOrZeroExtend(Op, Ty);
6319 return Op;
6320 };
6321 LS = CoerceOperand(LS);
6322 RS = CoerceOperand(RS);
6324 break;
6325 const SCEV *LDiff = getMinusSCEV(LA, LS);
6326 const SCEV *RDiff = getMinusSCEV(RA, RS);
6327 if (LDiff == RDiff)
6328 return getAddExpr(Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS),
6329 LDiff);
6330 LDiff = getMinusSCEV(LA, RS);
6331 RDiff = getMinusSCEV(RA, LS);
6332 if (LDiff == RDiff)
6333 return getAddExpr(Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS),
6334 LDiff);
6335 }
6336 break;
6337 case ICmpInst::ICMP_NE:
6338 // x != 0 ? x+y : C+y -> x == 0 ? C+y : x+y
6339 std::swap(TrueVal, FalseVal);
6340 [[fallthrough]];
6341 case ICmpInst::ICMP_EQ:
6342 // x == 0 ? C+y : x+y -> umax(x, C)+y iff C u<= 1
6345 const SCEV *X = getNoopOrZeroExtend(getSCEV(LHS), Ty);
6346 const SCEV *TrueValExpr = getSCEV(TrueVal); // C+y
6347 const SCEV *FalseValExpr = getSCEV(FalseVal); // x+y
6348 const SCEV *Y = getMinusSCEV(FalseValExpr, X); // y = (x+y)-x
6349 const SCEV *C = getMinusSCEV(TrueValExpr, Y); // C = (C+y)-y
6350 if (isa<SCEVConstant>(C) && cast<SCEVConstant>(C)->getAPInt().ule(1))
6351 return getAddExpr(getUMaxExpr(X, C), Y);
6352 }
6353 // x == 0 ? 0 : umin (..., x, ...) -> umin_seq(x, umin (...))
6354 // x == 0 ? 0 : umin_seq(..., x, ...) -> umin_seq(x, umin_seq(...))
6355 // x == 0 ? 0 : umin (..., umin_seq(..., x, ...), ...)
6356 // -> umin_seq(x, umin (..., umin_seq(...), ...))
6358 isa<ConstantInt>(TrueVal) && cast<ConstantInt>(TrueVal)->isZero()) {
6359 const SCEV *X = getSCEV(LHS);
6360 while (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(X))
6361 X = ZExt->getOperand();
6362 if (getTypeSizeInBits(X->getType()) <= getTypeSizeInBits(Ty)) {
6363 const SCEV *FalseValExpr = getSCEV(FalseVal);
6364 if (SCEVMinMaxExprContains(FalseValExpr, X, scSequentialUMinExpr))
6365 return getUMinExpr(getNoopOrZeroExtend(X, Ty), FalseValExpr,
6366 /*Sequential=*/true);
6367 }
6368 }
6369 break;
6370 default:
6371 break;
6372 }
6373
6374 return std::nullopt;
6375}
6376
6377static std::optional<const SCEV *>
6379 const SCEV *TrueExpr, const SCEV *FalseExpr) {
6380 assert(CondExpr->getType()->isIntegerTy(1) &&
6381 TrueExpr->getType() == FalseExpr->getType() &&
6382 TrueExpr->getType()->isIntegerTy(1) &&
6383 "Unexpected operands of a select.");
6384
6385 // i1 cond ? i1 x : i1 C --> C + (i1 cond ? (i1 x - i1 C) : i1 0)
6386 // --> C + (umin_seq cond, x - C)
6387 //
6388 // i1 cond ? i1 C : i1 x --> C + (i1 cond ? i1 0 : (i1 x - i1 C))
6389 // --> C + (i1 ~cond ? (i1 x - i1 C) : i1 0)
6390 // --> C + (umin_seq ~cond, x - C)
6391
6392 // FIXME: while we can't legally model the case where both of the hands
6393 // are fully variable, we only require that the *difference* is constant.
6394 if (!isa<SCEVConstant>(TrueExpr) && !isa<SCEVConstant>(FalseExpr))
6395 return std::nullopt;
6396
6397 const SCEV *X, *C;
6398 if (isa<SCEVConstant>(TrueExpr)) {
6399 CondExpr = SE->getNotSCEV(CondExpr);
6400 X = FalseExpr;
6401 C = TrueExpr;
6402 } else {
6403 X = TrueExpr;
6404 C = FalseExpr;
6405 }
6406 return SE->getAddExpr(C, SE->getUMinExpr(CondExpr, SE->getMinusSCEV(X, C),
6407 /*Sequential=*/true));
6408}
6409
6410static std::optional<const SCEV *>
6412 Value *FalseVal) {
6413 if (!isa<ConstantInt>(TrueVal) && !isa<ConstantInt>(FalseVal))
6414 return std::nullopt;
6415
6416 const auto *SECond = SE->getSCEV(Cond);
6417 const auto *SETrue = SE->getSCEV(TrueVal);
6418 const auto *SEFalse = SE->getSCEV(FalseVal);
6419 return createNodeForSelectViaUMinSeq(SE, SECond, SETrue, SEFalse);
6420}
6421
6422const SCEV *ScalarEvolution::createNodeForSelectOrPHIViaUMinSeq(
6423 Value *V, Value *Cond, Value *TrueVal, Value *FalseVal) {
6424 assert(Cond->getType()->isIntegerTy(1) && "Select condition is not an i1?");
6425 assert(TrueVal->getType() == FalseVal->getType() &&
6426 V->getType() == TrueVal->getType() &&
6427 "Types of select hands and of the result must match.");
6428
6429 // For now, only deal with i1-typed `select`s.
6430 if (!V->getType()->isIntegerTy(1))
6431 return getUnknown(V);
6432
6433 if (std::optional<const SCEV *> S =
6434 createNodeForSelectViaUMinSeq(this, Cond, TrueVal, FalseVal))
6435 return *S;
6436
6437 return getUnknown(V);
6438}
6439
6440const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Value *V, Value *Cond,
6441 Value *TrueVal,
6442 Value *FalseVal) {
6443 // Handle "constant" branch or select. This can occur for instance when a
6444 // loop pass transforms an inner loop and moves on to process the outer loop.
6445 if (auto *CI = dyn_cast<ConstantInt>(Cond))
6446 return getSCEV(CI->isOne() ? TrueVal : FalseVal);
6447
6448 if (auto *I = dyn_cast<Instruction>(V)) {
6449 if (auto *ICI = dyn_cast<ICmpInst>(Cond)) {
6450 if (std::optional<const SCEV *> S =
6451 createNodeForSelectOrPHIInstWithICmpInstCond(I->getType(), ICI,
6452 TrueVal, FalseVal))
6453 return *S;
6454 }
6455 }
6456
6457 return createNodeForSelectOrPHIViaUMinSeq(V, Cond, TrueVal, FalseVal);
6458}
6459
6460/// Expand GEP instructions into add and multiply operations. This allows them
6461/// to be analyzed by regular SCEV code.
6462const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
6463 assert(GEP->getSourceElementType()->isSized() &&
6464 "GEP source element type must be sized");
6465
6466 SmallVector<SCEVUse, 4> IndexExprs;
6467 for (Value *Index : GEP->indices())
6468 IndexExprs.push_back(getSCEV(Index));
6469 return getGEPExpr(GEP, IndexExprs);
6470}
6471
6472APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S,
6473 const Instruction *CtxI) {
6474 uint64_t BitWidth = getTypeSizeInBits(S->getType());
6475 auto GetShiftedByZeros = [BitWidth](uint32_t TrailingZeros) {
6476 return TrailingZeros >= BitWidth
6478 : APInt::getOneBitSet(BitWidth, TrailingZeros);
6479 };
6480 auto GetGCDMultiple = [this, CtxI](const SCEVNAryExpr *N) {
6481 // The result is GCD of all operands results.
6482 APInt Res = getConstantMultiple(N->getOperand(0), CtxI);
6483 for (unsigned I = 1, E = N->getNumOperands(); I < E && Res != 1; ++I)
6485 Res, getConstantMultiple(N->getOperand(I), CtxI));
6486 return Res;
6487 };
6488
6489 switch (S->getSCEVType()) {
6490 case scConstant:
6491 return cast<SCEVConstant>(S)->getAPInt();
6492 case scPtrToAddr:
6493 case scPtrToInt:
6494 return getConstantMultiple(cast<SCEVCastExpr>(S)->getOperand());
6495 case scUDivExpr:
6496 case scVScale:
6497 return APInt(BitWidth, 1);
6498 case scTruncate: {
6499 // Only multiples that are a power of 2 will hold after truncation.
6500 const SCEVTruncateExpr *T = cast<SCEVTruncateExpr>(S);
6501 uint32_t TZ = getMinTrailingZeros(T->getOperand(), CtxI);
6502 return GetShiftedByZeros(TZ);
6503 }
6504 case scZeroExtend: {
6505 const SCEVZeroExtendExpr *Z = cast<SCEVZeroExtendExpr>(S);
6506 return getConstantMultiple(Z->getOperand(), CtxI).zext(BitWidth);
6507 }
6508 case scSignExtend: {
6509 // Only multiples that are a power of 2 will hold after sext.
6510 const SCEVSignExtendExpr *E = cast<SCEVSignExtendExpr>(S);
6511 uint32_t TZ = getMinTrailingZeros(E->getOperand(), CtxI);
6512 return GetShiftedByZeros(TZ);
6513 }
6514 case scMulExpr: {
6515 const SCEVMulExpr *M = cast<SCEVMulExpr>(S);
6516 if (M->hasNoUnsignedWrap()) {
6517 // The result is the product of all operand results.
6518 APInt Res = getConstantMultiple(M->getOperand(0), CtxI);
6519 for (const SCEV *Operand : M->operands().drop_front())
6520 Res = Res * getConstantMultiple(Operand, CtxI);
6521 return Res;
6522 }
6523
6524 // If there are no wrap guarentees, find the trailing zeros, which is the
6525 // sum of trailing zeros for all its operands.
6526 uint32_t TZ = 0;
6527 for (const SCEV *Operand : M->operands())
6528 TZ += getMinTrailingZeros(Operand, CtxI);
6529 return GetShiftedByZeros(TZ);
6530 }
6531 case scAddExpr:
6532 case scAddRecExpr: {
6533 const SCEVNAryExpr *N = cast<SCEVNAryExpr>(S);
6534 if (N->hasNoUnsignedWrap())
6535 return GetGCDMultiple(N);
6536 // Find the trailing bits, which is the minimum of its operands.
6537 uint32_t TZ = getMinTrailingZeros(N->getOperand(0), CtxI);
6538 for (const SCEV *Operand : N->operands().drop_front())
6539 TZ = std::min(TZ, getMinTrailingZeros(Operand, CtxI));
6540 return GetShiftedByZeros(TZ);
6541 }
6542 case scUMaxExpr:
6543 case scSMaxExpr:
6544 case scUMinExpr:
6545 case scSMinExpr:
6547 return GetGCDMultiple(cast<SCEVNAryExpr>(S));
6548 case scUnknown: {
6549 // Ask ValueTracking for known bits. SCEVUnknown only become available at
6550 // the point their underlying IR instruction has been defined. If CtxI was
6551 // not provided, use:
6552 // * the first instruction in the entry block if it is an argument
6553 // * the instruction itself otherwise.
6554 const SCEVUnknown *U = cast<SCEVUnknown>(S);
6555 if (!CtxI) {
6556 if (isa<Argument>(U->getValue()))
6557 CtxI = &*F.getEntryBlock().begin();
6558 else if (auto *I = dyn_cast<Instruction>(U->getValue()))
6559 CtxI = I;
6560 }
6561 unsigned Known =
6562 computeKnownBits(U->getValue(),
6563 SimplifyQuery(getDataLayout(), &DT, &AC, CtxI)
6564 .allowEphemerals(true))
6565 .countMinTrailingZeros();
6566 return GetShiftedByZeros(Known);
6567 }
6568 case scCouldNotCompute:
6569 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6570 }
6571 llvm_unreachable("Unknown SCEV kind!");
6572}
6573
6575 const Instruction *CtxI) {
6576 // Skip looking up and updating the cache if there is a context instruction,
6577 // as the result will only be valid in the specified context.
6578 if (CtxI)
6579 return getConstantMultipleImpl(S, CtxI);
6580
6581 auto I = ConstantMultipleCache.find(S);
6582 if (I != ConstantMultipleCache.end())
6583 return I->second;
6584
6585 APInt Result = getConstantMultipleImpl(S, CtxI);
6586 auto InsertPair = ConstantMultipleCache.insert({S, Result});
6587 assert(InsertPair.second && "Should insert a new key");
6588 return InsertPair.first->second;
6589}
6590
6592 APInt Multiple = getConstantMultiple(S);
6593 return Multiple == 0 ? APInt(Multiple.getBitWidth(), 1) : Multiple;
6594}
6595
6597 const Instruction *CtxI) {
6598 return std::min(getConstantMultiple(S, CtxI).countTrailingZeros(),
6599 (unsigned)getTypeSizeInBits(S->getType()));
6600}
6601
6602/// Helper method to assign a range to V from metadata present in the IR.
6603static std::optional<ConstantRange> GetRangeFromMetadata(Value *V) {
6605 if (MDNode *MD = I->getMetadata(LLVMContext::MD_range))
6606 return getConstantRangeFromMetadata(*MD);
6607 if (const auto *CB = dyn_cast<CallBase>(V))
6608 if (std::optional<ConstantRange> Range = CB->getRange())
6609 return Range;
6610 }
6611 if (auto *A = dyn_cast<Argument>(V))
6612 if (std::optional<ConstantRange> Range = A->getRange())
6613 return Range;
6614
6615 return std::nullopt;
6616}
6617
6619 SCEV::NoWrapFlags Flags) {
6620 if (AddRec->getNoWrapFlags(Flags) != Flags) {
6621 AddRec->setNoWrapFlags(Flags);
6622 UnsignedRanges.erase(AddRec);
6623 SignedRanges.erase(AddRec);
6624 ConstantMultipleCache.erase(AddRec);
6625 }
6626}
6627
6628ConstantRange ScalarEvolution::
6629getRangeForUnknownRecurrence(const SCEVUnknown *U) {
6630 const DataLayout &DL = getDataLayout();
6631
6632 unsigned BitWidth = getTypeSizeInBits(U->getType());
6633 const ConstantRange FullSet(BitWidth, /*isFullSet=*/true);
6634
6635 // Match a simple recurrence of the form: <start, ShiftOp, Step>, and then
6636 // use information about the trip count to improve our available range. Note
6637 // that the trip count independent cases are already handled by known bits.
6638 // WARNING: The definition of recurrence used here is subtly different than
6639 // the one used by AddRec (and thus most of this file). Step is allowed to
6640 // be arbitrarily loop varying here, where AddRec allows only loop invariant
6641 // and other addrecs in the same loop (for non-affine addrecs). The code
6642 // below intentionally handles the case where step is not loop invariant.
6643 auto *P = dyn_cast<PHINode>(U->getValue());
6644 if (!P)
6645 return FullSet;
6646
6647 // Make sure that no Phi input comes from an unreachable block. Otherwise,
6648 // even the values that are not available in these blocks may come from them,
6649 // and this leads to false-positive recurrence test.
6650 for (auto *Pred : predecessors(P->getParent()))
6651 if (!DT.isReachableFromEntry(Pred))
6652 return FullSet;
6653
6654 BinaryOperator *BO;
6655 Value *Start, *Step;
6656 if (!matchSimpleRecurrence(P, BO, Start, Step))
6657 return FullSet;
6658
6659 // If we found a recurrence in reachable code, we must be in a loop. Note
6660 // that BO might be in some subloop of L, and that's completely okay.
6661 auto *L = LI.getLoopFor(P->getParent());
6662 assert(L && L->getHeader() == P->getParent());
6663 if (!L->contains(BO->getParent()))
6664 // NOTE: This bailout should be an assert instead. However, asserting
6665 // the condition here exposes a case where LoopFusion is querying SCEV
6666 // with malformed loop information during the midst of the transform.
6667 // There doesn't appear to be an obvious fix, so for the moment bailout
6668 // until the caller issue can be fixed. PR49566 tracks the bug.
6669 return FullSet;
6670
6671 // TODO: Extend to other opcodes such as mul, and div
6672 switch (BO->getOpcode()) {
6673 default:
6674 return FullSet;
6675 case Instruction::AShr:
6676 case Instruction::LShr:
6677 case Instruction::Shl:
6678 break;
6679 };
6680
6681 if (BO->getOperand(0) != P)
6682 // TODO: Handle the power function forms some day.
6683 return FullSet;
6684
6685 unsigned TC = getSmallConstantMaxTripCount(L);
6686 if (!TC || TC >= BitWidth)
6687 return FullSet;
6688
6689 auto KnownStart = computeKnownBits(Start, DL, &AC, nullptr, &DT);
6690 auto KnownStep = computeKnownBits(Step, DL, &AC, nullptr, &DT);
6691 assert(KnownStart.getBitWidth() == BitWidth &&
6692 KnownStep.getBitWidth() == BitWidth);
6693
6694 // Compute total shift amount, being careful of overflow and bitwidths.
6695 auto MaxShiftAmt = KnownStep.getMaxValue();
6696 APInt TCAP(BitWidth, TC-1);
6697 bool Overflow = false;
6698 auto TotalShift = MaxShiftAmt.umul_ov(TCAP, Overflow);
6699 if (Overflow)
6700 return FullSet;
6701
6702 switch (BO->getOpcode()) {
6703 default:
6704 llvm_unreachable("filtered out above");
6705 case Instruction::AShr: {
6706 // For each ashr, three cases:
6707 // shift = 0 => unchanged value
6708 // saturation => 0 or -1
6709 // other => a value closer to zero (of the same sign)
6710 // Thus, the end value is closer to zero than the start.
6711 auto KnownEnd = KnownBits::ashr(KnownStart,
6712 KnownBits::makeConstant(TotalShift));
6713 if (KnownStart.isNonNegative())
6714 // Analogous to lshr (simply not yet canonicalized)
6715 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6716 KnownStart.getMaxValue() + 1);
6717 if (KnownStart.isNegative())
6718 // End >=u Start && End <=s Start
6719 return ConstantRange::getNonEmpty(KnownStart.getMinValue(),
6720 KnownEnd.getMaxValue() + 1);
6721 break;
6722 }
6723 case Instruction::LShr: {
6724 // For each lshr, three cases:
6725 // shift = 0 => unchanged value
6726 // saturation => 0
6727 // other => a smaller positive number
6728 // Thus, the low end of the unsigned range is the last value produced.
6729 auto KnownEnd = KnownBits::lshr(KnownStart,
6730 KnownBits::makeConstant(TotalShift));
6731 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6732 KnownStart.getMaxValue() + 1);
6733 }
6734 case Instruction::Shl: {
6735 // Iff no bits are shifted out, value increases on every shift.
6736 auto KnownEnd = KnownBits::shl(KnownStart,
6737 KnownBits::makeConstant(TotalShift));
6738 if (TotalShift.ult(KnownStart.countMinLeadingZeros()))
6739 return ConstantRange(KnownStart.getMinValue(),
6740 KnownEnd.getMaxValue() + 1);
6741 break;
6742 }
6743 };
6744 return FullSet;
6745}
6746
6747// The goal of this function is to check if recursively visiting the operands
6748// of this PHI might lead to an infinite loop. If we do see such a loop,
6749// there's no good way to break it, so we avoid analyzing such cases.
6750//
6751// getRangeRef previously used a visited set to avoid infinite loops, but this
6752// caused other issues: the result was dependent on the order of getRangeRef
6753// calls, and the interaction with createSCEVIter could cause a stack overflow
6754// in some cases (see issue #148253).
6755//
6756// FIXME: The way this is implemented is overly conservative; this checks
6757// for a few obviously safe patterns, but anything that doesn't lead to
6758// recursion is fine.
6760 Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr;
6762 return true;
6763
6764 if (all_of(PHI->operands(),
6765 [&](Value *Operand) { return DT.dominates(Operand, PHI); }))
6766 return true;
6767
6768 return false;
6769}
6770
6771const ConstantRange &
6772ScalarEvolution::getRangeRefIter(const SCEV *S,
6773 ScalarEvolution::RangeSignHint SignHint) {
6774 DenseMap<const SCEV *, ConstantRange> &Cache =
6775 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6776 : SignedRanges;
6777 SmallVector<SCEVUse> WorkList;
6778 SmallPtrSet<const SCEV *, 8> Seen;
6779
6780 // Add Expr to the worklist, if Expr is either an N-ary expression or a
6781 // SCEVUnknown PHI node.
6782 auto AddToWorklist = [&WorkList, &Seen, &Cache](const SCEV *Expr) {
6783 if (!Seen.insert(Expr).second)
6784 return;
6785 if (Cache.contains(Expr))
6786 return;
6787 switch (Expr->getSCEVType()) {
6788 case scUnknown:
6790 break;
6791 [[fallthrough]];
6792 case scConstant:
6793 case scVScale:
6794 case scTruncate:
6795 case scZeroExtend:
6796 case scSignExtend:
6797 case scPtrToAddr:
6798 case scPtrToInt:
6799 case scAddExpr:
6800 case scMulExpr:
6801 case scUDivExpr:
6802 case scAddRecExpr:
6803 case scUMaxExpr:
6804 case scSMaxExpr:
6805 case scUMinExpr:
6806 case scSMinExpr:
6808 WorkList.push_back(Expr);
6809 break;
6810 case scCouldNotCompute:
6811 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6812 }
6813 };
6814 AddToWorklist(S);
6815
6816 // Build worklist by queuing operands of N-ary expressions and phi nodes.
6817 for (unsigned I = 0; I != WorkList.size(); ++I) {
6818 const SCEV *P = WorkList[I];
6819 auto *UnknownS = dyn_cast<SCEVUnknown>(P);
6820 // If it is not a `SCEVUnknown`, just recurse into operands.
6821 if (!UnknownS) {
6822 for (const SCEV *Op : P->operands())
6823 AddToWorklist(Op);
6824 continue;
6825 }
6826 // `SCEVUnknown`'s require special treatment.
6827 if (PHINode *P = dyn_cast<PHINode>(UnknownS->getValue())) {
6828 if (!RangeRefPHIAllowedOperands(DT, P))
6829 continue;
6830 for (auto &Op : reverse(P->operands()))
6831 AddToWorklist(getSCEV(Op));
6832 }
6833 }
6834
6835 if (!WorkList.empty()) {
6836 // Use getRangeRef to compute ranges for items in the worklist in reverse
6837 // order. This will force ranges for earlier operands to be computed before
6838 // their users in most cases.
6839 for (const SCEV *P : reverse(drop_begin(WorkList))) {
6840 getRangeRef(P, SignHint);
6841 }
6842 }
6843
6844 return getRangeRef(S, SignHint, 0);
6845}
6846
6847/// Determine the range for a particular SCEV. If SignHint is
6848/// HINT_RANGE_UNSIGNED (resp. HINT_RANGE_SIGNED) then getRange prefers ranges
6849/// with a "cleaner" unsigned (resp. signed) representation.
6850const ConstantRange &ScalarEvolution::getRangeRef(
6851 const SCEV *S, ScalarEvolution::RangeSignHint SignHint, unsigned Depth) {
6852 DenseMap<const SCEV *, ConstantRange> &Cache =
6853 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6854 : SignedRanges;
6856 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? ConstantRange::Unsigned
6858
6859 // See if we've computed this range already.
6860 auto I = Cache.find(S);
6861 if (I != Cache.end())
6862 return I->second;
6863
6864 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
6865 return setRange(C, SignHint, ConstantRange(C->getAPInt()));
6866
6867 // Switch to iteratively computing the range for S, if it is part of a deeply
6868 // nested expression.
6870 return getRangeRefIter(S, SignHint);
6871
6872 unsigned BitWidth = getTypeSizeInBits(S->getType());
6873 ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true);
6874 using OBO = OverflowingBinaryOperator;
6875
6876 // If the value has known zeros, the maximum value will have those known zeros
6877 // as well.
6878 if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) {
6879 APInt Multiple = getNonZeroConstantMultiple(S);
6880 APInt Remainder = APInt::getMaxValue(BitWidth).urem(Multiple);
6881 if (!Remainder.isZero())
6882 ConservativeResult =
6883 ConstantRange(APInt::getMinValue(BitWidth),
6884 APInt::getMaxValue(BitWidth) - Remainder + 1);
6885 }
6886 else {
6887 uint32_t TZ = getMinTrailingZeros(S);
6888 if (TZ != 0) {
6889 ConservativeResult = ConstantRange(
6891 APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1);
6892 }
6893 }
6894
6895 switch (S->getSCEVType()) {
6896 case scConstant:
6897 llvm_unreachable("Already handled above.");
6898 case scVScale:
6899 return setRange(S, SignHint, getVScaleRange(&F, BitWidth));
6900 case scTruncate: {
6901 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(S);
6902 ConstantRange X = getRangeRef(Trunc->getOperand(), SignHint, Depth + 1);
6903 return setRange(
6904 Trunc, SignHint,
6905 ConservativeResult.intersectWith(X.truncate(BitWidth), RangeType));
6906 }
6907 case scZeroExtend: {
6908 const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(S);
6909 ConstantRange X = getRangeRef(ZExt->getOperand(), SignHint, Depth + 1);
6910 return setRange(
6911 ZExt, SignHint,
6912 ConservativeResult.intersectWith(X.zeroExtend(BitWidth), RangeType));
6913 }
6914 case scSignExtend: {
6915 const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(S);
6916 ConstantRange X = getRangeRef(SExt->getOperand(), SignHint, Depth + 1);
6917 return setRange(
6918 SExt, SignHint,
6919 ConservativeResult.intersectWith(X.signExtend(BitWidth), RangeType));
6920 }
6921 case scPtrToAddr:
6922 case scPtrToInt: {
6923 const SCEVCastExpr *Cast = cast<SCEVCastExpr>(S);
6924 ConstantRange X = getRangeRef(Cast->getOperand(), SignHint, Depth + 1);
6925 return setRange(Cast, SignHint, X);
6926 }
6927 case scAddExpr: {
6928 const SCEVAddExpr *Add = cast<SCEVAddExpr>(S);
6929 // Check if this is a URem pattern: A - (A / B) * B, which is always < B.
6930 const SCEV *URemLHS = nullptr, *URemRHS = nullptr;
6931 if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED &&
6932 match(S, m_scev_URem(m_SCEV(URemLHS), m_SCEV(URemRHS), *this))) {
6933 ConstantRange LHSRange = getRangeRef(URemLHS, SignHint, Depth + 1);
6934 ConstantRange RHSRange = getRangeRef(URemRHS, SignHint, Depth + 1);
6935 ConservativeResult =
6936 ConservativeResult.intersectWith(LHSRange.urem(RHSRange), RangeType);
6937 }
6938 ConstantRange X = getRangeRef(Add->getOperand(0), SignHint, Depth + 1);
6939 unsigned WrapType = OBO::AnyWrap;
6940 if (Add->hasNoSignedWrap())
6941 WrapType |= OBO::NoSignedWrap;
6942 if (Add->hasNoUnsignedWrap())
6943 WrapType |= OBO::NoUnsignedWrap;
6944 for (const SCEV *Op : drop_begin(Add->operands()))
6945 X = X.addWithNoWrap(getRangeRef(Op, SignHint, Depth + 1), WrapType,
6946 RangeType);
6947 return setRange(Add, SignHint,
6948 ConservativeResult.intersectWith(X, RangeType));
6949 }
6950 case scMulExpr: {
6951 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(S);
6952 ConstantRange X = getRangeRef(Mul->getOperand(0), SignHint, Depth + 1);
6953 for (const SCEV *Op : drop_begin(Mul->operands()))
6954 X = X.multiply(getRangeRef(Op, SignHint, Depth + 1));
6955 return setRange(Mul, SignHint,
6956 ConservativeResult.intersectWith(X, RangeType));
6957 }
6958 case scUDivExpr: {
6959 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
6960 ConstantRange X = getRangeRef(UDiv->getLHS(), SignHint, Depth + 1);
6961 ConstantRange Y = getRangeRef(UDiv->getRHS(), SignHint, Depth + 1);
6962 return setRange(UDiv, SignHint,
6963 ConservativeResult.intersectWith(X.udiv(Y), RangeType));
6964 }
6965 case scAddRecExpr: {
6966 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(S);
6967 // If there's no unsigned wrap, the value will never be less than its
6968 // initial value.
6969 if (AddRec->hasNoUnsignedWrap()) {
6970 APInt UnsignedMinValue = getUnsignedRangeMin(AddRec->getStart());
6971 if (!UnsignedMinValue.isZero())
6972 ConservativeResult = ConservativeResult.intersectWith(
6973 ConstantRange(UnsignedMinValue, APInt(BitWidth, 0)), RangeType);
6974 }
6975
6976 // If there's no signed wrap, and all the operands except initial value have
6977 // the same sign or zero, the value won't ever be:
6978 // 1: smaller than initial value if operands are non negative,
6979 // 2: bigger than initial value if operands are non positive.
6980 // For both cases, value can not cross signed min/max boundary.
6981 if (AddRec->hasNoSignedWrap()) {
6982 bool AllNonNeg = true;
6983 bool AllNonPos = true;
6984 for (unsigned i = 1, e = AddRec->getNumOperands(); i != e; ++i) {
6985 if (!isKnownNonNegative(AddRec->getOperand(i)))
6986 AllNonNeg = false;
6987 if (!isKnownNonPositive(AddRec->getOperand(i)))
6988 AllNonPos = false;
6989 }
6990 if (AllNonNeg)
6991 ConservativeResult = ConservativeResult.intersectWith(
6994 RangeType);
6995 else if (AllNonPos)
6996 ConservativeResult = ConservativeResult.intersectWith(
6998 getSignedRangeMax(AddRec->getStart()) +
6999 1),
7000 RangeType);
7001 }
7002
7003 // TODO: non-affine addrec
7004 if (AddRec->isAffine()) {
7005 const SCEV *MaxBEScev =
7007 if (!isa<SCEVCouldNotCompute>(MaxBEScev)) {
7008 APInt MaxBECount = cast<SCEVConstant>(MaxBEScev)->getAPInt();
7009
7010 // Adjust MaxBECount to the same bitwidth as AddRec. We can truncate if
7011 // MaxBECount's active bits are all <= AddRec's bit width.
7012 if (MaxBECount.getBitWidth() > BitWidth &&
7013 MaxBECount.getActiveBits() <= BitWidth)
7014 MaxBECount = MaxBECount.trunc(BitWidth);
7015 else if (MaxBECount.getBitWidth() < BitWidth)
7016 MaxBECount = MaxBECount.zext(BitWidth);
7017
7018 if (MaxBECount.getBitWidth() == BitWidth) {
7019 auto RangeFromAffine = getRangeForAffineAR(
7020 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
7021 ConservativeResult =
7022 ConservativeResult.intersectWith(RangeFromAffine, RangeType);
7023
7024 auto RangeFromFactoring = getRangeViaFactoring(
7025 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
7026 ConservativeResult =
7027 ConservativeResult.intersectWith(RangeFromFactoring, RangeType);
7028 }
7029 }
7030
7031 // Now try symbolic BE count and more powerful methods.
7033 const SCEV *SymbolicMaxBECount =
7035 if (!isa<SCEVCouldNotCompute>(SymbolicMaxBECount) &&
7036 getTypeSizeInBits(MaxBEScev->getType()) <= BitWidth &&
7037 AddRec->hasNoSelfWrap()) {
7038 auto RangeFromAffineNew = getRangeForAffineNoSelfWrappingAR(
7039 AddRec, SymbolicMaxBECount, BitWidth, SignHint);
7040 ConservativeResult =
7041 ConservativeResult.intersectWith(RangeFromAffineNew, RangeType);
7042 }
7043 }
7044 }
7045
7046 return setRange(AddRec, SignHint, std::move(ConservativeResult));
7047 }
7048 case scUMaxExpr:
7049 case scSMaxExpr:
7050 case scUMinExpr:
7051 case scSMinExpr:
7052 case scSequentialUMinExpr: {
7054 switch (S->getSCEVType()) {
7055 case scUMaxExpr:
7056 ID = Intrinsic::umax;
7057 break;
7058 case scSMaxExpr:
7059 ID = Intrinsic::smax;
7060 break;
7061 case scUMinExpr:
7063 ID = Intrinsic::umin;
7064 break;
7065 case scSMinExpr:
7066 ID = Intrinsic::smin;
7067 break;
7068 default:
7069 llvm_unreachable("Unknown SCEVMinMaxExpr/SCEVSequentialMinMaxExpr.");
7070 }
7071
7072 const auto *NAry = cast<SCEVNAryExpr>(S);
7073 ConstantRange X = getRangeRef(NAry->getOperand(0), SignHint, Depth + 1);
7074 for (unsigned i = 1, e = NAry->getNumOperands(); i != e; ++i)
7075 X = X.intrinsic(
7076 ID, {X, getRangeRef(NAry->getOperand(i), SignHint, Depth + 1)});
7077 return setRange(S, SignHint,
7078 ConservativeResult.intersectWith(X, RangeType));
7079 }
7080 case scUnknown: {
7081 const SCEVUnknown *U = cast<SCEVUnknown>(S);
7082 Value *V = U->getValue();
7083
7084 // Check if the IR explicitly contains !range metadata.
7085 std::optional<ConstantRange> MDRange = GetRangeFromMetadata(V);
7086 if (MDRange)
7087 ConservativeResult =
7088 ConservativeResult.intersectWith(*MDRange, RangeType);
7089
7090 // Use facts about recurrences in the underlying IR. Note that add
7091 // recurrences are AddRecExprs and thus don't hit this path. This
7092 // primarily handles shift recurrences.
7093 auto CR = getRangeForUnknownRecurrence(U);
7094 ConservativeResult = ConservativeResult.intersectWith(CR);
7095
7096 // See if ValueTracking can give us a useful range.
7097 const DataLayout &DL = getDataLayout();
7098 KnownBits Known = computeKnownBits(V, DL, &AC, nullptr, &DT);
7099 if (Known.getBitWidth() != BitWidth)
7100 Known = Known.zextOrTrunc(BitWidth);
7101
7102 // ValueTracking may be able to compute a tighter result for the number of
7103 // sign bits than for the value of those sign bits.
7104 unsigned NS = ComputeNumSignBits(V, DL, &AC, nullptr, &DT);
7105 if (U->getType()->isPointerTy()) {
7106 // If the pointer size is larger than the index size type, this can cause
7107 // NS to be larger than BitWidth. So compensate for this.
7108 unsigned ptrSize = DL.getPointerTypeSizeInBits(U->getType());
7109 int ptrIdxDiff = ptrSize - BitWidth;
7110 if (ptrIdxDiff > 0 && ptrSize > BitWidth && NS > (unsigned)ptrIdxDiff)
7111 NS -= ptrIdxDiff;
7112 }
7113
7114 if (NS > 1) {
7115 // If we know any of the sign bits, we know all of the sign bits.
7116 if (!Known.Zero.getHiBits(NS).isZero())
7117 Known.Zero.setHighBits(NS);
7118 if (!Known.One.getHiBits(NS).isZero())
7119 Known.One.setHighBits(NS);
7120 }
7121
7122 if (Known.getMinValue() != Known.getMaxValue() + 1)
7123 ConservativeResult = ConservativeResult.intersectWith(
7124 ConstantRange(Known.getMinValue(), Known.getMaxValue() + 1),
7125 RangeType);
7126 if (NS > 1)
7127 ConservativeResult = ConservativeResult.intersectWith(
7128 ConstantRange(APInt::getSignedMinValue(BitWidth).ashr(NS - 1),
7129 APInt::getSignedMaxValue(BitWidth).ashr(NS - 1) + 1),
7130 RangeType);
7131
7132 if (U->getType()->isPointerTy() && SignHint == HINT_RANGE_UNSIGNED) {
7133 // Strengthen the range if the underlying IR value is a
7134 // global/alloca/heap allocation using the size of the object.
7135 bool CanBeNull, CanBeFreed;
7136 uint64_t DerefBytes =
7137 V->getPointerDereferenceableBytes(DL, CanBeNull, CanBeFreed);
7138 if (DerefBytes > 1 && isUIntN(BitWidth, DerefBytes)) {
7139 // The highest address the object can start is DerefBytes bytes before
7140 // the end (unsigned max value). If this value is not a multiple of the
7141 // alignment, the last possible start value is the next lowest multiple
7142 // of the alignment. Note: The computations below cannot overflow,
7143 // because if they would there's no possible start address for the
7144 // object.
7145 APInt MaxVal =
7146 APInt::getMaxValue(BitWidth) - APInt(BitWidth, DerefBytes);
7147 uint64_t Align = U->getValue()->getPointerAlignment(DL).value();
7148 uint64_t Rem = MaxVal.urem(Align);
7149 MaxVal -= APInt(BitWidth, Rem);
7150 APInt MinVal = APInt::getZero(BitWidth);
7151 if (llvm::isKnownNonZero(V, DL))
7152 MinVal = Align;
7153 ConservativeResult = ConservativeResult.intersectWith(
7154 ConstantRange::getNonEmpty(MinVal, MaxVal + 1), RangeType);
7155 }
7156 }
7157
7158 // A range of Phi is a subset of union of all ranges of its input.
7159 if (PHINode *Phi = dyn_cast<PHINode>(V)) {
7160 // SCEVExpander sometimes creates SCEVUnknowns that are secretly
7161 // AddRecs; return the range for the corresponding AddRec.
7162 if (auto *AR = dyn_cast<SCEVAddRecExpr>(getSCEV(V)))
7163 return getRangeRef(AR, SignHint, Depth + 1);
7164
7165 // Make sure that we do not run over cycled Phis.
7166 if (RangeRefPHIAllowedOperands(DT, Phi)) {
7167 ConstantRange RangeFromOps(BitWidth, /*isFullSet=*/false);
7168
7169 for (const auto &Op : Phi->operands()) {
7170 auto OpRange = getRangeRef(getSCEV(Op), SignHint, Depth + 1);
7171 RangeFromOps = RangeFromOps.unionWith(OpRange);
7172 // No point to continue if we already have a full set.
7173 if (RangeFromOps.isFullSet())
7174 break;
7175 }
7176 ConservativeResult =
7177 ConservativeResult.intersectWith(RangeFromOps, RangeType);
7178 }
7179 }
7180
7181 // vscale can't be equal to zero
7182 if (const auto *II = dyn_cast<IntrinsicInst>(V))
7183 if (II->getIntrinsicID() == Intrinsic::vscale) {
7184 ConstantRange Disallowed = APInt::getZero(BitWidth);
7185 ConservativeResult = ConservativeResult.difference(Disallowed);
7186 }
7187
7188 return setRange(U, SignHint, std::move(ConservativeResult));
7189 }
7190 case scCouldNotCompute:
7191 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
7192 }
7193
7194 return setRange(S, SignHint, std::move(ConservativeResult));
7195}
7196
7197// Given a StartRange, Step and MaxBECount for an expression compute a range of
7198// values that the expression can take. Initially, the expression has a value
7199// from StartRange and then is changed by Step up to MaxBECount times. Signed
7200// argument defines if we treat Step as signed or unsigned.
7202 const ConstantRange &StartRange,
7203 const APInt &MaxBECount,
7204 bool Signed) {
7205 unsigned BitWidth = Step.getBitWidth();
7206 assert(BitWidth == StartRange.getBitWidth() &&
7207 BitWidth == MaxBECount.getBitWidth() && "mismatched bit widths");
7208 // If either Step or MaxBECount is 0, then the expression won't change, and we
7209 // just need to return the initial range.
7210 if (Step == 0 || MaxBECount == 0)
7211 return StartRange;
7212
7213 // If we don't know anything about the initial value (i.e. StartRange is
7214 // FullRange), then we don't know anything about the final range either.
7215 // Return FullRange.
7216 if (StartRange.isFullSet())
7217 return ConstantRange::getFull(BitWidth);
7218
7219 // If Step is signed and negative, then we use its absolute value, but we also
7220 // note that we're moving in the opposite direction.
7221 bool Descending = Signed && Step.isNegative();
7222
7223 if (Signed)
7224 // This is correct even for INT_SMIN. Let's look at i8 to illustrate this:
7225 // abs(INT_SMIN) = abs(-128) = abs(0x80) = -0x80 = 0x80 = 128.
7226 // This equations hold true due to the well-defined wrap-around behavior of
7227 // APInt.
7228 Step = Step.abs();
7229
7230 // Check if Offset is more than full span of BitWidth. If it is, the
7231 // expression is guaranteed to overflow.
7232 if (APInt::getMaxValue(StartRange.getBitWidth()).udiv(Step).ult(MaxBECount))
7233 return ConstantRange::getFull(BitWidth);
7234
7235 // Offset is by how much the expression can change. Checks above guarantee no
7236 // overflow here.
7237 APInt Offset = Step * MaxBECount;
7238
7239 // Minimum value of the final range will match the minimal value of StartRange
7240 // if the expression is increasing and will be decreased by Offset otherwise.
7241 // Maximum value of the final range will match the maximal value of StartRange
7242 // if the expression is decreasing and will be increased by Offset otherwise.
7243 APInt StartLower = StartRange.getLower();
7244 APInt StartUpper = StartRange.getUpper() - 1;
7245 APInt MovedBoundary = Descending ? (StartLower - std::move(Offset))
7246 : (StartUpper + std::move(Offset));
7247
7248 // It's possible that the new minimum/maximum value will fall into the initial
7249 // range (due to wrap around). This means that the expression can take any
7250 // value in this bitwidth, and we have to return full range.
7251 if (StartRange.contains(MovedBoundary))
7252 return ConstantRange::getFull(BitWidth);
7253
7254 APInt NewLower =
7255 Descending ? std::move(MovedBoundary) : std::move(StartLower);
7256 APInt NewUpper =
7257 Descending ? std::move(StartUpper) : std::move(MovedBoundary);
7258 NewUpper += 1;
7259
7260 // No overflow detected, return [StartLower, StartUpper + Offset + 1) range.
7261 return ConstantRange::getNonEmpty(std::move(NewLower), std::move(NewUpper));
7262}
7263
7264ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start,
7265 const SCEV *Step,
7266 const APInt &MaxBECount) {
7267 assert(getTypeSizeInBits(Start->getType()) ==
7268 getTypeSizeInBits(Step->getType()) &&
7269 getTypeSizeInBits(Start->getType()) == MaxBECount.getBitWidth() &&
7270 "mismatched bit widths");
7271
7272 // First, consider step signed.
7273 ConstantRange StartSRange = getSignedRange(Start);
7274 ConstantRange StepSRange = getSignedRange(Step);
7275
7276 // If Step can be both positive and negative, we need to find ranges for the
7277 // maximum absolute step values in both directions and union them.
7278 ConstantRange SR = getRangeForAffineARHelper(
7279 StepSRange.getSignedMin(), StartSRange, MaxBECount, /* Signed = */ true);
7281 StartSRange, MaxBECount,
7282 /* Signed = */ true));
7283
7284 // Next, consider step unsigned.
7285 ConstantRange UR = getRangeForAffineARHelper(
7286 getUnsignedRangeMax(Step), getUnsignedRange(Start), MaxBECount,
7287 /* Signed = */ false);
7288
7289 // Finally, intersect signed and unsigned ranges.
7291}
7292
7293ConstantRange ScalarEvolution::getRangeForAffineNoSelfWrappingAR(
7294 const SCEVAddRecExpr *AddRec, const SCEV *MaxBECount, unsigned BitWidth,
7295 ScalarEvolution::RangeSignHint SignHint) {
7296 assert(AddRec->isAffine() && "Non-affine AddRecs are not suppored!\n");
7297 assert(AddRec->hasNoSelfWrap() &&
7298 "This only works for non-self-wrapping AddRecs!");
7299 const bool IsSigned = SignHint == HINT_RANGE_SIGNED;
7300 const SCEV *Step = AddRec->getStepRecurrence(*this);
7301 // Only deal with constant step to save compile time.
7302 if (!isa<SCEVConstant>(Step))
7303 return ConstantRange::getFull(BitWidth);
7304 // Let's make sure that we can prove that we do not self-wrap during
7305 // MaxBECount iterations. We need this because MaxBECount is a maximum
7306 // iteration count estimate, and we might infer nw from some exit for which we
7307 // do not know max exit count (or any other side reasoning).
7308 // TODO: Turn into assert at some point.
7309 if (getTypeSizeInBits(MaxBECount->getType()) >
7310 getTypeSizeInBits(AddRec->getType()))
7311 return ConstantRange::getFull(BitWidth);
7312 MaxBECount = getNoopOrZeroExtend(MaxBECount, AddRec->getType());
7313 const SCEV *RangeWidth = getMinusOne(AddRec->getType());
7314 const SCEV *StepAbs = getUMinExpr(Step, getNegativeSCEV(Step));
7315 const SCEV *MaxItersWithoutWrap = getUDivExpr(RangeWidth, StepAbs);
7316 if (!isKnownPredicateViaConstantRanges(ICmpInst::ICMP_ULE, MaxBECount,
7317 MaxItersWithoutWrap))
7318 return ConstantRange::getFull(BitWidth);
7319
7320 ICmpInst::Predicate LEPred =
7322 ICmpInst::Predicate GEPred =
7324 const SCEV *End = AddRec->evaluateAtIteration(MaxBECount, *this);
7325
7326 // We know that there is no self-wrap. Let's take Start and End values and
7327 // look at all intermediate values V1, V2, ..., Vn that IndVar takes during
7328 // the iteration. They either lie inside the range [Min(Start, End),
7329 // Max(Start, End)] or outside it:
7330 //
7331 // Case 1: RangeMin ... Start V1 ... VN End ... RangeMax;
7332 // Case 2: RangeMin Vk ... V1 Start ... End Vn ... Vk + 1 RangeMax;
7333 //
7334 // No self wrap flag guarantees that the intermediate values cannot be BOTH
7335 // outside and inside the range [Min(Start, End), Max(Start, End)]. Using that
7336 // knowledge, let's try to prove that we are dealing with Case 1. It is so if
7337 // Start <= End and step is positive, or Start >= End and step is negative.
7338 const SCEV *Start = applyLoopGuards(AddRec->getStart(), AddRec->getLoop());
7339 ConstantRange StartRange = getRangeRef(Start, SignHint);
7340 ConstantRange EndRange = getRangeRef(End, SignHint);
7341 ConstantRange RangeBetween = StartRange.unionWith(EndRange);
7342 // If they already cover full iteration space, we will know nothing useful
7343 // even if we prove what we want to prove.
7344 if (RangeBetween.isFullSet())
7345 return RangeBetween;
7346 // Only deal with ranges that do not wrap (i.e. RangeMin < RangeMax).
7347 bool IsWrappedSet = IsSigned ? RangeBetween.isSignWrappedSet()
7348 : RangeBetween.isWrappedSet();
7349 if (IsWrappedSet)
7350 return ConstantRange::getFull(BitWidth);
7351
7352 if (isKnownPositive(Step) &&
7353 isKnownPredicateViaConstantRanges(LEPred, Start, End))
7354 return RangeBetween;
7355 if (isKnownNegative(Step) &&
7356 isKnownPredicateViaConstantRanges(GEPred, Start, End))
7357 return RangeBetween;
7358 return ConstantRange::getFull(BitWidth);
7359}
7360
7361ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start,
7362 const SCEV *Step,
7363 const APInt &MaxBECount) {
7364 // RangeOf({C?A:B,+,C?P:Q}) == RangeOf(C?{A,+,P}:{B,+,Q})
7365 // == RangeOf({A,+,P}) union RangeOf({B,+,Q})
7366
7367 unsigned BitWidth = MaxBECount.getBitWidth();
7368 assert(getTypeSizeInBits(Start->getType()) == BitWidth &&
7369 getTypeSizeInBits(Step->getType()) == BitWidth &&
7370 "mismatched bit widths");
7371
7372 struct SelectPattern {
7373 Value *Condition = nullptr;
7374 APInt TrueValue;
7375 APInt FalseValue;
7376
7377 explicit SelectPattern(ScalarEvolution &SE, unsigned BitWidth,
7378 const SCEV *S) {
7379 std::optional<unsigned> CastOp;
7380 APInt Offset(BitWidth, 0);
7381
7383 "Should be!");
7384
7385 // Peel off a constant offset. In the future we could consider being
7386 // smarter here and handle {Start+Step,+,Step} too.
7387 const APInt *Off;
7388 if (match(S, m_scev_Add(m_scev_APInt(Off), m_SCEV(S))))
7389 Offset = *Off;
7390
7391 // Peel off a cast operation
7392 if (auto *SCast = dyn_cast<SCEVIntegralCastExpr>(S)) {
7393 CastOp = SCast->getSCEVType();
7394 S = SCast->getOperand();
7395 }
7396
7397 using namespace llvm::PatternMatch;
7398
7399 auto *SU = dyn_cast<SCEVUnknown>(S);
7400 const APInt *TrueVal, *FalseVal;
7401 if (!SU ||
7402 !match(SU->getValue(), m_Select(m_Value(Condition), m_APInt(TrueVal),
7403 m_APInt(FalseVal)))) {
7404 Condition = nullptr;
7405 return;
7406 }
7407
7408 TrueValue = *TrueVal;
7409 FalseValue = *FalseVal;
7410
7411 // Re-apply the cast we peeled off earlier
7412 if (CastOp)
7413 switch (*CastOp) {
7414 default:
7415 llvm_unreachable("Unknown SCEV cast type!");
7416
7417 case scTruncate:
7418 TrueValue = TrueValue.trunc(BitWidth);
7419 FalseValue = FalseValue.trunc(BitWidth);
7420 break;
7421 case scZeroExtend:
7422 TrueValue = TrueValue.zext(BitWidth);
7423 FalseValue = FalseValue.zext(BitWidth);
7424 break;
7425 case scSignExtend:
7426 TrueValue = TrueValue.sext(BitWidth);
7427 FalseValue = FalseValue.sext(BitWidth);
7428 break;
7429 }
7430
7431 // Re-apply the constant offset we peeled off earlier
7432 TrueValue += Offset;
7433 FalseValue += Offset;
7434 }
7435
7436 bool isRecognized() { return Condition != nullptr; }
7437 };
7438
7439 SelectPattern StartPattern(*this, BitWidth, Start);
7440 if (!StartPattern.isRecognized())
7441 return ConstantRange::getFull(BitWidth);
7442
7443 SelectPattern StepPattern(*this, BitWidth, Step);
7444 if (!StepPattern.isRecognized())
7445 return ConstantRange::getFull(BitWidth);
7446
7447 if (StartPattern.Condition != StepPattern.Condition) {
7448 // We don't handle this case today; but we could, by considering four
7449 // possibilities below instead of two. I'm not sure if there are cases where
7450 // that will help over what getRange already does, though.
7451 return ConstantRange::getFull(BitWidth);
7452 }
7453
7454 // NB! Calling ScalarEvolution::getConstant is fine, but we should not try to
7455 // construct arbitrary general SCEV expressions here. This function is called
7456 // from deep in the call stack, and calling getSCEV (on a sext instruction,
7457 // say) can end up caching a suboptimal value.
7458
7459 // FIXME: without the explicit `this` receiver below, MSVC errors out with
7460 // C2352 and C2512 (otherwise it isn't needed).
7461
7462 const SCEV *TrueStart = this->getConstant(StartPattern.TrueValue);
7463 const SCEV *TrueStep = this->getConstant(StepPattern.TrueValue);
7464 const SCEV *FalseStart = this->getConstant(StartPattern.FalseValue);
7465 const SCEV *FalseStep = this->getConstant(StepPattern.FalseValue);
7466
7467 ConstantRange TrueRange =
7468 this->getRangeForAffineAR(TrueStart, TrueStep, MaxBECount);
7469 ConstantRange FalseRange =
7470 this->getRangeForAffineAR(FalseStart, FalseStep, MaxBECount);
7471
7472 return TrueRange.unionWith(FalseRange);
7473}
7474
7475SCEV::NoWrapFlags ScalarEvolution::getNoWrapFlagsFromUB(const Value *V) {
7476 if (isa<ConstantExpr>(V)) return SCEV::FlagAnyWrap;
7477 const BinaryOperator *BinOp = cast<BinaryOperator>(V);
7478
7479 // Return early if there are no flags to propagate to the SCEV.
7481 if (auto *PDI = dyn_cast<PossiblyDisjointInst>(BinOp);
7482 PDI && PDI->isDisjoint()) {
7484 } else {
7485 if (BinOp->hasNoUnsignedWrap())
7487 if (BinOp->hasNoSignedWrap())
7489 }
7490 if (Flags == SCEV::FlagAnyWrap)
7491 return SCEV::FlagAnyWrap;
7492
7493 return isSCEVExprNeverPoison(BinOp) ? Flags : SCEV::FlagAnyWrap;
7494}
7495
7496const Instruction *
7497ScalarEvolution::getNonTrivialDefiningScopeBound(const SCEV *S) {
7498 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S))
7499 return &*AddRec->getLoop()->getHeader()->begin();
7500 if (auto *U = dyn_cast<SCEVUnknown>(S))
7501 if (auto *I = dyn_cast<Instruction>(U->getValue()))
7502 return I;
7503 return nullptr;
7504}
7505
7506const Instruction *ScalarEvolution::getDefiningScopeBound(ArrayRef<SCEVUse> Ops,
7507 bool &Precise) {
7508 Precise = true;
7509 // Do a bounded search of the def relation of the requested SCEVs.
7510 SmallPtrSet<const SCEV *, 16> Visited;
7511 SmallVector<SCEVUse> Worklist;
7512 auto pushOp = [&](const SCEV *S) {
7513 if (!Visited.insert(S).second)
7514 return;
7515 // Threshold of 30 here is arbitrary.
7516 if (Visited.size() > 30) {
7517 Precise = false;
7518 return;
7519 }
7520 Worklist.push_back(S);
7521 };
7522
7523 for (SCEVUse S : Ops)
7524 pushOp(S);
7525
7526 const Instruction *Bound = nullptr;
7527 while (!Worklist.empty()) {
7528 SCEVUse S = Worklist.pop_back_val();
7529 if (auto *DefI = getNonTrivialDefiningScopeBound(S)) {
7530 if (!Bound || DT.dominates(Bound, DefI))
7531 Bound = DefI;
7532 } else {
7533 for (SCEVUse Op : S->operands())
7534 pushOp(Op);
7535 }
7536 }
7537 return Bound ? Bound : &*F.getEntryBlock().begin();
7538}
7539
7540const Instruction *
7541ScalarEvolution::getDefiningScopeBound(ArrayRef<SCEVUse> Ops) {
7542 bool Discard;
7543 return getDefiningScopeBound(Ops, Discard);
7544}
7545
7546bool ScalarEvolution::isGuaranteedToTransferExecutionTo(const Instruction *A,
7547 const Instruction *B) {
7548 if (A->getParent() == B->getParent() &&
7550 B->getIterator()))
7551 return true;
7552
7553 auto *BLoop = LI.getLoopFor(B->getParent());
7554 if (BLoop && BLoop->getHeader() == B->getParent() &&
7555 BLoop->getLoopPreheader() == A->getParent() &&
7557 A->getParent()->end()) &&
7558 isGuaranteedToTransferExecutionToSuccessor(B->getParent()->begin(),
7559 B->getIterator()))
7560 return true;
7561 return false;
7562}
7563
7564bool ScalarEvolution::isGuaranteedNotToBePoison(const SCEV *Op) {
7565 SCEVPoisonCollector PC(/* LookThroughMaybePoisonBlocking */ true);
7566 visitAll(Op, PC);
7567 return PC.MaybePoison.empty();
7568}
7569
7570bool ScalarEvolution::isGuaranteedNotToCauseUB(const SCEV *Op) {
7571 return !SCEVExprContains(Op, [this](const SCEV *S) {
7572 const SCEV *Op1;
7573 bool M = match(S, m_scev_UDiv(m_SCEV(), m_SCEV(Op1)));
7574 // The UDiv may be UB if the divisor is poison or zero. Unless the divisor
7575 // is a non-zero constant, we have to assume the UDiv may be UB.
7576 return M && (!isKnownNonZero(Op1) || !isGuaranteedNotToBePoison(Op1));
7577 });
7578}
7579
7580bool ScalarEvolution::isSCEVExprNeverPoison(const Instruction *I) {
7581 // Only proceed if we can prove that I does not yield poison.
7583 return false;
7584
7585 // At this point we know that if I is executed, then it does not wrap
7586 // according to at least one of NSW or NUW. If I is not executed, then we do
7587 // not know if the calculation that I represents would wrap. Multiple
7588 // instructions can map to the same SCEV. If we apply NSW or NUW from I to
7589 // the SCEV, we must guarantee no wrapping for that SCEV also when it is
7590 // derived from other instructions that map to the same SCEV. We cannot make
7591 // that guarantee for cases where I is not executed. So we need to find a
7592 // upper bound on the defining scope for the SCEV, and prove that I is
7593 // executed every time we enter that scope. When the bounding scope is a
7594 // loop (the common case), this is equivalent to proving I executes on every
7595 // iteration of that loop.
7596 SmallVector<SCEVUse> SCEVOps;
7597 for (const Use &Op : I->operands()) {
7598 // I could be an extractvalue from a call to an overflow intrinsic.
7599 // TODO: We can do better here in some cases.
7600 if (isSCEVable(Op->getType()))
7601 SCEVOps.push_back(getSCEV(Op));
7602 }
7603 auto *DefI = getDefiningScopeBound(SCEVOps);
7604 return isGuaranteedToTransferExecutionTo(DefI, I);
7605}
7606
7607bool ScalarEvolution::isAddRecNeverPoison(const Instruction *I, const Loop *L) {
7608 // If we know that \c I can never be poison period, then that's enough.
7609 if (isSCEVExprNeverPoison(I))
7610 return true;
7611
7612 // If the loop only has one exit, then we know that, if the loop is entered,
7613 // any instruction dominating that exit will be executed. If any such
7614 // instruction would result in UB, the addrec cannot be poison.
7615 //
7616 // This is basically the same reasoning as in isSCEVExprNeverPoison(), but
7617 // also handles uses outside the loop header (they just need to dominate the
7618 // single exit).
7619
7620 auto *ExitingBB = L->getExitingBlock();
7621 if (!ExitingBB || !loopHasNoAbnormalExits(L))
7622 return false;
7623
7624 SmallPtrSet<const Value *, 16> KnownPoison;
7626
7627 // We start by assuming \c I, the post-inc add recurrence, is poison. Only
7628 // things that are known to be poison under that assumption go on the
7629 // Worklist.
7630 KnownPoison.insert(I);
7631 Worklist.push_back(I);
7632
7633 while (!Worklist.empty()) {
7634 const Instruction *Poison = Worklist.pop_back_val();
7635
7636 for (const Use &U : Poison->uses()) {
7637 const Instruction *PoisonUser = cast<Instruction>(U.getUser());
7638 if (mustTriggerUB(PoisonUser, KnownPoison) &&
7639 DT.dominates(PoisonUser->getParent(), ExitingBB))
7640 return true;
7641
7642 if (propagatesPoison(U) && L->contains(PoisonUser))
7643 if (KnownPoison.insert(PoisonUser).second)
7644 Worklist.push_back(PoisonUser);
7645 }
7646 }
7647
7648 return false;
7649}
7650
7651ScalarEvolution::LoopProperties
7652ScalarEvolution::getLoopProperties(const Loop *L) {
7653 using LoopProperties = ScalarEvolution::LoopProperties;
7654
7655 auto Itr = LoopPropertiesCache.find(L);
7656 if (Itr == LoopPropertiesCache.end()) {
7657 auto HasSideEffects = [](Instruction *I) {
7658 if (auto *SI = dyn_cast<StoreInst>(I))
7659 return !SI->isSimple();
7660
7661 if (I->mayThrow())
7662 return true;
7663
7664 // Non-volatile memset / memcpy do not count as side-effect for forward
7665 // progress.
7666 if (isa<MemIntrinsic>(I) && !I->isVolatile())
7667 return false;
7668
7669 return I->mayWriteToMemory();
7670 };
7671
7672 LoopProperties LP = {/* HasNoAbnormalExits */ true,
7673 /*HasNoSideEffects*/ true};
7674
7675 for (auto *BB : L->getBlocks())
7676 for (auto &I : *BB) {
7678 LP.HasNoAbnormalExits = false;
7679 if (HasSideEffects(&I))
7680 LP.HasNoSideEffects = false;
7681 if (!LP.HasNoAbnormalExits && !LP.HasNoSideEffects)
7682 break; // We're already as pessimistic as we can get.
7683 }
7684
7685 auto InsertPair = LoopPropertiesCache.insert({L, LP});
7686 assert(InsertPair.second && "We just checked!");
7687 Itr = InsertPair.first;
7688 }
7689
7690 return Itr->second;
7691}
7692
7694 // A mustprogress loop without side effects must be finite.
7695 // TODO: The check used here is very conservative. It's only *specific*
7696 // side effects which are well defined in infinite loops.
7697 return isFinite(L) || (isMustProgress(L) && loopHasNoSideEffects(L));
7698}
7699
7700const SCEV *ScalarEvolution::createSCEVIter(Value *V) {
7701 // Worklist item with a Value and a bool indicating whether all operands have
7702 // been visited already.
7705
7706 Stack.emplace_back(V, true);
7707 Stack.emplace_back(V, false);
7708 while (!Stack.empty()) {
7709 auto E = Stack.pop_back_val();
7710 Value *CurV = E.getPointer();
7711
7712 if (getExistingSCEV(CurV))
7713 continue;
7714
7716 const SCEV *CreatedSCEV = nullptr;
7717 // If all operands have been visited already, create the SCEV.
7718 if (E.getInt()) {
7719 CreatedSCEV = createSCEV(CurV);
7720 } else {
7721 // Otherwise get the operands we need to create SCEV's for before creating
7722 // the SCEV for CurV. If the SCEV for CurV can be constructed trivially,
7723 // just use it.
7724 CreatedSCEV = getOperandsToCreate(CurV, Ops);
7725 }
7726
7727 if (CreatedSCEV) {
7728 insertValueToMap(CurV, CreatedSCEV);
7729 } else {
7730 // Queue CurV for SCEV creation, followed by its's operands which need to
7731 // be constructed first.
7732 Stack.emplace_back(CurV, true);
7733 for (Value *Op : Ops)
7734 Stack.emplace_back(Op, false);
7735 }
7736 }
7737
7738 return getExistingSCEV(V);
7739}
7740
7741const SCEV *
7742ScalarEvolution::getOperandsToCreate(Value *V, SmallVectorImpl<Value *> &Ops) {
7743 if (!isSCEVable(V->getType()))
7744 return getUnknown(V);
7745
7746 if (Instruction *I = dyn_cast<Instruction>(V)) {
7747 // Don't attempt to analyze instructions in blocks that aren't
7748 // reachable. Such instructions don't matter, and they aren't required
7749 // to obey basic rules for definitions dominating uses which this
7750 // analysis depends on.
7751 if (!DT.isReachableFromEntry(I->getParent()))
7752 return getUnknown(PoisonValue::get(V->getType()));
7753 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7754 return getConstant(CI);
7755 else if (isa<GlobalAlias>(V))
7756 return getUnknown(V);
7757 else if (!isa<ConstantExpr>(V))
7758 return getUnknown(V);
7759
7761 if (auto BO =
7763 bool IsConstArg = isa<ConstantInt>(BO->RHS);
7764 switch (BO->Opcode) {
7765 case Instruction::Add:
7766 case Instruction::Mul: {
7767 // For additions and multiplications, traverse add/mul chains for which we
7768 // can potentially create a single SCEV, to reduce the number of
7769 // get{Add,Mul}Expr calls.
7770 do {
7771 if (BO->Op) {
7772 if (BO->Op != V && getExistingSCEV(BO->Op)) {
7773 Ops.push_back(BO->Op);
7774 break;
7775 }
7776 }
7777 Ops.push_back(BO->RHS);
7778 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7780 if (!NewBO ||
7781 (BO->Opcode == Instruction::Add &&
7782 (NewBO->Opcode != Instruction::Add &&
7783 NewBO->Opcode != Instruction::Sub)) ||
7784 (BO->Opcode == Instruction::Mul &&
7785 NewBO->Opcode != Instruction::Mul)) {
7786 Ops.push_back(BO->LHS);
7787 break;
7788 }
7789 // CreateSCEV calls getNoWrapFlagsFromUB, which under certain conditions
7790 // requires a SCEV for the LHS.
7791 if (BO->Op && (BO->IsNSW || BO->IsNUW)) {
7792 auto *I = dyn_cast<Instruction>(BO->Op);
7793 if (I && programUndefinedIfPoison(I)) {
7794 Ops.push_back(BO->LHS);
7795 break;
7796 }
7797 }
7798 BO = NewBO;
7799 } while (true);
7800 return nullptr;
7801 }
7802 case Instruction::Sub:
7803 case Instruction::UDiv:
7804 case Instruction::URem:
7805 break;
7806 case Instruction::AShr:
7807 case Instruction::Shl:
7808 case Instruction::Xor:
7809 if (!IsConstArg)
7810 return nullptr;
7811 break;
7812 case Instruction::And:
7813 case Instruction::Or:
7814 if (!IsConstArg && !BO->LHS->getType()->isIntegerTy(1))
7815 return nullptr;
7816 break;
7817 case Instruction::LShr:
7818 return getUnknown(V);
7819 default:
7820 llvm_unreachable("Unhandled binop");
7821 break;
7822 }
7823
7824 Ops.push_back(BO->LHS);
7825 Ops.push_back(BO->RHS);
7826 return nullptr;
7827 }
7828
7829 switch (U->getOpcode()) {
7830 case Instruction::Trunc:
7831 case Instruction::ZExt:
7832 case Instruction::SExt:
7833 case Instruction::PtrToAddr:
7834 case Instruction::PtrToInt:
7835 Ops.push_back(U->getOperand(0));
7836 return nullptr;
7837
7838 case Instruction::BitCast:
7839 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType())) {
7840 Ops.push_back(U->getOperand(0));
7841 return nullptr;
7842 }
7843 return getUnknown(V);
7844
7845 case Instruction::SDiv:
7846 case Instruction::SRem:
7847 Ops.push_back(U->getOperand(0));
7848 Ops.push_back(U->getOperand(1));
7849 return nullptr;
7850
7851 case Instruction::GetElementPtr:
7852 assert(cast<GEPOperator>(U)->getSourceElementType()->isSized() &&
7853 "GEP source element type must be sized");
7854 llvm::append_range(Ops, U->operands());
7855 return nullptr;
7856
7857 case Instruction::IntToPtr:
7858 return getUnknown(V);
7859
7860 case Instruction::PHI:
7861 // getNodeForPHI has four ways to turn a PHI into a SCEV; retrieve the
7862 // relevant nodes for each of them.
7863 //
7864 // The first is just to call simplifyInstruction, and get something back
7865 // that isn't a PHI.
7866 if (Value *V = simplifyInstruction(
7867 cast<PHINode>(U),
7868 {getDataLayout(), &TLI, &DT, &AC, /*CtxI=*/nullptr,
7869 /*UseInstrInfo=*/true, /*CanUseUndef=*/false})) {
7870 assert(V);
7871 Ops.push_back(V);
7872 return nullptr;
7873 }
7874 // The second is createNodeForPHIWithIdenticalOperands: this looks for
7875 // operands which all perform the same operation, but haven't been
7876 // CSE'ed for whatever reason.
7877 if (BinaryOperator *BO = getCommonInstForPHI(cast<PHINode>(U))) {
7878 assert(BO);
7879 Ops.push_back(BO);
7880 return nullptr;
7881 }
7882 // The third is createNodeFromSelectLikePHI; this takes a PHI which
7883 // is equivalent to a select, and analyzes it like a select.
7884 {
7885 Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr;
7887 assert(Cond);
7888 assert(LHS);
7889 assert(RHS);
7890 if (auto *CondICmp = dyn_cast<ICmpInst>(Cond)) {
7891 Ops.push_back(CondICmp->getOperand(0));
7892 Ops.push_back(CondICmp->getOperand(1));
7893 }
7894 Ops.push_back(Cond);
7895 Ops.push_back(LHS);
7896 Ops.push_back(RHS);
7897 return nullptr;
7898 }
7899 }
7900 // The fourth way is createAddRecFromPHI. It's complicated to handle here,
7901 // so just construct it recursively.
7902 //
7903 // In addition to getNodeForPHI, also construct nodes which might be needed
7904 // by getRangeRef.
7906 for (Value *V : cast<PHINode>(U)->operands())
7907 Ops.push_back(V);
7908 return nullptr;
7909 }
7910 return nullptr;
7911
7912 case Instruction::Select: {
7913 // Check if U is a select that can be simplified to a SCEVUnknown.
7914 auto CanSimplifyToUnknown = [this, U]() {
7915 if (U->getType()->isIntegerTy(1) || isa<ConstantInt>(U->getOperand(0)))
7916 return false;
7917
7918 auto *ICI = dyn_cast<ICmpInst>(U->getOperand(0));
7919 if (!ICI)
7920 return false;
7921 Value *LHS = ICI->getOperand(0);
7922 Value *RHS = ICI->getOperand(1);
7923 if (ICI->getPredicate() == CmpInst::ICMP_EQ ||
7924 ICI->getPredicate() == CmpInst::ICMP_NE) {
7926 return true;
7927 } else if (getTypeSizeInBits(LHS->getType()) >
7928 getTypeSizeInBits(U->getType()))
7929 return true;
7930 return false;
7931 };
7932 if (CanSimplifyToUnknown())
7933 return getUnknown(U);
7934
7935 llvm::append_range(Ops, U->operands());
7936 return nullptr;
7937 break;
7938 }
7939 case Instruction::Call:
7940 case Instruction::Invoke:
7941 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand()) {
7942 Ops.push_back(RV);
7943 return nullptr;
7944 }
7945
7946 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
7947 switch (II->getIntrinsicID()) {
7948 case Intrinsic::abs:
7949 Ops.push_back(II->getArgOperand(0));
7950 return nullptr;
7951 case Intrinsic::umax:
7952 case Intrinsic::umin:
7953 case Intrinsic::smax:
7954 case Intrinsic::smin:
7955 case Intrinsic::usub_sat:
7956 case Intrinsic::uadd_sat:
7957 Ops.push_back(II->getArgOperand(0));
7958 Ops.push_back(II->getArgOperand(1));
7959 return nullptr;
7960 case Intrinsic::start_loop_iterations:
7961 case Intrinsic::annotation:
7962 case Intrinsic::ptr_annotation:
7963 Ops.push_back(II->getArgOperand(0));
7964 return nullptr;
7965 default:
7966 break;
7967 }
7968 }
7969 break;
7970 }
7971
7972 return nullptr;
7973}
7974
7975const SCEV *ScalarEvolution::createSCEV(Value *V) {
7976 if (!isSCEVable(V->getType()))
7977 return getUnknown(V);
7978
7979 if (Instruction *I = dyn_cast<Instruction>(V)) {
7980 // Don't attempt to analyze instructions in blocks that aren't
7981 // reachable. Such instructions don't matter, and they aren't required
7982 // to obey basic rules for definitions dominating uses which this
7983 // analysis depends on.
7984 if (!DT.isReachableFromEntry(I->getParent()))
7985 return getUnknown(PoisonValue::get(V->getType()));
7986 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7987 return getConstant(CI);
7988 else if (isa<GlobalAlias>(V))
7989 return getUnknown(V);
7990 else if (!isa<ConstantExpr>(V))
7991 return getUnknown(V);
7992
7993 const SCEV *LHS;
7994 const SCEV *RHS;
7995
7997 if (auto BO =
7999 switch (BO->Opcode) {
8000 case Instruction::Add: {
8001 // The simple thing to do would be to just call getSCEV on both operands
8002 // and call getAddExpr with the result. However if we're looking at a
8003 // bunch of things all added together, this can be quite inefficient,
8004 // because it leads to N-1 getAddExpr calls for N ultimate operands.
8005 // Instead, gather up all the operands and make a single getAddExpr call.
8006 // LLVM IR canonical form means we need only traverse the left operands.
8008 do {
8009 if (BO->Op) {
8010 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
8011 AddOps.push_back(OpSCEV);
8012 break;
8013 }
8014
8015 // If a NUW or NSW flag can be applied to the SCEV for this
8016 // addition, then compute the SCEV for this addition by itself
8017 // with a separate call to getAddExpr. We need to do that
8018 // instead of pushing the operands of the addition onto AddOps,
8019 // since the flags are only known to apply to this particular
8020 // addition - they may not apply to other additions that can be
8021 // formed with operands from AddOps.
8022 const SCEV *RHS = getSCEV(BO->RHS);
8023 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
8024 if (Flags != SCEV::FlagAnyWrap) {
8025 const SCEV *LHS = getSCEV(BO->LHS);
8026 if (BO->Opcode == Instruction::Sub)
8027 AddOps.push_back(getMinusSCEV(LHS, RHS, Flags));
8028 else
8029 AddOps.push_back(getAddExpr(LHS, RHS, Flags));
8030 break;
8031 }
8032 }
8033
8034 if (BO->Opcode == Instruction::Sub)
8035 AddOps.push_back(getNegativeSCEV(getSCEV(BO->RHS)));
8036 else
8037 AddOps.push_back(getSCEV(BO->RHS));
8038
8039 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
8041 if (!NewBO || (NewBO->Opcode != Instruction::Add &&
8042 NewBO->Opcode != Instruction::Sub)) {
8043 AddOps.push_back(getSCEV(BO->LHS));
8044 break;
8045 }
8046 BO = NewBO;
8047 } while (true);
8048
8049 return getAddExpr(AddOps);
8050 }
8051
8052 case Instruction::Mul: {
8054 do {
8055 if (BO->Op) {
8056 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
8057 MulOps.push_back(OpSCEV);
8058 break;
8059 }
8060
8061 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
8062 if (Flags != SCEV::FlagAnyWrap) {
8063 LHS = getSCEV(BO->LHS);
8064 RHS = getSCEV(BO->RHS);
8065 MulOps.push_back(getMulExpr(LHS, RHS, Flags));
8066 break;
8067 }
8068 }
8069
8070 MulOps.push_back(getSCEV(BO->RHS));
8071 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
8073 if (!NewBO || NewBO->Opcode != Instruction::Mul) {
8074 MulOps.push_back(getSCEV(BO->LHS));
8075 break;
8076 }
8077 BO = NewBO;
8078 } while (true);
8079
8080 return getMulExpr(MulOps);
8081 }
8082 case Instruction::UDiv:
8083 LHS = getSCEV(BO->LHS);
8084 RHS = getSCEV(BO->RHS);
8085 return getUDivExpr(LHS, RHS);
8086 case Instruction::URem:
8087 LHS = getSCEV(BO->LHS);
8088 RHS = getSCEV(BO->RHS);
8089 return getURemExpr(LHS, RHS);
8090 case Instruction::Sub: {
8092 if (BO->Op)
8093 Flags = getNoWrapFlagsFromUB(BO->Op);
8094 LHS = getSCEV(BO->LHS);
8095 RHS = getSCEV(BO->RHS);
8096 return getMinusSCEV(LHS, RHS, Flags);
8097 }
8098 case Instruction::And:
8099 // For an expression like x&255 that merely masks off the high bits,
8100 // use zext(trunc(x)) as the SCEV expression.
8101 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
8102 if (CI->isZero())
8103 return getSCEV(BO->RHS);
8104 if (CI->isMinusOne())
8105 return getSCEV(BO->LHS);
8106 const APInt &A = CI->getValue();
8107
8108 // Instcombine's ShrinkDemandedConstant may strip bits out of
8109 // constants, obscuring what would otherwise be a low-bits mask.
8110 // Use computeKnownBits to compute what ShrinkDemandedConstant
8111 // knew about to reconstruct a low-bits mask value.
8112 unsigned LZ = A.countl_zero();
8113 unsigned TZ = A.countr_zero();
8114 unsigned BitWidth = A.getBitWidth();
8115 KnownBits Known(BitWidth);
8116 computeKnownBits(BO->LHS, Known, getDataLayout(), &AC, nullptr, &DT);
8117
8118 APInt EffectiveMask =
8119 APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ);
8120 if ((LZ != 0 || TZ != 0) && !((~A & ~Known.Zero) & EffectiveMask)) {
8121 const SCEV *MulCount = getConstant(APInt::getOneBitSet(BitWidth, TZ));
8122 const SCEV *LHS = getSCEV(BO->LHS);
8123 const SCEV *ShiftedLHS = nullptr;
8124 if (auto *LHSMul = dyn_cast<SCEVMulExpr>(LHS)) {
8125 if (auto *OpC = dyn_cast<SCEVConstant>(LHSMul->getOperand(0))) {
8126 // For an expression like (x * 8) & 8, simplify the multiply.
8127 unsigned MulZeros = OpC->getAPInt().countr_zero();
8128 unsigned GCD = std::min(MulZeros, TZ);
8129 APInt DivAmt = APInt::getOneBitSet(BitWidth, TZ - GCD);
8131 MulOps.push_back(getConstant(OpC->getAPInt().ashr(GCD)));
8132 append_range(MulOps, LHSMul->operands().drop_front());
8133 auto *NewMul = getMulExpr(MulOps, LHSMul->getNoWrapFlags());
8134 ShiftedLHS = getUDivExpr(NewMul, getConstant(DivAmt));
8135 }
8136 }
8137 if (!ShiftedLHS)
8138 ShiftedLHS = getUDivExpr(LHS, MulCount);
8139 return getMulExpr(
8141 getTruncateExpr(ShiftedLHS,
8142 IntegerType::get(getContext(), BitWidth - LZ - TZ)),
8143 BO->LHS->getType()),
8144 MulCount);
8145 }
8146 }
8147 // Binary `and` is a bit-wise `umin`.
8148 if (BO->LHS->getType()->isIntegerTy(1)) {
8149 LHS = getSCEV(BO->LHS);
8150 RHS = getSCEV(BO->RHS);
8151 return getUMinExpr(LHS, RHS);
8152 }
8153 break;
8154
8155 case Instruction::Or:
8156 // Binary `or` is a bit-wise `umax`.
8157 if (BO->LHS->getType()->isIntegerTy(1)) {
8158 LHS = getSCEV(BO->LHS);
8159 RHS = getSCEV(BO->RHS);
8160 return getUMaxExpr(LHS, RHS);
8161 }
8162 break;
8163
8164 case Instruction::Xor:
8165 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
8166 // If the RHS of xor is -1, then this is a not operation.
8167 if (CI->isMinusOne())
8168 return getNotSCEV(getSCEV(BO->LHS));
8169
8170 // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask.
8171 // This is a variant of the check for xor with -1, and it handles
8172 // the case where instcombine has trimmed non-demanded bits out
8173 // of an xor with -1.
8174 if (auto *LBO = dyn_cast<BinaryOperator>(BO->LHS))
8175 if (ConstantInt *LCI = dyn_cast<ConstantInt>(LBO->getOperand(1)))
8176 if (LBO->getOpcode() == Instruction::And &&
8177 LCI->getValue() == CI->getValue())
8178 if (const SCEVZeroExtendExpr *Z =
8180 Type *UTy = BO->LHS->getType();
8181 const SCEV *Z0 = Z->getOperand();
8182 Type *Z0Ty = Z0->getType();
8183 unsigned Z0TySize = getTypeSizeInBits(Z0Ty);
8184
8185 // If C is a low-bits mask, the zero extend is serving to
8186 // mask off the high bits. Complement the operand and
8187 // re-apply the zext.
8188 if (CI->getValue().isMask(Z0TySize))
8189 return getZeroExtendExpr(getNotSCEV(Z0), UTy);
8190
8191 // If C is a single bit, it may be in the sign-bit position
8192 // before the zero-extend. In this case, represent the xor
8193 // using an add, which is equivalent, and re-apply the zext.
8194 APInt Trunc = CI->getValue().trunc(Z0TySize);
8195 if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() &&
8196 Trunc.isSignMask())
8197 return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)),
8198 UTy);
8199 }
8200 }
8201 break;
8202
8203 case Instruction::Shl:
8204 // Turn shift left of a constant amount into a multiply.
8205 if (ConstantInt *SA = dyn_cast<ConstantInt>(BO->RHS)) {
8206 uint32_t BitWidth = cast<IntegerType>(SA->getType())->getBitWidth();
8207
8208 // If the shift count is not less than the bitwidth, the result of
8209 // the shift is undefined. Don't try to analyze it, because the
8210 // resolution chosen here may differ from the resolution chosen in
8211 // other parts of the compiler.
8212 if (SA->getValue().uge(BitWidth))
8213 break;
8214
8215 // We can safely preserve the nuw flag in all cases. It's also safe to
8216 // turn a nuw nsw shl into a nuw nsw mul. However, nsw in isolation
8217 // requires special handling. It can be preserved as long as we're not
8218 // left shifting by bitwidth - 1.
8219 auto Flags = SCEV::FlagAnyWrap;
8220 if (BO->Op) {
8221 auto MulFlags = getNoWrapFlagsFromUB(BO->Op);
8222 if (any(MulFlags & SCEV::FlagNSW) &&
8223 (any(MulFlags & SCEV::FlagNUW) ||
8224 SA->getValue().ult(BitWidth - 1)))
8226 if (any(MulFlags & SCEV::FlagNUW))
8228 }
8229
8230 ConstantInt *X = ConstantInt::get(
8231 getContext(), APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
8232 return getMulExpr(getSCEV(BO->LHS), getConstant(X), Flags);
8233 }
8234 break;
8235
8236 case Instruction::AShr:
8237 // AShr X, C, where C is a constant.
8238 ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS);
8239 if (!CI)
8240 break;
8241
8242 Type *OuterTy = BO->LHS->getType();
8243 uint64_t BitWidth = getTypeSizeInBits(OuterTy);
8244 // If the shift count is not less than the bitwidth, the result of
8245 // the shift is undefined. Don't try to analyze it, because the
8246 // resolution chosen here may differ from the resolution chosen in
8247 // other parts of the compiler.
8248 if (CI->getValue().uge(BitWidth))
8249 break;
8250
8251 if (CI->isZero())
8252 return getSCEV(BO->LHS); // shift by zero --> noop
8253
8254 uint64_t AShrAmt = CI->getZExtValue();
8255 Type *TruncTy = IntegerType::get(getContext(), BitWidth - AShrAmt);
8256
8257 Operator *L = dyn_cast<Operator>(BO->LHS);
8258 const SCEV *AddTruncateExpr = nullptr;
8259 ConstantInt *ShlAmtCI = nullptr;
8260 const SCEV *AddConstant = nullptr;
8261
8262 if (L && L->getOpcode() == Instruction::Add) {
8263 // X = Shl A, n
8264 // Y = Add X, c
8265 // Z = AShr Y, m
8266 // n, c and m are constants.
8267
8268 Operator *LShift = dyn_cast<Operator>(L->getOperand(0));
8269 ConstantInt *AddOperandCI = dyn_cast<ConstantInt>(L->getOperand(1));
8270 if (LShift && LShift->getOpcode() == Instruction::Shl) {
8271 if (AddOperandCI) {
8272 const SCEV *ShlOp0SCEV = getSCEV(LShift->getOperand(0));
8273 ShlAmtCI = dyn_cast<ConstantInt>(LShift->getOperand(1));
8274 // since we truncate to TruncTy, the AddConstant should be of the
8275 // same type, so create a new Constant with type same as TruncTy.
8276 // Also, the Add constant should be shifted right by AShr amount.
8277 APInt AddOperand = AddOperandCI->getValue().ashr(AShrAmt);
8278 AddConstant = getConstant(AddOperand.trunc(BitWidth - AShrAmt));
8279 // we model the expression as sext(add(trunc(A), c << n)), since the
8280 // sext(trunc) part is already handled below, we create a
8281 // AddExpr(TruncExp) which will be used later.
8282 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
8283 }
8284 }
8285 } else if (L && L->getOpcode() == Instruction::Shl) {
8286 // X = Shl A, n
8287 // Y = AShr X, m
8288 // Both n and m are constant.
8289
8290 const SCEV *ShlOp0SCEV = getSCEV(L->getOperand(0));
8291 ShlAmtCI = dyn_cast<ConstantInt>(L->getOperand(1));
8292 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
8293 }
8294
8295 if (AddTruncateExpr && ShlAmtCI) {
8296 // We can merge the two given cases into a single SCEV statement,
8297 // incase n = m, the mul expression will be 2^0, so it gets resolved to
8298 // a simpler case. The following code handles the two cases:
8299 //
8300 // 1) For a two-shift sext-inreg, i.e. n = m,
8301 // use sext(trunc(x)) as the SCEV expression.
8302 //
8303 // 2) When n > m, use sext(mul(trunc(x), 2^(n-m)))) as the SCEV
8304 // expression. We already checked that ShlAmt < BitWidth, so
8305 // the multiplier, 1 << (ShlAmt - AShrAmt), fits into TruncTy as
8306 // ShlAmt - AShrAmt < Amt.
8307 const APInt &ShlAmt = ShlAmtCI->getValue();
8308 if (ShlAmt.ult(BitWidth) && ShlAmt.uge(AShrAmt)) {
8309 APInt Mul = APInt::getOneBitSet(BitWidth - AShrAmt,
8310 ShlAmtCI->getZExtValue() - AShrAmt);
8311 const SCEV *CompositeExpr =
8312 getMulExpr(AddTruncateExpr, getConstant(Mul));
8313 if (L->getOpcode() != Instruction::Shl)
8314 CompositeExpr = getAddExpr(CompositeExpr, AddConstant);
8315
8316 return getSignExtendExpr(CompositeExpr, OuterTy);
8317 }
8318 }
8319 break;
8320 }
8321 }
8322
8323 switch (U->getOpcode()) {
8324 case Instruction::Trunc:
8325 return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType());
8326
8327 case Instruction::ZExt:
8328 return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType());
8329
8330 case Instruction::SExt:
8331 if (auto BO = MatchBinaryOp(U->getOperand(0), getDataLayout(), AC, DT,
8333 // The NSW flag of a subtract does not always survive the conversion to
8334 // A + (-1)*B. By pushing sign extension onto its operands we are much
8335 // more likely to preserve NSW and allow later AddRec optimisations.
8336 //
8337 // NOTE: This is effectively duplicating this logic from getSignExtend:
8338 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
8339 // but by that point the NSW information has potentially been lost.
8340 if (BO->Opcode == Instruction::Sub && BO->IsNSW) {
8341 Type *Ty = U->getType();
8342 auto *V1 = getSignExtendExpr(getSCEV(BO->LHS), Ty);
8343 auto *V2 = getSignExtendExpr(getSCEV(BO->RHS), Ty);
8344 return getMinusSCEV(V1, V2, SCEV::FlagNSW);
8345 }
8346 }
8347 return getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType());
8348
8349 case Instruction::BitCast:
8350 // BitCasts are no-op casts so we just eliminate the cast.
8351 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType()))
8352 return getSCEV(U->getOperand(0));
8353 break;
8354
8355 case Instruction::PtrToAddr: {
8356 const SCEV *IntOp = getPtrToAddrExpr(getSCEV(U->getOperand(0)));
8357 if (isa<SCEVCouldNotCompute>(IntOp))
8358 return getUnknown(V);
8359 return IntOp;
8360 }
8361
8362 case Instruction::PtrToInt: {
8363 // Pointer to integer cast is straight-forward, so do model it.
8364 const SCEV *Op = getSCEV(U->getOperand(0));
8365 Type *DstIntTy = U->getType();
8366 // But only if effective SCEV (integer) type is wide enough to represent
8367 // all possible pointer values.
8368 const SCEV *IntOp = getPtrToIntExpr(Op, DstIntTy);
8369 if (isa<SCEVCouldNotCompute>(IntOp))
8370 return getUnknown(V);
8371 return IntOp;
8372 }
8373 case Instruction::IntToPtr:
8374 // Just don't deal with inttoptr casts.
8375 return getUnknown(V);
8376
8377 case Instruction::SDiv:
8378 // If both operands are non-negative, this is just an udiv.
8379 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8380 isKnownNonNegative(getSCEV(U->getOperand(1))))
8381 return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8382 break;
8383
8384 case Instruction::SRem:
8385 // If both operands are non-negative, this is just an urem.
8386 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8387 isKnownNonNegative(getSCEV(U->getOperand(1))))
8388 return getURemExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8389 break;
8390
8391 case Instruction::GetElementPtr:
8392 return createNodeForGEP(cast<GEPOperator>(U));
8393
8394 case Instruction::PHI:
8395 return createNodeForPHI(cast<PHINode>(U));
8396
8397 case Instruction::Select:
8398 return createNodeForSelectOrPHI(U, U->getOperand(0), U->getOperand(1),
8399 U->getOperand(2));
8400
8401 case Instruction::Call:
8402 case Instruction::Invoke:
8403 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand())
8404 return getSCEV(RV);
8405
8406 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
8407 switch (II->getIntrinsicID()) {
8408 case Intrinsic::abs:
8409 return getAbsExpr(
8410 getSCEV(II->getArgOperand(0)),
8411 /*IsNSW=*/cast<ConstantInt>(II->getArgOperand(1))->isOne());
8412 case Intrinsic::umax:
8413 LHS = getSCEV(II->getArgOperand(0));
8414 RHS = getSCEV(II->getArgOperand(1));
8415 return getUMaxExpr(LHS, RHS);
8416 case Intrinsic::umin:
8417 LHS = getSCEV(II->getArgOperand(0));
8418 RHS = getSCEV(II->getArgOperand(1));
8419 return getUMinExpr(LHS, RHS);
8420 case Intrinsic::smax:
8421 LHS = getSCEV(II->getArgOperand(0));
8422 RHS = getSCEV(II->getArgOperand(1));
8423 return getSMaxExpr(LHS, RHS);
8424 case Intrinsic::smin:
8425 LHS = getSCEV(II->getArgOperand(0));
8426 RHS = getSCEV(II->getArgOperand(1));
8427 return getSMinExpr(LHS, RHS);
8428 case Intrinsic::usub_sat: {
8429 const SCEV *X = getSCEV(II->getArgOperand(0));
8430 const SCEV *Y = getSCEV(II->getArgOperand(1));
8431 const SCEV *ClampedY = getUMinExpr(X, Y);
8432 return getMinusSCEV(X, ClampedY, SCEV::FlagNUW);
8433 }
8434 case Intrinsic::uadd_sat: {
8435 const SCEV *X = getSCEV(II->getArgOperand(0));
8436 const SCEV *Y = getSCEV(II->getArgOperand(1));
8437 const SCEV *ClampedX = getUMinExpr(X, getNotSCEV(Y));
8438 return getAddExpr(ClampedX, Y, SCEV::FlagNUW);
8439 }
8440 case Intrinsic::start_loop_iterations:
8441 case Intrinsic::annotation:
8442 case Intrinsic::ptr_annotation:
8443 // A start_loop_iterations or llvm.annotation or llvm.prt.annotation is
8444 // just eqivalent to the first operand for SCEV purposes.
8445 return getSCEV(II->getArgOperand(0));
8446 case Intrinsic::vscale:
8447 return getVScale(II->getType());
8448 default:
8449 break;
8450 }
8451 }
8452 break;
8453 }
8454
8455 return getUnknown(V);
8456}
8457
8458//===----------------------------------------------------------------------===//
8459// Iteration Count Computation Code
8460//
8461
8463 if (isa<SCEVCouldNotCompute>(ExitCount))
8464 return getCouldNotCompute();
8465
8466 auto *ExitCountType = ExitCount->getType();
8467 assert(ExitCountType->isIntegerTy());
8468 auto *EvalTy = Type::getIntNTy(ExitCountType->getContext(),
8469 1 + ExitCountType->getScalarSizeInBits());
8470 return getTripCountFromExitCount(ExitCount, EvalTy, nullptr);
8471}
8472
8474 Type *EvalTy,
8475 const Loop *L) {
8476 if (isa<SCEVCouldNotCompute>(ExitCount))
8477 return getCouldNotCompute();
8478
8479 unsigned ExitCountSize = getTypeSizeInBits(ExitCount->getType());
8480 unsigned EvalSize = EvalTy->getPrimitiveSizeInBits();
8481
8482 auto CanAddOneWithoutOverflow = [&]() {
8483 ConstantRange ExitCountRange =
8484 getRangeRef(ExitCount, RangeSignHint::HINT_RANGE_UNSIGNED);
8485 if (!ExitCountRange.contains(APInt::getMaxValue(ExitCountSize)))
8486 return true;
8487
8488 return L && isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, ExitCount,
8489 getMinusOne(ExitCount->getType()));
8490 };
8491
8492 // If we need to zero extend the backedge count, check if we can add one to
8493 // it prior to zero extending without overflow. Provided this is safe, it
8494 // allows better simplification of the +1.
8495 if (EvalSize > ExitCountSize && CanAddOneWithoutOverflow())
8496 return getZeroExtendExpr(
8497 getAddExpr(ExitCount, getOne(ExitCount->getType())), EvalTy);
8498
8499 // Get the total trip count from the count by adding 1. This may wrap.
8500 return getAddExpr(getTruncateOrZeroExtend(ExitCount, EvalTy), getOne(EvalTy));
8501}
8502
8503static unsigned getConstantTripCount(const SCEVConstant *ExitCount) {
8504 if (!ExitCount)
8505 return 0;
8506
8507 ConstantInt *ExitConst = ExitCount->getValue();
8508
8509 // Guard against huge trip counts.
8510 if (ExitConst->getValue().getActiveBits() > 32)
8511 return 0;
8512
8513 // In case of integer overflow, this returns 0, which is correct.
8514 return ((unsigned)ExitConst->getZExtValue()) + 1;
8515}
8516
8518 auto *ExitCount = dyn_cast<SCEVConstant>(getBackedgeTakenCount(L, Exact));
8519 return getConstantTripCount(ExitCount);
8520}
8521
8522unsigned
8524 const BasicBlock *ExitingBlock) {
8525 assert(ExitingBlock && "Must pass a non-null exiting block!");
8526 assert(L->isLoopExiting(ExitingBlock) &&
8527 "Exiting block must actually branch out of the loop!");
8528 const SCEVConstant *ExitCount =
8529 dyn_cast<SCEVConstant>(getExitCount(L, ExitingBlock));
8530 return getConstantTripCount(ExitCount);
8531}
8532
8534 const Loop *L, SmallVectorImpl<const SCEVPredicate *> *Predicates) {
8535
8536 const auto *MaxExitCount =
8537 Predicates ? getPredicatedConstantMaxBackedgeTakenCount(L, *Predicates)
8539 return getConstantTripCount(dyn_cast<SCEVConstant>(MaxExitCount));
8540}
8541
8543 SmallVector<BasicBlock *, 8> ExitingBlocks;
8544 L->getExitingBlocks(ExitingBlocks);
8545
8546 std::optional<unsigned> Res;
8547 for (auto *ExitingBB : ExitingBlocks) {
8548 unsigned Multiple = getSmallConstantTripMultiple(L, ExitingBB);
8549 if (!Res)
8550 Res = Multiple;
8551 Res = std::gcd(*Res, Multiple);
8552 }
8553 return Res.value_or(1);
8554}
8555
8557 const SCEV *ExitCount) {
8558 if (isa<SCEVCouldNotCompute>(ExitCount))
8559 return 1;
8560
8561 // Get the trip count
8562 const SCEV *TCExpr = getTripCountFromExitCount(applyLoopGuards(ExitCount, L));
8563
8564 APInt Multiple = getNonZeroConstantMultiple(TCExpr);
8565 // If a trip multiple is huge (>=2^32), the trip count is still divisible by
8566 // the greatest power of 2 divisor less than 2^32.
8567 return Multiple.getActiveBits() > 32
8568 ? 1U << std::min(31U, Multiple.countTrailingZeros())
8569 : (unsigned)Multiple.getZExtValue();
8570}
8571
8572/// Returns the largest constant divisor of the trip count of this loop as a
8573/// normal unsigned value, if possible. This means that the actual trip count is
8574/// always a multiple of the returned value (don't forget the trip count could
8575/// very well be zero as well!).
8576///
8577/// Returns 1 if the trip count is unknown or not guaranteed to be the
8578/// multiple of a constant (which is also the case if the trip count is simply
8579/// constant, use getSmallConstantTripCount for that case), Will also return 1
8580/// if the trip count is very large (>= 2^32).
8581///
8582/// As explained in the comments for getSmallConstantTripCount, this assumes
8583/// that control exits the loop via ExitingBlock.
8584unsigned
8586 const BasicBlock *ExitingBlock) {
8587 assert(ExitingBlock && "Must pass a non-null exiting block!");
8588 assert(L->isLoopExiting(ExitingBlock) &&
8589 "Exiting block must actually branch out of the loop!");
8590 const SCEV *ExitCount = getExitCount(L, ExitingBlock);
8591 return getSmallConstantTripMultiple(L, ExitCount);
8592}
8593
8595 const BasicBlock *ExitingBlock,
8596 ExitCountKind Kind) {
8597 switch (Kind) {
8598 case Exact:
8599 return getBackedgeTakenInfo(L).getExact(ExitingBlock, this);
8600 case SymbolicMaximum:
8601 return getBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this);
8602 case ConstantMaximum:
8603 return getBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this);
8604 };
8605 llvm_unreachable("Invalid ExitCountKind!");
8606}
8607
8609 const Loop *L, const BasicBlock *ExitingBlock,
8611 switch (Kind) {
8612 case Exact:
8613 return getPredicatedBackedgeTakenInfo(L).getExact(ExitingBlock, this,
8614 Predicates);
8615 case SymbolicMaximum:
8616 return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this,
8617 Predicates);
8618 case ConstantMaximum:
8619 return getPredicatedBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this,
8620 Predicates);
8621 };
8622 llvm_unreachable("Invalid ExitCountKind!");
8623}
8624
8627 return getPredicatedBackedgeTakenInfo(L).getExact(L, this, &Preds);
8628}
8629
8631 ExitCountKind Kind) {
8632 switch (Kind) {
8633 case Exact:
8634 return getBackedgeTakenInfo(L).getExact(L, this);
8635 case ConstantMaximum:
8636 return getBackedgeTakenInfo(L).getConstantMax(this);
8637 case SymbolicMaximum:
8638 return getBackedgeTakenInfo(L).getSymbolicMax(L, this);
8639 };
8640 llvm_unreachable("Invalid ExitCountKind!");
8641}
8642
8645 return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(L, this, &Preds);
8646}
8647
8650 return getPredicatedBackedgeTakenInfo(L).getConstantMax(this, &Preds);
8651}
8652
8654 return getBackedgeTakenInfo(L).isConstantMaxOrZero(this);
8655}
8656
8657/// Push PHI nodes in the header of the given loop onto the given Worklist.
8658static void PushLoopPHIs(const Loop *L,
8661 BasicBlock *Header = L->getHeader();
8662
8663 // Push all Loop-header PHIs onto the Worklist stack.
8664 for (PHINode &PN : Header->phis())
8665 if (Visited.insert(&PN).second)
8666 Worklist.push_back(&PN);
8667}
8668
8669ScalarEvolution::BackedgeTakenInfo &
8670ScalarEvolution::getPredicatedBackedgeTakenInfo(const Loop *L) {
8671 auto &BTI = getBackedgeTakenInfo(L);
8672 if (BTI.hasFullInfo())
8673 return BTI;
8674
8675 auto Pair = PredicatedBackedgeTakenCounts.try_emplace(L);
8676
8677 if (!Pair.second)
8678 return Pair.first->second;
8679
8680 BackedgeTakenInfo Result =
8681 computeBackedgeTakenCount(L, /*AllowPredicates=*/true);
8682
8683 return PredicatedBackedgeTakenCounts.find(L)->second = std::move(Result);
8684}
8685
8686ScalarEvolution::BackedgeTakenInfo &
8687ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
8688 // Initially insert an invalid entry for this loop. If the insertion
8689 // succeeds, proceed to actually compute a backedge-taken count and
8690 // update the value. The temporary CouldNotCompute value tells SCEV
8691 // code elsewhere that it shouldn't attempt to request a new
8692 // backedge-taken count, which could result in infinite recursion.
8693 std::pair<DenseMap<const Loop *, BackedgeTakenInfo>::iterator, bool> Pair =
8694 BackedgeTakenCounts.try_emplace(L);
8695 if (!Pair.second)
8696 return Pair.first->second;
8697
8698 // computeBackedgeTakenCount may allocate memory for its result. Inserting it
8699 // into the BackedgeTakenCounts map transfers ownership. Otherwise, the result
8700 // must be cleared in this scope.
8701 BackedgeTakenInfo Result = computeBackedgeTakenCount(L);
8702
8703 // Now that we know more about the trip count for this loop, forget any
8704 // existing SCEV values for PHI nodes in this loop since they are only
8705 // conservative estimates made without the benefit of trip count
8706 // information. This invalidation is not necessary for correctness, and is
8707 // only done to produce more precise results.
8708 if (Result.hasAnyInfo()) {
8709 // Invalidate any expression using an addrec in this loop.
8710 SmallVector<SCEVUse, 8> ToForget;
8711 auto LoopUsersIt = LoopUsers.find(L);
8712 if (LoopUsersIt != LoopUsers.end())
8713 append_range(ToForget, LoopUsersIt->second);
8714 forgetMemoizedResults(ToForget);
8715
8716 // Invalidate constant-evolved loop header phis.
8717 for (PHINode &PN : L->getHeader()->phis())
8718 ConstantEvolutionLoopExitValue.erase(&PN);
8719 }
8720
8721 // Re-lookup the insert position, since the call to
8722 // computeBackedgeTakenCount above could result in a
8723 // recusive call to getBackedgeTakenInfo (on a different
8724 // loop), which would invalidate the iterator computed
8725 // earlier.
8726 return BackedgeTakenCounts.find(L)->second = std::move(Result);
8727}
8728
8730 // This method is intended to forget all info about loops. It should
8731 // invalidate caches as if the following happened:
8732 // - The trip counts of all loops have changed arbitrarily
8733 // - Every llvm::Value has been updated in place to produce a different
8734 // result.
8735 BackedgeTakenCounts.clear();
8736 PredicatedBackedgeTakenCounts.clear();
8737 BECountUsers.clear();
8738 LoopPropertiesCache.clear();
8739 ConstantEvolutionLoopExitValue.clear();
8740 ValueExprMap.clear();
8741 ValuesAtScopes.clear();
8742 ValuesAtScopesUsers.clear();
8743 LoopDispositions.clear();
8744 BlockDispositions.clear();
8745 UnsignedRanges.clear();
8746 SignedRanges.clear();
8747 ExprValueMap.clear();
8748 HasRecMap.clear();
8749 ConstantMultipleCache.clear();
8750 PredicatedSCEVRewrites.clear();
8751 FoldCache.clear();
8752 FoldCacheUser.clear();
8753}
8754void ScalarEvolution::visitAndClearUsers(
8757 SmallVectorImpl<SCEVUse> &ToForget) {
8758 while (!Worklist.empty()) {
8759 Instruction *I = Worklist.pop_back_val();
8760 if (!isSCEVable(I->getType()) && !isa<WithOverflowInst>(I))
8761 continue;
8762
8764 ValueExprMap.find_as(static_cast<Value *>(I));
8765 if (It != ValueExprMap.end()) {
8766 ToForget.push_back(It->second);
8767 eraseValueFromMap(It->first);
8768 if (PHINode *PN = dyn_cast<PHINode>(I))
8769 ConstantEvolutionLoopExitValue.erase(PN);
8770 }
8771
8772 PushDefUseChildren(I, Worklist, Visited);
8773 }
8774}
8775
8777 SmallVector<const Loop *, 16> LoopWorklist(1, L);
8780 SmallVector<SCEVUse, 16> ToForget;
8781
8782 // Iterate over all the loops and sub-loops to drop SCEV information.
8783 while (!LoopWorklist.empty()) {
8784 auto *CurrL = LoopWorklist.pop_back_val();
8785
8786 // Drop any stored trip count value.
8787 forgetBackedgeTakenCounts(CurrL, /* Predicated */ false);
8788 forgetBackedgeTakenCounts(CurrL, /* Predicated */ true);
8789
8790 // Drop information about predicated SCEV rewrites for this loop.
8791 PredicatedSCEVRewrites.remove_if(
8792 [&](const auto &Entry) { return Entry.first.second == CurrL; });
8793
8794 auto LoopUsersItr = LoopUsers.find(CurrL);
8795 if (LoopUsersItr != LoopUsers.end())
8796 llvm::append_range(ToForget, LoopUsersItr->second);
8797
8798 // Drop information about expressions based on loop-header PHIs.
8799 PushLoopPHIs(CurrL, Worklist, Visited);
8800 visitAndClearUsers(Worklist, Visited, ToForget);
8801
8802 LoopPropertiesCache.erase(CurrL);
8803 // Forget all contained loops too, to avoid dangling entries in the
8804 // ValuesAtScopes map.
8805 LoopWorklist.append(CurrL->begin(), CurrL->end());
8806 }
8807 forgetMemoizedResults(ToForget);
8808}
8809
8811 forgetLoop(L->getOutermostLoop());
8812}
8813
8816 if (!I) return;
8817
8818 // Drop information about expressions based on loop-header PHIs.
8821 SmallVector<SCEVUse, 8> ToForget;
8822 Worklist.push_back(I);
8823 Visited.insert(I);
8824 visitAndClearUsers(Worklist, Visited, ToForget);
8825
8826 forgetMemoizedResults(ToForget);
8827}
8828
8830 if (!isSCEVable(V->getType()))
8831 return;
8832
8833 // If SCEV looked through a trivial LCSSA phi node, we might have SCEV's
8834 // directly using a SCEVUnknown/SCEVAddRec defined in the loop. After an
8835 // extra predecessor is added, this is no longer valid. Find all Unknowns and
8836 // AddRecs defined in the loop and invalidate any SCEV's making use of them.
8837 if (const SCEV *S = getExistingSCEV(V)) {
8838 struct InvalidationRootCollector {
8839 Loop *L;
8841
8842 InvalidationRootCollector(Loop *L) : L(L) {}
8843
8844 bool follow(const SCEV *S) {
8845 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
8846 if (auto *I = dyn_cast<Instruction>(SU->getValue()))
8847 if (L->contains(I))
8848 Roots.push_back(S);
8849 } else if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
8850 if (L->contains(AddRec->getLoop()))
8851 Roots.push_back(S);
8852 }
8853 return true;
8854 }
8855 bool isDone() const { return false; }
8856 };
8857
8858 InvalidationRootCollector C(L);
8859 visitAll(S, C);
8860 forgetMemoizedResults(C.Roots);
8861 }
8862
8863 // Also perform the normal invalidation.
8864 forgetValue(V);
8865}
8866
8867void ScalarEvolution::forgetLoopDispositions() { LoopDispositions.clear(); }
8868
8870 // Unless a specific value is passed to invalidation, completely clear both
8871 // caches.
8872 if (!V) {
8873 BlockDispositions.clear();
8874 LoopDispositions.clear();
8875 return;
8876 }
8877
8878 if (!isSCEVable(V->getType()))
8879 return;
8880
8881 const SCEV *S = getExistingSCEV(V);
8882 if (!S)
8883 return;
8884
8885 // Invalidate the block and loop dispositions cached for S. Dispositions of
8886 // S's users may change if S's disposition changes (i.e. a user may change to
8887 // loop-invariant, if S changes to loop invariant), so also invalidate
8888 // dispositions of S's users recursively.
8889 SmallVector<SCEVUse, 8> Worklist = {S};
8891 while (!Worklist.empty()) {
8892 const SCEV *Curr = Worklist.pop_back_val();
8893 bool LoopDispoRemoved = LoopDispositions.erase(Curr);
8894 bool BlockDispoRemoved = BlockDispositions.erase(Curr);
8895 if (!LoopDispoRemoved && !BlockDispoRemoved)
8896 continue;
8897 auto Users = SCEVUsers.find(Curr);
8898 if (Users != SCEVUsers.end())
8899 for (const auto *User : Users->second)
8900 if (Seen.insert(User).second)
8901 Worklist.push_back(User);
8902 }
8903}
8904
8905/// Get the exact loop backedge taken count considering all loop exits. A
8906/// computable result can only be returned for loops with all exiting blocks
8907/// dominating the latch. howFarToZero assumes that the limit of each loop test
8908/// is never skipped. This is a valid assumption as long as the loop exits via
8909/// that test. For precise results, it is the caller's responsibility to specify
8910/// the relevant loop exiting block using getExact(ExitingBlock, SE).
8911const SCEV *ScalarEvolution::BackedgeTakenInfo::getExact(
8912 const Loop *L, ScalarEvolution *SE,
8914 // If any exits were not computable, the loop is not computable.
8915 if (!isComplete() || ExitNotTaken.empty())
8916 return SE->getCouldNotCompute();
8917
8918 const BasicBlock *Latch = L->getLoopLatch();
8919 // All exiting blocks we have collected must dominate the only backedge.
8920 if (!Latch)
8921 return SE->getCouldNotCompute();
8922
8923 // All exiting blocks we have gathered dominate loop's latch, so exact trip
8924 // count is simply a minimum out of all these calculated exit counts.
8926 for (const auto &ENT : ExitNotTaken) {
8927 const SCEV *BECount = ENT.ExactNotTaken;
8928 assert(BECount != SE->getCouldNotCompute() && "Bad exit SCEV!");
8929 assert(SE->DT.dominates(ENT.ExitingBlock, Latch) &&
8930 "We should only have known counts for exiting blocks that dominate "
8931 "latch!");
8932
8933 Ops.push_back(BECount);
8934
8935 if (Preds)
8936 append_range(*Preds, ENT.Predicates);
8937
8938 assert((Preds || ENT.hasAlwaysTruePredicate()) &&
8939 "Predicate should be always true!");
8940 }
8941
8942 // If an earlier exit exits on the first iteration (exit count zero), then
8943 // a later poison exit count should not propagate into the result. This are
8944 // exactly the semantics provided by umin_seq.
8945 return SE->getUMinFromMismatchedTypes(Ops, /* Sequential */ true);
8946}
8947
8948const ScalarEvolution::ExitNotTakenInfo *
8949ScalarEvolution::BackedgeTakenInfo::getExitNotTaken(
8950 const BasicBlock *ExitingBlock,
8951 SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
8952 for (const auto &ENT : ExitNotTaken)
8953 if (ENT.ExitingBlock == ExitingBlock) {
8954 if (ENT.hasAlwaysTruePredicate())
8955 return &ENT;
8956 else if (Predicates) {
8957 append_range(*Predicates, ENT.Predicates);
8958 return &ENT;
8959 }
8960 }
8961
8962 return nullptr;
8963}
8964
8965/// getConstantMax - Get the constant max backedge taken count for the loop.
8966const SCEV *ScalarEvolution::BackedgeTakenInfo::getConstantMax(
8967 ScalarEvolution *SE,
8968 SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
8969 if (!getConstantMax())
8970 return SE->getCouldNotCompute();
8971
8972 for (const auto &ENT : ExitNotTaken)
8973 if (!ENT.hasAlwaysTruePredicate()) {
8974 if (!Predicates)
8975 return SE->getCouldNotCompute();
8976 append_range(*Predicates, ENT.Predicates);
8977 }
8978
8979 assert((isa<SCEVCouldNotCompute>(getConstantMax()) ||
8980 isa<SCEVConstant>(getConstantMax())) &&
8981 "No point in having a non-constant max backedge taken count!");
8982 return getConstantMax();
8983}
8984
8985const SCEV *ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(
8986 const Loop *L, ScalarEvolution *SE,
8987 SmallVectorImpl<const SCEVPredicate *> *Predicates) {
8988 if (!SymbolicMax) {
8989 // Form an expression for the maximum exit count possible for this loop. We
8990 // merge the max and exact information to approximate a version of
8991 // getConstantMaxBackedgeTakenCount which isn't restricted to just
8992 // constants.
8993 SmallVector<SCEVUse, 4> ExitCounts;
8994
8995 for (const auto &ENT : ExitNotTaken) {
8996 const SCEV *ExitCount = ENT.SymbolicMaxNotTaken;
8997 if (!isa<SCEVCouldNotCompute>(ExitCount)) {
8998 assert(SE->DT.dominates(ENT.ExitingBlock, L->getLoopLatch()) &&
8999 "We should only have known counts for exiting blocks that "
9000 "dominate latch!");
9001 ExitCounts.push_back(ExitCount);
9002 if (Predicates)
9003 append_range(*Predicates, ENT.Predicates);
9004
9005 assert((Predicates || ENT.hasAlwaysTruePredicate()) &&
9006 "Predicate should be always true!");
9007 }
9008 }
9009 if (ExitCounts.empty())
9010 SymbolicMax = SE->getCouldNotCompute();
9011 else
9012 SymbolicMax =
9013 SE->getUMinFromMismatchedTypes(ExitCounts, /*Sequential*/ true);
9014 }
9015 return SymbolicMax;
9016}
9017
9018bool ScalarEvolution::BackedgeTakenInfo::isConstantMaxOrZero(
9019 ScalarEvolution *SE) const {
9020 auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) {
9021 return !ENT.hasAlwaysTruePredicate();
9022 };
9023 return MaxOrZero && !any_of(ExitNotTaken, PredicateNotAlwaysTrue);
9024}
9025
9028
9030 const SCEV *E, const SCEV *ConstantMaxNotTaken,
9031 const SCEV *SymbolicMaxNotTaken, bool MaxOrZero,
9035 // If we prove the max count is zero, so is the symbolic bound. This happens
9036 // in practice due to differences in a) how context sensitive we've chosen
9037 // to be and b) how we reason about bounds implied by UB.
9038 if (ConstantMaxNotTaken->isZero()) {
9039 this->ExactNotTaken = E = ConstantMaxNotTaken;
9040 this->SymbolicMaxNotTaken = SymbolicMaxNotTaken = ConstantMaxNotTaken;
9041 }
9042
9045 "Exact is not allowed to be less precise than Constant Max");
9048 "Exact is not allowed to be less precise than Symbolic Max");
9051 "Symbolic Max is not allowed to be less precise than Constant Max");
9054 "No point in having a non-constant max backedge taken count!");
9056 for (const auto PredList : PredLists)
9057 for (const auto *P : PredList) {
9058 if (SeenPreds.contains(P))
9059 continue;
9060 assert(!isa<SCEVUnionPredicate>(P) && "Only add leaf predicates here!");
9061 SeenPreds.insert(P);
9062 Predicates.push_back(P);
9063 }
9064 assert((isa<SCEVCouldNotCompute>(E) || !E->getType()->isPointerTy()) &&
9065 "Backedge count should be int");
9067 !ConstantMaxNotTaken->getType()->isPointerTy()) &&
9068 "Max backedge count should be int");
9069}
9070
9078
9079/// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each
9080/// computable exit into a persistent ExitNotTakenInfo array.
9081ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(
9083 bool IsComplete, const SCEV *ConstantMax, bool MaxOrZero)
9084 : ConstantMax(ConstantMax), IsComplete(IsComplete), MaxOrZero(MaxOrZero) {
9085 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
9086
9087 ExitNotTaken.reserve(ExitCounts.size());
9088 std::transform(ExitCounts.begin(), ExitCounts.end(),
9089 std::back_inserter(ExitNotTaken),
9090 [&](const EdgeExitInfo &EEI) {
9091 BasicBlock *ExitBB = EEI.first;
9092 const ExitLimit &EL = EEI.second;
9093 return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken,
9094 EL.ConstantMaxNotTaken, EL.SymbolicMaxNotTaken,
9095 EL.Predicates);
9096 });
9097 assert((isa<SCEVCouldNotCompute>(ConstantMax) ||
9098 isa<SCEVConstant>(ConstantMax)) &&
9099 "No point in having a non-constant max backedge taken count!");
9100}
9101
9102/// Compute the number of times the backedge of the specified loop will execute.
9103ScalarEvolution::BackedgeTakenInfo
9104ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
9105 bool AllowPredicates) {
9106 SmallVector<BasicBlock *, 8> ExitingBlocks;
9107 L->getExitingBlocks(ExitingBlocks);
9108
9109 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
9110
9112 bool CouldComputeBECount = true;
9113 BasicBlock *Latch = L->getLoopLatch(); // may be NULL.
9114 const SCEV *MustExitMaxBECount = nullptr;
9115 const SCEV *MayExitMaxBECount = nullptr;
9116 bool MustExitMaxOrZero = false;
9117 bool IsOnlyExit = ExitingBlocks.size() == 1;
9118
9119 // Compute the ExitLimit for each loop exit. Use this to populate ExitCounts
9120 // and compute maxBECount.
9121 // Do a union of all the predicates here.
9122 for (BasicBlock *ExitBB : ExitingBlocks) {
9123 // We canonicalize untaken exits to br (constant), ignore them so that
9124 // proving an exit untaken doesn't negatively impact our ability to reason
9125 // about the loop as whole.
9126 if (auto *BI = dyn_cast<CondBrInst>(ExitBB->getTerminator()))
9127 if (auto *CI = dyn_cast<ConstantInt>(BI->getCondition())) {
9128 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
9129 if (ExitIfTrue == CI->isZero())
9130 continue;
9131 }
9132
9133 ExitLimit EL = computeExitLimit(L, ExitBB, IsOnlyExit, AllowPredicates);
9134
9135 assert((AllowPredicates || EL.Predicates.empty()) &&
9136 "Predicated exit limit when predicates are not allowed!");
9137
9138 // 1. For each exit that can be computed, add an entry to ExitCounts.
9139 // CouldComputeBECount is true only if all exits can be computed.
9140 if (EL.ExactNotTaken != getCouldNotCompute())
9141 ++NumExitCountsComputed;
9142 else
9143 // We couldn't compute an exact value for this exit, so
9144 // we won't be able to compute an exact value for the loop.
9145 CouldComputeBECount = false;
9146 // Remember exit count if either exact or symbolic is known. Because
9147 // Exact always implies symbolic, only check symbolic.
9148 if (EL.SymbolicMaxNotTaken != getCouldNotCompute())
9149 ExitCounts.emplace_back(ExitBB, EL);
9150 else {
9151 assert(EL.ExactNotTaken == getCouldNotCompute() &&
9152 "Exact is known but symbolic isn't?");
9153 ++NumExitCountsNotComputed;
9154 }
9155
9156 // 2. Derive the loop's MaxBECount from each exit's max number of
9157 // non-exiting iterations. Partition the loop exits into two kinds:
9158 // LoopMustExits and LoopMayExits.
9159 //
9160 // If the exit dominates the loop latch, it is a LoopMustExit otherwise it
9161 // is a LoopMayExit. If any computable LoopMustExit is found, then
9162 // MaxBECount is the minimum EL.ConstantMaxNotTaken of computable
9163 // LoopMustExits. Otherwise, MaxBECount is conservatively the maximum
9164 // EL.ConstantMaxNotTaken, where CouldNotCompute is considered greater than
9165 // any
9166 // computable EL.ConstantMaxNotTaken.
9167 if (EL.ConstantMaxNotTaken != getCouldNotCompute() && Latch &&
9168 DT.dominates(ExitBB, Latch)) {
9169 if (!MustExitMaxBECount) {
9170 MustExitMaxBECount = EL.ConstantMaxNotTaken;
9171 MustExitMaxOrZero = EL.MaxOrZero;
9172 } else {
9173 MustExitMaxBECount = getUMinFromMismatchedTypes(MustExitMaxBECount,
9174 EL.ConstantMaxNotTaken);
9175 }
9176 } else if (MayExitMaxBECount != getCouldNotCompute()) {
9177 if (!MayExitMaxBECount || EL.ConstantMaxNotTaken == getCouldNotCompute())
9178 MayExitMaxBECount = EL.ConstantMaxNotTaken;
9179 else {
9180 MayExitMaxBECount = getUMaxFromMismatchedTypes(MayExitMaxBECount,
9181 EL.ConstantMaxNotTaken);
9182 }
9183 }
9184 }
9185 const SCEV *MaxBECount = MustExitMaxBECount ? MustExitMaxBECount :
9186 (MayExitMaxBECount ? MayExitMaxBECount : getCouldNotCompute());
9187 // The loop backedge will be taken the maximum or zero times if there's
9188 // a single exit that must be taken the maximum or zero times.
9189 bool MaxOrZero = (MustExitMaxOrZero && ExitingBlocks.size() == 1);
9190
9191 // Remember which SCEVs are used in exit limits for invalidation purposes.
9192 // We only care about non-constant SCEVs here, so we can ignore
9193 // EL.ConstantMaxNotTaken
9194 // and MaxBECount, which must be SCEVConstant.
9195 for (const auto &Pair : ExitCounts) {
9196 if (!isa<SCEVConstant>(Pair.second.ExactNotTaken))
9197 BECountUsers[Pair.second.ExactNotTaken].insert({L, AllowPredicates});
9198 if (!isa<SCEVConstant>(Pair.second.SymbolicMaxNotTaken))
9199 BECountUsers[Pair.second.SymbolicMaxNotTaken].insert(
9200 {L, AllowPredicates});
9201 }
9202 return BackedgeTakenInfo(std::move(ExitCounts), CouldComputeBECount,
9203 MaxBECount, MaxOrZero);
9204}
9205
9206ScalarEvolution::ExitLimit
9207ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
9208 bool IsOnlyExit, bool AllowPredicates) {
9209 assert(L->contains(ExitingBlock) && "Exit count for non-loop block?");
9210 // If our exiting block does not dominate the latch, then its connection with
9211 // loop's exit limit may be far from trivial.
9212 const BasicBlock *Latch = L->getLoopLatch();
9213 if (!Latch || !DT.dominates(ExitingBlock, Latch))
9214 return getCouldNotCompute();
9215
9216 Instruction *Term = ExitingBlock->getTerminator();
9217 if (CondBrInst *BI = dyn_cast<CondBrInst>(Term)) {
9218 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
9219 assert(ExitIfTrue == L->contains(BI->getSuccessor(1)) &&
9220 "It should have one successor in loop and one exit block!");
9221 // Proceed to the next level to examine the exit condition expression.
9222 return computeExitLimitFromCond(L, BI->getCondition(), ExitIfTrue,
9223 /*ControlsOnlyExit=*/IsOnlyExit,
9224 AllowPredicates);
9225 }
9226
9227 if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) {
9228 // For switch, make sure that there is a single exit from the loop.
9229 BasicBlock *Exit = nullptr;
9230 for (auto *SBB : successors(ExitingBlock))
9231 if (!L->contains(SBB)) {
9232 if (Exit) // Multiple exit successors.
9233 return getCouldNotCompute();
9234 Exit = SBB;
9235 }
9236 assert(Exit && "Exiting block must have at least one exit");
9237 return computeExitLimitFromSingleExitSwitch(
9238 L, SI, Exit, /*ControlsOnlyExit=*/IsOnlyExit);
9239 }
9240
9241 return getCouldNotCompute();
9242}
9243
9245 const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
9246 bool AllowPredicates) {
9247 ScalarEvolution::ExitLimitCacheTy Cache(L, ExitIfTrue, AllowPredicates);
9248 return computeExitLimitFromCondCached(Cache, L, ExitCond, ExitIfTrue,
9249 ControlsOnlyExit, AllowPredicates);
9250}
9251
9252std::optional<ScalarEvolution::ExitLimit>
9253ScalarEvolution::ExitLimitCache::find(const Loop *L, Value *ExitCond,
9254 bool ExitIfTrue, bool ControlsOnlyExit,
9255 bool AllowPredicates) {
9256 (void)this->L;
9257 (void)this->ExitIfTrue;
9258 (void)this->AllowPredicates;
9259
9260 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
9261 this->AllowPredicates == AllowPredicates &&
9262 "Variance in assumed invariant key components!");
9263 auto Itr = TripCountMap.find({ExitCond, ControlsOnlyExit});
9264 if (Itr == TripCountMap.end())
9265 return std::nullopt;
9266 return Itr->second;
9267}
9268
9269void ScalarEvolution::ExitLimitCache::insert(const Loop *L, Value *ExitCond,
9270 bool ExitIfTrue,
9271 bool ControlsOnlyExit,
9272 bool AllowPredicates,
9273 const ExitLimit &EL) {
9274 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
9275 this->AllowPredicates == AllowPredicates &&
9276 "Variance in assumed invariant key components!");
9277
9278 auto InsertResult = TripCountMap.insert({{ExitCond, ControlsOnlyExit}, EL});
9279 assert(InsertResult.second && "Expected successful insertion!");
9280 (void)InsertResult;
9281 (void)ExitIfTrue;
9282}
9283
9284ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached(
9285 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9286 bool ControlsOnlyExit, bool AllowPredicates) {
9287
9288 if (auto MaybeEL = Cache.find(L, ExitCond, ExitIfTrue, ControlsOnlyExit,
9289 AllowPredicates))
9290 return *MaybeEL;
9291
9292 ExitLimit EL = computeExitLimitFromCondImpl(
9293 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates);
9294 Cache.insert(L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates, EL);
9295 return EL;
9296}
9297
9298ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
9299 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9300 bool ControlsOnlyExit, bool AllowPredicates) {
9301 // Handle BinOp conditions (And, Or).
9302 if (auto LimitFromBinOp = computeExitLimitFromCondFromBinOp(
9303 Cache, L, ExitCond, ExitIfTrue, AllowPredicates))
9304 return *LimitFromBinOp;
9305
9306 // With an icmp, it may be feasible to compute an exact backedge-taken count.
9307 // Proceed to the next level to examine the icmp.
9308 if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond)) {
9309 ExitLimit EL =
9310 computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsOnlyExit);
9311 if (EL.hasFullInfo() || !AllowPredicates)
9312 return EL;
9313
9314 // Try again, but use SCEV predicates this time.
9315 return computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue,
9316 ControlsOnlyExit,
9317 /*AllowPredicates=*/true);
9318 }
9319
9320 // Check for a constant condition. These are normally stripped out by
9321 // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to
9322 // preserve the CFG and is temporarily leaving constant conditions
9323 // in place.
9324 if (ConstantInt *CI = dyn_cast<ConstantInt>(ExitCond)) {
9325 if (ExitIfTrue == !CI->getZExtValue())
9326 // The backedge is always taken.
9327 return getCouldNotCompute();
9328 // The backedge is never taken.
9329 return getZero(CI->getType());
9330 }
9331
9332 // If we're exiting based on the overflow flag of an x.with.overflow intrinsic
9333 // with a constant step, we can form an equivalent icmp predicate and figure
9334 // out how many iterations will be taken before we exit.
9335 const WithOverflowInst *WO;
9336 const APInt *C;
9337 if (match(ExitCond, m_ExtractValue<1>(m_WithOverflowInst(WO))) &&
9338 match(WO->getRHS(), m_APInt(C))) {
9339 ConstantRange NWR =
9341 WO->getNoWrapKind());
9342 CmpInst::Predicate Pred;
9343 APInt NewRHSC, Offset;
9344 NWR.getEquivalentICmp(Pred, NewRHSC, Offset);
9345 if (!ExitIfTrue)
9346 Pred = ICmpInst::getInversePredicate(Pred);
9347 auto *LHS = getSCEV(WO->getLHS());
9348 if (Offset != 0)
9350 auto EL = computeExitLimitFromICmp(L, Pred, LHS, getConstant(NewRHSC),
9351 ControlsOnlyExit, AllowPredicates);
9352 if (EL.hasAnyInfo())
9353 return EL;
9354 }
9355
9356 // If it's not an integer or pointer comparison then compute it the hard way.
9357 return computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
9358}
9359
9360std::optional<ScalarEvolution::ExitLimit>
9361ScalarEvolution::computeExitLimitFromCondFromBinOp(ExitLimitCacheTy &Cache,
9362 const Loop *L,
9363 Value *ExitCond,
9364 bool ExitIfTrue,
9365 bool AllowPredicates) {
9366 // Check if the controlling expression for this loop is an And or Or.
9367 Value *Op0, *Op1;
9368 bool IsAnd;
9369 if (match(ExitCond, m_LogicalAnd(m_Value(Op0), m_Value(Op1))))
9370 IsAnd = true;
9371 else if (match(ExitCond, m_LogicalOr(m_Value(Op0), m_Value(Op1))))
9372 IsAnd = false;
9373 else
9374 return std::nullopt;
9375
9376 // A sub-condition of a non-trivial binop never solely controls the exit,
9377 // whether we exit always depends on both conditions.
9378 ExitLimit EL0 = computeExitLimitFromCondCached(
9379 Cache, L, Op0, ExitIfTrue, /*ControlsOnlyExit=*/false, AllowPredicates);
9380 ExitLimit EL1 = computeExitLimitFromCondCached(
9381 Cache, L, Op1, ExitIfTrue, /*ControlsOnlyExit=*/false, AllowPredicates);
9382
9383 // EitherMayExit is true in these two cases:
9384 // br (and Op0 Op1), loop, exit
9385 // br (or Op0 Op1), exit, loop
9386 bool EitherMayExit = IsAnd ^ ExitIfTrue;
9387
9388 const SCEV *BECount = getCouldNotCompute();
9389 const SCEV *ConstantMaxBECount = getCouldNotCompute();
9390 const SCEV *SymbolicMaxBECount = getCouldNotCompute();
9391 if (EitherMayExit) {
9392 bool UseSequentialUMin = !isa<BinaryOperator>(ExitCond);
9393 // Both conditions must be same for the loop to continue executing.
9394 // Choose the less conservative count.
9395 if (EL0.ExactNotTaken != getCouldNotCompute() &&
9396 EL1.ExactNotTaken != getCouldNotCompute()) {
9397 BECount = getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken,
9398 UseSequentialUMin);
9399 }
9400 if (EL0.ConstantMaxNotTaken == getCouldNotCompute())
9401 ConstantMaxBECount = EL1.ConstantMaxNotTaken;
9402 else if (EL1.ConstantMaxNotTaken == getCouldNotCompute())
9403 ConstantMaxBECount = EL0.ConstantMaxNotTaken;
9404 else
9405 ConstantMaxBECount = getUMinFromMismatchedTypes(EL0.ConstantMaxNotTaken,
9406 EL1.ConstantMaxNotTaken);
9407 if (EL0.SymbolicMaxNotTaken == getCouldNotCompute())
9408 SymbolicMaxBECount = EL1.SymbolicMaxNotTaken;
9409 else if (EL1.SymbolicMaxNotTaken == getCouldNotCompute())
9410 SymbolicMaxBECount = EL0.SymbolicMaxNotTaken;
9411 else
9412 SymbolicMaxBECount = getUMinFromMismatchedTypes(
9413 EL0.SymbolicMaxNotTaken, EL1.SymbolicMaxNotTaken, UseSequentialUMin);
9414 } else {
9415 // Both conditions must be same at the same time for the loop to exit.
9416 // For now, be conservative.
9417 if (EL0.ExactNotTaken == EL1.ExactNotTaken)
9418 BECount = EL0.ExactNotTaken;
9419 }
9420
9421 // There are cases (e.g. PR26207) where computeExitLimitFromCond is able
9422 // to be more aggressive when computing BECount than when computing
9423 // ConstantMaxBECount. In these cases it is possible for EL0.ExactNotTaken
9424 // and
9425 // EL1.ExactNotTaken to match, but for EL0.ConstantMaxNotTaken and
9426 // EL1.ConstantMaxNotTaken to not.
9427 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
9428 !isa<SCEVCouldNotCompute>(BECount))
9429 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
9430 if (isa<SCEVCouldNotCompute>(SymbolicMaxBECount))
9431 SymbolicMaxBECount =
9432 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
9433 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
9434 {ArrayRef(EL0.Predicates), ArrayRef(EL1.Predicates)});
9435}
9436
9437ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9438 const Loop *L, ICmpInst *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
9439 bool AllowPredicates) {
9440 // If the condition was exit on true, convert the condition to exit on false
9441 CmpPredicate Pred;
9442 if (!ExitIfTrue)
9443 Pred = ExitCond->getCmpPredicate();
9444 else
9445 Pred = ExitCond->getInverseCmpPredicate();
9446 const ICmpInst::Predicate OriginalPred = Pred;
9447
9448 const SCEV *LHS = getSCEV(ExitCond->getOperand(0));
9449 const SCEV *RHS = getSCEV(ExitCond->getOperand(1));
9450
9451 ExitLimit EL = computeExitLimitFromICmp(L, Pred, LHS, RHS, ControlsOnlyExit,
9452 AllowPredicates);
9453 if (EL.hasAnyInfo())
9454 return EL;
9455
9456 auto *ExhaustiveCount =
9457 computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
9458
9459 if (!isa<SCEVCouldNotCompute>(ExhaustiveCount))
9460 return ExhaustiveCount;
9461
9462 return computeShiftCompareExitLimit(ExitCond->getOperand(0),
9463 ExitCond->getOperand(1), L, OriginalPred);
9464}
9465ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9466 const Loop *L, CmpPredicate Pred, SCEVUse LHS, SCEVUse RHS,
9467 bool ControlsOnlyExit, bool AllowPredicates) {
9468
9469 // Try to evaluate any dependencies out of the loop.
9470 LHS = getSCEVAtScope(LHS, L);
9471 RHS = getSCEVAtScope(RHS, L);
9472
9473 // At this point, we would like to compute how many iterations of the
9474 // loop the predicate will return true for these inputs.
9475 if (isLoopInvariant(LHS, L) && !isLoopInvariant(RHS, L)) {
9476 // If there is a loop-invariant, force it into the RHS.
9477 std::swap(LHS, RHS);
9479 }
9480
9481 bool ControllingFiniteLoop = ControlsOnlyExit && loopHasNoAbnormalExits(L) &&
9483 // Simplify the operands before analyzing them.
9484 (void)SimplifyICmpOperands(Pred, LHS, RHS, /*Depth=*/0);
9485
9486 // If we have a comparison of a chrec against a constant, try to use value
9487 // ranges to answer this query.
9488 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS))
9489 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS))
9490 if (AddRec->getLoop() == L) {
9491 // Form the constant range.
9492 ConstantRange CompRange =
9493 ConstantRange::makeExactICmpRegion(Pred, RHSC->getAPInt());
9494
9495 const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this);
9496 if (!isa<SCEVCouldNotCompute>(Ret)) return Ret;
9497 }
9498
9499 // If this loop must exit based on this condition (or execute undefined
9500 // behaviour), see if we can improve wrap flags. This is essentially
9501 // a must execute style proof.
9502 if (ControllingFiniteLoop && isLoopInvariant(RHS, L)) {
9503 // If we can prove the test sequence produced must repeat the same values
9504 // on self-wrap of the IV, then we can infer that IV doesn't self wrap
9505 // because if it did, we'd have an infinite (undefined) loop.
9506 // TODO: We can peel off any functions which are invertible *in L*. Loop
9507 // invariant terms are effectively constants for our purposes here.
9508 SCEVUse InnerLHS = LHS;
9509 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS))
9510 InnerLHS = ZExt->getOperand();
9511 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(InnerLHS);
9512 AR && !AR->hasNoSelfWrap() && AR->getLoop() == L && AR->isAffine() &&
9513 isKnownToBeAPowerOfTwo(AR->getStepRecurrence(*this), /*OrZero=*/true,
9514 /*OrNegative=*/true)) {
9515 auto Flags = AR->getNoWrapFlags();
9516 Flags = setFlags(Flags, SCEV::FlagNW);
9517 SmallVector<SCEVUse> Operands{AR->operands()};
9518 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
9519 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
9520 }
9521
9522 // For a slt/ult condition with a positive step, can we prove nsw/nuw?
9523 // From no-self-wrap, this follows trivially from the fact that every
9524 // (un)signed-wrapped, but not self-wrapped value must be LT than the
9525 // last value before (un)signed wrap. Since we know that last value
9526 // didn't exit, nor will any smaller one.
9527 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT) {
9528 auto WrapType = Pred == ICmpInst::ICMP_SLT ? SCEV::FlagNSW : SCEV::FlagNUW;
9529 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS);
9530 AR && AR->getLoop() == L && AR->isAffine() &&
9531 !AR->getNoWrapFlags(WrapType) && AR->hasNoSelfWrap() &&
9532 isKnownPositive(AR->getStepRecurrence(*this))) {
9533 auto Flags = AR->getNoWrapFlags();
9534 Flags = setFlags(Flags, WrapType);
9535 SmallVector<SCEVUse> Operands{AR->operands()};
9536 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
9537 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
9538 }
9539 }
9540 }
9541
9542 switch (Pred) {
9543 case ICmpInst::ICMP_NE: { // while (X != Y)
9544 // Convert to: while (X-Y != 0)
9545 if (LHS->getType()->isPointerTy()) {
9548 return LHS;
9549 }
9550 if (RHS->getType()->isPointerTy()) {
9553 return RHS;
9554 }
9555 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit,
9556 AllowPredicates);
9557 if (EL.hasAnyInfo())
9558 return EL;
9559 break;
9560 }
9561 case ICmpInst::ICMP_EQ: { // while (X == Y)
9562 // Convert to: while (X-Y == 0)
9563 if (LHS->getType()->isPointerTy()) {
9566 return LHS;
9567 }
9568 if (RHS->getType()->isPointerTy()) {
9571 return RHS;
9572 }
9573 ExitLimit EL = howFarToNonZero(getMinusSCEV(LHS, RHS), L);
9574 if (EL.hasAnyInfo()) return EL;
9575 break;
9576 }
9577 case ICmpInst::ICMP_SLE:
9578 case ICmpInst::ICMP_ULE:
9579 // Since the loop is finite, an invariant RHS cannot include the boundary
9580 // value, otherwise it would loop forever.
9581 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9582 !isLoopInvariant(RHS, L)) {
9583 // Otherwise, perform the addition in a wider type, to avoid overflow.
9584 // If the LHS is an addrec with the appropriate nowrap flag, the
9585 // extension will be sunk into it and the exit count can be analyzed.
9586 auto *OldType = dyn_cast<IntegerType>(LHS->getType());
9587 if (!OldType)
9588 break;
9589 // Prefer doubling the bitwidth over adding a single bit to make it more
9590 // likely that we use a legal type.
9591 auto *NewType =
9592 Type::getIntNTy(OldType->getContext(), OldType->getBitWidth() * 2);
9593 if (ICmpInst::isSigned(Pred)) {
9594 LHS = getSignExtendExpr(LHS, NewType);
9595 RHS = getSignExtendExpr(RHS, NewType);
9596 } else {
9597 LHS = getZeroExtendExpr(LHS, NewType);
9598 RHS = getZeroExtendExpr(RHS, NewType);
9599 }
9600 }
9602 [[fallthrough]];
9603 case ICmpInst::ICMP_SLT:
9604 case ICmpInst::ICMP_ULT: { // while (X < Y)
9605 bool IsSigned = ICmpInst::isSigned(Pred);
9606 ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9607 AllowPredicates);
9608 if (EL.hasAnyInfo())
9609 return EL;
9610 break;
9611 }
9612 case ICmpInst::ICMP_SGE:
9613 case ICmpInst::ICMP_UGE:
9614 // Since the loop is finite, an invariant RHS cannot include the boundary
9615 // value, otherwise it would loop forever.
9616 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9617 !isLoopInvariant(RHS, L))
9618 break;
9620 [[fallthrough]];
9621 case ICmpInst::ICMP_SGT:
9622 case ICmpInst::ICMP_UGT: { // while (X > Y)
9623 bool IsSigned = ICmpInst::isSigned(Pred);
9624 ExitLimit EL = howManyGreaterThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9625 AllowPredicates);
9626 if (EL.hasAnyInfo())
9627 return EL;
9628 break;
9629 }
9630 default:
9631 break;
9632 }
9633
9634 return getCouldNotCompute();
9635}
9636
9637ScalarEvolution::ExitLimit
9638ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L,
9639 SwitchInst *Switch,
9640 BasicBlock *ExitingBlock,
9641 bool ControlsOnlyExit) {
9642 assert(!L->contains(ExitingBlock) && "Not an exiting block!");
9643
9644 // Give up if the exit is the default dest of a switch.
9645 if (Switch->getDefaultDest() == ExitingBlock)
9646 return getCouldNotCompute();
9647
9648 assert(L->contains(Switch->getDefaultDest()) &&
9649 "Default case must not exit the loop!");
9650 const SCEV *LHS = getSCEVAtScope(Switch->getCondition(), L);
9651 const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock));
9652
9653 // while (X != Y) --> while (X-Y != 0)
9654 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit);
9655 if (EL.hasAnyInfo())
9656 return EL;
9657
9658 return getCouldNotCompute();
9659}
9660
9661static ConstantInt *
9663 ScalarEvolution &SE) {
9664 const SCEV *InVal = SE.getConstant(C);
9665 const SCEV *Val = AddRec->evaluateAtIteration(InVal, SE);
9667 "Evaluation of SCEV at constant didn't fold correctly?");
9668 return cast<SCEVConstant>(Val)->getValue();
9669}
9670
9671ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit(
9672 Value *LHS, Value *RHSV, const Loop *L, ICmpInst::Predicate Pred) {
9673 ConstantInt *RHS = dyn_cast<ConstantInt>(RHSV);
9674 if (!RHS)
9675 return getCouldNotCompute();
9676
9677 const BasicBlock *Latch = L->getLoopLatch();
9678 if (!Latch)
9679 return getCouldNotCompute();
9680
9681 const BasicBlock *Predecessor = L->getLoopPredecessor();
9682 if (!Predecessor)
9683 return getCouldNotCompute();
9684
9685 // Return true if V is of the form "LHS `shift_op` <positive constant>".
9686 // Return LHS in OutLHS and shift_opt in OutOpCode.
9687 auto MatchPositiveShift =
9688 [](Value *V, Value *&OutLHS, Instruction::BinaryOps &OutOpCode) {
9689
9690 using namespace PatternMatch;
9691
9692 ConstantInt *ShiftAmt;
9693 if (match(V, m_LShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9694 OutOpCode = Instruction::LShr;
9695 else if (match(V, m_AShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9696 OutOpCode = Instruction::AShr;
9697 else if (match(V, m_Shl(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9698 OutOpCode = Instruction::Shl;
9699 else
9700 return false;
9701
9702 return ShiftAmt->getValue().isStrictlyPositive();
9703 };
9704
9705 // Recognize a "shift recurrence" either of the form %iv or of %iv.shifted in
9706 //
9707 // loop:
9708 // %iv = phi i32 [ %iv.shifted, %loop ], [ %val, %preheader ]
9709 // %iv.shifted = lshr i32 %iv, <positive constant>
9710 //
9711 // Return true on a successful match. Return the corresponding PHI node (%iv
9712 // above) in PNOut and the opcode of the shift operation in OpCodeOut.
9713 auto MatchShiftRecurrence =
9714 [&](Value *V, PHINode *&PNOut, Instruction::BinaryOps &OpCodeOut) {
9715 std::optional<Instruction::BinaryOps> PostShiftOpCode;
9716
9717 {
9719 Value *V;
9720
9721 // If we encounter a shift instruction, "peel off" the shift operation,
9722 // and remember that we did so. Later when we inspect %iv's backedge
9723 // value, we will make sure that the backedge value uses the same
9724 // operation.
9725 //
9726 // Note: the peeled shift operation does not have to be the same
9727 // instruction as the one feeding into the PHI's backedge value. We only
9728 // really care about it being the same *kind* of shift instruction --
9729 // that's all that is required for our later inferences to hold.
9730 if (MatchPositiveShift(LHS, V, OpC)) {
9731 PostShiftOpCode = OpC;
9732 LHS = V;
9733 }
9734 }
9735
9736 PNOut = dyn_cast<PHINode>(LHS);
9737 if (!PNOut || PNOut->getParent() != L->getHeader())
9738 return false;
9739
9740 Value *BEValue = PNOut->getIncomingValueForBlock(Latch);
9741 Value *OpLHS;
9742
9743 return
9744 // The backedge value for the PHI node must be a shift by a positive
9745 // amount
9746 MatchPositiveShift(BEValue, OpLHS, OpCodeOut) &&
9747
9748 // of the PHI node itself
9749 OpLHS == PNOut &&
9750
9751 // and the kind of shift should be match the kind of shift we peeled
9752 // off, if any.
9753 (!PostShiftOpCode || *PostShiftOpCode == OpCodeOut);
9754 };
9755
9756 PHINode *PN;
9758 if (!MatchShiftRecurrence(LHS, PN, OpCode))
9759 return getCouldNotCompute();
9760
9761 const DataLayout &DL = getDataLayout();
9762
9763 // The key rationale for this optimization is that for some kinds of shift
9764 // recurrences, the value of the recurrence "stabilizes" to either 0 or -1
9765 // within a finite number of iterations. If the condition guarding the
9766 // backedge (in the sense that the backedge is taken if the condition is true)
9767 // is false for the value the shift recurrence stabilizes to, then we know
9768 // that the backedge is taken only a finite number of times.
9769
9770 ConstantInt *StableValue = nullptr;
9771 switch (OpCode) {
9772 default:
9773 llvm_unreachable("Impossible case!");
9774
9775 case Instruction::AShr: {
9776 // {K,ashr,<positive-constant>} stabilizes to signum(K) in at most
9777 // bitwidth(K) iterations.
9778 Value *FirstValue = PN->getIncomingValueForBlock(Predecessor);
9779 KnownBits Known = computeKnownBits(FirstValue, DL, &AC,
9780 Predecessor->getTerminator(), &DT);
9781 auto *Ty = cast<IntegerType>(RHS->getType());
9782 if (Known.isNonNegative())
9783 StableValue = ConstantInt::get(Ty, 0);
9784 else if (Known.isNegative())
9785 StableValue = ConstantInt::get(Ty, -1, true);
9786 else
9787 return getCouldNotCompute();
9788
9789 break;
9790 }
9791 case Instruction::LShr:
9792 case Instruction::Shl:
9793 // Both {K,lshr,<positive-constant>} and {K,shl,<positive-constant>}
9794 // stabilize to 0 in at most bitwidth(K) iterations.
9795 StableValue = ConstantInt::get(cast<IntegerType>(RHS->getType()), 0);
9796 break;
9797 }
9798
9799 auto *Result =
9800 ConstantFoldCompareInstOperands(Pred, StableValue, RHS, DL, &TLI);
9801 assert(Result->getType()->isIntegerTy(1) &&
9802 "Otherwise cannot be an operand to a branch instruction");
9803
9804 if (Result->isNullValue()) {
9805 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
9806 const SCEV *UpperBound =
9808 return ExitLimit(getCouldNotCompute(), UpperBound, UpperBound, false);
9809 }
9810
9811 return getCouldNotCompute();
9812}
9813
9814/// Return true if we can constant fold an instruction of the specified type,
9815/// assuming that all operands were constants.
9816static bool CanConstantFold(const Instruction *I) {
9820 return true;
9821
9822 if (const CallInst *CI = dyn_cast<CallInst>(I))
9823 if (const Function *F = CI->getCalledFunction())
9824 return canConstantFoldCallTo(CI, F);
9825 return false;
9826}
9827
9828/// Determine whether this instruction can constant evolve within this loop
9829/// assuming its operands can all constant evolve.
9830static bool canConstantEvolve(Instruction *I, const Loop *L) {
9831 // An instruction outside of the loop can't be derived from a loop PHI.
9832 if (!L->contains(I)) return false;
9833
9834 if (isa<PHINode>(I)) {
9835 // We don't currently keep track of the control flow needed to evaluate
9836 // PHIs, so we cannot handle PHIs inside of loops.
9837 return L->getHeader() == I->getParent();
9838 }
9839
9840 // If we won't be able to constant fold this expression even if the operands
9841 // are constants, bail early.
9842 return CanConstantFold(I);
9843}
9844
9845/// getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by
9846/// recursing through each instruction operand until reaching a loop header phi.
9847static PHINode *
9850 unsigned Depth) {
9852 return nullptr;
9853
9854 // Otherwise, we can evaluate this instruction if all of its operands are
9855 // constant or derived from a PHI node themselves.
9856 PHINode *PHI = nullptr;
9857 for (Value *Op : UseInst->operands()) {
9858 if (isa<Constant>(Op)) continue;
9859
9861 if (!OpInst || !canConstantEvolve(OpInst, L)) return nullptr;
9862
9863 PHINode *P = dyn_cast<PHINode>(OpInst);
9864 if (!P)
9865 // If this operand is already visited, reuse the prior result.
9866 // We may have P != PHI if this is the deepest point at which the
9867 // inconsistent paths meet.
9868 P = PHIMap.lookup(OpInst);
9869 if (!P) {
9870 // Recurse and memoize the results, whether a phi is found or not.
9871 // This recursive call invalidates pointers into PHIMap.
9872 P = getConstantEvolvingPHIOperands(OpInst, L, PHIMap, Depth + 1);
9873 PHIMap[OpInst] = P;
9874 }
9875 if (!P)
9876 return nullptr; // Not evolving from PHI
9877 if (PHI && PHI != P)
9878 return nullptr; // Evolving from multiple different PHIs.
9879 PHI = P;
9880 }
9881 // This is a expression evolving from a constant PHI!
9882 return PHI;
9883}
9884
9885/// getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node
9886/// in the loop that V is derived from. We allow arbitrary operations along the
9887/// way, but the operands of an operation must either be constants or a value
9888/// derived from a constant PHI. If this expression does not fit with these
9889/// constraints, return null.
9892 if (!I || !canConstantEvolve(I, L)) return nullptr;
9893
9894 if (PHINode *PN = dyn_cast<PHINode>(I))
9895 return PN;
9896
9897 // Record non-constant instructions contained by the loop.
9899 return getConstantEvolvingPHIOperands(I, L, PHIMap, 0);
9900}
9901
9902/// EvaluateExpression - Given an expression that passes the
9903/// getConstantEvolvingPHI predicate, evaluate its value assuming the PHI node
9904/// in the loop has the value PHIVal. If we can't fold this expression for some
9905/// reason, return null.
9908 const DataLayout &DL,
9909 const TargetLibraryInfo *TLI) {
9910 // Convenient constant check, but redundant for recursive calls.
9911 if (Constant *C = dyn_cast<Constant>(V)) return C;
9913 if (!I) return nullptr;
9914
9915 if (Constant *C = Vals.lookup(I)) return C;
9916
9917 // An instruction inside the loop depends on a value outside the loop that we
9918 // weren't given a mapping for, or a value such as a call inside the loop.
9919 if (!canConstantEvolve(I, L)) return nullptr;
9920
9921 // An unmapped PHI can be due to a branch or another loop inside this loop,
9922 // or due to this not being the initial iteration through a loop where we
9923 // couldn't compute the evolution of this particular PHI last time.
9924 if (isa<PHINode>(I)) return nullptr;
9925
9926 std::vector<Constant*> Operands(I->getNumOperands());
9927
9928 for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
9929 Instruction *Operand = dyn_cast<Instruction>(I->getOperand(i));
9930 if (!Operand) {
9931 Operands[i] = dyn_cast<Constant>(I->getOperand(i));
9932 if (!Operands[i]) return nullptr;
9933 continue;
9934 }
9935 Constant *C = EvaluateExpression(Operand, L, Vals, DL, TLI);
9936 Vals[Operand] = C;
9937 if (!C) return nullptr;
9938 Operands[i] = C;
9939 }
9940
9941 return ConstantFoldInstOperands(I, Operands, DL, TLI,
9942 /*AllowNonDeterministic=*/false);
9943}
9944
9945
9946// If every incoming value to PN except the one for BB is a specific Constant,
9947// return that, else return nullptr.
9949 Constant *IncomingVal = nullptr;
9950
9951 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
9952 if (PN->getIncomingBlock(i) == BB)
9953 continue;
9954
9955 auto *CurrentVal = dyn_cast<Constant>(PN->getIncomingValue(i));
9956 if (!CurrentVal)
9957 return nullptr;
9958
9959 if (IncomingVal != CurrentVal) {
9960 if (IncomingVal)
9961 return nullptr;
9962 IncomingVal = CurrentVal;
9963 }
9964 }
9965
9966 return IncomingVal;
9967}
9968
9969/// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
9970/// in the header of its containing loop, we know the loop executes a
9971/// constant number of times, and the PHI node is just a recurrence
9972/// involving constants, fold it.
9973Constant *
9974ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN,
9975 const APInt &BEs,
9976 const Loop *L) {
9977 auto [I, Inserted] = ConstantEvolutionLoopExitValue.try_emplace(PN);
9978 if (!Inserted)
9979 return I->second;
9980
9982 return nullptr; // Not going to evaluate it.
9983
9984 Constant *&RetVal = I->second;
9985
9986 DenseMap<Instruction *, Constant *> CurrentIterVals;
9987 BasicBlock *Header = L->getHeader();
9988 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
9989
9990 BasicBlock *Latch = L->getLoopLatch();
9991 if (!Latch)
9992 return nullptr;
9993
9994 for (PHINode &PHI : Header->phis()) {
9995 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
9996 CurrentIterVals[&PHI] = StartCST;
9997 }
9998 if (!CurrentIterVals.count(PN))
9999 return RetVal = nullptr;
10000
10001 Value *BEValue = PN->getIncomingValueForBlock(Latch);
10002
10003 // Execute the loop symbolically to determine the exit value.
10004 assert(BEs.getActiveBits() < CHAR_BIT * sizeof(unsigned) &&
10005 "BEs is <= MaxBruteForceIterations which is an 'unsigned'!");
10006
10007 unsigned NumIterations = BEs.getZExtValue(); // must be in range
10008 unsigned IterationNum = 0;
10009 const DataLayout &DL = getDataLayout();
10010 for (; ; ++IterationNum) {
10011 if (IterationNum == NumIterations)
10012 return RetVal = CurrentIterVals[PN]; // Got exit value!
10013
10014 // Compute the value of the PHIs for the next iteration.
10015 // EvaluateExpression adds non-phi values to the CurrentIterVals map.
10016 DenseMap<Instruction *, Constant *> NextIterVals;
10017 Constant *NextPHI =
10018 EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
10019 if (!NextPHI)
10020 return nullptr; // Couldn't evaluate!
10021 NextIterVals[PN] = NextPHI;
10022
10023 bool StoppedEvolving = NextPHI == CurrentIterVals[PN];
10024
10025 // Also evaluate the other PHI nodes. However, we don't get to stop if we
10026 // cease to be able to evaluate one of them or if they stop evolving,
10027 // because that doesn't necessarily prevent us from computing PN.
10029 for (const auto &I : CurrentIterVals) {
10030 PHINode *PHI = dyn_cast<PHINode>(I.first);
10031 if (!PHI || PHI == PN || PHI->getParent() != Header) continue;
10032 PHIsToCompute.emplace_back(PHI, I.second);
10033 }
10034 // We use two distinct loops because EvaluateExpression may invalidate any
10035 // iterators into CurrentIterVals.
10036 for (const auto &I : PHIsToCompute) {
10037 PHINode *PHI = I.first;
10038 Constant *&NextPHI = NextIterVals[PHI];
10039 if (!NextPHI) { // Not already computed.
10040 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
10041 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
10042 }
10043 if (NextPHI != I.second)
10044 StoppedEvolving = false;
10045 }
10046
10047 // If all entries in CurrentIterVals == NextIterVals then we can stop
10048 // iterating, the loop can't continue to change.
10049 if (StoppedEvolving)
10050 return RetVal = CurrentIterVals[PN];
10051
10052 CurrentIterVals.swap(NextIterVals);
10053 }
10054}
10055
10056const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L,
10057 Value *Cond,
10058 bool ExitWhen) {
10059 PHINode *PN = getConstantEvolvingPHI(Cond, L);
10060 if (!PN) return getCouldNotCompute();
10061
10062 // If the loop is canonicalized, the PHI will have exactly two entries.
10063 // That's the only form we support here.
10064 if (PN->getNumIncomingValues() != 2) return getCouldNotCompute();
10065
10066 DenseMap<Instruction *, Constant *> CurrentIterVals;
10067 BasicBlock *Header = L->getHeader();
10068 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
10069
10070 BasicBlock *Latch = L->getLoopLatch();
10071 assert(Latch && "Should follow from NumIncomingValues == 2!");
10072
10073 for (PHINode &PHI : Header->phis()) {
10074 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
10075 CurrentIterVals[&PHI] = StartCST;
10076 }
10077 if (!CurrentIterVals.count(PN))
10078 return getCouldNotCompute();
10079
10080 // Okay, we find a PHI node that defines the trip count of this loop. Execute
10081 // the loop symbolically to determine when the condition gets a value of
10082 // "ExitWhen".
10083 unsigned MaxIterations = MaxBruteForceIterations; // Limit analysis.
10084 const DataLayout &DL = getDataLayout();
10085 for (unsigned IterationNum = 0; IterationNum != MaxIterations;++IterationNum){
10086 auto *CondVal = dyn_cast_or_null<ConstantInt>(
10087 EvaluateExpression(Cond, L, CurrentIterVals, DL, &TLI));
10088
10089 // Couldn't symbolically evaluate.
10090 if (!CondVal) return getCouldNotCompute();
10091
10092 if (CondVal->getValue() == uint64_t(ExitWhen)) {
10093 ++NumBruteForceTripCountsComputed;
10094 return getConstant(Type::getInt32Ty(getContext()), IterationNum);
10095 }
10096
10097 // Update all the PHI nodes for the next iteration.
10098 DenseMap<Instruction *, Constant *> NextIterVals;
10099
10100 // Create a list of which PHIs we need to compute. We want to do this before
10101 // calling EvaluateExpression on them because that may invalidate iterators
10102 // into CurrentIterVals.
10103 SmallVector<PHINode *, 8> PHIsToCompute;
10104 for (const auto &I : CurrentIterVals) {
10105 PHINode *PHI = dyn_cast<PHINode>(I.first);
10106 if (!PHI || PHI->getParent() != Header) continue;
10107 PHIsToCompute.push_back(PHI);
10108 }
10109 for (PHINode *PHI : PHIsToCompute) {
10110 Constant *&NextPHI = NextIterVals[PHI];
10111 if (NextPHI) continue; // Already computed!
10112
10113 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
10114 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
10115 }
10116 CurrentIterVals.swap(NextIterVals);
10117 }
10118
10119 // Too many iterations were needed to evaluate.
10120 return getCouldNotCompute();
10121}
10122
10123const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
10125 ValuesAtScopes[V];
10126 // Check to see if we've folded this expression at this loop before.
10127 for (auto &LS : Values)
10128 if (LS.first == L)
10129 return LS.second ? LS.second : V;
10130
10131 Values.emplace_back(L, nullptr);
10132
10133 // Otherwise compute it.
10134 const SCEV *C = computeSCEVAtScope(V, L);
10135 for (auto &LS : reverse(ValuesAtScopes[V]))
10136 if (LS.first == L) {
10137 LS.second = C;
10138 if (!isa<SCEVConstant>(C))
10139 ValuesAtScopesUsers[C].push_back({L, V});
10140 break;
10141 }
10142 return C;
10143}
10144
10145/// This builds up a Constant using the ConstantExpr interface. That way, we
10146/// will return Constants for objects which aren't represented by a
10147/// SCEVConstant, because SCEVConstant is restricted to ConstantInt.
10148/// Returns NULL if the SCEV isn't representable as a Constant.
10150 switch (V->getSCEVType()) {
10151 case scCouldNotCompute:
10152 case scAddRecExpr:
10153 case scVScale:
10154 return nullptr;
10155 case scConstant:
10156 return cast<SCEVConstant>(V)->getValue();
10157 case scUnknown:
10159 case scPtrToAddr: {
10161 if (Constant *CastOp = BuildConstantFromSCEV(P2I->getOperand()))
10162 return ConstantExpr::getPtrToAddr(CastOp, P2I->getType());
10163
10164 return nullptr;
10165 }
10166 case scPtrToInt: {
10168 if (Constant *CastOp = BuildConstantFromSCEV(P2I->getOperand()))
10169 return ConstantExpr::getPtrToInt(CastOp, P2I->getType());
10170
10171 return nullptr;
10172 }
10173 case scTruncate: {
10175 if (Constant *CastOp = BuildConstantFromSCEV(ST->getOperand()))
10176 return ConstantExpr::getTrunc(CastOp, ST->getType());
10177 return nullptr;
10178 }
10179 case scAddExpr: {
10180 const SCEVAddExpr *SA = cast<SCEVAddExpr>(V);
10181 Constant *C = nullptr;
10182 for (const SCEV *Op : SA->operands()) {
10184 if (!OpC)
10185 return nullptr;
10186 if (!C) {
10187 C = OpC;
10188 continue;
10189 }
10190 assert(!C->getType()->isPointerTy() &&
10191 "Can only have one pointer, and it must be last");
10192 if (OpC->getType()->isPointerTy()) {
10193 // The offsets have been converted to bytes. We can add bytes using
10194 // an i8 GEP.
10195 C = ConstantExpr::getPtrAdd(OpC, C);
10196 } else {
10197 C = ConstantExpr::getAdd(C, OpC);
10198 }
10199 }
10200 return C;
10201 }
10202 case scMulExpr:
10203 case scSignExtend:
10204 case scZeroExtend:
10205 case scUDivExpr:
10206 case scSMaxExpr:
10207 case scUMaxExpr:
10208 case scSMinExpr:
10209 case scUMinExpr:
10211 return nullptr;
10212 }
10213 llvm_unreachable("Unknown SCEV kind!");
10214}
10215
10216const SCEV *ScalarEvolution::getWithOperands(const SCEV *S,
10217 SmallVectorImpl<SCEVUse> &NewOps) {
10218 switch (S->getSCEVType()) {
10219 case scTruncate:
10220 case scZeroExtend:
10221 case scSignExtend:
10222 case scPtrToAddr:
10223 case scPtrToInt:
10224 return getCastExpr(S->getSCEVType(), NewOps[0], S->getType());
10225 case scAddRecExpr: {
10226 auto *AddRec = cast<SCEVAddRecExpr>(S);
10227 return getAddRecExpr(NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags());
10228 }
10229 case scAddExpr:
10230 return getAddExpr(NewOps, cast<SCEVAddExpr>(S)->getNoWrapFlags());
10231 case scMulExpr:
10232 return getMulExpr(NewOps, cast<SCEVMulExpr>(S)->getNoWrapFlags());
10233 case scUDivExpr:
10234 return getUDivExpr(NewOps[0], NewOps[1]);
10235 case scUMaxExpr:
10236 case scSMaxExpr:
10237 case scUMinExpr:
10238 case scSMinExpr:
10239 return getMinMaxExpr(S->getSCEVType(), NewOps);
10241 return getSequentialMinMaxExpr(S->getSCEVType(), NewOps);
10242 case scConstant:
10243 case scVScale:
10244 case scUnknown:
10245 return S;
10246 case scCouldNotCompute:
10247 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
10248 }
10249 llvm_unreachable("Unknown SCEV kind!");
10250}
10251
10252const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
10253 switch (V->getSCEVType()) {
10254 case scConstant:
10255 case scVScale:
10256 return V;
10257 case scAddRecExpr: {
10258 // If this is a loop recurrence for a loop that does not contain L, then we
10259 // are dealing with the final value computed by the loop.
10260 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(V);
10261 // First, attempt to evaluate each operand.
10262 // Avoid performing the look-up in the common case where the specified
10263 // expression has no loop-variant portions.
10264 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
10265 const SCEV *OpAtScope = getSCEVAtScope(AddRec->getOperand(i), L);
10266 if (OpAtScope == AddRec->getOperand(i))
10267 continue;
10268
10269 // Okay, at least one of these operands is loop variant but might be
10270 // foldable. Build a new instance of the folded commutative expression.
10272 NewOps.reserve(AddRec->getNumOperands());
10273 append_range(NewOps, AddRec->operands().take_front(i));
10274 NewOps.push_back(OpAtScope);
10275 for (++i; i != e; ++i)
10276 NewOps.push_back(getSCEVAtScope(AddRec->getOperand(i), L));
10277
10278 const SCEV *FoldedRec = getAddRecExpr(
10279 NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags(SCEV::FlagNW));
10280 AddRec = dyn_cast<SCEVAddRecExpr>(FoldedRec);
10281 // The addrec may be folded to a nonrecurrence, for example, if the
10282 // induction variable is multiplied by zero after constant folding. Go
10283 // ahead and return the folded value.
10284 if (!AddRec)
10285 return FoldedRec;
10286 break;
10287 }
10288
10289 // If the scope is outside the addrec's loop, evaluate it by using the
10290 // loop exit value of the addrec.
10291 if (!AddRec->getLoop()->contains(L)) {
10292 // To evaluate this recurrence, we need to know how many times the AddRec
10293 // loop iterates. Compute this now.
10294 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop());
10295 if (BackedgeTakenCount == getCouldNotCompute())
10296 return AddRec;
10297
10298 // Then, evaluate the AddRec.
10299 return AddRec->evaluateAtIteration(BackedgeTakenCount, *this);
10300 }
10301
10302 return AddRec;
10303 }
10304 case scTruncate:
10305 case scZeroExtend:
10306 case scSignExtend:
10307 case scPtrToAddr:
10308 case scPtrToInt:
10309 case scAddExpr:
10310 case scMulExpr:
10311 case scUDivExpr:
10312 case scUMaxExpr:
10313 case scSMaxExpr:
10314 case scUMinExpr:
10315 case scSMinExpr:
10316 case scSequentialUMinExpr: {
10317 ArrayRef<SCEVUse> Ops = V->operands();
10318 // Avoid performing the look-up in the common case where the specified
10319 // expression has no loop-variant portions.
10320 for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
10321 const SCEV *OpAtScope = getSCEVAtScope(Ops[i].getPointer(), L);
10322 if (OpAtScope != Ops[i].getPointer()) {
10323 // Okay, at least one of these operands is loop variant but might be
10324 // foldable. Build a new instance of the folded commutative expression.
10326 NewOps.reserve(Ops.size());
10327 append_range(NewOps, Ops.take_front(i));
10328 NewOps.push_back(OpAtScope);
10329
10330 for (++i; i != e; ++i) {
10331 OpAtScope = getSCEVAtScope(Ops[i].getPointer(), L);
10332 NewOps.push_back(OpAtScope);
10333 }
10334
10335 return getWithOperands(V, NewOps);
10336 }
10337 }
10338 // If we got here, all operands are loop invariant.
10339 return V;
10340 }
10341 case scUnknown: {
10342 // If this instruction is evolved from a constant-evolving PHI, compute the
10343 // exit value from the loop without using SCEVs.
10344 const SCEVUnknown *SU = cast<SCEVUnknown>(V);
10346 if (!I)
10347 return V; // This is some other type of SCEVUnknown, just return it.
10348
10349 if (PHINode *PN = dyn_cast<PHINode>(I)) {
10350 const Loop *CurrLoop = this->LI[I->getParent()];
10351 // Looking for loop exit value.
10352 if (CurrLoop && CurrLoop->getParentLoop() == L &&
10353 PN->getParent() == CurrLoop->getHeader()) {
10354 // Okay, there is no closed form solution for the PHI node. Check
10355 // to see if the loop that contains it has a known backedge-taken
10356 // count. If so, we may be able to force computation of the exit
10357 // value.
10358 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(CurrLoop);
10359 // This trivial case can show up in some degenerate cases where
10360 // the incoming IR has not yet been fully simplified.
10361 if (BackedgeTakenCount->isZero()) {
10362 Value *InitValue = nullptr;
10363 bool MultipleInitValues = false;
10364 for (unsigned i = 0; i < PN->getNumIncomingValues(); i++) {
10365 if (!CurrLoop->contains(PN->getIncomingBlock(i))) {
10366 if (!InitValue)
10367 InitValue = PN->getIncomingValue(i);
10368 else if (InitValue != PN->getIncomingValue(i)) {
10369 MultipleInitValues = true;
10370 break;
10371 }
10372 }
10373 }
10374 if (!MultipleInitValues && InitValue)
10375 return getSCEV(InitValue);
10376 }
10377 // Do we have a loop invariant value flowing around the backedge
10378 // for a loop which must execute the backedge?
10379 if (!isa<SCEVCouldNotCompute>(BackedgeTakenCount) &&
10380 isKnownNonZero(BackedgeTakenCount) &&
10381 PN->getNumIncomingValues() == 2) {
10382
10383 unsigned InLoopPred =
10384 CurrLoop->contains(PN->getIncomingBlock(0)) ? 0 : 1;
10385 Value *BackedgeVal = PN->getIncomingValue(InLoopPred);
10386 if (CurrLoop->isLoopInvariant(BackedgeVal))
10387 return getSCEV(BackedgeVal);
10388 }
10389 if (auto *BTCC = dyn_cast<SCEVConstant>(BackedgeTakenCount)) {
10390 // Okay, we know how many times the containing loop executes. If
10391 // this is a constant evolving PHI node, get the final value at
10392 // the specified iteration number.
10393 Constant *RV =
10394 getConstantEvolutionLoopExitValue(PN, BTCC->getAPInt(), CurrLoop);
10395 if (RV)
10396 return getSCEV(RV);
10397 }
10398 }
10399 }
10400
10401 // Okay, this is an expression that we cannot symbolically evaluate
10402 // into a SCEV. Check to see if it's possible to symbolically evaluate
10403 // the arguments into constants, and if so, try to constant propagate the
10404 // result. This is particularly useful for computing loop exit values.
10405 if (!CanConstantFold(I))
10406 return V; // This is some other type of SCEVUnknown, just return it.
10407
10408 SmallVector<Constant *, 4> Operands;
10409 Operands.reserve(I->getNumOperands());
10410 bool MadeImprovement = false;
10411 for (Value *Op : I->operands()) {
10412 if (Constant *C = dyn_cast<Constant>(Op)) {
10413 Operands.push_back(C);
10414 continue;
10415 }
10416
10417 // If any of the operands is non-constant and if they are
10418 // non-integer and non-pointer, don't even try to analyze them
10419 // with scev techniques.
10420 if (!isSCEVable(Op->getType()))
10421 return V;
10422
10423 const SCEV *OrigV = getSCEV(Op);
10424 const SCEV *OpV = getSCEVAtScope(OrigV, L);
10425 MadeImprovement |= OrigV != OpV;
10426
10428 if (!C)
10429 return V;
10430 assert(C->getType() == Op->getType() && "Type mismatch");
10431 Operands.push_back(C);
10432 }
10433
10434 // Check to see if getSCEVAtScope actually made an improvement.
10435 if (!MadeImprovement)
10436 return V; // This is some other type of SCEVUnknown, just return it.
10437
10438 Constant *C = nullptr;
10439 const DataLayout &DL = getDataLayout();
10440 C = ConstantFoldInstOperands(I, Operands, DL, &TLI,
10441 /*AllowNonDeterministic=*/false);
10442 if (!C)
10443 return V;
10444 return getSCEV(C);
10445 }
10446 case scCouldNotCompute:
10447 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
10448 }
10449 llvm_unreachable("Unknown SCEV type!");
10450}
10451
10453 return getSCEVAtScope(getSCEV(V), L);
10454}
10455
10456const SCEV *ScalarEvolution::stripInjectiveFunctions(const SCEV *S) const {
10458 return stripInjectiveFunctions(ZExt->getOperand());
10460 return stripInjectiveFunctions(SExt->getOperand());
10461 return S;
10462}
10463
10464/// Finds the minimum unsigned root of the following equation:
10465///
10466/// A * X = B (mod N)
10467///
10468/// where N = 2^BW and BW is the common bit width of A and B. The signedness of
10469/// A and B isn't important.
10470///
10471/// If the equation does not have a solution, SCEVCouldNotCompute is returned.
10472static const SCEV *
10475 ScalarEvolution &SE, const Loop *L) {
10476 uint32_t BW = A.getBitWidth();
10477 assert(BW == SE.getTypeSizeInBits(B->getType()));
10478 assert(A != 0 && "A must be non-zero.");
10479
10480 // 1. D = gcd(A, N)
10481 //
10482 // The gcd of A and N may have only one prime factor: 2. The number of
10483 // trailing zeros in A is its multiplicity
10484 uint32_t Mult2 = A.countr_zero();
10485 // D = 2^Mult2
10486
10487 // 2. Check if B is divisible by D.
10488 //
10489 // B is divisible by D if and only if the multiplicity of prime factor 2 for B
10490 // is not less than multiplicity of this prime factor for D.
10491 unsigned MinTZ = SE.getMinTrailingZeros(B);
10492 // Try again with the terminator of the loop predecessor for context-specific
10493 // result, if MinTZ s too small.
10494 if (MinTZ < Mult2 && L->getLoopPredecessor())
10495 MinTZ = SE.getMinTrailingZeros(B, L->getLoopPredecessor()->getTerminator());
10496 if (MinTZ < Mult2) {
10497 // Check if we can prove there's no remainder using URem.
10498 const SCEV *URem =
10499 SE.getURemExpr(B, SE.getConstant(APInt::getOneBitSet(BW, Mult2)));
10500 const SCEV *Zero = SE.getZero(B->getType());
10501 if (!SE.isKnownPredicate(CmpInst::ICMP_EQ, URem, Zero)) {
10502 // Try to add a predicate ensuring B is a multiple of 1 << Mult2.
10503 if (!Predicates)
10504 return SE.getCouldNotCompute();
10505
10506 // Avoid adding a predicate that is known to be false.
10507 if (SE.isKnownPredicate(CmpInst::ICMP_NE, URem, Zero))
10508 return SE.getCouldNotCompute();
10509 Predicates->push_back(SE.getEqualPredicate(URem, Zero));
10510 }
10511 }
10512
10513 // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic
10514 // modulo (N / D).
10515 //
10516 // If D == 1, (N / D) == N == 2^BW, so we need one extra bit to represent
10517 // (N / D) in general. The inverse itself always fits into BW bits, though,
10518 // so we immediately truncate it.
10519 APInt AD = A.lshr(Mult2).trunc(BW - Mult2); // AD = A / D
10520 APInt I = AD.multiplicativeInverse().zext(BW);
10521
10522 // 4. Compute the minimum unsigned root of the equation:
10523 // I * (B / D) mod (N / D)
10524 // To simplify the computation, we factor out the divide by D:
10525 // (I * B mod N) / D
10526 const SCEV *D = SE.getConstant(APInt::getOneBitSet(BW, Mult2));
10527 return SE.getUDivExactExpr(SE.getMulExpr(B, SE.getConstant(I)), D);
10528}
10529
10530/// For a given quadratic addrec, generate coefficients of the corresponding
10531/// quadratic equation, multiplied by a common value to ensure that they are
10532/// integers.
10533/// The returned value is a tuple { A, B, C, M, BitWidth }, where
10534/// Ax^2 + Bx + C is the quadratic function, M is the value that A, B and C
10535/// were multiplied by, and BitWidth is the bit width of the original addrec
10536/// coefficients.
10537/// This function returns std::nullopt if the addrec coefficients are not
10538/// compile- time constants.
10539static std::optional<std::tuple<APInt, APInt, APInt, APInt, unsigned>>
10541 assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!");
10542 const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0));
10543 const SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1));
10544 const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2));
10545 LLVM_DEBUG(dbgs() << __func__ << ": analyzing quadratic addrec: "
10546 << *AddRec << '\n');
10547
10548 // We currently can only solve this if the coefficients are constants.
10549 if (!LC || !MC || !NC) {
10550 LLVM_DEBUG(dbgs() << __func__ << ": coefficients are not constant\n");
10551 return std::nullopt;
10552 }
10553
10554 APInt L = LC->getAPInt();
10555 APInt M = MC->getAPInt();
10556 APInt N = NC->getAPInt();
10557 assert(!N.isZero() && "This is not a quadratic addrec");
10558
10559 unsigned BitWidth = LC->getAPInt().getBitWidth();
10560 unsigned NewWidth = BitWidth + 1;
10561 LLVM_DEBUG(dbgs() << __func__ << ": addrec coeff bw: "
10562 << BitWidth << '\n');
10563 // The sign-extension (as opposed to a zero-extension) here matches the
10564 // extension used in SolveQuadraticEquationWrap (with the same motivation).
10565 N = N.sext(NewWidth);
10566 M = M.sext(NewWidth);
10567 L = L.sext(NewWidth);
10568
10569 // The increments are M, M+N, M+2N, ..., so the accumulated values are
10570 // L+M, (L+M)+(M+N), (L+M)+(M+N)+(M+2N), ..., that is,
10571 // L+M, L+2M+N, L+3M+3N, ...
10572 // After n iterations the accumulated value Acc is L + nM + n(n-1)/2 N.
10573 //
10574 // The equation Acc = 0 is then
10575 // L + nM + n(n-1)/2 N = 0, or 2L + 2M n + n(n-1) N = 0.
10576 // In a quadratic form it becomes:
10577 // N n^2 + (2M-N) n + 2L = 0.
10578
10579 APInt A = N;
10580 APInt B = 2 * M - A;
10581 APInt C = 2 * L;
10582 APInt T = APInt(NewWidth, 2);
10583 LLVM_DEBUG(dbgs() << __func__ << ": equation " << A << "x^2 + " << B
10584 << "x + " << C << ", coeff bw: " << NewWidth
10585 << ", multiplied by " << T << '\n');
10586 return std::make_tuple(A, B, C, T, BitWidth);
10587}
10588
10589/// Helper function to compare optional APInts:
10590/// (a) if X and Y both exist, return min(X, Y),
10591/// (b) if neither X nor Y exist, return std::nullopt,
10592/// (c) if exactly one of X and Y exists, return that value.
10593static std::optional<APInt> MinOptional(std::optional<APInt> X,
10594 std::optional<APInt> Y) {
10595 if (X && Y) {
10596 unsigned W = std::max(X->getBitWidth(), Y->getBitWidth());
10597 APInt XW = X->sext(W);
10598 APInt YW = Y->sext(W);
10599 return XW.slt(YW) ? *X : *Y;
10600 }
10601 if (!X && !Y)
10602 return std::nullopt;
10603 return X ? *X : *Y;
10604}
10605
10606/// Helper function to truncate an optional APInt to a given BitWidth.
10607/// When solving addrec-related equations, it is preferable to return a value
10608/// that has the same bit width as the original addrec's coefficients. If the
10609/// solution fits in the original bit width, truncate it (except for i1).
10610/// Returning a value of a different bit width may inhibit some optimizations.
10611///
10612/// In general, a solution to a quadratic equation generated from an addrec
10613/// may require BW+1 bits, where BW is the bit width of the addrec's
10614/// coefficients. The reason is that the coefficients of the quadratic
10615/// equation are BW+1 bits wide (to avoid truncation when converting from
10616/// the addrec to the equation).
10617static std::optional<APInt> TruncIfPossible(std::optional<APInt> X,
10618 unsigned BitWidth) {
10619 if (!X)
10620 return std::nullopt;
10621 unsigned W = X->getBitWidth();
10623 return X->trunc(BitWidth);
10624 return X;
10625}
10626
10627/// Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n
10628/// iterations. The values L, M, N are assumed to be signed, and they
10629/// should all have the same bit widths.
10630/// Find the least n >= 0 such that c(n) = 0 in the arithmetic modulo 2^BW,
10631/// where BW is the bit width of the addrec's coefficients.
10632/// If the calculated value is a BW-bit integer (for BW > 1), it will be
10633/// returned as such, otherwise the bit width of the returned value may
10634/// be greater than BW.
10635///
10636/// This function returns std::nullopt if
10637/// (a) the addrec coefficients are not constant, or
10638/// (b) SolveQuadraticEquationWrap was unable to find a solution. For cases
10639/// like x^2 = 5, no integer solutions exist, in other cases an integer
10640/// solution may exist, but SolveQuadraticEquationWrap may fail to find it.
10641static std::optional<APInt>
10643 APInt A, B, C, M;
10644 unsigned BitWidth;
10645 auto T = GetQuadraticEquation(AddRec);
10646 if (!T)
10647 return std::nullopt;
10648
10649 std::tie(A, B, C, M, BitWidth) = *T;
10650 LLVM_DEBUG(dbgs() << __func__ << ": solving for unsigned overflow\n");
10651 std::optional<APInt> X =
10653 if (!X)
10654 return std::nullopt;
10655
10656 ConstantInt *CX = ConstantInt::get(SE.getContext(), *X);
10657 ConstantInt *V = EvaluateConstantChrecAtConstant(AddRec, CX, SE);
10658 if (!V->isZero())
10659 return std::nullopt;
10660
10661 return TruncIfPossible(X, BitWidth);
10662}
10663
10664/// Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n
10665/// iterations. The values M, N are assumed to be signed, and they
10666/// should all have the same bit widths.
10667/// Find the least n such that c(n) does not belong to the given range,
10668/// while c(n-1) does.
10669///
10670/// This function returns std::nullopt if
10671/// (a) the addrec coefficients are not constant, or
10672/// (b) SolveQuadraticEquationWrap was unable to find a solution for the
10673/// bounds of the range.
10674static std::optional<APInt>
10676 const ConstantRange &Range, ScalarEvolution &SE) {
10677 assert(AddRec->getOperand(0)->isZero() &&
10678 "Starting value of addrec should be 0");
10679 LLVM_DEBUG(dbgs() << __func__ << ": solving boundary crossing for range "
10680 << Range << ", addrec " << *AddRec << '\n');
10681 // This case is handled in getNumIterationsInRange. Here we can assume that
10682 // we start in the range.
10683 assert(Range.contains(APInt(SE.getTypeSizeInBits(AddRec->getType()), 0)) &&
10684 "Addrec's initial value should be in range");
10685
10686 APInt A, B, C, M;
10687 unsigned BitWidth;
10688 auto T = GetQuadraticEquation(AddRec);
10689 if (!T)
10690 return std::nullopt;
10691
10692 // Be careful about the return value: there can be two reasons for not
10693 // returning an actual number. First, if no solutions to the equations
10694 // were found, and second, if the solutions don't leave the given range.
10695 // The first case means that the actual solution is "unknown", the second
10696 // means that it's known, but not valid. If the solution is unknown, we
10697 // cannot make any conclusions.
10698 // Return a pair: the optional solution and a flag indicating if the
10699 // solution was found.
10700 auto SolveForBoundary =
10701 [&](APInt Bound) -> std::pair<std::optional<APInt>, bool> {
10702 // Solve for signed overflow and unsigned overflow, pick the lower
10703 // solution.
10704 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: checking boundary "
10705 << Bound << " (before multiplying by " << M << ")\n");
10706 Bound *= M; // The quadratic equation multiplier.
10707
10708 std::optional<APInt> SO;
10709 if (BitWidth > 1) {
10710 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10711 "signed overflow\n");
10713 }
10714 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10715 "unsigned overflow\n");
10716 std::optional<APInt> UO =
10718
10719 auto LeavesRange = [&] (const APInt &X) {
10720 ConstantInt *C0 = ConstantInt::get(SE.getContext(), X);
10721 ConstantInt *V0 = EvaluateConstantChrecAtConstant(AddRec, C0, SE);
10722 if (Range.contains(V0->getValue()))
10723 return false;
10724 // X should be at least 1, so X-1 is non-negative.
10725 ConstantInt *C1 = ConstantInt::get(SE.getContext(), X-1);
10726 ConstantInt *V1 = EvaluateConstantChrecAtConstant(AddRec, C1, SE);
10727 if (Range.contains(V1->getValue()))
10728 return true;
10729 return false;
10730 };
10731
10732 // If SolveQuadraticEquationWrap returns std::nullopt, it means that there
10733 // can be a solution, but the function failed to find it. We cannot treat it
10734 // as "no solution".
10735 if (!SO || !UO)
10736 return {std::nullopt, false};
10737
10738 // Check the smaller value first to see if it leaves the range.
10739 // At this point, both SO and UO must have values.
10740 std::optional<APInt> Min = MinOptional(SO, UO);
10741 if (LeavesRange(*Min))
10742 return { Min, true };
10743 std::optional<APInt> Max = Min == SO ? UO : SO;
10744 if (LeavesRange(*Max))
10745 return { Max, true };
10746
10747 // Solutions were found, but were eliminated, hence the "true".
10748 return {std::nullopt, true};
10749 };
10750
10751 std::tie(A, B, C, M, BitWidth) = *T;
10752 // Lower bound is inclusive, subtract 1 to represent the exiting value.
10753 APInt Lower = Range.getLower().sext(A.getBitWidth()) - 1;
10754 APInt Upper = Range.getUpper().sext(A.getBitWidth());
10755 auto SL = SolveForBoundary(Lower);
10756 auto SU = SolveForBoundary(Upper);
10757 // If any of the solutions was unknown, no meaninigful conclusions can
10758 // be made.
10759 if (!SL.second || !SU.second)
10760 return std::nullopt;
10761
10762 // Claim: The correct solution is not some value between Min and Max.
10763 //
10764 // Justification: Assuming that Min and Max are different values, one of
10765 // them is when the first signed overflow happens, the other is when the
10766 // first unsigned overflow happens. Crossing the range boundary is only
10767 // possible via an overflow (treating 0 as a special case of it, modeling
10768 // an overflow as crossing k*2^W for some k).
10769 //
10770 // The interesting case here is when Min was eliminated as an invalid
10771 // solution, but Max was not. The argument is that if there was another
10772 // overflow between Min and Max, it would also have been eliminated if
10773 // it was considered.
10774 //
10775 // For a given boundary, it is possible to have two overflows of the same
10776 // type (signed/unsigned) without having the other type in between: this
10777 // can happen when the vertex of the parabola is between the iterations
10778 // corresponding to the overflows. This is only possible when the two
10779 // overflows cross k*2^W for the same k. In such case, if the second one
10780 // left the range (and was the first one to do so), the first overflow
10781 // would have to enter the range, which would mean that either we had left
10782 // the range before or that we started outside of it. Both of these cases
10783 // are contradictions.
10784 //
10785 // Claim: In the case where SolveForBoundary returns std::nullopt, the correct
10786 // solution is not some value between the Max for this boundary and the
10787 // Min of the other boundary.
10788 //
10789 // Justification: Assume that we had such Max_A and Min_B corresponding
10790 // to range boundaries A and B and such that Max_A < Min_B. If there was
10791 // a solution between Max_A and Min_B, it would have to be caused by an
10792 // overflow corresponding to either A or B. It cannot correspond to B,
10793 // since Min_B is the first occurrence of such an overflow. If it
10794 // corresponded to A, it would have to be either a signed or an unsigned
10795 // overflow that is larger than both eliminated overflows for A. But
10796 // between the eliminated overflows and this overflow, the values would
10797 // cover the entire value space, thus crossing the other boundary, which
10798 // is a contradiction.
10799
10800 return TruncIfPossible(MinOptional(SL.first, SU.first), BitWidth);
10801}
10802
10803ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
10804 const Loop *L,
10805 bool ControlsOnlyExit,
10806 bool AllowPredicates) {
10807
10808 // This is only used for loops with a "x != y" exit test. The exit condition
10809 // is now expressed as a single expression, V = x-y. So the exit test is
10810 // effectively V != 0. We know and take advantage of the fact that this
10811 // expression only being used in a comparison by zero context.
10812
10814 // If the value is a constant
10815 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10816 // If the value is already zero, the branch will execute zero times.
10817 if (C->getValue()->isZero()) return C;
10818 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10819 }
10820
10821 const SCEVAddRecExpr *AddRec =
10822 dyn_cast<SCEVAddRecExpr>(stripInjectiveFunctions(V));
10823
10824 if (!AddRec && AllowPredicates)
10825 // Try to make this an AddRec using runtime tests, in the first X
10826 // iterations of this loop, where X is the SCEV expression found by the
10827 // algorithm below.
10828 AddRec = convertSCEVToAddRecWithPredicates(V, L, Predicates);
10829
10830 if (!AddRec || AddRec->getLoop() != L)
10831 return getCouldNotCompute();
10832
10833 // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of
10834 // the quadratic equation to solve it.
10835 if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) {
10836 // We can only use this value if the chrec ends up with an exact zero
10837 // value at this index. When solving for "X*X != 5", for example, we
10838 // should not accept a root of 2.
10839 if (auto S = SolveQuadraticAddRecExact(AddRec, *this)) {
10840 const auto *R = cast<SCEVConstant>(getConstant(*S));
10841 return ExitLimit(R, R, R, false, Predicates);
10842 }
10843 return getCouldNotCompute();
10844 }
10845
10846 // Otherwise we can only handle this if it is affine.
10847 if (!AddRec->isAffine())
10848 return getCouldNotCompute();
10849
10850 // If this is an affine expression, the execution count of this branch is
10851 // the minimum unsigned root of the following equation:
10852 //
10853 // Start + Step*N = 0 (mod 2^BW)
10854 //
10855 // equivalent to:
10856 //
10857 // Step*N = -Start (mod 2^BW)
10858 //
10859 // where BW is the common bit width of Start and Step.
10860
10861 // Get the initial value for the loop.
10862 const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop());
10863 const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop());
10864
10865 if (!isLoopInvariant(Step, L))
10866 return getCouldNotCompute();
10867
10868 LoopGuards Guards = LoopGuards::collect(L, *this);
10869 // Specialize step for this loop so we get context sensitive facts below.
10870 const SCEV *StepWLG = applyLoopGuards(Step, Guards);
10871
10872 // For positive steps (counting up until unsigned overflow):
10873 // N = -Start/Step (as unsigned)
10874 // For negative steps (counting down to zero):
10875 // N = Start/-Step
10876 // First compute the unsigned distance from zero in the direction of Step.
10877 bool CountDown = isKnownNegative(StepWLG);
10878 if (!CountDown && !isKnownNonNegative(StepWLG))
10879 return getCouldNotCompute();
10880
10881 const SCEV *Distance = CountDown ? Start : getNegativeSCEV(Start);
10882 // Handle unitary steps, which cannot wraparound.
10883 // 1*N = -Start; -1*N = Start (mod 2^BW), so:
10884 // N = Distance (as unsigned)
10885
10886 if (match(Step, m_CombineOr(m_scev_One(), m_scev_AllOnes()))) {
10887 APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, Guards));
10888 MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance));
10889
10890 // When a loop like "for (int i = 0; i != n; ++i) { /* body */ }" is rotated,
10891 // we end up with a loop whose backedge-taken count is n - 1. Detect this
10892 // case, and see if we can improve the bound.
10893 //
10894 // Explicitly handling this here is necessary because getUnsignedRange
10895 // isn't context-sensitive; it doesn't know that we only care about the
10896 // range inside the loop.
10897 const SCEV *Zero = getZero(Distance->getType());
10898 const SCEV *One = getOne(Distance->getType());
10899 const SCEV *DistancePlusOne = getAddExpr(Distance, One);
10900 if (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, DistancePlusOne, Zero)) {
10901 // If Distance + 1 doesn't overflow, we can compute the maximum distance
10902 // as "unsigned_max(Distance + 1) - 1".
10903 ConstantRange CR = getUnsignedRange(DistancePlusOne);
10904 MaxBECount = APIntOps::umin(MaxBECount, CR.getUnsignedMax() - 1);
10905 }
10906 return ExitLimit(Distance, getConstant(MaxBECount), Distance, false,
10907 Predicates);
10908 }
10909
10910 // If the condition controls loop exit (the loop exits only if the expression
10911 // is true) and the addition is no-wrap we can use unsigned divide to
10912 // compute the backedge count. In this case, the step may not divide the
10913 // distance, but we don't care because if the condition is "missed" the loop
10914 // will have undefined behavior due to wrapping.
10915 if (ControlsOnlyExit && AddRec->hasNoSelfWrap() &&
10916 loopHasNoAbnormalExits(AddRec->getLoop())) {
10917
10918 // If the stride is zero and the start is non-zero, the loop must be
10919 // infinite. In C++, most loops are finite by assumption, in which case the
10920 // step being zero implies UB must execute if the loop is entered.
10921 if (!(loopIsFiniteByAssumption(L) && isKnownNonZero(Start)) &&
10922 !isKnownNonZero(StepWLG))
10923 return getCouldNotCompute();
10924
10925 const SCEV *Exact =
10926 getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
10927 const SCEV *ConstantMax = getCouldNotCompute();
10928 if (Exact != getCouldNotCompute()) {
10929 APInt MaxInt = getUnsignedRangeMax(applyLoopGuards(Exact, Guards));
10930 ConstantMax =
10932 }
10933 const SCEV *SymbolicMax =
10934 isa<SCEVCouldNotCompute>(Exact) ? ConstantMax : Exact;
10935 return ExitLimit(Exact, ConstantMax, SymbolicMax, false, Predicates);
10936 }
10937
10938 // Solve the general equation.
10939 const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
10940 if (!StepC || StepC->getValue()->isZero())
10941 return getCouldNotCompute();
10942 const SCEV *E = SolveLinEquationWithOverflow(
10943 StepC->getAPInt(), getNegativeSCEV(Start),
10944 AllowPredicates ? &Predicates : nullptr, *this, L);
10945
10946 const SCEV *M = E;
10947 if (E != getCouldNotCompute()) {
10948 APInt MaxWithGuards = getUnsignedRangeMax(applyLoopGuards(E, Guards));
10949 M = getConstant(APIntOps::umin(MaxWithGuards, getUnsignedRangeMax(E)));
10950 }
10951 auto *S = isa<SCEVCouldNotCompute>(E) ? M : E;
10952 return ExitLimit(E, M, S, false, Predicates);
10953}
10954
10955ScalarEvolution::ExitLimit
10956ScalarEvolution::howFarToNonZero(const SCEV *V, const Loop *L) {
10957 // Loops that look like: while (X == 0) are very strange indeed. We don't
10958 // handle them yet except for the trivial case. This could be expanded in the
10959 // future as needed.
10960
10961 // If the value is a constant, check to see if it is known to be non-zero
10962 // already. If so, the backedge will execute zero times.
10963 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10964 if (!C->getValue()->isZero())
10965 return getZero(C->getType());
10966 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10967 }
10968
10969 // We could implement others, but I really doubt anyone writes loops like
10970 // this, and if they did, they would already be constant folded.
10971 return getCouldNotCompute();
10972}
10973
10974std::pair<const BasicBlock *, const BasicBlock *>
10975ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(const BasicBlock *BB)
10976 const {
10977 // If the block has a unique predecessor, then there is no path from the
10978 // predecessor to the block that does not go through the direct edge
10979 // from the predecessor to the block.
10980 if (const BasicBlock *Pred = BB->getSinglePredecessor())
10981 return {Pred, BB};
10982
10983 // A loop's header is defined to be a block that dominates the loop.
10984 // If the header has a unique predecessor outside the loop, it must be
10985 // a block that has exactly one successor that can reach the loop.
10986 if (const Loop *L = LI.getLoopFor(BB))
10987 return {L->getLoopPredecessor(), L->getHeader()};
10988
10989 return {nullptr, BB};
10990}
10991
10992/// SCEV structural equivalence is usually sufficient for testing whether two
10993/// expressions are equal, however for the purposes of looking for a condition
10994/// guarding a loop, it can be useful to be a little more general, since a
10995/// front-end may have replicated the controlling expression.
10996static bool HasSameValue(const SCEV *A, const SCEV *B) {
10997 // Quick check to see if they are the same SCEV.
10998 if (A == B) return true;
10999
11000 auto ComputesEqualValues = [](const Instruction *A, const Instruction *B) {
11001 // Not all instructions that are "identical" compute the same value. For
11002 // instance, two distinct alloca instructions allocating the same type are
11003 // identical and do not read memory; but compute distinct values.
11004 return A->isIdenticalTo(B) && (isa<BinaryOperator>(A) || isa<GetElementPtrInst>(A));
11005 };
11006
11007 // Otherwise, if they're both SCEVUnknown, it's possible that they hold
11008 // two different instructions with the same value. Check for this case.
11009 if (const SCEVUnknown *AU = dyn_cast<SCEVUnknown>(A))
11010 if (const SCEVUnknown *BU = dyn_cast<SCEVUnknown>(B))
11011 if (const Instruction *AI = dyn_cast<Instruction>(AU->getValue()))
11012 if (const Instruction *BI = dyn_cast<Instruction>(BU->getValue()))
11013 if (ComputesEqualValues(AI, BI))
11014 return true;
11015
11016 // Otherwise assume they may have a different value.
11017 return false;
11018}
11019
11020static bool MatchBinarySub(const SCEV *S, SCEVUse &LHS, SCEVUse &RHS) {
11021 const SCEV *Op0, *Op1;
11022 if (!match(S, m_scev_Add(m_SCEV(Op0), m_SCEV(Op1))))
11023 return false;
11024 if (match(Op0, m_scev_Mul(m_scev_AllOnes(), m_SCEV(RHS)))) {
11025 LHS = Op1;
11026 return true;
11027 }
11028 if (match(Op1, m_scev_Mul(m_scev_AllOnes(), m_SCEV(RHS)))) {
11029 LHS = Op0;
11030 return true;
11031 }
11032 return false;
11033}
11034
11036 SCEVUse &RHS, unsigned Depth) {
11037 bool Changed = false;
11038 // Simplifies ICMP to trivial true or false by turning it into '0 == 0' or
11039 // '0 != 0'.
11040 auto TrivialCase = [&](bool TriviallyTrue) {
11042 Pred = TriviallyTrue ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE;
11043 return true;
11044 };
11045 // If we hit the max recursion limit bail out.
11046 if (Depth >= 3)
11047 return false;
11048
11049 const SCEV *NewLHS, *NewRHS;
11050 if (match(LHS, m_scev_c_Mul(m_SCEV(NewLHS), m_SCEVVScale())) &&
11051 match(RHS, m_scev_c_Mul(m_SCEV(NewRHS), m_SCEVVScale()))) {
11052 const SCEVMulExpr *LMul = cast<SCEVMulExpr>(LHS);
11053 const SCEVMulExpr *RMul = cast<SCEVMulExpr>(RHS);
11054
11055 // (X * vscale) pred (Y * vscale) ==> X pred Y
11056 // when both multiples are NSW.
11057 // (X * vscale) uicmp/eq/ne (Y * vscale) ==> X uicmp/eq/ne Y
11058 // when both multiples are NUW.
11059 if ((LMul->hasNoSignedWrap() && RMul->hasNoSignedWrap()) ||
11060 (LMul->hasNoUnsignedWrap() && RMul->hasNoUnsignedWrap() &&
11061 !ICmpInst::isSigned(Pred))) {
11062 LHS = NewLHS;
11063 RHS = NewRHS;
11064 Changed = true;
11065 }
11066 }
11067
11068 // Canonicalize a constant to the right side.
11069 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
11070 // Check for both operands constant.
11071 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
11072 if (!ICmpInst::compare(LHSC->getAPInt(), RHSC->getAPInt(), Pred))
11073 return TrivialCase(false);
11074 return TrivialCase(true);
11075 }
11076 // Otherwise swap the operands to put the constant on the right.
11077 std::swap(LHS, RHS);
11079 Changed = true;
11080 }
11081
11082 // If we're comparing an addrec with a value which is loop-invariant in the
11083 // addrec's loop, put the addrec on the left. Also make a dominance check,
11084 // as both operands could be addrecs loop-invariant in each other's loop.
11085 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(RHS)) {
11086 const Loop *L = AR->getLoop();
11087 if (isLoopInvariant(LHS, L) && properlyDominates(LHS, L->getHeader())) {
11088 std::swap(LHS, RHS);
11090 Changed = true;
11091 }
11092 }
11093
11094 // If there's a constant operand, canonicalize comparisons with boundary
11095 // cases, and canonicalize *-or-equal comparisons to regular comparisons.
11096 if (const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS)) {
11097 const APInt &RA = RC->getAPInt();
11098
11099 bool SimplifiedByConstantRange = false;
11100
11101 if (!ICmpInst::isEquality(Pred)) {
11103 if (ExactCR.isFullSet())
11104 return TrivialCase(true);
11105 if (ExactCR.isEmptySet())
11106 return TrivialCase(false);
11107
11108 APInt NewRHS;
11109 CmpInst::Predicate NewPred;
11110 if (ExactCR.getEquivalentICmp(NewPred, NewRHS) &&
11111 ICmpInst::isEquality(NewPred)) {
11112 // We were able to convert an inequality to an equality.
11113 Pred = NewPred;
11114 RHS = getConstant(NewRHS);
11115 Changed = SimplifiedByConstantRange = true;
11116 }
11117 }
11118
11119 if (!SimplifiedByConstantRange) {
11120 switch (Pred) {
11121 default:
11122 break;
11123 case ICmpInst::ICMP_EQ:
11124 case ICmpInst::ICMP_NE:
11125 // Fold ((-1) * %a) + %b == 0 (equivalent to %b-%a == 0) into %a == %b.
11126 if (RA.isZero() && MatchBinarySub(LHS, LHS, RHS))
11127 Changed = true;
11128 break;
11129
11130 // The "Should have been caught earlier!" messages refer to the fact
11131 // that the ExactCR.isFullSet() or ExactCR.isEmptySet() check above
11132 // should have fired on the corresponding cases, and canonicalized the
11133 // check to trivial case.
11134
11135 case ICmpInst::ICMP_UGE:
11136 assert(!RA.isMinValue() && "Should have been caught earlier!");
11137 Pred = ICmpInst::ICMP_UGT;
11138 RHS = getConstant(RA - 1);
11139 Changed = true;
11140 break;
11141 case ICmpInst::ICMP_ULE:
11142 assert(!RA.isMaxValue() && "Should have been caught earlier!");
11143 Pred = ICmpInst::ICMP_ULT;
11144 RHS = getConstant(RA + 1);
11145 Changed = true;
11146 break;
11147 case ICmpInst::ICMP_SGE:
11148 assert(!RA.isMinSignedValue() && "Should have been caught earlier!");
11149 Pred = ICmpInst::ICMP_SGT;
11150 RHS = getConstant(RA - 1);
11151 Changed = true;
11152 break;
11153 case ICmpInst::ICMP_SLE:
11154 assert(!RA.isMaxSignedValue() && "Should have been caught earlier!");
11155 Pred = ICmpInst::ICMP_SLT;
11156 RHS = getConstant(RA + 1);
11157 Changed = true;
11158 break;
11159 }
11160 }
11161 }
11162
11163 // Check for obvious equality.
11164 if (HasSameValue(LHS, RHS)) {
11165 if (ICmpInst::isTrueWhenEqual(Pred))
11166 return TrivialCase(true);
11168 return TrivialCase(false);
11169 }
11170
11171 // If possible, canonicalize GE/LE comparisons to GT/LT comparisons, by
11172 // adding or subtracting 1 from one of the operands.
11173 switch (Pred) {
11174 case ICmpInst::ICMP_SLE:
11175 if (!getSignedRangeMax(RHS).isMaxSignedValue()) {
11176 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
11178 Pred = ICmpInst::ICMP_SLT;
11179 Changed = true;
11180 } else if (!getSignedRangeMin(LHS).isMinSignedValue()) {
11181 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS,
11183 Pred = ICmpInst::ICMP_SLT;
11184 Changed = true;
11185 }
11186 break;
11187 case ICmpInst::ICMP_SGE:
11188 if (!getSignedRangeMin(RHS).isMinSignedValue()) {
11189 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS,
11191 Pred = ICmpInst::ICMP_SGT;
11192 Changed = true;
11193 } else if (!getSignedRangeMax(LHS).isMaxSignedValue()) {
11194 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
11196 Pred = ICmpInst::ICMP_SGT;
11197 Changed = true;
11198 }
11199 break;
11200 case ICmpInst::ICMP_ULE:
11201 if (!getUnsignedRangeMax(RHS).isMaxValue()) {
11202 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
11204 Pred = ICmpInst::ICMP_ULT;
11205 Changed = true;
11206 } else if (!getUnsignedRangeMin(LHS).isMinValue()) {
11207 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS);
11208 Pred = ICmpInst::ICMP_ULT;
11209 Changed = true;
11210 }
11211 break;
11212 case ICmpInst::ICMP_UGE:
11213 // If RHS is an op we can fold the -1, try that first.
11214 // Otherwise prefer LHS to preserve the nuw flag.
11215 if ((isa<SCEVConstant>(RHS) ||
11217 isa<SCEVConstant>(cast<SCEVNAryExpr>(RHS)->getOperand(0)))) &&
11218 !getUnsignedRangeMin(RHS).isMinValue()) {
11219 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
11220 Pred = ICmpInst::ICMP_UGT;
11221 Changed = true;
11222 } else if (!getUnsignedRangeMax(LHS).isMaxValue()) {
11223 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
11225 Pred = ICmpInst::ICMP_UGT;
11226 Changed = true;
11227 } else if (!getUnsignedRangeMin(RHS).isMinValue()) {
11228 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
11229 Pred = ICmpInst::ICMP_UGT;
11230 Changed = true;
11231 }
11232 break;
11233 default:
11234 break;
11235 }
11236
11237 // TODO: More simplifications are possible here.
11238
11239 // Recursively simplify until we either hit a recursion limit or nothing
11240 // changes.
11241 if (Changed)
11242 (void)SimplifyICmpOperands(Pred, LHS, RHS, Depth + 1);
11243
11244 return Changed;
11245}
11246
11248 return getSignedRangeMax(S).isNegative();
11249}
11250
11254
11256 return !getSignedRangeMin(S).isNegative();
11257}
11258
11262
11264 // Query push down for cases where the unsigned range is
11265 // less than sufficient.
11266 if (const auto *SExt = dyn_cast<SCEVSignExtendExpr>(S))
11267 return isKnownNonZero(SExt->getOperand(0));
11268 return getUnsignedRangeMin(S) != 0;
11269}
11270
11272 bool OrNegative) {
11273 auto NonRecursive = [OrNegative](const SCEV *S) {
11274 if (auto *C = dyn_cast<SCEVConstant>(S))
11275 return C->getAPInt().isPowerOf2() ||
11276 (OrNegative && C->getAPInt().isNegatedPowerOf2());
11277
11278 // vscale is a power-of-two.
11279 return isa<SCEVVScale>(S);
11280 };
11281
11282 if (NonRecursive(S))
11283 return true;
11284
11285 auto *Mul = dyn_cast<SCEVMulExpr>(S);
11286 if (!Mul)
11287 return false;
11288 return all_of(Mul->operands(), NonRecursive) && (OrZero || isKnownNonZero(S));
11289}
11290
11292 const SCEV *S, uint64_t M,
11294 if (M == 0)
11295 return false;
11296 if (M == 1)
11297 return true;
11298
11299 // Recursively check AddRec operands. An AddRecExpr S is a multiple of M if S
11300 // starts with a multiple of M and at every iteration step S only adds
11301 // multiples of M.
11302 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S))
11303 return isKnownMultipleOf(AddRec->getStart(), M, Assumptions) &&
11304 isKnownMultipleOf(AddRec->getStepRecurrence(*this), M, Assumptions);
11305
11306 // For a constant, check that "S % M == 0".
11307 if (auto *Cst = dyn_cast<SCEVConstant>(S)) {
11308 APInt C = Cst->getAPInt();
11309 return C.urem(M) == 0;
11310 }
11311
11312 // TODO: Also check other SCEV expressions, i.e., SCEVAddRecExpr, etc.
11313
11314 // Basic tests have failed.
11315 // Check "S % M == 0" at compile time and record runtime Assumptions.
11316 auto *STy = dyn_cast<IntegerType>(S->getType());
11317 const SCEV *SmodM =
11318 getURemExpr(S, getConstant(ConstantInt::get(STy, M, false)));
11319 const SCEV *Zero = getZero(STy);
11320
11321 // Check whether "S % M == 0" is known at compile time.
11322 if (isKnownPredicate(ICmpInst::ICMP_EQ, SmodM, Zero))
11323 return true;
11324
11325 // Check whether "S % M != 0" is known at compile time.
11326 if (isKnownPredicate(ICmpInst::ICMP_NE, SmodM, Zero))
11327 return false;
11328
11330
11331 // Detect redundant predicates.
11332 for (auto *A : Assumptions)
11333 if (A->implies(P, *this))
11334 return true;
11335
11336 // Only record non-redundant predicates.
11337 Assumptions.push_back(P);
11338 return true;
11339}
11340
11342 return ((isKnownNonNegative(S1) && isKnownNonNegative(S2)) ||
11344}
11345
11346std::pair<const SCEV *, const SCEV *>
11348 // Compute SCEV on entry of loop L.
11349 const SCEV *Start = SCEVInitRewriter::rewrite(S, L, *this);
11350 if (Start == getCouldNotCompute())
11351 return { Start, Start };
11352 // Compute post increment SCEV for loop L.
11353 const SCEV *PostInc = SCEVPostIncRewriter::rewrite(S, L, *this);
11354 assert(PostInc != getCouldNotCompute() && "Unexpected could not compute");
11355 return { Start, PostInc };
11356}
11357
11359 SCEVUse RHS) {
11360 // First collect all loops.
11362 getUsedLoops(LHS, LoopsUsed);
11363 getUsedLoops(RHS, LoopsUsed);
11364
11365 if (LoopsUsed.empty())
11366 return false;
11367
11368 // Domination relationship must be a linear order on collected loops.
11369#ifndef NDEBUG
11370 for (const auto *L1 : LoopsUsed)
11371 for (const auto *L2 : LoopsUsed)
11372 assert((DT.dominates(L1->getHeader(), L2->getHeader()) ||
11373 DT.dominates(L2->getHeader(), L1->getHeader())) &&
11374 "Domination relationship is not a linear order");
11375#endif
11376
11377 const Loop *MDL =
11378 *llvm::max_element(LoopsUsed, [&](const Loop *L1, const Loop *L2) {
11379 return DT.properlyDominates(L1->getHeader(), L2->getHeader());
11380 });
11381
11382 // Get init and post increment value for LHS.
11383 auto SplitLHS = SplitIntoInitAndPostInc(MDL, LHS);
11384 // if LHS contains unknown non-invariant SCEV then bail out.
11385 if (SplitLHS.first == getCouldNotCompute())
11386 return false;
11387 assert (SplitLHS.second != getCouldNotCompute() && "Unexpected CNC");
11388 // Get init and post increment value for RHS.
11389 auto SplitRHS = SplitIntoInitAndPostInc(MDL, RHS);
11390 // if RHS contains unknown non-invariant SCEV then bail out.
11391 if (SplitRHS.first == getCouldNotCompute())
11392 return false;
11393 assert (SplitRHS.second != getCouldNotCompute() && "Unexpected CNC");
11394 // It is possible that init SCEV contains an invariant load but it does
11395 // not dominate MDL and is not available at MDL loop entry, so we should
11396 // check it here.
11397 if (!isAvailableAtLoopEntry(SplitLHS.first, MDL) ||
11398 !isAvailableAtLoopEntry(SplitRHS.first, MDL))
11399 return false;
11400
11401 // It seems backedge guard check is faster than entry one so in some cases
11402 // it can speed up whole estimation by short circuit
11403 return isLoopBackedgeGuardedByCond(MDL, Pred, SplitLHS.second,
11404 SplitRHS.second) &&
11405 isLoopEntryGuardedByCond(MDL, Pred, SplitLHS.first, SplitRHS.first);
11406}
11407
11409 SCEVUse RHS) {
11410 // Canonicalize the inputs first.
11411 (void)SimplifyICmpOperands(Pred, LHS, RHS);
11412
11413 if (isKnownViaInduction(Pred, LHS, RHS))
11414 return true;
11415
11416 if (isKnownPredicateViaSplitting(Pred, LHS, RHS))
11417 return true;
11418
11419 // Otherwise see what can be done with some simple reasoning.
11420 return isKnownViaNonRecursiveReasoning(Pred, LHS, RHS);
11421}
11422
11424 const SCEV *LHS,
11425 const SCEV *RHS) {
11426 if (isKnownPredicate(Pred, LHS, RHS))
11427 return true;
11429 return false;
11430 return std::nullopt;
11431}
11432
11434 const SCEV *RHS,
11435 const Instruction *CtxI) {
11436 // TODO: Analyze guards and assumes from Context's block.
11437 return isKnownPredicate(Pred, LHS, RHS) ||
11438 isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS);
11439}
11440
11441std::optional<bool>
11443 const SCEV *RHS, const Instruction *CtxI) {
11444 std::optional<bool> KnownWithoutContext = evaluatePredicate(Pred, LHS, RHS);
11445 if (KnownWithoutContext)
11446 return KnownWithoutContext;
11447
11448 if (isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS))
11449 return true;
11451 CtxI->getParent(), ICmpInst::getInverseCmpPredicate(Pred), LHS, RHS))
11452 return false;
11453 return std::nullopt;
11454}
11455
11457 const SCEVAddRecExpr *LHS,
11458 const SCEV *RHS) {
11459 const Loop *L = LHS->getLoop();
11460 return isLoopEntryGuardedByCond(L, Pred, LHS->getStart(), RHS) &&
11461 isLoopBackedgeGuardedByCond(L, Pred, LHS->getPostIncExpr(*this), RHS);
11462}
11463
11464std::optional<ScalarEvolution::MonotonicPredicateType>
11466 ICmpInst::Predicate Pred) {
11467 auto Result = getMonotonicPredicateTypeImpl(LHS, Pred);
11468
11469#ifndef NDEBUG
11470 // Verify an invariant: inverting the predicate should turn a monotonically
11471 // increasing change to a monotonically decreasing one, and vice versa.
11472 if (Result) {
11473 auto ResultSwapped =
11474 getMonotonicPredicateTypeImpl(LHS, ICmpInst::getSwappedPredicate(Pred));
11475
11476 assert(*ResultSwapped != *Result &&
11477 "monotonicity should flip as we flip the predicate");
11478 }
11479#endif
11480
11481 return Result;
11482}
11483
11484std::optional<ScalarEvolution::MonotonicPredicateType>
11485ScalarEvolution::getMonotonicPredicateTypeImpl(const SCEVAddRecExpr *LHS,
11486 ICmpInst::Predicate Pred) {
11487 // A zero step value for LHS means the induction variable is essentially a
11488 // loop invariant value. We don't really depend on the predicate actually
11489 // flipping from false to true (for increasing predicates, and the other way
11490 // around for decreasing predicates), all we care about is that *if* the
11491 // predicate changes then it only changes from false to true.
11492 //
11493 // A zero step value in itself is not very useful, but there may be places
11494 // where SCEV can prove X >= 0 but not prove X > 0, so it is helpful to be
11495 // as general as possible.
11496
11497 // Only handle LE/LT/GE/GT predicates.
11498 if (!ICmpInst::isRelational(Pred))
11499 return std::nullopt;
11500
11501 bool IsGreater = ICmpInst::isGE(Pred) || ICmpInst::isGT(Pred);
11502 assert((IsGreater || ICmpInst::isLE(Pred) || ICmpInst::isLT(Pred)) &&
11503 "Should be greater or less!");
11504
11505 // Check that AR does not wrap.
11506 if (ICmpInst::isUnsigned(Pred)) {
11507 if (!LHS->hasNoUnsignedWrap())
11508 return std::nullopt;
11510 }
11511 assert(ICmpInst::isSigned(Pred) &&
11512 "Relational predicate is either signed or unsigned!");
11513 if (!LHS->hasNoSignedWrap())
11514 return std::nullopt;
11515
11516 const SCEV *Step = LHS->getStepRecurrence(*this);
11517
11518 if (isKnownNonNegative(Step))
11520
11521 if (isKnownNonPositive(Step))
11523
11524 return std::nullopt;
11525}
11526
11527std::optional<ScalarEvolution::LoopInvariantPredicate>
11529 const SCEV *RHS, const Loop *L,
11530 const Instruction *CtxI) {
11531 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
11532 if (!isLoopInvariant(RHS, L)) {
11533 if (!isLoopInvariant(LHS, L))
11534 return std::nullopt;
11535
11536 std::swap(LHS, RHS);
11538 }
11539
11540 const SCEVAddRecExpr *ArLHS = dyn_cast<SCEVAddRecExpr>(LHS);
11541 if (!ArLHS || ArLHS->getLoop() != L)
11542 return std::nullopt;
11543
11544 auto MonotonicType = getMonotonicPredicateType(ArLHS, Pred);
11545 if (!MonotonicType)
11546 return std::nullopt;
11547 // If the predicate "ArLHS `Pred` RHS" monotonically increases from false to
11548 // true as the loop iterates, and the backedge is control dependent on
11549 // "ArLHS `Pred` RHS" == true then we can reason as follows:
11550 //
11551 // * if the predicate was false in the first iteration then the predicate
11552 // is never evaluated again, since the loop exits without taking the
11553 // backedge.
11554 // * if the predicate was true in the first iteration then it will
11555 // continue to be true for all future iterations since it is
11556 // monotonically increasing.
11557 //
11558 // For both the above possibilities, we can replace the loop varying
11559 // predicate with its value on the first iteration of the loop (which is
11560 // loop invariant).
11561 //
11562 // A similar reasoning applies for a monotonically decreasing predicate, by
11563 // replacing true with false and false with true in the above two bullets.
11565 auto P = Increasing ? Pred : ICmpInst::getInverseCmpPredicate(Pred);
11566
11567 if (isLoopBackedgeGuardedByCond(L, P, LHS, RHS))
11569 RHS);
11570
11571 if (!CtxI)
11572 return std::nullopt;
11573 // Try to prove via context.
11574 // TODO: Support other cases.
11575 switch (Pred) {
11576 default:
11577 break;
11578 case ICmpInst::ICMP_ULE:
11579 case ICmpInst::ICMP_ULT: {
11580 assert(ArLHS->hasNoUnsignedWrap() && "Is a requirement of monotonicity!");
11581 // Given preconditions
11582 // (1) ArLHS does not cross the border of positive and negative parts of
11583 // range because of:
11584 // - Positive step; (TODO: lift this limitation)
11585 // - nuw - does not cross zero boundary;
11586 // - nsw - does not cross SINT_MAX boundary;
11587 // (2) ArLHS <s RHS
11588 // (3) RHS >=s 0
11589 // we can replace the loop variant ArLHS <u RHS condition with loop
11590 // invariant Start(ArLHS) <u RHS.
11591 //
11592 // Because of (1) there are two options:
11593 // - ArLHS is always negative. It means that ArLHS <u RHS is always false;
11594 // - ArLHS is always non-negative. Because of (3) RHS is also non-negative.
11595 // It means that ArLHS <s RHS <=> ArLHS <u RHS.
11596 // Because of (2) ArLHS <u RHS is trivially true.
11597 // All together it means that ArLHS <u RHS <=> Start(ArLHS) >=s 0.
11598 // We can strengthen this to Start(ArLHS) <u RHS.
11599 auto SignFlippedPred = ICmpInst::getFlippedSignednessPredicate(Pred);
11600 if (ArLHS->hasNoSignedWrap() && ArLHS->isAffine() &&
11601 isKnownPositive(ArLHS->getStepRecurrence(*this)) &&
11602 isKnownNonNegative(RHS) &&
11603 isKnownPredicateAt(SignFlippedPred, ArLHS, RHS, CtxI))
11605 RHS);
11606 }
11607 }
11608
11609 return std::nullopt;
11610}
11611
11612std::optional<ScalarEvolution::LoopInvariantPredicate>
11614 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11615 const Instruction *CtxI, const SCEV *MaxIter) {
11617 Pred, LHS, RHS, L, CtxI, MaxIter))
11618 return LIP;
11619 if (auto *UMin = dyn_cast<SCEVUMinExpr>(MaxIter))
11620 // Number of iterations expressed as UMIN isn't always great for expressing
11621 // the value on the last iteration. If the straightforward approach didn't
11622 // work, try the following trick: if the a predicate is invariant for X, it
11623 // is also invariant for umin(X, ...). So try to find something that works
11624 // among subexpressions of MaxIter expressed as umin.
11625 for (SCEVUse Op : UMin->operands())
11627 Pred, LHS, RHS, L, CtxI, Op))
11628 return LIP;
11629 return std::nullopt;
11630}
11631
11632std::optional<ScalarEvolution::LoopInvariantPredicate>
11634 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11635 const Instruction *CtxI, const SCEV *MaxIter) {
11636 // Try to prove the following set of facts:
11637 // - The predicate is monotonic in the iteration space.
11638 // - If the check does not fail on the 1st iteration:
11639 // - No overflow will happen during first MaxIter iterations;
11640 // - It will not fail on the MaxIter'th iteration.
11641 // If the check does fail on the 1st iteration, we leave the loop and no
11642 // other checks matter.
11643
11644 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
11645 if (!isLoopInvariant(RHS, L)) {
11646 if (!isLoopInvariant(LHS, L))
11647 return std::nullopt;
11648
11649 std::swap(LHS, RHS);
11651 }
11652
11653 auto *AR = dyn_cast<SCEVAddRecExpr>(LHS);
11654 if (!AR || AR->getLoop() != L)
11655 return std::nullopt;
11656
11657 // Even if both are valid, we need to consistently chose the unsigned or the
11658 // signed predicate below, not mixtures of both. For now, prefer the unsigned
11659 // predicate.
11660 Pred = Pred.dropSameSign();
11661
11662 // The predicate must be relational (i.e. <, <=, >=, >).
11663 if (!ICmpInst::isRelational(Pred))
11664 return std::nullopt;
11665
11666 // TODO: Support steps other than +/- 1.
11667 const SCEV *Step = AR->getStepRecurrence(*this);
11668 auto *One = getOne(Step->getType());
11669 auto *MinusOne = getNegativeSCEV(One);
11670 if (Step != One && Step != MinusOne)
11671 return std::nullopt;
11672
11673 // Type mismatch here means that MaxIter is potentially larger than max
11674 // unsigned value in start type, which mean we cannot prove no wrap for the
11675 // indvar.
11676 if (AR->getType() != MaxIter->getType())
11677 return std::nullopt;
11678
11679 // Value of IV on suggested last iteration.
11680 const SCEV *Last = AR->evaluateAtIteration(MaxIter, *this);
11681 // Does it still meet the requirement?
11682 if (!isLoopBackedgeGuardedByCond(L, Pred, Last, RHS))
11683 return std::nullopt;
11684 // Because step is +/- 1 and MaxIter has same type as Start (i.e. it does
11685 // not exceed max unsigned value of this type), this effectively proves
11686 // that there is no wrap during the iteration. To prove that there is no
11687 // signed/unsigned wrap, we need to check that
11688 // Start <= Last for step = 1 or Start >= Last for step = -1.
11689 ICmpInst::Predicate NoOverflowPred =
11691 if (Step == MinusOne)
11692 NoOverflowPred = ICmpInst::getSwappedPredicate(NoOverflowPred);
11693 const SCEV *Start = AR->getStart();
11694 if (!isKnownPredicateAt(NoOverflowPred, Start, Last, CtxI))
11695 return std::nullopt;
11696
11697 // Everything is fine.
11698 return ScalarEvolution::LoopInvariantPredicate(Pred, Start, RHS);
11699}
11700
11701bool ScalarEvolution::isKnownPredicateViaConstantRanges(CmpPredicate Pred,
11702 SCEVUse LHS,
11703 SCEVUse RHS) {
11704 if (HasSameValue(LHS, RHS))
11705 return ICmpInst::isTrueWhenEqual(Pred);
11706
11707 auto CheckRange = [&](bool IsSigned) {
11708 auto RangeLHS = IsSigned ? getSignedRange(LHS) : getUnsignedRange(LHS);
11709 auto RangeRHS = IsSigned ? getSignedRange(RHS) : getUnsignedRange(RHS);
11710 return RangeLHS.icmp(Pred, RangeRHS);
11711 };
11712
11713 // The check at the top of the function catches the case where the values are
11714 // known to be equal.
11715 if (Pred == CmpInst::ICMP_EQ)
11716 return false;
11717
11718 if (Pred == CmpInst::ICMP_NE) {
11719 if (CheckRange(true) || CheckRange(false))
11720 return true;
11721 auto *Diff = getMinusSCEV(LHS, RHS);
11722 return !isa<SCEVCouldNotCompute>(Diff) && isKnownNonZero(Diff);
11723 }
11724
11725 return CheckRange(CmpInst::isSigned(Pred));
11726}
11727
11728bool ScalarEvolution::isKnownPredicateViaNoOverflow(CmpPredicate Pred,
11730 // Match X to (A + C1)<ExpectedFlags> and Y to (A + C2)<ExpectedFlags>, where
11731 // C1 and C2 are constant integers. If either X or Y are not add expressions,
11732 // consider them as X + 0 and Y + 0 respectively. C1 and C2 are returned via
11733 // OutC1 and OutC2.
11734 auto MatchBinaryAddToConst = [this](SCEVUse X, SCEVUse Y, APInt &OutC1,
11735 APInt &OutC2,
11736 SCEV::NoWrapFlags ExpectedFlags) {
11737 SCEVUse XNonConstOp, XConstOp;
11738 SCEVUse YNonConstOp, YConstOp;
11739 SCEV::NoWrapFlags XFlagsPresent;
11740 SCEV::NoWrapFlags YFlagsPresent;
11741
11742 if (!splitBinaryAdd(X, XConstOp, XNonConstOp, XFlagsPresent)) {
11743 XConstOp = getZero(X->getType());
11744 XNonConstOp = X;
11745 XFlagsPresent = ExpectedFlags;
11746 }
11747 if (!isa<SCEVConstant>(XConstOp))
11748 return false;
11749
11750 if (!splitBinaryAdd(Y, YConstOp, YNonConstOp, YFlagsPresent)) {
11751 YConstOp = getZero(Y->getType());
11752 YNonConstOp = Y;
11753 YFlagsPresent = ExpectedFlags;
11754 }
11755
11756 if (YNonConstOp != XNonConstOp)
11757 return false;
11758
11759 if (!isa<SCEVConstant>(YConstOp))
11760 return false;
11761
11762 // When matching ADDs with NUW flags (and unsigned predicates), only the
11763 // second ADD (with the larger constant) requires NUW.
11764 if ((YFlagsPresent & ExpectedFlags) != ExpectedFlags)
11765 return false;
11766 if (ExpectedFlags != SCEV::FlagNUW &&
11767 (XFlagsPresent & ExpectedFlags) != ExpectedFlags) {
11768 return false;
11769 }
11770
11771 OutC1 = cast<SCEVConstant>(XConstOp)->getAPInt();
11772 OutC2 = cast<SCEVConstant>(YConstOp)->getAPInt();
11773
11774 return true;
11775 };
11776
11777 APInt C1;
11778 APInt C2;
11779
11780 switch (Pred) {
11781 default:
11782 break;
11783
11784 case ICmpInst::ICMP_SGE:
11785 std::swap(LHS, RHS);
11786 [[fallthrough]];
11787 case ICmpInst::ICMP_SLE:
11788 // (X + C1)<nsw> s<= (X + C2)<nsw> if C1 s<= C2.
11789 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.sle(C2))
11790 return true;
11791
11792 break;
11793
11794 case ICmpInst::ICMP_SGT:
11795 std::swap(LHS, RHS);
11796 [[fallthrough]];
11797 case ICmpInst::ICMP_SLT:
11798 // (X + C1)<nsw> s< (X + C2)<nsw> if C1 s< C2.
11799 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.slt(C2))
11800 return true;
11801
11802 break;
11803
11804 case ICmpInst::ICMP_UGE:
11805 std::swap(LHS, RHS);
11806 [[fallthrough]];
11807 case ICmpInst::ICMP_ULE:
11808 // (X + C1) u<= (X + C2)<nuw> for C1 u<= C2.
11809 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNUW) && C1.ule(C2))
11810 return true;
11811
11812 break;
11813
11814 case ICmpInst::ICMP_UGT:
11815 std::swap(LHS, RHS);
11816 [[fallthrough]];
11817 case ICmpInst::ICMP_ULT:
11818 // (X + C1) u< (X + C2)<nuw> if C1 u< C2.
11819 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNUW) && C1.ult(C2))
11820 return true;
11821 break;
11822 }
11823
11824 return false;
11825}
11826
11827bool ScalarEvolution::isKnownPredicateViaSplitting(CmpPredicate Pred,
11829 if (Pred != ICmpInst::ICMP_ULT || ProvingSplitPredicate)
11830 return false;
11831
11832 // Allowing arbitrary number of activations of isKnownPredicateViaSplitting on
11833 // the stack can result in exponential time complexity.
11834 SaveAndRestore Restore(ProvingSplitPredicate, true);
11835
11836 // If L >= 0 then I `ult` L <=> I >= 0 && I `slt` L
11837 //
11838 // To prove L >= 0 we use isKnownNonNegative whereas to prove I >= 0 we use
11839 // isKnownPredicate. isKnownPredicate is more powerful, but also more
11840 // expensive; and using isKnownNonNegative(RHS) is sufficient for most of the
11841 // interesting cases seen in practice. We can consider "upgrading" L >= 0 to
11842 // use isKnownPredicate later if needed.
11843 return isKnownNonNegative(RHS) &&
11846}
11847
11848bool ScalarEvolution::isImpliedViaGuard(const BasicBlock *BB, CmpPredicate Pred,
11849 const SCEV *LHS, const SCEV *RHS) {
11850 // No need to even try if we know the module has no guards.
11851 if (!HasGuards)
11852 return false;
11853
11854 return any_of(*BB, [&](const Instruction &I) {
11855 using namespace llvm::PatternMatch;
11856
11857 Value *Condition;
11859 m_Value(Condition))) &&
11860 isImpliedCond(Pred, LHS, RHS, Condition, false);
11861 });
11862}
11863
11864/// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is
11865/// protected by a conditional between LHS and RHS. This is used to
11866/// to eliminate casts.
11868 CmpPredicate Pred,
11869 const SCEV *LHS,
11870 const SCEV *RHS) {
11871 // Interpret a null as meaning no loop, where there is obviously no guard
11872 // (interprocedural conditions notwithstanding). Do not bother about
11873 // unreachable loops.
11874 if (!L || !DT.isReachableFromEntry(L->getHeader()))
11875 return true;
11876
11877 if (VerifyIR)
11878 assert(!verifyFunction(*L->getHeader()->getParent(), &dbgs()) &&
11879 "This cannot be done on broken IR!");
11880
11881
11882 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
11883 return true;
11884
11885 BasicBlock *Latch = L->getLoopLatch();
11886 if (!Latch)
11887 return false;
11888
11889 CondBrInst *LoopContinuePredicate =
11891 if (LoopContinuePredicate &&
11892 isImpliedCond(Pred, LHS, RHS, LoopContinuePredicate->getCondition(),
11893 LoopContinuePredicate->getSuccessor(0) != L->getHeader()))
11894 return true;
11895
11896 // We don't want more than one activation of the following loops on the stack
11897 // -- that can lead to O(n!) time complexity.
11898 if (WalkingBEDominatingConds)
11899 return false;
11900
11901 SaveAndRestore ClearOnExit(WalkingBEDominatingConds, true);
11902
11903 // See if we can exploit a trip count to prove the predicate.
11904 const auto &BETakenInfo = getBackedgeTakenInfo(L);
11905 const SCEV *LatchBECount = BETakenInfo.getExact(Latch, this);
11906 if (LatchBECount != getCouldNotCompute()) {
11907 // We know that Latch branches back to the loop header exactly
11908 // LatchBECount times. This means the backdege condition at Latch is
11909 // equivalent to "{0,+,1} u< LatchBECount".
11910 Type *Ty = LatchBECount->getType();
11911 auto NoWrapFlags = SCEV::NoWrapFlags(SCEV::FlagNUW | SCEV::FlagNW);
11912 const SCEV *LoopCounter =
11913 getAddRecExpr(getZero(Ty), getOne(Ty), L, NoWrapFlags);
11914 if (isImpliedCond(Pred, LHS, RHS, ICmpInst::ICMP_ULT, LoopCounter,
11915 LatchBECount))
11916 return true;
11917 }
11918
11919 // Check conditions due to any @llvm.assume intrinsics.
11920 for (auto &AssumeVH : AC.assumptions()) {
11921 if (!AssumeVH)
11922 continue;
11923 auto *CI = cast<CallInst>(AssumeVH);
11924 if (!DT.dominates(CI, Latch->getTerminator()))
11925 continue;
11926
11927 if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false))
11928 return true;
11929 }
11930
11931 if (isImpliedViaGuard(Latch, Pred, LHS, RHS))
11932 return true;
11933
11934 for (DomTreeNode *DTN = DT[Latch], *HeaderDTN = DT[L->getHeader()];
11935 DTN != HeaderDTN; DTN = DTN->getIDom()) {
11936 assert(DTN && "should reach the loop header before reaching the root!");
11937
11938 BasicBlock *BB = DTN->getBlock();
11939 if (isImpliedViaGuard(BB, Pred, LHS, RHS))
11940 return true;
11941
11942 BasicBlock *PBB = BB->getSinglePredecessor();
11943 if (!PBB)
11944 continue;
11945
11947 if (!ContBr || ContBr->getSuccessor(0) == ContBr->getSuccessor(1))
11948 continue;
11949
11950 // If we have an edge `E` within the loop body that dominates the only
11951 // latch, the condition guarding `E` also guards the backedge. This
11952 // reasoning works only for loops with a single latch.
11953 // We're constructively (and conservatively) enumerating edges within the
11954 // loop body that dominate the latch. The dominator tree better agree
11955 // with us on this:
11956 assert(DT.dominates(BasicBlockEdge(PBB, BB), Latch) && "should be!");
11957 if (isImpliedCond(Pred, LHS, RHS, ContBr->getCondition(),
11958 BB != ContBr->getSuccessor(0)))
11959 return true;
11960 }
11961
11962 return false;
11963}
11964
11966 CmpPredicate Pred,
11967 const SCEV *LHS,
11968 const SCEV *RHS) {
11969 // Do not bother proving facts for unreachable code.
11970 if (!DT.isReachableFromEntry(BB))
11971 return true;
11972 if (VerifyIR)
11973 assert(!verifyFunction(*BB->getParent(), &dbgs()) &&
11974 "This cannot be done on broken IR!");
11975
11976 // If we cannot prove strict comparison (e.g. a > b), maybe we can prove
11977 // the facts (a >= b && a != b) separately. A typical situation is when the
11978 // non-strict comparison is known from ranges and non-equality is known from
11979 // dominating predicates. If we are proving strict comparison, we always try
11980 // to prove non-equality and non-strict comparison separately.
11981 CmpPredicate NonStrictPredicate = ICmpInst::getNonStrictCmpPredicate(Pred);
11982 const bool ProvingStrictComparison =
11983 Pred != NonStrictPredicate.dropSameSign();
11984 bool ProvedNonStrictComparison = false;
11985 bool ProvedNonEquality = false;
11986
11987 auto SplitAndProve = [&](std::function<bool(CmpPredicate)> Fn) -> bool {
11988 if (!ProvedNonStrictComparison)
11989 ProvedNonStrictComparison = Fn(NonStrictPredicate);
11990 if (!ProvedNonEquality)
11991 ProvedNonEquality = Fn(ICmpInst::ICMP_NE);
11992 if (ProvedNonStrictComparison && ProvedNonEquality)
11993 return true;
11994 return false;
11995 };
11996
11997 if (ProvingStrictComparison) {
11998 auto ProofFn = [&](CmpPredicate P) {
11999 return isKnownViaNonRecursiveReasoning(P, LHS, RHS);
12000 };
12001 if (SplitAndProve(ProofFn))
12002 return true;
12003 }
12004
12005 // Try to prove (Pred, LHS, RHS) using isImpliedCond.
12006 auto ProveViaCond = [&](const Value *Condition, bool Inverse) {
12007 const Instruction *CtxI = &BB->front();
12008 if (isImpliedCond(Pred, LHS, RHS, Condition, Inverse, CtxI))
12009 return true;
12010 if (ProvingStrictComparison) {
12011 auto ProofFn = [&](CmpPredicate P) {
12012 return isImpliedCond(P, LHS, RHS, Condition, Inverse, CtxI);
12013 };
12014 if (SplitAndProve(ProofFn))
12015 return true;
12016 }
12017 return false;
12018 };
12019
12020 // Starting at the block's predecessor, climb up the predecessor chain, as long
12021 // as there are predecessors that can be found that have unique successors
12022 // leading to the original block.
12023 const Loop *ContainingLoop = LI.getLoopFor(BB);
12024 const BasicBlock *PredBB;
12025 if (ContainingLoop && ContainingLoop->getHeader() == BB)
12026 PredBB = ContainingLoop->getLoopPredecessor();
12027 else
12028 PredBB = BB->getSinglePredecessor();
12029 for (std::pair<const BasicBlock *, const BasicBlock *> Pair(PredBB, BB);
12030 Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
12031 const CondBrInst *BlockEntryPredicate =
12032 dyn_cast<CondBrInst>(Pair.first->getTerminator());
12033 if (!BlockEntryPredicate)
12034 continue;
12035
12036 if (ProveViaCond(BlockEntryPredicate->getCondition(),
12037 BlockEntryPredicate->getSuccessor(0) != Pair.second))
12038 return true;
12039 }
12040
12041 // Check conditions due to any @llvm.assume intrinsics.
12042 for (auto &AssumeVH : AC.assumptions()) {
12043 if (!AssumeVH)
12044 continue;
12045 auto *CI = cast<CallInst>(AssumeVH);
12046 if (!DT.dominates(CI, BB))
12047 continue;
12048
12049 if (ProveViaCond(CI->getArgOperand(0), false))
12050 return true;
12051 }
12052
12053 // Check conditions due to any @llvm.experimental.guard intrinsics.
12054 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
12055 F.getParent(), Intrinsic::experimental_guard);
12056 if (GuardDecl)
12057 for (const auto *GU : GuardDecl->users())
12058 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
12059 if (Guard->getFunction() == BB->getParent() && DT.dominates(Guard, BB))
12060 if (ProveViaCond(Guard->getArgOperand(0), false))
12061 return true;
12062 return false;
12063}
12064
12066 const SCEV *LHS,
12067 const SCEV *RHS) {
12068 // Interpret a null as meaning no loop, where there is obviously no guard
12069 // (interprocedural conditions notwithstanding).
12070 if (!L)
12071 return false;
12072
12073 // Both LHS and RHS must be available at loop entry.
12075 "LHS is not available at Loop Entry");
12077 "RHS is not available at Loop Entry");
12078
12079 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
12080 return true;
12081
12082 return isBasicBlockEntryGuardedByCond(L->getHeader(), Pred, LHS, RHS);
12083}
12084
12085bool ScalarEvolution::isImpliedCond(CmpPredicate Pred, const SCEV *LHS,
12086 const SCEV *RHS,
12087 const Value *FoundCondValue, bool Inverse,
12088 const Instruction *CtxI) {
12089 // False conditions implies anything. Do not bother analyzing it further.
12090 if (FoundCondValue ==
12091 ConstantInt::getBool(FoundCondValue->getContext(), Inverse))
12092 return true;
12093
12094 if (!PendingLoopPredicates.insert(FoundCondValue).second)
12095 return false;
12096
12097 llvm::scope_exit ClearOnExit(
12098 [&]() { PendingLoopPredicates.erase(FoundCondValue); });
12099
12100 // Recursively handle And and Or conditions.
12101 const Value *Op0, *Op1;
12102 if (match(FoundCondValue, m_LogicalAnd(m_Value(Op0), m_Value(Op1)))) {
12103 if (!Inverse)
12104 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
12105 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
12106 } else if (match(FoundCondValue, m_LogicalOr(m_Value(Op0), m_Value(Op1)))) {
12107 if (Inverse)
12108 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
12109 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
12110 }
12111
12112 const ICmpInst *ICI = dyn_cast<ICmpInst>(FoundCondValue);
12113 if (!ICI) return false;
12114
12115 // Now that we found a conditional branch that dominates the loop or controls
12116 // the loop latch. Check to see if it is the comparison we are looking for.
12117 CmpPredicate FoundPred;
12118 if (Inverse)
12119 FoundPred = ICI->getInverseCmpPredicate();
12120 else
12121 FoundPred = ICI->getCmpPredicate();
12122
12123 const SCEV *FoundLHS = getSCEV(ICI->getOperand(0));
12124 const SCEV *FoundRHS = getSCEV(ICI->getOperand(1));
12125
12126 return isImpliedCond(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS, CtxI);
12127}
12128
12129bool ScalarEvolution::isImpliedCond(CmpPredicate Pred, const SCEV *LHS,
12130 const SCEV *RHS, CmpPredicate FoundPred,
12131 const SCEV *FoundLHS, const SCEV *FoundRHS,
12132 const Instruction *CtxI) {
12133 // Balance the types.
12134 if (getTypeSizeInBits(LHS->getType()) <
12135 getTypeSizeInBits(FoundLHS->getType())) {
12136 // For unsigned and equality predicates, try to prove that both found
12137 // operands fit into narrow unsigned range. If so, try to prove facts in
12138 // narrow types.
12139 if (!CmpInst::isSigned(FoundPred) && !FoundLHS->getType()->isPointerTy() &&
12140 !FoundRHS->getType()->isPointerTy()) {
12141 auto *NarrowType = LHS->getType();
12142 auto *WideType = FoundLHS->getType();
12143 auto BitWidth = getTypeSizeInBits(NarrowType);
12144 const SCEV *MaxValue = getZeroExtendExpr(
12146 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundLHS,
12147 MaxValue) &&
12148 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundRHS,
12149 MaxValue)) {
12150 const SCEV *TruncFoundLHS = getTruncateExpr(FoundLHS, NarrowType);
12151 const SCEV *TruncFoundRHS = getTruncateExpr(FoundRHS, NarrowType);
12152 // We cannot preserve samesign after truncation.
12153 if (isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred.dropSameSign(),
12154 TruncFoundLHS, TruncFoundRHS, CtxI))
12155 return true;
12156 }
12157 }
12158
12159 if (LHS->getType()->isPointerTy() || RHS->getType()->isPointerTy())
12160 return false;
12161 if (CmpInst::isSigned(Pred)) {
12162 LHS = getSignExtendExpr(LHS, FoundLHS->getType());
12163 RHS = getSignExtendExpr(RHS, FoundLHS->getType());
12164 } else {
12165 LHS = getZeroExtendExpr(LHS, FoundLHS->getType());
12166 RHS = getZeroExtendExpr(RHS, FoundLHS->getType());
12167 }
12168 } else if (getTypeSizeInBits(LHS->getType()) >
12169 getTypeSizeInBits(FoundLHS->getType())) {
12170 if (FoundLHS->getType()->isPointerTy() || FoundRHS->getType()->isPointerTy())
12171 return false;
12172 if (CmpInst::isSigned(FoundPred)) {
12173 FoundLHS = getSignExtendExpr(FoundLHS, LHS->getType());
12174 FoundRHS = getSignExtendExpr(FoundRHS, LHS->getType());
12175 } else {
12176 FoundLHS = getZeroExtendExpr(FoundLHS, LHS->getType());
12177 FoundRHS = getZeroExtendExpr(FoundRHS, LHS->getType());
12178 }
12179 }
12180 return isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, FoundLHS,
12181 FoundRHS, CtxI);
12182}
12183
12184bool ScalarEvolution::isImpliedCondBalancedTypes(
12185 CmpPredicate Pred, SCEVUse LHS, SCEVUse RHS, CmpPredicate FoundPred,
12186 SCEVUse FoundLHS, SCEVUse FoundRHS, const Instruction *CtxI) {
12188 getTypeSizeInBits(FoundLHS->getType()) &&
12189 "Types should be balanced!");
12190 // Canonicalize the query to match the way instcombine will have
12191 // canonicalized the comparison.
12192 if (SimplifyICmpOperands(Pred, LHS, RHS))
12193 if (LHS == RHS)
12194 return CmpInst::isTrueWhenEqual(Pred);
12195 if (SimplifyICmpOperands(FoundPred, FoundLHS, FoundRHS))
12196 if (FoundLHS == FoundRHS)
12197 return CmpInst::isFalseWhenEqual(FoundPred);
12198
12199 // Check to see if we can make the LHS or RHS match.
12200 if (LHS == FoundRHS || RHS == FoundLHS) {
12201 if (isa<SCEVConstant>(RHS)) {
12202 std::swap(FoundLHS, FoundRHS);
12203 FoundPred = ICmpInst::getSwappedCmpPredicate(FoundPred);
12204 } else {
12205 std::swap(LHS, RHS);
12207 }
12208 }
12209
12210 // Check whether the found predicate is the same as the desired predicate.
12211 if (auto P = CmpPredicate::getMatching(FoundPred, Pred))
12212 return isImpliedCondOperands(*P, LHS, RHS, FoundLHS, FoundRHS, CtxI);
12213
12214 // Check whether swapping the found predicate makes it the same as the
12215 // desired predicate.
12216 if (auto P = CmpPredicate::getMatching(
12217 ICmpInst::getSwappedCmpPredicate(FoundPred), Pred)) {
12218 // We can write the implication
12219 // 0. LHS Pred RHS <- FoundLHS SwapPred FoundRHS
12220 // using one of the following ways:
12221 // 1. LHS Pred RHS <- FoundRHS Pred FoundLHS
12222 // 2. RHS SwapPred LHS <- FoundLHS SwapPred FoundRHS
12223 // 3. LHS Pred RHS <- ~FoundLHS Pred ~FoundRHS
12224 // 4. ~LHS SwapPred ~RHS <- FoundLHS SwapPred FoundRHS
12225 // Forms 1. and 2. require swapping the operands of one condition. Don't
12226 // do this if it would break canonical constant/addrec ordering.
12228 return isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(*P), RHS,
12229 LHS, FoundLHS, FoundRHS, CtxI);
12230 if (!isa<SCEVConstant>(FoundRHS) && !isa<SCEVAddRecExpr>(FoundLHS))
12231 return isImpliedCondOperands(*P, LHS, RHS, FoundRHS, FoundLHS, CtxI);
12232
12233 // There's no clear preference between forms 3. and 4., try both. Avoid
12234 // forming getNotSCEV of pointer values as the resulting subtract is
12235 // not legal.
12236 if (!LHS->getType()->isPointerTy() && !RHS->getType()->isPointerTy() &&
12237 isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(*P),
12238 getNotSCEV(LHS), getNotSCEV(RHS), FoundLHS,
12239 FoundRHS, CtxI))
12240 return true;
12241
12242 if (!FoundLHS->getType()->isPointerTy() &&
12243 !FoundRHS->getType()->isPointerTy() &&
12244 isImpliedCondOperands(*P, LHS, RHS, getNotSCEV(FoundLHS),
12245 getNotSCEV(FoundRHS), CtxI))
12246 return true;
12247
12248 return false;
12249 }
12250
12251 auto IsSignFlippedPredicate = [](CmpInst::Predicate P1,
12253 assert(P1 != P2 && "Handled earlier!");
12254 return CmpInst::isRelational(P2) &&
12256 };
12257 if (IsSignFlippedPredicate(Pred, FoundPred)) {
12258 // Unsigned comparison is the same as signed comparison when both the
12259 // operands are non-negative or negative.
12260 if (haveSameSign(FoundLHS, FoundRHS))
12261 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI);
12262 // Create local copies that we can freely swap and canonicalize our
12263 // conditions to "le/lt".
12264 CmpPredicate CanonicalPred = Pred, CanonicalFoundPred = FoundPred;
12265 const SCEV *CanonicalLHS = LHS, *CanonicalRHS = RHS,
12266 *CanonicalFoundLHS = FoundLHS, *CanonicalFoundRHS = FoundRHS;
12267 if (ICmpInst::isGT(CanonicalPred) || ICmpInst::isGE(CanonicalPred)) {
12268 CanonicalPred = ICmpInst::getSwappedCmpPredicate(CanonicalPred);
12269 CanonicalFoundPred = ICmpInst::getSwappedCmpPredicate(CanonicalFoundPred);
12270 std::swap(CanonicalLHS, CanonicalRHS);
12271 std::swap(CanonicalFoundLHS, CanonicalFoundRHS);
12272 }
12273 assert((ICmpInst::isLT(CanonicalPred) || ICmpInst::isLE(CanonicalPred)) &&
12274 "Must be!");
12275 assert((ICmpInst::isLT(CanonicalFoundPred) ||
12276 ICmpInst::isLE(CanonicalFoundPred)) &&
12277 "Must be!");
12278 if (ICmpInst::isSigned(CanonicalPred) && isKnownNonNegative(CanonicalRHS))
12279 // Use implication:
12280 // x <u y && y >=s 0 --> x <s y.
12281 // If we can prove the left part, the right part is also proven.
12282 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
12283 CanonicalRHS, CanonicalFoundLHS,
12284 CanonicalFoundRHS);
12285 if (ICmpInst::isUnsigned(CanonicalPred) && isKnownNegative(CanonicalRHS))
12286 // Use implication:
12287 // x <s y && y <s 0 --> x <u y.
12288 // If we can prove the left part, the right part is also proven.
12289 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
12290 CanonicalRHS, CanonicalFoundLHS,
12291 CanonicalFoundRHS);
12292 }
12293
12294 // Check if we can make progress by sharpening ranges.
12295 if (FoundPred == ICmpInst::ICMP_NE &&
12296 (isa<SCEVConstant>(FoundLHS) || isa<SCEVConstant>(FoundRHS))) {
12297
12298 const SCEVConstant *C = nullptr;
12299 const SCEV *V = nullptr;
12300
12301 if (isa<SCEVConstant>(FoundLHS)) {
12302 C = cast<SCEVConstant>(FoundLHS);
12303 V = FoundRHS;
12304 } else {
12305 C = cast<SCEVConstant>(FoundRHS);
12306 V = FoundLHS;
12307 }
12308
12309 // The guarding predicate tells us that C != V. If the known range
12310 // of V is [C, t), we can sharpen the range to [C + 1, t). The
12311 // range we consider has to correspond to same signedness as the
12312 // predicate we're interested in folding.
12313
12314 APInt Min = ICmpInst::isSigned(Pred) ?
12316
12317 if (Min == C->getAPInt()) {
12318 // Given (V >= Min && V != Min) we conclude V >= (Min + 1).
12319 // This is true even if (Min + 1) wraps around -- in case of
12320 // wraparound, (Min + 1) < Min, so (V >= Min => V >= (Min + 1)).
12321
12322 APInt SharperMin = Min + 1;
12323
12324 switch (Pred) {
12325 case ICmpInst::ICMP_SGE:
12326 case ICmpInst::ICMP_UGE:
12327 // We know V `Pred` SharperMin. If this implies LHS `Pred`
12328 // RHS, we're done.
12329 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(SharperMin),
12330 CtxI))
12331 return true;
12332 [[fallthrough]];
12333
12334 case ICmpInst::ICMP_SGT:
12335 case ICmpInst::ICMP_UGT:
12336 // We know from the range information that (V `Pred` Min ||
12337 // V == Min). We know from the guarding condition that !(V
12338 // == Min). This gives us
12339 //
12340 // V `Pred` Min || V == Min && !(V == Min)
12341 // => V `Pred` Min
12342 //
12343 // If V `Pred` Min implies LHS `Pred` RHS, we're done.
12344
12345 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min), CtxI))
12346 return true;
12347 break;
12348
12349 // `LHS < RHS` and `LHS <= RHS` are handled in the same way as `RHS > LHS` and `RHS >= LHS` respectively.
12350 case ICmpInst::ICMP_SLE:
12351 case ICmpInst::ICMP_ULE:
12352 if (isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(Pred), RHS,
12353 LHS, V, getConstant(SharperMin), CtxI))
12354 return true;
12355 [[fallthrough]];
12356
12357 case ICmpInst::ICMP_SLT:
12358 case ICmpInst::ICMP_ULT:
12359 if (isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(Pred), RHS,
12360 LHS, V, getConstant(Min), CtxI))
12361 return true;
12362 break;
12363
12364 default:
12365 // No change
12366 break;
12367 }
12368 }
12369 }
12370
12371 // Check whether the actual condition is beyond sufficient.
12372 if (FoundPred == ICmpInst::ICMP_EQ)
12373 if (ICmpInst::isTrueWhenEqual(Pred))
12374 if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
12375 return true;
12376 if (Pred == ICmpInst::ICMP_NE)
12377 if (!ICmpInst::isTrueWhenEqual(FoundPred))
12378 if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
12379 return true;
12380
12381 if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS))
12382 return true;
12383
12384 // Otherwise assume the worst.
12385 return false;
12386}
12387
12388bool ScalarEvolution::splitBinaryAdd(SCEVUse Expr, SCEVUse &L, SCEVUse &R,
12389 SCEV::NoWrapFlags &Flags) {
12390 if (!match(Expr, m_scev_Add(m_SCEV(L), m_SCEV(R))))
12391 return false;
12392
12393 Flags = cast<SCEVAddExpr>(Expr)->getNoWrapFlags();
12394 return true;
12395}
12396
12397std::optional<APInt>
12399 // We avoid subtracting expressions here because this function is usually
12400 // fairly deep in the call stack (i.e. is called many times).
12401
12402 unsigned BW = getTypeSizeInBits(More->getType());
12403 APInt Diff(BW, 0);
12404 APInt DiffMul(BW, 1);
12405 // Try various simplifications to reduce the difference to a constant. Limit
12406 // the number of allowed simplifications to keep compile-time low.
12407 for (unsigned I = 0; I < 8; ++I) {
12408 if (More == Less)
12409 return Diff;
12410
12411 // Reduce addrecs with identical steps to their start value.
12413 const auto *LAR = cast<SCEVAddRecExpr>(Less);
12414 const auto *MAR = cast<SCEVAddRecExpr>(More);
12415
12416 if (LAR->getLoop() != MAR->getLoop())
12417 return std::nullopt;
12418
12419 // We look at affine expressions only; not for correctness but to keep
12420 // getStepRecurrence cheap.
12421 if (!LAR->isAffine() || !MAR->isAffine())
12422 return std::nullopt;
12423
12424 if (LAR->getStepRecurrence(*this) != MAR->getStepRecurrence(*this))
12425 return std::nullopt;
12426
12427 Less = LAR->getStart();
12428 More = MAR->getStart();
12429 continue;
12430 }
12431
12432 // Try to match a common constant multiply.
12433 auto MatchConstMul =
12434 [](const SCEV *S) -> std::optional<std::pair<const SCEV *, APInt>> {
12435 const APInt *C;
12436 const SCEV *Op;
12437 if (match(S, m_scev_Mul(m_scev_APInt(C), m_SCEV(Op))))
12438 return {{Op, *C}};
12439 return std::nullopt;
12440 };
12441 if (auto MatchedMore = MatchConstMul(More)) {
12442 if (auto MatchedLess = MatchConstMul(Less)) {
12443 if (MatchedMore->second == MatchedLess->second) {
12444 More = MatchedMore->first;
12445 Less = MatchedLess->first;
12446 DiffMul *= MatchedMore->second;
12447 continue;
12448 }
12449 }
12450 }
12451
12452 // Try to cancel out common factors in two add expressions.
12454 auto Add = [&](const SCEV *S, int Mul) {
12455 if (auto *C = dyn_cast<SCEVConstant>(S)) {
12456 if (Mul == 1) {
12457 Diff += C->getAPInt() * DiffMul;
12458 } else {
12459 assert(Mul == -1);
12460 Diff -= C->getAPInt() * DiffMul;
12461 }
12462 } else
12463 Multiplicity[S] += Mul;
12464 };
12465 auto Decompose = [&](const SCEV *S, int Mul) {
12466 if (isa<SCEVAddExpr>(S)) {
12467 for (const SCEV *Op : S->operands())
12468 Add(Op, Mul);
12469 } else
12470 Add(S, Mul);
12471 };
12472 Decompose(More, 1);
12473 Decompose(Less, -1);
12474
12475 // Check whether all the non-constants cancel out, or reduce to new
12476 // More/Less values.
12477 const SCEV *NewMore = nullptr, *NewLess = nullptr;
12478 for (const auto &[S, Mul] : Multiplicity) {
12479 if (Mul == 0)
12480 continue;
12481 if (Mul == 1) {
12482 if (NewMore)
12483 return std::nullopt;
12484 NewMore = S;
12485 } else if (Mul == -1) {
12486 if (NewLess)
12487 return std::nullopt;
12488 NewLess = S;
12489 } else
12490 return std::nullopt;
12491 }
12492
12493 // Values stayed the same, no point in trying further.
12494 if (NewMore == More || NewLess == Less)
12495 return std::nullopt;
12496
12497 More = NewMore;
12498 Less = NewLess;
12499
12500 // Reduced to constant.
12501 if (!More && !Less)
12502 return Diff;
12503
12504 // Left with variable on only one side, bail out.
12505 if (!More || !Less)
12506 return std::nullopt;
12507 }
12508
12509 // Did not reduce to constant.
12510 return std::nullopt;
12511}
12512
12513bool ScalarEvolution::isImpliedCondOperandsViaAddRecStart(
12514 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS,
12515 const SCEV *FoundRHS, const Instruction *CtxI) {
12516 // Try to recognize the following pattern:
12517 //
12518 // FoundRHS = ...
12519 // ...
12520 // loop:
12521 // FoundLHS = {Start,+,W}
12522 // context_bb: // Basic block from the same loop
12523 // known(Pred, FoundLHS, FoundRHS)
12524 //
12525 // If some predicate is known in the context of a loop, it is also known on
12526 // each iteration of this loop, including the first iteration. Therefore, in
12527 // this case, `FoundLHS Pred FoundRHS` implies `Start Pred FoundRHS`. Try to
12528 // prove the original pred using this fact.
12529 if (!CtxI)
12530 return false;
12531 const BasicBlock *ContextBB = CtxI->getParent();
12532 // Make sure AR varies in the context block.
12533 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundLHS)) {
12534 const Loop *L = AR->getLoop();
12535 const auto *Latch = L->getLoopLatch();
12536 // Make sure that context belongs to the loop and executes on 1st iteration
12537 // (if it ever executes at all).
12538 if (!L->contains(ContextBB) || !Latch || !DT.dominates(ContextBB, Latch))
12539 return false;
12540 if (!isAvailableAtLoopEntry(FoundRHS, AR->getLoop()))
12541 return false;
12542 return isImpliedCondOperands(Pred, LHS, RHS, AR->getStart(), FoundRHS);
12543 }
12544
12545 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundRHS)) {
12546 const Loop *L = AR->getLoop();
12547 const auto *Latch = L->getLoopLatch();
12548 // Make sure that context belongs to the loop and executes on 1st iteration
12549 // (if it ever executes at all).
12550 if (!L->contains(ContextBB) || !Latch || !DT.dominates(ContextBB, Latch))
12551 return false;
12552 if (!isAvailableAtLoopEntry(FoundLHS, AR->getLoop()))
12553 return false;
12554 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, AR->getStart());
12555 }
12556
12557 return false;
12558}
12559
12560bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow(CmpPredicate Pred,
12561 const SCEV *LHS,
12562 const SCEV *RHS,
12563 const SCEV *FoundLHS,
12564 const SCEV *FoundRHS) {
12565 if (Pred != CmpInst::ICMP_SLT && Pred != CmpInst::ICMP_ULT)
12566 return false;
12567
12568 const auto *AddRecLHS = dyn_cast<SCEVAddRecExpr>(LHS);
12569 if (!AddRecLHS)
12570 return false;
12571
12572 const auto *AddRecFoundLHS = dyn_cast<SCEVAddRecExpr>(FoundLHS);
12573 if (!AddRecFoundLHS)
12574 return false;
12575
12576 // We'd like to let SCEV reason about control dependencies, so we constrain
12577 // both the inequalities to be about add recurrences on the same loop. This
12578 // way we can use isLoopEntryGuardedByCond later.
12579
12580 const Loop *L = AddRecFoundLHS->getLoop();
12581 if (L != AddRecLHS->getLoop())
12582 return false;
12583
12584 // FoundLHS u< FoundRHS u< -C => (FoundLHS + C) u< (FoundRHS + C) ... (1)
12585 //
12586 // FoundLHS s< FoundRHS s< INT_MIN - C => (FoundLHS + C) s< (FoundRHS + C)
12587 // ... (2)
12588 //
12589 // Informal proof for (2), assuming (1) [*]:
12590 //
12591 // We'll also assume (A s< B) <=> ((A + INT_MIN) u< (B + INT_MIN)) ... (3)[**]
12592 //
12593 // Then
12594 //
12595 // FoundLHS s< FoundRHS s< INT_MIN - C
12596 // <=> (FoundLHS + INT_MIN) u< (FoundRHS + INT_MIN) u< -C [ using (3) ]
12597 // <=> (FoundLHS + INT_MIN + C) u< (FoundRHS + INT_MIN + C) [ using (1) ]
12598 // <=> (FoundLHS + INT_MIN + C + INT_MIN) s<
12599 // (FoundRHS + INT_MIN + C + INT_MIN) [ using (3) ]
12600 // <=> FoundLHS + C s< FoundRHS + C
12601 //
12602 // [*]: (1) can be proved by ruling out overflow.
12603 //
12604 // [**]: This can be proved by analyzing all the four possibilities:
12605 // (A s< 0, B s< 0), (A s< 0, B s>= 0), (A s>= 0, B s< 0) and
12606 // (A s>= 0, B s>= 0).
12607 //
12608 // Note:
12609 // Despite (2), "FoundRHS s< INT_MIN - C" does not mean that "FoundRHS + C"
12610 // will not sign underflow. For instance, say FoundLHS = (i8 -128), FoundRHS
12611 // = (i8 -127) and C = (i8 -100). Then INT_MIN - C = (i8 -28), and FoundRHS
12612 // s< (INT_MIN - C). Lack of sign overflow / underflow in "FoundRHS + C" is
12613 // neither necessary nor sufficient to prove "(FoundLHS + C) s< (FoundRHS +
12614 // C)".
12615
12616 std::optional<APInt> LDiff = computeConstantDifference(LHS, FoundLHS);
12617 if (!LDiff)
12618 return false;
12619 std::optional<APInt> RDiff = computeConstantDifference(RHS, FoundRHS);
12620 if (!RDiff || *LDiff != *RDiff)
12621 return false;
12622
12623 if (LDiff->isMinValue())
12624 return true;
12625
12626 APInt FoundRHSLimit;
12627
12628 if (Pred == CmpInst::ICMP_ULT) {
12629 FoundRHSLimit = -(*RDiff);
12630 } else {
12631 assert(Pred == CmpInst::ICMP_SLT && "Checked above!");
12632 FoundRHSLimit = APInt::getSignedMinValue(getTypeSizeInBits(RHS->getType())) - *RDiff;
12633 }
12634
12635 // Try to prove (1) or (2), as needed.
12636 return isAvailableAtLoopEntry(FoundRHS, L) &&
12637 isLoopEntryGuardedByCond(L, Pred, FoundRHS,
12638 getConstant(FoundRHSLimit));
12639}
12640
12641bool ScalarEvolution::isImpliedViaMerge(CmpPredicate Pred, const SCEV *LHS,
12642 const SCEV *RHS, const SCEV *FoundLHS,
12643 const SCEV *FoundRHS, unsigned Depth) {
12644 const PHINode *LPhi = nullptr, *RPhi = nullptr;
12645
12646 llvm::scope_exit ClearOnExit([&]() {
12647 if (LPhi) {
12648 bool Erased = PendingMerges.erase(LPhi);
12649 assert(Erased && "Failed to erase LPhi!");
12650 (void)Erased;
12651 }
12652 if (RPhi) {
12653 bool Erased = PendingMerges.erase(RPhi);
12654 assert(Erased && "Failed to erase RPhi!");
12655 (void)Erased;
12656 }
12657 });
12658
12659 // Find respective Phis and check that they are not being pending.
12660 if (const SCEVUnknown *LU = dyn_cast<SCEVUnknown>(LHS))
12661 if (auto *Phi = dyn_cast<PHINode>(LU->getValue())) {
12662 if (!PendingMerges.insert(Phi).second)
12663 return false;
12664 LPhi = Phi;
12665 }
12666 if (const SCEVUnknown *RU = dyn_cast<SCEVUnknown>(RHS))
12667 if (auto *Phi = dyn_cast<PHINode>(RU->getValue())) {
12668 // If we detect a loop of Phi nodes being processed by this method, for
12669 // example:
12670 //
12671 // %a = phi i32 [ %some1, %preheader ], [ %b, %latch ]
12672 // %b = phi i32 [ %some2, %preheader ], [ %a, %latch ]
12673 //
12674 // we don't want to deal with a case that complex, so return conservative
12675 // answer false.
12676 if (!PendingMerges.insert(Phi).second)
12677 return false;
12678 RPhi = Phi;
12679 }
12680
12681 // If none of LHS, RHS is a Phi, nothing to do here.
12682 if (!LPhi && !RPhi)
12683 return false;
12684
12685 // If there is a SCEVUnknown Phi we are interested in, make it left.
12686 if (!LPhi) {
12687 std::swap(LHS, RHS);
12688 std::swap(FoundLHS, FoundRHS);
12689 std::swap(LPhi, RPhi);
12691 }
12692
12693 assert(LPhi && "LPhi should definitely be a SCEVUnknown Phi!");
12694 const BasicBlock *LBB = LPhi->getParent();
12695 const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
12696
12697 auto ProvedEasily = [&](const SCEV *S1, const SCEV *S2) {
12698 return isKnownViaNonRecursiveReasoning(Pred, S1, S2) ||
12699 isImpliedCondOperandsViaRanges(Pred, S1, S2, Pred, FoundLHS, FoundRHS) ||
12700 isImpliedViaOperations(Pred, S1, S2, FoundLHS, FoundRHS, Depth);
12701 };
12702
12703 if (RPhi && RPhi->getParent() == LBB) {
12704 // Case one: RHS is also a SCEVUnknown Phi from the same basic block.
12705 // If we compare two Phis from the same block, and for each entry block
12706 // the predicate is true for incoming values from this block, then the
12707 // predicate is also true for the Phis.
12708 for (const BasicBlock *IncBB : predecessors(LBB)) {
12709 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12710 const SCEV *R = getSCEV(RPhi->getIncomingValueForBlock(IncBB));
12711 if (!ProvedEasily(L, R))
12712 return false;
12713 }
12714 } else if (RAR && RAR->getLoop()->getHeader() == LBB) {
12715 // Case two: RHS is also a Phi from the same basic block, and it is an
12716 // AddRec. It means that there is a loop which has both AddRec and Unknown
12717 // PHIs, for it we can compare incoming values of AddRec from above the loop
12718 // and latch with their respective incoming values of LPhi.
12719 // TODO: Generalize to handle loops with many inputs in a header.
12720 if (LPhi->getNumIncomingValues() != 2) return false;
12721
12722 auto *RLoop = RAR->getLoop();
12723 auto *Predecessor = RLoop->getLoopPredecessor();
12724 assert(Predecessor && "Loop with AddRec with no predecessor?");
12725 const SCEV *L1 = getSCEV(LPhi->getIncomingValueForBlock(Predecessor));
12726 if (!ProvedEasily(L1, RAR->getStart()))
12727 return false;
12728 auto *Latch = RLoop->getLoopLatch();
12729 assert(Latch && "Loop with AddRec with no latch?");
12730 const SCEV *L2 = getSCEV(LPhi->getIncomingValueForBlock(Latch));
12731 if (!ProvedEasily(L2, RAR->getPostIncExpr(*this)))
12732 return false;
12733 } else {
12734 // In all other cases go over inputs of LHS and compare each of them to RHS,
12735 // the predicate is true for (LHS, RHS) if it is true for all such pairs.
12736 // At this point RHS is either a non-Phi, or it is a Phi from some block
12737 // different from LBB.
12738 for (const BasicBlock *IncBB : predecessors(LBB)) {
12739 // Check that RHS is available in this block.
12740 if (!dominates(RHS, IncBB))
12741 return false;
12742 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12743 // Make sure L does not refer to a value from a potentially previous
12744 // iteration of a loop.
12745 if (!properlyDominates(L, LBB))
12746 return false;
12747 // Addrecs are considered to properly dominate their loop, so are missed
12748 // by the previous check. Discard any values that have computable
12749 // evolution in this loop.
12750 if (auto *Loop = LI.getLoopFor(LBB))
12751 if (hasComputableLoopEvolution(L, Loop))
12752 return false;
12753 if (!ProvedEasily(L, RHS))
12754 return false;
12755 }
12756 }
12757 return true;
12758}
12759
12760bool ScalarEvolution::isImpliedCondOperandsViaShift(CmpPredicate Pred,
12761 const SCEV *LHS,
12762 const SCEV *RHS,
12763 const SCEV *FoundLHS,
12764 const SCEV *FoundRHS) {
12765 // We want to imply LHS < RHS from LHS < (RHS >> shiftvalue). First, make
12766 // sure that we are dealing with same LHS.
12767 if (RHS == FoundRHS) {
12768 std::swap(LHS, RHS);
12769 std::swap(FoundLHS, FoundRHS);
12771 }
12772 if (LHS != FoundLHS)
12773 return false;
12774
12775 auto *SUFoundRHS = dyn_cast<SCEVUnknown>(FoundRHS);
12776 if (!SUFoundRHS)
12777 return false;
12778
12779 Value *Shiftee, *ShiftValue;
12780
12781 using namespace PatternMatch;
12782 if (match(SUFoundRHS->getValue(),
12783 m_LShr(m_Value(Shiftee), m_Value(ShiftValue)))) {
12784 auto *ShifteeS = getSCEV(Shiftee);
12785 // Prove one of the following:
12786 // LHS <u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <u RHS
12787 // LHS <=u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <=u RHS
12788 // LHS <s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12789 // ---> LHS <s RHS
12790 // LHS <=s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12791 // ---> LHS <=s RHS
12792 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE)
12793 return isKnownPredicate(ICmpInst::ICMP_ULE, ShifteeS, RHS);
12794 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE)
12795 if (isKnownNonNegative(ShifteeS))
12796 return isKnownPredicate(ICmpInst::ICMP_SLE, ShifteeS, RHS);
12797 }
12798
12799 return false;
12800}
12801
12802bool ScalarEvolution::isImpliedCondOperands(CmpPredicate Pred, const SCEV *LHS,
12803 const SCEV *RHS,
12804 const SCEV *FoundLHS,
12805 const SCEV *FoundRHS,
12806 const Instruction *CtxI) {
12807 return isImpliedCondOperandsViaRanges(Pred, LHS, RHS, Pred, FoundLHS,
12808 FoundRHS) ||
12809 isImpliedCondOperandsViaNoOverflow(Pred, LHS, RHS, FoundLHS,
12810 FoundRHS) ||
12811 isImpliedCondOperandsViaShift(Pred, LHS, RHS, FoundLHS, FoundRHS) ||
12812 isImpliedCondOperandsViaAddRecStart(Pred, LHS, RHS, FoundLHS, FoundRHS,
12813 CtxI) ||
12814 isImpliedCondOperandsHelper(Pred, LHS, RHS, FoundLHS, FoundRHS);
12815}
12816
12817/// Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
12818template <typename MinMaxExprType>
12819static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr,
12820 const SCEV *Candidate) {
12821 const MinMaxExprType *MinMaxExpr = dyn_cast<MinMaxExprType>(MaybeMinMaxExpr);
12822 if (!MinMaxExpr)
12823 return false;
12824
12825 return is_contained(MinMaxExpr->operands(), Candidate);
12826}
12827
12829 CmpPredicate Pred, const SCEV *LHS,
12830 const SCEV *RHS) {
12831 // If both sides are affine addrecs for the same loop, with equal
12832 // steps, and we know the recurrences don't wrap, then we only
12833 // need to check the predicate on the starting values.
12834
12835 if (!ICmpInst::isRelational(Pred))
12836 return false;
12837
12838 const SCEV *LStart, *RStart, *Step;
12839 const Loop *L;
12840 if (!match(LHS,
12841 m_scev_AffineAddRec(m_SCEV(LStart), m_SCEV(Step), m_Loop(L))) ||
12843 m_SpecificLoop(L))))
12844 return false;
12849 if (!LAR->getNoWrapFlags(NW) || !RAR->getNoWrapFlags(NW))
12850 return false;
12851
12852 return SE.isKnownPredicate(Pred, LStart, RStart);
12853}
12854
12855/// Is LHS `Pred` RHS true on the virtue of LHS or RHS being a Min or Max
12856/// expression?
12858 const SCEV *LHS, const SCEV *RHS) {
12859 switch (Pred) {
12860 default:
12861 return false;
12862
12863 case ICmpInst::ICMP_SGE:
12864 std::swap(LHS, RHS);
12865 [[fallthrough]];
12866 case ICmpInst::ICMP_SLE:
12867 return
12868 // min(A, ...) <= A
12870 // A <= max(A, ...)
12872
12873 case ICmpInst::ICMP_UGE:
12874 std::swap(LHS, RHS);
12875 [[fallthrough]];
12876 case ICmpInst::ICMP_ULE:
12877 return
12878 // min(A, ...) <= A
12879 // FIXME: what about umin_seq?
12881 // A <= max(A, ...)
12883 }
12884
12885 llvm_unreachable("covered switch fell through?!");
12886}
12887
12888bool ScalarEvolution::isImpliedViaOperations(CmpPredicate Pred, const SCEV *LHS,
12889 const SCEV *RHS,
12890 const SCEV *FoundLHS,
12891 const SCEV *FoundRHS,
12892 unsigned Depth) {
12895 "LHS and RHS have different sizes?");
12896 assert(getTypeSizeInBits(FoundLHS->getType()) ==
12897 getTypeSizeInBits(FoundRHS->getType()) &&
12898 "FoundLHS and FoundRHS have different sizes?");
12899 // We want to avoid hurting the compile time with analysis of too big trees.
12901 return false;
12902
12903 // We only want to work with GT comparison so far.
12904 if (ICmpInst::isLT(Pred)) {
12906 std::swap(LHS, RHS);
12907 std::swap(FoundLHS, FoundRHS);
12908 }
12909
12911
12912 // For unsigned, try to reduce it to corresponding signed comparison.
12913 if (P == ICmpInst::ICMP_UGT)
12914 // We can replace unsigned predicate with its signed counterpart if all
12915 // involved values are non-negative.
12916 // TODO: We could have better support for unsigned.
12917 if (isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) {
12918 // Knowing that both FoundLHS and FoundRHS are non-negative, and knowing
12919 // FoundLHS >u FoundRHS, we also know that FoundLHS >s FoundRHS. Let us
12920 // use this fact to prove that LHS and RHS are non-negative.
12921 const SCEV *MinusOne = getMinusOne(LHS->getType());
12922 if (isImpliedCondOperands(ICmpInst::ICMP_SGT, LHS, MinusOne, FoundLHS,
12923 FoundRHS) &&
12924 isImpliedCondOperands(ICmpInst::ICMP_SGT, RHS, MinusOne, FoundLHS,
12925 FoundRHS))
12927 }
12928
12929 if (P != ICmpInst::ICMP_SGT)
12930 return false;
12931
12932 auto GetOpFromSExt = [&](const SCEV *S) -> const SCEV * {
12933 if (auto *Ext = dyn_cast<SCEVSignExtendExpr>(S))
12934 return Ext->getOperand();
12935 // TODO: If S is a SCEVConstant then you can cheaply "strip" the sext off
12936 // the constant in some cases.
12937 return S;
12938 };
12939
12940 // Acquire values from extensions.
12941 auto *OrigLHS = LHS;
12942 auto *OrigFoundLHS = FoundLHS;
12943 LHS = GetOpFromSExt(LHS);
12944 FoundLHS = GetOpFromSExt(FoundLHS);
12945
12946 // Is the SGT predicate can be proved trivially or using the found context.
12947 auto IsSGTViaContext = [&](const SCEV *S1, const SCEV *S2) {
12948 return isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGT, S1, S2) ||
12949 isImpliedViaOperations(ICmpInst::ICMP_SGT, S1, S2, OrigFoundLHS,
12950 FoundRHS, Depth + 1);
12951 };
12952
12953 if (auto *LHSAddExpr = dyn_cast<SCEVAddExpr>(LHS)) {
12954 // We want to avoid creation of any new non-constant SCEV. Since we are
12955 // going to compare the operands to RHS, we should be certain that we don't
12956 // need any size extensions for this. So let's decline all cases when the
12957 // sizes of types of LHS and RHS do not match.
12958 // TODO: Maybe try to get RHS from sext to catch more cases?
12960 return false;
12961
12962 // Should not overflow.
12963 if (!LHSAddExpr->hasNoSignedWrap())
12964 return false;
12965
12966 SCEVUse LL = LHSAddExpr->getOperand(0);
12967 SCEVUse LR = LHSAddExpr->getOperand(1);
12968 auto *MinusOne = getMinusOne(RHS->getType());
12969
12970 // Checks that S1 >= 0 && S2 > RHS, trivially or using the found context.
12971 auto IsSumGreaterThanRHS = [&](const SCEV *S1, const SCEV *S2) {
12972 return IsSGTViaContext(S1, MinusOne) && IsSGTViaContext(S2, RHS);
12973 };
12974 // Try to prove the following rule:
12975 // (LHS = LL + LR) && (LL >= 0) && (LR > RHS) => (LHS > RHS).
12976 // (LHS = LL + LR) && (LR >= 0) && (LL > RHS) => (LHS > RHS).
12977 if (IsSumGreaterThanRHS(LL, LR) || IsSumGreaterThanRHS(LR, LL))
12978 return true;
12979 } else if (auto *LHSUnknownExpr = dyn_cast<SCEVUnknown>(LHS)) {
12980 Value *LL, *LR;
12981 // FIXME: Once we have SDiv implemented, we can get rid of this matching.
12982
12983 using namespace llvm::PatternMatch;
12984
12985 if (match(LHSUnknownExpr->getValue(), m_SDiv(m_Value(LL), m_Value(LR)))) {
12986 // Rules for division.
12987 // We are going to perform some comparisons with Denominator and its
12988 // derivative expressions. In general case, creating a SCEV for it may
12989 // lead to a complex analysis of the entire graph, and in particular it
12990 // can request trip count recalculation for the same loop. This would
12991 // cache as SCEVCouldNotCompute to avoid the infinite recursion. To avoid
12992 // this, we only want to create SCEVs that are constants in this section.
12993 // So we bail if Denominator is not a constant.
12994 if (!isa<ConstantInt>(LR))
12995 return false;
12996
12997 auto *Denominator = cast<SCEVConstant>(getSCEV(LR));
12998
12999 // We want to make sure that LHS = FoundLHS / Denominator. If it is so,
13000 // then a SCEV for the numerator already exists and matches with FoundLHS.
13001 auto *Numerator = getExistingSCEV(LL);
13002 if (!Numerator || Numerator->getType() != FoundLHS->getType())
13003 return false;
13004
13005 // Make sure that the numerator matches with FoundLHS and the denominator
13006 // is positive.
13007 if (!HasSameValue(Numerator, FoundLHS) || !isKnownPositive(Denominator))
13008 return false;
13009
13010 auto *DTy = Denominator->getType();
13011 auto *FRHSTy = FoundRHS->getType();
13012 if (DTy->isPointerTy() != FRHSTy->isPointerTy())
13013 // One of types is a pointer and another one is not. We cannot extend
13014 // them properly to a wider type, so let us just reject this case.
13015 // TODO: Usage of getEffectiveSCEVType for DTy, FRHSTy etc should help
13016 // to avoid this check.
13017 return false;
13018
13019 // Given that:
13020 // FoundLHS > FoundRHS, LHS = FoundLHS / Denominator, Denominator > 0.
13021 auto *WTy = getWiderType(DTy, FRHSTy);
13022 auto *DenominatorExt = getNoopOrSignExtend(Denominator, WTy);
13023 auto *FoundRHSExt = getNoopOrSignExtend(FoundRHS, WTy);
13024
13025 // Try to prove the following rule:
13026 // (FoundRHS > Denominator - 2) && (RHS <= 0) => (LHS > RHS).
13027 // For example, given that FoundLHS > 2. It means that FoundLHS is at
13028 // least 3. If we divide it by Denominator < 4, we will have at least 1.
13029 auto *DenomMinusTwo = getMinusSCEV(DenominatorExt, getConstant(WTy, 2));
13030 if (isKnownNonPositive(RHS) &&
13031 IsSGTViaContext(FoundRHSExt, DenomMinusTwo))
13032 return true;
13033
13034 // Try to prove the following rule:
13035 // (FoundRHS > -1 - Denominator) && (RHS < 0) => (LHS > RHS).
13036 // For example, given that FoundLHS > -3. Then FoundLHS is at least -2.
13037 // If we divide it by Denominator > 2, then:
13038 // 1. If FoundLHS is negative, then the result is 0.
13039 // 2. If FoundLHS is non-negative, then the result is non-negative.
13040 // Anyways, the result is non-negative.
13041 auto *MinusOne = getMinusOne(WTy);
13042 auto *NegDenomMinusOne = getMinusSCEV(MinusOne, DenominatorExt);
13043 if (isKnownNegative(RHS) &&
13044 IsSGTViaContext(FoundRHSExt, NegDenomMinusOne))
13045 return true;
13046 }
13047 }
13048
13049 // If our expression contained SCEVUnknown Phis, and we split it down and now
13050 // need to prove something for them, try to prove the predicate for every
13051 // possible incoming values of those Phis.
13052 if (isImpliedViaMerge(Pred, OrigLHS, RHS, OrigFoundLHS, FoundRHS, Depth + 1))
13053 return true;
13054
13055 return false;
13056}
13057
13059 const SCEV *RHS) {
13060 // zext x u<= sext x, sext x s<= zext x
13061 const SCEV *Op;
13062 switch (Pred) {
13063 case ICmpInst::ICMP_SGE:
13064 std::swap(LHS, RHS);
13065 [[fallthrough]];
13066 case ICmpInst::ICMP_SLE: {
13067 // If operand >=s 0 then ZExt == SExt. If operand <s 0 then SExt <s ZExt.
13068 return match(LHS, m_scev_SExt(m_SCEV(Op))) &&
13070 }
13071 case ICmpInst::ICMP_UGE:
13072 std::swap(LHS, RHS);
13073 [[fallthrough]];
13074 case ICmpInst::ICMP_ULE: {
13075 // If operand >=u 0 then ZExt == SExt. If operand <u 0 then ZExt <u SExt.
13076 return match(LHS, m_scev_ZExt(m_SCEV(Op))) &&
13078 }
13079 default:
13080 return false;
13081 };
13082 llvm_unreachable("unhandled case");
13083}
13084
13085bool ScalarEvolution::isKnownViaNonRecursiveReasoning(CmpPredicate Pred,
13086 SCEVUse LHS,
13087 SCEVUse RHS) {
13088 return isKnownPredicateExtendIdiom(Pred, LHS, RHS) ||
13089 isKnownPredicateViaConstantRanges(Pred, LHS, RHS) ||
13090 IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) ||
13091 IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) ||
13092 isKnownPredicateViaNoOverflow(Pred, LHS, RHS);
13093}
13094
13095bool ScalarEvolution::isImpliedCondOperandsHelper(CmpPredicate Pred,
13096 const SCEV *LHS,
13097 const SCEV *RHS,
13098 const SCEV *FoundLHS,
13099 const SCEV *FoundRHS) {
13100 switch (Pred) {
13101 default:
13102 llvm_unreachable("Unexpected CmpPredicate value!");
13103 case ICmpInst::ICMP_EQ:
13104 case ICmpInst::ICMP_NE:
13105 if (HasSameValue(LHS, FoundLHS) && HasSameValue(RHS, FoundRHS))
13106 return true;
13107 break;
13108 case ICmpInst::ICMP_SLT:
13109 case ICmpInst::ICMP_SLE:
13110 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, LHS, FoundLHS) &&
13111 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, RHS, FoundRHS))
13112 return true;
13113 break;
13114 case ICmpInst::ICMP_SGT:
13115 case ICmpInst::ICMP_SGE:
13116 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, LHS, FoundLHS) &&
13117 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, RHS, FoundRHS))
13118 return true;
13119 break;
13120 case ICmpInst::ICMP_ULT:
13121 case ICmpInst::ICMP_ULE:
13122 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, LHS, FoundLHS) &&
13123 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, RHS, FoundRHS))
13124 return true;
13125 break;
13126 case ICmpInst::ICMP_UGT:
13127 case ICmpInst::ICMP_UGE:
13128 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, LHS, FoundLHS) &&
13129 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, RHS, FoundRHS))
13130 return true;
13131 break;
13132 }
13133
13134 // Maybe it can be proved via operations?
13135 if (isImpliedViaOperations(Pred, LHS, RHS, FoundLHS, FoundRHS))
13136 return true;
13137
13138 return false;
13139}
13140
13141bool ScalarEvolution::isImpliedCondOperandsViaRanges(
13142 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, CmpPredicate FoundPred,
13143 const SCEV *FoundLHS, const SCEV *FoundRHS) {
13144 if (!isa<SCEVConstant>(RHS) || !isa<SCEVConstant>(FoundRHS))
13145 // The restriction on `FoundRHS` be lifted easily -- it exists only to
13146 // reduce the compile time impact of this optimization.
13147 return false;
13148
13149 std::optional<APInt> Addend = computeConstantDifference(LHS, FoundLHS);
13150 if (!Addend)
13151 return false;
13152
13153 const APInt &ConstFoundRHS = cast<SCEVConstant>(FoundRHS)->getAPInt();
13154
13155 // `FoundLHSRange` is the range we know `FoundLHS` to be in by virtue of the
13156 // antecedent "`FoundLHS` `FoundPred` `FoundRHS`".
13157 ConstantRange FoundLHSRange =
13158 ConstantRange::makeExactICmpRegion(FoundPred, ConstFoundRHS);
13159
13160 // Since `LHS` is `FoundLHS` + `Addend`, we can compute a range for `LHS`:
13161 ConstantRange LHSRange = FoundLHSRange.add(ConstantRange(*Addend));
13162
13163 // We can also compute the range of values for `LHS` that satisfy the
13164 // consequent, "`LHS` `Pred` `RHS`":
13165 const APInt &ConstRHS = cast<SCEVConstant>(RHS)->getAPInt();
13166 // The antecedent implies the consequent if every value of `LHS` that
13167 // satisfies the antecedent also satisfies the consequent.
13168 return LHSRange.icmp(Pred, ConstRHS);
13169}
13170
13171bool ScalarEvolution::canIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride,
13172 bool IsSigned) {
13173 assert(isKnownPositive(Stride) && "Positive stride expected!");
13174
13175 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
13176 const SCEV *One = getOne(Stride->getType());
13177
13178 if (IsSigned) {
13179 APInt MaxRHS = getSignedRangeMax(RHS);
13180 APInt MaxValue = APInt::getSignedMaxValue(BitWidth);
13181 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
13182
13183 // SMaxRHS + SMaxStrideMinusOne > SMaxValue => overflow!
13184 return (std::move(MaxValue) - MaxStrideMinusOne).slt(MaxRHS);
13185 }
13186
13187 APInt MaxRHS = getUnsignedRangeMax(RHS);
13188 APInt MaxValue = APInt::getMaxValue(BitWidth);
13189 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
13190
13191 // UMaxRHS + UMaxStrideMinusOne > UMaxValue => overflow!
13192 return (std::move(MaxValue) - MaxStrideMinusOne).ult(MaxRHS);
13193}
13194
13195bool ScalarEvolution::canIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride,
13196 bool IsSigned) {
13197
13198 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
13199 const SCEV *One = getOne(Stride->getType());
13200
13201 if (IsSigned) {
13202 APInt MinRHS = getSignedRangeMin(RHS);
13203 APInt MinValue = APInt::getSignedMinValue(BitWidth);
13204 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
13205
13206 // SMinRHS - SMaxStrideMinusOne < SMinValue => overflow!
13207 return (std::move(MinValue) + MaxStrideMinusOne).sgt(MinRHS);
13208 }
13209
13210 APInt MinRHS = getUnsignedRangeMin(RHS);
13211 APInt MinValue = APInt::getMinValue(BitWidth);
13212 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
13213
13214 // UMinRHS - UMaxStrideMinusOne < UMinValue => overflow!
13215 return (std::move(MinValue) + MaxStrideMinusOne).ugt(MinRHS);
13216}
13217
13219 // umin(N, 1) + floor((N - umin(N, 1)) / D)
13220 // This is equivalent to "1 + floor((N - 1) / D)" for N != 0. The umin
13221 // expression fixes the case of N=0.
13222 const SCEV *MinNOne = getUMinExpr(N, getOne(N->getType()));
13223 const SCEV *NMinusOne = getMinusSCEV(N, MinNOne);
13224 return getAddExpr(MinNOne, getUDivExpr(NMinusOne, D));
13225}
13226
13227const SCEV *ScalarEvolution::computeMaxBECountForLT(const SCEV *Start,
13228 const SCEV *Stride,
13229 const SCEV *End,
13230 unsigned BitWidth,
13231 bool IsSigned) {
13232 // The logic in this function assumes we can represent a positive stride.
13233 // If we can't, the backedge-taken count must be zero.
13234 if (IsSigned && BitWidth == 1)
13235 return getZero(Stride->getType());
13236
13237 // This code below only been closely audited for negative strides in the
13238 // unsigned comparison case, it may be correct for signed comparison, but
13239 // that needs to be established.
13240 if (IsSigned && isKnownNegative(Stride))
13241 return getCouldNotCompute();
13242
13243 // Calculate the maximum backedge count based on the range of values
13244 // permitted by Start, End, and Stride.
13245 APInt MinStart =
13246 IsSigned ? getSignedRangeMin(Start) : getUnsignedRangeMin(Start);
13247
13248 APInt MinStride =
13249 IsSigned ? getSignedRangeMin(Stride) : getUnsignedRangeMin(Stride);
13250
13251 // We assume either the stride is positive, or the backedge-taken count
13252 // is zero. So force StrideForMaxBECount to be at least one.
13253 APInt One(BitWidth, 1);
13254 APInt StrideForMaxBECount = IsSigned ? APIntOps::smax(One, MinStride)
13255 : APIntOps::umax(One, MinStride);
13256
13257 APInt MaxValue = IsSigned ? APInt::getSignedMaxValue(BitWidth)
13258 : APInt::getMaxValue(BitWidth);
13259 APInt Limit = MaxValue - (StrideForMaxBECount - 1);
13260
13261 // Although End can be a MAX expression we estimate MaxEnd considering only
13262 // the case End = RHS of the loop termination condition. This is safe because
13263 // in the other case (End - Start) is zero, leading to a zero maximum backedge
13264 // taken count.
13265 APInt MaxEnd = IsSigned ? APIntOps::smin(getSignedRangeMax(End), Limit)
13266 : APIntOps::umin(getUnsignedRangeMax(End), Limit);
13267
13268 // MaxBECount = ceil((max(MaxEnd, MinStart) - MinStart) / Stride)
13269 MaxEnd = IsSigned ? APIntOps::smax(MaxEnd, MinStart)
13270 : APIntOps::umax(MaxEnd, MinStart);
13271
13272 return getUDivCeilSCEV(getConstant(MaxEnd - MinStart) /* Delta */,
13273 getConstant(StrideForMaxBECount) /* Step */);
13274}
13275
13277ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
13278 const Loop *L, bool IsSigned,
13279 bool ControlsOnlyExit, bool AllowPredicates) {
13281
13283 bool PredicatedIV = false;
13284 if (!IV) {
13285 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS)) {
13286 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(ZExt->getOperand());
13287 if (AR && AR->getLoop() == L && AR->isAffine()) {
13288 auto canProveNUW = [&]() {
13289 // We can use the comparison to infer no-wrap flags only if it fully
13290 // controls the loop exit.
13291 if (!ControlsOnlyExit)
13292 return false;
13293
13294 if (!isLoopInvariant(RHS, L))
13295 return false;
13296
13297 if (!isKnownNonZero(AR->getStepRecurrence(*this)))
13298 // We need the sequence defined by AR to strictly increase in the
13299 // unsigned integer domain for the logic below to hold.
13300 return false;
13301
13302 const unsigned InnerBitWidth = getTypeSizeInBits(AR->getType());
13303 const unsigned OuterBitWidth = getTypeSizeInBits(RHS->getType());
13304 // If RHS <=u Limit, then there must exist a value V in the sequence
13305 // defined by AR (e.g. {Start,+,Step}) such that V >u RHS, and
13306 // V <=u UINT_MAX. Thus, we must exit the loop before unsigned
13307 // overflow occurs. This limit also implies that a signed comparison
13308 // (in the wide bitwidth) is equivalent to an unsigned comparison as
13309 // the high bits on both sides must be zero.
13310 APInt StrideMax = getUnsignedRangeMax(AR->getStepRecurrence(*this));
13311 APInt Limit = APInt::getMaxValue(InnerBitWidth) - (StrideMax - 1);
13312 Limit = Limit.zext(OuterBitWidth);
13313 return getUnsignedRangeMax(applyLoopGuards(RHS, L)).ule(Limit);
13314 };
13315 auto Flags = AR->getNoWrapFlags();
13316 if (!hasFlags(Flags, SCEV::FlagNUW) && canProveNUW())
13317 Flags = setFlags(Flags, SCEV::FlagNUW);
13318
13319 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
13320 if (AR->hasNoUnsignedWrap()) {
13321 // Emulate what getZeroExtendExpr would have done during construction
13322 // if we'd been able to infer the fact just above at that time.
13323 const SCEV *Step = AR->getStepRecurrence(*this);
13324 Type *Ty = ZExt->getType();
13325 auto *S = getAddRecExpr(
13327 getZeroExtendExpr(Step, Ty, 0), L, AR->getNoWrapFlags());
13329 }
13330 }
13331 }
13332 }
13333
13334
13335 if (!IV && AllowPredicates) {
13336 // Try to make this an AddRec using runtime tests, in the first X
13337 // iterations of this loop, where X is the SCEV expression found by the
13338 // algorithm below.
13339 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
13340 PredicatedIV = true;
13341 }
13342
13343 // Avoid weird loops
13344 if (!IV || IV->getLoop() != L || !IV->isAffine())
13345 return getCouldNotCompute();
13346
13347 // A precondition of this method is that the condition being analyzed
13348 // reaches an exiting branch which dominates the latch. Given that, we can
13349 // assume that an increment which violates the nowrap specification and
13350 // produces poison must cause undefined behavior when the resulting poison
13351 // value is branched upon and thus we can conclude that the backedge is
13352 // taken no more often than would be required to produce that poison value.
13353 // Note that a well defined loop can exit on the iteration which violates
13354 // the nowrap specification if there is another exit (either explicit or
13355 // implicit/exceptional) which causes the loop to execute before the
13356 // exiting instruction we're analyzing would trigger UB.
13357 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
13358 bool NoWrap = ControlsOnlyExit && any(IV->getNoWrapFlags(WrapType));
13360
13361 const SCEV *Stride = IV->getStepRecurrence(*this);
13362
13363 bool PositiveStride = isKnownPositive(Stride);
13364
13365 // Avoid negative or zero stride values.
13366 if (!PositiveStride) {
13367 // We can compute the correct backedge taken count for loops with unknown
13368 // strides if we can prove that the loop is not an infinite loop with side
13369 // effects. Here's the loop structure we are trying to handle -
13370 //
13371 // i = start
13372 // do {
13373 // A[i] = i;
13374 // i += s;
13375 // } while (i < end);
13376 //
13377 // The backedge taken count for such loops is evaluated as -
13378 // (max(end, start + stride) - start - 1) /u stride
13379 //
13380 // The additional preconditions that we need to check to prove correctness
13381 // of the above formula is as follows -
13382 //
13383 // a) IV is either nuw or nsw depending upon signedness (indicated by the
13384 // NoWrap flag).
13385 // b) the loop is guaranteed to be finite (e.g. is mustprogress and has
13386 // no side effects within the loop)
13387 // c) loop has a single static exit (with no abnormal exits)
13388 //
13389 // Precondition a) implies that if the stride is negative, this is a single
13390 // trip loop. The backedge taken count formula reduces to zero in this case.
13391 //
13392 // Precondition b) and c) combine to imply that if rhs is invariant in L,
13393 // then a zero stride means the backedge can't be taken without executing
13394 // undefined behavior.
13395 //
13396 // The positive stride case is the same as isKnownPositive(Stride) returning
13397 // true (original behavior of the function).
13398 //
13399 if (PredicatedIV || !NoWrap || !loopIsFiniteByAssumption(L) ||
13401 return getCouldNotCompute();
13402
13403 if (!isKnownNonZero(Stride)) {
13404 // If we have a step of zero, and RHS isn't invariant in L, we don't know
13405 // if it might eventually be greater than start and if so, on which
13406 // iteration. We can't even produce a useful upper bound.
13407 if (!isLoopInvariant(RHS, L))
13408 return getCouldNotCompute();
13409
13410 // We allow a potentially zero stride, but we need to divide by stride
13411 // below. Since the loop can't be infinite and this check must control
13412 // the sole exit, we can infer the exit must be taken on the first
13413 // iteration (e.g. backedge count = 0) if the stride is zero. Given that,
13414 // we know the numerator in the divides below must be zero, so we can
13415 // pick an arbitrary non-zero value for the denominator (e.g. stride)
13416 // and produce the right result.
13417 // FIXME: Handle the case where Stride is poison?
13418 auto wouldZeroStrideBeUB = [&]() {
13419 // Proof by contradiction. Suppose the stride were zero. If we can
13420 // prove that the backedge *is* taken on the first iteration, then since
13421 // we know this condition controls the sole exit, we must have an
13422 // infinite loop. We can't have a (well defined) infinite loop per
13423 // check just above.
13424 // Note: The (Start - Stride) term is used to get the start' term from
13425 // (start' + stride,+,stride). Remember that we only care about the
13426 // result of this expression when stride == 0 at runtime.
13427 auto *StartIfZero = getMinusSCEV(IV->getStart(), Stride);
13428 return isLoopEntryGuardedByCond(L, Cond, StartIfZero, RHS);
13429 };
13430 if (!wouldZeroStrideBeUB()) {
13431 Stride = getUMaxExpr(Stride, getOne(Stride->getType()));
13432 }
13433 }
13434 } else if (!NoWrap) {
13435 // Avoid proven overflow cases: this will ensure that the backedge taken
13436 // count will not generate any unsigned overflow.
13437 if (canIVOverflowOnLT(RHS, Stride, IsSigned))
13438 return getCouldNotCompute();
13439 }
13440
13441 // On all paths just preceeding, we established the following invariant:
13442 // IV can be assumed not to overflow up to and including the exiting
13443 // iteration. We proved this in one of two ways:
13444 // 1) We can show overflow doesn't occur before the exiting iteration
13445 // 1a) canIVOverflowOnLT, and b) step of one
13446 // 2) We can show that if overflow occurs, the loop must execute UB
13447 // before any possible exit.
13448 // Note that we have not yet proved RHS invariant (in general).
13449
13450 const SCEV *Start = IV->getStart();
13451
13452 // Preserve pointer-typed Start/RHS to pass to isLoopEntryGuardedByCond.
13453 // If we convert to integers, isLoopEntryGuardedByCond will miss some cases.
13454 // Use integer-typed versions for actual computation; we can't subtract
13455 // pointers in general.
13456 const SCEV *OrigStart = Start;
13457 const SCEV *OrigRHS = RHS;
13458 if (Start->getType()->isPointerTy()) {
13460 if (isa<SCEVCouldNotCompute>(Start))
13461 return Start;
13462 }
13463 if (RHS->getType()->isPointerTy()) {
13466 return RHS;
13467 }
13468
13469 const SCEV *End = nullptr, *BECount = nullptr,
13470 *BECountIfBackedgeTaken = nullptr;
13471 if (!isLoopInvariant(RHS, L)) {
13472 const auto *RHSAddRec = dyn_cast<SCEVAddRecExpr>(RHS);
13473 if (PositiveStride && RHSAddRec != nullptr && RHSAddRec->getLoop() == L &&
13474 any(RHSAddRec->getNoWrapFlags())) {
13475 // The structure of loop we are trying to calculate backedge count of:
13476 //
13477 // left = left_start
13478 // right = right_start
13479 //
13480 // while(left < right){
13481 // ... do something here ...
13482 // left += s1; // stride of left is s1 (s1 > 0)
13483 // right += s2; // stride of right is s2 (s2 < 0)
13484 // }
13485 //
13486
13487 const SCEV *RHSStart = RHSAddRec->getStart();
13488 const SCEV *RHSStride = RHSAddRec->getStepRecurrence(*this);
13489
13490 // If Stride - RHSStride is positive and does not overflow, we can write
13491 // backedge count as ->
13492 // ceil((End - Start) /u (Stride - RHSStride))
13493 // Where, End = max(RHSStart, Start)
13494
13495 // Check if RHSStride < 0 and Stride - RHSStride will not overflow.
13496 if (isKnownNegative(RHSStride) &&
13497 willNotOverflow(Instruction::Sub, /*Signed=*/true, Stride,
13498 RHSStride)) {
13499
13500 const SCEV *Denominator = getMinusSCEV(Stride, RHSStride);
13501 if (isKnownPositive(Denominator)) {
13502 End = IsSigned ? getSMaxExpr(RHSStart, Start)
13503 : getUMaxExpr(RHSStart, Start);
13504
13505 // We can do this because End >= Start, as End = max(RHSStart, Start)
13506 const SCEV *Delta = getMinusSCEV(End, Start);
13507
13508 BECount = getUDivCeilSCEV(Delta, Denominator);
13509 BECountIfBackedgeTaken =
13510 getUDivCeilSCEV(getMinusSCEV(RHSStart, Start), Denominator);
13511 }
13512 }
13513 }
13514 if (BECount == nullptr) {
13515 // If we cannot calculate ExactBECount, we can calculate the MaxBECount,
13516 // given the start, stride and max value for the end bound of the
13517 // loop (RHS), and the fact that IV does not overflow (which is
13518 // checked above).
13519 const SCEV *MaxBECount = computeMaxBECountForLT(
13520 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13521 return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,
13522 MaxBECount, false /*MaxOrZero*/, Predicates);
13523 }
13524 } else {
13525 // We use the expression (max(End,Start)-Start)/Stride to describe the
13526 // backedge count, as if the backedge is taken at least once
13527 // max(End,Start) is End and so the result is as above, and if not
13528 // max(End,Start) is Start so we get a backedge count of zero.
13529 auto *OrigStartMinusStride = getMinusSCEV(OrigStart, Stride);
13530 assert(isAvailableAtLoopEntry(OrigStartMinusStride, L) && "Must be!");
13531 assert(isAvailableAtLoopEntry(OrigStart, L) && "Must be!");
13532 assert(isAvailableAtLoopEntry(OrigRHS, L) && "Must be!");
13533 // Can we prove (max(RHS,Start) > Start - Stride?
13534 if (isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigStart) &&
13535 isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigRHS)) {
13536 // In this case, we can use a refined formula for computing backedge
13537 // taken count. The general formula remains:
13538 // "End-Start /uceiling Stride" where "End = max(RHS,Start)"
13539 // We want to use the alternate formula:
13540 // "((End - 1) - (Start - Stride)) /u Stride"
13541 // Let's do a quick case analysis to show these are equivalent under
13542 // our precondition that max(RHS,Start) > Start - Stride.
13543 // * For RHS <= Start, the backedge-taken count must be zero.
13544 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
13545 // "((Start - 1) - (Start - Stride)) /u Stride" which simplies to
13546 // "Stride - 1 /u Stride" which is indeed zero for all non-zero values
13547 // of Stride. For 0 stride, we've use umin(1,Stride) above,
13548 // reducing this to the stride of 1 case.
13549 // * For RHS >= Start, the backedge count must be "RHS-Start /uceil
13550 // Stride".
13551 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
13552 // "((RHS - 1) - (Start - Stride)) /u Stride" reassociates to
13553 // "((RHS - (Start - Stride) - 1) /u Stride".
13554 // Our preconditions trivially imply no overflow in that form.
13555 const SCEV *MinusOne = getMinusOne(Stride->getType());
13556 const SCEV *Numerator =
13557 getMinusSCEV(getAddExpr(RHS, MinusOne), getMinusSCEV(Start, Stride));
13558 BECount = getUDivExpr(Numerator, Stride);
13559 }
13560
13561 if (!BECount) {
13562 auto canProveRHSGreaterThanEqualStart = [&]() {
13563 auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
13564 const SCEV *GuardedRHS = applyLoopGuards(OrigRHS, L);
13565 const SCEV *GuardedStart = applyLoopGuards(OrigStart, L);
13566
13567 if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart) ||
13568 isKnownPredicate(CondGE, GuardedRHS, GuardedStart))
13569 return true;
13570
13571 // (RHS > Start - 1) implies RHS >= Start.
13572 // * "RHS >= Start" is trivially equivalent to "RHS > Start - 1" if
13573 // "Start - 1" doesn't overflow.
13574 // * For signed comparison, if Start - 1 does overflow, it's equal
13575 // to INT_MAX, and "RHS >s INT_MAX" is trivially false.
13576 // * For unsigned comparison, if Start - 1 does overflow, it's equal
13577 // to UINT_MAX, and "RHS >u UINT_MAX" is trivially false.
13578 //
13579 // FIXME: Should isLoopEntryGuardedByCond do this for us?
13580 auto CondGT = IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT;
13581 auto *StartMinusOne =
13582 getAddExpr(OrigStart, getMinusOne(OrigStart->getType()));
13583 return isLoopEntryGuardedByCond(L, CondGT, OrigRHS, StartMinusOne);
13584 };
13585
13586 // If we know that RHS >= Start in the context of loop, then we know
13587 // that max(RHS, Start) = RHS at this point.
13588 if (canProveRHSGreaterThanEqualStart()) {
13589 End = RHS;
13590 } else {
13591 // If RHS < Start, the backedge will be taken zero times. So in
13592 // general, we can write the backedge-taken count as:
13593 //
13594 // RHS >= Start ? ceil(RHS - Start) / Stride : 0
13595 //
13596 // We convert it to the following to make it more convenient for SCEV:
13597 //
13598 // ceil(max(RHS, Start) - Start) / Stride
13599 End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start);
13600
13601 // See what would happen if we assume the backedge is taken. This is
13602 // used to compute MaxBECount.
13603 BECountIfBackedgeTaken =
13604 getUDivCeilSCEV(getMinusSCEV(RHS, Start), Stride);
13605 }
13606
13607 // At this point, we know:
13608 //
13609 // 1. If IsSigned, Start <=s End; otherwise, Start <=u End
13610 // 2. The index variable doesn't overflow.
13611 //
13612 // Therefore, we know N exists such that
13613 // (Start + Stride * N) >= End, and computing "(Start + Stride * N)"
13614 // doesn't overflow.
13615 //
13616 // Using this information, try to prove whether the addition in
13617 // "(Start - End) + (Stride - 1)" has unsigned overflow.
13618 const SCEV *One = getOne(Stride->getType());
13619 bool MayAddOverflow = [&] {
13620 if (isKnownToBeAPowerOfTwo(Stride)) {
13621 // Suppose Stride is a power of two, and Start/End are unsigned
13622 // integers. Let UMAX be the largest representable unsigned
13623 // integer.
13624 //
13625 // By the preconditions of this function, we know
13626 // "(Start + Stride * N) >= End", and this doesn't overflow.
13627 // As a formula:
13628 //
13629 // End <= (Start + Stride * N) <= UMAX
13630 //
13631 // Subtracting Start from all the terms:
13632 //
13633 // End - Start <= Stride * N <= UMAX - Start
13634 //
13635 // Since Start is unsigned, UMAX - Start <= UMAX. Therefore:
13636 //
13637 // End - Start <= Stride * N <= UMAX
13638 //
13639 // Stride * N is a multiple of Stride. Therefore,
13640 //
13641 // End - Start <= Stride * N <= UMAX - (UMAX mod Stride)
13642 //
13643 // Since Stride is a power of two, UMAX + 1 is divisible by
13644 // Stride. Therefore, UMAX mod Stride == Stride - 1. So we can
13645 // write:
13646 //
13647 // End - Start <= Stride * N <= UMAX - Stride - 1
13648 //
13649 // Dropping the middle term:
13650 //
13651 // End - Start <= UMAX - Stride - 1
13652 //
13653 // Adding Stride - 1 to both sides:
13654 //
13655 // (End - Start) + (Stride - 1) <= UMAX
13656 //
13657 // In other words, the addition doesn't have unsigned overflow.
13658 //
13659 // A similar proof works if we treat Start/End as signed values.
13660 // Just rewrite steps before "End - Start <= Stride * N <= UMAX"
13661 // to use signed max instead of unsigned max. Note that we're
13662 // trying to prove a lack of unsigned overflow in either case.
13663 return false;
13664 }
13665 if (Start == Stride || Start == getMinusSCEV(Stride, One)) {
13666 // If Start is equal to Stride, (End - Start) + (Stride - 1) == End
13667 // - 1. If !IsSigned, 0 <u Stride == Start <=u End; so 0 <u End - 1
13668 // <u End. If IsSigned, 0 <s Stride == Start <=s End; so 0 <s End -
13669 // 1 <s End.
13670 //
13671 // If Start is equal to Stride - 1, (End - Start) + Stride - 1 ==
13672 // End.
13673 return false;
13674 }
13675 return true;
13676 }();
13677
13678 const SCEV *Delta = getMinusSCEV(End, Start);
13679 if (!MayAddOverflow) {
13680 // floor((D + (S - 1)) / S)
13681 // We prefer this formulation if it's legal because it's fewer
13682 // operations.
13683 BECount =
13684 getUDivExpr(getAddExpr(Delta, getMinusSCEV(Stride, One)), Stride);
13685 } else {
13686 BECount = getUDivCeilSCEV(Delta, Stride);
13687 }
13688 }
13689 }
13690
13691 const SCEV *ConstantMaxBECount;
13692 bool MaxOrZero = false;
13693 if (isa<SCEVConstant>(BECount)) {
13694 ConstantMaxBECount = BECount;
13695 } else if (BECountIfBackedgeTaken &&
13696 isa<SCEVConstant>(BECountIfBackedgeTaken)) {
13697 // If we know exactly how many times the backedge will be taken if it's
13698 // taken at least once, then the backedge count will either be that or
13699 // zero.
13700 ConstantMaxBECount = BECountIfBackedgeTaken;
13701 MaxOrZero = true;
13702 } else {
13703 ConstantMaxBECount = computeMaxBECountForLT(
13704 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13705 }
13706
13707 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
13708 !isa<SCEVCouldNotCompute>(BECount))
13709 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
13710
13711 const SCEV *SymbolicMaxBECount =
13712 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13713 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, MaxOrZero,
13714 Predicates);
13715}
13716
13717ScalarEvolution::ExitLimit ScalarEvolution::howManyGreaterThans(
13718 const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned,
13719 bool ControlsOnlyExit, bool AllowPredicates) {
13721 // We handle only IV > Invariant
13722 if (!isLoopInvariant(RHS, L))
13723 return getCouldNotCompute();
13724
13725 const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
13726 if (!IV && AllowPredicates)
13727 // Try to make this an AddRec using runtime tests, in the first X
13728 // iterations of this loop, where X is the SCEV expression found by the
13729 // algorithm below.
13730 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
13731
13732 // Avoid weird loops
13733 if (!IV || IV->getLoop() != L || !IV->isAffine())
13734 return getCouldNotCompute();
13735
13736 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
13737 bool NoWrap = ControlsOnlyExit && any(IV->getNoWrapFlags(WrapType));
13739
13740 const SCEV *Stride = getNegativeSCEV(IV->getStepRecurrence(*this));
13741
13742 // Avoid negative or zero stride values
13743 if (!isKnownPositive(Stride))
13744 return getCouldNotCompute();
13745
13746 // Avoid proven overflow cases: this will ensure that the backedge taken count
13747 // will not generate any unsigned overflow. Relaxed no-overflow conditions
13748 // exploit NoWrapFlags, allowing to optimize in presence of undefined
13749 // behaviors like the case of C language.
13750 if (!Stride->isOne() && !NoWrap)
13751 if (canIVOverflowOnGT(RHS, Stride, IsSigned))
13752 return getCouldNotCompute();
13753
13754 const SCEV *Start = IV->getStart();
13755 const SCEV *End = RHS;
13756 if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) {
13757 // If we know that Start >= RHS in the context of loop, then we know that
13758 // min(RHS, Start) = RHS at this point.
13760 L, IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE, Start, RHS))
13761 End = RHS;
13762 else
13763 End = IsSigned ? getSMinExpr(RHS, Start) : getUMinExpr(RHS, Start);
13764 }
13765
13766 if (Start->getType()->isPointerTy()) {
13768 if (isa<SCEVCouldNotCompute>(Start))
13769 return Start;
13770 }
13771 if (End->getType()->isPointerTy()) {
13772 End = getLosslessPtrToIntExpr(End);
13773 if (isa<SCEVCouldNotCompute>(End))
13774 return End;
13775 }
13776
13777 // Compute ((Start - End) + (Stride - 1)) / Stride.
13778 // FIXME: This can overflow. Holding off on fixing this for now;
13779 // howManyGreaterThans will hopefully be gone soon.
13780 const SCEV *One = getOne(Stride->getType());
13781 const SCEV *BECount = getUDivExpr(
13782 getAddExpr(getMinusSCEV(Start, End), getMinusSCEV(Stride, One)), Stride);
13783
13784 APInt MaxStart = IsSigned ? getSignedRangeMax(Start)
13786
13787 APInt MinStride = IsSigned ? getSignedRangeMin(Stride)
13788 : getUnsignedRangeMin(Stride);
13789
13790 unsigned BitWidth = getTypeSizeInBits(LHS->getType());
13791 APInt Limit = IsSigned ? APInt::getSignedMinValue(BitWidth) + (MinStride - 1)
13792 : APInt::getMinValue(BitWidth) + (MinStride - 1);
13793
13794 // Although End can be a MIN expression we estimate MinEnd considering only
13795 // the case End = RHS. This is safe because in the other case (Start - End)
13796 // is zero, leading to a zero maximum backedge taken count.
13797 APInt MinEnd =
13798 IsSigned ? APIntOps::smax(getSignedRangeMin(RHS), Limit)
13799 : APIntOps::umax(getUnsignedRangeMin(RHS), Limit);
13800
13801 const SCEV *ConstantMaxBECount =
13802 isa<SCEVConstant>(BECount)
13803 ? BECount
13804 : getUDivCeilSCEV(getConstant(MaxStart - MinEnd),
13805 getConstant(MinStride));
13806
13807 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount))
13808 ConstantMaxBECount = BECount;
13809 const SCEV *SymbolicMaxBECount =
13810 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13811
13812 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
13813 Predicates);
13814}
13815
13817 ScalarEvolution &SE) const {
13818 if (Range.isFullSet()) // Infinite loop.
13819 return SE.getCouldNotCompute();
13820
13821 // If the start is a non-zero constant, shift the range to simplify things.
13822 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart()))
13823 if (!SC->getValue()->isZero()) {
13825 Operands[0] = SE.getZero(SC->getType());
13826 const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop(),
13828 if (const auto *ShiftedAddRec = dyn_cast<SCEVAddRecExpr>(Shifted))
13829 return ShiftedAddRec->getNumIterationsInRange(
13830 Range.subtract(SC->getAPInt()), SE);
13831 // This is strange and shouldn't happen.
13832 return SE.getCouldNotCompute();
13833 }
13834
13835 // The only time we can solve this is when we have all constant indices.
13836 // Otherwise, we cannot determine the overflow conditions.
13837 if (any_of(operands(), [](const SCEV *Op) { return !isa<SCEVConstant>(Op); }))
13838 return SE.getCouldNotCompute();
13839
13840 // Okay at this point we know that all elements of the chrec are constants and
13841 // that the start element is zero.
13842
13843 // First check to see if the range contains zero. If not, the first
13844 // iteration exits.
13845 unsigned BitWidth = SE.getTypeSizeInBits(getType());
13846 if (!Range.contains(APInt(BitWidth, 0)))
13847 return SE.getZero(getType());
13848
13849 if (isAffine()) {
13850 // If this is an affine expression then we have this situation:
13851 // Solve {0,+,A} in Range === Ax in Range
13852
13853 // We know that zero is in the range. If A is positive then we know that
13854 // the upper value of the range must be the first possible exit value.
13855 // If A is negative then the lower of the range is the last possible loop
13856 // value. Also note that we already checked for a full range.
13857 APInt A = cast<SCEVConstant>(getOperand(1))->getAPInt();
13858 APInt End = A.sge(1) ? (Range.getUpper() - 1) : Range.getLower();
13859
13860 // The exit value should be (End+A)/A.
13861 APInt ExitVal = (End + A).udiv(A);
13862 ConstantInt *ExitValue = ConstantInt::get(SE.getContext(), ExitVal);
13863
13864 // Evaluate at the exit value. If we really did fall out of the valid
13865 // range, then we computed our trip count, otherwise wrap around or other
13866 // things must have happened.
13867 ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue, SE);
13868 if (Range.contains(Val->getValue()))
13869 return SE.getCouldNotCompute(); // Something strange happened
13870
13871 // Ensure that the previous value is in the range.
13872 assert(Range.contains(
13874 ConstantInt::get(SE.getContext(), ExitVal - 1), SE)->getValue()) &&
13875 "Linear scev computation is off in a bad way!");
13876 return SE.getConstant(ExitValue);
13877 }
13878
13879 if (isQuadratic()) {
13880 if (auto S = SolveQuadraticAddRecRange(this, Range, SE))
13881 return SE.getConstant(*S);
13882 }
13883
13884 return SE.getCouldNotCompute();
13885}
13886
13887const SCEVAddRecExpr *
13889 assert(getNumOperands() > 1 && "AddRec with zero step?");
13890 // There is a temptation to just call getAddExpr(this, getStepRecurrence(SE)),
13891 // but in this case we cannot guarantee that the value returned will be an
13892 // AddRec because SCEV does not have a fixed point where it stops
13893 // simplification: it is legal to return ({rec1} + {rec2}). For example, it
13894 // may happen if we reach arithmetic depth limit while simplifying. So we
13895 // construct the returned value explicitly.
13897 // If this is {A,+,B,+,C,...,+,N}, then its step is {B,+,C,+,...,+,N}, and
13898 // (this + Step) is {A+B,+,B+C,+...,+,N}.
13899 for (unsigned i = 0, e = getNumOperands() - 1; i < e; ++i)
13900 Ops.push_back(SE.getAddExpr(getOperand(i), getOperand(i + 1)));
13901 // We know that the last operand is not a constant zero (otherwise it would
13902 // have been popped out earlier). This guarantees us that if the result has
13903 // the same last operand, then it will also not be popped out, meaning that
13904 // the returned value will be an AddRec.
13905 const SCEV *Last = getOperand(getNumOperands() - 1);
13906 assert(!Last->isZero() && "Recurrency with zero step?");
13907 Ops.push_back(Last);
13910}
13911
13912// Return true when S contains at least an undef value.
13914 return SCEVExprContains(
13915 S, [](const SCEV *S) { return match(S, m_scev_UndefOrPoison()); });
13916}
13917
13918// Return true when S contains a value that is a nullptr.
13920 return SCEVExprContains(S, [](const SCEV *S) {
13921 if (const auto *SU = dyn_cast<SCEVUnknown>(S))
13922 return SU->getValue() == nullptr;
13923 return false;
13924 });
13925}
13926
13927/// Return the size of an element read or written by Inst.
13929 Type *Ty;
13930 if (StoreInst *Store = dyn_cast<StoreInst>(Inst))
13931 Ty = Store->getValueOperand()->getType();
13932 else if (LoadInst *Load = dyn_cast<LoadInst>(Inst))
13933 Ty = Load->getType();
13934 else
13935 return nullptr;
13936
13938 return getSizeOfExpr(ETy, Ty);
13939}
13940
13941//===----------------------------------------------------------------------===//
13942// SCEVCallbackVH Class Implementation
13943//===----------------------------------------------------------------------===//
13944
13946 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13947 if (PHINode *PN = dyn_cast<PHINode>(getValPtr()))
13948 SE->ConstantEvolutionLoopExitValue.erase(PN);
13949 SE->eraseValueFromMap(getValPtr());
13950 // this now dangles!
13951}
13952
13953void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) {
13954 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13955
13956 // Forget all the expressions associated with users of the old value,
13957 // so that future queries will recompute the expressions using the new
13958 // value.
13959 SE->forgetValue(getValPtr());
13960 // this now dangles!
13961}
13962
13963ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se)
13964 : CallbackVH(V), SE(se) {}
13965
13966//===----------------------------------------------------------------------===//
13967// ScalarEvolution Class Implementation
13968//===----------------------------------------------------------------------===//
13969
13972 LoopInfo &LI)
13973 : F(F), DL(F.getDataLayout()), TLI(TLI), AC(AC), DT(DT), LI(LI),
13974 CouldNotCompute(new SCEVCouldNotCompute()), ValuesAtScopes(64),
13975 LoopDispositions(64), BlockDispositions(64) {
13976 // To use guards for proving predicates, we need to scan every instruction in
13977 // relevant basic blocks, and not just terminators. Doing this is a waste of
13978 // time if the IR does not actually contain any calls to
13979 // @llvm.experimental.guard, so do a quick check and remember this beforehand.
13980 //
13981 // This pessimizes the case where a pass that preserves ScalarEvolution wants
13982 // to _add_ guards to the module when there weren't any before, and wants
13983 // ScalarEvolution to optimize based on those guards. For now we prefer to be
13984 // efficient in lieu of being smart in that rather obscure case.
13985
13986 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
13987 F.getParent(), Intrinsic::experimental_guard);
13988 HasGuards = GuardDecl && !GuardDecl->use_empty();
13989}
13990
13992 : F(Arg.F), DL(Arg.DL), HasGuards(Arg.HasGuards), TLI(Arg.TLI), AC(Arg.AC),
13993 DT(Arg.DT), LI(Arg.LI), CouldNotCompute(std::move(Arg.CouldNotCompute)),
13994 ValueExprMap(std::move(Arg.ValueExprMap)),
13995 PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)),
13996 PendingMerges(std::move(Arg.PendingMerges)),
13997 ConstantMultipleCache(std::move(Arg.ConstantMultipleCache)),
13998 BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)),
13999 PredicatedBackedgeTakenCounts(
14000 std::move(Arg.PredicatedBackedgeTakenCounts)),
14001 BECountUsers(std::move(Arg.BECountUsers)),
14002 ConstantEvolutionLoopExitValue(
14003 std::move(Arg.ConstantEvolutionLoopExitValue)),
14004 ValuesAtScopes(std::move(Arg.ValuesAtScopes)),
14005 ValuesAtScopesUsers(std::move(Arg.ValuesAtScopesUsers)),
14006 LoopDispositions(std::move(Arg.LoopDispositions)),
14007 LoopPropertiesCache(std::move(Arg.LoopPropertiesCache)),
14008 BlockDispositions(std::move(Arg.BlockDispositions)),
14009 SCEVUsers(std::move(Arg.SCEVUsers)),
14010 UnsignedRanges(std::move(Arg.UnsignedRanges)),
14011 SignedRanges(std::move(Arg.SignedRanges)),
14012 UniqueSCEVs(std::move(Arg.UniqueSCEVs)),
14013 UniquePreds(std::move(Arg.UniquePreds)),
14014 SCEVAllocator(std::move(Arg.SCEVAllocator)),
14015 LoopUsers(std::move(Arg.LoopUsers)),
14016 PredicatedSCEVRewrites(std::move(Arg.PredicatedSCEVRewrites)),
14017 FirstUnknown(Arg.FirstUnknown) {
14018 Arg.FirstUnknown = nullptr;
14019}
14020
14022 // Iterate through all the SCEVUnknown instances and call their
14023 // destructors, so that they release their references to their values.
14024 for (SCEVUnknown *U = FirstUnknown; U;) {
14025 SCEVUnknown *Tmp = U;
14026 U = U->Next;
14027 Tmp->~SCEVUnknown();
14028 }
14029 FirstUnknown = nullptr;
14030
14031 ExprValueMap.clear();
14032 ValueExprMap.clear();
14033 HasRecMap.clear();
14034 BackedgeTakenCounts.clear();
14035 PredicatedBackedgeTakenCounts.clear();
14036
14037 assert(PendingLoopPredicates.empty() && "isImpliedCond garbage");
14038 assert(PendingMerges.empty() && "isImpliedViaMerge garbage");
14039 assert(!WalkingBEDominatingConds && "isLoopBackedgeGuardedByCond garbage!");
14040 assert(!ProvingSplitPredicate && "ProvingSplitPredicate garbage!");
14041}
14042
14046
14047/// When printing a top-level SCEV for trip counts, it's helpful to include
14048/// a type for constants which are otherwise hard to disambiguate.
14049static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV* S) {
14050 if (isa<SCEVConstant>(S))
14051 OS << *S->getType() << " ";
14052 OS << *S;
14053}
14054
14056 const Loop *L) {
14057 // Print all inner loops first
14058 for (Loop *I : *L)
14059 PrintLoopInfo(OS, SE, I);
14060
14061 OS << "Loop ";
14062 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14063 OS << ": ";
14064
14065 SmallVector<BasicBlock *, 8> ExitingBlocks;
14066 L->getExitingBlocks(ExitingBlocks);
14067 if (ExitingBlocks.size() != 1)
14068 OS << "<multiple exits> ";
14069
14070 auto *BTC = SE->getBackedgeTakenCount(L);
14071 if (!isa<SCEVCouldNotCompute>(BTC)) {
14072 OS << "backedge-taken count is ";
14073 PrintSCEVWithTypeHint(OS, BTC);
14074 } else
14075 OS << "Unpredictable backedge-taken count.";
14076 OS << "\n";
14077
14078 if (ExitingBlocks.size() > 1)
14079 for (BasicBlock *ExitingBlock : ExitingBlocks) {
14080 OS << " exit count for " << ExitingBlock->getName() << ": ";
14081 const SCEV *EC = SE->getExitCount(L, ExitingBlock);
14082 PrintSCEVWithTypeHint(OS, EC);
14083 if (isa<SCEVCouldNotCompute>(EC)) {
14084 // Retry with predicates.
14086 EC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates);
14087 if (!isa<SCEVCouldNotCompute>(EC)) {
14088 OS << "\n predicated exit count for " << ExitingBlock->getName()
14089 << ": ";
14090 PrintSCEVWithTypeHint(OS, EC);
14091 OS << "\n Predicates:\n";
14092 for (const auto *P : Predicates)
14093 P->print(OS, 4);
14094 }
14095 }
14096 OS << "\n";
14097 }
14098
14099 OS << "Loop ";
14100 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14101 OS << ": ";
14102
14103 auto *ConstantBTC = SE->getConstantMaxBackedgeTakenCount(L);
14104 if (!isa<SCEVCouldNotCompute>(ConstantBTC)) {
14105 OS << "constant max backedge-taken count is ";
14106 PrintSCEVWithTypeHint(OS, ConstantBTC);
14108 OS << ", actual taken count either this or zero.";
14109 } else {
14110 OS << "Unpredictable constant max backedge-taken count. ";
14111 }
14112
14113 OS << "\n"
14114 "Loop ";
14115 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14116 OS << ": ";
14117
14118 auto *SymbolicBTC = SE->getSymbolicMaxBackedgeTakenCount(L);
14119 if (!isa<SCEVCouldNotCompute>(SymbolicBTC)) {
14120 OS << "symbolic max backedge-taken count is ";
14121 PrintSCEVWithTypeHint(OS, SymbolicBTC);
14123 OS << ", actual taken count either this or zero.";
14124 } else {
14125 OS << "Unpredictable symbolic max backedge-taken count. ";
14126 }
14127 OS << "\n";
14128
14129 if (ExitingBlocks.size() > 1)
14130 for (BasicBlock *ExitingBlock : ExitingBlocks) {
14131 OS << " symbolic max exit count for " << ExitingBlock->getName() << ": ";
14132 auto *ExitBTC = SE->getExitCount(L, ExitingBlock,
14134 PrintSCEVWithTypeHint(OS, ExitBTC);
14135 if (isa<SCEVCouldNotCompute>(ExitBTC)) {
14136 // Retry with predicates.
14138 ExitBTC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates,
14140 if (!isa<SCEVCouldNotCompute>(ExitBTC)) {
14141 OS << "\n predicated symbolic max exit count for "
14142 << ExitingBlock->getName() << ": ";
14143 PrintSCEVWithTypeHint(OS, ExitBTC);
14144 OS << "\n Predicates:\n";
14145 for (const auto *P : Predicates)
14146 P->print(OS, 4);
14147 }
14148 }
14149 OS << "\n";
14150 }
14151
14153 auto *PBT = SE->getPredicatedBackedgeTakenCount(L, Preds);
14154 if (PBT != BTC) {
14155 assert(!Preds.empty() && "Different predicated BTC, but no predicates");
14156 OS << "Loop ";
14157 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14158 OS << ": ";
14159 if (!isa<SCEVCouldNotCompute>(PBT)) {
14160 OS << "Predicated backedge-taken count is ";
14161 PrintSCEVWithTypeHint(OS, PBT);
14162 } else
14163 OS << "Unpredictable predicated backedge-taken count.";
14164 OS << "\n";
14165 OS << " Predicates:\n";
14166 for (const auto *P : Preds)
14167 P->print(OS, 4);
14168 }
14169 Preds.clear();
14170
14171 auto *PredConstantMax =
14173 if (PredConstantMax != ConstantBTC) {
14174 assert(!Preds.empty() &&
14175 "different predicated constant max BTC but no predicates");
14176 OS << "Loop ";
14177 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14178 OS << ": ";
14179 if (!isa<SCEVCouldNotCompute>(PredConstantMax)) {
14180 OS << "Predicated constant max backedge-taken count is ";
14181 PrintSCEVWithTypeHint(OS, PredConstantMax);
14182 } else
14183 OS << "Unpredictable predicated constant max backedge-taken count.";
14184 OS << "\n";
14185 OS << " Predicates:\n";
14186 for (const auto *P : Preds)
14187 P->print(OS, 4);
14188 }
14189 Preds.clear();
14190
14191 auto *PredSymbolicMax =
14193 if (SymbolicBTC != PredSymbolicMax) {
14194 assert(!Preds.empty() &&
14195 "Different predicated symbolic max BTC, but no predicates");
14196 OS << "Loop ";
14197 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14198 OS << ": ";
14199 if (!isa<SCEVCouldNotCompute>(PredSymbolicMax)) {
14200 OS << "Predicated symbolic max backedge-taken count is ";
14201 PrintSCEVWithTypeHint(OS, PredSymbolicMax);
14202 } else
14203 OS << "Unpredictable predicated symbolic max backedge-taken count.";
14204 OS << "\n";
14205 OS << " Predicates:\n";
14206 for (const auto *P : Preds)
14207 P->print(OS, 4);
14208 }
14209
14211 OS << "Loop ";
14212 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14213 OS << ": ";
14214 OS << "Trip multiple is " << SE->getSmallConstantTripMultiple(L) << "\n";
14215 }
14216}
14217
14218namespace llvm {
14219// Note: these overloaded operators need to be in the llvm namespace for them
14220// to be resolved correctly. If we put them outside the llvm namespace, the
14221//
14222// OS << ": " << SE.getLoopDisposition(SV, InnerL);
14223//
14224// code below "breaks" and start printing raw enum values as opposed to the
14225// string values.
14228 switch (LD) {
14230 OS << "Variant";
14231 break;
14233 OS << "Invariant";
14234 break;
14236 OS << "Uniform";
14237 break;
14239 OS << "Computable";
14240 break;
14241 }
14242 return OS;
14243}
14244
14247 switch (BD) {
14249 OS << "DoesNotDominate";
14250 break;
14252 OS << "Dominates";
14253 break;
14255 OS << "ProperlyDominates";
14256 break;
14257 }
14258 return OS;
14259}
14260} // namespace llvm
14261
14263 // ScalarEvolution's implementation of the print method is to print
14264 // out SCEV values of all instructions that are interesting. Doing
14265 // this potentially causes it to create new SCEV objects though,
14266 // which technically conflicts with the const qualifier. This isn't
14267 // observable from outside the class though, so casting away the
14268 // const isn't dangerous.
14269 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
14270
14271 if (ClassifyExpressions) {
14272 OS << "Classifying expressions for: ";
14273 F.printAsOperand(OS, /*PrintType=*/false);
14274 OS << "\n";
14275 for (Instruction &I : instructions(F))
14276 if (isSCEVable(I.getType()) && !isa<CmpInst>(I)) {
14277 OS << I << '\n';
14278 OS << " --> ";
14279 const SCEV *SV = SE.getSCEV(&I);
14280 SV->print(OS);
14281 if (!isa<SCEVCouldNotCompute>(SV)) {
14282 OS << " U: ";
14283 SE.getUnsignedRange(SV).print(OS);
14284 OS << " S: ";
14285 SE.getSignedRange(SV).print(OS);
14286 }
14287
14288 const Loop *L = LI.getLoopFor(I.getParent());
14289
14290 const SCEV *AtUse = SE.getSCEVAtScope(SV, L);
14291 if (AtUse != SV) {
14292 OS << " --> ";
14293 AtUse->print(OS);
14294 if (!isa<SCEVCouldNotCompute>(AtUse)) {
14295 OS << " U: ";
14296 SE.getUnsignedRange(AtUse).print(OS);
14297 OS << " S: ";
14298 SE.getSignedRange(AtUse).print(OS);
14299 }
14300 }
14301
14302 if (L) {
14303 OS << "\t\t" "Exits: ";
14304 const SCEV *ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop());
14305 if (!SE.isLoopInvariant(ExitValue, L)) {
14306 OS << "<<Unknown>>";
14307 } else {
14308 OS << *ExitValue;
14309 }
14310
14311 ListSeparator LS(", ", "\t\tLoopDispositions: { ");
14312 for (const auto *Iter = L; Iter; Iter = Iter->getParentLoop()) {
14313 OS << LS;
14314 Iter->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14315 OS << ": " << SE.getLoopDisposition(SV, Iter);
14316 }
14317
14318 for (const auto *InnerL : depth_first(L)) {
14319 if (InnerL == L)
14320 continue;
14321 OS << LS;
14322 InnerL->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14323 OS << ": " << SE.getLoopDisposition(SV, InnerL);
14324 }
14325
14326 OS << " }";
14327 }
14328
14329 OS << "\n";
14330 }
14331 }
14332
14333 OS << "Determining loop execution counts for: ";
14334 F.printAsOperand(OS, /*PrintType=*/false);
14335 OS << "\n";
14336 for (Loop *I : LI)
14337 PrintLoopInfo(OS, &SE, I);
14338}
14339
14342 auto &Values = LoopDispositions[S];
14343 for (auto &V : Values) {
14344 if (V.getPointer() == L)
14345 return V.getInt();
14346 }
14347 Values.emplace_back(L, LoopVariant);
14348 LoopDisposition D = computeLoopDisposition(S, L);
14349 auto &Values2 = LoopDispositions[S];
14350 for (auto &V : llvm::reverse(Values2)) {
14351 if (V.getPointer() == L) {
14352 V.setInt(D);
14353 break;
14354 }
14355 }
14356 return D;
14357}
14358
14360ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) {
14361 switch (S->getSCEVType()) {
14362 case scConstant:
14363 case scVScale:
14364 return LoopInvariant;
14365 case scAddRecExpr: {
14366 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
14367
14368 // If L is the addrec's loop, it's computable.
14369 if (AR->getLoop() == L)
14370 return LoopComputable;
14371
14372 // Add recurrences are never invariant in the function-body (null loop).
14373 if (!L)
14374 return LoopVariant;
14375
14376 // Everything that is not defined at loop entry is variant.
14377 if (DT.dominates(L->getHeader(), AR->getLoop()->getHeader())) {
14378 if (L->contains(AR->getLoop()) &&
14379 llvm::all_of(AR->operands(),
14380 [&](const SCEV *Op) { return isLoopUniform(Op, L); }))
14381 return LoopUniform;
14382
14383 return LoopVariant;
14384 }
14385 assert(!L->contains(AR->getLoop()) && "Containing loop's header does not"
14386 " dominate the contained loop's header?");
14387
14388 // This recurrence is invariant w.r.t. L if AR's loop contains L.
14389 if (AR->getLoop()->contains(L))
14390 return LoopInvariant;
14391
14392 // This recurrence is variant w.r.t. L if any of its operands
14393 // are variant.
14394 for (SCEVUse Op : AR->operands())
14395 if (!isLoopInvariant(Op, L))
14396 return LoopVariant;
14397
14398 // Otherwise it's loop-invariant.
14399 return LoopInvariant;
14400 }
14401 case scTruncate:
14402 case scZeroExtend:
14403 case scSignExtend:
14404 case scPtrToAddr:
14405 case scPtrToInt:
14406 case scAddExpr:
14407 case scMulExpr:
14408 case scUDivExpr:
14409 case scUMaxExpr:
14410 case scSMaxExpr:
14411 case scUMinExpr:
14412 case scSMinExpr:
14413 case scSequentialUMinExpr: {
14414 bool HasVarying = false;
14415 bool HasUniform = false;
14416 for (SCEVUse Op : S->operands()) {
14418 if (D == LoopVariant)
14419 return LoopVariant;
14420 if (D == LoopComputable)
14421 HasVarying = true;
14422 if (D == LoopUniform)
14423 HasUniform = true;
14424 }
14425 return HasVarying ? (HasUniform ? LoopVariant : LoopComputable)
14426 : (HasUniform ? LoopUniform : LoopInvariant);
14427 }
14428 case scUnknown:
14429 // All non-instruction values are loop invariant. All instructions are loop
14430 // invariant if they are not contained in the specified loop.
14431 // Instructions are never considered invariant in the function body
14432 // (null loop) because they are defined within the "loop".
14434 return (L && !L->contains(I)) ? LoopInvariant : LoopVariant;
14435 return LoopInvariant;
14436 case scCouldNotCompute:
14437 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
14438 }
14439 llvm_unreachable("Unknown SCEV kind!");
14440}
14441
14442bool ScalarEvolution::isLoopUniform(const SCEV *S, const Loop *L) {
14444 return D == LoopUniform || D == LoopInvariant;
14445}
14446
14448 return getLoopDisposition(S, L) == LoopInvariant;
14449}
14450
14452 return getLoopDisposition(S, L) == LoopComputable;
14453}
14454
14457 auto &Values = BlockDispositions[S];
14458 for (auto &V : Values) {
14459 if (V.getPointer() == BB)
14460 return V.getInt();
14461 }
14462 Values.emplace_back(BB, DoesNotDominateBlock);
14463 BlockDisposition D = computeBlockDisposition(S, BB);
14464 auto &Values2 = BlockDispositions[S];
14465 for (auto &V : llvm::reverse(Values2)) {
14466 if (V.getPointer() == BB) {
14467 V.setInt(D);
14468 break;
14469 }
14470 }
14471 return D;
14472}
14473
14475ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) {
14476 switch (S->getSCEVType()) {
14477 case scConstant:
14478 case scVScale:
14480 case scAddRecExpr: {
14481 // This uses a "dominates" query instead of "properly dominates" query
14482 // to test for proper dominance too, because the instruction which
14483 // produces the addrec's value is a PHI, and a PHI effectively properly
14484 // dominates its entire containing block.
14485 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
14486 if (!DT.dominates(AR->getLoop()->getHeader(), BB))
14487 return DoesNotDominateBlock;
14488
14489 // Fall through into SCEVNAryExpr handling.
14490 [[fallthrough]];
14491 }
14492 case scTruncate:
14493 case scZeroExtend:
14494 case scSignExtend:
14495 case scPtrToAddr:
14496 case scPtrToInt:
14497 case scAddExpr:
14498 case scMulExpr:
14499 case scUDivExpr:
14500 case scUMaxExpr:
14501 case scSMaxExpr:
14502 case scUMinExpr:
14503 case scSMinExpr:
14504 case scSequentialUMinExpr: {
14505 bool Proper = true;
14506 for (const SCEV *NAryOp : S->operands()) {
14508 if (D == DoesNotDominateBlock)
14509 return DoesNotDominateBlock;
14510 if (D == DominatesBlock)
14511 Proper = false;
14512 }
14513 return Proper ? ProperlyDominatesBlock : DominatesBlock;
14514 }
14515 case scUnknown:
14516 if (Instruction *I =
14518 if (I->getParent() == BB)
14519 return DominatesBlock;
14520 if (DT.properlyDominates(I->getParent(), BB))
14522 return DoesNotDominateBlock;
14523 }
14525 case scCouldNotCompute:
14526 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
14527 }
14528 llvm_unreachable("Unknown SCEV kind!");
14529}
14530
14531bool ScalarEvolution::dominates(const SCEV *S, const BasicBlock *BB) {
14532 return getBlockDisposition(S, BB) >= DominatesBlock;
14533}
14534
14537}
14538
14539bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const {
14540 return SCEVExprContains(S, [&](const SCEV *Expr) { return Expr == Op; });
14541}
14542
14543void ScalarEvolution::forgetBackedgeTakenCounts(const Loop *L,
14544 bool Predicated) {
14545 auto &BECounts =
14546 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
14547 auto It = BECounts.find(L);
14548 if (It != BECounts.end()) {
14549 for (const ExitNotTakenInfo &ENT : It->second.ExitNotTaken) {
14550 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
14551 if (!isa<SCEVConstant>(S)) {
14552 auto UserIt = BECountUsers.find(S);
14553 assert(UserIt != BECountUsers.end());
14554 UserIt->second.erase({L, Predicated});
14555 }
14556 }
14557 }
14558 BECounts.erase(It);
14559 }
14560}
14561
14562void ScalarEvolution::forgetMemoizedResults(ArrayRef<SCEVUse> SCEVs) {
14563 SmallPtrSet<const SCEV *, 8> ToForget(llvm::from_range, SCEVs);
14564 SmallVector<SCEVUse, 8> Worklist(ToForget.begin(), ToForget.end());
14565
14566 while (!Worklist.empty()) {
14567 const SCEV *Curr = Worklist.pop_back_val();
14568 auto Users = SCEVUsers.find(Curr);
14569 if (Users != SCEVUsers.end())
14570 for (const auto *User : Users->second)
14571 if (ToForget.insert(User).second)
14572 Worklist.push_back(User);
14573 }
14574
14575 for (const auto *S : ToForget)
14576 forgetMemoizedResultsImpl(S);
14577
14578 PredicatedSCEVRewrites.remove_if(
14579 [&](const auto &Entry) { return ToForget.count(Entry.first.first); });
14580}
14581
14582void ScalarEvolution::forgetMemoizedResultsImpl(const SCEV *S) {
14583 LoopDispositions.erase(S);
14584 BlockDispositions.erase(S);
14585 UnsignedRanges.erase(S);
14586 SignedRanges.erase(S);
14587 HasRecMap.erase(S);
14588 ConstantMultipleCache.erase(S);
14589
14590 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S)) {
14591 UnsignedWrapViaInductionTried.erase(AR);
14592 SignedWrapViaInductionTried.erase(AR);
14593 }
14594
14595 auto ExprIt = ExprValueMap.find(S);
14596 if (ExprIt != ExprValueMap.end()) {
14597 for (Value *V : ExprIt->second) {
14598 auto ValueIt = ValueExprMap.find_as(V);
14599 if (ValueIt != ValueExprMap.end())
14600 ValueExprMap.erase(ValueIt);
14601 }
14602 ExprValueMap.erase(ExprIt);
14603 }
14604
14605 auto ScopeIt = ValuesAtScopes.find(S);
14606 if (ScopeIt != ValuesAtScopes.end()) {
14607 for (const auto &Pair : ScopeIt->second)
14608 if (!isa_and_nonnull<SCEVConstant>(Pair.second))
14609 llvm::erase(ValuesAtScopesUsers[Pair.second],
14610 std::make_pair(Pair.first, S));
14611 ValuesAtScopes.erase(ScopeIt);
14612 }
14613
14614 auto ScopeUserIt = ValuesAtScopesUsers.find(S);
14615 if (ScopeUserIt != ValuesAtScopesUsers.end()) {
14616 for (const auto &Pair : ScopeUserIt->second)
14617 llvm::erase(ValuesAtScopes[Pair.second], std::make_pair(Pair.first, S));
14618 ValuesAtScopesUsers.erase(ScopeUserIt);
14619 }
14620
14621 auto BEUsersIt = BECountUsers.find(S);
14622 if (BEUsersIt != BECountUsers.end()) {
14623 // Work on a copy, as forgetBackedgeTakenCounts() will modify the original.
14624 auto Copy = BEUsersIt->second;
14625 for (const auto &Pair : Copy)
14626 forgetBackedgeTakenCounts(Pair.getPointer(), Pair.getInt());
14627 BECountUsers.erase(BEUsersIt);
14628 }
14629
14630 auto FoldUser = FoldCacheUser.find(S);
14631 if (FoldUser != FoldCacheUser.end())
14632 for (auto &KV : FoldUser->second)
14633 FoldCache.erase(KV);
14634 FoldCacheUser.erase(S);
14635}
14636
14637void
14638ScalarEvolution::getUsedLoops(const SCEV *S,
14639 SmallPtrSetImpl<const Loop *> &LoopsUsed) {
14640 struct FindUsedLoops {
14641 FindUsedLoops(SmallPtrSetImpl<const Loop *> &LoopsUsed)
14642 : LoopsUsed(LoopsUsed) {}
14643 SmallPtrSetImpl<const Loop *> &LoopsUsed;
14644 bool follow(const SCEV *S) {
14645 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S))
14646 LoopsUsed.insert(AR->getLoop());
14647 return true;
14648 }
14649
14650 bool isDone() const { return false; }
14651 };
14652
14653 FindUsedLoops F(LoopsUsed);
14654 SCEVTraversal<FindUsedLoops>(F).visitAll(S);
14655}
14656
14657void ScalarEvolution::getReachableBlocks(
14660 Worklist.push_back(&F.getEntryBlock());
14661 while (!Worklist.empty()) {
14662 BasicBlock *BB = Worklist.pop_back_val();
14663 if (!Reachable.insert(BB).second)
14664 continue;
14665
14666 Value *Cond;
14667 BasicBlock *TrueBB, *FalseBB;
14668 if (match(BB->getTerminator(), m_Br(m_Value(Cond), m_BasicBlock(TrueBB),
14669 m_BasicBlock(FalseBB)))) {
14670 if (auto *C = dyn_cast<ConstantInt>(Cond)) {
14671 Worklist.push_back(C->isOne() ? TrueBB : FalseBB);
14672 continue;
14673 }
14674
14675 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
14676 const SCEV *L = getSCEV(Cmp->getOperand(0));
14677 const SCEV *R = getSCEV(Cmp->getOperand(1));
14678 if (isKnownPredicateViaConstantRanges(Cmp->getCmpPredicate(), L, R)) {
14679 Worklist.push_back(TrueBB);
14680 continue;
14681 }
14682 if (isKnownPredicateViaConstantRanges(Cmp->getInverseCmpPredicate(), L,
14683 R)) {
14684 Worklist.push_back(FalseBB);
14685 continue;
14686 }
14687 }
14688 }
14689
14690 append_range(Worklist, successors(BB));
14691 }
14692}
14693
14695 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
14696 ScalarEvolution SE2(F, TLI, AC, DT, LI);
14697
14698 SmallVector<Loop *, 8> LoopStack(LI.begin(), LI.end());
14699
14700 // Map's SCEV expressions from one ScalarEvolution "universe" to another.
14701 struct SCEVMapper : public SCEVRewriteVisitor<SCEVMapper> {
14702 SCEVMapper(ScalarEvolution &SE) : SCEVRewriteVisitor<SCEVMapper>(SE) {}
14703
14704 const SCEV *visitConstant(const SCEVConstant *Constant) {
14705 return SE.getConstant(Constant->getAPInt());
14706 }
14707
14708 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14709 return SE.getUnknown(Expr->getValue());
14710 }
14711
14712 const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
14713 return SE.getCouldNotCompute();
14714 }
14715 };
14716
14717 SCEVMapper SCM(SE2);
14718 SmallPtrSet<BasicBlock *, 16> ReachableBlocks;
14719 SE2.getReachableBlocks(ReachableBlocks, F);
14720
14721 auto GetDelta = [&](const SCEV *Old, const SCEV *New) -> const SCEV * {
14722 if (containsUndefs(Old) || containsUndefs(New)) {
14723 // SCEV treats "undef" as an unknown but consistent value (i.e. it does
14724 // not propagate undef aggressively). This means we can (and do) fail
14725 // verification in cases where a transform makes a value go from "undef"
14726 // to "undef+1" (say). The transform is fine, since in both cases the
14727 // result is "undef", but SCEV thinks the value increased by 1.
14728 return nullptr;
14729 }
14730
14731 // Unless VerifySCEVStrict is set, we only compare constant deltas.
14732 const SCEV *Delta = SE2.getMinusSCEV(Old, New);
14733 if (!VerifySCEVStrict && !isa<SCEVConstant>(Delta))
14734 return nullptr;
14735
14736 return Delta;
14737 };
14738
14739 while (!LoopStack.empty()) {
14740 auto *L = LoopStack.pop_back_val();
14741 llvm::append_range(LoopStack, *L);
14742
14743 // Only verify BECounts in reachable loops. For an unreachable loop,
14744 // any BECount is legal.
14745 if (!ReachableBlocks.contains(L->getHeader()))
14746 continue;
14747
14748 // Only verify cached BECounts. Computing new BECounts may change the
14749 // results of subsequent SCEV uses.
14750 auto It = BackedgeTakenCounts.find(L);
14751 if (It == BackedgeTakenCounts.end())
14752 continue;
14753
14754 auto *CurBECount =
14755 SCM.visit(It->second.getExact(L, const_cast<ScalarEvolution *>(this)));
14756 auto *NewBECount = SE2.getBackedgeTakenCount(L);
14757
14758 if (CurBECount == SE2.getCouldNotCompute() ||
14759 NewBECount == SE2.getCouldNotCompute()) {
14760 // NB! This situation is legal, but is very suspicious -- whatever pass
14761 // change the loop to make a trip count go from could not compute to
14762 // computable or vice-versa *should have* invalidated SCEV. However, we
14763 // choose not to assert here (for now) since we don't want false
14764 // positives.
14765 continue;
14766 }
14767
14768 if (SE.getTypeSizeInBits(CurBECount->getType()) >
14769 SE.getTypeSizeInBits(NewBECount->getType()))
14770 NewBECount = SE2.getZeroExtendExpr(NewBECount, CurBECount->getType());
14771 else if (SE.getTypeSizeInBits(CurBECount->getType()) <
14772 SE.getTypeSizeInBits(NewBECount->getType()))
14773 CurBECount = SE2.getZeroExtendExpr(CurBECount, NewBECount->getType());
14774
14775 const SCEV *Delta = GetDelta(CurBECount, NewBECount);
14776 if (Delta && !Delta->isZero()) {
14777 dbgs() << "Trip Count for " << *L << " Changed!\n";
14778 dbgs() << "Old: " << *CurBECount << "\n";
14779 dbgs() << "New: " << *NewBECount << "\n";
14780 dbgs() << "Delta: " << *Delta << "\n";
14781 std::abort();
14782 }
14783 }
14784
14785 // Collect all valid loops currently in LoopInfo.
14786 SmallPtrSet<Loop *, 32> ValidLoops;
14787 SmallVector<Loop *, 32> Worklist(LI.begin(), LI.end());
14788 while (!Worklist.empty()) {
14789 Loop *L = Worklist.pop_back_val();
14790 if (ValidLoops.insert(L).second)
14791 Worklist.append(L->begin(), L->end());
14792 }
14793 for (const auto &KV : ValueExprMap) {
14794#ifndef NDEBUG
14795 // Check for SCEV expressions referencing invalid/deleted loops.
14796 if (auto *AR = dyn_cast<SCEVAddRecExpr>(KV.second)) {
14797 assert(ValidLoops.contains(AR->getLoop()) &&
14798 "AddRec references invalid loop");
14799 }
14800#endif
14801
14802 // Check that the value is also part of the reverse map.
14803 auto It = ExprValueMap.find(KV.second);
14804 if (It == ExprValueMap.end() || !It->second.contains(KV.first)) {
14805 dbgs() << "Value " << *KV.first
14806 << " is in ValueExprMap but not in ExprValueMap\n";
14807 std::abort();
14808 }
14809
14810 if (auto *I = dyn_cast<Instruction>(&*KV.first)) {
14811 if (!ReachableBlocks.contains(I->getParent()))
14812 continue;
14813 const SCEV *OldSCEV = SCM.visit(KV.second);
14814 const SCEV *NewSCEV = SE2.getSCEV(I);
14815 const SCEV *Delta = GetDelta(OldSCEV, NewSCEV);
14816 if (Delta && !Delta->isZero()) {
14817 dbgs() << "SCEV for value " << *I << " changed!\n"
14818 << "Old: " << *OldSCEV << "\n"
14819 << "New: " << *NewSCEV << "\n"
14820 << "Delta: " << *Delta << "\n";
14821 std::abort();
14822 }
14823 }
14824 }
14825
14826 for (const auto &KV : ExprValueMap) {
14827 for (Value *V : KV.second) {
14828 const SCEV *S = ValueExprMap.lookup(V);
14829 if (!S) {
14830 dbgs() << "Value " << *V
14831 << " is in ExprValueMap but not in ValueExprMap\n";
14832 std::abort();
14833 }
14834 if (S != KV.first) {
14835 dbgs() << "Value " << *V << " mapped to " << *S << " rather than "
14836 << *KV.first << "\n";
14837 std::abort();
14838 }
14839 }
14840 }
14841
14842 // Verify integrity of SCEV users.
14843 for (const auto &S : UniqueSCEVs) {
14844 for (SCEVUse Op : S.operands()) {
14845 // We do not store dependencies of constants.
14846 if (isa<SCEVConstant>(Op))
14847 continue;
14848 auto It = SCEVUsers.find(Op);
14849 if (It != SCEVUsers.end() && It->second.count(&S))
14850 continue;
14851 dbgs() << "Use of operand " << *Op << " by user " << S
14852 << " is not being tracked!\n";
14853 std::abort();
14854 }
14855 }
14856
14857 // Verify integrity of ValuesAtScopes users.
14858 for (const auto &ValueAndVec : ValuesAtScopes) {
14859 const SCEV *Value = ValueAndVec.first;
14860 for (const auto &LoopAndValueAtScope : ValueAndVec.second) {
14861 const Loop *L = LoopAndValueAtScope.first;
14862 const SCEV *ValueAtScope = LoopAndValueAtScope.second;
14863 if (!isa<SCEVConstant>(ValueAtScope)) {
14864 auto It = ValuesAtScopesUsers.find(ValueAtScope);
14865 if (It != ValuesAtScopesUsers.end() &&
14866 is_contained(It->second, std::make_pair(L, Value)))
14867 continue;
14868 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14869 << *ValueAtScope << " missing in ValuesAtScopesUsers\n";
14870 std::abort();
14871 }
14872 }
14873 }
14874
14875 for (const auto &ValueAtScopeAndVec : ValuesAtScopesUsers) {
14876 const SCEV *ValueAtScope = ValueAtScopeAndVec.first;
14877 for (const auto &LoopAndValue : ValueAtScopeAndVec.second) {
14878 const Loop *L = LoopAndValue.first;
14879 const SCEV *Value = LoopAndValue.second;
14881 auto It = ValuesAtScopes.find(Value);
14882 if (It != ValuesAtScopes.end() &&
14883 is_contained(It->second, std::make_pair(L, ValueAtScope)))
14884 continue;
14885 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14886 << *ValueAtScope << " missing in ValuesAtScopes\n";
14887 std::abort();
14888 }
14889 }
14890
14891 // Verify integrity of BECountUsers.
14892 auto VerifyBECountUsers = [&](bool Predicated) {
14893 auto &BECounts =
14894 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
14895 for (const auto &LoopAndBEInfo : BECounts) {
14896 for (const ExitNotTakenInfo &ENT : LoopAndBEInfo.second.ExitNotTaken) {
14897 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
14898 if (!isa<SCEVConstant>(S)) {
14899 auto UserIt = BECountUsers.find(S);
14900 if (UserIt != BECountUsers.end() &&
14901 UserIt->second.contains({ LoopAndBEInfo.first, Predicated }))
14902 continue;
14903 dbgs() << "Value " << *S << " for loop " << *LoopAndBEInfo.first
14904 << " missing from BECountUsers\n";
14905 std::abort();
14906 }
14907 }
14908 }
14909 }
14910 };
14911 VerifyBECountUsers(/* Predicated */ false);
14912 VerifyBECountUsers(/* Predicated */ true);
14913
14914 // Verify intergity of loop disposition cache.
14915 for (auto &[S, Values] : LoopDispositions) {
14916 for (auto [Loop, CachedDisposition] : Values) {
14917 const auto RecomputedDisposition = SE2.getLoopDisposition(S, Loop);
14918 if (CachedDisposition != RecomputedDisposition) {
14919 dbgs() << "Cached disposition of " << *S << " for loop " << *Loop
14920 << " is incorrect: cached " << CachedDisposition << ", actual "
14921 << RecomputedDisposition << "\n";
14922 std::abort();
14923 }
14924 }
14925 }
14926
14927 // Verify integrity of the block disposition cache.
14928 for (auto &[S, Values] : BlockDispositions) {
14929 for (auto [BB, CachedDisposition] : Values) {
14930 const auto RecomputedDisposition = SE2.getBlockDisposition(S, BB);
14931 if (CachedDisposition != RecomputedDisposition) {
14932 dbgs() << "Cached disposition of " << *S << " for block %"
14933 << BB->getName() << " is incorrect: cached " << CachedDisposition
14934 << ", actual " << RecomputedDisposition << "\n";
14935 std::abort();
14936 }
14937 }
14938 }
14939
14940 // Verify FoldCache/FoldCacheUser caches.
14941 for (auto [FoldID, Expr] : FoldCache) {
14942 auto I = FoldCacheUser.find(Expr);
14943 if (I == FoldCacheUser.end()) {
14944 dbgs() << "Missing entry in FoldCacheUser for cached expression " << *Expr
14945 << "!\n";
14946 std::abort();
14947 }
14948 if (!is_contained(I->second, FoldID)) {
14949 dbgs() << "Missing FoldID in cached users of " << *Expr << "!\n";
14950 std::abort();
14951 }
14952 }
14953 for (auto [Expr, IDs] : FoldCacheUser) {
14954 for (auto &FoldID : IDs) {
14955 const SCEV *S = FoldCache.lookup(FoldID);
14956 if (!S) {
14957 dbgs() << "Missing entry in FoldCache for expression " << *Expr
14958 << "!\n";
14959 std::abort();
14960 }
14961 if (S != Expr) {
14962 dbgs() << "Entry in FoldCache doesn't match FoldCacheUser: " << *S
14963 << " != " << *Expr << "!\n";
14964 std::abort();
14965 }
14966 }
14967 }
14968
14969 // Verify that ConstantMultipleCache computations are correct. We check that
14970 // cached multiples and recomputed multiples are multiples of each other to
14971 // verify correctness. It is possible that a recomputed multiple is different
14972 // from the cached multiple due to strengthened no wrap flags or changes in
14973 // KnownBits computations.
14974 for (auto [S, Multiple] : ConstantMultipleCache) {
14975 APInt RecomputedMultiple = SE2.getConstantMultiple(S);
14976 if ((Multiple != 0 && RecomputedMultiple != 0 &&
14977 Multiple.urem(RecomputedMultiple) != 0 &&
14978 RecomputedMultiple.urem(Multiple) != 0)) {
14979 dbgs() << "Incorrect cached computation in ConstantMultipleCache for "
14980 << *S << " : Computed " << RecomputedMultiple
14981 << " but cache contains " << Multiple << "!\n";
14982 std::abort();
14983 }
14984 }
14985}
14986
14988 Function &F, const PreservedAnalyses &PA,
14989 FunctionAnalysisManager::Invalidator &Inv) {
14990 // Invalidate the ScalarEvolution object whenever it isn't preserved or one
14991 // of its dependencies is invalidated.
14992 auto PAC = PA.getChecker<ScalarEvolutionAnalysis>();
14993 return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>()) ||
14994 Inv.invalidate<AssumptionAnalysis>(F, PA) ||
14995 Inv.invalidate<DominatorTreeAnalysis>(F, PA) ||
14996 Inv.invalidate<LoopAnalysis>(F, PA);
14997}
14998
14999AnalysisKey ScalarEvolutionAnalysis::Key;
15000
15003 auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
15004 auto &AC = AM.getResult<AssumptionAnalysis>(F);
15005 auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
15006 auto &LI = AM.getResult<LoopAnalysis>(F);
15007 return ScalarEvolution(F, TLI, AC, DT, LI);
15008}
15009
15015
15018 // For compatibility with opt's -analyze feature under legacy pass manager
15019 // which was not ported to NPM. This keeps tests using
15020 // update_analyze_test_checks.py working.
15021 OS << "Printing analysis 'Scalar Evolution Analysis' for function '"
15022 << F.getName() << "':\n";
15024 return PreservedAnalyses::all();
15025}
15026
15028 "Scalar Evolution Analysis", false, true)
15034 "Scalar Evolution Analysis", false, true)
15035
15037
15039
15041 SE.reset(new ScalarEvolution(
15043 getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F),
15045 getAnalysis<LoopInfoWrapperPass>().getLoopInfo()));
15046 return false;
15047}
15048
15050
15052 SE->print(OS);
15053}
15054
15056 if (!VerifySCEV)
15057 return;
15058
15059 SE->verify();
15060}
15061
15069
15071 const SCEV *RHS) {
15072 return getComparePredicate(ICmpInst::ICMP_EQ, LHS, RHS);
15073}
15074
15075const SCEVPredicate *
15077 const SCEV *LHS, const SCEV *RHS) {
15079 assert(LHS->getType() == RHS->getType() &&
15080 "Type mismatch between LHS and RHS");
15081 // Unique this node based on the arguments
15082 ID.AddInteger(SCEVPredicate::P_Compare);
15083 ID.AddInteger(Pred);
15084 ID.AddPointer(LHS);
15085 ID.AddPointer(RHS);
15086 void *IP = nullptr;
15087 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
15088 return S;
15089 SCEVComparePredicate *Eq = new (SCEVAllocator)
15090 SCEVComparePredicate(ID.Intern(SCEVAllocator), Pred, LHS, RHS);
15091 UniquePreds.InsertNode(Eq, IP);
15092 return Eq;
15093}
15094
15096 const SCEVAddRecExpr *AR,
15099 // Unique this node based on the arguments
15100 ID.AddInteger(SCEVPredicate::P_Wrap);
15101 ID.AddPointer(AR);
15102 ID.AddInteger(AddedFlags);
15103 void *IP = nullptr;
15104 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
15105 return S;
15106 auto *OF = new (SCEVAllocator)
15107 SCEVWrapPredicate(ID.Intern(SCEVAllocator), AR, AddedFlags);
15108 UniquePreds.InsertNode(OF, IP);
15109 return OF;
15110}
15111
15112namespace {
15113
15114class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
15115public:
15116
15117 /// Rewrites \p S in the context of a loop L and the SCEV predication
15118 /// infrastructure.
15119 ///
15120 /// If \p Pred is non-null, the SCEV expression is rewritten to respect the
15121 /// equivalences present in \p Pred.
15122 ///
15123 /// If \p NewPreds is non-null, rewrite is free to add further predicates to
15124 /// \p NewPreds such that the result will be an AddRecExpr.
15125 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
15127 const SCEVPredicate *Pred) {
15128 SCEVPredicateRewriter Rewriter(L, SE, NewPreds, Pred);
15129 return Rewriter.visit(S);
15130 }
15131
15132 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
15133 if (Pred) {
15134 if (auto *U = dyn_cast<SCEVUnionPredicate>(Pred)) {
15135 for (const auto *Pred : U->getPredicates())
15136 if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred))
15137 if (IPred->getLHS() == Expr &&
15138 IPred->getPredicate() == ICmpInst::ICMP_EQ)
15139 return IPred->getRHS();
15140 } else if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred)) {
15141 if (IPred->getLHS() == Expr &&
15142 IPred->getPredicate() == ICmpInst::ICMP_EQ)
15143 return IPred->getRHS();
15144 }
15145 }
15146 return convertToAddRecWithPreds(Expr);
15147 }
15148
15149 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
15150 const SCEV *Operand = visit(Expr->getOperand());
15151 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
15152 if (AR && AR->getLoop() == L && AR->isAffine()) {
15153 // This couldn't be folded because the operand didn't have the nuw
15154 // flag. Add the nusw flag as an assumption that we could make.
15155 const SCEV *Step = AR->getStepRecurrence(SE);
15156 Type *Ty = Expr->getType();
15157 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNUSW))
15158 return SE.getAddRecExpr(SE.getZeroExtendExpr(AR->getStart(), Ty),
15159 SE.getSignExtendExpr(Step, Ty), L,
15160 AR->getNoWrapFlags());
15161 }
15162 return SE.getZeroExtendExpr(Operand, Expr->getType());
15163 }
15164
15165 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
15166 const SCEV *Operand = visit(Expr->getOperand());
15167 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
15168 if (AR && AR->getLoop() == L && AR->isAffine()) {
15169 // This couldn't be folded because the operand didn't have the nsw
15170 // flag. Add the nssw flag as an assumption that we could make.
15171 const SCEV *Step = AR->getStepRecurrence(SE);
15172 Type *Ty = Expr->getType();
15173 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNSSW))
15174 return SE.getAddRecExpr(SE.getSignExtendExpr(AR->getStart(), Ty),
15175 SE.getSignExtendExpr(Step, Ty), L,
15176 AR->getNoWrapFlags());
15177 }
15178 return SE.getSignExtendExpr(Operand, Expr->getType());
15179 }
15180
15181private:
15182 explicit SCEVPredicateRewriter(
15183 const Loop *L, ScalarEvolution &SE,
15184 SmallVectorImpl<const SCEVPredicate *> *NewPreds,
15185 const SCEVPredicate *Pred)
15186 : SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {}
15187
15188 bool addOverflowAssumption(const SCEVPredicate *P) {
15189 if (!NewPreds) {
15190 // Check if we've already made this assumption.
15191 return Pred && Pred->implies(P, SE);
15192 }
15193 NewPreds->push_back(P);
15194 return true;
15195 }
15196
15197 bool addOverflowAssumption(const SCEVAddRecExpr *AR,
15199 auto *A = SE.getWrapPredicate(AR, AddedFlags);
15200 return addOverflowAssumption(A);
15201 }
15202
15203 // If \p Expr represents a PHINode, we try to see if it can be represented
15204 // as an AddRec, possibly under a predicate (PHISCEVPred). If it is possible
15205 // to add this predicate as a runtime overflow check, we return the AddRec.
15206 // If \p Expr does not meet these conditions (is not a PHI node, or we
15207 // couldn't create an AddRec for it, or couldn't add the predicate), we just
15208 // return \p Expr.
15209 const SCEV *convertToAddRecWithPreds(const SCEVUnknown *Expr) {
15210 if (!isa<PHINode>(Expr->getValue()))
15211 return Expr;
15212 std::optional<
15213 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
15214 PredicatedRewrite = SE.createAddRecFromPHIWithCasts(Expr);
15215 if (!PredicatedRewrite)
15216 return Expr;
15217 for (const auto *P : PredicatedRewrite->second){
15218 // Wrap predicates from outer loops are not supported.
15219 if (auto *WP = dyn_cast<const SCEVWrapPredicate>(P)) {
15220 if (L != WP->getExpr()->getLoop())
15221 return Expr;
15222 }
15223 if (!addOverflowAssumption(P))
15224 return Expr;
15225 }
15226 return PredicatedRewrite->first;
15227 }
15228
15229 SmallVectorImpl<const SCEVPredicate *> *NewPreds;
15230 const SCEVPredicate *Pred;
15231 const Loop *L;
15232};
15233
15234} // end anonymous namespace
15235
15236const SCEV *
15238 const SCEVPredicate &Preds) {
15239 return SCEVPredicateRewriter::rewrite(S, L, *this, nullptr, &Preds);
15240}
15241
15243 const SCEV *S, const Loop *L,
15246 S = SCEVPredicateRewriter::rewrite(S, L, *this, &TransformPreds, nullptr);
15247 auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
15248
15249 if (!AddRec)
15250 return nullptr;
15251
15252 // Check if any of the transformed predicates is known to be false. In that
15253 // case, it doesn't make sense to convert to a predicated AddRec, as the
15254 // versioned loop will never execute.
15255 for (const SCEVPredicate *Pred : TransformPreds) {
15256 auto *WrapPred = dyn_cast<SCEVWrapPredicate>(Pred);
15257 if (!WrapPred || WrapPred->getFlags() != SCEVWrapPredicate::IncrementNSSW)
15258 continue;
15259
15260 const SCEVAddRecExpr *AddRecToCheck = WrapPred->getExpr();
15261 const SCEV *ExitCount = getBackedgeTakenCount(AddRecToCheck->getLoop());
15262 if (isa<SCEVCouldNotCompute>(ExitCount))
15263 continue;
15264
15265 const SCEV *Step = AddRecToCheck->getStepRecurrence(*this);
15266 if (!Step->isOne())
15267 continue;
15268
15269 ExitCount = getTruncateOrSignExtend(ExitCount, Step->getType());
15270 const SCEV *Add = getAddExpr(AddRecToCheck->getStart(), ExitCount);
15271 if (isKnownPredicate(CmpInst::ICMP_SLT, Add, AddRecToCheck->getStart()))
15272 return nullptr;
15273 }
15274
15275 // Since the transformation was successful, we can now transfer the SCEV
15276 // predicates.
15277 Preds.append(TransformPreds.begin(), TransformPreds.end());
15278
15279 return AddRec;
15280}
15281
15282/// SCEV predicates
15286
15288 const ICmpInst::Predicate Pred,
15289 const SCEV *LHS, const SCEV *RHS)
15290 : SCEVPredicate(ID, P_Compare), Pred(Pred), LHS(LHS), RHS(RHS) {
15291 assert(LHS->getType() == RHS->getType() && "LHS and RHS types don't match");
15292 assert(LHS != RHS && "LHS and RHS are the same SCEV");
15293}
15294
15296 ScalarEvolution &SE) const {
15297 const auto *Op = dyn_cast<SCEVComparePredicate>(N);
15298
15299 if (!Op)
15300 return false;
15301
15302 if (Pred != ICmpInst::ICMP_EQ)
15303 return false;
15304
15305 return Op->LHS == LHS && Op->RHS == RHS;
15306}
15307
15308bool SCEVComparePredicate::isAlwaysTrue() const { return false; }
15309
15311 if (Pred == ICmpInst::ICMP_EQ)
15312 OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n";
15313 else
15314 OS.indent(Depth) << "Compare predicate: " << *LHS << " " << Pred << ") "
15315 << *RHS << "\n";
15316
15317}
15318
15320 const SCEVAddRecExpr *AR,
15321 IncrementWrapFlags Flags)
15322 : SCEVPredicate(ID, P_Wrap), AR(AR), Flags(Flags) {}
15323
15324const SCEVAddRecExpr *SCEVWrapPredicate::getExpr() const { return AR; }
15325
15327 ScalarEvolution &SE) const {
15328 const auto *Op = dyn_cast<SCEVWrapPredicate>(N);
15329 if (!Op || setFlags(Flags, Op->Flags) != Flags)
15330 return false;
15331
15332 if (Op->AR == AR)
15333 return true;
15334
15335 if (Flags != SCEVWrapPredicate::IncrementNSSW &&
15337 return false;
15338
15339 const SCEV *Start = AR->getStart();
15340 const SCEV *OpStart = Op->AR->getStart();
15341 if (Start->getType()->isPointerTy() != OpStart->getType()->isPointerTy())
15342 return false;
15343
15344 // Reject pointers to different address spaces.
15345 if (Start->getType()->isPointerTy() && Start->getType() != OpStart->getType())
15346 return false;
15347
15348 // NUSW/NSSW on a wider-type AddRec does not imply the same on a
15349 // narrower-type AddRec.
15350 if (SE.getTypeSizeInBits(AR->getType()) >
15351 SE.getTypeSizeInBits(Op->AR->getType()))
15352 return false;
15353
15354 const SCEV *Step = AR->getStepRecurrence(SE);
15355 const SCEV *OpStep = Op->AR->getStepRecurrence(SE);
15356 if (!SE.isKnownPositive(Step) || !SE.isKnownPositive(OpStep))
15357 return false;
15358
15359 // If both steps are positive, this implies N, if N's start and step are
15360 // ULE/SLE (for NSUW/NSSW) than this'.
15361 Type *WiderTy = SE.getWiderType(Step->getType(), OpStep->getType());
15362 Step = SE.getNoopOrZeroExtend(Step, WiderTy);
15363 OpStep = SE.getNoopOrZeroExtend(OpStep, WiderTy);
15364
15365 bool IsNUW = Flags == SCEVWrapPredicate::IncrementNUSW;
15366 OpStart = IsNUW ? SE.getNoopOrZeroExtend(OpStart, WiderTy)
15367 : SE.getNoopOrSignExtend(OpStart, WiderTy);
15368 Start = IsNUW ? SE.getNoopOrZeroExtend(Start, WiderTy)
15369 : SE.getNoopOrSignExtend(Start, WiderTy);
15371 return SE.isKnownPredicate(Pred, OpStep, Step) &&
15372 SE.isKnownPredicate(Pred, OpStart, Start);
15373}
15374
15376 SCEV::NoWrapFlags ScevFlags = AR->getNoWrapFlags();
15377 IncrementWrapFlags IFlags = Flags;
15378
15379 if (ScalarEvolution::setFlags(ScevFlags, SCEV::FlagNSW) == ScevFlags)
15380 IFlags = clearFlags(IFlags, IncrementNSSW);
15381
15382 return IFlags == IncrementAnyWrap;
15383}
15384
15385void SCEVWrapPredicate::print(raw_ostream &OS, unsigned Depth) const {
15386 OS.indent(Depth) << *getExpr() << " Added Flags: ";
15388 OS << "<nusw>";
15390 OS << "<nssw>";
15391 OS << "\n";
15392}
15393
15396 ScalarEvolution &SE) {
15397 IncrementWrapFlags ImpliedFlags = IncrementAnyWrap;
15398 SCEV::NoWrapFlags StaticFlags = AR->getNoWrapFlags();
15399
15400 // We can safely transfer the NSW flag as NSSW.
15401 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNSW) == StaticFlags)
15402 ImpliedFlags = IncrementNSSW;
15403
15404 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNUW) == StaticFlags) {
15405 // If the increment is positive, the SCEV NUW flag will also imply the
15406 // WrapPredicate NUSW flag.
15407 if (const auto *Step = dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE)))
15408 if (Step->getValue()->getValue().isNonNegative())
15409 ImpliedFlags = setFlags(ImpliedFlags, IncrementNUSW);
15410 }
15411
15412 return ImpliedFlags;
15413}
15414
15415/// Union predicates don't get cached so create a dummy set ID for it.
15417 ScalarEvolution &SE)
15418 : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
15419 for (const auto *P : Preds)
15420 add(P, SE);
15421}
15422
15424 return all_of(Preds,
15425 [](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
15426}
15427
15429 ScalarEvolution &SE) const {
15430 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N))
15431 return all_of(Set->Preds, [this, &SE](const SCEVPredicate *I) {
15432 return this->implies(I, SE);
15433 });
15434
15435 return any_of(Preds,
15436 [N, &SE](const SCEVPredicate *I) { return I->implies(N, SE); });
15437}
15438
15440 for (const auto *Pred : Preds)
15441 Pred->print(OS, Depth);
15442}
15443
15444void SCEVUnionPredicate::add(const SCEVPredicate *N, ScalarEvolution &SE) {
15445 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) {
15446 for (const auto *Pred : Set->Preds)
15447 add(Pred, SE);
15448 return;
15449 }
15450
15451 // Implication checks are quadratic in the number of predicates. Stop doing
15452 // them if there are many predicates, as they should be too expensive to use
15453 // anyway at that point.
15454 bool CheckImplies = Preds.size() < 16;
15455
15456 // Only add predicate if it is not already implied by this union predicate.
15457 if (CheckImplies && implies(N, SE))
15458 return;
15459
15460 // Build a new vector containing the current predicates, except the ones that
15461 // are implied by the new predicate N.
15463 for (auto *P : Preds) {
15464 if (CheckImplies && N->implies(P, SE))
15465 continue;
15466 PrunedPreds.push_back(P);
15467 }
15468 Preds = std::move(PrunedPreds);
15469 Preds.push_back(N);
15470}
15471
15473 Loop &L)
15474 : SE(SE), L(L) {
15476 Preds = std::make_unique<SCEVUnionPredicate>(Empty, SE);
15477}
15478
15481 for (const auto *Op : Ops)
15482 // We do not expect that forgetting cached data for SCEVConstants will ever
15483 // open any prospects for sharpening or introduce any correctness issues,
15484 // so we don't bother storing their dependencies.
15485 if (!isa<SCEVConstant>(Op))
15486 SCEVUsers[Op].insert(User);
15487}
15488
15490 for (const SCEV *Op : Ops)
15491 // We do not expect that forgetting cached data for SCEVConstants will ever
15492 // open any prospects for sharpening or introduce any correctness issues,
15493 // so we don't bother storing their dependencies.
15494 if (!isa<SCEVConstant>(Op))
15495 SCEVUsers[Op].insert(User);
15496}
15497
15499 const SCEV *Expr = SE.getSCEV(V);
15500 return getPredicatedSCEV(Expr);
15501}
15502
15504 RewriteEntry &Entry = RewriteMap[Expr];
15505
15506 // If we already have an entry and the version matches, return it.
15507 if (Entry.second && Generation == Entry.first)
15508 return Entry.second;
15509
15510 // We found an entry but it's stale. Rewrite the stale entry
15511 // according to the current predicate.
15512 if (Entry.second)
15513 Expr = Entry.second;
15514
15515 const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, &L, *Preds);
15516 Entry = {Generation, NewSCEV};
15517
15518 return NewSCEV;
15519}
15520
15522 if (!BackedgeCount) {
15524 BackedgeCount = SE.getPredicatedBackedgeTakenCount(&L, Preds);
15525 for (const auto *P : Preds)
15526 addPredicate(*P);
15527 }
15528 return BackedgeCount;
15529}
15530
15532 if (!SymbolicMaxBackedgeCount) {
15534 SymbolicMaxBackedgeCount =
15535 SE.getPredicatedSymbolicMaxBackedgeTakenCount(&L, Preds);
15536 for (const auto *P : Preds)
15537 addPredicate(*P);
15538 }
15539 return SymbolicMaxBackedgeCount;
15540}
15541
15543 if (!SmallConstantMaxTripCount) {
15545 SmallConstantMaxTripCount = SE.getSmallConstantMaxTripCount(&L, &Preds);
15546 for (const auto *P : Preds)
15547 addPredicate(*P);
15548 }
15549 return *SmallConstantMaxTripCount;
15550}
15551
15553 if (Preds->implies(&Pred, SE))
15554 return;
15555
15556 SmallVector<const SCEVPredicate *, 4> NewPreds(Preds->getPredicates());
15557 NewPreds.push_back(&Pred);
15558 Preds = std::make_unique<SCEVUnionPredicate>(NewPreds, SE);
15559 updateGeneration();
15560}
15561
15563 return *Preds;
15564}
15565
15566void PredicatedScalarEvolution::updateGeneration() {
15567 // If the generation number wrapped recompute everything.
15568 if (++Generation == 0) {
15569 for (auto &II : RewriteMap) {
15570 const SCEV *Rewritten = II.second.second;
15571 II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, &L, *Preds)};
15572 }
15573 }
15574}
15575
15578 const SCEV *Expr = getSCEV(V);
15579 const auto *AR = cast<SCEVAddRecExpr>(Expr);
15580
15581 auto ImpliedFlags = SCEVWrapPredicate::getImpliedFlags(AR, SE);
15582
15583 // Clear the statically implied flags.
15584 Flags = SCEVWrapPredicate::clearFlags(Flags, ImpliedFlags);
15585 addPredicate(*SE.getWrapPredicate(AR, Flags));
15586
15587 auto II = FlagsMap.insert({V, Flags});
15588 if (!II.second)
15589 II.first->second = SCEVWrapPredicate::setFlags(Flags, II.first->second);
15590}
15591
15594 const SCEV *Expr = getSCEV(V);
15595 const auto *AR = cast<SCEVAddRecExpr>(Expr);
15596
15598 Flags, SCEVWrapPredicate::getImpliedFlags(AR, SE));
15599
15600 auto II = FlagsMap.find(V);
15601
15602 if (II != FlagsMap.end())
15603 Flags = SCEVWrapPredicate::clearFlags(Flags, II->second);
15604
15606}
15607
15609 const SCEV *Expr = this->getSCEV(V);
15611 auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, NewPreds);
15612
15613 if (!New)
15614 return nullptr;
15615
15616 for (const auto *P : NewPreds)
15617 addPredicate(*P);
15618
15619 RewriteMap[SE.getSCEV(V)] = {Generation, New};
15620 return New;
15621}
15622
15625 : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
15626 Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates(),
15627 SE)),
15628 Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
15629 for (auto I : Init.FlagsMap)
15630 FlagsMap.insert(I);
15631}
15632
15634 // For each block.
15635 for (auto *BB : L.getBlocks())
15636 for (auto &I : *BB) {
15637 if (!SE.isSCEVable(I.getType()))
15638 continue;
15639
15640 auto *Expr = SE.getSCEV(&I);
15641 auto II = RewriteMap.find(Expr);
15642
15643 if (II == RewriteMap.end())
15644 continue;
15645
15646 // Don't print things that are not interesting.
15647 if (II->second.second == Expr)
15648 continue;
15649
15650 OS.indent(Depth) << "[PSE]" << I << ":\n";
15651 OS.indent(Depth + 2) << *Expr << "\n";
15652 OS.indent(Depth + 2) << "--> " << *II->second.second << "\n";
15653 }
15654}
15655
15658 BasicBlock *Header = L->getHeader();
15659 BasicBlock *Pred = L->getLoopPredecessor();
15660 LoopGuards Guards(SE);
15661 if (!Pred)
15662 return Guards;
15664 collectFromBlock(SE, Guards, Header, Pred, VisitedBlocks);
15665 return Guards;
15666}
15667
15668void ScalarEvolution::LoopGuards::collectFromPHI(
15672 unsigned Depth) {
15673 if (!SE.isSCEVable(Phi.getType()))
15674 return;
15675
15676 using MinMaxPattern = std::pair<const SCEVConstant *, SCEVTypes>;
15677 auto GetMinMaxConst = [&](unsigned IncomingIdx) -> MinMaxPattern {
15678 const BasicBlock *InBlock = Phi.getIncomingBlock(IncomingIdx);
15679 if (!VisitedBlocks.insert(InBlock).second)
15680 return {nullptr, scCouldNotCompute};
15681
15682 // Avoid analyzing unreachable blocks so that we don't get trapped
15683 // traversing cycles with ill-formed dominance or infinite cycles
15684 if (!SE.DT.isReachableFromEntry(InBlock))
15685 return {nullptr, scCouldNotCompute};
15686
15687 auto [G, Inserted] = IncomingGuards.try_emplace(InBlock, LoopGuards(SE));
15688 if (Inserted)
15689 collectFromBlock(SE, G->second, Phi.getParent(), InBlock, VisitedBlocks,
15690 Depth + 1);
15691 auto &RewriteMap = G->second.RewriteMap;
15692 if (RewriteMap.empty())
15693 return {nullptr, scCouldNotCompute};
15694 auto S = RewriteMap.find(SE.getSCEV(Phi.getIncomingValue(IncomingIdx)));
15695 if (S == RewriteMap.end())
15696 return {nullptr, scCouldNotCompute};
15697 auto *SM = dyn_cast_if_present<SCEVMinMaxExpr>(S->second);
15698 if (!SM)
15699 return {nullptr, scCouldNotCompute};
15700 if (const SCEVConstant *C0 = dyn_cast<SCEVConstant>(SM->getOperand(0)))
15701 return {C0, SM->getSCEVType()};
15702 return {nullptr, scCouldNotCompute};
15703 };
15704 auto MergeMinMaxConst = [](MinMaxPattern P1,
15705 MinMaxPattern P2) -> MinMaxPattern {
15706 auto [C1, T1] = P1;
15707 auto [C2, T2] = P2;
15708 if (!C1 || !C2 || T1 != T2)
15709 return {nullptr, scCouldNotCompute};
15710 switch (T1) {
15711 case scUMaxExpr:
15712 return {C1->getAPInt().ult(C2->getAPInt()) ? C1 : C2, T1};
15713 case scSMaxExpr:
15714 return {C1->getAPInt().slt(C2->getAPInt()) ? C1 : C2, T1};
15715 case scUMinExpr:
15716 return {C1->getAPInt().ugt(C2->getAPInt()) ? C1 : C2, T1};
15717 case scSMinExpr:
15718 return {C1->getAPInt().sgt(C2->getAPInt()) ? C1 : C2, T1};
15719 default:
15720 llvm_unreachable("Trying to merge non-MinMaxExpr SCEVs.");
15721 }
15722 };
15723 auto P = GetMinMaxConst(0);
15724 for (unsigned int In = 1; In < Phi.getNumIncomingValues(); In++) {
15725 if (!P.first)
15726 break;
15727 P = MergeMinMaxConst(P, GetMinMaxConst(In));
15728 }
15729 if (P.first) {
15730 const SCEV *LHS = SE.getSCEV(const_cast<PHINode *>(&Phi));
15731 SmallVector<SCEVUse, 2> Ops({P.first, LHS});
15732 const SCEV *RHS = SE.getMinMaxExpr(P.second, Ops);
15733 Guards.RewriteMap.insert({LHS, RHS});
15734 }
15735}
15736
15737// Return a new SCEV that modifies \p Expr to the closest number divides by
15738// \p Divisor and less or equal than Expr. For now, only handle constant
15739// Expr.
15741 const APInt &DivisorVal,
15742 ScalarEvolution &SE) {
15743 const APInt *ExprVal;
15744 if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() ||
15745 DivisorVal.isNonPositive())
15746 return Expr;
15747 APInt Rem = ExprVal->urem(DivisorVal);
15748 // return the SCEV: Expr - Expr % Divisor
15749 return SE.getConstant(*ExprVal - Rem);
15750}
15751
15752// Return a new SCEV that modifies \p Expr to the closest number divides by
15753// \p Divisor and greater or equal than Expr. For now, only handle constant
15754// Expr.
15755static const SCEV *getNextSCEVDivisibleByDivisor(const SCEV *Expr,
15756 const APInt &DivisorVal,
15757 ScalarEvolution &SE) {
15758 const APInt *ExprVal;
15759 if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() ||
15760 DivisorVal.isNonPositive())
15761 return Expr;
15762 APInt Rem = ExprVal->urem(DivisorVal);
15763 if (Rem.isZero())
15764 return Expr;
15765 // return the SCEV: Expr + Divisor - Expr % Divisor
15766 return SE.getConstant(*ExprVal + DivisorVal - Rem);
15767}
15768
15770 ICmpInst::Predicate Predicate, const SCEV *LHS, const SCEV *RHS,
15773 // If we have LHS == 0, check if LHS is computing a property of some unknown
15774 // SCEV %v which we can rewrite %v to express explicitly.
15776 return false;
15777 // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
15778 // explicitly express that.
15779 const SCEVUnknown *URemLHS = nullptr;
15780 const SCEV *URemRHS = nullptr;
15781 if (!match(LHS, m_scev_URem(m_SCEVUnknown(URemLHS), m_SCEV(URemRHS), SE)))
15782 return false;
15783
15784 const SCEV *Multiple =
15785 SE.getMulExpr(SE.getUDivExpr(URemLHS, URemRHS), URemRHS);
15786 DivInfo[URemLHS] = Multiple;
15787 if (auto *C = dyn_cast<SCEVConstant>(URemRHS))
15788 Multiples[URemLHS] = C->getAPInt();
15789 return true;
15790}
15791
15792// Check if the condition is a divisibility guard (A % B == 0).
15793static bool isDivisibilityGuard(const SCEV *LHS, const SCEV *RHS,
15794 ScalarEvolution &SE) {
15795 const SCEV *X, *Y;
15796 return match(LHS, m_scev_URem(m_SCEV(X), m_SCEV(Y), SE)) && RHS->isZero();
15797}
15798
15799// Apply divisibility by \p Divisor on MinMaxExpr with constant values,
15800// recursively. This is done by aligning up/down the constant value to the
15801// Divisor.
15802static const SCEV *applyDivisibilityOnMinMaxExpr(const SCEV *MinMaxExpr,
15803 APInt Divisor,
15804 ScalarEvolution &SE) {
15805 // Return true if \p Expr is a MinMax SCEV expression with a non-negative
15806 // constant operand. If so, return in \p SCTy the SCEV type and in \p RHS
15807 // the non-constant operand and in \p LHS the constant operand.
15808 auto IsMinMaxSCEVWithNonNegativeConstant =
15809 [&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS,
15810 const SCEV *&RHS) {
15811 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
15812 if (MinMax->getNumOperands() != 2)
15813 return false;
15814 if (auto *C = dyn_cast<SCEVConstant>(MinMax->getOperand(0))) {
15815 if (C->getAPInt().isNegative())
15816 return false;
15817 SCTy = MinMax->getSCEVType();
15818 LHS = MinMax->getOperand(0);
15819 RHS = MinMax->getOperand(1);
15820 return true;
15821 }
15822 }
15823 return false;
15824 };
15825
15826 const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr;
15827 SCEVTypes SCTy;
15828 if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS,
15829 MinMaxRHS))
15830 return MinMaxExpr;
15831 auto IsMin = isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr);
15832 assert(SE.isKnownNonNegative(MinMaxLHS) && "Expected non-negative operand!");
15833 auto *DivisibleExpr =
15834 IsMin ? getPreviousSCEVDivisibleByDivisor(MinMaxLHS, Divisor, SE)
15835 : getNextSCEVDivisibleByDivisor(MinMaxLHS, Divisor, SE);
15837 applyDivisibilityOnMinMaxExpr(MinMaxRHS, Divisor, SE), DivisibleExpr};
15838 return SE.getMinMaxExpr(SCTy, Ops);
15839}
15840
15841void ScalarEvolution::LoopGuards::collectFromBlock(
15842 ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
15843 const BasicBlock *Block, const BasicBlock *Pred,
15844 SmallPtrSetImpl<const BasicBlock *> &VisitedBlocks, unsigned Depth) {
15845
15847
15848 SmallVector<SCEVUse> ExprsToRewrite;
15849 auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
15850 const SCEV *RHS,
15851 DenseMap<const SCEV *, const SCEV *> &RewriteMap,
15852 const LoopGuards &DivGuards) {
15853 // WARNING: It is generally unsound to apply any wrap flags to the proposed
15854 // replacement SCEV which isn't directly implied by the structure of that
15855 // SCEV. In particular, using contextual facts to imply flags is *NOT*
15856 // legal. See the scoping rules for flags in the header to understand why.
15857
15858 // Check for a condition of the form (-C1 + X < C2). InstCombine will
15859 // create this form when combining two checks of the form (X u< C2 + C1) and
15860 // (X >=u C1).
15861 auto MatchRangeCheckIdiom = [&SE, Predicate, LHS, RHS, &RewriteMap,
15862 &ExprsToRewrite]() {
15863 const SCEVConstant *C1;
15864 const SCEVUnknown *LHSUnknown;
15865 auto *C2 = dyn_cast<SCEVConstant>(RHS);
15866 if (!match(LHS,
15867 m_scev_Add(m_SCEVConstant(C1), m_SCEVUnknown(LHSUnknown))) ||
15868 !C2)
15869 return false;
15870
15871 auto ExactRegion =
15872 ConstantRange::makeExactICmpRegion(Predicate, C2->getAPInt())
15873 .sub(C1->getAPInt());
15874
15875 // Bail out, unless we have a non-wrapping, monotonic range.
15876 if (ExactRegion.isWrappedSet() || ExactRegion.isFullSet())
15877 return false;
15878 auto [I, Inserted] = RewriteMap.try_emplace(LHSUnknown);
15879 const SCEV *RewrittenLHS = Inserted ? LHSUnknown : I->second;
15880 I->second = SE.getUMaxExpr(
15881 SE.getConstant(ExactRegion.getUnsignedMin()),
15882 SE.getUMinExpr(RewrittenLHS,
15883 SE.getConstant(ExactRegion.getUnsignedMax())));
15884 ExprsToRewrite.push_back(LHSUnknown);
15885 return true;
15886 };
15887 if (MatchRangeCheckIdiom())
15888 return;
15889
15890 // Do not apply information for constants or if RHS contains an AddRec.
15892 return;
15893
15894 // If RHS is SCEVUnknown, make sure the information is applied to it.
15896 std::swap(LHS, RHS);
15898 }
15899
15900 // Puts rewrite rule \p From -> \p To into the rewrite map. Also if \p From
15901 // and \p FromRewritten are the same (i.e. there has been no rewrite
15902 // registered for \p From), then puts this value in the list of rewritten
15903 // expressions.
15904 auto AddRewrite = [&](const SCEV *From, const SCEV *FromRewritten,
15905 const SCEV *To) {
15906 if (From == FromRewritten)
15907 ExprsToRewrite.push_back(From);
15908 RewriteMap[From] = To;
15909 };
15910
15911 // Checks whether \p S has already been rewritten. In that case returns the
15912 // existing rewrite because we want to chain further rewrites onto the
15913 // already rewritten value. Otherwise returns \p S.
15914 auto GetMaybeRewritten = [&](const SCEV *S) {
15915 return RewriteMap.lookup_or(S, S);
15916 };
15917
15918 const SCEV *RewrittenLHS = GetMaybeRewritten(LHS);
15919 // Apply divisibility information when computing the constant multiple.
15920 const APInt &DividesBy =
15921 SE.getConstantMultiple(DivGuards.rewrite(RewrittenLHS));
15922
15923 // Collect rewrites for LHS and its transitive operands based on the
15924 // condition.
15925 // For min/max expressions, also apply the guard to its operands:
15926 // 'min(a, b) >= c' -> '(a >= c) and (b >= c)',
15927 // 'min(a, b) > c' -> '(a > c) and (b > c)',
15928 // 'max(a, b) <= c' -> '(a <= c) and (b <= c)',
15929 // 'max(a, b) < c' -> '(a < c) and (b < c)'.
15930
15931 // We cannot express strict predicates in SCEV, so instead we replace them
15932 // with non-strict ones against plus or minus one of RHS depending on the
15933 // predicate.
15934 const SCEV *One = SE.getOne(RHS->getType());
15935 switch (Predicate) {
15936 case CmpInst::ICMP_ULT:
15937 if (RHS->getType()->isPointerTy())
15938 return;
15939 RHS = SE.getUMaxExpr(RHS, One);
15940 [[fallthrough]];
15941 case CmpInst::ICMP_SLT: {
15942 RHS = SE.getMinusSCEV(RHS, One);
15943 RHS = getPreviousSCEVDivisibleByDivisor(RHS, DividesBy, SE);
15944 break;
15945 }
15946 case CmpInst::ICMP_UGT:
15947 case CmpInst::ICMP_SGT:
15948 RHS = SE.getAddExpr(RHS, One);
15949 RHS = getNextSCEVDivisibleByDivisor(RHS, DividesBy, SE);
15950 break;
15951 case CmpInst::ICMP_ULE:
15952 case CmpInst::ICMP_SLE:
15953 RHS = getPreviousSCEVDivisibleByDivisor(RHS, DividesBy, SE);
15954 break;
15955 case CmpInst::ICMP_UGE:
15956 case CmpInst::ICMP_SGE:
15957 RHS = getNextSCEVDivisibleByDivisor(RHS, DividesBy, SE);
15958 break;
15959 default:
15960 break;
15961 }
15962
15963 SmallVector<SCEVUse, 16> Worklist(1, LHS);
15964 SmallPtrSet<const SCEV *, 16> Visited;
15965
15966 auto EnqueueOperands = [&Worklist](const SCEVNAryExpr *S) {
15967 append_range(Worklist, S->operands());
15968 };
15969
15970 while (!Worklist.empty()) {
15971 const SCEV *From = Worklist.pop_back_val();
15972 if (isa<SCEVConstant>(From))
15973 continue;
15974 if (!Visited.insert(From).second)
15975 continue;
15976 const SCEV *FromRewritten = GetMaybeRewritten(From);
15977 const SCEV *To = nullptr;
15978
15979 switch (Predicate) {
15980 case CmpInst::ICMP_ULT:
15981 case CmpInst::ICMP_ULE:
15982 To = SE.getUMinExpr(FromRewritten, RHS);
15983 if (auto *UMax = dyn_cast<SCEVUMaxExpr>(FromRewritten))
15984 EnqueueOperands(UMax);
15985 break;
15986 case CmpInst::ICMP_SLT:
15987 case CmpInst::ICMP_SLE:
15988 To = SE.getSMinExpr(FromRewritten, RHS);
15989 if (auto *SMax = dyn_cast<SCEVSMaxExpr>(FromRewritten))
15990 EnqueueOperands(SMax);
15991 break;
15992 case CmpInst::ICMP_UGT:
15993 case CmpInst::ICMP_UGE:
15994 To = SE.getUMaxExpr(FromRewritten, RHS);
15995 if (auto *UMin = dyn_cast<SCEVUMinExpr>(FromRewritten))
15996 EnqueueOperands(UMin);
15997 break;
15998 case CmpInst::ICMP_SGT:
15999 case CmpInst::ICMP_SGE:
16000 To = SE.getSMaxExpr(FromRewritten, RHS);
16001 if (auto *SMin = dyn_cast<SCEVSMinExpr>(FromRewritten))
16002 EnqueueOperands(SMin);
16003 break;
16004 case CmpInst::ICMP_EQ:
16006 To = RHS;
16007 break;
16008 case CmpInst::ICMP_NE:
16009 if (match(RHS, m_scev_Zero())) {
16010 const SCEV *OneAlignedUp =
16011 getNextSCEVDivisibleByDivisor(One, DividesBy, SE);
16012 To = SE.getUMaxExpr(FromRewritten, OneAlignedUp);
16013 } else {
16014 // LHS != RHS can be rewritten as (LHS - RHS) = UMax(1, LHS - RHS),
16015 // but creating the subtraction eagerly is expensive. Track the
16016 // inequalities in a separate map, and materialize the rewrite lazily
16017 // when encountering a suitable subtraction while re-writing.
16018 if (LHS->getType()->isPointerTy()) {
16022 break;
16023 }
16024 const SCEVConstant *C;
16025 const SCEV *A, *B;
16028 RHS = A;
16029 LHS = B;
16030 }
16031 if (LHS > RHS)
16032 std::swap(LHS, RHS);
16033 Guards.NotEqual.insert({LHS, RHS});
16034 continue;
16035 }
16036 break;
16037 default:
16038 break;
16039 }
16040
16041 if (To)
16042 AddRewrite(From, FromRewritten, To);
16043 }
16044 };
16045
16047 // First, collect information from assumptions dominating the loop.
16048 for (auto &AssumeVH : SE.AC.assumptions()) {
16049 if (!AssumeVH)
16050 continue;
16051 auto *AssumeI = cast<CallInst>(AssumeVH);
16052 if (!SE.DT.dominates(AssumeI, Block))
16053 continue;
16054 Terms.emplace_back(AssumeI->getOperand(0), true);
16055 }
16056
16057 // Second, collect information from llvm.experimental.guards dominating the loop.
16058 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
16059 SE.F.getParent(), Intrinsic::experimental_guard);
16060 if (GuardDecl)
16061 for (const auto *GU : GuardDecl->users())
16062 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
16063 if (Guard->getFunction() == Block->getParent() &&
16064 SE.DT.dominates(Guard, Block))
16065 Terms.emplace_back(Guard->getArgOperand(0), true);
16066
16067 // Third, collect conditions from dominating branches. Starting at the loop
16068 // predecessor, climb up the predecessor chain, as long as there are
16069 // predecessors that can be found that have unique successors leading to the
16070 // original header.
16071 // TODO: share this logic with isLoopEntryGuardedByCond.
16072 unsigned NumCollectedConditions = 0;
16074 std::pair<const BasicBlock *, const BasicBlock *> Pair(Pred, Block);
16075 for (; Pair.first;
16076 Pair = SE.getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
16077 VisitedBlocks.insert(Pair.second);
16078 const CondBrInst *LoopEntryPredicate =
16079 dyn_cast<CondBrInst>(Pair.first->getTerminator());
16080 if (!LoopEntryPredicate)
16081 continue;
16082
16083 Terms.emplace_back(LoopEntryPredicate->getCondition(),
16084 LoopEntryPredicate->getSuccessor(0) == Pair.second);
16085 NumCollectedConditions++;
16086
16087 // If we are recursively collecting guards stop after 2
16088 // conditions to limit compile-time impact for now.
16089 if (Depth > 0 && NumCollectedConditions == 2)
16090 break;
16091 }
16092 // Finally, if we stopped climbing the predecessor chain because
16093 // there wasn't a unique one to continue, try to collect conditions
16094 // for PHINodes by recursively following all of their incoming
16095 // blocks and try to merge the found conditions to build a new one
16096 // for the Phi.
16097 if (Pair.second->hasNPredecessorsOrMore(2) &&
16099 SmallDenseMap<const BasicBlock *, LoopGuards> IncomingGuards;
16100 for (auto &Phi : Pair.second->phis())
16101 collectFromPHI(SE, Guards, Phi, VisitedBlocks, IncomingGuards, Depth);
16102 }
16103
16104 // Now apply the information from the collected conditions to
16105 // Guards.RewriteMap. Conditions are processed in reverse order, so the
16106 // earliest conditions is processed first, except guards with divisibility
16107 // information, which are moved to the back. This ensures the SCEVs with the
16108 // shortest dependency chains are constructed first.
16110 GuardsToProcess;
16111 for (auto [Term, EnterIfTrue] : reverse(Terms)) {
16112 SmallVector<Value *, 8> Worklist;
16113 SmallPtrSet<Value *, 8> Visited;
16114 Worklist.push_back(Term);
16115 while (!Worklist.empty()) {
16116 Value *Cond = Worklist.pop_back_val();
16117 if (!Visited.insert(Cond).second)
16118 continue;
16119
16120 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
16121 auto Predicate =
16122 EnterIfTrue ? Cmp->getPredicate() : Cmp->getInversePredicate();
16123 const auto *LHS = SE.getSCEV(Cmp->getOperand(0));
16124 const auto *RHS = SE.getSCEV(Cmp->getOperand(1));
16125 // If LHS is a constant, apply information to the other expression.
16126 // TODO: If LHS is not a constant, check if using CompareSCEVComplexity
16127 // can improve results.
16128 if (isa<SCEVConstant>(LHS)) {
16129 std::swap(LHS, RHS);
16131 }
16132 GuardsToProcess.emplace_back(Predicate, LHS, RHS);
16133 continue;
16134 }
16135
16136 Value *L, *R;
16137 if (EnterIfTrue ? match(Cond, m_LogicalAnd(m_Value(L), m_Value(R)))
16138 : match(Cond, m_LogicalOr(m_Value(L), m_Value(R)))) {
16139 Worklist.push_back(L);
16140 Worklist.push_back(R);
16141 }
16142 }
16143 }
16144
16145 // Process divisibility guards in reverse order to populate DivGuards early.
16146 DenseMap<const SCEV *, APInt> Multiples;
16147 LoopGuards DivGuards(SE);
16148 for (const auto &[Predicate, LHS, RHS] : GuardsToProcess) {
16149 if (!isDivisibilityGuard(LHS, RHS, SE))
16150 continue;
16151 collectDivisibilityInformation(Predicate, LHS, RHS, DivGuards.RewriteMap,
16152 Multiples, SE);
16153 }
16154
16155 for (const auto &[Predicate, LHS, RHS] : GuardsToProcess)
16156 CollectCondition(Predicate, LHS, RHS, Guards.RewriteMap, DivGuards);
16157
16158 // Apply divisibility information last. This ensures it is applied to the
16159 // outermost expression after other rewrites for the given value.
16160 for (const auto &[K, Divisor] : Multiples) {
16161 const SCEV *DivisorSCEV = SE.getConstant(Divisor);
16162 Guards.RewriteMap[K] =
16164 Guards.rewrite(K), Divisor, SE),
16165 DivisorSCEV),
16166 DivisorSCEV);
16167 ExprsToRewrite.push_back(K);
16168 }
16169
16170 // Let the rewriter preserve NUW/NSW flags if the unsigned/signed ranges of
16171 // the replacement expressions are contained in the ranges of the replaced
16172 // expressions.
16173 Guards.PreserveNUW = true;
16174 Guards.PreserveNSW = true;
16175 for (const SCEV *Expr : ExprsToRewrite) {
16176 const SCEV *RewriteTo = Guards.RewriteMap[Expr];
16177 Guards.PreserveNUW &=
16178 SE.getUnsignedRange(Expr).contains(SE.getUnsignedRange(RewriteTo));
16179 Guards.PreserveNSW &=
16180 SE.getSignedRange(Expr).contains(SE.getSignedRange(RewriteTo));
16181 }
16182
16183 // Now that all rewrite information is collect, rewrite the collected
16184 // expressions with the information in the map. This applies information to
16185 // sub-expressions.
16186 if (ExprsToRewrite.size() > 1) {
16187 for (const SCEV *Expr : ExprsToRewrite) {
16188 const SCEV *RewriteTo = Guards.RewriteMap[Expr];
16189 Guards.RewriteMap.erase(Expr);
16190 Guards.RewriteMap.insert({Expr, Guards.rewrite(RewriteTo)});
16191 }
16192 }
16193}
16194
16196 /// A rewriter to replace SCEV expressions in Map with the corresponding entry
16197 /// in the map. It skips AddRecExpr because we cannot guarantee that the
16198 /// replacement is loop invariant in the loop of the AddRec.
16199 class SCEVLoopGuardRewriter
16200 : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
16203
16205
16206 public:
16207 SCEVLoopGuardRewriter(ScalarEvolution &SE,
16208 const ScalarEvolution::LoopGuards &Guards)
16209 : SCEVRewriteVisitor(SE), Map(Guards.RewriteMap),
16210 NotEqual(Guards.NotEqual) {
16211 if (Guards.PreserveNUW)
16212 FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNUW);
16213 if (Guards.PreserveNSW)
16214 FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNSW);
16215 }
16216
16217 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
16218
16219 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
16220 return Map.lookup_or(Expr, Expr);
16221 }
16222
16223 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
16224 if (const SCEV *S = Map.lookup(Expr))
16225 return S;
16226
16227 // If we didn't find the extact ZExt expr in the map, check if there's
16228 // an entry for a smaller ZExt we can use instead.
16229 Type *Ty = Expr->getType();
16230 const SCEV *Op = Expr->getOperand(0);
16231 unsigned Bitwidth = Ty->getScalarSizeInBits() / 2;
16232 while (Bitwidth % 8 == 0 && Bitwidth >= 8 &&
16233 Bitwidth > Op->getType()->getScalarSizeInBits()) {
16234 Type *NarrowTy = IntegerType::get(SE.getContext(), Bitwidth);
16235 auto *NarrowExt = SE.getZeroExtendExpr(Op, NarrowTy);
16236 if (const SCEV *S = Map.lookup(NarrowExt))
16237 return SE.getZeroExtendExpr(S, Ty);
16238 Bitwidth = Bitwidth / 2;
16239 }
16240
16242 Expr);
16243 }
16244
16245 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
16246 if (const SCEV *S = Map.lookup(Expr))
16247 return S;
16249 Expr);
16250 }
16251
16252 const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) {
16253 if (const SCEV *S = Map.lookup(Expr))
16254 return S;
16256 }
16257
16258 const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) {
16259 if (const SCEV *S = Map.lookup(Expr))
16260 return S;
16262 }
16263
16264 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
16265 // Helper to check if S is a subtraction (A - B) where A != B, and if so,
16266 // return UMax(S, 1).
16267 auto RewriteSubtraction = [&](const SCEV *S) -> const SCEV * {
16268 SCEVUse LHS, RHS;
16269 if (MatchBinarySub(S, LHS, RHS)) {
16270 if (LHS > RHS)
16271 std::swap(LHS, RHS);
16272 if (NotEqual.contains({LHS, RHS})) {
16273 const SCEV *OneAlignedUp = getNextSCEVDivisibleByDivisor(
16274 SE.getOne(S->getType()), SE.getConstantMultiple(S), SE);
16275 return SE.getUMaxExpr(OneAlignedUp, S);
16276 }
16277 }
16278 return nullptr;
16279 };
16280
16281 // Check if Expr itself is a subtraction pattern with guard info.
16282 if (const SCEV *Rewritten = RewriteSubtraction(Expr))
16283 return Rewritten;
16284
16285 // Trip count expressions sometimes consist of adding 3 operands, i.e.
16286 // (Const + A + B). There may be guard info for A + B, and if so, apply
16287 // it.
16288 // TODO: Could more generally apply guards to Add sub-expressions.
16289 if (isa<SCEVConstant>(Expr->getOperand(0)) &&
16290 Expr->getNumOperands() == 3) {
16291 const SCEV *Add =
16292 SE.getAddExpr(Expr->getOperand(1), Expr->getOperand(2));
16293 if (const SCEV *Rewritten = RewriteSubtraction(Add))
16294 return SE.getAddExpr(
16295 Expr->getOperand(0), Rewritten,
16296 ScalarEvolution::maskFlags(Expr->getNoWrapFlags(), FlagMask));
16297 if (const SCEV *S = Map.lookup(Add))
16298 return SE.getAddExpr(Expr->getOperand(0), S);
16299 }
16300 SmallVector<SCEVUse, 2> Operands;
16301 bool Changed = false;
16302 for (SCEVUse Op : Expr->operands()) {
16303 Operands.push_back(
16305 Changed |= Op != Operands.back();
16306 }
16307 // We are only replacing operands with equivalent values, so transfer the
16308 // flags from the original expression.
16309 return !Changed ? Expr
16310 : SE.getAddExpr(Operands,
16312 Expr->getNoWrapFlags(), FlagMask));
16313 }
16314
16315 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
16316 SmallVector<SCEVUse, 2> Operands;
16317 bool Changed = false;
16318 for (SCEVUse Op : Expr->operands()) {
16319 Operands.push_back(
16321 Changed |= Op != Operands.back();
16322 }
16323 // We are only replacing operands with equivalent values, so transfer the
16324 // flags from the original expression.
16325 return !Changed ? Expr
16326 : SE.getMulExpr(Operands,
16328 Expr->getNoWrapFlags(), FlagMask));
16329 }
16330 };
16331
16332 if (RewriteMap.empty() && NotEqual.empty())
16333 return Expr;
16334
16335 SCEVLoopGuardRewriter Rewriter(SE, *this);
16336 return Rewriter.visit(Expr);
16337}
16338
16339const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
16340 return applyLoopGuards(Expr, LoopGuards::collect(L, *this));
16341}
16342
16344 const LoopGuards &Guards) {
16345 return Guards.rewrite(Expr);
16346}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
constexpr LLT S1
Rewrite undef for PHI
This file implements a class to represent arbitrary precision integral constant values and operations...
@ PostInc
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Expand Atomic instructions
#define X(NUM, ENUM, NAME)
Definition ELF.h:853
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
#define LLVM_DUMP_METHOD
Mark debug helper function definitions like dump() that should not be stripped from debug builds.
Definition Compiler.h:661
This file contains the declarations for the subclasses of Constant, which represent the different fla...
SmallPtrSet< const BasicBlock *, 8 > VisitedBlocks
This file defines the DenseMap class.
This file builds on the ADT/GraphTraits.h file to build generic depth first graph iterator.
static bool isSigned(unsigned Opcode)
This file defines a hash set that can be used to remove duplication of nodes in a graph.
#define op(i)
Hexagon Common GEP
Value * getPointer(Value *Ptr)
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
This defines the Use class.
iv Induction Variable Users
Definition IVUsers.cpp:48
static constexpr Value * getValue(Ty &ValueOrUse)
const AbstractManglingParser< Derived, Alloc >::OperatorInfo AbstractManglingParser< Derived, Alloc >::Ops[]
static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, AssumptionCache *AC)
Definition Lint.cpp:539
#define F(x, y, z)
Definition MD5.cpp:54
#define I(x, y, z)
Definition MD5.cpp:57
#define G(x, y, z)
Definition MD5.cpp:55
#define T
#define T1
static constexpr unsigned SM(unsigned Version)
ConstantRange Range(APInt(BitWidth, Low), APInt(BitWidth, High))
uint64_t IntrinsicInst * II
#define P(N)
ppc ctr loops verify
PowerPC Reduce CR logical Operation
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition PassSupport.h:42
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition PassSupport.h:44
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition PassSupport.h:39
R600 Clause Merge
const SmallVectorImpl< MachineOperand > & Cond
static DominatorTree getDomTree(Function &F)
static bool isValid(const char C)
Returns true if C is a valid mangled character: <0-9a-zA-Z_>.
SI optimize exec mask operations pre RA
static void visit(BasicBlock &Start, std::function< bool(BasicBlock *)> op)
This file contains some templates that are useful if you are working with the STL at all.
This file provides utility classes that use RAII to save and restore values.
bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind, SCEVTypes RootKind)
static cl::opt< unsigned > MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden, cl::desc("Max coefficients in AddRec during evolving"), cl::init(8))
static cl::opt< unsigned > RangeIterThreshold("scev-range-iter-threshold", cl::Hidden, cl::desc("Threshold for switching to iteratively computing SCEV ranges"), cl::init(32))
static const Loop * isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI)
static unsigned getConstantTripCount(const SCEVConstant *ExitCount)
static int CompareValueComplexity(const LoopInfo *const LI, Value *LV, Value *RV, unsigned Depth)
Compare the two values LV and RV in terms of their "complexity" where "complexity" is a partial (and ...
static const SCEV * getNextSCEVDivisibleByDivisor(const SCEV *Expr, const APInt &DivisorVal, ScalarEvolution &SE)
static void PushLoopPHIs(const Loop *L, SmallVectorImpl< Instruction * > &Worklist, SmallPtrSetImpl< Instruction * > &Visited)
Push PHI nodes in the header of the given loop onto the given Worklist.
static void insertFoldCacheEntry(const ScalarEvolution::FoldID &ID, const SCEV *S, DenseMap< ScalarEvolution::FoldID, const SCEV * > &FoldCache, DenseMap< const SCEV *, SmallVector< ScalarEvolution::FoldID, 2 > > &FoldCacheUser)
static cl::opt< bool > ClassifyExpressions("scalar-evolution-classify-expressions", cl::Hidden, cl::init(true), cl::desc("When printing analysis, include information on every instruction"))
static bool hasHugeExpression(ArrayRef< SCEVUse > Ops)
Returns true if Ops contains a huge SCEV (the subtree of S contains at least HugeExprThreshold nodes)...
static bool CanConstantFold(const Instruction *I)
Return true if we can constant fold an instruction of the specified type, assuming that all operands ...
static cl::opt< unsigned > AddOpsInlineThreshold("scev-addops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining addition operands into a SCEV"), cl::init(500))
static cl::opt< unsigned > MaxLoopGuardCollectionDepth("scalar-evolution-max-loop-guard-collection-depth", cl::Hidden, cl::desc("Maximum depth for recursive loop guard collection"), cl::init(1))
static cl::opt< bool > VerifyIR("scev-verify-ir", cl::Hidden, cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"), cl::init(false))
static bool RangeRefPHIAllowedOperands(DominatorTree &DT, PHINode *PHI)
static const SCEV * getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static std::optional< APInt > MinOptional(std::optional< APInt > X, std::optional< APInt > Y)
Helper function to compare optional APInts: (a) if X and Y both exist, return min(X,...
static cl::opt< unsigned > MulOpsInlineThreshold("scev-mulops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining multiplication operands into a SCEV"), cl::init(32))
static BinaryOperator * getCommonInstForPHI(PHINode *PN)
static bool isDivisibilityGuard(const SCEV *LHS, const SCEV *RHS, ScalarEvolution &SE)
static std::optional< const SCEV * > createNodeForSelectViaUMinSeq(ScalarEvolution *SE, const SCEV *CondExpr, const SCEV *TrueExpr, const SCEV *FalseExpr)
static Constant * BuildConstantFromSCEV(const SCEV *V)
This builds up a Constant using the ConstantExpr interface.
static ConstantInt * EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C, ScalarEvolution &SE)
static const SCEV * BinomialCoefficient(const SCEV *It, unsigned K, ScalarEvolution &SE, Type *ResultTy)
Compute BC(It, K). The result has width W. Assume, K > 0.
static cl::opt< unsigned > MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden, cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"), cl::init(8))
static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr, const SCEV *Candidate)
Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
static PHINode * getConstantEvolvingPHI(Value *V, const Loop *L)
getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node in the loop that V is deri...
static const SCEV * SolveLinEquationWithOverflow(const APInt &A, const SCEV *B, SmallVectorImpl< const SCEVPredicate * > *Predicates, ScalarEvolution &SE, const Loop *L)
Finds the minimum unsigned root of the following equation:
static cl::opt< unsigned > MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden, cl::desc("Maximum number of iterations SCEV will " "symbolically execute a constant " "derived loop"), cl::init(100))
static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow)
static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV *S)
When printing a top-level SCEV for trip counts, it's helpful to include a type for constants which ar...
static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, const Loop *L)
static SCEV::NoWrapFlags StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, ArrayRef< SCEVUse > Ops, SCEV::NoWrapFlags Flags)
static bool containsConstantInAddMulChain(const SCEV *StartExpr)
Determine if any of the operands in this SCEV are a constant or if any of the add or multiply express...
static const SCEV * getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static bool CollectAddOperandsWithScales(SmallDenseMap< SCEVUse, APInt, 16 > &M, SmallVectorImpl< SCEVUse > &NewOps, APInt &AccumulatedConstant, ArrayRef< SCEVUse > Ops, const APInt &Scale, ScalarEvolution &SE)
Process the given Ops list, which is a list of operands to be added under the given scale,...
static const SCEV * constantFoldAndGroupOps(ScalarEvolution &SE, LoopInfo &LI, DominatorTree &DT, SmallVectorImpl< SCEVUse > &Ops, FoldT Fold, IsIdentityT IsIdentity, IsAbsorberT IsAbsorber)
Performs a number of common optimizations on the passed Ops.
static cl::opt< unsigned > MaxPhiSCCAnalysisSize("scalar-evolution-max-scc-analysis-depth", cl::Hidden, cl::desc("Maximum amount of nodes to process while searching SCEVUnknown " "Phi strongly connected components"), cl::init(8))
static bool IsKnownPredicateViaAddRecStart(ScalarEvolution &SE, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
static void GroupByComplexity(SmallVectorImpl< SCEVUse > &Ops, LoopInfo *LI, DominatorTree &DT)
Given a list of SCEV objects, order them by their complexity, and group objects of the same complexit...
static bool collectDivisibilityInformation(ICmpInst::Predicate Predicate, const SCEV *LHS, const SCEV *RHS, DenseMap< const SCEV *, const SCEV * > &DivInfo, DenseMap< const SCEV *, APInt > &Multiples, ScalarEvolution &SE)
static cl::opt< unsigned > MaxSCEVOperationsImplicationDepth("scalar-evolution-max-scev-operations-implication-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV operations implication analysis"), cl::init(2))
static void PushDefUseChildren(Instruction *I, SmallVectorImpl< Instruction * > &Worklist, SmallPtrSetImpl< Instruction * > &Visited)
Push users of the given Instruction onto the given Worklist.
static std::optional< APInt > SolveQuadraticAddRecRange(const SCEVAddRecExpr *AddRec, const ConstantRange &Range, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n iterations.
static cl::opt< bool > UseContextForNoWrapFlagInference("scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden, cl::desc("Infer nuw/nsw flags using context where suitable"), cl::init(true))
static cl::opt< bool > EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden, cl::desc("Handle <= and >= in finite loops"), cl::init(true))
static bool getOperandsForSelectLikePHI(DominatorTree &DT, PHINode *PN, Value *&Cond, Value *&LHS, Value *&RHS)
static std::optional< std::tuple< APInt, APInt, APInt, APInt, unsigned > > GetQuadraticEquation(const SCEVAddRecExpr *AddRec)
For a given quadratic addrec, generate coefficients of the corresponding quadratic equation,...
static bool isKnownPredicateExtendIdiom(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
static std::optional< BinaryOp > MatchBinaryOp(Value *V, const DataLayout &DL, AssumptionCache &AC, const DominatorTree &DT, const Instruction *CxtI)
Try to map V into a BinaryOp, and return std::nullopt on failure.
static std::optional< APInt > SolveQuadraticAddRecExact(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n iterations.
static std::optional< APInt > TruncIfPossible(std::optional< APInt > X, unsigned BitWidth)
Helper function to truncate an optional APInt to a given BitWidth.
static cl::opt< unsigned > MaxSCEVCompareDepth("scalar-evolution-max-scev-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV complexity comparisons"), cl::init(32))
static APInt extractConstantWithoutWrapping(ScalarEvolution &SE, const SCEVConstant *ConstantTerm, const SCEVAddExpr *WholeAddExpr)
static cl::opt< unsigned > MaxConstantEvolvingDepth("scalar-evolution-max-constant-evolving-depth", cl::Hidden, cl::desc("Maximum depth of recursive constant evolving"), cl::init(32))
static ConstantRange getRangeForAffineARHelper(APInt Step, const ConstantRange &StartRange, const APInt &MaxBECount, bool Signed)
static bool MatchBinarySub(const SCEV *S, SCEVUse &LHS, SCEVUse &RHS)
static std::optional< ConstantRange > GetRangeFromMetadata(Value *V)
Helper method to assign a range to V from metadata present in the IR.
static cl::opt< unsigned > HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden, cl::desc("Size of the expression which is considered huge"), cl::init(4096))
static Type * isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI, bool &Signed, ScalarEvolution &SE)
Helper function to createAddRecFromPHIWithCasts.
static Constant * EvaluateExpression(Value *V, const Loop *L, DenseMap< Instruction *, Constant * > &Vals, const DataLayout &DL, const TargetLibraryInfo *TLI)
EvaluateExpression - Given an expression that passes the getConstantEvolvingPHI predicate,...
static const SCEV * getPreviousSCEVDivisibleByDivisor(const SCEV *Expr, const APInt &DivisorVal, ScalarEvolution &SE)
static const SCEV * MatchNotExpr(const SCEV *Expr)
If Expr computes ~A, return A else return nullptr.
static cl::opt< unsigned > MaxValueCompareDepth("scalar-evolution-max-value-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive value complexity comparisons"), cl::init(2))
static const SCEV * applyDivisibilityOnMinMaxExpr(const SCEV *MinMaxExpr, APInt Divisor, ScalarEvolution &SE)
static cl::opt< bool, true > VerifySCEVOpt("verify-scev", cl::Hidden, cl::location(VerifySCEV), cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"))
static const SCEV * getSignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
static cl::opt< unsigned > MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden, cl::desc("Maximum depth of recursive arithmetics"), cl::init(32))
static bool HasSameValue(const SCEV *A, const SCEV *B)
SCEV structural equivalence is usually sufficient for testing whether two expressions are equal,...
static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow)
Compute the result of "n choose k", the binomial coefficient.
static std::optional< int > CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS, const SCEV *RHS, DominatorTree &DT, unsigned Depth=0)
static bool canConstantEvolve(Instruction *I, const Loop *L)
Determine whether this instruction can constant evolve within this loop assuming its operands can all...
static PHINode * getConstantEvolvingPHIOperands(Instruction *UseInst, const Loop *L, DenseMap< Instruction *, PHINode * > &PHIMap, unsigned Depth)
getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by recursing through each instructi...
static bool scevUnconditionallyPropagatesPoisonFromOperands(SCEVTypes Kind)
static cl::opt< bool > VerifySCEVStrict("verify-scev-strict", cl::Hidden, cl::desc("Enable stricter verification with -verify-scev is passed"))
static Constant * getOtherIncomingValue(PHINode *PN, BasicBlock *BB)
static cl::opt< bool > UseExpensiveRangeSharpening("scalar-evolution-use-expensive-range-sharpening", cl::Hidden, cl::init(false), cl::desc("Use more powerful methods of sharpening expression ranges. May " "be costly in terms of compile time"))
static const SCEV * getUnsignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Is LHS Pred RHS true on the virtue of LHS or RHS being a Min or Max expression?
static bool BrPHIToSelect(DominatorTree &DT, CondBrInst *BI, PHINode *Merge, Value *&C, Value *&LHS, Value *&RHS)
This file defines the make_scope_exit function, which executes user-defined cleanup logic at scope ex...
static bool InBlock(const Value *V, const BasicBlock *BB)
Provides some synthesis utilities to produce sequences of values.
This file defines the SmallPtrSet class.
This file defines the SmallVector class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition Statistic.h:171
This file contains some functions that are useful when dealing with strings.
#define LLVM_DEBUG(...)
Definition Debug.h:119
static TableGen::Emitter::Opt Y("gen-skeleton-entry", EmitSkeleton, "Generate example skeleton entry")
static SymbolRef::Type getType(const Symbol *Sym)
Definition TapiFile.cpp:39
LocallyHashedType DenseMapInfo< LocallyHashedType >::Empty
static std::optional< bool > isImpliedCondOperands(CmpInst::Predicate Pred, const Value *ALHS, const Value *ARHS, const Value *BLHS, const Value *BRHS)
Return true if "icmp Pred BLHS BRHS" is true whenever "icmp PredALHS ARHS" is true.
Virtual Register Rewriter
Value * RHS
Value * LHS
BinaryOperator * Mul
static const uint32_t IV[8]
Definition blake3_impl.h:83
SCEVCastSinkingRewriter(ScalarEvolution &SE, Type *TargetTy, ConversionFn CreatePtrCast)
static const SCEV * rewrite(const SCEV *Scev, ScalarEvolution &SE, Type *TargetTy, ConversionFn CreatePtrCast)
const SCEV * visitUnknown(const SCEVUnknown *Expr)
const SCEV * visitMulExpr(const SCEVMulExpr *Expr)
const SCEV * visitAddExpr(const SCEVAddExpr *Expr)
const SCEV * visit(const SCEV *S)
Class for arbitrary precision integers.
Definition APInt.h:78
LLVM_ABI APInt umul_ov(const APInt &RHS, bool &Overflow) const
Definition APInt.cpp:2023
LLVM_ABI APInt zext(unsigned width) const
Zero extend to a new width.
Definition APInt.cpp:1055
bool isMinSignedValue() const
Determine if this is the smallest signed value.
Definition APInt.h:424
uint64_t getZExtValue() const
Get zero extended value.
Definition APInt.h:1563
void setHighBits(unsigned hiBits)
Set the top hiBits bits.
Definition APInt.h:1414
LLVM_ABI APInt getHiBits(unsigned numBits) const
Compute an APInt containing numBits highbits from this APInt.
Definition APInt.cpp:640
unsigned getActiveBits() const
Compute the number of active bits in the value.
Definition APInt.h:1535
LLVM_ABI APInt trunc(unsigned width) const
Truncate to new width.
Definition APInt.cpp:968
static APInt getMaxValue(unsigned numBits)
Gets maximum unsigned value of APInt for specific bit width.
Definition APInt.h:207
APInt abs() const
Get the absolute value.
Definition APInt.h:1818
bool sgt(const APInt &RHS) const
Signed greater than comparison.
Definition APInt.h:1208
bool isAllOnes() const
Determine if all bits are set. This is true for zero-width values.
Definition APInt.h:372
bool ugt(const APInt &RHS) const
Unsigned greater than comparison.
Definition APInt.h:1189
bool isZero() const
Determine if this value is zero, i.e. all bits are clear.
Definition APInt.h:381
bool isSignMask() const
Check if the APInt's value is returned by getSignMask.
Definition APInt.h:467
LLVM_ABI APInt urem(const APInt &RHS) const
Unsigned remainder operation.
Definition APInt.cpp:1709
unsigned getBitWidth() const
Return the number of bits in the APInt.
Definition APInt.h:1511
bool ult(const APInt &RHS) const
Unsigned less than comparison.
Definition APInt.h:1118
static APInt getSignedMaxValue(unsigned numBits)
Gets maximum signed value of APInt for a specific bit width.
Definition APInt.h:210
static APInt getMinValue(unsigned numBits)
Gets minimum unsigned value of APInt for a specific bit width.
Definition APInt.h:217
bool isNegative() const
Determine sign of this APInt.
Definition APInt.h:330
bool sle(const APInt &RHS) const
Signed less or equal comparison.
Definition APInt.h:1173
static APInt getSignedMinValue(unsigned numBits)
Gets minimum signed value of APInt for a specific bit width.
Definition APInt.h:220
bool isNonPositive() const
Determine if this APInt Value is non-positive (<= 0).
Definition APInt.h:362
unsigned countTrailingZeros() const
Definition APInt.h:1670
bool isStrictlyPositive() const
Determine if this APInt Value is positive.
Definition APInt.h:357
unsigned logBase2() const
Definition APInt.h:1784
APInt ashr(unsigned ShiftAmt) const
Arithmetic right-shift function.
Definition APInt.h:834
LLVM_ABI APInt multiplicativeInverse() const
Definition APInt.cpp:1317
bool ule(const APInt &RHS) const
Unsigned less or equal comparison.
Definition APInt.h:1157
LLVM_ABI APInt sext(unsigned width) const
Sign extend to a new width.
Definition APInt.cpp:1028
APInt shl(unsigned shiftAmt) const
Left-shift function.
Definition APInt.h:880
bool isPowerOf2() const
Check if this APInt's value is a power of two greater than zero.
Definition APInt.h:441
static APInt getLowBitsSet(unsigned numBits, unsigned loBitsSet)
Constructs an APInt value that has the bottom loBitsSet bits set.
Definition APInt.h:307
bool isSignBitSet() const
Determine if sign bit of this APInt is set.
Definition APInt.h:342
bool slt(const APInt &RHS) const
Signed less than comparison.
Definition APInt.h:1137
static APInt getZero(unsigned numBits)
Get the '0' value for the specified bit-width.
Definition APInt.h:201
bool isIntN(unsigned N) const
Check if this APInt has an N-bits unsigned integer value.
Definition APInt.h:433
bool sge(const APInt &RHS) const
Signed greater or equal comparison.
Definition APInt.h:1244
static APInt getOneBitSet(unsigned numBits, unsigned BitNo)
Return an APInt with exactly one bit set in the result.
Definition APInt.h:240
bool uge(const APInt &RHS) const
Unsigned greater or equal comparison.
Definition APInt.h:1228
This templated class represents "all analyses that operate over <aparticular IR unit>" (e....
Definition Analysis.h:50
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Represent the analysis usage information of a pass.
void setPreservesAll()
Set by analyses that do not transform their input at all.
AnalysisUsage & addRequiredTransitive()
Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition ArrayRef.h:40
iterator end() const
Definition ArrayRef.h:130
size_t size() const
Get the array size.
Definition ArrayRef.h:141
iterator begin() const
Definition ArrayRef.h:129
A function analysis which provides an AssumptionCache.
An immutable pass that tracks lazily created AssumptionCache objects.
A cache of @llvm.assume calls within a function.
MutableArrayRef< WeakVH > assumptions()
Access the list of assumption handles currently tracked for this function.
LLVM Basic Block Representation.
Definition BasicBlock.h:62
iterator begin()
Instruction iterator methods.
Definition BasicBlock.h:461
const Function * getParent() const
Return the enclosing method, or null if none.
Definition BasicBlock.h:213
LLVM_ABI const BasicBlock * getSinglePredecessor() const
Return the predecessor of this block if it has a single predecessor block.
const Instruction & front() const
Definition BasicBlock.h:484
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction; assumes that the block is well-formed.
Definition BasicBlock.h:237
LLVM_ABI unsigned getNoWrapKind() const
Returns one of OBO::NoSignedWrap or OBO::NoUnsignedWrap.
LLVM_ABI Instruction::BinaryOps getBinaryOp() const
Returns the binary operation underlying the intrinsic.
BinaryOps getOpcode() const
Definition InstrTypes.h:409
This class represents a function call, abstracting a target machine's calling convention.
virtual void deleted()
Callback for Value destruction.
void setValPtr(Value *P)
bool isFalseWhenEqual() const
This is just a convenience.
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition InstrTypes.h:740
@ ICMP_SLT
signed less than
Definition InstrTypes.h:769
@ ICMP_SLE
signed less or equal
Definition InstrTypes.h:770
@ ICMP_UGE
unsigned greater or equal
Definition InstrTypes.h:764
@ ICMP_UGT
unsigned greater than
Definition InstrTypes.h:763
@ ICMP_SGT
signed greater than
Definition InstrTypes.h:767
@ ICMP_ULT
unsigned less than
Definition InstrTypes.h:765
@ ICMP_NE
not equal
Definition InstrTypes.h:762
@ ICMP_SGE
signed greater or equal
Definition InstrTypes.h:768
@ ICMP_ULE
unsigned less or equal
Definition InstrTypes.h:766
bool isSigned() const
Definition InstrTypes.h:993
Predicate getSwappedPredicate() const
For example, EQ->EQ, SLE->SGE, ULT->UGT, OEQ->OEQ, ULE->UGE, OLT->OGT, etc.
Definition InstrTypes.h:890
bool isTrueWhenEqual() const
This is just a convenience.
Predicate getInversePredicate() const
For example, EQ -> NE, UGT -> ULE, SLT -> SGE, OEQ -> UNE, UGT -> OLE, OLT -> UGE,...
Definition InstrTypes.h:852
bool isUnsigned() const
Definition InstrTypes.h:999
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
Definition InstrTypes.h:989
An abstraction over a floating-point predicate, and a pack of an integer predicate with samesign info...
static LLVM_ABI std::optional< CmpPredicate > getMatching(CmpPredicate A, CmpPredicate B)
Compares two CmpPredicates taking samesign into account and returns the canonicalized CmpPredicate if...
LLVM_ABI CmpInst::Predicate getPreferredSignedPredicate() const
Attempts to return a signed CmpInst::Predicate from the CmpPredicate.
CmpInst::Predicate dropSameSign() const
Drops samesign information.
Conditional Branch instruction.
Value * getCondition() const
BasicBlock * getSuccessor(unsigned i) const
static LLVM_ABI Constant * getNot(Constant *C)
static Constant * getPtrAdd(Constant *Ptr, Constant *Offset, GEPNoWrapFlags NW=GEPNoWrapFlags::none(), std::optional< ConstantRange > InRange=std::nullopt, Type *OnlyIfReduced=nullptr)
Create a getelementptr i8, ptr, offset constant expression.
Definition Constants.h:1481
static LLVM_ABI Constant * getPtrToInt(Constant *C, Type *Ty, bool OnlyIfReduced=false)
static LLVM_ABI Constant * getPtrToAddr(Constant *C, Type *Ty, bool OnlyIfReduced=false)
static LLVM_ABI Constant * getAdd(Constant *C1, Constant *C2, bool HasNUW=false, bool HasNSW=false)
static LLVM_ABI Constant * getNeg(Constant *C, bool HasNSW=false)
static LLVM_ABI Constant * getTrunc(Constant *C, Type *Ty, bool OnlyIfReduced=false)
This is the shared class of boolean and integer constants.
Definition Constants.h:87
bool isZero() const
This is just a convenience method to make client code smaller for a common code.
Definition Constants.h:219
static LLVM_ABI ConstantInt * getFalse(LLVMContext &Context)
uint64_t getZExtValue() const
Return the constant as a 64-bit unsigned integer value after it has been zero extended as appropriate...
Definition Constants.h:168
const APInt & getValue() const
Return the constant as an APInt value reference.
Definition Constants.h:159
static LLVM_ABI ConstantInt * getBool(LLVMContext &Context, bool V)
This class represents a range of values.
LLVM_ABI ConstantRange add(const ConstantRange &Other) const
Return a new range representing the possible values resulting from an addition of a value in this ran...
LLVM_ABI ConstantRange zextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
PreferredRangeType
If represented precisely, the result of some range operations may consist of multiple disjoint ranges...
LLVM_ABI bool getEquivalentICmp(CmpInst::Predicate &Pred, APInt &RHS) const
Set up Pred and RHS such that ConstantRange::makeExactICmpRegion(Pred, RHS) == *this.
const APInt & getLower() const
Return the lower value for this range.
LLVM_ABI ConstantRange urem(const ConstantRange &Other) const
Return a new range representing the possible values resulting from an unsigned remainder operation of...
LLVM_ABI bool isFullSet() const
Return true if this set contains all of the elements possible for this data-type.
LLVM_ABI bool icmp(CmpInst::Predicate Pred, const ConstantRange &Other) const
Does the predicate Pred hold between ranges this and Other?
LLVM_ABI bool isEmptySet() const
Return true if this set contains no members.
LLVM_ABI ConstantRange zeroExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
LLVM_ABI bool isSignWrappedSet() const
Return true if this set wraps around the signed domain.
LLVM_ABI APInt getSignedMin() const
Return the smallest signed value contained in the ConstantRange.
LLVM_ABI bool isWrappedSet() const
Return true if this set wraps around the unsigned domain.
LLVM_ABI void print(raw_ostream &OS) const
Print out the bounds to a stream.
LLVM_ABI ConstantRange truncate(uint32_t BitWidth, unsigned NoWrapKind=0) const
Return a new range in the specified integer type, which must be strictly smaller than the current typ...
LLVM_ABI ConstantRange signExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
const APInt & getUpper() const
Return the upper value for this range.
LLVM_ABI ConstantRange unionWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the union of this range with another range.
static LLVM_ABI ConstantRange makeExactICmpRegion(CmpInst::Predicate Pred, const APInt &Other)
Produce the exact range such that all values in the returned range satisfy the given predicate with a...
LLVM_ABI bool contains(const APInt &Val) const
Return true if the specified value is in the set.
LLVM_ABI APInt getUnsignedMax() const
Return the largest unsigned value contained in the ConstantRange.
LLVM_ABI ConstantRange intersectWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the intersection of this range with another range.
LLVM_ABI APInt getSignedMax() const
Return the largest signed value contained in the ConstantRange.
static ConstantRange getNonEmpty(APInt Lower, APInt Upper)
Create non-empty constant range with the given bounds.
static LLVM_ABI ConstantRange makeGuaranteedNoWrapRegion(Instruction::BinaryOps BinOp, const ConstantRange &Other, unsigned NoWrapKind)
Produce the largest range containing all X such that "X BinOp Y" is guaranteed not to wrap (overflow)...
LLVM_ABI unsigned getMinSignedBits() const
Compute the maximal number of bits needed to represent every value in this signed range.
uint32_t getBitWidth() const
Get the bit width of this ConstantRange.
LLVM_ABI ConstantRange sub(const ConstantRange &Other) const
Return a new range representing the possible values resulting from a subtraction of a value in this r...
LLVM_ABI ConstantRange sextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
static LLVM_ABI ConstantRange makeExactNoWrapRegion(Instruction::BinaryOps BinOp, const APInt &Other, unsigned NoWrapKind)
Produce the range that contains X if and only if "X BinOp Other" does not wrap.
This is an important base class in LLVM.
Definition Constant.h:43
A parsed version of the target data layout string in and methods for querying it.
Definition DataLayout.h:64
LLVM_ABI const StructLayout * getStructLayout(StructType *Ty) const
Returns a StructLayout object, indicating the alignment of the struct, its size, and the offsets of i...
LLVM_ABI IntegerType * getIntPtrType(LLVMContext &C, unsigned AddressSpace=0) const
Returns an integer type with size at least as big as that of a pointer in the given address space.
LLVM_ABI unsigned getIndexTypeSizeInBits(Type *Ty) const
The size in bits of the index used in GEP calculation for this type.
LLVM_ABI IntegerType * getIndexType(LLVMContext &C, unsigned AddressSpace) const
Returns the type of a GEP index in AddressSpace.
TypeSize getTypeSizeInBits(Type *Ty) const
Size examples:
Definition DataLayout.h:791
ValueT lookup(const_arg_type_t< KeyT > Val) const
Return the entry for the specified key, or a default constructed value if no such entry exists.
Definition DenseMap.h:205
iterator find(const_arg_type_t< KeyT > Val)
Definition DenseMap.h:178
std::pair< iterator, bool > try_emplace(KeyT &&Key, Ts &&...Args)
Definition DenseMap.h:254
DenseMapIterator< KeyT, ValueT, KeyInfoT, BucketT > iterator
Definition DenseMap.h:74
iterator find_as(const LookupKeyT &Val)
Alternate version of find() which allows a different, and possibly less expensive,...
Definition DenseMap.h:191
size_type count(const_arg_type_t< KeyT > Val) const
Return 1 if the specified key is in the map, 0 otherwise.
Definition DenseMap.h:174
iterator end()
Definition DenseMap.h:81
bool contains(const_arg_type_t< KeyT > Val) const
Return true if the specified key is in the map, false otherwise.
Definition DenseMap.h:169
void swap(DerivedT &RHS)
Definition DenseMap.h:395
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition DenseMap.h:239
Analysis pass which computes a DominatorTree.
Definition Dominators.h:278
Legacy analysis pass which computes a DominatorTree.
Definition Dominators.h:314
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition Dominators.h:159
LLVM_ABI bool isReachableFromEntry(const Use &U) const
Provide an overload for a Use.
LLVM_ABI bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
This class describes a reference to an interned FoldingSetNodeID, which can be a useful to store node...
Definition FoldingSet.h:171
This class is used to gather all the unique data bits of a node.
Definition FoldingSet.h:208
FunctionPass(char &pid)
Definition Pass.h:316
Represents flags for the getelementptr instruction/expression.
bool hasNoUnsignedSignedWrap() const
bool hasNoUnsignedWrap() const
static GEPNoWrapFlags none()
static LLVM_ABI Type * getTypeAtIndex(Type *Ty, Value *Idx)
Return the type of the element at the given index of an indexable type.
Module * getParent()
Get the module that this global value is contained inside of...
static bool isPrivateLinkage(LinkageTypes Linkage)
static bool isInternalLinkage(LinkageTypes Linkage)
This instruction compares its operands according to the predicate given to the constructor.
CmpPredicate getCmpPredicate() const
static bool isGE(Predicate P)
Return true if the predicate is SGE or UGE.
CmpPredicate getSwappedCmpPredicate() const
static LLVM_ABI bool compare(const APInt &LHS, const APInt &RHS, ICmpInst::Predicate Pred)
Return result of LHS Pred RHS comparison.
static bool isLT(Predicate P)
Return true if the predicate is SLT or ULT.
CmpPredicate getInverseCmpPredicate() const
Predicate getNonStrictCmpPredicate() const
For example, SGT -> SGE, SLT -> SLE, ULT -> ULE, UGT -> UGE.
static bool isGT(Predicate P)
Return true if the predicate is SGT or UGT.
Predicate getFlippedSignednessPredicate() const
For example, SLT->ULT, ULT->SLT, SLE->ULE, ULE->SLE, EQ->EQ.
static CmpPredicate getInverseCmpPredicate(CmpPredicate Pred)
static bool isEquality(Predicate P)
Return true if this predicate is either EQ or NE.
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
static bool isLE(Predicate P)
Return true if the predicate is SLE or ULE.
LLVM_ABI bool hasNoUnsignedWrap() const LLVM_READONLY
Determine whether the no unsigned wrap flag is set.
LLVM_ABI bool hasNoSignedWrap() const LLVM_READONLY
Determine whether the no signed wrap flag is set.
LLVM_ABI bool isIdenticalToWhenDefined(const Instruction *I, bool IntersectAttrs=false) const LLVM_READONLY
This is like isIdenticalTo, except that it ignores the SubclassOptionalData flags,...
Class to represent integer types.
static LLVM_ABI IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
Definition Type.cpp:354
A helper class to return the specified delimiter string after the first invocation of operator String...
An instruction for reading from memory.
Analysis pass that exposes the LoopInfo for a function.
Definition LoopInfo.h:587
bool contains(const LoopT *L) const
Return true if the specified loop is contained within in this loop.
BlockT * getHeader() const
unsigned getLoopDepth() const
Return the nesting level of this loop.
BlockT * getLoopPredecessor() const
If the given loop's header has exactly one unique predecessor outside the loop, return it.
LoopT * getParentLoop() const
Return the parent loop if it exists or nullptr for top level loops.
unsigned getLoopDepth(const BlockT *BB) const
Return the loop nesting level of the specified block.
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
The legacy pass manager's analysis pass to compute loop information.
Definition LoopInfo.h:612
Represents a single loop in the control flow graph.
Definition LoopInfo.h:40
bool isLoopInvariant(const Value *V) const
Return true if the specified value is loop invariant.
Definition LoopInfo.cpp:67
Metadata node.
Definition Metadata.h:1080
A Module instance is used to store all the information related to an LLVM module.
Definition Module.h:67
unsigned getOpcode() const
Return the opcode for this Instruction or ConstantExpr.
Definition Operator.h:43
Utility class for integer operators which may exhibit overflow - Add, Sub, Mul, and Shl.
Definition Operator.h:78
bool hasNoSignedWrap() const
Test whether this operation is known to never undergo signed overflow, aka the nsw property.
Definition Operator.h:111
bool hasNoUnsignedWrap() const
Test whether this operation is known to never undergo unsigned overflow, aka the nuw property.
Definition Operator.h:105
iterator_range< const_block_iterator > blocks() const
op_range incoming_values()
Value * getIncomingValueForBlock(const BasicBlock *BB) const
BasicBlock * getIncomingBlock(unsigned i) const
Return incoming basic block number i.
Value * getIncomingValue(unsigned i) const
Return incoming value number x.
unsigned getNumIncomingValues() const
Return the number of incoming edges.
AnalysisType & getAnalysis() const
getAnalysis<AnalysisType>() - This function is used by subclasses to get to the analysis information ...
PointerIntPair - This class implements a pair of a pointer and small integer.
static PointerType * getUnqual(Type *ElementType)
This constructs a pointer to an object of the specified type in the default address space (address sp...
static LLVM_ABI PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
LLVM_ABI void addPredicate(const SCEVPredicate &Pred)
Adds a new predicate.
LLVM_ABI const SCEVPredicate & getPredicate() const
LLVM_ABI const SCEV * getPredicatedSCEV(const SCEV *Expr)
Returns the rewritten SCEV for Expr in the context of the current SCEV predicate.
LLVM_ABI bool hasNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags)
Returns true if we've proved that V doesn't wrap by means of a SCEV predicate.
LLVM_ABI void setNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags)
Proves that V doesn't overflow by adding SCEV predicate.
LLVM_ABI void print(raw_ostream &OS, unsigned Depth) const
Print the SCEV mappings done by the Predicated Scalar Evolution.
LLVM_ABI bool areAddRecsEqualWithPreds(const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const
Check if AR1 and AR2 are equal, while taking into account Equal predicates in Preds.
LLVM_ABI PredicatedScalarEvolution(ScalarEvolution &SE, Loop &L)
LLVM_ABI const SCEVAddRecExpr * getAsAddRec(Value *V)
Attempts to produce an AddRecExpr for V by adding additional SCEV predicates.
LLVM_ABI unsigned getSmallConstantMaxTripCount()
Returns the upper bound of the loop trip count as a normal unsigned value, or 0 if the trip count is ...
LLVM_ABI const SCEV * getBackedgeTakenCount()
Get the (predicated) backedge count for the analyzed loop.
LLVM_ABI const SCEV * getSymbolicMaxBackedgeTakenCount()
Get the (predicated) symbolic max backedge count for the analyzed loop.
LLVM_ABI const SCEV * getSCEV(Value *V)
Returns the SCEV expression of V, in the context of the current SCEV predicate.
A set of analyses that are preserved following a run of a transformation pass.
Definition Analysis.h:112
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition Analysis.h:118
PreservedAnalysisChecker getChecker() const
Build a checker for this PreservedAnalyses and the specified analysis type.
Definition Analysis.h:275
constexpr bool isValid() const
Definition Register.h:112
This node represents an addition of some number of SCEVs.
This node represents a polynomial recurrence on the trip count of the specified loop.
LLVM_ABI const SCEV * evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const
Return the value of this chain of recurrences at the specified iteration number.
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a recurrence without clearing any previously set flags.
bool isAffine() const
Return true if this represents an expression A + B*x where A and B are loop invariant values.
bool isQuadratic() const
Return true if this represents an expression A + B*x + C*x^2 where A, B and C are loop invariant valu...
LLVM_ABI const SCEV * getNumIterationsInRange(const ConstantRange &Range, ScalarEvolution &SE) const
Return the number of iterations of this loop that produce values in the specified constant range.
LLVM_ABI const SCEVAddRecExpr * getPostIncExpr(ScalarEvolution &SE) const
Return an expression representing the value of this expression one iteration of the loop ahead.
SCEVUse getStepRecurrence(ScalarEvolution &SE) const
Constructs and returns the recurrence indicating how much this expression steps by.
This is the base class for unary cast operator classes.
LLVM_ABI SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, SCEVUse op, Type *ty)
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a non-recurrence without clearing previously set flags.
This class represents an assumption that the expression LHS Pred RHS evaluates to true,...
SCEVComparePredicate(const FoldingSetNodeIDRef ID, const ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Implementation of the SCEVPredicate interface.
This class represents a constant integer value.
ConstantInt * getValue() const
const APInt & getAPInt() const
This is the base class for unary integral cast operator classes.
LLVM_ABI SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, SCEVUse op, Type *ty)
This node is the base class min/max selections.
static enum SCEVTypes negate(enum SCEVTypes T)
This node represents multiplication of some number of SCEVs.
This node is a base class providing common functionality for n'ary operators.
ArrayRef< SCEVUse > operands() const
NoWrapFlags getNoWrapFlags(NoWrapFlags Mask=NoWrapMask) const
SCEVUse getOperand(unsigned i) const
This class represents an assumption made using SCEV expressions which can be checked at run-time.
SCEVPredicate(const SCEVPredicate &)=default
virtual bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const =0
Returns true if this predicate implies N.
SCEVPredicateKind Kind
This class represents a cast from a pointer to a pointer-sized integer value, without capturing the p...
This class represents a cast from a pointer to a pointer-sized integer value.
This visitor recursively visits a SCEV expression and re-writes it.
const SCEV * visitSignExtendExpr(const SCEVSignExtendExpr *Expr)
const SCEV * visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr)
const SCEV * visitSMinExpr(const SCEVSMinExpr *Expr)
const SCEV * visitUMinExpr(const SCEVUMinExpr *Expr)
This class represents a signed minimum selection.
This node is the base class for sequential/in-order min/max selections.
static SCEVTypes getEquivalentNonSequentialSCEVType(SCEVTypes Ty)
This class represents a sign extension of a small integer value to a larger integer value.
Visit all nodes in the expression tree using worklist traversal.
This class represents a truncation of an integer value to a smaller integer value.
This class represents a binary unsigned division operation.
This class represents an unsigned minimum selection.
This class represents a composition of other SCEV predicates, and is the class that most clients will...
void print(raw_ostream &OS, unsigned Depth) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Returns true if this predicate implies N.
SCEVUnionPredicate(ArrayRef< const SCEVPredicate * > Preds, ScalarEvolution &SE)
Union predicates don't get cached so create a dummy set ID for it.
bool isAlwaysTrue() const override
Implementation of the SCEVPredicate interface.
This means that we are dealing with an entirely unknown SCEV value, and only represent it as its LLVM...
This class represents the value of vscale, as used when defining the length of a scalable vector or r...
This class represents an assumption made on an AddRec expression.
IncrementWrapFlags
Similar to SCEV::NoWrapFlags, but with slightly different semantics for FlagNUSW.
SCEVWrapPredicate(const FoldingSetNodeIDRef ID, const SCEVAddRecExpr *AR, IncrementWrapFlags Flags)
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Returns true if this predicate implies N.
static SCEVWrapPredicate::IncrementWrapFlags setFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OnFlags)
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
const SCEVAddRecExpr * getExpr() const
Implementation of the SCEVPredicate interface.
static SCEVWrapPredicate::IncrementWrapFlags clearFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OffFlags)
Convenient IncrementWrapFlags manipulation methods.
static SCEVWrapPredicate::IncrementWrapFlags getImpliedFlags(const SCEVAddRecExpr *AR, ScalarEvolution &SE)
Returns the set of SCEVWrapPredicate no wrap flags implied by a SCEVAddRecExpr.
IncrementWrapFlags getFlags() const
Returns the set assumed no overflow flags.
This class represents a zero extension of a small integer value to a larger integer value.
This class represents an analyzed expression in the program.
unsigned short getExpressionSize() const
SCEVNoWrapFlags NoWrapFlags
LLVM_ABI bool isOne() const
Return true if the expression is a constant one.
static constexpr auto FlagNUW
LLVM_ABI void computeAndSetCanonical(ScalarEvolution &SE)
Compute and set the canonical SCEV, by constructing a SCEV with the same operands,...
LLVM_ABI bool isZero() const
Return true if the expression is a constant zero.
const SCEV * CanonicalSCEV
Pointer to the canonical version of the SCEV, i.e.
static constexpr auto FlagAnyWrap
LLVM_ABI void dump() const
This method is used for debugging.
LLVM_ABI bool isAllOnesValue() const
Return true if the expression is a constant all-ones value.
LLVM_ABI bool isNonConstantNegative() const
Return true if the specified scev is negated, but not a constant.
static constexpr auto FlagNSW
LLVM_ABI ArrayRef< SCEVUse > operands() const
Return operands of this SCEV expression.
LLVM_ABI void print(raw_ostream &OS) const
Print out the internal representation of this scalar to the specified stream.
SCEV(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, unsigned short ExpressionSize)
SCEVTypes getSCEVType() const
static constexpr auto FlagNW
LLVM_ABI Type * getType() const
Return the LLVM type of this SCEV expression.
Analysis pass that exposes the ScalarEvolution for a function.
LLVM_ABI ScalarEvolution run(Function &F, FunctionAnalysisManager &AM)
LLVM_ABI PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
LLVM_ABI PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
void print(raw_ostream &OS, const Module *=nullptr) const override
print - Print out the internal state of the pass.
bool runOnFunction(Function &F) override
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
void releaseMemory() override
releaseMemory() - This member can be implemented by a pass if it wants to be able to release its memo...
void verifyAnalysis() const override
verifyAnalysis() - This member can be implemented by a analysis pass to check state of analysis infor...
static LLVM_ABI LoopGuards collect(const Loop *L, ScalarEvolution &SE)
Collect rewrite map for loop guards for loop L, together with flags indicating if NUW and NSW can be ...
LLVM_ABI const SCEV * rewrite(const SCEV *Expr) const
Try to apply the collected loop guards to Expr.
The main scalar evolution driver.
LLVM_ABI const SCEV * getUDivExpr(SCEVUse LHS, SCEVUse RHS)
Get a canonical unsigned division expression, or something simpler if possible.
const SCEV * getConstantMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEVConstant that is greater than or equal to (i.e.
static bool hasFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags TestFlags)
const DataLayout & getDataLayout() const
Return the DataLayout associated with the module this SCEV instance is operating on.
LLVM_ABI bool isKnownNonNegative(const SCEV *S)
Test if the given expression is known to be non-negative.
LLVM_ABI bool isKnownOnEveryIteration(CmpPredicate Pred, const SCEVAddRecExpr *LHS, const SCEV *RHS)
Test if the condition described by Pred, LHS, RHS is known to be true on every iteration of the loop ...
LLVM_ABI const SCEV * getNegativeSCEV(const SCEV *V, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap)
Return the SCEV object corresponding to -V.
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterationsImpl(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
LLVM_ABI const SCEV * getUDivCeilSCEV(const SCEV *N, const SCEV *D)
Compute ceil(N / D).
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterations(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L at given Context duri...
LLVM_ABI Type * getWiderType(Type *Ty1, Type *Ty2) const
LLVM_ABI const SCEV * getAbsExpr(const SCEV *Op, bool IsNSW)
LLVM_ABI bool isKnownNonPositive(const SCEV *S)
Test if the given expression is known to be non-positive.
LLVM_ABI bool isKnownNegative(const SCEV *S)
Test if the given expression is known to be negative.
LLVM_ABI const SCEV * getPredicatedConstantMaxBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getConstantMaxBackedgeTakenCount, except it will add a set of SCEV predicates to Predicate...
LLVM_ABI const SCEV * removePointerBase(const SCEV *S)
Compute an expression equivalent to S - getPointerBase(S).
LLVM_ABI bool isLoopEntryGuardedByCond(const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the loop is protected by a conditional between LHS and RHS.
LLVM_ABI bool isKnownNonZero(const SCEV *S)
Test if the given expression is known to be non-zero.
LLVM_ABI const SCEV * getURemExpr(SCEVUse LHS, SCEVUse RHS)
Represents an unsigned remainder expression based on unsigned division.
LLVM_ABI const SCEV * getSCEVAtScope(const SCEV *S, const Loop *L)
Return a SCEV expression for the specified value at the specified scope in the program.
LLVM_ABI const SCEV * getBackedgeTakenCount(const Loop *L, ExitCountKind Kind=Exact)
If the specified loop has a predictable backedge-taken count, return it, otherwise return a SCEVCould...
LLVM_ABI const SCEV * getSMinExpr(SCEVUse LHS, SCEVUse RHS)
LLVM_ABI void setNoWrapFlags(SCEVAddRecExpr *AddRec, SCEV::NoWrapFlags Flags)
Update no-wrap flags of an AddRec.
LLVM_ABI const SCEV * getUMaxFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS)
Promote the operands to the wider of the types using zero-extension, and then perform a umax operatio...
const SCEV * getZero(Type *Ty)
Return a SCEV for the constant 0 of a specific type.
LLVM_ABI bool willNotOverflow(Instruction::BinaryOps BinOp, bool Signed, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI=nullptr)
Is operation BinOp between LHS and RHS provably does not have a signed/unsigned overflow (Signed)?
LLVM_ABI ExitLimit computeExitLimitFromCond(const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit, bool AllowPredicates=false)
Compute the number of times the backedge of the specified loop will execute if its exit condition wer...
LLVM_ABI const SCEV * getZeroExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI const SCEV * getMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< SCEVUse > &Operands)
LLVM_ABI const SCEVPredicate * getEqualPredicate(const SCEV *LHS, const SCEV *RHS)
LLVM_ABI unsigned getSmallConstantTripMultiple(const Loop *L, const SCEV *ExitCount)
Returns the largest constant divisor of the trip count as a normal unsigned value,...
LLVM_ABI uint64_t getTypeSizeInBits(Type *Ty) const
Return the size in bits of the specified type, for which isSCEVable must return true.
LLVM_ABI const SCEV * getConstant(ConstantInt *V)
LLVM_ABI const SCEV * getPredicatedBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getBackedgeTakenCount, except it will add a set of SCEV predicates to Predicates that are ...
LLVM_ABI const SCEV * getSCEV(Value *V)
Return a SCEV expression for the full generality of the specified expression.
LLVM_ABI const SCEV * getMinusSCEV(SCEVUse LHS, SCEVUse RHS, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Return LHS-RHS.
ConstantRange getSignedRange(const SCEV *S)
Determine the signed range for a particular SCEV.
LLVM_ABI const SCEV * getAddRecExpr(SCEVUse Start, SCEVUse Step, const Loop *L, SCEV::NoWrapFlags Flags)
Get an add recurrence expression for the specified loop.
LLVM_ABI const SCEV * getNoopOrSignExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
bool loopHasNoAbnormalExits(const Loop *L)
Return true if the loop has no abnormal exits.
LLVM_ABI const SCEV * getTripCountFromExitCount(const SCEV *ExitCount)
A version of getTripCountFromExitCount below which always picks an evaluation type which can not resu...
LLVM_ABI ScalarEvolution(Function &F, TargetLibraryInfo &TLI, AssumptionCache &AC, DominatorTree &DT, LoopInfo &LI)
const SCEV * getOne(Type *Ty)
Return a SCEV for the constant 1 of a specific type.
LLVM_ABI const SCEV * getTruncateOrNoop(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI const SCEV * getLosslessPtrToIntExpr(const SCEV *Op)
LLVM_ABI const SCEV * getCastExpr(SCEVTypes Kind, const SCEV *Op, Type *Ty)
LLVM_ABI const SCEV * getSequentialMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< SCEVUse > &Operands)
LLVM_ABI std::optional< bool > evaluatePredicateAt(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Check whether the condition described by Pred, LHS, and RHS is true or false in the given Context.
LLVM_ABI unsigned getSmallConstantMaxTripCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > *Predicates=nullptr)
Returns the upper bound of the loop trip count as a normal unsigned value.
LLVM_ABI const SCEV * getPtrToIntExpr(const SCEV *Op, Type *Ty)
LLVM_ABI bool isBackedgeTakenCountMaxOrZero(const Loop *L)
Return true if the backedge taken count is either the value returned by getConstantMaxBackedgeTakenCo...
LLVM_ABI void forgetLoop(const Loop *L)
This method should be called by the client when it has changed a loop in a way that may effect Scalar...
LLVM_ABI bool isLoopInvariant(const SCEV *S, const Loop *L)
Return true if the value of the given SCEV is unchanging in the specified loop.
LLVM_ABI bool isKnownPositive(const SCEV *S)
Test if the given expression is known to be positive.
LLVM_ABI bool SimplifyICmpOperands(CmpPredicate &Pred, SCEVUse &LHS, SCEVUse &RHS, unsigned Depth=0)
Simplify LHS and RHS in a comparison with predicate Pred.
APInt getUnsignedRangeMin(const SCEV *S)
Determine the min of the unsigned range for a particular SCEV.
LLVM_ABI const SCEV * getOffsetOfExpr(Type *IntTy, StructType *STy, unsigned FieldNo)
Return an expression for offsetof on the given field with type IntTy.
LLVM_ABI LoopDisposition getLoopDisposition(const SCEV *S, const Loop *L)
Return the "disposition" of the given SCEV with respect to the given loop.
LLVM_ABI bool containsAddRecurrence(const SCEV *S)
Return true if the SCEV is a scAddRecExpr or it contains scAddRecExpr.
LLVM_ABI const SCEV * getSignExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI bool hasOperand(const SCEV *S, const SCEV *Op) const
Test whether the given SCEV has Op as a direct or indirect operand.
LLVM_ABI const SCEV * getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI bool isSCEVable(Type *Ty) const
Test if values of the given type are analyzable within the SCEV framework.
LLVM_ABI Type * getEffectiveSCEVType(Type *Ty) const
Return a type with the same bitwidth as the given type and which represents how SCEV will treat the g...
LLVM_ABI const SCEVPredicate * getComparePredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
LLVM_ABI bool haveSameSign(const SCEV *S1, const SCEV *S2)
Return true if we know that S1 and S2 must have the same sign.
LLVM_ABI const SCEV * getNotSCEV(const SCEV *V)
Return the SCEV object corresponding to ~V.
LLVM_ABI const SCEV * getElementCount(Type *Ty, ElementCount EC, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap)
LLVM_ABI bool instructionCouldExistWithOperands(const SCEV *A, const SCEV *B)
Return true if there exists a point in the program at which both A and B could be operands to the sam...
ConstantRange getUnsignedRange(const SCEV *S)
Determine the unsigned range for a particular SCEV.
LLVM_ABI void print(raw_ostream &OS) const
LLVM_ABI const SCEV * getPredicatedExitCount(const Loop *L, const BasicBlock *ExitingBlock, SmallVectorImpl< const SCEVPredicate * > *Predicates, ExitCountKind Kind=Exact)
Same as above except this uses the predicated backedge taken info and may require predicates.
static SCEV::NoWrapFlags clearFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OffFlags)
LLVM_ABI void forgetTopmostLoop(const Loop *L)
LLVM_ABI void forgetValue(Value *V)
This method should be called by the client when it has changed a value in a way that may effect its v...
APInt getSignedRangeMin(const SCEV *S)
Determine the min of the signed range for a particular SCEV.
LLVM_ABI bool isLoopUniform(const SCEV *S, const Loop *L)
Returns true if the given SCEV is loop-uniform with respect to the specified loop L.
LLVM_ABI const SCEV * getNoopOrAnyExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI void forgetBlockAndLoopDispositions(Value *V=nullptr)
Called when the client has changed the disposition of values in a loop or block.
LLVM_ABI const SCEV * getTruncateExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI const SCEV * getUMaxExpr(SCEVUse LHS, SCEVUse RHS)
static SCEV::NoWrapFlags maskFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags Mask)
Convenient NoWrapFlags manipulation.
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantPredicate(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI=nullptr)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L, return a LoopInvaria...
LLVM_ABI const SCEV * getStoreSizeOfExpr(Type *IntTy, Type *StoreTy)
Return an expression for the store size of StoreTy that is type IntTy.
LLVM_ABI const SCEVPredicate * getWrapPredicate(const SCEVAddRecExpr *AR, SCEVWrapPredicate::IncrementWrapFlags AddedFlags)
LLVM_ABI bool isLoopBackedgeGuardedByCond(const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether the backedge of the loop is protected by a conditional between LHS and RHS.
LLVM_ABI APInt getNonZeroConstantMultiple(const SCEV *S)
const SCEV * getMinusOne(Type *Ty)
Return a SCEV for the constant -1 of a specific type.
static SCEV::NoWrapFlags setFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OnFlags)
LLVM_ABI bool hasLoopInvariantBackedgeTakenCount(const Loop *L)
Return true if the specified loop has an analyzable loop-invariant backedge-taken count.
LLVM_ABI BlockDisposition getBlockDisposition(const SCEV *S, const BasicBlock *BB)
Return the "disposition" of the given SCEV with respect to the given block.
LLVM_ABI const SCEV * getNoopOrZeroExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI bool invalidate(Function &F, const PreservedAnalyses &PA, FunctionAnalysisManager::Invalidator &Inv)
LLVM_ABI const SCEV * getUMinFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS, bool Sequential=false)
Promote the operands to the wider of the types using zero-extension, and then perform a umin operatio...
LLVM_ABI bool loopIsFiniteByAssumption(const Loop *L)
Return true if this loop is finite by assumption.
LLVM_ABI const SCEV * getExistingSCEV(Value *V)
Return an existing SCEV for V if there is one, otherwise return nullptr.
LLVM_ABI APInt getConstantMultiple(const SCEV *S, const Instruction *CtxI=nullptr)
Returns the max constant multiple of S.
LoopDisposition
An enum describing the relationship between a SCEV and a loop.
@ LoopComputable
The SCEV varies predictably with the loop.
@ LoopVariant
The SCEV is loop-variant (unknown).
@ LoopInvariant
The SCEV is loop-invariant.
@ LoopUniform
The SCEV is loop-uniform.
LLVM_ABI bool isKnownMultipleOf(const SCEV *S, uint64_t M, SmallVectorImpl< const SCEVPredicate * > &Assumptions)
Check that S is a multiple of M.
LLVM_ABI const SCEV * getAnyExtendExpr(const SCEV *Op, Type *Ty)
getAnyExtendExpr - Return a SCEV for the given operand extended with unspecified bits out to the give...
LLVM_ABI bool isKnownToBeAPowerOfTwo(const SCEV *S, bool OrZero=false, bool OrNegative=false)
Test if the given expression is known to be a power of 2.
LLVM_ABI std::optional< SCEV::NoWrapFlags > getStrengthenedNoWrapFlagsFromBinOp(const OverflowingBinaryOperator *OBO)
Parse NSW/NUW flags from add/sub/mul IR binary operation Op into SCEV no-wrap flags,...
LLVM_ABI void forgetLcssaPhiWithNewPredecessor(Loop *L, PHINode *V)
Forget LCSSA phi node V of loop L to which a new predecessor was added, such that it may no longer be...
LLVM_ABI bool containsUndefs(const SCEV *S) const
Return true if the SCEV expression contains an undef value.
LLVM_ABI std::optional< MonotonicPredicateType > getMonotonicPredicateType(const SCEVAddRecExpr *LHS, ICmpInst::Predicate Pred)
If, for all loop invariant X, the predicate "LHS `Pred` X" is monotonically increasing or decreasing,...
LLVM_ABI const SCEV * getCouldNotCompute()
LLVM_ABI const SCEV * getMulExpr(SmallVectorImpl< SCEVUse > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical multiply expression, or something simpler if possible.
LLVM_ABI bool isAvailableAtLoopEntry(const SCEV *S, const Loop *L)
Determine if the SCEV can be evaluated at loop's entry.
LLVM_ABI uint32_t getMinTrailingZeros(const SCEV *S, const Instruction *CtxI=nullptr)
Determine the minimum number of zero bits that S is guaranteed to end in (at every loop iteration).
BlockDisposition
An enum describing the relationship between a SCEV and a basic block.
@ DominatesBlock
The SCEV dominates the block.
@ ProperlyDominatesBlock
The SCEV properly dominates the block.
@ DoesNotDominateBlock
The SCEV does not dominate the block.
LLVM_ABI const SCEV * getExitCount(const Loop *L, const BasicBlock *ExitingBlock, ExitCountKind Kind=Exact)
Return the number of times the backedge executes before the given exit would be taken; if not exactly...
LLVM_ABI const SCEV * getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI void getPoisonGeneratingValues(SmallPtrSetImpl< const Value * > &Result, const SCEV *S)
Return the set of Values that, if poison, will definitively result in S being poison as well.
LLVM_ABI void forgetLoopDispositions()
Called when the client has changed the disposition of values in this loop.
LLVM_ABI const SCEV * getVScale(Type *Ty)
LLVM_ABI unsigned getSmallConstantTripCount(const Loop *L)
Returns the exact trip count of the loop if we can compute it, and the result is a small constant.
LLVM_ABI bool hasComputableLoopEvolution(const SCEV *S, const Loop *L)
Return true if the given SCEV changes value in a known way in the specified loop.
LLVM_ABI const SCEV * getPointerBase(const SCEV *V)
Transitively follow the chain of pointer-type operands until reaching a SCEV that does not have a sin...
LLVM_ABI void forgetAllLoops()
LLVM_ABI bool dominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV dominate the specified basic block.
APInt getUnsignedRangeMax(const SCEV *S)
Determine the max of the unsigned range for a particular SCEV.
LLVM_ABI const SCEV * getAddExpr(SmallVectorImpl< SCEVUse > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical add expression, or something simpler if possible.
ExitCountKind
The terms "backedge taken count" and "exit count" are used interchangeably to refer to the number of ...
@ SymbolicMaximum
An expression which provides an upper bound on the exact trip count.
@ ConstantMaximum
A constant which provides an upper bound on the exact trip count.
@ Exact
An expression exactly describing the number of times the backedge has executed when a loop is exited.
LLVM_ABI bool isKnownPredicate(CmpPredicate Pred, SCEVUse LHS, SCEVUse RHS)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
LLVM_ABI const SCEV * applyLoopGuards(const SCEV *Expr, const Loop *L)
Try to apply information from loop guards for L to Expr.
LLVM_ABI const SCEV * getPtrToAddrExpr(const SCEV *Op)
LLVM_ABI const SCEVAddRecExpr * convertSCEVToAddRecWithPredicates(const SCEV *S, const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Preds)
Tries to convert the S expression to an AddRec expression, adding additional predicates to Preds as r...
LLVM_ABI const SCEV * getSMaxExpr(SCEVUse LHS, SCEVUse RHS)
LLVM_ABI const SCEV * getElementSize(Instruction *Inst)
Return the size of an element read or written by Inst.
LLVM_ABI const SCEV * getSizeOfExpr(Type *IntTy, TypeSize Size)
Return an expression for a TypeSize.
LLVM_ABI std::optional< bool > evaluatePredicate(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Check whether the condition described by Pred, LHS, and RHS is true or false.
LLVM_ABI const SCEV * getUnknown(Value *V)
LLVM_ABI std::optional< std::pair< const SCEV *, SmallVector< const SCEVPredicate *, 3 > > > createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI)
Checks if SymbolicPHI can be rewritten as an AddRecExpr under some Predicates.
LLVM_ABI const SCEV * getTruncateOrZeroExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI bool isKnownViaInduction(CmpPredicate Pred, SCEVUse LHS, SCEVUse RHS)
We'd like to check the predicate on every iteration of the most dominated loop between loops used in ...
LLVM_ABI std::optional< APInt > computeConstantDifference(const SCEV *LHS, const SCEV *RHS)
Compute LHS - RHS and returns the result as an APInt if it is a constant, and std::nullopt if it isn'...
LLVM_ABI bool properlyDominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV properly dominate the specified basic block.
LLVM_ABI const SCEV * getUDivExactExpr(SCEVUse LHS, SCEVUse RHS)
Get a canonical unsigned division expression, or something simpler if possible.
LLVM_ABI const SCEV * rewriteUsingPredicate(const SCEV *S, const Loop *L, const SCEVPredicate &A)
Re-writes the SCEV according to the Predicates in A.
LLVM_ABI std::pair< const SCEV *, const SCEV * > SplitIntoInitAndPostInc(const Loop *L, const SCEV *S)
Splits SCEV expression S into two SCEVs.
LLVM_ABI bool canReuseInstruction(const SCEV *S, Instruction *I, SmallVectorImpl< Instruction * > &DropPoisonGeneratingInsts)
Check whether it is poison-safe to represent the expression S using the instruction I.
LLVM_ABI bool isKnownPredicateAt(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
LLVM_ABI const SCEV * getPredicatedSymbolicMaxBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getSymbolicMaxBackedgeTakenCount, except it will add a set of SCEV predicates to Predicate...
LLVM_ABI const SCEV * getGEPExpr(GEPOperator *GEP, ArrayRef< SCEVUse > IndexExprs)
Returns an expression for a GEP.
LLVM_ABI const SCEV * getUMinExpr(SCEVUse LHS, SCEVUse RHS, bool Sequential=false)
LLVM_ABI void registerUser(const SCEV *User, ArrayRef< const SCEV * > Ops)
Notify this ScalarEvolution that User directly uses SCEVs in Ops.
LLVM_ABI bool isBasicBlockEntryGuardedByCond(const BasicBlock *BB, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the basic block is protected by a conditional between LHS and RHS.
LLVM_ABI const SCEV * getTruncateOrSignExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI bool containsErasedValue(const SCEV *S) const
Return true if the SCEV expression contains a Value that has been optimised out and is now a nullptr.
const SCEV * getSymbolicMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEV that is greater than or equal to (i.e.
APInt getSignedRangeMax(const SCEV *S)
Determine the max of the signed range for a particular SCEV.
LLVM_ABI void verify() const
LLVMContext & getContext() const
Implements a dense probed hash-table based set with some number of buckets stored inline.
Definition DenseSet.h:301
size_type size() const
Definition SmallPtrSet.h:99
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
bool contains(ConstPtrType Ptr) const
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
reference emplace_back(ArgTypes &&... Args)
void reserve(size_type N)
iterator erase(const_iterator CI)
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
iterator insert(iterator I, T &&Elt)
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
An instruction for storing to memory.
Represent a constant reference to a string, i.e.
Definition StringRef.h:56
Used to lazily calculate structure layout information for a target machine, based on the DataLayout s...
Definition DataLayout.h:743
TypeSize getElementOffset(unsigned Idx) const
Definition DataLayout.h:774
TypeSize getSizeInBits() const
Definition DataLayout.h:754
Class to represent struct types.
Analysis pass providing the TargetLibraryInfo.
Provides information about what library functions are available for the current target.
The instances of the Type class are immutable: once they are created, they are never changed.
Definition Type.h:46
static LLVM_ABI IntegerType * getInt32Ty(LLVMContext &C)
Definition Type.cpp:313
bool isPointerTy() const
True if this is an instance of PointerType.
Definition Type.h:284
LLVM_ABI TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
Definition Type.cpp:201
static LLVM_ABI IntegerType * getInt1Ty(LLVMContext &C)
Definition Type.cpp:310
bool isIntOrPtrTy() const
Return true if this is an integer type or a pointer type.
Definition Type.h:272
bool isIntegerTy() const
True if this is an instance of IntegerType.
Definition Type.h:257
static LLVM_ABI IntegerType * getIntNTy(LLVMContext &C, unsigned N)
Definition Type.cpp:317
A Use represents the edge between a Value definition and its users.
Definition Use.h:35
op_range operands()
Definition User.h:267
Use & Op()
Definition User.h:171
Value * getOperand(unsigned i) const
Definition User.h:207
LLVM Value Representation.
Definition Value.h:75
Type * getType() const
All values are typed, get the type of this value.
Definition Value.h:255
LLVMContext & getContext() const
All values hold a context through their type.
Definition Value.h:258
unsigned getValueID() const
Return an ID for the concrete type of this object.
Definition Value.h:543
LLVM_ABI void printAsOperand(raw_ostream &O, bool PrintType=true, const Module *M=nullptr) const
Print the name of this Value out to the specified raw_ostream.
LLVM_ABI StringRef getName() const
Return a constant reference to the value's name.
Definition Value.cpp:318
constexpr bool isScalable() const
Returns whether the quantity is scaled by a runtime quantity (vscale).
Definition TypeSize.h:168
An efficient, type-erasing, non-owning reference to a callable.
const ParentTy * getParent() const
Definition ilist_node.h:34
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition raw_ostream.h:53
raw_ostream & indent(unsigned NumSpaces)
indent - Insert 'NumSpaces' spaces.
Changed
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr char Align[]
Key for Kernel::Arg::Metadata::mAlign.
const APInt & smin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be signed.
Definition APInt.h:2277
const APInt & smax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be signed.
Definition APInt.h:2282
const APInt & umin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be unsigned.
Definition APInt.h:2287
LLVM_ABI std::optional< APInt > SolveQuadraticEquationWrap(APInt A, APInt B, APInt C, unsigned RangeWidth)
Let q(n) = An^2 + Bn + C, and BW = bit width of the value range (e.g.
Definition APInt.cpp:2864
const APInt & umax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be unsigned.
Definition APInt.h:2292
LLVM_ABI APInt GreatestCommonDivisor(APInt A, APInt B)
Compute GCD of two unsigned APInt values.
Definition APInt.cpp:830
constexpr bool any(E Val)
@ Entry
Definition COFF.h:862
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition CallingConv.h:24
@ C
The default llvm calling convention, compatible with C.
Definition CallingConv.h:34
int getMinValue(MCInstrInfo const &MCII, MCInst const &MCI)
Return the minimum value of an extendable operand.
@ BasicBlock
Various leaf nodes.
Definition ISDOpcodes.h:81
LLVM_ABI Function * getDeclarationIfExists(const Module *M, ID id)
Look up the Function declaration of the intrinsic id in the Module M and return it if it exists.
Predicate
Predicate - These are "(BI << 5) | BO" for various predicates.
match_combine_or< Ty... > m_CombineOr(const Ty &...Ps)
Combine pattern matchers matching any of Ps patterns.
BinaryOp_match< LHS, RHS, Instruction::AShr > m_AShr(const LHS &L, const RHS &R)
ap_match< APInt > m_APInt(const APInt *&Res)
Match a ConstantInt or splatted ConstantVector, binding the specified pointer to the contained APInt.
bool match(Val *V, const Pattern &P)
IntrinsicID_match m_Intrinsic()
Match intrinsic calls like this: m_Intrinsic<Intrinsic::fabs>(m_Value(X))
ThreeOps_match< Cond, LHS, RHS, Instruction::Select > m_Select(const Cond &C, const LHS &L, const RHS &R)
Matches SelectInst.
auto m_BasicBlock()
Match an arbitrary basic block value and ignore it.
ExtractValue_match< Ind, Val_t > m_ExtractValue(const Val_t &V)
Match a single index ExtractValue instruction.
auto m_Value()
Match an arbitrary value and ignore it.
auto m_LogicalOr()
Matches L || R where L and R are arbitrary values.
match_bind< WithOverflowInst > m_WithOverflowInst(WithOverflowInst *&I)
Match a with overflow intrinsic, capturing it if we match.
BinaryOp_match< LHS, RHS, Instruction::SDiv > m_SDiv(const LHS &L, const RHS &R)
BinaryOp_match< LHS, RHS, Instruction::LShr > m_LShr(const LHS &L, const RHS &R)
BinaryOp_match< LHS, RHS, Instruction::Shl > m_Shl(const LHS &L, const RHS &R)
auto m_LogicalAnd()
Matches L && R where L and R are arbitrary values.
brc_match< Cond_t, match_bind< BasicBlock >, match_bind< BasicBlock > > m_Br(const Cond_t &C, BasicBlock *&T, BasicBlock *&F)
auto m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
bind_cst_ty m_scev_APInt(const APInt *&C)
Match an SCEV constant and bind it to an APInt.
cst_pred_ty< is_all_ones > m_scev_AllOnes()
Match an integer with all bits set.
SCEVUnaryExpr_match< SCEVZeroExtendExpr, Op0_t > m_scev_ZExt(const Op0_t &Op0)
is_undef_or_poison m_scev_UndefOrPoison()
Match an SCEVUnknown wrapping undef or poison.
cst_pred_ty< is_one > m_scev_One()
Match an integer 1.
specificloop_ty m_SpecificLoop(const Loop *L)
SCEVUnaryExpr_match< SCEVSignExtendExpr, Op0_t > m_scev_SExt(const Op0_t &Op0)
match_bind< const SCEVMulExpr > m_scev_Mul(const SCEVMulExpr *&V)
cst_pred_ty< is_zero > m_scev_Zero()
Match an integer 0.
SCEVUnaryExpr_match< SCEVTruncateExpr, Op0_t > m_scev_Trunc(const Op0_t &Op0)
bool match(const SCEV *S, const Pattern &P)
SCEVBinaryExpr_match< SCEVUDivExpr, Op0_t, Op1_t > m_scev_UDiv(const Op0_t &Op0, const Op1_t &Op1)
specificscev_ty m_scev_Specific(const SCEV *S)
Match if we have a specific specified SCEV.
SCEVAffineAddRec_match< Op0_t, Op1_t, match_isa< const Loop > > m_scev_AffineAddRec(const Op0_t &Op0, const Op1_t &Op1)
match_bind< const SCEVUnknown > m_SCEVUnknown(const SCEVUnknown *&V)
SCEVBinaryExpr_match< SCEVMulExpr, Op0_t, Op1_t, SCEV::FlagNUW, true > m_scev_c_NUWMul(const Op0_t &Op0, const Op1_t &Op1)
match_bind< const SCEVAddExpr > m_scev_Add(const SCEVAddExpr *&V)
SCEVBinaryExpr_match< SCEVMulExpr, Op0_t, Op1_t, SCEV::FlagAnyWrap, true > m_scev_c_Mul(const Op0_t &Op0, const Op1_t &Op1)
SCEVBinaryExpr_match< SCEVSMaxExpr, Op0_t, Op1_t > m_scev_SMax(const Op0_t &Op0, const Op1_t &Op1)
SCEVURem_match< Op0_t, Op1_t > m_scev_URem(Op0_t LHS, Op1_t RHS, ScalarEvolution &SE)
Match the mathematical pattern A - (A / B) * B, where A and B can be arbitrary expressions.
@ Valid
The data is already valid.
initializer< Ty > init(const Ty &Val)
LocationClass< Ty > location(Ty &L)
@ Switch
The "resume-switch" lowering, where there are separate resume and destroy functions that are shared b...
Definition CoroShape.h:31
constexpr double e
NodeAddr< PhiNode * > Phi
Definition RDFGraph.h:390
friend class Instruction
Iterator for Instructions in a `BasicBlock.
Definition BasicBlock.h:73
This is an optimization pass for GlobalISel generic memory operations.
void visitAll(const SCEV *Root, SV &Visitor)
Use SCEVTraversal to visit all nodes in the given expression tree.
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
Definition STLExtras.h:315
@ Offset
Definition DWP.cpp:558
FunctionAddr VTableAddr Value
Definition InstrProf.h:137
LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt gcd(const DynamicAPInt &A, const DynamicAPInt &B)
void stable_sort(R &&Range)
Definition STLExtras.h:2115
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1738
SaveAndRestore(T &) -> SaveAndRestore< T >
Printable print(const GCNRegPressure &RP, const GCNSubtarget *ST=nullptr, unsigned DynamicVGPRBlockSize=0)
LLVM_ABI bool canCreatePoison(const Operator *Op, bool ConsiderFlagsAndMetadata=true)
LLVM_ABI bool mustTriggerUB(const Instruction *I, const SmallPtrSetImpl< const Value * > &KnownPoison)
Return true if the given instruction must trigger undefined behavior when I is executed with any oper...
LLVM_ABI bool canConstantFoldCallTo(const CallBase *Call, const Function *F)
canConstantFoldCallTo - Return true if its even possible to fold a call to the specified function.
InterleavedRange< Range > interleaved(const Range &R, StringRef Separator=", ", StringRef Prefix="", StringRef Suffix="")
Output range R as a sequence of interleaved elements.
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:643
LLVM_ABI bool verifyFunction(const Function &F, raw_ostream *OS=nullptr)
Check a function for errors, useful for use when debugging a pass.
auto successors(const MachineBasicBlock *BB)
scope_exit(Callable) -> scope_exit< Callable >
constexpr from_range_t from_range
auto dyn_cast_if_present(const Y &Val)
dyn_cast_if_present<X> - Functionally identical to dyn_cast, except that a null (or none in the case ...
Definition Casting.h:732
bool set_is_subset(const S1Ty &S1, const S2Ty &S2)
set_is_subset(A, B) - Return true iff A in B
void append_range(Container &C, Range &&R)
Wrapper function to append range R to container C.
Definition STLExtras.h:2207
constexpr bool isUIntN(unsigned N, uint64_t x)
Checks if an unsigned integer fits into the given (dynamic) bit width.
Definition MathExtras.h:243
LLVM_ABI Constant * ConstantFoldCompareInstOperands(unsigned Predicate, Constant *LHS, Constant *RHS, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr, const Instruction *I=nullptr)
Attempt to constant fold a compare instruction (icmp/fcmp) with the specified operands.
void * PointerTy
LLVM_ABI bool VerifySCEV
auto uninitialized_copy(R &&Src, IterTy Dst)
Definition STLExtras.h:2110
bool isa_and_nonnull(const Y &Val)
Definition Casting.h:676
LLVM_ABI ConstantRange getConstantRangeFromMetadata(const MDNode &RangeMD)
Parse out a conservative ConstantRange from !range metadata.
int countr_zero(T Val)
Count number of 0's from the least significant bit to the most stopping at the first 1.
Definition bit.h:204
LLVM_ABI Value * simplifyInstruction(Instruction *I, const SimplifyQuery &Q)
See if we can compute a simplified version of this instruction.
LLVM_ABI bool isOverflowIntrinsicNoWrap(const WithOverflowInst *WO, const DominatorTree &DT)
Returns true if the arithmetic part of the WO 's result is used only along the paths control dependen...
DomTreeNodeBase< BasicBlock > DomTreeNode
Definition Dominators.h:94
LLVM_ABI bool matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO, Value *&Start, Value *&Step)
Attempt to match a simple first order recurrence cycle of the form: iv = phi Ty [Start,...
auto dyn_cast_or_null(const Y &Val)
Definition Casting.h:753
void erase(Container &C, ValueType V)
Wrapper function to remove a value from a container:
Definition STLExtras.h:2199
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1745
iterator_range< pointee_iterator< WrappedIteratorT > > make_pointee_range(RangeT &&Range)
Definition iterator.h:341
auto reverse(ContainerTy &&C)
Definition STLExtras.h:407
LLVM_ABI bool isMustProgress(const Loop *L)
Return true if this loop can be assumed to make progress.
LLVM_ABI bool impliesPoison(const Value *ValAssumedPoison, const Value *V)
Return true if V is poison given that ValAssumedPoison is already poison.
LLVM_ABI bool isFinite(const Loop *L)
Return true if this loop can be assumed to run for a finite number of iterations.
LLVM_ABI void computeKnownBits(const Value *V, KnownBits &Known, const DataLayout &DL, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true, unsigned Depth=0)
Determine which bits of V are known to be either zero or one and return them in the KnownZero/KnownOn...
unsigned short computeExpressionSize(ArrayRef< SCEVUse > Args)
LLVM_ABI bool programUndefinedIfPoison(const Instruction *Inst)
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:209
bool isPointerTy(const Type *T)
Definition SPIRVUtils.h:371
FunctionAddr VTableAddr Count
Definition InstrProf.h:139
LLVM_ABI ConstantRange getVScaleRange(const Function *F, unsigned BitWidth)
Determine the possible constant range of vscale with the given bit width, based on the vscale_range f...
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
Definition Casting.h:547
LLVM_ATTRIBUTE_VISIBILITY_DEFAULT AnalysisKey InnerAnalysisManagerProxy< AnalysisManagerT, IRUnitT, ExtraArgTs... >::Key
LLVM_ABI bool isKnownNonZero(const Value *V, const SimplifyQuery &Q, unsigned Depth=0)
Return true if the given value is known to be non-zero when defined.
LLVM_ABI bool propagatesPoison(const Use &PoisonOp)
Return true if PoisonOp's user yields poison or raises UB if its operand PoisonOp is poison.
@ UMin
Unsigned integer min implemented in terms of select(cmp()).
@ Mul
Product of integers.
@ SMax
Signed integer max implemented in terms of select(cmp()).
@ SMin
Signed integer min implemented in terms of select(cmp()).
@ Add
Sum of integers.
@ UMax
Unsigned integer max implemented in terms of select(cmp()).
auto count(R &&Range, const E &Element)
Wrapper function around std::count to count the number of times an element Element occurs in the give...
Definition STLExtras.h:2011
DWARFExpression::Operation Op
auto max_element(R &&Range)
Provide wrappers to std::max_element which take ranges instead of having to pass begin/end explicitly...
Definition STLExtras.h:2087
raw_ostream & operator<<(raw_ostream &OS, const APFixedPoint &FX)
ArrayRef(const T &OneElt) -> ArrayRef< T >
LLVM_ABI unsigned ComputeNumSignBits(const Value *Op, const DataLayout &DL, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true, unsigned Depth=0)
Return the number of times the sign bit of the register is replicated into the other bits.
constexpr unsigned BitWidth
OutputIt move(R &&Range, OutputIt Out)
Provide wrappers to std::move which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1916
LLVM_ABI bool isGuaranteedToTransferExecutionToSuccessor(const Instruction *I)
Return true if this function can prove that the instruction I will always transfer execution to one o...
auto count_if(R &&Range, UnaryPredicate P)
Wrapper function around std::count_if to count the number of times an element satisfying a given pred...
Definition STLExtras.h:2018
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:559
constexpr bool isIntN(unsigned N, int64_t x)
Checks if an signed integer fits into the given (dynamic) bit width.
Definition MathExtras.h:248
auto predecessors(const MachineBasicBlock *BB)
bool is_contained(R &&Range, const E &Element)
Returns true if Element is found in Range.
Definition STLExtras.h:1946
iterator_range< df_iterator< T > > depth_first(const T &G)
auto seq(T Begin, T End)
Iterate over an integral type from Begin up to - but not including - End.
Definition Sequence.h:305
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.
LLVM_ABI bool isGuaranteedNotToBePoison(const Value *V, AssumptionCache *AC=nullptr, const Instruction *CtxI=nullptr, const DominatorTree *DT=nullptr, unsigned Depth=0)
Returns true if V cannot be poison, but may be undef.
LLVM_ABI Constant * ConstantFoldInstOperands(const Instruction *I, ArrayRef< Constant * > Ops, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr, bool AllowNonDeterministic=true)
ConstantFoldInstOperands - Attempt to constant fold an instruction with the specified operands.
SCEVUseT< const SCEV * > SCEVUse
bool SCEVExprContains(const SCEV *Root, PredTy Pred)
Return true if any node in Root satisfies the predicate Pred.
Implement std::hash so that hash_code can be used in STL containers.
Definition BitVector.h:874
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition BitVector.h:876
#define N
#define NC
Definition regutils.h:42
A special type used by analysis passes to provide an address that identifies that particular analysis...
Definition Analysis.h:29
static KnownBits makeConstant(const APInt &C)
Create known bits from a known constant.
Definition KnownBits.h:315
bool isNonNegative() const
Returns true if this value is known to be non-negative.
Definition KnownBits.h:106
static LLVM_ABI KnownBits ashr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for ashr(LHS, RHS).
unsigned getBitWidth() const
Get the bit width of this value.
Definition KnownBits.h:44
static LLVM_ABI KnownBits lshr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for lshr(LHS, RHS).
KnownBits zextOrTrunc(unsigned BitWidth) const
Return known bits for a zero extension or truncation of the value we're tracking.
Definition KnownBits.h:200
APInt getMaxValue() const
Return the maximal unsigned value possible given these KnownBits.
Definition KnownBits.h:146
APInt getMinValue() const
Return the minimal unsigned value possible given these KnownBits.
Definition KnownBits.h:130
bool isNegative() const
Returns true if this value is known to be negative.
Definition KnownBits.h:103
static LLVM_ABI KnownBits shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW=false, bool NSW=false, bool ShAmtNonZero=false)
Compute known bits for shl(LHS, RHS).
An object of this class is returned by queries that could not be answered.
static LLVM_ABI bool classof(const SCEV *S)
Methods for support type inquiry through isa, cast, and dyn_cast:
This class defines a simple visitor class that may be used for various SCEV analysis purposes.
A utility class that uses RAII to save and restore the value of a variable.
Information about the number of loop iterations for which a loop exit's branch condition evaluates to...
LLVM_ABI ExitLimit(const SCEV *E)
Construct either an exact exit limit from a constant, or an unknown one from a SCEVCouldNotCompute.
SmallVector< const SCEVPredicate *, 4 > Predicates
A vector of predicate guards for this ExitLimit.