LLVM 23.0.0git
ScalarEvolution.cpp
Go to the documentation of this file.
1//===- ScalarEvolution.cpp - Scalar Evolution Analysis --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains the implementation of the scalar evolution analysis
10// engine, which is used primarily to analyze expressions involving induction
11// variables in loops.
12//
13// There are several aspects to this library. First is the representation of
14// scalar expressions, which are represented as subclasses of the SCEV class.
15// These classes are used to represent certain types of subexpressions that we
16// can handle. We only create one SCEV of a particular shape, so
17// pointer-comparisons for equality are legal.
18//
19// One important aspect of the SCEV objects is that they are never cyclic, even
20// if there is a cycle in the dataflow for an expression (ie, a PHI node). If
21// the PHI node is one of the idioms that we can represent (e.g., a polynomial
22// recurrence) then we represent it directly as a recurrence node, otherwise we
23// represent it as a SCEVUnknown node.
24//
25// In addition to being able to represent expressions of various types, we also
26// have folders that are used to build the *canonical* representation for a
27// particular expression. These folders are capable of using a variety of
28// rewrite rules to simplify the expressions.
29//
30// Once the folders are defined, we can implement the more interesting
31// higher-level code, such as the code that recognizes PHI nodes of various
32// types, computes the execution count of a loop, etc.
33//
34// TODO: We should use these routines and value representations to implement
35// dependence analysis!
36//
37//===----------------------------------------------------------------------===//
38//
39// There are several good references for the techniques used in this analysis.
40//
41// Chains of recurrences -- a method to expedite the evaluation
42// of closed-form functions
43// Olaf Bachmann, Paul S. Wang, Eugene V. Zima
44//
45// On computational properties of chains of recurrences
46// Eugene V. Zima
47//
48// Symbolic Evaluation of Chains of Recurrences for Loop Optimization
49// Robert A. van Engelen
50//
51// Efficient Symbolic Analysis for Optimizing Compilers
52// Robert A. van Engelen
53//
54// Using the chains of recurrences algebra for data dependence testing and
55// induction variable substitution
56// MS Thesis, Johnie Birch
57//
58//===----------------------------------------------------------------------===//
59
61#include "llvm/ADT/APInt.h"
62#include "llvm/ADT/ArrayRef.h"
63#include "llvm/ADT/DenseMap.h"
65#include "llvm/ADT/FoldingSet.h"
66#include "llvm/ADT/STLExtras.h"
67#include "llvm/ADT/ScopeExit.h"
68#include "llvm/ADT/Sequence.h"
71#include "llvm/ADT/Statistic.h"
73#include "llvm/ADT/StringRef.h"
83#include "llvm/Config/llvm-config.h"
84#include "llvm/IR/Argument.h"
85#include "llvm/IR/BasicBlock.h"
86#include "llvm/IR/CFG.h"
87#include "llvm/IR/Constant.h"
89#include "llvm/IR/Constants.h"
90#include "llvm/IR/DataLayout.h"
92#include "llvm/IR/Dominators.h"
93#include "llvm/IR/Function.h"
94#include "llvm/IR/GlobalAlias.h"
95#include "llvm/IR/GlobalValue.h"
97#include "llvm/IR/InstrTypes.h"
98#include "llvm/IR/Instruction.h"
101#include "llvm/IR/Intrinsics.h"
102#include "llvm/IR/LLVMContext.h"
103#include "llvm/IR/Operator.h"
104#include "llvm/IR/PatternMatch.h"
105#include "llvm/IR/Type.h"
106#include "llvm/IR/Use.h"
107#include "llvm/IR/User.h"
108#include "llvm/IR/Value.h"
109#include "llvm/IR/Verifier.h"
111#include "llvm/Pass.h"
112#include "llvm/Support/Casting.h"
115#include "llvm/Support/Debug.h"
121#include <algorithm>
122#include <cassert>
123#include <climits>
124#include <cstdint>
125#include <cstdlib>
126#include <map>
127#include <memory>
128#include <numeric>
129#include <optional>
130#include <tuple>
131#include <utility>
132#include <vector>
133
134using namespace llvm;
135using namespace PatternMatch;
136using namespace SCEVPatternMatch;
137
138#define DEBUG_TYPE "scalar-evolution"
139
140STATISTIC(NumExitCountsComputed,
141 "Number of loop exits with predictable exit counts");
142STATISTIC(NumExitCountsNotComputed,
143 "Number of loop exits without predictable exit counts");
144STATISTIC(NumBruteForceTripCountsComputed,
145 "Number of loops with trip counts computed by force");
146
147#ifdef EXPENSIVE_CHECKS
148bool llvm::VerifySCEV = true;
149#else
150bool llvm::VerifySCEV = false;
151#endif
152
154 MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
155 cl::desc("Maximum number of iterations SCEV will "
156 "symbolically execute a constant "
157 "derived loop"),
158 cl::init(100));
159
161 "verify-scev", cl::Hidden, cl::location(VerifySCEV),
162 cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"));
164 "verify-scev-strict", cl::Hidden,
165 cl::desc("Enable stricter verification with -verify-scev is passed"));
166
168 "scev-verify-ir", cl::Hidden,
169 cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"),
170 cl::init(false));
171
173 "scev-mulops-inline-threshold", cl::Hidden,
174 cl::desc("Threshold for inlining multiplication operands into a SCEV"),
175 cl::init(32));
176
178 "scev-addops-inline-threshold", cl::Hidden,
179 cl::desc("Threshold for inlining addition operands into a SCEV"),
180 cl::init(500));
181
183 "scalar-evolution-max-scev-compare-depth", cl::Hidden,
184 cl::desc("Maximum depth of recursive SCEV complexity comparisons"),
185 cl::init(32));
186
188 "scalar-evolution-max-scev-operations-implication-depth", cl::Hidden,
189 cl::desc("Maximum depth of recursive SCEV operations implication analysis"),
190 cl::init(2));
191
193 "scalar-evolution-max-value-compare-depth", cl::Hidden,
194 cl::desc("Maximum depth of recursive value complexity comparisons"),
195 cl::init(2));
196
198 MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden,
199 cl::desc("Maximum depth of recursive arithmetics"),
200 cl::init(32));
201
203 "scalar-evolution-max-constant-evolving-depth", cl::Hidden,
204 cl::desc("Maximum depth of recursive constant evolving"), cl::init(32));
205
207 MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden,
208 cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"),
209 cl::init(8));
210
212 MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden,
213 cl::desc("Max coefficients in AddRec during evolving"),
214 cl::init(8));
215
217 HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden,
218 cl::desc("Size of the expression which is considered huge"),
219 cl::init(4096));
220
222 "scev-range-iter-threshold", cl::Hidden,
223 cl::desc("Threshold for switching to iteratively computing SCEV ranges"),
224 cl::init(32));
225
227 "scalar-evolution-max-loop-guard-collection-depth", cl::Hidden,
228 cl::desc("Maximum depth for recursive loop guard collection"), cl::init(1));
229
230static cl::opt<bool>
231ClassifyExpressions("scalar-evolution-classify-expressions",
232 cl::Hidden, cl::init(true),
233 cl::desc("When printing analysis, include information on every instruction"));
234
236 "scalar-evolution-use-expensive-range-sharpening", cl::Hidden,
237 cl::init(false),
238 cl::desc("Use more powerful methods of sharpening expression ranges. May "
239 "be costly in terms of compile time"));
240
242 "scalar-evolution-max-scc-analysis-depth", cl::Hidden,
243 cl::desc("Maximum amount of nodes to process while searching SCEVUnknown "
244 "Phi strongly connected components"),
245 cl::init(8));
246
247static cl::opt<bool>
248 EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden,
249 cl::desc("Handle <= and >= in finite loops"),
250 cl::init(true));
251
253 "scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden,
254 cl::desc("Infer nuw/nsw flags using context where suitable"),
255 cl::init(true));
256
257//===----------------------------------------------------------------------===//
258// SCEV class definitions
259//===----------------------------------------------------------------------===//
260
262 // Leaf nodes are always their own canonical.
263 switch (getSCEVType()) {
264 case scConstant:
265 case scVScale:
266 case scUnknown:
267 CanonicalSCEV = this;
268 return;
269 default:
270 break;
271 }
272
273 // For all other expressions, check whether any immediate operand has a
274 // different canonical. Since operands are always created before their parent,
275 // their canonical pointers are already set — no recursion needed.
276 bool Changed = false;
278 for (SCEVUse Op : operands()) {
279 CanonOps.push_back(Op->getCanonical());
280 Changed |= CanonOps.back() != Op.getPointer();
281 }
282
283 if (!Changed) {
284 CanonicalSCEV = this;
285 return;
286 }
287
288 auto *NAry = dyn_cast<SCEVNAryExpr>(this);
289 SCEV::NoWrapFlags Flags = NAry ? NAry->getNoWrapFlags() : SCEV::FlagAnyWrap;
290 switch (getSCEVType()) {
291 case scPtrToAddr:
292 CanonicalSCEV = SE.getPtrToAddrExpr(CanonOps[0]);
293 return;
294 case scPtrToInt:
295 CanonicalSCEV = SE.getPtrToIntExpr(CanonOps[0], getType());
296 return;
297 case scTruncate:
298 CanonicalSCEV = SE.getTruncateExpr(CanonOps[0], getType());
299 return;
300 case scZeroExtend:
301 CanonicalSCEV = SE.getZeroExtendExpr(CanonOps[0], getType());
302 return;
303 case scSignExtend:
304 CanonicalSCEV = SE.getSignExtendExpr(CanonOps[0], getType());
305 return;
306 case scUDivExpr:
307 CanonicalSCEV = SE.getUDivExpr(CanonOps[0], CanonOps[1]);
308 return;
309 case scAddExpr:
310 CanonicalSCEV = SE.getAddExpr(CanonOps, Flags);
311 return;
312 case scMulExpr:
313 CanonicalSCEV = SE.getMulExpr(CanonOps, Flags);
314 return;
315 case scAddRecExpr:
317 CanonOps, cast<SCEVAddRecExpr>(this)->getLoop(), Flags);
318 return;
319 case scSMaxExpr:
320 CanonicalSCEV = SE.getSMaxExpr(CanonOps);
321 return;
322 case scUMaxExpr:
323 CanonicalSCEV = SE.getUMaxExpr(CanonOps);
324 return;
325 case scSMinExpr:
326 CanonicalSCEV = SE.getSMinExpr(CanonOps);
327 return;
328 case scUMinExpr:
329 CanonicalSCEV = SE.getUMinExpr(CanonOps);
330 return;
332 CanonicalSCEV = SE.getUMinExpr(CanonOps, /*Sequential=*/true);
333 return;
334 default:
335 llvm_unreachable("Unknown SCEV type");
336 }
337}
338
339//===----------------------------------------------------------------------===//
340// Implementation of the SCEV class.
341//
342
343#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
345 print(dbgs());
346 dbgs() << '\n';
347}
348#endif
349
350void SCEV::print(raw_ostream &OS) const {
351 switch (getSCEVType()) {
352 case scConstant:
353 cast<SCEVConstant>(this)->getValue()->printAsOperand(OS, false);
354 return;
355 case scVScale:
356 OS << "vscale";
357 return;
358 case scPtrToAddr:
359 case scPtrToInt: {
360 const SCEVCastExpr *PtrCast = cast<SCEVCastExpr>(this);
361 const SCEV *Op = PtrCast->getOperand();
362 StringRef OpS = getSCEVType() == scPtrToAddr ? "addr" : "int";
363 OS << "(ptrto" << OpS << " " << *Op->getType() << " " << *Op << " to "
364 << *PtrCast->getType() << ")";
365 return;
366 }
367 case scTruncate: {
368 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(this);
369 const SCEV *Op = Trunc->getOperand();
370 OS << "(trunc " << *Op->getType() << " " << *Op << " to "
371 << *Trunc->getType() << ")";
372 return;
373 }
374 case scZeroExtend: {
376 const SCEV *Op = ZExt->getOperand();
377 OS << "(zext " << *Op->getType() << " " << *Op << " to "
378 << *ZExt->getType() << ")";
379 return;
380 }
381 case scSignExtend: {
383 const SCEV *Op = SExt->getOperand();
384 OS << "(sext " << *Op->getType() << " " << *Op << " to "
385 << *SExt->getType() << ")";
386 return;
387 }
388 case scAddRecExpr: {
389 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(this);
390 OS << "{" << *AR->getOperand(0);
391 for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i)
392 OS << ",+," << *AR->getOperand(i);
393 OS << "}<";
394 if (AR->hasNoUnsignedWrap())
395 OS << "nuw><";
396 if (AR->hasNoSignedWrap())
397 OS << "nsw><";
398 if (AR->hasNoSelfWrap() &&
400 OS << "nw><";
401 AR->getLoop()->getHeader()->printAsOperand(OS, /*PrintType=*/false);
402 OS << ">";
403 return;
404 }
405 case scAddExpr:
406 case scMulExpr:
407 case scUMaxExpr:
408 case scSMaxExpr:
409 case scUMinExpr:
410 case scSMinExpr:
412 const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(this);
413 const char *OpStr = nullptr;
414 switch (NAry->getSCEVType()) {
415 case scAddExpr: OpStr = " + "; break;
416 case scMulExpr: OpStr = " * "; break;
417 case scUMaxExpr: OpStr = " umax "; break;
418 case scSMaxExpr: OpStr = " smax "; break;
419 case scUMinExpr:
420 OpStr = " umin ";
421 break;
422 case scSMinExpr:
423 OpStr = " smin ";
424 break;
426 OpStr = " umin_seq ";
427 break;
428 default:
429 llvm_unreachable("There are no other nary expression types.");
430 }
431 OS << "("
433 << ")";
434 switch (NAry->getSCEVType()) {
435 case scAddExpr:
436 case scMulExpr:
437 if (NAry->hasNoUnsignedWrap())
438 OS << "<nuw>";
439 if (NAry->hasNoSignedWrap())
440 OS << "<nsw>";
441 break;
442 default:
443 // Nothing to print for other nary expressions.
444 break;
445 }
446 return;
447 }
448 case scUDivExpr: {
449 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(this);
450 OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")";
451 return;
452 }
453 case scUnknown:
454 cast<SCEVUnknown>(this)->getValue()->printAsOperand(OS, false);
455 return;
457 OS << "***COULDNOTCOMPUTE***";
458 return;
459 }
460 llvm_unreachable("Unknown SCEV kind!");
461}
462
464 switch (getSCEVType()) {
465 case scConstant:
466 return cast<SCEVConstant>(this)->getType();
467 case scVScale:
468 return cast<SCEVVScale>(this)->getType();
469 case scPtrToAddr:
470 case scPtrToInt:
471 case scTruncate:
472 case scZeroExtend:
473 case scSignExtend:
474 return cast<SCEVCastExpr>(this)->getType();
475 case scAddRecExpr:
476 return cast<SCEVAddRecExpr>(this)->getType();
477 case scMulExpr:
478 return cast<SCEVMulExpr>(this)->getType();
479 case scUMaxExpr:
480 case scSMaxExpr:
481 case scUMinExpr:
482 case scSMinExpr:
483 return cast<SCEVMinMaxExpr>(this)->getType();
485 return cast<SCEVSequentialMinMaxExpr>(this)->getType();
486 case scAddExpr:
487 return cast<SCEVAddExpr>(this)->getType();
488 case scUDivExpr:
489 return cast<SCEVUDivExpr>(this)->getType();
490 case scUnknown:
491 return cast<SCEVUnknown>(this)->getType();
493 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
494 }
495 llvm_unreachable("Unknown SCEV kind!");
496}
497
499 switch (getSCEVType()) {
500 case scConstant:
501 case scVScale:
502 case scUnknown:
503 return {};
504 case scPtrToAddr:
505 case scPtrToInt:
506 case scTruncate:
507 case scZeroExtend:
508 case scSignExtend:
509 return cast<SCEVCastExpr>(this)->operands();
510 case scAddRecExpr:
511 case scAddExpr:
512 case scMulExpr:
513 case scUMaxExpr:
514 case scSMaxExpr:
515 case scUMinExpr:
516 case scSMinExpr:
518 return cast<SCEVNAryExpr>(this)->operands();
519 case scUDivExpr:
520 return cast<SCEVUDivExpr>(this)->operands();
522 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
523 }
524 llvm_unreachable("Unknown SCEV kind!");
525}
526
527bool SCEV::isZero() const { return match(this, m_scev_Zero()); }
528
529bool SCEV::isOne() const { return match(this, m_scev_One()); }
530
531bool SCEV::isAllOnesValue() const { return match(this, m_scev_AllOnes()); }
532
535 if (!Mul) return false;
536
537 // If there is a constant factor, it will be first.
538 const SCEVConstant *SC = dyn_cast<SCEVConstant>(Mul->getOperand(0));
539 if (!SC) return false;
540
541 // Return true if the value is negative, this matches things like (-42 * V).
542 return SC->getAPInt().isNegative();
543}
544
547
549 return S->getSCEVType() == scCouldNotCompute;
550}
551
554 ID.AddInteger(scConstant);
555 ID.AddPointer(V);
556 void *IP = nullptr;
557 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
558 SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V);
559 UniqueSCEVs.InsertNode(S, IP);
560 S->computeAndSetCanonical(*this);
561 return S;
562}
563
565 return getConstant(ConstantInt::get(getContext(), Val));
566}
567
568const SCEV *
571 // TODO: Avoid implicit trunc?
572 // See https://github.com/llvm/llvm-project/issues/112510.
573 return getConstant(
574 ConstantInt::get(ITy, V, isSigned, /*ImplicitTrunc=*/true));
575}
576
579 ID.AddInteger(scVScale);
580 ID.AddPointer(Ty);
581 void *IP = nullptr;
582 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
583 return S;
584 SCEV *S = new (SCEVAllocator) SCEVVScale(ID.Intern(SCEVAllocator), Ty);
585 UniqueSCEVs.InsertNode(S, IP);
586 S->computeAndSetCanonical(*this);
587 return S;
588}
589
591 SCEV::NoWrapFlags Flags) {
592 const SCEV *Res = getConstant(Ty, EC.getKnownMinValue());
593 if (EC.isScalable())
594 Res = getMulExpr(Res, getVScale(Ty), Flags);
595 return Res;
596}
597
601
602SCEVPtrToAddrExpr::SCEVPtrToAddrExpr(const FoldingSetNodeIDRef ID,
603 const SCEV *Op, Type *ITy)
604 : SCEVCastExpr(ID, scPtrToAddr, Op, ITy) {
605 assert(getOperand()->getType()->isPointerTy() && Ty->isIntegerTy() &&
606 "Must be a non-bit-width-changing pointer-to-integer cast!");
607}
608
609SCEVPtrToIntExpr::SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, SCEVUse Op,
610 Type *ITy)
611 : SCEVCastExpr(ID, scPtrToInt, Op, ITy) {
612 assert(getOperand()->getType()->isPointerTy() && Ty->isIntegerTy() &&
613 "Must be a non-bit-width-changing pointer-to-integer cast!");
614}
615
620
621SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID, SCEVUse op,
622 Type *ty)
624 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
625 "Cannot truncate non-integer value!");
626}
627
628SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID, SCEVUse op,
629 Type *ty)
631 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
632 "Cannot zero extend non-integer value!");
633}
634
635SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID, SCEVUse op,
636 Type *ty)
638 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
639 "Cannot sign extend non-integer value!");
640}
641
643 // Clear this SCEVUnknown from various maps.
644 SE->forgetMemoizedResults({this});
645
646 // Remove this SCEVUnknown from the uniquing map.
647 SE->UniqueSCEVs.RemoveNode(this);
648
649 // Release the value.
650 setValPtr(nullptr);
651}
652
653void SCEVUnknown::allUsesReplacedWith(Value *New) {
654 // Clear this SCEVUnknown from various maps.
655 SE->forgetMemoizedResults({this});
656
657 // Remove this SCEVUnknown from the uniquing map.
658 SE->UniqueSCEVs.RemoveNode(this);
659
660 // Replace the value pointer in case someone is still using this SCEVUnknown.
661 setValPtr(New);
662}
663
664//===----------------------------------------------------------------------===//
665// SCEV Utilities
666//===----------------------------------------------------------------------===//
667
668/// Compare the two values \p LV and \p RV in terms of their "complexity" where
669/// "complexity" is a partial (and somewhat ad-hoc) relation used to order
670/// operands in SCEV expressions.
671static int CompareValueComplexity(const LoopInfo *const LI, Value *LV,
672 Value *RV, unsigned Depth) {
674 return 0;
675
676 // Order pointer values after integer values. This helps SCEVExpander form
677 // GEPs.
678 bool LIsPointer = LV->getType()->isPointerTy(),
679 RIsPointer = RV->getType()->isPointerTy();
680 if (LIsPointer != RIsPointer)
681 return (int)LIsPointer - (int)RIsPointer;
682
683 // Compare getValueID values.
684 unsigned LID = LV->getValueID(), RID = RV->getValueID();
685 if (LID != RID)
686 return (int)LID - (int)RID;
687
688 // Sort arguments by their position.
689 if (const auto *LA = dyn_cast<Argument>(LV)) {
690 const auto *RA = cast<Argument>(RV);
691 unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo();
692 return (int)LArgNo - (int)RArgNo;
693 }
694
695 if (const auto *LGV = dyn_cast<GlobalValue>(LV)) {
696 const auto *RGV = cast<GlobalValue>(RV);
697
698 if (auto L = LGV->getLinkage() - RGV->getLinkage())
699 return L;
700
701 const auto IsGVNameSemantic = [&](const GlobalValue *GV) {
702 auto LT = GV->getLinkage();
703 return !(GlobalValue::isPrivateLinkage(LT) ||
705 };
706
707 // Use the names to distinguish the two values, but only if the
708 // names are semantically important.
709 if (IsGVNameSemantic(LGV) && IsGVNameSemantic(RGV))
710 return LGV->getName().compare(RGV->getName());
711 }
712
713 // For instructions, compare their loop depth, and their operand count. This
714 // is pretty loose.
715 if (const auto *LInst = dyn_cast<Instruction>(LV)) {
716 const auto *RInst = cast<Instruction>(RV);
717
718 // Compare loop depths.
719 const BasicBlock *LParent = LInst->getParent(),
720 *RParent = RInst->getParent();
721 if (LParent != RParent) {
722 unsigned LDepth = LI->getLoopDepth(LParent),
723 RDepth = LI->getLoopDepth(RParent);
724 if (LDepth != RDepth)
725 return (int)LDepth - (int)RDepth;
726 }
727
728 // Compare the number of operands.
729 unsigned LNumOps = LInst->getNumOperands(),
730 RNumOps = RInst->getNumOperands();
731 if (LNumOps != RNumOps)
732 return (int)LNumOps - (int)RNumOps;
733
734 for (unsigned Idx : seq(LNumOps)) {
735 int Result = CompareValueComplexity(LI, LInst->getOperand(Idx),
736 RInst->getOperand(Idx), Depth + 1);
737 if (Result != 0)
738 return Result;
739 }
740 }
741
742 return 0;
743}
744
745// Return negative, zero, or positive, if LHS is less than, equal to, or greater
746// than RHS, respectively. A three-way result allows recursive comparisons to be
747// more efficient.
748// If the max analysis depth was reached, return std::nullopt, assuming we do
749// not know if they are equivalent for sure.
750static std::optional<int>
751CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS,
752 const SCEV *RHS, DominatorTree &DT, unsigned Depth = 0) {
753 // Fast-path: SCEVs are uniqued so we can do a quick equality check.
754 if (LHS == RHS)
755 return 0;
756
757 // Primarily, sort the SCEVs by their getSCEVType().
758 SCEVTypes LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
759 if (LType != RType)
760 return (int)LType - (int)RType;
761
763 return std::nullopt;
764
765 // Aside from the getSCEVType() ordering, the particular ordering
766 // isn't very important except that it's beneficial to be consistent,
767 // so that (a + b) and (b + a) don't end up as different expressions.
768 switch (LType) {
769 case scUnknown: {
770 const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
771 const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
772
773 int X =
774 CompareValueComplexity(LI, LU->getValue(), RU->getValue(), Depth + 1);
775 return X;
776 }
777
778 case scConstant: {
781
782 // Compare constant values.
783 const APInt &LA = LC->getAPInt();
784 const APInt &RA = RC->getAPInt();
785 unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth();
786 if (LBitWidth != RBitWidth)
787 return (int)LBitWidth - (int)RBitWidth;
788 return LA.ult(RA) ? -1 : 1;
789 }
790
791 case scVScale: {
792 const auto *LTy = cast<IntegerType>(cast<SCEVVScale>(LHS)->getType());
793 const auto *RTy = cast<IntegerType>(cast<SCEVVScale>(RHS)->getType());
794 return LTy->getBitWidth() - RTy->getBitWidth();
795 }
796
797 case scAddRecExpr: {
800
801 // There is always a dominance between two recs that are used by one SCEV,
802 // so we can safely sort recs by loop header dominance. We require such
803 // order in getAddExpr.
804 const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop();
805 if (LLoop != RLoop) {
806 const BasicBlock *LHead = LLoop->getHeader(), *RHead = RLoop->getHeader();
807 assert(LHead != RHead && "Two loops share the same header?");
808 if (DT.dominates(LHead, RHead))
809 return 1;
810 assert(DT.dominates(RHead, LHead) &&
811 "No dominance between recurrences used by one SCEV?");
812 return -1;
813 }
814
815 [[fallthrough]];
816 }
817
818 case scTruncate:
819 case scZeroExtend:
820 case scSignExtend:
821 case scPtrToAddr:
822 case scPtrToInt:
823 case scAddExpr:
824 case scMulExpr:
825 case scUDivExpr:
826 case scSMaxExpr:
827 case scUMaxExpr:
828 case scSMinExpr:
829 case scUMinExpr:
831 ArrayRef<SCEVUse> LOps = LHS->operands();
832 ArrayRef<SCEVUse> ROps = RHS->operands();
833
834 // Lexicographically compare n-ary-like expressions.
835 unsigned LNumOps = LOps.size(), RNumOps = ROps.size();
836 if (LNumOps != RNumOps)
837 return (int)LNumOps - (int)RNumOps;
838
839 for (unsigned i = 0; i != LNumOps; ++i) {
840 auto X = CompareSCEVComplexity(LI, LOps[i].getPointer(),
841 ROps[i].getPointer(), DT, Depth + 1);
842 if (X != 0)
843 return X;
844 }
845 return 0;
846 }
847
849 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
850 }
851 llvm_unreachable("Unknown SCEV kind!");
852}
853
854/// Given a list of SCEV objects, order them by their complexity, and group
855/// objects of the same complexity together by value. When this routine is
856/// finished, we know that any duplicates in the vector are consecutive and that
857/// complexity is monotonically increasing.
858///
859/// Note that we go take special precautions to ensure that we get deterministic
860/// results from this routine. In other words, we don't want the results of
861/// this to depend on where the addresses of various SCEV objects happened to
862/// land in memory.
864 DominatorTree &DT) {
865 if (Ops.size() < 2) return; // Noop
866
867 // Whether LHS has provably less complexity than RHS.
868 auto IsLessComplex = [&](SCEVUse LHS, SCEVUse RHS) {
869 auto Complexity = CompareSCEVComplexity(LI, LHS, RHS, DT);
870 return Complexity && *Complexity < 0;
871 };
872 if (Ops.size() == 2) {
873 // This is the common case, which also happens to be trivially simple.
874 // Special case it.
875 SCEVUse &LHS = Ops[0], &RHS = Ops[1];
876 if (IsLessComplex(RHS, LHS))
877 std::swap(LHS, RHS);
878 return;
879 }
880
881 // Do the rough sort by complexity.
883 Ops, [&](SCEVUse LHS, SCEVUse RHS) { return IsLessComplex(LHS, RHS); });
884
885 // Now that we are sorted by complexity, group elements of the same
886 // complexity. Note that this is, at worst, N^2, but the vector is likely to
887 // be extremely short in practice. Note that we take this approach because we
888 // do not want to depend on the addresses of the objects we are grouping.
889 for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) {
890 const SCEV *S = Ops[i];
891 unsigned Complexity = S->getSCEVType();
892
893 // If there are any objects of the same complexity and same value as this
894 // one, group them.
895 for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) {
896 if (Ops[j] == S) { // Found a duplicate.
897 // Move it to immediately after i'th element.
898 std::swap(Ops[i+1], Ops[j]);
899 ++i; // no need to rescan it.
900 if (i == e-2) return; // Done!
901 }
902 }
903 }
904}
905
906/// Returns true if \p Ops contains a huge SCEV (the subtree of S contains at
907/// least HugeExprThreshold nodes).
909 return any_of(Ops, [](const SCEV *S) {
911 });
912}
913
914/// Performs a number of common optimizations on the passed \p Ops. If the
915/// whole expression reduces down to a single operand, it will be returned.
916///
917/// The following optimizations are performed:
918/// * Fold constants using the \p Fold function.
919/// * Remove identity constants satisfying \p IsIdentity.
920/// * If a constant satisfies \p IsAbsorber, return it.
921/// * Sort operands by complexity.
922template <typename FoldT, typename IsIdentityT, typename IsAbsorberT>
923static const SCEV *
925 SmallVectorImpl<SCEVUse> &Ops, FoldT Fold,
926 IsIdentityT IsIdentity, IsAbsorberT IsAbsorber) {
927 const SCEVConstant *Folded = nullptr;
928 for (unsigned Idx = 0; Idx < Ops.size();) {
929 const SCEV *Op = Ops[Idx];
930 if (const auto *C = dyn_cast<SCEVConstant>(Op)) {
931 if (!Folded)
932 Folded = C;
933 else
934 Folded = cast<SCEVConstant>(
935 SE.getConstant(Fold(Folded->getAPInt(), C->getAPInt())));
936 Ops.erase(Ops.begin() + Idx);
937 continue;
938 }
939 ++Idx;
940 }
941
942 if (Ops.empty()) {
943 assert(Folded && "Must have folded value");
944 return Folded;
945 }
946
947 if (Folded && IsAbsorber(Folded->getAPInt()))
948 return Folded;
949
950 GroupByComplexity(Ops, &LI, DT);
951 if (Folded && !IsIdentity(Folded->getAPInt()))
952 Ops.insert(Ops.begin(), Folded);
953
954 return Ops.size() == 1 ? Ops[0] : nullptr;
955}
956
957//===----------------------------------------------------------------------===//
958// Simple SCEV method implementations
959//===----------------------------------------------------------------------===//
960
961/// Compute BC(It, K). The result has width W. Assume, K > 0.
962static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
963 ScalarEvolution &SE,
964 Type *ResultTy) {
965 // Handle the simplest case efficiently.
966 if (K == 1)
967 return SE.getTruncateOrZeroExtend(It, ResultTy);
968
969 // We are using the following formula for BC(It, K):
970 //
971 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
972 //
973 // Suppose, W is the bitwidth of the return value. We must be prepared for
974 // overflow. Hence, we must assure that the result of our computation is
975 // equal to the accurate one modulo 2^W. Unfortunately, division isn't
976 // safe in modular arithmetic.
977 //
978 // However, this code doesn't use exactly that formula; the formula it uses
979 // is something like the following, where T is the number of factors of 2 in
980 // K! (i.e. trailing zeros in the binary representation of K!), and ^ is
981 // exponentiation:
982 //
983 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T)
984 //
985 // This formula is trivially equivalent to the previous formula. However,
986 // this formula can be implemented much more efficiently. The trick is that
987 // K! / 2^T is odd, and exact division by an odd number *is* safe in modular
988 // arithmetic. To do exact division in modular arithmetic, all we have
989 // to do is multiply by the inverse. Therefore, this step can be done at
990 // width W.
991 //
992 // The next issue is how to safely do the division by 2^T. The way this
993 // is done is by doing the multiplication step at a width of at least W + T
994 // bits. This way, the bottom W+T bits of the product are accurate. Then,
995 // when we perform the division by 2^T (which is equivalent to a right shift
996 // by T), the bottom W bits are accurate. Extra bits are okay; they'll get
997 // truncated out after the division by 2^T.
998 //
999 // In comparison to just directly using the first formula, this technique
1000 // is much more efficient; using the first formula requires W * K bits,
1001 // but this formula less than W + K bits. Also, the first formula requires
1002 // a division step, whereas this formula only requires multiplies and shifts.
1003 //
1004 // It doesn't matter whether the subtraction step is done in the calculation
1005 // width or the input iteration count's width; if the subtraction overflows,
1006 // the result must be zero anyway. We prefer here to do it in the width of
1007 // the induction variable because it helps a lot for certain cases; CodeGen
1008 // isn't smart enough to ignore the overflow, which leads to much less
1009 // efficient code if the width of the subtraction is wider than the native
1010 // register width.
1011 //
1012 // (It's possible to not widen at all by pulling out factors of 2 before
1013 // the multiplication; for example, K=2 can be calculated as
1014 // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires
1015 // extra arithmetic, so it's not an obvious win, and it gets
1016 // much more complicated for K > 3.)
1017
1018 // Protection from insane SCEVs; this bound is conservative,
1019 // but it probably doesn't matter.
1020 if (K > 1000)
1021 return SE.getCouldNotCompute();
1022
1023 unsigned W = SE.getTypeSizeInBits(ResultTy);
1024
1025 // Calculate K! / 2^T and T; we divide out the factors of two before
1026 // multiplying for calculating K! / 2^T to avoid overflow.
1027 // Other overflow doesn't matter because we only care about the bottom
1028 // W bits of the result.
1029 APInt OddFactorial(W, 1);
1030 unsigned T = 1;
1031 for (unsigned i = 3; i <= K; ++i) {
1032 unsigned TwoFactors = countr_zero(i);
1033 T += TwoFactors;
1034 OddFactorial *= (i >> TwoFactors);
1035 }
1036
1037 // We need at least W + T bits for the multiplication step
1038 unsigned CalculationBits = W + T;
1039
1040 // Calculate 2^T, at width T+W.
1041 APInt DivFactor = APInt::getOneBitSet(CalculationBits, T);
1042
1043 // Calculate the multiplicative inverse of K! / 2^T;
1044 // this multiplication factor will perform the exact division by
1045 // K! / 2^T.
1046 APInt MultiplyFactor = OddFactorial.multiplicativeInverse();
1047
1048 // Calculate the product, at width T+W
1049 IntegerType *CalculationTy = IntegerType::get(SE.getContext(),
1050 CalculationBits);
1051 const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
1052 for (unsigned i = 1; i != K; ++i) {
1053 const SCEV *S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i));
1054 Dividend = SE.getMulExpr(Dividend,
1055 SE.getTruncateOrZeroExtend(S, CalculationTy));
1056 }
1057
1058 // Divide by 2^T
1059 const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
1060
1061 // Truncate the result, and divide by K! / 2^T.
1062
1063 return SE.getMulExpr(SE.getConstant(MultiplyFactor),
1064 SE.getTruncateOrZeroExtend(DivResult, ResultTy));
1065}
1066
1067/// Return the value of this chain of recurrences at the specified iteration
1068/// number. We can evaluate this recurrence by multiplying each element in the
1069/// chain by the binomial coefficient corresponding to it. In other words, we
1070/// can evaluate {A,+,B,+,C,+,D} as:
1071///
1072/// A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
1073///
1074/// where BC(It, k) stands for binomial coefficient.
1076 ScalarEvolution &SE) const {
1077 return evaluateAtIteration(operands(), It, SE);
1078}
1079
1081 const SCEV *It,
1082 ScalarEvolution &SE) {
1083 assert(Operands.size() > 0);
1084 const SCEV *Result = Operands[0].getPointer();
1085 for (unsigned i = 1, e = Operands.size(); i != e; ++i) {
1086 // The computation is correct in the face of overflow provided that the
1087 // multiplication is performed _after_ the evaluation of the binomial
1088 // coefficient.
1089 const SCEV *Coeff = BinomialCoefficient(It, i, SE, Result->getType());
1090 if (isa<SCEVCouldNotCompute>(Coeff))
1091 return Coeff;
1092
1093 Result =
1094 SE.getAddExpr(Result, SE.getMulExpr(Operands[i].getPointer(), Coeff));
1095 }
1096 return Result;
1097}
1098
1099//===----------------------------------------------------------------------===//
1100// SCEV Expression folder implementations
1101//===----------------------------------------------------------------------===//
1102
1103/// The SCEVCastSinkingRewriter takes a scalar evolution expression,
1104/// which computes a pointer-typed value, and rewrites the whole expression
1105/// tree so that *all* the computations are done on integers, and the only
1106/// pointer-typed operands in the expression are SCEVUnknown.
1107/// The CreatePtrCast callback is invoked to create the actual conversion
1108/// (ptrtoint or ptrtoaddr) at the SCEVUnknown leaves.
1110 : public SCEVRewriteVisitor<SCEVCastSinkingRewriter> {
1112 using ConversionFn = function_ref<const SCEV *(const SCEVUnknown *)>;
1113 Type *TargetTy;
1114 ConversionFn CreatePtrCast;
1115
1116public:
1118 ConversionFn CreatePtrCast)
1119 : Base(SE), TargetTy(TargetTy), CreatePtrCast(std::move(CreatePtrCast)) {}
1120
1121 static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE,
1122 Type *TargetTy, ConversionFn CreatePtrCast) {
1123 SCEVCastSinkingRewriter Rewriter(SE, TargetTy, std::move(CreatePtrCast));
1124 return Rewriter.visit(Scev);
1125 }
1126
1127 const SCEV *visit(const SCEV *S) {
1128 Type *STy = S->getType();
1129 // If the expression is not pointer-typed, just keep it as-is.
1130 if (!STy->isPointerTy())
1131 return S;
1132 // Else, recursively sink the cast down into it.
1133 return Base::visit(S);
1134 }
1135
1136 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
1137 // Preserve wrap flags on rewritten SCEVAddExpr, which the default
1138 // implementation drops.
1139 SmallVector<SCEVUse, 2> Operands;
1140 bool Changed = false;
1141 for (SCEVUse Op : Expr->operands()) {
1142 Operands.push_back(visit(Op.getPointer()));
1143 Changed |= Op.getPointer() != Operands.back();
1144 }
1145 return !Changed ? Expr : SE.getAddExpr(Operands, Expr->getNoWrapFlags());
1146 }
1147
1148 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
1149 SmallVector<SCEVUse, 2> Operands;
1150 bool Changed = false;
1151 for (SCEVUse Op : Expr->operands()) {
1152 Operands.push_back(visit(Op.getPointer()));
1153 Changed |= Op.getPointer() != Operands.back();
1154 }
1155 return !Changed ? Expr : SE.getMulExpr(Operands, Expr->getNoWrapFlags());
1156 }
1157
1158 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
1159 assert(Expr->getType()->isPointerTy() &&
1160 "Should only reach pointer-typed SCEVUnknown's.");
1161 // Perform some basic constant folding. If the operand of the cast is a
1162 // null pointer, don't create a cast SCEV expression (that will be left
1163 // as-is), but produce a zero constant.
1165 return SE.getZero(TargetTy);
1166 return CreatePtrCast(Expr);
1167 }
1168};
1169
1171 assert(Op->getType()->isPointerTy() && "Op must be a pointer");
1172
1173 // It isn't legal for optimizations to construct new ptrtoint expressions
1174 // for non-integral pointers.
1175 if (getDataLayout().isNonIntegralPointerType(Op->getType()))
1176 return getCouldNotCompute();
1177
1178 Type *IntPtrTy = getDataLayout().getIntPtrType(Op->getType());
1179
1180 // We can only trivially model ptrtoint if SCEV's effective (integer) type
1181 // is sufficiently wide to represent all possible pointer values.
1182 // We could theoretically teach SCEV to truncate wider pointers, but
1183 // that isn't implemented for now.
1185 getDataLayout().getTypeSizeInBits(IntPtrTy))
1186 return getCouldNotCompute();
1187
1188 // Use the rewriter to sink the cast down to SCEVUnknown leaves.
1190 Op, *this, IntPtrTy, [this, IntPtrTy](const SCEVUnknown *U) {
1192 ID.AddInteger(scPtrToInt);
1193 ID.AddPointer(U);
1194 void *IP = nullptr;
1195 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1196 return S;
1197 SCEV *S = new (SCEVAllocator)
1198 SCEVPtrToIntExpr(ID.Intern(SCEVAllocator), U, IntPtrTy);
1199 UniqueSCEVs.InsertNode(S, IP);
1200 S->computeAndSetCanonical(*this);
1201 registerUser(S, U);
1202 return static_cast<const SCEV *>(S);
1203 });
1204 assert(IntOp->getType()->isIntegerTy() &&
1205 "We must have succeeded in sinking the cast, "
1206 "and ending up with an integer-typed expression!");
1207 return IntOp;
1208}
1209
1211 assert(Op->getType()->isPointerTy() && "Op must be a pointer");
1212
1213 // Treat pointers with unstable representation conservatively, since the
1214 // address bits may change.
1215 if (DL.hasUnstableRepresentation(Op->getType()))
1216 return getCouldNotCompute();
1217
1218 Type *Ty = DL.getAddressType(Op->getType());
1219
1220 // Use the rewriter to sink the cast down to SCEVUnknown leaves.
1221 // The rewriter handles null pointer constant folding.
1223 Op, *this, Ty, [this, Ty](const SCEVUnknown *U) {
1225 ID.AddInteger(scPtrToAddr);
1226 ID.AddPointer(U);
1227 ID.AddPointer(Ty);
1228 void *IP = nullptr;
1229 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1230 return S;
1231 SCEV *S = new (SCEVAllocator)
1232 SCEVPtrToAddrExpr(ID.Intern(SCEVAllocator), U, Ty);
1233 UniqueSCEVs.InsertNode(S, IP);
1234 S->computeAndSetCanonical(*this);
1235 registerUser(S, U);
1236 return static_cast<const SCEV *>(S);
1237 });
1238 assert(IntOp->getType()->isIntegerTy() &&
1239 "We must have succeeded in sinking the cast, "
1240 "and ending up with an integer-typed expression!");
1241 return IntOp;
1242}
1243
1245 assert(Ty->isIntegerTy() && "Target type must be an integer type!");
1246
1247 const SCEV *IntOp = getLosslessPtrToIntExpr(Op);
1248 if (isa<SCEVCouldNotCompute>(IntOp))
1249 return IntOp;
1250
1251 return getTruncateOrZeroExtend(IntOp, Ty);
1252}
1253
1255 unsigned Depth) {
1256 assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) &&
1257 "This is not a truncating conversion!");
1258 assert(isSCEVable(Ty) &&
1259 "This is not a conversion to a SCEVable type!");
1260 assert(!Op->getType()->isPointerTy() && "Can't truncate pointer!");
1261 Ty = getEffectiveSCEVType(Ty);
1262
1264 ID.AddInteger(scTruncate);
1265 ID.AddPointer(Op);
1266 ID.AddPointer(Ty);
1267 void *IP = nullptr;
1268 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1269
1270 // Fold if the operand is constant.
1271 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1272 return getConstant(
1273 cast<ConstantInt>(ConstantExpr::getTrunc(SC->getValue(), Ty)));
1274
1275 // trunc(trunc(x)) --> trunc(x)
1277 return getTruncateExpr(ST->getOperand(), Ty, Depth + 1);
1278
1279 // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing
1281 return getTruncateOrSignExtend(SS->getOperand(), Ty, Depth + 1);
1282
1283 // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing
1285 return getTruncateOrZeroExtend(SZ->getOperand(), Ty, Depth + 1);
1286
1287 if (Depth > MaxCastDepth) {
1288 SCEV *S =
1289 new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), Op, Ty);
1290 UniqueSCEVs.InsertNode(S, IP);
1291 S->computeAndSetCanonical(*this);
1292 registerUser(S, Op);
1293 return S;
1294 }
1295
1296 // trunc(x1 + ... + xN) --> trunc(x1) + ... + trunc(xN) and
1297 // trunc(x1 * ... * xN) --> trunc(x1) * ... * trunc(xN),
1298 // if after transforming we have at most one truncate, not counting truncates
1299 // that replace other casts.
1301 auto *CommOp = cast<SCEVCommutativeExpr>(Op);
1302 SmallVector<SCEVUse, 4> Operands;
1303 unsigned numTruncs = 0;
1304 for (unsigned i = 0, e = CommOp->getNumOperands(); i != e && numTruncs < 2;
1305 ++i) {
1306 const SCEV *S = getTruncateExpr(CommOp->getOperand(i), Ty, Depth + 1);
1307 if (!isa<SCEVIntegralCastExpr>(CommOp->getOperand(i)) &&
1309 numTruncs++;
1310 Operands.push_back(S);
1311 }
1312 if (numTruncs < 2) {
1313 if (isa<SCEVAddExpr>(Op))
1314 return getAddExpr(Operands);
1315 if (isa<SCEVMulExpr>(Op))
1316 return getMulExpr(Operands);
1317 llvm_unreachable("Unexpected SCEV type for Op.");
1318 }
1319 // Although we checked in the beginning that ID is not in the cache, it is
1320 // possible that during recursion and different modification ID was inserted
1321 // into the cache. So if we find it, just return it.
1322 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1323 return S;
1324 }
1325
1326 // If the input value is a chrec scev, truncate the chrec's operands.
1327 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
1328 SmallVector<SCEVUse, 4> Operands;
1329 for (const SCEV *Op : AddRec->operands())
1330 Operands.push_back(getTruncateExpr(Op, Ty, Depth + 1));
1331 return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap);
1332 }
1333
1334 // Return zero if truncating to known zeros.
1335 uint32_t MinTrailingZeros = getMinTrailingZeros(Op);
1336 if (MinTrailingZeros >= getTypeSizeInBits(Ty))
1337 return getZero(Ty);
1338
1339 // The cast wasn't folded; create an explicit cast node. We can reuse
1340 // the existing insert position since if we get here, we won't have
1341 // made any changes which would invalidate it.
1342 SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator),
1343 Op, Ty);
1344 UniqueSCEVs.InsertNode(S, IP);
1345 S->computeAndSetCanonical(*this);
1346 registerUser(S, Op);
1347 return S;
1348}
1349
1350// Get the limit of a recurrence such that incrementing by Step cannot cause
1351// signed overflow as long as the value of the recurrence within the
1352// loop does not exceed this limit before incrementing.
1353static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step,
1354 ICmpInst::Predicate *Pred,
1355 ScalarEvolution *SE) {
1356 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1357 if (SE->isKnownPositive(Step)) {
1358 *Pred = ICmpInst::ICMP_SLT;
1360 SE->getSignedRangeMax(Step));
1361 }
1362 if (SE->isKnownNegative(Step)) {
1363 *Pred = ICmpInst::ICMP_SGT;
1365 SE->getSignedRangeMin(Step));
1366 }
1367 return nullptr;
1368}
1369
1370// Get the limit of a recurrence such that incrementing by Step cannot cause
1371// unsigned overflow as long as the value of the recurrence within the loop does
1372// not exceed this limit before incrementing.
1374 ICmpInst::Predicate *Pred,
1375 ScalarEvolution *SE) {
1376 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1377 *Pred = ICmpInst::ICMP_ULT;
1378
1380 SE->getUnsignedRangeMax(Step));
1381}
1382
1383namespace {
1384
1385struct ExtendOpTraitsBase {
1386 typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *,
1387 unsigned);
1388};
1389
1390// Used to make code generic over signed and unsigned overflow.
1391template <typename ExtendOp> struct ExtendOpTraits {
1392 // Members present:
1393 //
1394 // static const SCEV::NoWrapFlags WrapType;
1395 //
1396 // static const ExtendOpTraitsBase::GetExtendExprTy GetExtendExpr;
1397 //
1398 // static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1399 // ICmpInst::Predicate *Pred,
1400 // ScalarEvolution *SE);
1401};
1402
1403template <>
1404struct ExtendOpTraits<SCEVSignExtendExpr> : public ExtendOpTraitsBase {
1405 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNSW;
1406
1407 static const GetExtendExprTy GetExtendExpr;
1408
1409 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1410 ICmpInst::Predicate *Pred,
1411 ScalarEvolution *SE) {
1412 return getSignedOverflowLimitForStep(Step, Pred, SE);
1413 }
1414};
1415
1416const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1418
1419template <>
1420struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase {
1421 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNUW;
1422
1423 static const GetExtendExprTy GetExtendExpr;
1424
1425 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1426 ICmpInst::Predicate *Pred,
1427 ScalarEvolution *SE) {
1428 return getUnsignedOverflowLimitForStep(Step, Pred, SE);
1429 }
1430};
1431
1432const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1434
1435} // end anonymous namespace
1436
1437// The recurrence AR has been shown to have no signed/unsigned wrap or something
1438// close to it. Typically, if we can prove NSW/NUW for AR, then we can just as
1439// easily prove NSW/NUW for its preincrement or postincrement sibling. This
1440// allows normalizing a sign/zero extended AddRec as such: {sext/zext(Step +
1441// Start),+,Step} => {(Step + sext/zext(Start),+,Step} As a result, the
1442// expression "Step + sext/zext(PreIncAR)" is congruent with
1443// "sext/zext(PostIncAR)"
1444template <typename ExtendOpTy>
1445static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty,
1446 ScalarEvolution *SE, unsigned Depth) {
1447 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1448 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1449
1450 const Loop *L = AR->getLoop();
1451 const SCEV *Start = AR->getStart();
1452 const SCEV *Step = AR->getStepRecurrence(*SE);
1453
1454 // Check for a simple looking step prior to loop entry.
1455 const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Start);
1456 if (!SA)
1457 return nullptr;
1458
1459 // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV
1460 // subtraction is expensive. For this purpose, perform a quick and dirty
1461 // difference, by checking for Step in the operand list. Note, that
1462 // SA might have repeated ops, like %a + %a + ..., so only remove one.
1463 SmallVector<SCEVUse, 4> DiffOps(SA->operands());
1464 for (auto It = DiffOps.begin(); It != DiffOps.end(); ++It)
1465 if (*It == Step) {
1466 DiffOps.erase(It);
1467 break;
1468 }
1469
1470 if (DiffOps.size() == SA->getNumOperands())
1471 return nullptr;
1472
1473 // Try to prove `WrapType` (SCEV::FlagNSW or SCEV::FlagNUW) on `PreStart` +
1474 // `Step`:
1475
1476 // 1. NSW/NUW flags on the step increment.
1477 auto PreStartFlags =
1479 const SCEV *PreStart = SE->getAddExpr(DiffOps, PreStartFlags);
1481 SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap));
1482
1483 // "{S,+,X} is <nsw>/<nuw>" and "the backedge is taken at least once" implies
1484 // "S+X does not sign/unsign-overflow".
1485 //
1486
1487 const SCEV *BECount = SE->getBackedgeTakenCount(L);
1488 if (PreAR && PreAR->getNoWrapFlags(WrapType) &&
1489 !isa<SCEVCouldNotCompute>(BECount) && SE->isKnownPositive(BECount))
1490 return PreStart;
1491
1492 // 2. Direct overflow check on the step operation's expression.
1493 unsigned BitWidth = SE->getTypeSizeInBits(AR->getType());
1494 Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2);
1495 const SCEV *OperandExtendedStart =
1496 SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy, Depth),
1497 (SE->*GetExtendExpr)(Step, WideTy, Depth));
1498 if ((SE->*GetExtendExpr)(Start, WideTy, Depth) == OperandExtendedStart) {
1499 if (PreAR && AR->getNoWrapFlags(WrapType)) {
1500 // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW
1501 // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then
1502 // `PreAR` == {`PreStart`,+,`Step`} is also `WrapType`. Cache this fact.
1503 SE->setNoWrapFlags(const_cast<SCEVAddRecExpr *>(PreAR), WrapType);
1504 }
1505 return PreStart;
1506 }
1507
1508 // 3. Loop precondition.
1510 const SCEV *OverflowLimit =
1511 ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(Step, &Pred, SE);
1512
1513 if (OverflowLimit &&
1514 SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit))
1515 return PreStart;
1516
1517 return nullptr;
1518}
1519
1520// Get the normalized zero or sign extended expression for this AddRec's Start.
1521template <typename ExtendOpTy>
1522static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty,
1523 ScalarEvolution *SE,
1524 unsigned Depth) {
1525 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1526
1527 const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE, Depth);
1528 if (!PreStart)
1529 return (SE->*GetExtendExpr)(AR->getStart(), Ty, Depth);
1530
1531 return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty,
1532 Depth),
1533 (SE->*GetExtendExpr)(PreStart, Ty, Depth));
1534}
1535
1536// Try to prove away overflow by looking at "nearby" add recurrences. A
1537// motivating example for this rule: if we know `{0,+,4}` is `ult` `-1` and it
1538// does not itself wrap then we can conclude that `{1,+,4}` is `nuw`.
1539//
1540// Formally:
1541//
1542// {S,+,X} == {S-T,+,X} + T
1543// => Ext({S,+,X}) == Ext({S-T,+,X} + T)
1544//
1545// If ({S-T,+,X} + T) does not overflow ... (1)
1546//
1547// RHS == Ext({S-T,+,X} + T) == Ext({S-T,+,X}) + Ext(T)
1548//
1549// If {S-T,+,X} does not overflow ... (2)
1550//
1551// RHS == Ext({S-T,+,X}) + Ext(T) == {Ext(S-T),+,Ext(X)} + Ext(T)
1552// == {Ext(S-T)+Ext(T),+,Ext(X)}
1553//
1554// If (S-T)+T does not overflow ... (3)
1555//
1556// RHS == {Ext(S-T)+Ext(T),+,Ext(X)} == {Ext(S-T+T),+,Ext(X)}
1557// == {Ext(S),+,Ext(X)} == LHS
1558//
1559// Thus, if (1), (2) and (3) are true for some T, then
1560// Ext({S,+,X}) == {Ext(S),+,Ext(X)}
1561//
1562// (3) is implied by (1) -- "(S-T)+T does not overflow" is simply "({S-T,+,X}+T)
1563// does not overflow" restricted to the 0th iteration. Therefore we only need
1564// to check for (1) and (2).
1565//
1566// In the current context, S is `Start`, X is `Step`, Ext is `ExtendOpTy` and T
1567// is `Delta` (defined below).
1568template <typename ExtendOpTy>
1569bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start,
1570 const SCEV *Step,
1571 const Loop *L) {
1572 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1573
1574 // We restrict `Start` to a constant to prevent SCEV from spending too much
1575 // time here. It is correct (but more expensive) to continue with a
1576 // non-constant `Start` and do a general SCEV subtraction to compute
1577 // `PreStart` below.
1578 const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start);
1579 if (!StartC)
1580 return false;
1581
1582 APInt StartAI = StartC->getAPInt();
1583
1584 for (unsigned Delta : {-2, -1, 1, 2}) {
1585 const SCEV *PreStart = getConstant(StartAI - Delta);
1586
1587 FoldingSetNodeID ID;
1588 ID.AddInteger(scAddRecExpr);
1589 ID.AddPointer(PreStart);
1590 ID.AddPointer(Step);
1591 ID.AddPointer(L);
1592 void *IP = nullptr;
1593 const auto *PreAR =
1594 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
1595
1596 // Give up if we don't already have the add recurrence we need because
1597 // actually constructing an add recurrence is relatively expensive.
1598 if (PreAR && PreAR->getNoWrapFlags(WrapType)) { // proves (2)
1599 const SCEV *DeltaS = getConstant(StartC->getType(), Delta);
1601 const SCEV *Limit = ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(
1602 DeltaS, &Pred, this);
1603 if (Limit && isKnownPredicate(Pred, PreAR, Limit)) // proves (1)
1604 return true;
1605 }
1606 }
1607
1608 return false;
1609}
1610
1611// Finds an integer D for an expression (C + x + y + ...) such that the top
1612// level addition in (D + (C - D + x + y + ...)) would not wrap (signed or
1613// unsigned) and the number of trailing zeros of (C - D + x + y + ...) is
1614// maximized, where C is the \p ConstantTerm, x, y, ... are arbitrary SCEVs, and
1615// the (C + x + y + ...) expression is \p WholeAddExpr.
1617 const SCEVConstant *ConstantTerm,
1618 const SCEVAddExpr *WholeAddExpr) {
1619 const APInt &C = ConstantTerm->getAPInt();
1620 const unsigned BitWidth = C.getBitWidth();
1621 // Find number of trailing zeros of (x + y + ...) w/o the C first:
1622 uint32_t TZ = BitWidth;
1623 for (unsigned I = 1, E = WholeAddExpr->getNumOperands(); I < E && TZ; ++I)
1624 TZ = std::min(TZ, SE.getMinTrailingZeros(WholeAddExpr->getOperand(I)));
1625 if (TZ) {
1626 // Set D to be as many least significant bits of C as possible while still
1627 // guaranteeing that adding D to (C - D + x + y + ...) won't cause a wrap:
1628 return TZ < BitWidth ? C.trunc(TZ).zext(BitWidth) : C;
1629 }
1630 return APInt(BitWidth, 0);
1631}
1632
1633// Finds an integer D for an affine AddRec expression {C,+,x} such that the top
1634// level addition in (D + {C-D,+,x}) would not wrap (signed or unsigned) and the
1635// number of trailing zeros of (C - D + x * n) is maximized, where C is the \p
1636// ConstantStart, x is an arbitrary \p Step, and n is the loop trip count.
1638 const APInt &ConstantStart,
1639 const SCEV *Step) {
1640 const unsigned BitWidth = ConstantStart.getBitWidth();
1641 const uint32_t TZ = SE.getMinTrailingZeros(Step);
1642 if (TZ)
1643 return TZ < BitWidth ? ConstantStart.trunc(TZ).zext(BitWidth)
1644 : ConstantStart;
1645 return APInt(BitWidth, 0);
1646}
1647
1649 const ScalarEvolution::FoldID &ID, const SCEV *S,
1652 &FoldCacheUser) {
1653 auto I = FoldCache.insert({ID, S});
1654 if (!I.second) {
1655 // Remove FoldCacheUser entry for ID when replacing an existing FoldCache
1656 // entry.
1657 auto &UserIDs = FoldCacheUser[I.first->second];
1658 assert(count(UserIDs, ID) == 1 && "unexpected duplicates in UserIDs");
1659 for (unsigned I = 0; I != UserIDs.size(); ++I)
1660 if (UserIDs[I] == ID) {
1661 std::swap(UserIDs[I], UserIDs.back());
1662 break;
1663 }
1664 UserIDs.pop_back();
1665 I.first->second = S;
1666 }
1667 FoldCacheUser[S].push_back(ID);
1668}
1669
1670const SCEV *
1672 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1673 "This is not an extending conversion!");
1674 assert(isSCEVable(Ty) &&
1675 "This is not a conversion to a SCEVable type!");
1676 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1677 Ty = getEffectiveSCEVType(Ty);
1678
1679 FoldID ID(scZeroExtend, Op, Ty);
1680 if (const SCEV *S = FoldCache.lookup(ID))
1681 return S;
1682
1683 const SCEV *S = getZeroExtendExprImpl(Op, Ty, Depth);
1685 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
1686 return S;
1687}
1688
1690 unsigned Depth) {
1691 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1692 "This is not an extending conversion!");
1693 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
1694 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1695
1696 // Fold if the operand is constant.
1697 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1698 return getConstant(SC->getAPInt().zext(getTypeSizeInBits(Ty)));
1699
1700 // zext(zext(x)) --> zext(x)
1702 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1703
1704 // Before doing any expensive analysis, check to see if we've already
1705 // computed a SCEV for this Op and Ty.
1707 ID.AddInteger(scZeroExtend);
1708 ID.AddPointer(Op);
1709 ID.AddPointer(Ty);
1710 void *IP = nullptr;
1711 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1712 if (Depth > MaxCastDepth) {
1713 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1714 Op, Ty);
1715 UniqueSCEVs.InsertNode(S, IP);
1716 S->computeAndSetCanonical(*this);
1717 registerUser(S, Op);
1718 return S;
1719 }
1720
1721 // zext(trunc(x)) --> zext(x) or x or trunc(x)
1723 // It's possible the bits taken off by the truncate were all zero bits. If
1724 // so, we should be able to simplify this further.
1725 const SCEV *X = ST->getOperand();
1727 unsigned TruncBits = getTypeSizeInBits(ST->getType());
1728 unsigned NewBits = getTypeSizeInBits(Ty);
1729 if (CR.truncate(TruncBits).zeroExtend(NewBits).contains(
1730 CR.zextOrTrunc(NewBits)))
1731 return getTruncateOrZeroExtend(X, Ty, Depth);
1732 }
1733
1734 // If the input value is a chrec scev, and we can prove that the value
1735 // did not overflow the old, smaller, value, we can zero extend all of the
1736 // operands (often constants). This allows analysis of something like
1737 // this: for (unsigned char X = 0; X < 100; ++X) { int Y = X; }
1739 if (AR->isAffine()) {
1740 const SCEV *Start = AR->getStart();
1741 const SCEV *Step = AR->getStepRecurrence(*this);
1742 unsigned BitWidth = getTypeSizeInBits(AR->getType());
1743 const Loop *L = AR->getLoop();
1744
1745 // If we have special knowledge that this addrec won't overflow,
1746 // we don't need to do any further analysis.
1747 if (AR->hasNoUnsignedWrap()) {
1748 Start =
1750 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1751 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1752 }
1753
1754 // Check whether the backedge-taken count is SCEVCouldNotCompute.
1755 // Note that this serves two purposes: It filters out loops that are
1756 // simply not analyzable, and it covers the case where this code is
1757 // being called from within backedge-taken count analysis, such that
1758 // attempting to ask for the backedge-taken count would likely result
1759 // in infinite recursion. In the later case, the analysis code will
1760 // cope with a conservative value, and it will take care to purge
1761 // that value once it has finished.
1762 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
1763 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
1764 // Manually compute the final value for AR, checking for overflow.
1765
1766 // Check whether the backedge-taken count can be losslessly casted to
1767 // the addrec's type. The count is always unsigned.
1768 const SCEV *CastedMaxBECount =
1769 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
1770 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
1771 CastedMaxBECount, MaxBECount->getType(), Depth);
1772 if (MaxBECount == RecastedMaxBECount) {
1773 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
1774 // Check whether Start+Step*MaxBECount has no unsigned overflow.
1775 const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step,
1777 const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul,
1779 Depth + 1),
1780 WideTy, Depth + 1);
1781 const SCEV *WideStart = getZeroExtendExpr(Start, WideTy, Depth + 1);
1782 const SCEV *WideMaxBECount =
1783 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
1784 const SCEV *OperandExtendedAdd =
1785 getAddExpr(WideStart,
1786 getMulExpr(WideMaxBECount,
1787 getZeroExtendExpr(Step, WideTy, Depth + 1),
1790 if (ZAdd == OperandExtendedAdd) {
1791 // Cache knowledge of AR NUW, which is propagated to this AddRec.
1792 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1793 // Return the expression with the addrec on the outside.
1794 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1795 Depth + 1);
1796 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1797 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1798 }
1799 // Similar to above, only this time treat the step value as signed.
1800 // This covers loops that count down.
1801 OperandExtendedAdd =
1802 getAddExpr(WideStart,
1803 getMulExpr(WideMaxBECount,
1804 getSignExtendExpr(Step, WideTy, Depth + 1),
1807 if (ZAdd == OperandExtendedAdd) {
1808 // Cache knowledge of AR NW, which is propagated to this AddRec.
1809 // Negative step causes unsigned wrap, but it still can't self-wrap.
1810 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1811 // Return the expression with the addrec on the outside.
1812 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1813 Depth + 1);
1814 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1815 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1816 }
1817 }
1818 }
1819
1820 // Normally, in the cases we can prove no-overflow via a
1821 // backedge guarding condition, we can also compute a backedge
1822 // taken count for the loop. The exceptions are assumptions and
1823 // guards present in the loop -- SCEV is not great at exploiting
1824 // these to compute max backedge taken counts, but can still use
1825 // these to prove lack of overflow. Use this fact to avoid
1826 // doing extra work that may not pay off.
1827 if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards ||
1828 !AC.assumptions().empty()) {
1829
1830 auto NewFlags = proveNoUnsignedWrapViaInduction(AR);
1831 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
1832 if (AR->hasNoUnsignedWrap()) {
1833 // Same as nuw case above - duplicated here to avoid a compile time
1834 // issue. It's not clear that the order of checks does matter, but
1835 // it's one of two issue possible causes for a change which was
1836 // reverted. Be conservative for the moment.
1837 Start =
1839 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1840 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1841 }
1842
1843 // For a negative step, we can extend the operands iff doing so only
1844 // traverses values in the range zext([0,UINT_MAX]).
1845 if (isKnownNegative(Step)) {
1847 getSignedRangeMin(Step));
1850 // Cache knowledge of AR NW, which is propagated to this
1851 // AddRec. Negative step causes unsigned wrap, but it
1852 // still can't self-wrap.
1853 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1854 // Return the expression with the addrec on the outside.
1855 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1856 Depth + 1);
1857 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1858 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1859 }
1860 }
1861 }
1862
1863 // zext({C,+,Step}) --> (zext(D) + zext({C-D,+,Step}))<nuw><nsw>
1864 // if D + (C - D + Step * n) could be proven to not unsigned wrap
1865 // where D maximizes the number of trailing zeros of (C - D + Step * n)
1866 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
1867 const APInt &C = SC->getAPInt();
1868 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
1869 if (D != 0) {
1870 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1871 const SCEV *SResidual =
1872 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
1873 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1874 return getAddExpr(SZExtD, SZExtR,
1876 Depth + 1);
1877 }
1878 }
1879
1880 if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) {
1881 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1882 Start =
1884 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1885 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1886 }
1887 }
1888
1889 // zext(A % B) --> zext(A) % zext(B)
1890 {
1891 const SCEV *LHS;
1892 const SCEV *RHS;
1893 if (match(Op, m_scev_URem(m_SCEV(LHS), m_SCEV(RHS), *this)))
1894 return getURemExpr(getZeroExtendExpr(LHS, Ty, Depth + 1),
1895 getZeroExtendExpr(RHS, Ty, Depth + 1));
1896 }
1897
1898 // zext(A / B) --> zext(A) / zext(B).
1899 if (auto *Div = dyn_cast<SCEVUDivExpr>(Op))
1900 return getUDivExpr(getZeroExtendExpr(Div->getLHS(), Ty, Depth + 1),
1901 getZeroExtendExpr(Div->getRHS(), Ty, Depth + 1));
1902
1903 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1904 // zext((A + B + ...)<nuw>) --> (zext(A) + zext(B) + ...)<nuw>
1905 if (SA->hasNoUnsignedWrap()) {
1906 // If the addition does not unsign overflow then we can, by definition,
1907 // commute the zero extension with the addition operation.
1909 for (SCEVUse Op : SA->operands())
1910 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1911 return getAddExpr(Ops, SCEV::FlagNUW, Depth + 1);
1912 }
1913
1914 // zext(C + x + y + ...) --> (zext(D) + zext((C - D) + x + y + ...))
1915 // if D + (C - D + x + y + ...) could be proven to not unsigned wrap
1916 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1917 //
1918 // Often address arithmetics contain expressions like
1919 // (zext (add (shl X, C1), C2)), for instance, (zext (5 + (4 * X))).
1920 // This transformation is useful while proving that such expressions are
1921 // equal or differ by a small constant amount, see LoadStoreVectorizer pass.
1922 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1923 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1924 if (D != 0) {
1925 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1926 const SCEV *SResidual =
1928 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1929 return getAddExpr(SZExtD, SZExtR,
1931 Depth + 1);
1932 }
1933 }
1934 }
1935
1936 if (auto *SM = dyn_cast<SCEVMulExpr>(Op)) {
1937 // zext((A * B * ...)<nuw>) --> (zext(A) * zext(B) * ...)<nuw>
1938 if (SM->hasNoUnsignedWrap()) {
1939 // If the multiply does not unsign overflow then we can, by definition,
1940 // commute the zero extension with the multiply operation.
1942 for (SCEVUse Op : SM->operands())
1943 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1944 return getMulExpr(Ops, SCEV::FlagNUW, Depth + 1);
1945 }
1946
1947 // zext(2^K * (trunc X to iN)) to iM ->
1948 // 2^K * (zext(trunc X to i{N-K}) to iM)<nuw>
1949 //
1950 // Proof:
1951 //
1952 // zext(2^K * (trunc X to iN)) to iM
1953 // = zext((trunc X to iN) << K) to iM
1954 // = zext((trunc X to i{N-K}) << K)<nuw> to iM
1955 // (because shl removes the top K bits)
1956 // = zext((2^K * (trunc X to i{N-K}))<nuw>) to iM
1957 // = (2^K * (zext(trunc X to i{N-K}) to iM))<nuw>.
1958 //
1959 const APInt *C;
1960 const SCEV *TruncRHS;
1961 if (match(SM,
1962 m_scev_Mul(m_scev_APInt(C), m_scev_Trunc(m_SCEV(TruncRHS)))) &&
1963 C->isPowerOf2()) {
1964 int NewTruncBits =
1965 getTypeSizeInBits(SM->getOperand(1)->getType()) - C->logBase2();
1966 Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits);
1967 return getMulExpr(
1968 getZeroExtendExpr(SM->getOperand(0), Ty),
1969 getZeroExtendExpr(getTruncateExpr(TruncRHS, NewTruncTy), Ty),
1970 SCEV::FlagNUW, Depth + 1);
1971 }
1972 }
1973
1974 // zext(umin(x, y)) -> umin(zext(x), zext(y))
1975 // zext(umax(x, y)) -> umax(zext(x), zext(y))
1978 SmallVector<SCEVUse, 4> Operands;
1979 for (SCEVUse Operand : MinMax->operands())
1980 Operands.push_back(getZeroExtendExpr(Operand, Ty));
1982 return getUMinExpr(Operands);
1983 return getUMaxExpr(Operands);
1984 }
1985
1986 // zext(umin_seq(x, y)) -> umin_seq(zext(x), zext(y))
1988 assert(isa<SCEVSequentialUMinExpr>(MinMax) && "Not supported!");
1989 SmallVector<SCEVUse, 4> Operands;
1990 for (SCEVUse Operand : MinMax->operands())
1991 Operands.push_back(getZeroExtendExpr(Operand, Ty));
1992 return getUMinExpr(Operands, /*Sequential*/ true);
1993 }
1994
1995 // The cast wasn't folded; create an explicit cast node.
1996 // Recompute the insert position, as it may have been invalidated.
1997 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1998 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1999 Op, Ty);
2000 UniqueSCEVs.InsertNode(S, IP);
2001 S->computeAndSetCanonical(*this);
2002 registerUser(S, Op);
2003 return S;
2004}
2005
2006const SCEV *
2008 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
2009 "This is not an extending conversion!");
2010 assert(isSCEVable(Ty) &&
2011 "This is not a conversion to a SCEVable type!");
2012 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
2013 Ty = getEffectiveSCEVType(Ty);
2014
2015 FoldID ID(scSignExtend, Op, Ty);
2016 if (const SCEV *S = FoldCache.lookup(ID))
2017 return S;
2018
2019 const SCEV *S = getSignExtendExprImpl(Op, Ty, Depth);
2021 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
2022 return S;
2023}
2024
2026 unsigned Depth) {
2027 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
2028 "This is not an extending conversion!");
2029 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
2030 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
2031 Ty = getEffectiveSCEVType(Ty);
2032
2033 // Fold if the operand is constant.
2034 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
2035 return getConstant(SC->getAPInt().sext(getTypeSizeInBits(Ty)));
2036
2037 // sext(sext(x)) --> sext(x)
2039 return getSignExtendExpr(SS->getOperand(), Ty, Depth + 1);
2040
2041 // sext(zext(x)) --> zext(x)
2043 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
2044
2045 // Before doing any expensive analysis, check to see if we've already
2046 // computed a SCEV for this Op and Ty.
2048 ID.AddInteger(scSignExtend);
2049 ID.AddPointer(Op);
2050 ID.AddPointer(Ty);
2051 void *IP = nullptr;
2052 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2053 // Limit recursion depth.
2054 if (Depth > MaxCastDepth) {
2055 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
2056 Op, Ty);
2057 UniqueSCEVs.InsertNode(S, IP);
2058 S->computeAndSetCanonical(*this);
2059 registerUser(S, Op);
2060 return S;
2061 }
2062
2063 // sext(trunc(x)) --> sext(x) or x or trunc(x)
2065 // It's possible the bits taken off by the truncate were all sign bits. If
2066 // so, we should be able to simplify this further.
2067 const SCEV *X = ST->getOperand();
2069 unsigned TruncBits = getTypeSizeInBits(ST->getType());
2070 unsigned NewBits = getTypeSizeInBits(Ty);
2071 if (CR.truncate(TruncBits).signExtend(NewBits).contains(
2072 CR.sextOrTrunc(NewBits)))
2073 return getTruncateOrSignExtend(X, Ty, Depth);
2074 }
2075
2076 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
2077 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
2078 if (SA->hasNoSignedWrap()) {
2079 // If the addition does not sign overflow then we can, by definition,
2080 // commute the sign extension with the addition operation.
2082 for (SCEVUse Op : SA->operands())
2083 Ops.push_back(getSignExtendExpr(Op, Ty, Depth + 1));
2084 return getAddExpr(Ops, SCEV::FlagNSW, Depth + 1);
2085 }
2086
2087 // sext(C + x + y + ...) --> (sext(D) + sext((C - D) + x + y + ...))
2088 // if D + (C - D + x + y + ...) could be proven to not signed wrap
2089 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
2090 //
2091 // For instance, this will bring two seemingly different expressions:
2092 // 1 + sext(5 + 20 * %x + 24 * %y) and
2093 // sext(6 + 20 * %x + 24 * %y)
2094 // to the same form:
2095 // 2 + sext(4 + 20 * %x + 24 * %y)
2096 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
2097 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
2098 if (D != 0) {
2099 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
2100 const SCEV *SResidual =
2102 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
2103 return getAddExpr(SSExtD, SSExtR,
2105 Depth + 1);
2106 }
2107 }
2108 }
2109 // If the input value is a chrec scev, and we can prove that the value
2110 // did not overflow the old, smaller, value, we can sign extend all of the
2111 // operands (often constants). This allows analysis of something like
2112 // this: for (signed char X = 0; X < 100; ++X) { int Y = X; }
2114 if (AR->isAffine()) {
2115 const SCEV *Start = AR->getStart();
2116 const SCEV *Step = AR->getStepRecurrence(*this);
2117 unsigned BitWidth = getTypeSizeInBits(AR->getType());
2118 const Loop *L = AR->getLoop();
2119
2120 // If we have special knowledge that this addrec won't overflow,
2121 // we don't need to do any further analysis.
2122 if (AR->hasNoSignedWrap()) {
2123 Start =
2125 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2126 return getAddRecExpr(Start, Step, L, SCEV::FlagNSW);
2127 }
2128
2129 // Check whether the backedge-taken count is SCEVCouldNotCompute.
2130 // Note that this serves two purposes: It filters out loops that are
2131 // simply not analyzable, and it covers the case where this code is
2132 // being called from within backedge-taken count analysis, such that
2133 // attempting to ask for the backedge-taken count would likely result
2134 // in infinite recursion. In the later case, the analysis code will
2135 // cope with a conservative value, and it will take care to purge
2136 // that value once it has finished.
2137 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
2138 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
2139 // Manually compute the final value for AR, checking for
2140 // overflow.
2141
2142 // Check whether the backedge-taken count can be losslessly casted to
2143 // the addrec's type. The count is always unsigned.
2144 const SCEV *CastedMaxBECount =
2145 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
2146 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
2147 CastedMaxBECount, MaxBECount->getType(), Depth);
2148 if (MaxBECount == RecastedMaxBECount) {
2149 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
2150 // Check whether Start+Step*MaxBECount has no signed overflow.
2151 const SCEV *SMul = getMulExpr(CastedMaxBECount, Step,
2153 const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul,
2155 Depth + 1),
2156 WideTy, Depth + 1);
2157 const SCEV *WideStart = getSignExtendExpr(Start, WideTy, Depth + 1);
2158 const SCEV *WideMaxBECount =
2159 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
2160 const SCEV *OperandExtendedAdd =
2161 getAddExpr(WideStart,
2162 getMulExpr(WideMaxBECount,
2163 getSignExtendExpr(Step, WideTy, Depth + 1),
2166 if (SAdd == OperandExtendedAdd) {
2167 // Cache knowledge of AR NSW, which is propagated to this AddRec.
2168 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2169 // Return the expression with the addrec on the outside.
2170 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2171 Depth + 1);
2172 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2173 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2174 }
2175 // Similar to above, only this time treat the step value as unsigned.
2176 // This covers loops that count up with an unsigned step.
2177 OperandExtendedAdd =
2178 getAddExpr(WideStart,
2179 getMulExpr(WideMaxBECount,
2180 getZeroExtendExpr(Step, WideTy, Depth + 1),
2183 if (SAdd == OperandExtendedAdd) {
2184 // If AR wraps around then
2185 //
2186 // abs(Step) * MaxBECount > unsigned-max(AR->getType())
2187 // => SAdd != OperandExtendedAdd
2188 //
2189 // Thus (AR is not NW => SAdd != OperandExtendedAdd) <=>
2190 // (SAdd == OperandExtendedAdd => AR is NW)
2191
2192 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
2193
2194 // Return the expression with the addrec on the outside.
2195 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2196 Depth + 1);
2197 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
2198 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2199 }
2200 }
2201 }
2202
2203 auto NewFlags = proveNoSignedWrapViaInduction(AR);
2204 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
2205 if (AR->hasNoSignedWrap()) {
2206 // Same as nsw case above - duplicated here to avoid a compile time
2207 // issue. It's not clear that the order of checks does matter, but
2208 // it's one of two issue possible causes for a change which was
2209 // reverted. Be conservative for the moment.
2210 Start =
2212 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2213 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2214 }
2215
2216 // sext({C,+,Step}) --> (sext(D) + sext({C-D,+,Step}))<nuw><nsw>
2217 // if D + (C - D + Step * n) could be proven to not signed wrap
2218 // where D maximizes the number of trailing zeros of (C - D + Step * n)
2219 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
2220 const APInt &C = SC->getAPInt();
2221 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
2222 if (D != 0) {
2223 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
2224 const SCEV *SResidual =
2225 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
2226 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
2227 return getAddExpr(SSExtD, SSExtR,
2229 Depth + 1);
2230 }
2231 }
2232
2233 if (proveNoWrapByVaryingStart<SCEVSignExtendExpr>(Start, Step, L)) {
2234 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2235 Start =
2237 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2238 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2239 }
2240 }
2241
2242 // If the input value is provably positive and we could not simplify
2243 // away the sext build a zext instead.
2245 return getZeroExtendExpr(Op, Ty, Depth + 1);
2246
2247 // sext(smin(x, y)) -> smin(sext(x), sext(y))
2248 // sext(smax(x, y)) -> smax(sext(x), sext(y))
2251 SmallVector<SCEVUse, 4> Operands;
2252 for (SCEVUse Operand : MinMax->operands())
2253 Operands.push_back(getSignExtendExpr(Operand, Ty));
2255 return getSMinExpr(Operands);
2256 return getSMaxExpr(Operands);
2257 }
2258
2259 // The cast wasn't folded; create an explicit cast node.
2260 // Recompute the insert position, as it may have been invalidated.
2261 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2262 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
2263 Op, Ty);
2264 UniqueSCEVs.InsertNode(S, IP);
2265 S->computeAndSetCanonical(*this);
2266 registerUser(S, Op);
2267 return S;
2268}
2269
2271 Type *Ty) {
2272 switch (Kind) {
2273 case scTruncate:
2274 return getTruncateExpr(Op, Ty);
2275 case scZeroExtend:
2276 return getZeroExtendExpr(Op, Ty);
2277 case scSignExtend:
2278 return getSignExtendExpr(Op, Ty);
2279 case scPtrToInt:
2280 return getPtrToIntExpr(Op, Ty);
2281 default:
2282 llvm_unreachable("Not a SCEV cast expression!");
2283 }
2284}
2285
2286/// getAnyExtendExpr - Return a SCEV for the given operand extended with
2287/// unspecified bits out to the given type.
2289 Type *Ty) {
2290 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
2291 "This is not an extending conversion!");
2292 assert(isSCEVable(Ty) &&
2293 "This is not a conversion to a SCEVable type!");
2294 Ty = getEffectiveSCEVType(Ty);
2295
2296 // Sign-extend negative constants.
2297 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
2298 if (SC->getAPInt().isNegative())
2299 return getSignExtendExpr(Op, Ty);
2300
2301 // Peel off a truncate cast.
2303 const SCEV *NewOp = T->getOperand();
2304 if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty))
2305 return getAnyExtendExpr(NewOp, Ty);
2306 return getTruncateOrNoop(NewOp, Ty);
2307 }
2308
2309 // Next try a zext cast. If the cast is folded, use it.
2310 const SCEV *ZExt = getZeroExtendExpr(Op, Ty);
2311 if (!isa<SCEVZeroExtendExpr>(ZExt))
2312 return ZExt;
2313
2314 // Next try a sext cast. If the cast is folded, use it.
2315 const SCEV *SExt = getSignExtendExpr(Op, Ty);
2316 if (!isa<SCEVSignExtendExpr>(SExt))
2317 return SExt;
2318
2319 // Force the cast to be folded into the operands of an addrec.
2320 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) {
2322 for (const SCEV *Op : AR->operands())
2323 Ops.push_back(getAnyExtendExpr(Op, Ty));
2324 return getAddRecExpr(Ops, AR->getLoop(), SCEV::FlagNW);
2325 }
2326
2327 // If the expression is obviously signed, use the sext cast value.
2328 if (isa<SCEVSMaxExpr>(Op))
2329 return SExt;
2330
2331 // Absent any other information, use the zext cast value.
2332 return ZExt;
2333}
2334
2335/// Process the given Ops list, which is a list of operands to be added under
2336/// the given scale, update the given map. This is a helper function for
2337/// getAddRecExpr. As an example of what it does, given a sequence of operands
2338/// that would form an add expression like this:
2339///
2340/// m + n + 13 + (A * (o + p + (B * (q + m + 29)))) + r + (-1 * r)
2341///
2342/// where A and B are constants, update the map with these values:
2343///
2344/// (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0)
2345///
2346/// and add 13 + A*B*29 to AccumulatedConstant.
2347/// This will allow getAddRecExpr to produce this:
2348///
2349/// 13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B)
2350///
2351/// This form often exposes folding opportunities that are hidden in
2352/// the original operand list.
2353///
2354/// Return true iff it appears that any interesting folding opportunities
2355/// may be exposed. This helps getAddRecExpr short-circuit extra work in
2356/// the common case where no interesting opportunities are present, and
2357/// is also used as a check to avoid infinite recursion.
2360 APInt &AccumulatedConstant,
2362 const APInt &Scale,
2363 ScalarEvolution &SE) {
2364 bool Interesting = false;
2365
2366 // Iterate over the add operands. They are sorted, with constants first.
2367 unsigned i = 0;
2368 while (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
2369 ++i;
2370 // Pull a buried constant out to the outside.
2371 if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero())
2372 Interesting = true;
2373 AccumulatedConstant += Scale * C->getAPInt();
2374 }
2375
2376 // Next comes everything else. We're especially interested in multiplies
2377 // here, but they're in the middle, so just visit the rest with one loop.
2378 for (; i != Ops.size(); ++i) {
2380 if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) {
2381 APInt NewScale =
2382 Scale * cast<SCEVConstant>(Mul->getOperand(0))->getAPInt();
2383 if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) {
2384 // A multiplication of a constant with another add; recurse.
2385 const SCEVAddExpr *Add = cast<SCEVAddExpr>(Mul->getOperand(1));
2386 Interesting |= CollectAddOperandsWithScales(
2387 M, NewOps, AccumulatedConstant, Add->operands(), NewScale, SE);
2388 } else {
2389 // A multiplication of a constant with some other value. Update
2390 // the map.
2391 SmallVector<SCEVUse, 4> MulOps(drop_begin(Mul->operands()));
2392 const SCEV *Key = SE.getMulExpr(MulOps);
2393 auto Pair = M.insert({Key, NewScale});
2394 if (Pair.second) {
2395 NewOps.push_back(Pair.first->first);
2396 } else {
2397 Pair.first->second += NewScale;
2398 // The map already had an entry for this value, which may indicate
2399 // a folding opportunity.
2400 Interesting = true;
2401 }
2402 }
2403 } else {
2404 // An ordinary operand. Update the map.
2405 auto Pair = M.insert({Ops[i], Scale});
2406 if (Pair.second) {
2407 NewOps.push_back(Pair.first->first);
2408 } else {
2409 Pair.first->second += Scale;
2410 // The map already had an entry for this value, which may indicate
2411 // a folding opportunity.
2412 Interesting = true;
2413 }
2414 }
2415 }
2416
2417 return Interesting;
2418}
2419
2421 const SCEV *LHS, const SCEV *RHS,
2422 const Instruction *CtxI) {
2424 unsigned);
2425 switch (BinOp) {
2426 default:
2427 llvm_unreachable("Unsupported binary op");
2428 case Instruction::Add:
2430 break;
2431 case Instruction::Sub:
2433 break;
2434 case Instruction::Mul:
2436 break;
2437 }
2438
2439 const SCEV *(ScalarEvolution::*Extension)(const SCEV *, Type *, unsigned) =
2442
2443 // Check ext(LHS op RHS) == ext(LHS) op ext(RHS)
2444 auto *NarrowTy = cast<IntegerType>(LHS->getType());
2445 auto *WideTy =
2446 IntegerType::get(NarrowTy->getContext(), NarrowTy->getBitWidth() * 2);
2447
2448 const SCEV *A = (this->*Extension)(
2449 (this->*Operation)(LHS, RHS, SCEV::FlagAnyWrap, 0), WideTy, 0);
2450 const SCEV *LHSB = (this->*Extension)(LHS, WideTy, 0);
2451 const SCEV *RHSB = (this->*Extension)(RHS, WideTy, 0);
2452 const SCEV *B = (this->*Operation)(LHSB, RHSB, SCEV::FlagAnyWrap, 0);
2453 if (A == B)
2454 return true;
2455 // Can we use context to prove the fact we need?
2456 if (!CtxI)
2457 return false;
2458 // TODO: Support mul.
2459 if (BinOp == Instruction::Mul)
2460 return false;
2461 auto *RHSC = dyn_cast<SCEVConstant>(RHS);
2462 // TODO: Lift this limitation.
2463 if (!RHSC)
2464 return false;
2465 APInt C = RHSC->getAPInt();
2466 unsigned NumBits = C.getBitWidth();
2467 bool IsSub = (BinOp == Instruction::Sub);
2468 bool IsNegativeConst = (Signed && C.isNegative());
2469 // Compute the direction and magnitude by which we need to check overflow.
2470 bool OverflowDown = IsSub ^ IsNegativeConst;
2471 APInt Magnitude = C;
2472 if (IsNegativeConst) {
2473 if (C == APInt::getSignedMinValue(NumBits))
2474 // TODO: SINT_MIN on inversion gives the same negative value, we don't
2475 // want to deal with that.
2476 return false;
2477 Magnitude = -C;
2478 }
2479
2481 if (OverflowDown) {
2482 // To avoid overflow down, we need to make sure that MIN + Magnitude <= LHS.
2483 APInt Min = Signed ? APInt::getSignedMinValue(NumBits)
2484 : APInt::getMinValue(NumBits);
2485 APInt Limit = Min + Magnitude;
2486 return isKnownPredicateAt(Pred, getConstant(Limit), LHS, CtxI);
2487 } else {
2488 // To avoid overflow up, we need to make sure that LHS <= MAX - Magnitude.
2489 APInt Max = Signed ? APInt::getSignedMaxValue(NumBits)
2490 : APInt::getMaxValue(NumBits);
2491 APInt Limit = Max - Magnitude;
2492 return isKnownPredicateAt(Pred, LHS, getConstant(Limit), CtxI);
2493 }
2494}
2495
2496std::optional<SCEV::NoWrapFlags>
2498 const OverflowingBinaryOperator *OBO) {
2499 // It cannot be done any better.
2500 if (OBO->hasNoUnsignedWrap() && OBO->hasNoSignedWrap())
2501 return std::nullopt;
2502
2504
2505 if (OBO->hasNoUnsignedWrap())
2507 if (OBO->hasNoSignedWrap())
2509
2510 bool Deduced = false;
2511
2512 if (OBO->getOpcode() != Instruction::Add &&
2513 OBO->getOpcode() != Instruction::Sub &&
2514 OBO->getOpcode() != Instruction::Mul)
2515 return std::nullopt;
2516
2517 const SCEV *LHS = getSCEV(OBO->getOperand(0));
2518 const SCEV *RHS = getSCEV(OBO->getOperand(1));
2519
2520 const Instruction *CtxI =
2522 if (!OBO->hasNoUnsignedWrap() &&
2524 /* Signed */ false, LHS, RHS, CtxI)) {
2526 Deduced = true;
2527 }
2528
2529 if (!OBO->hasNoSignedWrap() &&
2531 /* Signed */ true, LHS, RHS, CtxI)) {
2533 Deduced = true;
2534 }
2535
2536 if (Deduced)
2537 return Flags;
2538 return std::nullopt;
2539}
2540
2541// We're trying to construct a SCEV of type `Type' with `Ops' as operands and
2542// `OldFlags' as can't-wrap behavior. Infer a more aggressive set of
2543// can't-overflow flags for the operation if possible.
2547 SCEV::NoWrapFlags Flags) {
2548 using namespace std::placeholders;
2549
2550 using OBO = OverflowingBinaryOperator;
2551
2552 bool CanAnalyze =
2554 (void)CanAnalyze;
2555 assert(CanAnalyze && "don't call from other places!");
2556
2557 int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
2558 SCEV::NoWrapFlags SignOrUnsignWrap =
2559 ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2560
2561 // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
2562 auto IsKnownNonNegative = [&](SCEVUse U) {
2563 return SE->isKnownNonNegative(U);
2564 };
2565
2566 if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative))
2567 Flags =
2568 ScalarEvolution::setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
2569
2570 SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2571
2572 if (SignOrUnsignWrap != SignOrUnsignMask &&
2573 (Type == scAddExpr || Type == scMulExpr) && Ops.size() == 2 &&
2574 isa<SCEVConstant>(Ops[0])) {
2575
2576 auto Opcode = [&] {
2577 switch (Type) {
2578 case scAddExpr:
2579 return Instruction::Add;
2580 case scMulExpr:
2581 return Instruction::Mul;
2582 default:
2583 llvm_unreachable("Unexpected SCEV op.");
2584 }
2585 }();
2586
2587 const APInt &C = cast<SCEVConstant>(Ops[0])->getAPInt();
2588
2589 // (A <opcode> C) --> (A <opcode> C)<nsw> if the op doesn't sign overflow.
2590 if (!(SignOrUnsignWrap & SCEV::FlagNSW)) {
2592 Opcode, C, OBO::NoSignedWrap);
2593 if (NSWRegion.contains(SE->getSignedRange(Ops[1])))
2595 }
2596
2597 // (A <opcode> C) --> (A <opcode> C)<nuw> if the op doesn't unsign overflow.
2598 if (!(SignOrUnsignWrap & SCEV::FlagNUW)) {
2600 Opcode, C, OBO::NoUnsignedWrap);
2601 if (NUWRegion.contains(SE->getUnsignedRange(Ops[1])))
2603 }
2604 }
2605
2606 // <0,+,nonnegative><nw> is also nuw
2607 // TODO: Add corresponding nsw case
2609 !ScalarEvolution::hasFlags(Flags, SCEV::FlagNUW) && Ops.size() == 2 &&
2610 Ops[0]->isZero() && IsKnownNonNegative(Ops[1]))
2612
2613 // both (udiv X, Y) * Y and Y * (udiv X, Y) are always NUW
2615 Ops.size() == 2) {
2616 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[0]))
2617 if (UDiv->getOperand(1) == Ops[1])
2619 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[1]))
2620 if (UDiv->getOperand(1) == Ops[0])
2622 }
2623
2624 return Flags;
2625}
2626
2628 return isLoopInvariant(S, L) && properlyDominates(S, L->getHeader());
2629}
2630
2631/// Get a canonical add expression, or something simpler if possible.
2633 SCEV::NoWrapFlags OrigFlags,
2634 unsigned Depth) {
2635 assert(!(OrigFlags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) &&
2636 "only nuw or nsw allowed");
2637 assert(!Ops.empty() && "Cannot get empty add!");
2638 if (Ops.size() == 1) return Ops[0];
2639#ifndef NDEBUG
2640 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
2641 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2642 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
2643 "SCEVAddExpr operand types don't match!");
2644 unsigned NumPtrs = count_if(
2645 Ops, [](const SCEV *Op) { return Op->getType()->isPointerTy(); });
2646 assert(NumPtrs <= 1 && "add has at most one pointer operand");
2647#endif
2648
2649 const SCEV *Folded = constantFoldAndGroupOps(
2650 *this, LI, DT, Ops,
2651 [](const APInt &C1, const APInt &C2) { return C1 + C2; },
2652 [](const APInt &C) { return C.isZero(); }, // identity
2653 [](const APInt &C) { return false; }); // absorber
2654 if (Folded)
2655 return Folded;
2656
2657 unsigned Idx = isa<SCEVConstant>(Ops[0]) ? 1 : 0;
2658
2659 // Delay expensive flag strengthening until necessary.
2660 auto ComputeFlags = [this, OrigFlags](ArrayRef<SCEVUse> Ops) {
2661 return StrengthenNoWrapFlags(this, scAddExpr, Ops, OrigFlags);
2662 };
2663
2664 // Limit recursion calls depth.
2666 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2667
2668 if (SCEV *S = findExistingSCEVInCache(scAddExpr, Ops)) {
2669 // Don't strengthen flags if we have no new information.
2670 SCEVAddExpr *Add = static_cast<SCEVAddExpr *>(S);
2671 if (Add->getNoWrapFlags(OrigFlags) != OrigFlags)
2672 Add->setNoWrapFlags(ComputeFlags(Ops));
2673 return S;
2674 }
2675
2676 // Okay, check to see if the same value occurs in the operand list more than
2677 // once. If so, merge them together into an multiply expression. Since we
2678 // sorted the list, these values are required to be adjacent.
2679 Type *Ty = Ops[0]->getType();
2680 bool FoundMatch = false;
2681 for (unsigned i = 0, e = Ops.size(); i != e-1; ++i)
2682 if (Ops[i] == Ops[i+1]) { // X + Y + Y --> X + Y*2
2683 // Scan ahead to count how many equal operands there are.
2684 unsigned Count = 2;
2685 while (i+Count != e && Ops[i+Count] == Ops[i])
2686 ++Count;
2687 // Merge the values into a multiply.
2688 SCEVUse Scale = getConstant(Ty, Count);
2689 const SCEV *Mul = getMulExpr(Scale, Ops[i], SCEV::FlagAnyWrap, Depth + 1);
2690 if (Ops.size() == Count)
2691 return Mul;
2692 Ops[i] = Mul;
2693 Ops.erase(Ops.begin()+i+1, Ops.begin()+i+Count);
2694 --i; e -= Count - 1;
2695 FoundMatch = true;
2696 }
2697 if (FoundMatch)
2698 return getAddExpr(Ops, OrigFlags, Depth + 1);
2699
2700 // Check for truncates. If all the operands are truncated from the same
2701 // type, see if factoring out the truncate would permit the result to be
2702 // folded. eg., n*trunc(x) + m*trunc(y) --> trunc(trunc(m)*x + trunc(n)*y)
2703 // if the contents of the resulting outer trunc fold to something simple.
2704 auto FindTruncSrcType = [&]() -> Type * {
2705 // We're ultimately looking to fold an addrec of truncs and muls of only
2706 // constants and truncs, so if we find any other types of SCEV
2707 // as operands of the addrec then we bail and return nullptr here.
2708 // Otherwise, we return the type of the operand of a trunc that we find.
2709 if (auto *T = dyn_cast<SCEVTruncateExpr>(Ops[Idx]))
2710 return T->getOperand()->getType();
2711 if (const auto *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
2712 SCEVUse LastOp = Mul->getOperand(Mul->getNumOperands() - 1);
2713 if (const auto *T = dyn_cast<SCEVTruncateExpr>(LastOp))
2714 return T->getOperand()->getType();
2715 }
2716 return nullptr;
2717 };
2718 if (auto *SrcType = FindTruncSrcType()) {
2719 SmallVector<SCEVUse, 8> LargeOps;
2720 bool Ok = true;
2721 // Check all the operands to see if they can be represented in the
2722 // source type of the truncate.
2723 for (const SCEV *Op : Ops) {
2725 if (T->getOperand()->getType() != SrcType) {
2726 Ok = false;
2727 break;
2728 }
2729 LargeOps.push_back(T->getOperand());
2730 } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Op)) {
2731 LargeOps.push_back(getAnyExtendExpr(C, SrcType));
2732 } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Op)) {
2733 SmallVector<SCEVUse, 8> LargeMulOps;
2734 for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) {
2735 if (const SCEVTruncateExpr *T =
2736 dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) {
2737 if (T->getOperand()->getType() != SrcType) {
2738 Ok = false;
2739 break;
2740 }
2741 LargeMulOps.push_back(T->getOperand());
2742 } else if (const auto *C = dyn_cast<SCEVConstant>(M->getOperand(j))) {
2743 LargeMulOps.push_back(getAnyExtendExpr(C, SrcType));
2744 } else {
2745 Ok = false;
2746 break;
2747 }
2748 }
2749 if (Ok)
2750 LargeOps.push_back(getMulExpr(LargeMulOps, SCEV::FlagAnyWrap, Depth + 1));
2751 } else {
2752 Ok = false;
2753 break;
2754 }
2755 }
2756 if (Ok) {
2757 // Evaluate the expression in the larger type.
2758 const SCEV *Fold = getAddExpr(LargeOps, SCEV::FlagAnyWrap, Depth + 1);
2759 // If it folds to something simple, use it. Otherwise, don't.
2760 if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
2761 return getTruncateExpr(Fold, Ty);
2762 }
2763 }
2764
2765 if (Ops.size() == 2) {
2766 // Check if we have an expression of the form ((X + C1) - C2), where C1 and
2767 // C2 can be folded in a way that allows retaining wrapping flags of (X +
2768 // C1).
2769 const SCEV *A = Ops[0];
2770 const SCEV *B = Ops[1];
2771 auto *AddExpr = dyn_cast<SCEVAddExpr>(B);
2772 auto *C = dyn_cast<SCEVConstant>(A);
2773 if (AddExpr && C && isa<SCEVConstant>(AddExpr->getOperand(0))) {
2774 auto C1 = cast<SCEVConstant>(AddExpr->getOperand(0))->getAPInt();
2775 auto C2 = C->getAPInt();
2776 SCEV::NoWrapFlags PreservedFlags = SCEV::FlagAnyWrap;
2777
2778 APInt ConstAdd = C1 + C2;
2779 auto AddFlags = AddExpr->getNoWrapFlags();
2780 // Adding a smaller constant is NUW if the original AddExpr was NUW.
2782 ConstAdd.ule(C1)) {
2783 PreservedFlags =
2785 }
2786
2787 // Adding a constant with the same sign and small magnitude is NSW, if the
2788 // original AddExpr was NSW.
2790 C1.isSignBitSet() == ConstAdd.isSignBitSet() &&
2791 ConstAdd.abs().ule(C1.abs())) {
2792 PreservedFlags =
2794 }
2795
2796 if (PreservedFlags != SCEV::FlagAnyWrap) {
2797 SmallVector<SCEVUse, 4> NewOps(AddExpr->operands());
2798 NewOps[0] = getConstant(ConstAdd);
2799 return getAddExpr(NewOps, PreservedFlags);
2800 }
2801 }
2802
2803 // Try to push the constant operand into a ZExt: A + zext (-A + B) -> zext
2804 // (B), if trunc (A) + -A + B does not unsigned-wrap.
2805 const SCEVAddExpr *InnerAdd;
2806 if (match(B, m_scev_ZExt(m_scev_Add(InnerAdd)))) {
2807 const SCEV *NarrowA = getTruncateExpr(A, InnerAdd->getType());
2808 if (NarrowA == getNegativeSCEV(InnerAdd->getOperand(0)) &&
2809 getZeroExtendExpr(NarrowA, B->getType()) == A &&
2810 hasFlags(StrengthenNoWrapFlags(this, scAddExpr, {NarrowA, InnerAdd},
2812 SCEV::FlagNUW)) {
2813 return getZeroExtendExpr(getAddExpr(NarrowA, InnerAdd), B->getType());
2814 }
2815 }
2816 }
2817
2818 // Canonicalize (-1 * urem X, Y) + X --> (Y * X/Y)
2819 const SCEV *Y;
2820 if (Ops.size() == 2 &&
2821 match(Ops[0],
2823 m_scev_URem(m_scev_Specific(Ops[1]), m_SCEV(Y), *this))))
2824 return getMulExpr(Y, getUDivExpr(Ops[1], Y));
2825
2826 // Skip past any other cast SCEVs.
2827 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
2828 ++Idx;
2829
2830 // If there are add operands they would be next.
2831 if (Idx < Ops.size()) {
2832 bool DeletedAdd = false;
2833 // If the original flags and all inlined SCEVAddExprs are NUW, use the
2834 // common NUW flag for expression after inlining. Other flags cannot be
2835 // preserved, because they may depend on the original order of operations.
2836 SCEV::NoWrapFlags CommonFlags = maskFlags(OrigFlags, SCEV::FlagNUW);
2837 while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
2838 if (Ops.size() > AddOpsInlineThreshold ||
2839 Add->getNumOperands() > AddOpsInlineThreshold)
2840 break;
2841 // If we have an add, expand the add operands onto the end of the operands
2842 // list.
2843 Ops.erase(Ops.begin()+Idx);
2844 append_range(Ops, Add->operands());
2845 DeletedAdd = true;
2846 CommonFlags = maskFlags(CommonFlags, Add->getNoWrapFlags());
2847 }
2848
2849 // If we deleted at least one add, we added operands to the end of the list,
2850 // and they are not necessarily sorted. Recurse to resort and resimplify
2851 // any operands we just acquired.
2852 if (DeletedAdd)
2853 return getAddExpr(Ops, CommonFlags, Depth + 1);
2854 }
2855
2856 // Skip over the add expression until we get to a multiply.
2857 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
2858 ++Idx;
2859
2860 // Check to see if there are any folding opportunities present with
2861 // operands multiplied by constant values.
2862 if (Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx])) {
2866 APInt AccumulatedConstant(BitWidth, 0);
2867 if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2868 Ops, APInt(BitWidth, 1), *this)) {
2869 struct APIntCompare {
2870 bool operator()(const APInt &LHS, const APInt &RHS) const {
2871 return LHS.ult(RHS);
2872 }
2873 };
2874
2875 // Some interesting folding opportunity is present, so its worthwhile to
2876 // re-generate the operands list. Group the operands by constant scale,
2877 // to avoid multiplying by the same constant scale multiple times.
2878 std::map<APInt, SmallVector<SCEVUse, 4>, APIntCompare> MulOpLists;
2879 for (const SCEV *NewOp : NewOps)
2880 MulOpLists[M.find(NewOp)->second].push_back(NewOp);
2881 // Re-generate the operands list.
2882 Ops.clear();
2883 if (AccumulatedConstant != 0)
2884 Ops.push_back(getConstant(AccumulatedConstant));
2885 for (auto &MulOp : MulOpLists) {
2886 if (MulOp.first == 1) {
2887 Ops.push_back(getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1));
2888 } else if (MulOp.first != 0) {
2889 Ops.push_back(getMulExpr(
2890 getConstant(MulOp.first),
2891 getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1),
2892 SCEV::FlagAnyWrap, Depth + 1));
2893 }
2894 }
2895 if (Ops.empty())
2896 return getZero(Ty);
2897 if (Ops.size() == 1)
2898 return Ops[0];
2899 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2900 }
2901 }
2902
2903 // If we are adding something to a multiply expression, make sure the
2904 // something is not already an operand of the multiply. If so, merge it into
2905 // the multiply.
2906 for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) {
2907 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]);
2908 for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) {
2909 const SCEV *MulOpSCEV = Mul->getOperand(MulOp);
2910 if (isa<SCEVConstant>(MulOpSCEV))
2911 continue;
2912 for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
2913 if (MulOpSCEV == Ops[AddOp]) {
2914 // Fold W + X + (X * Y * Z) --> W + (X * ((Y*Z)+1))
2915 const SCEV *InnerMul = Mul->getOperand(MulOp == 0);
2916 if (Mul->getNumOperands() != 2) {
2917 // If the multiply has more than two operands, we must get the
2918 // Y*Z term.
2919 SmallVector<SCEVUse, 4> MulOps(Mul->operands().take_front(MulOp));
2920 append_range(MulOps, Mul->operands().drop_front(MulOp + 1));
2921 InnerMul = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2922 }
2923 const SCEV *AddOne =
2924 getAddExpr(getOne(Ty), InnerMul, SCEV::FlagAnyWrap, Depth + 1);
2925 const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV,
2927 if (Ops.size() == 2) return OuterMul;
2928 if (AddOp < Idx) {
2929 Ops.erase(Ops.begin()+AddOp);
2930 Ops.erase(Ops.begin()+Idx-1);
2931 } else {
2932 Ops.erase(Ops.begin()+Idx);
2933 Ops.erase(Ops.begin()+AddOp-1);
2934 }
2935 Ops.push_back(OuterMul);
2936 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2937 }
2938
2939 // Check this multiply against other multiplies being added together.
2940 for (unsigned OtherMulIdx = Idx+1;
2941 OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]);
2942 ++OtherMulIdx) {
2943 const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]);
2944 // If MulOp occurs in OtherMul, we can fold the two multiplies
2945 // together.
2946 for (unsigned OMulOp = 0, e = OtherMul->getNumOperands();
2947 OMulOp != e; ++OMulOp)
2948 if (OtherMul->getOperand(OMulOp) == MulOpSCEV) {
2949 // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E))
2950 const SCEV *InnerMul1 = Mul->getOperand(MulOp == 0);
2951 if (Mul->getNumOperands() != 2) {
2952 SmallVector<SCEVUse, 4> MulOps(Mul->operands().take_front(MulOp));
2953 append_range(MulOps, Mul->operands().drop_front(MulOp+1));
2954 InnerMul1 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2955 }
2956 const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0);
2957 if (OtherMul->getNumOperands() != 2) {
2959 OtherMul->operands().take_front(OMulOp));
2960 append_range(MulOps, OtherMul->operands().drop_front(OMulOp+1));
2961 InnerMul2 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2962 }
2963 const SCEV *InnerMulSum =
2964 getAddExpr(InnerMul1, InnerMul2, SCEV::FlagAnyWrap, Depth + 1);
2965 const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum,
2967 if (Ops.size() == 2) return OuterMul;
2968 Ops.erase(Ops.begin()+Idx);
2969 Ops.erase(Ops.begin()+OtherMulIdx-1);
2970 Ops.push_back(OuterMul);
2971 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2972 }
2973 }
2974 }
2975 }
2976
2977 // If there are any add recurrences in the operands list, see if any other
2978 // added values are loop invariant. If so, we can fold them into the
2979 // recurrence.
2980 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
2981 ++Idx;
2982
2983 // Scan over all recurrences, trying to fold loop invariants into them.
2984 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
2985 // Scan all of the other operands to this add and add them to the vector if
2986 // they are loop invariant w.r.t. the recurrence.
2988 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
2989 const Loop *AddRecLoop = AddRec->getLoop();
2990 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2991 if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) {
2992 LIOps.push_back(Ops[i]);
2993 Ops.erase(Ops.begin()+i);
2994 --i; --e;
2995 }
2996
2997 // If we found some loop invariants, fold them into the recurrence.
2998 if (!LIOps.empty()) {
2999 // Compute nowrap flags for the addition of the loop-invariant ops and
3000 // the addrec. Temporarily push it as an operand for that purpose. These
3001 // flags are valid in the scope of the addrec only.
3002 LIOps.push_back(AddRec);
3003 SCEV::NoWrapFlags Flags = ComputeFlags(LIOps);
3004 LIOps.pop_back();
3005
3006 // NLI + LI + {Start,+,Step} --> NLI + {LI+Start,+,Step}
3007 LIOps.push_back(AddRec->getStart());
3008
3009 SmallVector<SCEVUse, 4> AddRecOps(AddRec->operands());
3010
3011 // It is not in general safe to propagate flags valid on an add within
3012 // the addrec scope to one outside it. We must prove that the inner
3013 // scope is guaranteed to execute if the outer one does to be able to
3014 // safely propagate. We know the program is undefined if poison is
3015 // produced on the inner scoped addrec. We also know that *for this use*
3016 // the outer scoped add can't overflow (because of the flags we just
3017 // computed for the inner scoped add) without the program being undefined.
3018 // Proving that entry to the outer scope neccesitates entry to the inner
3019 // scope, thus proves the program undefined if the flags would be violated
3020 // in the outer scope.
3021 SCEV::NoWrapFlags AddFlags = Flags;
3022 if (AddFlags != SCEV::FlagAnyWrap) {
3023 auto *DefI = getDefiningScopeBound(LIOps);
3024 auto *ReachI = &*AddRecLoop->getHeader()->begin();
3025 if (!isGuaranteedToTransferExecutionTo(DefI, ReachI))
3026 AddFlags = SCEV::FlagAnyWrap;
3027 }
3028 AddRecOps[0] = getAddExpr(LIOps, AddFlags, Depth + 1);
3029
3030 // Build the new addrec. Propagate the NUW and NSW flags if both the
3031 // outer add and the inner addrec are guaranteed to have no overflow.
3032 // Always propagate NW.
3033 Flags = AddRec->getNoWrapFlags(setFlags(Flags, SCEV::FlagNW));
3034 const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags);
3035
3036 // If all of the other operands were loop invariant, we are done.
3037 if (Ops.size() == 1) return NewRec;
3038
3039 // Otherwise, add the folded AddRec by the non-invariant parts.
3040 for (unsigned i = 0;; ++i)
3041 if (Ops[i] == AddRec) {
3042 Ops[i] = NewRec;
3043 break;
3044 }
3045 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3046 }
3047
3048 // Okay, if there weren't any loop invariants to be folded, check to see if
3049 // there are multiple AddRec's with the same loop induction variable being
3050 // added together. If so, we can fold them.
3051 for (unsigned OtherIdx = Idx+1;
3052 OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
3053 ++OtherIdx) {
3054 // We expect the AddRecExpr's to be sorted in reverse dominance order,
3055 // so that the 1st found AddRecExpr is dominated by all others.
3056 assert(DT.dominates(
3057 cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()->getHeader(),
3058 AddRec->getLoop()->getHeader()) &&
3059 "AddRecExprs are not sorted in reverse dominance order?");
3060 if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) {
3061 // Other + {A,+,B}<L> + {C,+,D}<L> --> Other + {A+C,+,B+D}<L>
3062 SmallVector<SCEVUse, 4> AddRecOps(AddRec->operands());
3063 for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
3064 ++OtherIdx) {
3065 const auto *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
3066 if (OtherAddRec->getLoop() == AddRecLoop) {
3067 for (unsigned i = 0, e = OtherAddRec->getNumOperands();
3068 i != e; ++i) {
3069 if (i >= AddRecOps.size()) {
3070 append_range(AddRecOps, OtherAddRec->operands().drop_front(i));
3071 break;
3072 }
3073 AddRecOps[i] =
3074 getAddExpr(AddRecOps[i], OtherAddRec->getOperand(i),
3076 }
3077 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
3078 }
3079 }
3080 // Step size has changed, so we cannot guarantee no self-wraparound.
3081 Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap);
3082 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3083 }
3084 }
3085
3086 // Otherwise couldn't fold anything into this recurrence. Move onto the
3087 // next one.
3088 }
3089
3090 // Okay, it looks like we really DO need an add expr. Check to see if we
3091 // already have one, otherwise create a new one.
3092 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
3093}
3094
3095const SCEV *ScalarEvolution::getOrCreateAddExpr(ArrayRef<SCEVUse> Ops,
3096 SCEV::NoWrapFlags Flags) {
3098 ID.AddInteger(scAddExpr);
3099 for (const SCEV *Op : Ops)
3100 ID.AddPointer(Op);
3101 void *IP = nullptr;
3102 SCEVAddExpr *S =
3103 static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3104 if (!S) {
3105 SCEVUse *O = SCEVAllocator.Allocate<SCEVUse>(Ops.size());
3107 S = new (SCEVAllocator)
3108 SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size());
3109 UniqueSCEVs.InsertNode(S, IP);
3110 S->computeAndSetCanonical(*this);
3111 registerUser(S, Ops);
3112 }
3113 S->setNoWrapFlags(Flags);
3114 return S;
3115}
3116
3117const SCEV *ScalarEvolution::getOrCreateAddRecExpr(ArrayRef<SCEVUse> Ops,
3118 const Loop *L,
3119 SCEV::NoWrapFlags Flags) {
3120 FoldingSetNodeID ID;
3121 ID.AddInteger(scAddRecExpr);
3122 for (const SCEV *Op : Ops)
3123 ID.AddPointer(Op);
3124 ID.AddPointer(L);
3125 void *IP = nullptr;
3126 SCEVAddRecExpr *S =
3127 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3128 if (!S) {
3129 SCEVUse *O = SCEVAllocator.Allocate<SCEVUse>(Ops.size());
3131 S = new (SCEVAllocator)
3132 SCEVAddRecExpr(ID.Intern(SCEVAllocator), O, Ops.size(), L);
3133 UniqueSCEVs.InsertNode(S, IP);
3134 S->computeAndSetCanonical(*this);
3135 LoopUsers[L].push_back(S);
3136 registerUser(S, Ops);
3137 }
3138 setNoWrapFlags(S, Flags);
3139 return S;
3140}
3141
3142const SCEV *ScalarEvolution::getOrCreateMulExpr(ArrayRef<SCEVUse> Ops,
3143 SCEV::NoWrapFlags Flags) {
3144 FoldingSetNodeID ID;
3145 ID.AddInteger(scMulExpr);
3146 for (const SCEV *Op : Ops)
3147 ID.AddPointer(Op);
3148 void *IP = nullptr;
3149 SCEVMulExpr *S =
3150 static_cast<SCEVMulExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3151 if (!S) {
3152 SCEVUse *O = SCEVAllocator.Allocate<SCEVUse>(Ops.size());
3154 S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator),
3155 O, Ops.size());
3156 UniqueSCEVs.InsertNode(S, IP);
3157 S->computeAndSetCanonical(*this);
3158 registerUser(S, Ops);
3159 }
3160 S->setNoWrapFlags(Flags);
3161 return S;
3162}
3163
3164static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) {
3165 uint64_t k = i*j;
3166 if (j > 1 && k / j != i) Overflow = true;
3167 return k;
3168}
3169
3170/// Compute the result of "n choose k", the binomial coefficient. If an
3171/// intermediate computation overflows, Overflow will be set and the return will
3172/// be garbage. Overflow is not cleared on absence of overflow.
3173static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) {
3174 // We use the multiplicative formula:
3175 // n(n-1)(n-2)...(n-(k-1)) / k(k-1)(k-2)...1 .
3176 // At each iteration, we take the n-th term of the numeral and divide by the
3177 // (k-n)th term of the denominator. This division will always produce an
3178 // integral result, and helps reduce the chance of overflow in the
3179 // intermediate computations. However, we can still overflow even when the
3180 // final result would fit.
3181
3182 if (n == 0 || n == k) return 1;
3183 if (k > n) return 0;
3184
3185 if (k > n/2)
3186 k = n-k;
3187
3188 uint64_t r = 1;
3189 for (uint64_t i = 1; i <= k; ++i) {
3190 r = umul_ov(r, n-(i-1), Overflow);
3191 r /= i;
3192 }
3193 return r;
3194}
3195
3196/// Determine if any of the operands in this SCEV are a constant or if
3197/// any of the add or multiply expressions in this SCEV contain a constant.
3198static bool containsConstantInAddMulChain(const SCEV *StartExpr) {
3199 struct FindConstantInAddMulChain {
3200 bool FoundConstant = false;
3201
3202 bool follow(const SCEV *S) {
3203 FoundConstant |= isa<SCEVConstant>(S);
3204 return isa<SCEVAddExpr>(S) || isa<SCEVMulExpr>(S);
3205 }
3206
3207 bool isDone() const {
3208 return FoundConstant;
3209 }
3210 };
3211
3212 FindConstantInAddMulChain F;
3214 ST.visitAll(StartExpr);
3215 return F.FoundConstant;
3216}
3217
3218/// Get a canonical multiply expression, or something simpler if possible.
3220 SCEV::NoWrapFlags OrigFlags,
3221 unsigned Depth) {
3222 assert(OrigFlags == maskFlags(OrigFlags, SCEV::FlagNUW | SCEV::FlagNSW) &&
3223 "only nuw or nsw allowed");
3224 assert(!Ops.empty() && "Cannot get empty mul!");
3225 if (Ops.size() == 1) return Ops[0];
3226#ifndef NDEBUG
3227 Type *ETy = Ops[0]->getType();
3228 assert(!ETy->isPointerTy());
3229 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
3230 assert(Ops[i]->getType() == ETy &&
3231 "SCEVMulExpr operand types don't match!");
3232#endif
3233
3234 const SCEV *Folded = constantFoldAndGroupOps(
3235 *this, LI, DT, Ops,
3236 [](const APInt &C1, const APInt &C2) { return C1 * C2; },
3237 [](const APInt &C) { return C.isOne(); }, // identity
3238 [](const APInt &C) { return C.isZero(); }); // absorber
3239 if (Folded)
3240 return Folded;
3241
3242 // Delay expensive flag strengthening until necessary.
3243 auto ComputeFlags = [this, OrigFlags](const ArrayRef<SCEVUse> Ops) {
3244 return StrengthenNoWrapFlags(this, scMulExpr, Ops, OrigFlags);
3245 };
3246
3247 // Limit recursion calls depth.
3249 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3250
3251 if (SCEV *S = findExistingSCEVInCache(scMulExpr, Ops)) {
3252 // Don't strengthen flags if we have no new information.
3253 SCEVMulExpr *Mul = static_cast<SCEVMulExpr *>(S);
3254 if (Mul->getNoWrapFlags(OrigFlags) != OrigFlags)
3255 Mul->setNoWrapFlags(ComputeFlags(Ops));
3256 return S;
3257 }
3258
3259 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3260 if (Ops.size() == 2) {
3261 // C1*(C2+V) -> C1*C2 + C1*V
3262 // If any of Add's ops are Adds or Muls with a constant, apply this
3263 // transformation as well.
3264 //
3265 // TODO: There are some cases where this transformation is not
3266 // profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of
3267 // this transformation should be narrowed down.
3268 const SCEV *Op0, *Op1;
3269 if (match(Ops[1], m_scev_Add(m_SCEV(Op0), m_SCEV(Op1))) &&
3271 const SCEV *LHS = getMulExpr(LHSC, Op0, SCEV::FlagAnyWrap, Depth + 1);
3272 const SCEV *RHS = getMulExpr(LHSC, Op1, SCEV::FlagAnyWrap, Depth + 1);
3273 return getAddExpr(LHS, RHS, SCEV::FlagAnyWrap, Depth + 1);
3274 }
3275
3276 if (Ops[0]->isAllOnesValue()) {
3277 // If we have a mul by -1 of an add, try distributing the -1 among the
3278 // add operands.
3279 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) {
3281 bool AnyFolded = false;
3282 for (const SCEV *AddOp : Add->operands()) {
3283 const SCEV *Mul = getMulExpr(Ops[0], SCEVUse(AddOp),
3285 if (!isa<SCEVMulExpr>(Mul)) AnyFolded = true;
3286 NewOps.push_back(Mul);
3287 }
3288 if (AnyFolded)
3289 return getAddExpr(NewOps, SCEV::FlagAnyWrap, Depth + 1);
3290 } else if (const auto *AddRec = dyn_cast<SCEVAddRecExpr>(Ops[1])) {
3291 // Negation preserves a recurrence's no self-wrap property.
3292 SmallVector<SCEVUse, 4> Operands;
3293 for (const SCEV *AddRecOp : AddRec->operands())
3294 Operands.push_back(getMulExpr(Ops[0], SCEVUse(AddRecOp),
3295 SCEV::FlagAnyWrap, Depth + 1));
3296 // Let M be the minimum representable signed value. AddRec with nsw
3297 // multiplied by -1 can have signed overflow if and only if it takes a
3298 // value of M: M * (-1) would stay M and (M + 1) * (-1) would be the
3299 // maximum signed value. In all other cases signed overflow is
3300 // impossible.
3301 auto FlagsMask = SCEV::FlagNW;
3302 if (hasFlags(AddRec->getNoWrapFlags(), SCEV::FlagNSW)) {
3303 auto MinInt =
3304 APInt::getSignedMinValue(getTypeSizeInBits(AddRec->getType()));
3305 if (getSignedRangeMin(AddRec) != MinInt)
3306 FlagsMask = setFlags(FlagsMask, SCEV::FlagNSW);
3307 }
3308 return getAddRecExpr(Operands, AddRec->getLoop(),
3309 AddRec->getNoWrapFlags(FlagsMask));
3310 }
3311 }
3312
3313 // Try to push the constant operand into a ZExt: C * zext (A + B) ->
3314 // zext (C*A + C*B) if trunc (C) * (A + B) does not unsigned-wrap.
3315 const SCEVAddExpr *InnerAdd;
3316 if (match(Ops[1], m_scev_ZExt(m_scev_Add(InnerAdd)))) {
3317 const SCEV *NarrowC = getTruncateExpr(LHSC, InnerAdd->getType());
3318 if (isa<SCEVConstant>(InnerAdd->getOperand(0)) &&
3319 getZeroExtendExpr(NarrowC, Ops[1]->getType()) == LHSC &&
3320 hasFlags(StrengthenNoWrapFlags(this, scMulExpr, {NarrowC, InnerAdd},
3322 SCEV::FlagNUW)) {
3323 auto *Res = getMulExpr(NarrowC, InnerAdd, SCEV::FlagNUW, Depth + 1);
3324 return getZeroExtendExpr(Res, Ops[1]->getType(), Depth + 1);
3325 };
3326 }
3327
3328 // Try to fold (C1 * D /u C2) -> C1/C2 * D, if C1 and C2 are powers-of-2,
3329 // D is a multiple of C2, and C1 is a multiple of C2. If C2 is a multiple
3330 // of C1, fold to (D /u (C2 /u C1)).
3331 const SCEV *D;
3332 APInt C1V = LHSC->getAPInt();
3333 // (C1 * D /u C2) == -1 * -C1 * D /u C2 when C1 != INT_MIN. Don't treat -1
3334 // as -1 * 1, as it won't enable additional folds.
3335 if (C1V.isNegative() && !C1V.isMinSignedValue() && !C1V.isAllOnes())
3336 C1V = C1V.abs();
3337 const SCEVConstant *C2;
3338 if (C1V.isPowerOf2() &&
3340 C2->getAPInt().isPowerOf2() &&
3341 C1V.logBase2() <= getMinTrailingZeros(D)) {
3342 const SCEV *NewMul = nullptr;
3343 if (C1V.uge(C2->getAPInt())) {
3344 NewMul = getMulExpr(getUDivExpr(getConstant(C1V), C2), D);
3345 } else if (C2->getAPInt().logBase2() <= getMinTrailingZeros(D)) {
3346 assert(C1V.ugt(1) && "C1 <= 1 should have been folded earlier");
3347 NewMul = getUDivExpr(D, getUDivExpr(C2, getConstant(C1V)));
3348 }
3349 if (NewMul)
3350 return C1V == LHSC->getAPInt() ? NewMul : getNegativeSCEV(NewMul);
3351 }
3352 }
3353 }
3354
3355 // Skip over the add expression until we get to a multiply.
3356 unsigned Idx = 0;
3357 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
3358 ++Idx;
3359
3360 // If there are mul operands inline them all into this expression.
3361 if (Idx < Ops.size()) {
3362 bool DeletedMul = false;
3363 while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
3364 if (Ops.size() > MulOpsInlineThreshold)
3365 break;
3366 // If we have an mul, expand the mul operands onto the end of the
3367 // operands list.
3368 Ops.erase(Ops.begin()+Idx);
3369 append_range(Ops, Mul->operands());
3370 DeletedMul = true;
3371 }
3372
3373 // If we deleted at least one mul, we added operands to the end of the
3374 // list, and they are not necessarily sorted. Recurse to resort and
3375 // resimplify any operands we just acquired.
3376 if (DeletedMul)
3377 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3378 }
3379
3380 // If there are any add recurrences in the operands list, see if any other
3381 // added values are loop invariant. If so, we can fold them into the
3382 // recurrence.
3383 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
3384 ++Idx;
3385
3386 // Scan over all recurrences, trying to fold loop invariants into them.
3387 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
3388 // Scan all of the other operands to this mul and add them to the vector
3389 // if they are loop invariant w.r.t. the recurrence.
3391 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
3392 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3393 if (isAvailableAtLoopEntry(Ops[i], AddRec->getLoop())) {
3394 LIOps.push_back(Ops[i]);
3395 Ops.erase(Ops.begin()+i);
3396 --i; --e;
3397 }
3398
3399 // If we found some loop invariants, fold them into the recurrence.
3400 if (!LIOps.empty()) {
3401 // NLI * LI * {Start,+,Step} --> NLI * {LI*Start,+,LI*Step}
3403 NewOps.reserve(AddRec->getNumOperands());
3404 const SCEV *Scale = getMulExpr(LIOps, SCEV::FlagAnyWrap, Depth + 1);
3405
3406 // If both the mul and addrec are nuw, we can preserve nuw.
3407 // If both the mul and addrec are nsw, we can only preserve nsw if either
3408 // a) they are also nuw, or
3409 // b) all multiplications of addrec operands with scale are nsw.
3410 SCEV::NoWrapFlags Flags =
3411 AddRec->getNoWrapFlags(ComputeFlags({Scale, AddRec}));
3412
3413 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
3414 NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i),
3415 SCEV::FlagAnyWrap, Depth + 1));
3416
3417 if (hasFlags(Flags, SCEV::FlagNSW) && !hasFlags(Flags, SCEV::FlagNUW)) {
3419 Instruction::Mul, getSignedRange(Scale),
3421 if (!NSWRegion.contains(getSignedRange(AddRec->getOperand(i))))
3422 Flags = clearFlags(Flags, SCEV::FlagNSW);
3423 }
3424 }
3425
3426 const SCEV *NewRec = getAddRecExpr(NewOps, AddRec->getLoop(), Flags);
3427
3428 // If all of the other operands were loop invariant, we are done.
3429 if (Ops.size() == 1) return NewRec;
3430
3431 // Otherwise, multiply the folded AddRec by the non-invariant parts.
3432 for (unsigned i = 0;; ++i)
3433 if (Ops[i] == AddRec) {
3434 Ops[i] = NewRec;
3435 break;
3436 }
3437 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3438 }
3439
3440 // Okay, if there weren't any loop invariants to be folded, check to see
3441 // if there are multiple AddRec's with the same loop induction variable
3442 // being multiplied together. If so, we can fold them.
3443
3444 // {A1,+,A2,+,...,+,An}<L> * {B1,+,B2,+,...,+,Bn}<L>
3445 // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [
3446 // choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z
3447 // ]]],+,...up to x=2n}.
3448 // Note that the arguments to choose() are always integers with values
3449 // known at compile time, never SCEV objects.
3450 //
3451 // The implementation avoids pointless extra computations when the two
3452 // addrec's are of different length (mathematically, it's equivalent to
3453 // an infinite stream of zeros on the right).
3454 bool OpsModified = false;
3455 for (unsigned OtherIdx = Idx+1;
3456 OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
3457 ++OtherIdx) {
3458 const SCEVAddRecExpr *OtherAddRec =
3459 dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]);
3460 if (!OtherAddRec || OtherAddRec->getLoop() != AddRec->getLoop())
3461 continue;
3462
3463 // Limit max number of arguments to avoid creation of unreasonably big
3464 // SCEVAddRecs with very complex operands.
3465 if (AddRec->getNumOperands() + OtherAddRec->getNumOperands() - 1 >
3466 MaxAddRecSize || hasHugeExpression({AddRec, OtherAddRec}))
3467 continue;
3468
3469 bool Overflow = false;
3470 Type *Ty = AddRec->getType();
3471 bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64;
3472 SmallVector<SCEVUse, 7> AddRecOps;
3473 for (int x = 0, xe = AddRec->getNumOperands() +
3474 OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) {
3476 for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) {
3477 uint64_t Coeff1 = Choose(x, 2*x - y, Overflow);
3478 for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1),
3479 ze = std::min(x+1, (int)OtherAddRec->getNumOperands());
3480 z < ze && !Overflow; ++z) {
3481 uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow);
3482 uint64_t Coeff;
3483 if (LargerThan64Bits)
3484 Coeff = umul_ov(Coeff1, Coeff2, Overflow);
3485 else
3486 Coeff = Coeff1*Coeff2;
3487 const SCEV *CoeffTerm = getConstant(Ty, Coeff);
3488 const SCEV *Term1 = AddRec->getOperand(y-z);
3489 const SCEV *Term2 = OtherAddRec->getOperand(z);
3490 SumOps.push_back(getMulExpr(CoeffTerm, Term1, Term2,
3491 SCEV::FlagAnyWrap, Depth + 1));
3492 }
3493 }
3494 if (SumOps.empty())
3495 SumOps.push_back(getZero(Ty));
3496 AddRecOps.push_back(getAddExpr(SumOps, SCEV::FlagAnyWrap, Depth + 1));
3497 }
3498 if (!Overflow) {
3499 const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRec->getLoop(),
3501 if (Ops.size() == 2) return NewAddRec;
3502 Ops[Idx] = NewAddRec;
3503 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
3504 OpsModified = true;
3505 AddRec = dyn_cast<SCEVAddRecExpr>(NewAddRec);
3506 if (!AddRec)
3507 break;
3508 }
3509 }
3510 if (OpsModified)
3511 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3512
3513 // Otherwise couldn't fold anything into this recurrence. Move onto the
3514 // next one.
3515 }
3516
3517 // Okay, it looks like we really DO need an mul expr. Check to see if we
3518 // already have one, otherwise create a new one.
3519 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3520}
3521
3522/// Represents an unsigned remainder expression based on unsigned division.
3524 assert(getEffectiveSCEVType(LHS->getType()) ==
3525 getEffectiveSCEVType(RHS->getType()) &&
3526 "SCEVURemExpr operand types don't match!");
3527
3528 // Short-circuit easy cases
3529 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3530 // If constant is one, the result is trivial
3531 if (RHSC->getValue()->isOne())
3532 return getZero(LHS->getType()); // X urem 1 --> 0
3533
3534 // If constant is a power of two, fold into a zext(trunc(LHS)).
3535 if (RHSC->getAPInt().isPowerOf2()) {
3536 Type *FullTy = LHS->getType();
3537 Type *TruncTy =
3538 IntegerType::get(getContext(), RHSC->getAPInt().logBase2());
3539 return getZeroExtendExpr(getTruncateExpr(LHS, TruncTy), FullTy);
3540 }
3541 }
3542
3543 // Fallback to %a == %x urem %y == %x -<nuw> ((%x udiv %y) *<nuw> %y)
3544 const SCEV *UDiv = getUDivExpr(LHS, RHS);
3545 const SCEV *Mult = getMulExpr(UDiv, RHS, SCEV::FlagNUW);
3546 return getMinusSCEV(LHS, Mult, SCEV::FlagNUW);
3547}
3548
3549/// Get a canonical unsigned division expression, or something simpler if
3550/// possible.
3552 assert(!LHS->getType()->isPointerTy() &&
3553 "SCEVUDivExpr operand can't be pointer!");
3554 assert(LHS->getType() == RHS->getType() &&
3555 "SCEVUDivExpr operand types don't match!");
3556
3558 ID.AddInteger(scUDivExpr);
3559 ID.AddPointer(LHS);
3560 ID.AddPointer(RHS);
3561 void *IP = nullptr;
3562 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3563 return S;
3564
3565 // 0 udiv Y == 0
3566 if (match(LHS, m_scev_Zero()))
3567 return LHS;
3568
3569 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3570 if (RHSC->getValue()->isOne())
3571 return LHS; // X udiv 1 --> x
3572 // If the denominator is zero, the result of the udiv is undefined. Don't
3573 // try to analyze it, because the resolution chosen here may differ from
3574 // the resolution chosen in other parts of the compiler.
3575 if (!RHSC->getValue()->isZero()) {
3576 // Determine if the division can be folded into the operands of
3577 // its operands.
3578 // TODO: Generalize this to non-constants by using known-bits information.
3579 Type *Ty = LHS->getType();
3580 unsigned LZ = RHSC->getAPInt().countl_zero();
3581 unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1;
3582 // For non-power-of-two values, effectively round the value up to the
3583 // nearest power of two.
3584 if (!RHSC->getAPInt().isPowerOf2())
3585 ++MaxShiftAmt;
3586 IntegerType *ExtTy =
3587 IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt);
3588 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
3589 if (const SCEVConstant *Step =
3590 dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this))) {
3591 // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded.
3592 const APInt &StepInt = Step->getAPInt();
3593 const APInt &DivInt = RHSC->getAPInt();
3594 if (!StepInt.urem(DivInt) &&
3595 getZeroExtendExpr(AR, ExtTy) ==
3596 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3597 getZeroExtendExpr(Step, ExtTy),
3598 AR->getLoop(), SCEV::FlagAnyWrap)) {
3599 SmallVector<SCEVUse, 4> Operands;
3600 for (const SCEV *Op : AR->operands())
3601 Operands.push_back(getUDivExpr(Op, RHS));
3602 return getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagNW);
3603 }
3604 /// Get a canonical UDivExpr for a recurrence.
3605 /// {X,+,N}/C => {Y,+,N}/C where Y=X-(X%N). Safe when C%N=0.
3606 const APInt *StartRem;
3607 if (!DivInt.urem(StepInt) && match(getURemExpr(AR->getStart(), Step),
3608 m_scev_APInt(StartRem))) {
3609 bool NoWrap =
3610 getZeroExtendExpr(AR, ExtTy) ==
3611 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3612 getZeroExtendExpr(Step, ExtTy), AR->getLoop(),
3614
3615 // With N <= C and both N, C as powers-of-2, the transformation
3616 // {X,+,N}/C => {(X - X%N),+,N}/C preserves division results even
3617 // if wrapping occurs, as the division results remain equivalent for
3618 // all offsets in [[(X - X%N), X).
3619 bool CanFoldWithWrap = StepInt.ule(DivInt) && // N <= C
3620 StepInt.isPowerOf2() && DivInt.isPowerOf2();
3621 // Only fold if the subtraction can be folded in the start
3622 // expression.
3623 const SCEV *NewStart =
3624 getMinusSCEV(AR->getStart(), getConstant(*StartRem));
3625 if (*StartRem != 0 && (NoWrap || CanFoldWithWrap) &&
3626 !isa<SCEVAddExpr>(NewStart)) {
3627 const SCEV *NewLHS =
3628 getAddRecExpr(NewStart, Step, AR->getLoop(),
3629 NoWrap ? SCEV::FlagNW : SCEV::FlagAnyWrap);
3630 if (LHS != NewLHS) {
3631 LHS = NewLHS;
3632
3633 // Reset the ID to include the new LHS, and check if it is
3634 // already cached.
3635 ID.clear();
3636 ID.AddInteger(scUDivExpr);
3637 ID.AddPointer(LHS);
3638 ID.AddPointer(RHS);
3639 IP = nullptr;
3640 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3641 return S;
3642 }
3643 }
3644 }
3645 }
3646 // (A*B)/C --> A*(B/C) if safe and B/C can be folded.
3647 if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) {
3648 SmallVector<SCEVUse, 4> Operands;
3649 for (const SCEV *Op : M->operands())
3650 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3651 if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands))
3652 // Find an operand that's safely divisible.
3653 for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) {
3654 const SCEV *Op = M->getOperand(i);
3655 const SCEV *Div = getUDivExpr(Op, RHSC);
3656 if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) {
3657 Operands = SmallVector<SCEVUse, 4>(M->operands());
3658 Operands[i] = Div;
3659 return getMulExpr(Operands);
3660 }
3661 }
3662 }
3663
3664 // (A/B)/C --> A/(B*C) if safe and B*C can be folded.
3665 if (const SCEVUDivExpr *OtherDiv = dyn_cast<SCEVUDivExpr>(LHS)) {
3666 if (auto *DivisorConstant =
3667 dyn_cast<SCEVConstant>(OtherDiv->getRHS())) {
3668 bool Overflow = false;
3669 APInt NewRHS =
3670 DivisorConstant->getAPInt().umul_ov(RHSC->getAPInt(), Overflow);
3671 if (Overflow) {
3672 return getConstant(RHSC->getType(), 0, false);
3673 }
3674 return getUDivExpr(OtherDiv->getLHS(), getConstant(NewRHS));
3675 }
3676 }
3677
3678 // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded.
3679 if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(LHS)) {
3680 SmallVector<SCEVUse, 4> Operands;
3681 for (const SCEV *Op : A->operands())
3682 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3683 if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) {
3684 Operands.clear();
3685 for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) {
3686 const SCEV *Op = getUDivExpr(A->getOperand(i), RHS);
3687 if (isa<SCEVUDivExpr>(Op) ||
3688 getMulExpr(Op, RHS) != A->getOperand(i))
3689 break;
3690 Operands.push_back(Op);
3691 }
3692 if (Operands.size() == A->getNumOperands())
3693 return getAddExpr(Operands);
3694 }
3695 }
3696
3697 // Fold if both operands are constant.
3698 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS))
3699 return getConstant(LHSC->getAPInt().udiv(RHSC->getAPInt()));
3700 }
3701 }
3702
3703 // ((-C + (C smax %x)) /u %x) evaluates to zero, for any positive constant C.
3704 const APInt *NegC, *C;
3705 if (match(LHS,
3708 NegC->isNegative() && !NegC->isMinSignedValue() && *C == -*NegC)
3709 return getZero(LHS->getType());
3710
3711 // TODO: Generalize to handle any common factors.
3712 // udiv (mul nuw a, vscale), (mul nuw b, vscale) --> udiv a, b
3713 const SCEV *NewLHS, *NewRHS;
3714 if (match(LHS, m_scev_c_NUWMul(m_SCEV(NewLHS), m_SCEVVScale())) &&
3715 match(RHS, m_scev_c_NUWMul(m_SCEV(NewRHS), m_SCEVVScale())))
3716 return getUDivExpr(NewLHS, NewRHS);
3717
3718 // The Insertion Point (IP) might be invalid by now (due to UniqueSCEVs
3719 // changes). Make sure we get a new one.
3720 IP = nullptr;
3721 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
3722 SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator),
3723 LHS, RHS);
3724 UniqueSCEVs.InsertNode(S, IP);
3725 S->computeAndSetCanonical(*this);
3726 registerUser(S, ArrayRef<SCEVUse>({LHS, RHS}));
3727 return S;
3728}
3729
3730APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) {
3731 APInt A = C1->getAPInt().abs();
3732 APInt B = C2->getAPInt().abs();
3733 uint32_t ABW = A.getBitWidth();
3734 uint32_t BBW = B.getBitWidth();
3735
3736 if (ABW > BBW)
3737 B = B.zext(ABW);
3738 else if (ABW < BBW)
3739 A = A.zext(BBW);
3740
3741 return APIntOps::GreatestCommonDivisor(std::move(A), std::move(B));
3742}
3743
3744/// Get a canonical unsigned division expression, or something simpler if
3745/// possible. There is no representation for an exact udiv in SCEV IR, but we
3746/// can attempt to remove factors from the LHS and RHS. We can't do this when
3747/// it's not exact because the udiv may be clearing bits.
3749 // TODO: we could try to find factors in all sorts of things, but for now we
3750 // just deal with u/exact (multiply, constant). See SCEVDivision towards the
3751 // end of this file for inspiration.
3752
3754 if (!Mul || !Mul->hasNoUnsignedWrap())
3755 return getUDivExpr(LHS, RHS);
3756
3757 if (const SCEVConstant *RHSCst = dyn_cast<SCEVConstant>(RHS)) {
3758 // If the mulexpr multiplies by a constant, then that constant must be the
3759 // first element of the mulexpr.
3760 if (const auto *LHSCst = dyn_cast<SCEVConstant>(Mul->getOperand(0))) {
3761 if (LHSCst == RHSCst) {
3762 SmallVector<SCEVUse, 2> Operands(drop_begin(Mul->operands()));
3763 return getMulExpr(Operands);
3764 }
3765
3766 // We can't just assume that LHSCst divides RHSCst cleanly, it could be
3767 // that there's a factor provided by one of the other terms. We need to
3768 // check.
3769 APInt Factor = gcd(LHSCst, RHSCst);
3770 if (!Factor.isIntN(1)) {
3771 LHSCst =
3772 cast<SCEVConstant>(getConstant(LHSCst->getAPInt().udiv(Factor)));
3773 RHSCst =
3774 cast<SCEVConstant>(getConstant(RHSCst->getAPInt().udiv(Factor)));
3775 SmallVector<SCEVUse, 2> Operands;
3776 Operands.push_back(LHSCst);
3777 append_range(Operands, Mul->operands().drop_front());
3778 LHS = getMulExpr(Operands);
3779 RHS = RHSCst;
3781 if (!Mul)
3782 return getUDivExactExpr(LHS, RHS);
3783 }
3784 }
3785 }
3786
3787 for (int i = 0, e = Mul->getNumOperands(); i != e; ++i) {
3788 if (Mul->getOperand(i) == RHS) {
3789 SmallVector<SCEVUse, 2> Operands;
3790 append_range(Operands, Mul->operands().take_front(i));
3791 append_range(Operands, Mul->operands().drop_front(i + 1));
3792 return getMulExpr(Operands);
3793 }
3794 }
3795
3796 return getUDivExpr(LHS, RHS);
3797}
3798
3799/// Get an add recurrence expression for the specified loop. Simplify the
3800/// expression as much as possible.
3802 const Loop *L,
3803 SCEV::NoWrapFlags Flags) {
3804 SmallVector<SCEVUse, 4> Operands;
3805 Operands.push_back(Start);
3806 if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step))
3807 if (StepChrec->getLoop() == L) {
3808 append_range(Operands, StepChrec->operands());
3809 return getAddRecExpr(Operands, L, maskFlags(Flags, SCEV::FlagNW));
3810 }
3811
3812 Operands.push_back(Step);
3813 return getAddRecExpr(Operands, L, Flags);
3814}
3815
3816/// Get an add recurrence expression for the specified loop. Simplify the
3817/// expression as much as possible.
3819 const Loop *L,
3820 SCEV::NoWrapFlags Flags) {
3821 if (Operands.size() == 1) return Operands[0];
3822#ifndef NDEBUG
3823 Type *ETy = getEffectiveSCEVType(Operands[0]->getType());
3824 for (const SCEV *Op : llvm::drop_begin(Operands)) {
3825 assert(getEffectiveSCEVType(Op->getType()) == ETy &&
3826 "SCEVAddRecExpr operand types don't match!");
3827 assert(!Op->getType()->isPointerTy() && "Step must be integer");
3828 }
3829 for (const SCEV *Op : Operands)
3831 "SCEVAddRecExpr operand is not available at loop entry!");
3832#endif
3833
3834 if (Operands.back()->isZero()) {
3835 Operands.pop_back();
3836 return getAddRecExpr(Operands, L, SCEV::FlagAnyWrap); // {X,+,0} --> X
3837 }
3838
3839 // It's tempting to want to call getConstantMaxBackedgeTakenCount count here and
3840 // use that information to infer NUW and NSW flags. However, computing a
3841 // BE count requires calling getAddRecExpr, so we may not yet have a
3842 // meaningful BE count at this point (and if we don't, we'd be stuck
3843 // with a SCEVCouldNotCompute as the cached BE count).
3844
3845 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
3846
3847 // Canonicalize nested AddRecs in by nesting them in order of loop depth.
3848 if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
3849 const Loop *NestedLoop = NestedAR->getLoop();
3850 if (L->contains(NestedLoop)
3851 ? (L->getLoopDepth() < NestedLoop->getLoopDepth())
3852 : (!NestedLoop->contains(L) &&
3853 DT.dominates(L->getHeader(), NestedLoop->getHeader()))) {
3854 SmallVector<SCEVUse, 4> NestedOperands(NestedAR->operands());
3855 Operands[0] = NestedAR->getStart();
3856 // AddRecs require their operands be loop-invariant with respect to their
3857 // loops. Don't perform this transformation if it would break this
3858 // requirement.
3859 bool AllInvariant = all_of(
3860 Operands, [&](const SCEV *Op) { return isLoopInvariant(Op, L); });
3861
3862 if (AllInvariant) {
3863 // Create a recurrence for the outer loop with the same step size.
3864 //
3865 // The outer recurrence keeps its NW flag but only keeps NUW/NSW if the
3866 // inner recurrence has the same property.
3867 SCEV::NoWrapFlags OuterFlags =
3868 maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags());
3869
3870 NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags);
3871 AllInvariant = all_of(NestedOperands, [&](const SCEV *Op) {
3872 return isLoopInvariant(Op, NestedLoop);
3873 });
3874
3875 if (AllInvariant) {
3876 // Ok, both add recurrences are valid after the transformation.
3877 //
3878 // The inner recurrence keeps its NW flag but only keeps NUW/NSW if
3879 // the outer recurrence has the same property.
3880 SCEV::NoWrapFlags InnerFlags =
3881 maskFlags(NestedAR->getNoWrapFlags(), SCEV::FlagNW | Flags);
3882 return getAddRecExpr(NestedOperands, NestedLoop, InnerFlags);
3883 }
3884 }
3885 // Reset Operands to its original state.
3886 Operands[0] = NestedAR;
3887 }
3888 }
3889
3890 // Okay, it looks like we really DO need an addrec expr. Check to see if we
3891 // already have one, otherwise create a new one.
3892 return getOrCreateAddRecExpr(Operands, L, Flags);
3893}
3894
3896 ArrayRef<SCEVUse> IndexExprs) {
3897 const SCEV *BaseExpr = getSCEV(GEP->getPointerOperand());
3898 // getSCEV(Base)->getType() has the same address space as Base->getType()
3899 // because SCEV::getType() preserves the address space.
3900 GEPNoWrapFlags NW = GEP->getNoWrapFlags();
3901 if (NW != GEPNoWrapFlags::none()) {
3902 // We'd like to propagate flags from the IR to the corresponding SCEV nodes,
3903 // but to do that, we have to ensure that said flag is valid in the entire
3904 // defined scope of the SCEV.
3905 // TODO: non-instructions have global scope. We might be able to prove
3906 // some global scope cases
3907 auto *GEPI = dyn_cast<Instruction>(GEP);
3908 if (!GEPI || !isSCEVExprNeverPoison(GEPI))
3909 NW = GEPNoWrapFlags::none();
3910 }
3911
3912 return getGEPExpr(BaseExpr, IndexExprs, GEP->getSourceElementType(), NW);
3913}
3914
3916 ArrayRef<SCEVUse> IndexExprs,
3917 Type *SrcElementTy, GEPNoWrapFlags NW) {
3919 if (NW.hasNoUnsignedSignedWrap())
3920 OffsetWrap = setFlags(OffsetWrap, SCEV::FlagNSW);
3921 if (NW.hasNoUnsignedWrap())
3922 OffsetWrap = setFlags(OffsetWrap, SCEV::FlagNUW);
3923
3924 Type *CurTy = BaseExpr->getType();
3925 Type *IntIdxTy = getEffectiveSCEVType(BaseExpr->getType());
3926 bool FirstIter = true;
3928 for (SCEVUse IndexExpr : IndexExprs) {
3929 // Compute the (potentially symbolic) offset in bytes for this index.
3930 if (StructType *STy = dyn_cast<StructType>(CurTy)) {
3931 // For a struct, add the member offset.
3932 ConstantInt *Index = cast<SCEVConstant>(IndexExpr)->getValue();
3933 unsigned FieldNo = Index->getZExtValue();
3934 const SCEV *FieldOffset = getOffsetOfExpr(IntIdxTy, STy, FieldNo);
3935 Offsets.push_back(FieldOffset);
3936
3937 // Update CurTy to the type of the field at Index.
3938 CurTy = STy->getTypeAtIndex(Index);
3939 } else {
3940 // Update CurTy to its element type.
3941 if (FirstIter) {
3942 assert(isa<PointerType>(CurTy) &&
3943 "The first index of a GEP indexes a pointer");
3944 CurTy = SrcElementTy;
3945 FirstIter = false;
3946 } else {
3948 }
3949 // For an array, add the element offset, explicitly scaled.
3950 const SCEV *ElementSize = getSizeOfExpr(IntIdxTy, CurTy);
3951 // Getelementptr indices are signed.
3952 IndexExpr = getTruncateOrSignExtend(IndexExpr, IntIdxTy);
3953
3954 // Multiply the index by the element size to compute the element offset.
3955 const SCEV *LocalOffset = getMulExpr(IndexExpr, ElementSize, OffsetWrap);
3956 Offsets.push_back(LocalOffset);
3957 }
3958 }
3959
3960 // Handle degenerate case of GEP without offsets.
3961 if (Offsets.empty())
3962 return BaseExpr;
3963
3964 // Add the offsets together, assuming nsw if inbounds.
3965 const SCEV *Offset = getAddExpr(Offsets, OffsetWrap);
3966 // Add the base address and the offset. We cannot use the nsw flag, as the
3967 // base address is unsigned. However, if we know that the offset is
3968 // non-negative, we can use nuw.
3969 bool NUW = NW.hasNoUnsignedWrap() ||
3972 auto *GEPExpr = getAddExpr(BaseExpr, Offset, BaseWrap);
3973 assert(BaseExpr->getType() == GEPExpr->getType() &&
3974 "GEP should not change type mid-flight.");
3975 return GEPExpr;
3976}
3977
3978SCEV *ScalarEvolution::findExistingSCEVInCache(SCEVTypes SCEVType,
3981 ID.AddInteger(SCEVType);
3982 for (const SCEV *Op : Ops)
3983 ID.AddPointer(Op);
3984 void *IP = nullptr;
3985 return UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3986}
3987
3988SCEV *ScalarEvolution::findExistingSCEVInCache(SCEVTypes SCEVType,
3991 ID.AddInteger(SCEVType);
3992 for (const SCEV *Op : Ops)
3993 ID.AddPointer(Op);
3994 void *IP = nullptr;
3995 return UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3996}
3997
3998const SCEV *ScalarEvolution::getAbsExpr(const SCEV *Op, bool IsNSW) {
4000 return getSMaxExpr(Op, getNegativeSCEV(Op, Flags));
4001}
4002
4005 assert(SCEVMinMaxExpr::isMinMaxType(Kind) && "Not a SCEVMinMaxExpr!");
4006 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
4007 if (Ops.size() == 1) return Ops[0];
4008#ifndef NDEBUG
4009 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
4010 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4011 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
4012 "Operand types don't match!");
4013 assert(Ops[0]->getType()->isPointerTy() ==
4014 Ops[i]->getType()->isPointerTy() &&
4015 "min/max should be consistently pointerish");
4016 }
4017#endif
4018
4019 bool IsSigned = Kind == scSMaxExpr || Kind == scSMinExpr;
4020 bool IsMax = Kind == scSMaxExpr || Kind == scUMaxExpr;
4021
4022 const SCEV *Folded = constantFoldAndGroupOps(
4023 *this, LI, DT, Ops,
4024 [&](const APInt &C1, const APInt &C2) {
4025 switch (Kind) {
4026 case scSMaxExpr:
4027 return APIntOps::smax(C1, C2);
4028 case scSMinExpr:
4029 return APIntOps::smin(C1, C2);
4030 case scUMaxExpr:
4031 return APIntOps::umax(C1, C2);
4032 case scUMinExpr:
4033 return APIntOps::umin(C1, C2);
4034 default:
4035 llvm_unreachable("Unknown SCEV min/max opcode");
4036 }
4037 },
4038 [&](const APInt &C) {
4039 // identity
4040 if (IsMax)
4041 return IsSigned ? C.isMinSignedValue() : C.isMinValue();
4042 else
4043 return IsSigned ? C.isMaxSignedValue() : C.isMaxValue();
4044 },
4045 [&](const APInt &C) {
4046 // absorber
4047 if (IsMax)
4048 return IsSigned ? C.isMaxSignedValue() : C.isMaxValue();
4049 else
4050 return IsSigned ? C.isMinSignedValue() : C.isMinValue();
4051 });
4052 if (Folded)
4053 return Folded;
4054
4055 // Check if we have created the same expression before.
4056 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops)) {
4057 return S;
4058 }
4059
4060 // Find the first operation of the same kind
4061 unsigned Idx = 0;
4062 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < Kind)
4063 ++Idx;
4064
4065 // Check to see if one of the operands is of the same kind. If so, expand its
4066 // operands onto our operand list, and recurse to simplify.
4067 if (Idx < Ops.size()) {
4068 bool DeletedAny = false;
4069 while (Ops[Idx]->getSCEVType() == Kind) {
4070 const SCEVMinMaxExpr *SMME = cast<SCEVMinMaxExpr>(Ops[Idx]);
4071 Ops.erase(Ops.begin()+Idx);
4072 append_range(Ops, SMME->operands());
4073 DeletedAny = true;
4074 }
4075
4076 if (DeletedAny)
4077 return getMinMaxExpr(Kind, Ops);
4078 }
4079
4080 // Okay, check to see if the same value occurs in the operand list twice. If
4081 // so, delete one. Since we sorted the list, these values are required to
4082 // be adjacent.
4087 llvm::CmpInst::Predicate FirstPred = IsMax ? GEPred : LEPred;
4088 llvm::CmpInst::Predicate SecondPred = IsMax ? LEPred : GEPred;
4089 for (unsigned i = 0, e = Ops.size() - 1; i != e; ++i) {
4090 if (Ops[i] == Ops[i + 1] ||
4091 isKnownViaNonRecursiveReasoning(FirstPred, Ops[i], Ops[i + 1])) {
4092 // X op Y op Y --> X op Y
4093 // X op Y --> X, if we know X, Y are ordered appropriately
4094 Ops.erase(Ops.begin() + i + 1, Ops.begin() + i + 2);
4095 --i;
4096 --e;
4097 } else if (isKnownViaNonRecursiveReasoning(SecondPred, Ops[i],
4098 Ops[i + 1])) {
4099 // X op Y --> Y, if we know X, Y are ordered appropriately
4100 Ops.erase(Ops.begin() + i, Ops.begin() + i + 1);
4101 --i;
4102 --e;
4103 }
4104 }
4105
4106 if (Ops.size() == 1) return Ops[0];
4107
4108 assert(!Ops.empty() && "Reduced smax down to nothing!");
4109
4110 // Okay, it looks like we really DO need an expr. Check to see if we
4111 // already have one, otherwise create a new one.
4113 ID.AddInteger(Kind);
4114 for (const SCEV *Op : Ops)
4115 ID.AddPointer(Op);
4116 void *IP = nullptr;
4117 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
4118 if (ExistingSCEV)
4119 return ExistingSCEV;
4120 SCEVUse *O = SCEVAllocator.Allocate<SCEVUse>(Ops.size());
4122 SCEV *S = new (SCEVAllocator)
4123 SCEVMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
4124
4125 UniqueSCEVs.InsertNode(S, IP);
4126 S->computeAndSetCanonical(*this);
4127 registerUser(S, Ops);
4128 return S;
4129}
4130
4131namespace {
4132
4133class SCEVSequentialMinMaxDeduplicatingVisitor final
4134 : public SCEVVisitor<SCEVSequentialMinMaxDeduplicatingVisitor,
4135 std::optional<const SCEV *>> {
4136 using RetVal = std::optional<const SCEV *>;
4138
4139 ScalarEvolution &SE;
4140 const SCEVTypes RootKind; // Must be a sequential min/max expression.
4141 const SCEVTypes NonSequentialRootKind; // Non-sequential variant of RootKind.
4143
4144 bool canRecurseInto(SCEVTypes Kind) const {
4145 // We can only recurse into the SCEV expression of the same effective type
4146 // as the type of our root SCEV expression.
4147 return RootKind == Kind || NonSequentialRootKind == Kind;
4148 };
4149
4150 RetVal visitAnyMinMaxExpr(const SCEV *S) {
4152 "Only for min/max expressions.");
4153 SCEVTypes Kind = S->getSCEVType();
4154
4155 if (!canRecurseInto(Kind))
4156 return S;
4157
4158 auto *NAry = cast<SCEVNAryExpr>(S);
4159 SmallVector<SCEVUse> NewOps;
4160 bool Changed = visit(Kind, NAry->operands(), NewOps);
4161
4162 if (!Changed)
4163 return S;
4164 if (NewOps.empty())
4165 return std::nullopt;
4166
4168 ? SE.getSequentialMinMaxExpr(Kind, NewOps)
4169 : SE.getMinMaxExpr(Kind, NewOps);
4170 }
4171
4172 RetVal visit(const SCEV *S) {
4173 // Has the whole operand been seen already?
4174 if (!SeenOps.insert(S).second)
4175 return std::nullopt;
4176 return Base::visit(S);
4177 }
4178
4179public:
4180 SCEVSequentialMinMaxDeduplicatingVisitor(ScalarEvolution &SE,
4181 SCEVTypes RootKind)
4182 : SE(SE), RootKind(RootKind),
4183 NonSequentialRootKind(
4184 SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType(
4185 RootKind)) {}
4186
4187 bool /*Changed*/ visit(SCEVTypes Kind, ArrayRef<SCEVUse> OrigOps,
4188 SmallVectorImpl<SCEVUse> &NewOps) {
4189 bool Changed = false;
4191 Ops.reserve(OrigOps.size());
4192
4193 for (const SCEV *Op : OrigOps) {
4194 RetVal NewOp = visit(Op);
4195 if (NewOp != Op)
4196 Changed = true;
4197 if (NewOp)
4198 Ops.emplace_back(*NewOp);
4199 }
4200
4201 if (Changed)
4202 NewOps = std::move(Ops);
4203 return Changed;
4204 }
4205
4206 RetVal visitConstant(const SCEVConstant *Constant) { return Constant; }
4207
4208 RetVal visitVScale(const SCEVVScale *VScale) { return VScale; }
4209
4210 RetVal visitPtrToAddrExpr(const SCEVPtrToAddrExpr *Expr) { return Expr; }
4211
4212 RetVal visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { return Expr; }
4213
4214 RetVal visitTruncateExpr(const SCEVTruncateExpr *Expr) { return Expr; }
4215
4216 RetVal visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { return Expr; }
4217
4218 RetVal visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { return Expr; }
4219
4220 RetVal visitAddExpr(const SCEVAddExpr *Expr) { return Expr; }
4221
4222 RetVal visitMulExpr(const SCEVMulExpr *Expr) { return Expr; }
4223
4224 RetVal visitUDivExpr(const SCEVUDivExpr *Expr) { return Expr; }
4225
4226 RetVal visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
4227
4228 RetVal visitSMaxExpr(const SCEVSMaxExpr *Expr) {
4229 return visitAnyMinMaxExpr(Expr);
4230 }
4231
4232 RetVal visitUMaxExpr(const SCEVUMaxExpr *Expr) {
4233 return visitAnyMinMaxExpr(Expr);
4234 }
4235
4236 RetVal visitSMinExpr(const SCEVSMinExpr *Expr) {
4237 return visitAnyMinMaxExpr(Expr);
4238 }
4239
4240 RetVal visitUMinExpr(const SCEVUMinExpr *Expr) {
4241 return visitAnyMinMaxExpr(Expr);
4242 }
4243
4244 RetVal visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
4245 return visitAnyMinMaxExpr(Expr);
4246 }
4247
4248 RetVal visitUnknown(const SCEVUnknown *Expr) { return Expr; }
4249
4250 RetVal visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { return Expr; }
4251};
4252
4253} // namespace
4254
4256 switch (Kind) {
4257 case scConstant:
4258 case scVScale:
4259 case scTruncate:
4260 case scZeroExtend:
4261 case scSignExtend:
4262 case scPtrToAddr:
4263 case scPtrToInt:
4264 case scAddExpr:
4265 case scMulExpr:
4266 case scUDivExpr:
4267 case scAddRecExpr:
4268 case scUMaxExpr:
4269 case scSMaxExpr:
4270 case scUMinExpr:
4271 case scSMinExpr:
4272 case scUnknown:
4273 // If any operand is poison, the whole expression is poison.
4274 return true;
4276 // FIXME: if the *first* operand is poison, the whole expression is poison.
4277 return false; // Pessimistically, say that it does not propagate poison.
4278 case scCouldNotCompute:
4279 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
4280 }
4281 llvm_unreachable("Unknown SCEV kind!");
4282}
4283
4284namespace {
4285// The only way poison may be introduced in a SCEV expression is from a
4286// poison SCEVUnknown (ConstantExprs are also represented as SCEVUnknown,
4287// not SCEVConstant). Notably, nowrap flags in SCEV nodes can *not*
4288// introduce poison -- they encode guaranteed, non-speculated knowledge.
4289//
4290// Additionally, all SCEV nodes propagate poison from inputs to outputs,
4291// with the notable exception of umin_seq, where only poison from the first
4292// operand is (unconditionally) propagated.
4293struct SCEVPoisonCollector {
4294 bool LookThroughMaybePoisonBlocking;
4295 SmallPtrSet<const SCEVUnknown *, 4> MaybePoison;
4296 SCEVPoisonCollector(bool LookThroughMaybePoisonBlocking)
4297 : LookThroughMaybePoisonBlocking(LookThroughMaybePoisonBlocking) {}
4298
4299 bool follow(const SCEV *S) {
4300 if (!LookThroughMaybePoisonBlocking &&
4302 return false;
4303
4304 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
4305 if (!isGuaranteedNotToBePoison(SU->getValue()))
4306 MaybePoison.insert(SU);
4307 }
4308 return true;
4309 }
4310 bool isDone() const { return false; }
4311};
4312} // namespace
4313
4314/// Return true if V is poison given that AssumedPoison is already poison.
4315static bool impliesPoison(const SCEV *AssumedPoison, const SCEV *S) {
4316 // First collect all SCEVs that might result in AssumedPoison to be poison.
4317 // We need to look through potentially poison-blocking operations here,
4318 // because we want to find all SCEVs that *might* result in poison, not only
4319 // those that are *required* to.
4320 SCEVPoisonCollector PC1(/* LookThroughMaybePoisonBlocking */ true);
4321 visitAll(AssumedPoison, PC1);
4322
4323 // AssumedPoison is never poison. As the assumption is false, the implication
4324 // is true. Don't bother walking the other SCEV in this case.
4325 if (PC1.MaybePoison.empty())
4326 return true;
4327
4328 // Collect all SCEVs in S that, if poison, *will* result in S being poison
4329 // as well. We cannot look through potentially poison-blocking operations
4330 // here, as their arguments only *may* make the result poison.
4331 SCEVPoisonCollector PC2(/* LookThroughMaybePoisonBlocking */ false);
4332 visitAll(S, PC2);
4333
4334 // Make sure that no matter which SCEV in PC1.MaybePoison is actually poison,
4335 // it will also make S poison by being part of PC2.MaybePoison.
4336 return llvm::set_is_subset(PC1.MaybePoison, PC2.MaybePoison);
4337}
4338
4340 SmallPtrSetImpl<const Value *> &Result, const SCEV *S) {
4341 SCEVPoisonCollector PC(/* LookThroughMaybePoisonBlocking */ false);
4342 visitAll(S, PC);
4343 for (const SCEVUnknown *SU : PC.MaybePoison)
4344 Result.insert(SU->getValue());
4345}
4346
4348 const SCEV *S, Instruction *I,
4349 SmallVectorImpl<Instruction *> &DropPoisonGeneratingInsts) {
4350 // If the instruction cannot be poison, it's always safe to reuse.
4352 return true;
4353
4354 // Otherwise, it is possible that I is more poisonous that S. Collect the
4355 // poison-contributors of S, and then check whether I has any additional
4356 // poison-contributors. Poison that is contributed through poison-generating
4357 // flags is handled by dropping those flags instead.
4359 getPoisonGeneratingValues(PoisonVals, S);
4360
4361 SmallVector<Value *> Worklist;
4363 Worklist.push_back(I);
4364 while (!Worklist.empty()) {
4365 Value *V = Worklist.pop_back_val();
4366 if (!Visited.insert(V).second)
4367 continue;
4368
4369 // Avoid walking large instruction graphs.
4370 if (Visited.size() > 16)
4371 return false;
4372
4373 // Either the value can't be poison, or the S would also be poison if it
4374 // is.
4375 if (PoisonVals.contains(V) || ::isGuaranteedNotToBePoison(V))
4376 continue;
4377
4378 auto *I = dyn_cast<Instruction>(V);
4379 if (!I)
4380 return false;
4381
4382 // Disjoint or instructions are interpreted as adds by SCEV. However, we
4383 // can't replace an arbitrary add with disjoint or, even if we drop the
4384 // flag. We would need to convert the or into an add.
4385 if (auto *PDI = dyn_cast<PossiblyDisjointInst>(I))
4386 if (PDI->isDisjoint())
4387 return false;
4388
4389 // FIXME: Ignore vscale, even though it technically could be poison. Do this
4390 // because SCEV currently assumes it can't be poison. Remove this special
4391 // case once we proper model when vscale can be poison.
4392 if (auto *II = dyn_cast<IntrinsicInst>(I);
4393 II && II->getIntrinsicID() == Intrinsic::vscale)
4394 continue;
4395
4396 if (canCreatePoison(cast<Operator>(I), /*ConsiderFlagsAndMetadata*/ false))
4397 return false;
4398
4399 // If the instruction can't create poison, we can recurse to its operands.
4400 if (I->hasPoisonGeneratingAnnotations())
4401 DropPoisonGeneratingInsts.push_back(I);
4402
4403 llvm::append_range(Worklist, I->operands());
4404 }
4405 return true;
4406}
4407
4408const SCEV *
4411 assert(SCEVSequentialMinMaxExpr::isSequentialMinMaxType(Kind) &&
4412 "Not a SCEVSequentialMinMaxExpr!");
4413 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
4414 if (Ops.size() == 1)
4415 return Ops[0];
4416#ifndef NDEBUG
4417 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
4418 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4419 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
4420 "Operand types don't match!");
4421 assert(Ops[0]->getType()->isPointerTy() ==
4422 Ops[i]->getType()->isPointerTy() &&
4423 "min/max should be consistently pointerish");
4424 }
4425#endif
4426
4427 // Note that SCEVSequentialMinMaxExpr is *NOT* commutative,
4428 // so we can *NOT* do any kind of sorting of the expressions!
4429
4430 // Check if we have created the same expression before.
4431 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops))
4432 return S;
4433
4434 // FIXME: there are *some* simplifications that we can do here.
4435
4436 // Keep only the first instance of an operand.
4437 {
4438 SCEVSequentialMinMaxDeduplicatingVisitor Deduplicator(*this, Kind);
4439 bool Changed = Deduplicator.visit(Kind, Ops, Ops);
4440 if (Changed)
4441 return getSequentialMinMaxExpr(Kind, Ops);
4442 }
4443
4444 // Check to see if one of the operands is of the same kind. If so, expand its
4445 // operands onto our operand list, and recurse to simplify.
4446 {
4447 unsigned Idx = 0;
4448 bool DeletedAny = false;
4449 while (Idx < Ops.size()) {
4450 if (Ops[Idx]->getSCEVType() != Kind) {
4451 ++Idx;
4452 continue;
4453 }
4454 const auto *SMME = cast<SCEVSequentialMinMaxExpr>(Ops[Idx]);
4455 Ops.erase(Ops.begin() + Idx);
4456 Ops.insert(Ops.begin() + Idx, SMME->operands().begin(),
4457 SMME->operands().end());
4458 DeletedAny = true;
4459 }
4460
4461 if (DeletedAny)
4462 return getSequentialMinMaxExpr(Kind, Ops);
4463 }
4464
4465 const SCEV *SaturationPoint;
4467 switch (Kind) {
4469 SaturationPoint = getZero(Ops[0]->getType());
4470 Pred = ICmpInst::ICMP_ULE;
4471 break;
4472 default:
4473 llvm_unreachable("Not a sequential min/max type.");
4474 }
4475
4476 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4477 if (!isGuaranteedNotToCauseUB(Ops[i]))
4478 continue;
4479 // We can replace %x umin_seq %y with %x umin %y if either:
4480 // * %y being poison implies %x is also poison.
4481 // * %x cannot be the saturating value (e.g. zero for umin).
4482 if (::impliesPoison(Ops[i], Ops[i - 1]) ||
4483 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_NE, Ops[i - 1],
4484 SaturationPoint)) {
4485 SmallVector<SCEVUse, 2> SeqOps = {Ops[i - 1], Ops[i]};
4486 Ops[i - 1] = getMinMaxExpr(
4488 SeqOps);
4489 Ops.erase(Ops.begin() + i);
4490 return getSequentialMinMaxExpr(Kind, Ops);
4491 }
4492 // Fold %x umin_seq %y to %x if %x ule %y.
4493 // TODO: We might be able to prove the predicate for a later operand.
4494 if (isKnownViaNonRecursiveReasoning(Pred, Ops[i - 1], Ops[i])) {
4495 Ops.erase(Ops.begin() + i);
4496 return getSequentialMinMaxExpr(Kind, Ops);
4497 }
4498 }
4499
4500 // Okay, it looks like we really DO need an expr. Check to see if we
4501 // already have one, otherwise create a new one.
4503 ID.AddInteger(Kind);
4504 for (const SCEV *Op : Ops)
4505 ID.AddPointer(Op);
4506 void *IP = nullptr;
4507 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
4508 if (ExistingSCEV)
4509 return ExistingSCEV;
4510
4511 SCEVUse *O = SCEVAllocator.Allocate<SCEVUse>(Ops.size());
4513 SCEV *S = new (SCEVAllocator)
4514 SCEVSequentialMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
4515
4516 UniqueSCEVs.InsertNode(S, IP);
4517 S->computeAndSetCanonical(*this);
4518 registerUser(S, Ops);
4519 return S;
4520}
4521
4526
4530
4535
4539
4544
4548
4550 bool Sequential) {
4551 SmallVector<SCEVUse, 2> Ops = {LHS, RHS};
4552 return getUMinExpr(Ops, Sequential);
4553}
4554
4560
4561const SCEV *
4563 const SCEV *Res = getConstant(IntTy, Size.getKnownMinValue());
4564 if (Size.isScalable())
4565 Res = getMulExpr(Res, getVScale(IntTy));
4566 return Res;
4567}
4568
4570 return getSizeOfExpr(IntTy, getDataLayout().getTypeAllocSize(AllocTy));
4571}
4572
4574 return getSizeOfExpr(IntTy, getDataLayout().getTypeStoreSize(StoreTy));
4575}
4576
4578 StructType *STy,
4579 unsigned FieldNo) {
4580 // We can bypass creating a target-independent constant expression and then
4581 // folding it back into a ConstantInt. This is just a compile-time
4582 // optimization.
4583 const StructLayout *SL = getDataLayout().getStructLayout(STy);
4584 assert(!SL->getSizeInBits().isScalable() &&
4585 "Cannot get offset for structure containing scalable vector types");
4586 return getConstant(IntTy, SL->getElementOffset(FieldNo));
4587}
4588
4590 // Don't attempt to do anything other than create a SCEVUnknown object
4591 // here. createSCEV only calls getUnknown after checking for all other
4592 // interesting possibilities, and any other code that calls getUnknown
4593 // is doing so in order to hide a value from SCEV canonicalization.
4594
4596 ID.AddInteger(scUnknown);
4597 ID.AddPointer(V);
4598 void *IP = nullptr;
4599 if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) {
4600 assert(cast<SCEVUnknown>(S)->getValue() == V &&
4601 "Stale SCEVUnknown in uniquing map!");
4602 return S;
4603 }
4604 SCEV *S = new (SCEVAllocator) SCEVUnknown(ID.Intern(SCEVAllocator), V, this,
4605 FirstUnknown);
4606 FirstUnknown = cast<SCEVUnknown>(S);
4607 UniqueSCEVs.InsertNode(S, IP);
4608 S->computeAndSetCanonical(*this);
4609 return S;
4610}
4611
4612//===----------------------------------------------------------------------===//
4613// Basic SCEV Analysis and PHI Idiom Recognition Code
4614//
4615
4616/// Test if values of the given type are analyzable within the SCEV
4617/// framework. This primarily includes integer types, and it can optionally
4618/// include pointer types if the ScalarEvolution class has access to
4619/// target-specific information.
4621 // Integers and pointers are always SCEVable.
4622 return Ty->isIntOrPtrTy();
4623}
4624
4625/// Return the size in bits of the specified type, for which isSCEVable must
4626/// return true.
4628 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4629 if (Ty->isPointerTy())
4631 return getDataLayout().getTypeSizeInBits(Ty);
4632}
4633
4634/// Return a type with the same bitwidth as the given type and which represents
4635/// how SCEV will treat the given type, for which isSCEVable must return
4636/// true. For pointer types, this is the pointer index sized integer type.
4638 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4639
4640 if (Ty->isIntegerTy())
4641 return Ty;
4642
4643 // The only other support type is pointer.
4644 assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!");
4645 return getDataLayout().getIndexType(Ty);
4646}
4647
4649 return getTypeSizeInBits(T1) >= getTypeSizeInBits(T2) ? T1 : T2;
4650}
4651
4653 const SCEV *B) {
4654 /// For a valid use point to exist, the defining scope of one operand
4655 /// must dominate the other.
4656 bool PreciseA, PreciseB;
4657 auto *ScopeA = getDefiningScopeBound({A}, PreciseA);
4658 auto *ScopeB = getDefiningScopeBound({B}, PreciseB);
4659 if (!PreciseA || !PreciseB)
4660 // Can't tell.
4661 return false;
4662 return (ScopeA == ScopeB) || DT.dominates(ScopeA, ScopeB) ||
4663 DT.dominates(ScopeB, ScopeA);
4664}
4665
4667 return CouldNotCompute.get();
4668}
4669
4670bool ScalarEvolution::checkValidity(const SCEV *S) const {
4671 bool ContainsNulls = SCEVExprContains(S, [](const SCEV *S) {
4672 auto *SU = dyn_cast<SCEVUnknown>(S);
4673 return SU && SU->getValue() == nullptr;
4674 });
4675
4676 return !ContainsNulls;
4677}
4678
4680 HasRecMapType::iterator I = HasRecMap.find(S);
4681 if (I != HasRecMap.end())
4682 return I->second;
4683
4684 bool FoundAddRec =
4685 SCEVExprContains(S, [](const SCEV *S) { return isa<SCEVAddRecExpr>(S); });
4686 HasRecMap.insert({S, FoundAddRec});
4687 return FoundAddRec;
4688}
4689
4690/// Return the ValueOffsetPair set for \p S. \p S can be represented
4691/// by the value and offset from any ValueOffsetPair in the set.
4692ArrayRef<Value *> ScalarEvolution::getSCEVValues(const SCEV *S) {
4693 ExprValueMapType::iterator SI = ExprValueMap.find_as(S);
4694 if (SI == ExprValueMap.end())
4695 return {};
4696 return SI->second.getArrayRef();
4697}
4698
4699/// Erase Value from ValueExprMap and ExprValueMap. ValueExprMap.erase(V)
4700/// cannot be used separately. eraseValueFromMap should be used to remove
4701/// V from ValueExprMap and ExprValueMap at the same time.
4702void ScalarEvolution::eraseValueFromMap(Value *V) {
4703 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4704 if (I != ValueExprMap.end()) {
4705 auto EVIt = ExprValueMap.find(I->second);
4706 bool Removed = EVIt->second.remove(V);
4707 (void) Removed;
4708 assert(Removed && "Value not in ExprValueMap?");
4709 ValueExprMap.erase(I);
4710 }
4711}
4712
4713void ScalarEvolution::insertValueToMap(Value *V, const SCEV *S) {
4714 // A recursive query may have already computed the SCEV. It should be
4715 // equivalent, but may not necessarily be exactly the same, e.g. due to lazily
4716 // inferred nowrap flags.
4717 auto It = ValueExprMap.find_as(V);
4718 if (It == ValueExprMap.end()) {
4719 ValueExprMap.insert({SCEVCallbackVH(V, this), S});
4720 ExprValueMap[S].insert(V);
4721 }
4722}
4723
4724/// Return an existing SCEV if it exists, otherwise analyze the expression and
4725/// create a new one.
4727 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4728
4729 if (const SCEV *S = getExistingSCEV(V))
4730 return S;
4731 return createSCEVIter(V);
4732}
4733
4735 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4736
4737 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4738 if (I != ValueExprMap.end()) {
4739 const SCEV *S = I->second;
4740 assert(checkValidity(S) &&
4741 "existing SCEV has not been properly invalidated");
4742 return S;
4743 }
4744 return nullptr;
4745}
4746
4747/// Return a SCEV corresponding to -V = -1*V
4749 SCEV::NoWrapFlags Flags) {
4750 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4751 return getConstant(
4752 cast<ConstantInt>(ConstantExpr::getNeg(VC->getValue())));
4753
4754 Type *Ty = V->getType();
4755 Ty = getEffectiveSCEVType(Ty);
4756 return getMulExpr(V, getMinusOne(Ty), Flags);
4757}
4758
4759/// If Expr computes ~A, return A else return nullptr
4760static const SCEV *MatchNotExpr(const SCEV *Expr) {
4761 const SCEV *MulOp;
4762 if (match(Expr, m_scev_Add(m_scev_AllOnes(),
4763 m_scev_Mul(m_scev_AllOnes(), m_SCEV(MulOp)))))
4764 return MulOp;
4765 return nullptr;
4766}
4767
4768/// Return a SCEV corresponding to ~V = -1-V
4770 assert(!V->getType()->isPointerTy() && "Can't negate pointer");
4771
4772 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4773 return getConstant(
4774 cast<ConstantInt>(ConstantExpr::getNot(VC->getValue())));
4775
4776 // Fold ~(u|s)(min|max)(~x, ~y) to (u|s)(max|min)(x, y)
4777 if (const SCEVMinMaxExpr *MME = dyn_cast<SCEVMinMaxExpr>(V)) {
4778 auto MatchMinMaxNegation = [&](const SCEVMinMaxExpr *MME) {
4779 SmallVector<SCEVUse, 2> MatchedOperands;
4780 for (const SCEV *Operand : MME->operands()) {
4781 const SCEV *Matched = MatchNotExpr(Operand);
4782 if (!Matched)
4783 return (const SCEV *)nullptr;
4784 MatchedOperands.push_back(Matched);
4785 }
4786 return getMinMaxExpr(SCEVMinMaxExpr::negate(MME->getSCEVType()),
4787 MatchedOperands);
4788 };
4789 if (const SCEV *Replaced = MatchMinMaxNegation(MME))
4790 return Replaced;
4791 }
4792
4793 Type *Ty = V->getType();
4794 Ty = getEffectiveSCEVType(Ty);
4795 return getMinusSCEV(getMinusOne(Ty), V);
4796}
4797
4799 assert(P->getType()->isPointerTy());
4800
4801 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(P)) {
4802 // The base of an AddRec is the first operand.
4803 SmallVector<SCEVUse> Ops{AddRec->operands()};
4804 Ops[0] = removePointerBase(Ops[0]);
4805 // Don't try to transfer nowrap flags for now. We could in some cases
4806 // (for example, if pointer operand of the AddRec is a SCEVUnknown).
4807 return getAddRecExpr(Ops, AddRec->getLoop(), SCEV::FlagAnyWrap);
4808 }
4809 if (auto *Add = dyn_cast<SCEVAddExpr>(P)) {
4810 // The base of an Add is the pointer operand.
4811 SmallVector<SCEVUse> Ops{Add->operands()};
4812 SCEVUse *PtrOp = nullptr;
4813 for (SCEVUse &AddOp : Ops) {
4814 if (AddOp->getType()->isPointerTy()) {
4815 assert(!PtrOp && "Cannot have multiple pointer ops");
4816 PtrOp = &AddOp;
4817 }
4818 }
4819 *PtrOp = removePointerBase(*PtrOp);
4820 // Don't try to transfer nowrap flags for now. We could in some cases
4821 // (for example, if the pointer operand of the Add is a SCEVUnknown).
4822 return getAddExpr(Ops);
4823 }
4824 // Any other expression must be a pointer base.
4825 return getZero(P->getType());
4826}
4827
4829 SCEV::NoWrapFlags Flags,
4830 unsigned Depth) {
4831 // Fast path: X - X --> 0.
4832 if (LHS == RHS)
4833 return getZero(LHS->getType());
4834
4835 // If we subtract two pointers with different pointer bases, bail.
4836 // Eventually, we're going to add an assertion to getMulExpr that we
4837 // can't multiply by a pointer.
4838 if (RHS->getType()->isPointerTy()) {
4839 if (!LHS->getType()->isPointerTy() ||
4840 getPointerBase(LHS) != getPointerBase(RHS))
4841 return getCouldNotCompute();
4842 LHS = removePointerBase(LHS);
4843 RHS = removePointerBase(RHS);
4844 }
4845
4846 // We represent LHS - RHS as LHS + (-1)*RHS. This transformation
4847 // makes it so that we cannot make much use of NUW.
4848 auto AddFlags = SCEV::FlagAnyWrap;
4849 const bool RHSIsNotMinSigned =
4851 if (hasFlags(Flags, SCEV::FlagNSW)) {
4852 // Let M be the minimum representable signed value. Then (-1)*RHS
4853 // signed-wraps if and only if RHS is M. That can happen even for
4854 // a NSW subtraction because e.g. (-1)*M signed-wraps even though
4855 // -1 - M does not. So to transfer NSW from LHS - RHS to LHS +
4856 // (-1)*RHS, we need to prove that RHS != M.
4857 //
4858 // If LHS is non-negative and we know that LHS - RHS does not
4859 // signed-wrap, then RHS cannot be M. So we can rule out signed-wrap
4860 // either by proving that RHS > M or that LHS >= 0.
4861 if (RHSIsNotMinSigned || isKnownNonNegative(LHS)) {
4862 AddFlags = SCEV::FlagNSW;
4863 }
4864 }
4865
4866 // FIXME: Find a correct way to transfer NSW to (-1)*M when LHS -
4867 // RHS is NSW and LHS >= 0.
4868 //
4869 // The difficulty here is that the NSW flag may have been proven
4870 // relative to a loop that is to be found in a recurrence in LHS and
4871 // not in RHS. Applying NSW to (-1)*M may then let the NSW have a
4872 // larger scope than intended.
4873 auto NegFlags = RHSIsNotMinSigned ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
4874
4875 return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags, Depth);
4876}
4877
4879 unsigned Depth) {
4880 Type *SrcTy = V->getType();
4881 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4882 "Cannot truncate or zero extend with non-integer arguments!");
4883 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4884 return V; // No conversion
4885 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4886 return getTruncateExpr(V, Ty, Depth);
4887 return getZeroExtendExpr(V, Ty, Depth);
4888}
4889
4891 unsigned Depth) {
4892 Type *SrcTy = V->getType();
4893 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4894 "Cannot truncate or zero extend with non-integer arguments!");
4895 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4896 return V; // No conversion
4897 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4898 return getTruncateExpr(V, Ty, Depth);
4899 return getSignExtendExpr(V, Ty, Depth);
4900}
4901
4902const SCEV *
4904 Type *SrcTy = V->getType();
4905 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4906 "Cannot noop or zero extend with non-integer arguments!");
4908 "getNoopOrZeroExtend cannot truncate!");
4909 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4910 return V; // No conversion
4911 return getZeroExtendExpr(V, Ty);
4912}
4913
4914const SCEV *
4916 Type *SrcTy = V->getType();
4917 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4918 "Cannot noop or sign extend with non-integer arguments!");
4920 "getNoopOrSignExtend cannot truncate!");
4921 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4922 return V; // No conversion
4923 return getSignExtendExpr(V, Ty);
4924}
4925
4926const SCEV *
4928 Type *SrcTy = V->getType();
4929 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4930 "Cannot noop or any extend with non-integer arguments!");
4932 "getNoopOrAnyExtend cannot truncate!");
4933 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4934 return V; // No conversion
4935 return getAnyExtendExpr(V, Ty);
4936}
4937
4938const SCEV *
4940 Type *SrcTy = V->getType();
4941 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4942 "Cannot truncate or noop with non-integer arguments!");
4944 "getTruncateOrNoop cannot extend!");
4945 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4946 return V; // No conversion
4947 return getTruncateExpr(V, Ty);
4948}
4949
4951 const SCEV *RHS) {
4952 const SCEV *PromotedLHS = LHS;
4953 const SCEV *PromotedRHS = RHS;
4954
4955 if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType()))
4956 PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
4957 else
4958 PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
4959
4960 return getUMaxExpr(PromotedLHS, PromotedRHS);
4961}
4962
4964 const SCEV *RHS,
4965 bool Sequential) {
4966 SmallVector<SCEVUse, 2> Ops = {LHS, RHS};
4967 return getUMinFromMismatchedTypes(Ops, Sequential);
4968}
4969
4970const SCEV *
4972 bool Sequential) {
4973 assert(!Ops.empty() && "At least one operand must be!");
4974 // Trivial case.
4975 if (Ops.size() == 1)
4976 return Ops[0];
4977
4978 // Find the max type first.
4979 Type *MaxType = nullptr;
4980 for (SCEVUse S : Ops)
4981 if (MaxType)
4982 MaxType = getWiderType(MaxType, S->getType());
4983 else
4984 MaxType = S->getType();
4985 assert(MaxType && "Failed to find maximum type!");
4986
4987 // Extend all ops to max type.
4988 SmallVector<SCEVUse, 2> PromotedOps;
4989 for (SCEVUse S : Ops)
4990 PromotedOps.push_back(getNoopOrZeroExtend(S, MaxType));
4991
4992 // Generate umin.
4993 return getUMinExpr(PromotedOps, Sequential);
4994}
4995
4997 // A pointer operand may evaluate to a nonpointer expression, such as null.
4998 if (!V->getType()->isPointerTy())
4999 return V;
5000
5001 while (true) {
5002 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(V)) {
5003 V = AddRec->getStart();
5004 } else if (auto *Add = dyn_cast<SCEVAddExpr>(V)) {
5005 const SCEV *PtrOp = nullptr;
5006 for (const SCEV *AddOp : Add->operands()) {
5007 if (AddOp->getType()->isPointerTy()) {
5008 assert(!PtrOp && "Cannot have multiple pointer ops");
5009 PtrOp = AddOp;
5010 }
5011 }
5012 assert(PtrOp && "Must have pointer op");
5013 V = PtrOp;
5014 } else // Not something we can look further into.
5015 return V;
5016 }
5017}
5018
5019/// Push users of the given Instruction onto the given Worklist.
5023 // Push the def-use children onto the Worklist stack.
5024 for (User *U : I->users()) {
5025 auto *UserInsn = cast<Instruction>(U);
5026 if (Visited.insert(UserInsn).second)
5027 Worklist.push_back(UserInsn);
5028 }
5029}
5030
5031namespace {
5032
5033/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its start
5034/// expression in case its Loop is L. If it is not L then
5035/// if IgnoreOtherLoops is true then use AddRec itself
5036/// otherwise rewrite cannot be done.
5037/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
5038class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> {
5039public:
5040 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
5041 bool IgnoreOtherLoops = true) {
5042 SCEVInitRewriter Rewriter(L, SE);
5043 const SCEV *Result = Rewriter.visit(S);
5044 if (Rewriter.hasSeenLoopVariantSCEVUnknown())
5045 return SE.getCouldNotCompute();
5046 return Rewriter.hasSeenOtherLoops() && !IgnoreOtherLoops
5047 ? SE.getCouldNotCompute()
5048 : Result;
5049 }
5050
5051 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5052 if (!SE.isLoopInvariant(Expr, L))
5053 SeenLoopVariantSCEVUnknown = true;
5054 return Expr;
5055 }
5056
5057 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
5058 // Only re-write AddRecExprs for this loop.
5059 if (Expr->getLoop() == L)
5060 return Expr->getStart();
5061 SeenOtherLoops = true;
5062 return Expr;
5063 }
5064
5065 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
5066
5067 bool hasSeenOtherLoops() { return SeenOtherLoops; }
5068
5069private:
5070 explicit SCEVInitRewriter(const Loop *L, ScalarEvolution &SE)
5071 : SCEVRewriteVisitor(SE), L(L) {}
5072
5073 const Loop *L;
5074 bool SeenLoopVariantSCEVUnknown = false;
5075 bool SeenOtherLoops = false;
5076};
5077
5078/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its post
5079/// increment expression in case its Loop is L. If it is not L then
5080/// use AddRec itself.
5081/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
5082class SCEVPostIncRewriter : public SCEVRewriteVisitor<SCEVPostIncRewriter> {
5083public:
5084 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE) {
5085 SCEVPostIncRewriter Rewriter(L, SE);
5086 const SCEV *Result = Rewriter.visit(S);
5087 return Rewriter.hasSeenLoopVariantSCEVUnknown()
5088 ? SE.getCouldNotCompute()
5089 : Result;
5090 }
5091
5092 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5093 if (!SE.isLoopInvariant(Expr, L))
5094 SeenLoopVariantSCEVUnknown = true;
5095 return Expr;
5096 }
5097
5098 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
5099 // Only re-write AddRecExprs for this loop.
5100 if (Expr->getLoop() == L)
5101 return Expr->getPostIncExpr(SE);
5102 SeenOtherLoops = true;
5103 return Expr;
5104 }
5105
5106 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
5107
5108 bool hasSeenOtherLoops() { return SeenOtherLoops; }
5109
5110private:
5111 explicit SCEVPostIncRewriter(const Loop *L, ScalarEvolution &SE)
5112 : SCEVRewriteVisitor(SE), L(L) {}
5113
5114 const Loop *L;
5115 bool SeenLoopVariantSCEVUnknown = false;
5116 bool SeenOtherLoops = false;
5117};
5118
5119/// This class evaluates the compare condition by matching it against the
5120/// condition of loop latch. If there is a match we assume a true value
5121/// for the condition while building SCEV nodes.
5122class SCEVBackedgeConditionFolder
5123 : public SCEVRewriteVisitor<SCEVBackedgeConditionFolder> {
5124public:
5125 static const SCEV *rewrite(const SCEV *S, const Loop *L,
5126 ScalarEvolution &SE) {
5127 bool IsPosBECond = false;
5128 Value *BECond = nullptr;
5129 if (BasicBlock *Latch = L->getLoopLatch()) {
5130 if (CondBrInst *BI = dyn_cast<CondBrInst>(Latch->getTerminator())) {
5131 assert(BI->getSuccessor(0) != BI->getSuccessor(1) &&
5132 "Both outgoing branches should not target same header!");
5133 BECond = BI->getCondition();
5134 IsPosBECond = BI->getSuccessor(0) == L->getHeader();
5135 } else {
5136 return S;
5137 }
5138 }
5139 SCEVBackedgeConditionFolder Rewriter(L, BECond, IsPosBECond, SE);
5140 return Rewriter.visit(S);
5141 }
5142
5143 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5144 const SCEV *Result = Expr;
5145 bool InvariantF = SE.isLoopInvariant(Expr, L);
5146
5147 if (!InvariantF) {
5149 switch (I->getOpcode()) {
5150 case Instruction::Select: {
5151 SelectInst *SI = cast<SelectInst>(I);
5152 std::optional<const SCEV *> Res =
5153 compareWithBackedgeCondition(SI->getCondition());
5154 if (Res) {
5155 bool IsOne = cast<SCEVConstant>(*Res)->getValue()->isOne();
5156 Result = SE.getSCEV(IsOne ? SI->getTrueValue() : SI->getFalseValue());
5157 }
5158 break;
5159 }
5160 default: {
5161 std::optional<const SCEV *> Res = compareWithBackedgeCondition(I);
5162 if (Res)
5163 Result = *Res;
5164 break;
5165 }
5166 }
5167 }
5168 return Result;
5169 }
5170
5171private:
5172 explicit SCEVBackedgeConditionFolder(const Loop *L, Value *BECond,
5173 bool IsPosBECond, ScalarEvolution &SE)
5174 : SCEVRewriteVisitor(SE), L(L), BackedgeCond(BECond),
5175 IsPositiveBECond(IsPosBECond) {}
5176
5177 std::optional<const SCEV *> compareWithBackedgeCondition(Value *IC);
5178
5179 const Loop *L;
5180 /// Loop back condition.
5181 Value *BackedgeCond = nullptr;
5182 /// Set to true if loop back is on positive branch condition.
5183 bool IsPositiveBECond;
5184};
5185
5186std::optional<const SCEV *>
5187SCEVBackedgeConditionFolder::compareWithBackedgeCondition(Value *IC) {
5188
5189 // If value matches the backedge condition for loop latch,
5190 // then return a constant evolution node based on loopback
5191 // branch taken.
5192 if (BackedgeCond == IC)
5193 return IsPositiveBECond ? SE.getOne(Type::getInt1Ty(SE.getContext()))
5195 return std::nullopt;
5196}
5197
5198class SCEVShiftRewriter : public SCEVRewriteVisitor<SCEVShiftRewriter> {
5199public:
5200 static const SCEV *rewrite(const SCEV *S, const Loop *L,
5201 ScalarEvolution &SE) {
5202 SCEVShiftRewriter Rewriter(L, SE);
5203 const SCEV *Result = Rewriter.visit(S);
5204 return Rewriter.isValid() ? Result : SE.getCouldNotCompute();
5205 }
5206
5207 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5208 // Only allow AddRecExprs for this loop.
5209 if (!SE.isLoopInvariant(Expr, L))
5210 Valid = false;
5211 return Expr;
5212 }
5213
5214 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
5215 if (Expr->getLoop() == L && Expr->isAffine())
5216 return SE.getMinusSCEV(Expr, Expr->getStepRecurrence(SE));
5217 Valid = false;
5218 return Expr;
5219 }
5220
5221 bool isValid() { return Valid; }
5222
5223private:
5224 explicit SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE)
5225 : SCEVRewriteVisitor(SE), L(L) {}
5226
5227 const Loop *L;
5228 bool Valid = true;
5229};
5230
5231} // end anonymous namespace
5232
5234ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) {
5235 if (!AR->isAffine())
5236 return SCEV::FlagAnyWrap;
5237
5238 using OBO = OverflowingBinaryOperator;
5239
5241
5242 if (!AR->hasNoSelfWrap()) {
5243 const SCEV *BECount = getConstantMaxBackedgeTakenCount(AR->getLoop());
5244 if (const SCEVConstant *BECountMax = dyn_cast<SCEVConstant>(BECount)) {
5245 ConstantRange StepCR = getSignedRange(AR->getStepRecurrence(*this));
5246 const APInt &BECountAP = BECountMax->getAPInt();
5247 unsigned NoOverflowBitWidth =
5248 BECountAP.getActiveBits() + StepCR.getMinSignedBits();
5249 if (NoOverflowBitWidth <= getTypeSizeInBits(AR->getType()))
5251 }
5252 }
5253
5254 if (!AR->hasNoSignedWrap()) {
5255 ConstantRange AddRecRange = getSignedRange(AR);
5256 ConstantRange IncRange = getSignedRange(AR->getStepRecurrence(*this));
5257
5259 Instruction::Add, IncRange, OBO::NoSignedWrap);
5260 if (NSWRegion.contains(AddRecRange))
5262 }
5263
5264 if (!AR->hasNoUnsignedWrap()) {
5265 ConstantRange AddRecRange = getUnsignedRange(AR);
5266 ConstantRange IncRange = getUnsignedRange(AR->getStepRecurrence(*this));
5267
5269 Instruction::Add, IncRange, OBO::NoUnsignedWrap);
5270 if (NUWRegion.contains(AddRecRange))
5272 }
5273
5274 return Result;
5275}
5276
5278ScalarEvolution::proveNoSignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5280
5281 if (AR->hasNoSignedWrap())
5282 return Result;
5283
5284 if (!AR->isAffine())
5285 return Result;
5286
5287 // This function can be expensive, only try to prove NSW once per AddRec.
5288 if (!SignedWrapViaInductionTried.insert(AR).second)
5289 return Result;
5290
5291 const SCEV *Step = AR->getStepRecurrence(*this);
5292 const Loop *L = AR->getLoop();
5293
5294 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5295 // Note that this serves two purposes: It filters out loops that are
5296 // simply not analyzable, and it covers the case where this code is
5297 // being called from within backedge-taken count analysis, such that
5298 // attempting to ask for the backedge-taken count would likely result
5299 // in infinite recursion. In the later case, the analysis code will
5300 // cope with a conservative value, and it will take care to purge
5301 // that value once it has finished.
5302 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5303
5304 // Normally, in the cases we can prove no-overflow via a
5305 // backedge guarding condition, we can also compute a backedge
5306 // taken count for the loop. The exceptions are assumptions and
5307 // guards present in the loop -- SCEV is not great at exploiting
5308 // these to compute max backedge taken counts, but can still use
5309 // these to prove lack of overflow. Use this fact to avoid
5310 // doing extra work that may not pay off.
5311
5312 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5313 AC.assumptions().empty())
5314 return Result;
5315
5316 // If the backedge is guarded by a comparison with the pre-inc value the
5317 // addrec is safe. Also, if the entry is guarded by a comparison with the
5318 // start value and the backedge is guarded by a comparison with the post-inc
5319 // value, the addrec is safe.
5321 const SCEV *OverflowLimit =
5322 getSignedOverflowLimitForStep(Step, &Pred, this);
5323 if (OverflowLimit &&
5324 (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) ||
5325 isKnownOnEveryIteration(Pred, AR, OverflowLimit))) {
5326 Result = setFlags(Result, SCEV::FlagNSW);
5327 }
5328 return Result;
5329}
5331ScalarEvolution::proveNoUnsignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5333
5334 if (AR->hasNoUnsignedWrap())
5335 return Result;
5336
5337 if (!AR->isAffine())
5338 return Result;
5339
5340 // This function can be expensive, only try to prove NUW once per AddRec.
5341 if (!UnsignedWrapViaInductionTried.insert(AR).second)
5342 return Result;
5343
5344 const SCEV *Step = AR->getStepRecurrence(*this);
5345 unsigned BitWidth = getTypeSizeInBits(AR->getType());
5346 const Loop *L = AR->getLoop();
5347
5348 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5349 // Note that this serves two purposes: It filters out loops that are
5350 // simply not analyzable, and it covers the case where this code is
5351 // being called from within backedge-taken count analysis, such that
5352 // attempting to ask for the backedge-taken count would likely result
5353 // in infinite recursion. In the later case, the analysis code will
5354 // cope with a conservative value, and it will take care to purge
5355 // that value once it has finished.
5356 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5357
5358 // Normally, in the cases we can prove no-overflow via a
5359 // backedge guarding condition, we can also compute a backedge
5360 // taken count for the loop. The exceptions are assumptions and
5361 // guards present in the loop -- SCEV is not great at exploiting
5362 // these to compute max backedge taken counts, but can still use
5363 // these to prove lack of overflow. Use this fact to avoid
5364 // doing extra work that may not pay off.
5365
5366 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5367 AC.assumptions().empty())
5368 return Result;
5369
5370 // If the backedge is guarded by a comparison with the pre-inc value the
5371 // addrec is safe. Also, if the entry is guarded by a comparison with the
5372 // start value and the backedge is guarded by a comparison with the post-inc
5373 // value, the addrec is safe.
5374 if (isKnownPositive(Step)) {
5375 const SCEV *N = getConstant(APInt::getMinValue(BitWidth) -
5376 getUnsignedRangeMax(Step));
5379 Result = setFlags(Result, SCEV::FlagNUW);
5380 }
5381 }
5382
5383 return Result;
5384}
5385
5386namespace {
5387
5388/// Represents an abstract binary operation. This may exist as a
5389/// normal instruction or constant expression, or may have been
5390/// derived from an expression tree.
5391struct BinaryOp {
5392 unsigned Opcode;
5393 Value *LHS;
5394 Value *RHS;
5395 bool IsNSW = false;
5396 bool IsNUW = false;
5397
5398 /// Op is set if this BinaryOp corresponds to a concrete LLVM instruction or
5399 /// constant expression.
5400 Operator *Op = nullptr;
5401
5402 explicit BinaryOp(Operator *Op)
5403 : Opcode(Op->getOpcode()), LHS(Op->getOperand(0)), RHS(Op->getOperand(1)),
5404 Op(Op) {
5405 if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Op)) {
5406 IsNSW = OBO->hasNoSignedWrap();
5407 IsNUW = OBO->hasNoUnsignedWrap();
5408 }
5409 }
5410
5411 explicit BinaryOp(unsigned Opcode, Value *LHS, Value *RHS, bool IsNSW = false,
5412 bool IsNUW = false)
5413 : Opcode(Opcode), LHS(LHS), RHS(RHS), IsNSW(IsNSW), IsNUW(IsNUW) {}
5414};
5415
5416} // end anonymous namespace
5417
5418/// Try to map \p V into a BinaryOp, and return \c std::nullopt on failure.
5419static std::optional<BinaryOp> MatchBinaryOp(Value *V, const DataLayout &DL,
5420 AssumptionCache &AC,
5421 const DominatorTree &DT,
5422 const Instruction *CxtI) {
5423 auto *Op = dyn_cast<Operator>(V);
5424 if (!Op)
5425 return std::nullopt;
5426
5427 // Implementation detail: all the cleverness here should happen without
5428 // creating new SCEV expressions -- our caller knowns tricks to avoid creating
5429 // SCEV expressions when possible, and we should not break that.
5430
5431 switch (Op->getOpcode()) {
5432 case Instruction::Add:
5433 case Instruction::Sub:
5434 case Instruction::Mul:
5435 case Instruction::UDiv:
5436 case Instruction::URem:
5437 case Instruction::And:
5438 case Instruction::AShr:
5439 case Instruction::Shl:
5440 return BinaryOp(Op);
5441
5442 case Instruction::Or: {
5443 // Convert or disjoint into add nuw nsw.
5444 if (cast<PossiblyDisjointInst>(Op)->isDisjoint())
5445 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1),
5446 /*IsNSW=*/true, /*IsNUW=*/true);
5447 return BinaryOp(Op);
5448 }
5449
5450 case Instruction::Xor:
5451 if (auto *RHSC = dyn_cast<ConstantInt>(Op->getOperand(1)))
5452 // If the RHS of the xor is a signmask, then this is just an add.
5453 // Instcombine turns add of signmask into xor as a strength reduction step.
5454 if (RHSC->getValue().isSignMask())
5455 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5456 // Binary `xor` is a bit-wise `add`.
5457 if (V->getType()->isIntegerTy(1))
5458 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5459 return BinaryOp(Op);
5460
5461 case Instruction::LShr:
5462 // Turn logical shift right of a constant into a unsigned divide.
5463 if (ConstantInt *SA = dyn_cast<ConstantInt>(Op->getOperand(1))) {
5464 uint32_t BitWidth = cast<IntegerType>(Op->getType())->getBitWidth();
5465
5466 // If the shift count is not less than the bitwidth, the result of
5467 // the shift is undefined. Don't try to analyze it, because the
5468 // resolution chosen here may differ from the resolution chosen in
5469 // other parts of the compiler.
5470 if (SA->getValue().ult(BitWidth)) {
5471 Constant *X =
5472 ConstantInt::get(SA->getContext(),
5473 APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
5474 return BinaryOp(Instruction::UDiv, Op->getOperand(0), X);
5475 }
5476 }
5477 return BinaryOp(Op);
5478
5479 case Instruction::ExtractValue: {
5480 auto *EVI = cast<ExtractValueInst>(Op);
5481 if (EVI->getNumIndices() != 1 || EVI->getIndices()[0] != 0)
5482 break;
5483
5484 auto *WO = dyn_cast<WithOverflowInst>(EVI->getAggregateOperand());
5485 if (!WO)
5486 break;
5487
5488 Instruction::BinaryOps BinOp = WO->getBinaryOp();
5489 bool Signed = WO->isSigned();
5490 // TODO: Should add nuw/nsw flags for mul as well.
5491 if (BinOp == Instruction::Mul || !isOverflowIntrinsicNoWrap(WO, DT))
5492 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS());
5493
5494 // Now that we know that all uses of the arithmetic-result component of
5495 // CI are guarded by the overflow check, we can go ahead and pretend
5496 // that the arithmetic is non-overflowing.
5497 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS(),
5498 /* IsNSW = */ Signed, /* IsNUW = */ !Signed);
5499 }
5500
5501 default:
5502 break;
5503 }
5504
5505 // Recognise intrinsic loop.decrement.reg, and as this has exactly the same
5506 // semantics as a Sub, return a binary sub expression.
5507 if (auto *II = dyn_cast<IntrinsicInst>(V))
5508 if (II->getIntrinsicID() == Intrinsic::loop_decrement_reg)
5509 return BinaryOp(Instruction::Sub, II->getOperand(0), II->getOperand(1));
5510
5511 return std::nullopt;
5512}
5513
5514/// Helper function to createAddRecFromPHIWithCasts. We have a phi
5515/// node whose symbolic (unknown) SCEV is \p SymbolicPHI, which is updated via
5516/// the loop backedge by a SCEVAddExpr, possibly also with a few casts on the
5517/// way. This function checks if \p Op, an operand of this SCEVAddExpr,
5518/// follows one of the following patterns:
5519/// Op == (SExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5520/// Op == (ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5521/// If the SCEV expression of \p Op conforms with one of the expected patterns
5522/// we return the type of the truncation operation, and indicate whether the
5523/// truncated type should be treated as signed/unsigned by setting
5524/// \p Signed to true/false, respectively.
5525static Type *isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI,
5526 bool &Signed, ScalarEvolution &SE) {
5527 // The case where Op == SymbolicPHI (that is, with no type conversions on
5528 // the way) is handled by the regular add recurrence creating logic and
5529 // would have already been triggered in createAddRecForPHI. Reaching it here
5530 // means that createAddRecFromPHI had failed for this PHI before (e.g.,
5531 // because one of the other operands of the SCEVAddExpr updating this PHI is
5532 // not invariant).
5533 //
5534 // Here we look for the case where Op = (ext(trunc(SymbolicPHI))), and in
5535 // this case predicates that allow us to prove that Op == SymbolicPHI will
5536 // be added.
5537 if (Op == SymbolicPHI)
5538 return nullptr;
5539
5540 unsigned SourceBits = SE.getTypeSizeInBits(SymbolicPHI->getType());
5541 unsigned NewBits = SE.getTypeSizeInBits(Op->getType());
5542 if (SourceBits != NewBits)
5543 return nullptr;
5544
5545 if (match(Op, m_scev_SExt(m_scev_Trunc(m_scev_Specific(SymbolicPHI))))) {
5546 Signed = true;
5547 return cast<SCEVCastExpr>(Op)->getOperand()->getType();
5548 }
5549 if (match(Op, m_scev_ZExt(m_scev_Trunc(m_scev_Specific(SymbolicPHI))))) {
5550 Signed = false;
5551 return cast<SCEVCastExpr>(Op)->getOperand()->getType();
5552 }
5553 return nullptr;
5554}
5555
5556static const Loop *isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI) {
5557 if (!PN->getType()->isIntegerTy())
5558 return nullptr;
5559 const Loop *L = LI.getLoopFor(PN->getParent());
5560 if (!L || L->getHeader() != PN->getParent())
5561 return nullptr;
5562 return L;
5563}
5564
5565// Analyze \p SymbolicPHI, a SCEV expression of a phi node, and check if the
5566// computation that updates the phi follows the following pattern:
5567// (SExt/ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy) + InvariantAccum
5568// which correspond to a phi->trunc->sext/zext->add->phi update chain.
5569// If so, try to see if it can be rewritten as an AddRecExpr under some
5570// Predicates. If successful, return them as a pair. Also cache the results
5571// of the analysis.
5572//
5573// Example usage scenario:
5574// Say the Rewriter is called for the following SCEV:
5575// 8 * ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5576// where:
5577// %X = phi i64 (%Start, %BEValue)
5578// It will visitMul->visitAdd->visitSExt->visitTrunc->visitUnknown(%X),
5579// and call this function with %SymbolicPHI = %X.
5580//
5581// The analysis will find that the value coming around the backedge has
5582// the following SCEV:
5583// BEValue = ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5584// Upon concluding that this matches the desired pattern, the function
5585// will return the pair {NewAddRec, SmallPredsVec} where:
5586// NewAddRec = {%Start,+,%Step}
5587// SmallPredsVec = {P1, P2, P3} as follows:
5588// P1(WrapPred): AR: {trunc(%Start),+,(trunc %Step)}<nsw> Flags: <nssw>
5589// P2(EqualPred): %Start == (sext i32 (trunc i64 %Start to i32) to i64)
5590// P3(EqualPred): %Step == (sext i32 (trunc i64 %Step to i32) to i64)
5591// The returned pair means that SymbolicPHI can be rewritten into NewAddRec
5592// under the predicates {P1,P2,P3}.
5593// This predicated rewrite will be cached in PredicatedSCEVRewrites:
5594// PredicatedSCEVRewrites[{%X,L}] = {NewAddRec, {P1,P2,P3)}
5595//
5596// TODO's:
5597//
5598// 1) Extend the Induction descriptor to also support inductions that involve
5599// casts: When needed (namely, when we are called in the context of the
5600// vectorizer induction analysis), a Set of cast instructions will be
5601// populated by this method, and provided back to isInductionPHI. This is
5602// needed to allow the vectorizer to properly record them to be ignored by
5603// the cost model and to avoid vectorizing them (otherwise these casts,
5604// which are redundant under the runtime overflow checks, will be
5605// vectorized, which can be costly).
5606//
5607// 2) Support additional induction/PHISCEV patterns: We also want to support
5608// inductions where the sext-trunc / zext-trunc operations (partly) occur
5609// after the induction update operation (the induction increment):
5610//
5611// (Trunc iy (SExt/ZExt ix (%SymbolicPHI + InvariantAccum) to iy) to ix)
5612// which correspond to a phi->add->trunc->sext/zext->phi update chain.
5613//
5614// (Trunc iy ((SExt/ZExt ix (%SymbolicPhi) to iy) + InvariantAccum) to ix)
5615// which correspond to a phi->trunc->add->sext/zext->phi update chain.
5616//
5617// 3) Outline common code with createAddRecFromPHI to avoid duplication.
5618std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5619ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI) {
5621
5622 // *** Part1: Analyze if we have a phi-with-cast pattern for which we can
5623 // return an AddRec expression under some predicate.
5624
5625 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5626 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5627 assert(L && "Expecting an integer loop header phi");
5628
5629 // The loop may have multiple entrances or multiple exits; we can analyze
5630 // this phi as an addrec if it has a unique entry value and a unique
5631 // backedge value.
5632 Value *BEValueV = nullptr, *StartValueV = nullptr;
5633 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5634 Value *V = PN->getIncomingValue(i);
5635 if (L->contains(PN->getIncomingBlock(i))) {
5636 if (!BEValueV) {
5637 BEValueV = V;
5638 } else if (BEValueV != V) {
5639 BEValueV = nullptr;
5640 break;
5641 }
5642 } else if (!StartValueV) {
5643 StartValueV = V;
5644 } else if (StartValueV != V) {
5645 StartValueV = nullptr;
5646 break;
5647 }
5648 }
5649 if (!BEValueV || !StartValueV)
5650 return std::nullopt;
5651
5652 const SCEV *BEValue = getSCEV(BEValueV);
5653
5654 // If the value coming around the backedge is an add with the symbolic
5655 // value we just inserted, possibly with casts that we can ignore under
5656 // an appropriate runtime guard, then we found a simple induction variable!
5657 const auto *Add = dyn_cast<SCEVAddExpr>(BEValue);
5658 if (!Add)
5659 return std::nullopt;
5660
5661 // If there is a single occurrence of the symbolic value, possibly
5662 // casted, replace it with a recurrence.
5663 unsigned FoundIndex = Add->getNumOperands();
5664 Type *TruncTy = nullptr;
5665 bool Signed;
5666 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5667 if ((TruncTy =
5668 isSimpleCastedPHI(Add->getOperand(i), SymbolicPHI, Signed, *this)))
5669 if (FoundIndex == e) {
5670 FoundIndex = i;
5671 break;
5672 }
5673
5674 if (FoundIndex == Add->getNumOperands())
5675 return std::nullopt;
5676
5677 // Create an add with everything but the specified operand.
5679 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5680 if (i != FoundIndex)
5681 Ops.push_back(Add->getOperand(i));
5682 const SCEV *Accum = getAddExpr(Ops);
5683
5684 // The runtime checks will not be valid if the step amount is
5685 // varying inside the loop.
5686 if (!isLoopInvariant(Accum, L))
5687 return std::nullopt;
5688
5689 // *** Part2: Create the predicates
5690
5691 // Analysis was successful: we have a phi-with-cast pattern for which we
5692 // can return an AddRec expression under the following predicates:
5693 //
5694 // P1: A Wrap predicate that guarantees that Trunc(Start) + i*Trunc(Accum)
5695 // fits within the truncated type (does not overflow) for i = 0 to n-1.
5696 // P2: An Equal predicate that guarantees that
5697 // Start = (Ext ix (Trunc iy (Start) to ix) to iy)
5698 // P3: An Equal predicate that guarantees that
5699 // Accum = (Ext ix (Trunc iy (Accum) to ix) to iy)
5700 //
5701 // As we next prove, the above predicates guarantee that:
5702 // Start + i*Accum = (Ext ix (Trunc iy ( Start + i*Accum ) to ix) to iy)
5703 //
5704 //
5705 // More formally, we want to prove that:
5706 // Expr(i+1) = Start + (i+1) * Accum
5707 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5708 //
5709 // Given that:
5710 // 1) Expr(0) = Start
5711 // 2) Expr(1) = Start + Accum
5712 // = (Ext ix (Trunc iy (Start) to ix) to iy) + Accum :: from P2
5713 // 3) Induction hypothesis (step i):
5714 // Expr(i) = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum
5715 //
5716 // Proof:
5717 // Expr(i+1) =
5718 // = Start + (i+1)*Accum
5719 // = (Start + i*Accum) + Accum
5720 // = Expr(i) + Accum
5721 // = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum + Accum
5722 // :: from step i
5723 //
5724 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy) + Accum + Accum
5725 //
5726 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy)
5727 // + (Ext ix (Trunc iy (Accum) to ix) to iy)
5728 // + Accum :: from P3
5729 //
5730 // = (Ext ix (Trunc iy ((Start + (i-1)*Accum) + Accum) to ix) to iy)
5731 // + Accum :: from P1: Ext(x)+Ext(y)=>Ext(x+y)
5732 //
5733 // = (Ext ix (Trunc iy (Start + i*Accum) to ix) to iy) + Accum
5734 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5735 //
5736 // By induction, the same applies to all iterations 1<=i<n:
5737 //
5738
5739 // Create a truncated addrec for which we will add a no overflow check (P1).
5740 const SCEV *StartVal = getSCEV(StartValueV);
5741 const SCEV *PHISCEV =
5742 getAddRecExpr(getTruncateExpr(StartVal, TruncTy),
5743 getTruncateExpr(Accum, TruncTy), L, SCEV::FlagAnyWrap);
5744
5745 // PHISCEV can be either a SCEVConstant or a SCEVAddRecExpr.
5746 // ex: If truncated Accum is 0 and StartVal is a constant, then PHISCEV
5747 // will be constant.
5748 //
5749 // If PHISCEV is a constant, then P1 degenerates into P2 or P3, so we don't
5750 // add P1.
5751 if (const auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5755 const SCEVPredicate *AddRecPred = getWrapPredicate(AR, AddedFlags);
5756 Predicates.push_back(AddRecPred);
5757 }
5758
5759 // Create the Equal Predicates P2,P3:
5760
5761 // It is possible that the predicates P2 and/or P3 are computable at
5762 // compile time due to StartVal and/or Accum being constants.
5763 // If either one is, then we can check that now and escape if either P2
5764 // or P3 is false.
5765
5766 // Construct the extended SCEV: (Ext ix (Trunc iy (Expr) to ix) to iy)
5767 // for each of StartVal and Accum
5768 auto getExtendedExpr = [&](const SCEV *Expr,
5769 bool CreateSignExtend) -> const SCEV * {
5770 assert(isLoopInvariant(Expr, L) && "Expr is expected to be invariant");
5771 const SCEV *TruncatedExpr = getTruncateExpr(Expr, TruncTy);
5772 const SCEV *ExtendedExpr =
5773 CreateSignExtend ? getSignExtendExpr(TruncatedExpr, Expr->getType())
5774 : getZeroExtendExpr(TruncatedExpr, Expr->getType());
5775 return ExtendedExpr;
5776 };
5777
5778 // Given:
5779 // ExtendedExpr = (Ext ix (Trunc iy (Expr) to ix) to iy
5780 // = getExtendedExpr(Expr)
5781 // Determine whether the predicate P: Expr == ExtendedExpr
5782 // is known to be false at compile time
5783 auto PredIsKnownFalse = [&](const SCEV *Expr,
5784 const SCEV *ExtendedExpr) -> bool {
5785 return Expr != ExtendedExpr &&
5786 isKnownPredicate(ICmpInst::ICMP_NE, Expr, ExtendedExpr);
5787 };
5788
5789 const SCEV *StartExtended = getExtendedExpr(StartVal, Signed);
5790 if (PredIsKnownFalse(StartVal, StartExtended)) {
5791 LLVM_DEBUG(dbgs() << "P2 is compile-time false\n";);
5792 return std::nullopt;
5793 }
5794
5795 // The Step is always Signed (because the overflow checks are either
5796 // NSSW or NUSW)
5797 const SCEV *AccumExtended = getExtendedExpr(Accum, /*CreateSignExtend=*/true);
5798 if (PredIsKnownFalse(Accum, AccumExtended)) {
5799 LLVM_DEBUG(dbgs() << "P3 is compile-time false\n";);
5800 return std::nullopt;
5801 }
5802
5803 auto AppendPredicate = [&](const SCEV *Expr,
5804 const SCEV *ExtendedExpr) -> void {
5805 if (Expr != ExtendedExpr &&
5806 !isKnownPredicate(ICmpInst::ICMP_EQ, Expr, ExtendedExpr)) {
5807 const SCEVPredicate *Pred = getEqualPredicate(Expr, ExtendedExpr);
5808 LLVM_DEBUG(dbgs() << "Added Predicate: " << *Pred);
5809 Predicates.push_back(Pred);
5810 }
5811 };
5812
5813 AppendPredicate(StartVal, StartExtended);
5814 AppendPredicate(Accum, AccumExtended);
5815
5816 // *** Part3: Predicates are ready. Now go ahead and create the new addrec in
5817 // which the casts had been folded away. The caller can rewrite SymbolicPHI
5818 // into NewAR if it will also add the runtime overflow checks specified in
5819 // Predicates.
5820 auto *NewAR = getAddRecExpr(StartVal, Accum, L, SCEV::FlagAnyWrap);
5821
5822 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> PredRewrite =
5823 std::make_pair(NewAR, Predicates);
5824 // Remember the result of the analysis for this SCEV at this locayyytion.
5825 PredicatedSCEVRewrites[{SymbolicPHI, L}] = PredRewrite;
5826 return PredRewrite;
5827}
5828
5829std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5831 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5832 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5833 if (!L)
5834 return std::nullopt;
5835
5836 // Check to see if we already analyzed this PHI.
5837 auto I = PredicatedSCEVRewrites.find({SymbolicPHI, L});
5838 if (I != PredicatedSCEVRewrites.end()) {
5839 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> Rewrite =
5840 I->second;
5841 // Analysis was done before and failed to create an AddRec:
5842 if (Rewrite.first == SymbolicPHI)
5843 return std::nullopt;
5844 // Analysis was done before and succeeded to create an AddRec under
5845 // a predicate:
5846 assert(isa<SCEVAddRecExpr>(Rewrite.first) && "Expected an AddRec");
5847 assert(!(Rewrite.second).empty() && "Expected to find Predicates");
5848 return Rewrite;
5849 }
5850
5851 std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5852 Rewrite = createAddRecFromPHIWithCastsImpl(SymbolicPHI);
5853
5854 // Record in the cache that the analysis failed
5855 if (!Rewrite) {
5857 PredicatedSCEVRewrites[{SymbolicPHI, L}] = {SymbolicPHI, Predicates};
5858 return std::nullopt;
5859 }
5860
5861 return Rewrite;
5862}
5863
5864// FIXME: This utility is currently required because the Rewriter currently
5865// does not rewrite this expression:
5866// {0, +, (sext ix (trunc iy to ix) to iy)}
5867// into {0, +, %step},
5868// even when the following Equal predicate exists:
5869// "%step == (sext ix (trunc iy to ix) to iy)".
5871 const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const {
5872 if (AR1 == AR2)
5873 return true;
5874
5875 auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool {
5876 if (Expr1 != Expr2 &&
5877 !Preds->implies(SE.getEqualPredicate(Expr1, Expr2), SE) &&
5878 !Preds->implies(SE.getEqualPredicate(Expr2, Expr1), SE))
5879 return false;
5880 return true;
5881 };
5882
5883 if (!areExprsEqual(AR1->getStart(), AR2->getStart()) ||
5884 !areExprsEqual(AR1->getStepRecurrence(SE), AR2->getStepRecurrence(SE)))
5885 return false;
5886 return true;
5887}
5888
5889/// A helper function for createAddRecFromPHI to handle simple cases.
5890///
5891/// This function tries to find an AddRec expression for the simplest (yet most
5892/// common) cases: PN = PHI(Start, OP(Self, LoopInvariant)).
5893/// If it fails, createAddRecFromPHI will use a more general, but slow,
5894/// technique for finding the AddRec expression.
5895const SCEV *ScalarEvolution::createSimpleAffineAddRec(PHINode *PN,
5896 Value *BEValueV,
5897 Value *StartValueV) {
5898 const Loop *L = LI.getLoopFor(PN->getParent());
5899 assert(L && L->getHeader() == PN->getParent());
5900 assert(BEValueV && StartValueV);
5901
5902 auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN);
5903 if (!BO)
5904 return nullptr;
5905
5906 if (BO->Opcode != Instruction::Add)
5907 return nullptr;
5908
5909 const SCEV *Accum = nullptr;
5910 if (BO->LHS == PN && L->isLoopInvariant(BO->RHS))
5911 Accum = getSCEV(BO->RHS);
5912 else if (BO->RHS == PN && L->isLoopInvariant(BO->LHS))
5913 Accum = getSCEV(BO->LHS);
5914
5915 if (!Accum)
5916 return nullptr;
5917
5919 if (BO->IsNUW)
5920 Flags = setFlags(Flags, SCEV::FlagNUW);
5921 if (BO->IsNSW)
5922 Flags = setFlags(Flags, SCEV::FlagNSW);
5923
5924 const SCEV *StartVal = getSCEV(StartValueV);
5925 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5926 insertValueToMap(PN, PHISCEV);
5927
5928 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5929 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR),
5931 proveNoWrapViaConstantRanges(AR)));
5932 }
5933
5934 // We can add Flags to the post-inc expression only if we
5935 // know that it is *undefined behavior* for BEValueV to
5936 // overflow.
5937 if (auto *BEInst = dyn_cast<Instruction>(BEValueV)) {
5938 assert(isLoopInvariant(Accum, L) &&
5939 "Accum is defined outside L, but is not invariant?");
5940 if (isAddRecNeverPoison(BEInst, L))
5941 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5942 }
5943
5944 return PHISCEV;
5945}
5946
5947const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
5948 const Loop *L = LI.getLoopFor(PN->getParent());
5949 if (!L || L->getHeader() != PN->getParent())
5950 return nullptr;
5951
5952 // The loop may have multiple entrances or multiple exits; we can analyze
5953 // this phi as an addrec if it has a unique entry value and a unique
5954 // backedge value.
5955 Value *BEValueV = nullptr, *StartValueV = nullptr;
5956 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5957 Value *V = PN->getIncomingValue(i);
5958 if (L->contains(PN->getIncomingBlock(i))) {
5959 if (!BEValueV) {
5960 BEValueV = V;
5961 } else if (BEValueV != V) {
5962 BEValueV = nullptr;
5963 break;
5964 }
5965 } else if (!StartValueV) {
5966 StartValueV = V;
5967 } else if (StartValueV != V) {
5968 StartValueV = nullptr;
5969 break;
5970 }
5971 }
5972 if (!BEValueV || !StartValueV)
5973 return nullptr;
5974
5975 assert(ValueExprMap.find_as(PN) == ValueExprMap.end() &&
5976 "PHI node already processed?");
5977
5978 // First, try to find AddRec expression without creating a fictituos symbolic
5979 // value for PN.
5980 if (auto *S = createSimpleAffineAddRec(PN, BEValueV, StartValueV))
5981 return S;
5982
5983 // Handle PHI node value symbolically.
5984 const SCEV *SymbolicName = getUnknown(PN);
5985 insertValueToMap(PN, SymbolicName);
5986
5987 // Using this symbolic name for the PHI, analyze the value coming around
5988 // the back-edge.
5989 const SCEV *BEValue = getSCEV(BEValueV);
5990
5991 // NOTE: If BEValue is loop invariant, we know that the PHI node just
5992 // has a special value for the first iteration of the loop.
5993
5994 // If the value coming around the backedge is an add with the symbolic
5995 // value we just inserted, then we found a simple induction variable!
5996 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) {
5997 // If there is a single occurrence of the symbolic value, replace it
5998 // with a recurrence.
5999 unsigned FoundIndex = Add->getNumOperands();
6000 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
6001 if (Add->getOperand(i) == SymbolicName)
6002 if (FoundIndex == e) {
6003 FoundIndex = i;
6004 break;
6005 }
6006
6007 if (FoundIndex != Add->getNumOperands()) {
6008 // Create an add with everything but the specified operand.
6010 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
6011 if (i != FoundIndex)
6012 Ops.push_back(SCEVBackedgeConditionFolder::rewrite(Add->getOperand(i),
6013 L, *this));
6014 const SCEV *Accum = getAddExpr(Ops);
6015
6016 // This is not a valid addrec if the step amount is varying each
6017 // loop iteration, but is not itself an addrec in this loop.
6018 if (isLoopInvariant(Accum, L) ||
6019 (isa<SCEVAddRecExpr>(Accum) &&
6020 cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) {
6022
6023 if (auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN)) {
6024 if (BO->Opcode == Instruction::Add && BO->LHS == PN) {
6025 if (BO->IsNUW)
6026 Flags = setFlags(Flags, SCEV::FlagNUW);
6027 if (BO->IsNSW)
6028 Flags = setFlags(Flags, SCEV::FlagNSW);
6029 }
6030 } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(BEValueV)) {
6031 if (GEP->getOperand(0) == PN) {
6032 GEPNoWrapFlags NW = GEP->getNoWrapFlags();
6033 // If the increment has any nowrap flags, then we know the address
6034 // space cannot be wrapped around.
6035 if (NW != GEPNoWrapFlags::none())
6036 Flags = setFlags(Flags, SCEV::FlagNW);
6037 // If the GEP is nuw or nusw with non-negative offset, we know that
6038 // no unsigned wrap occurs. We cannot set the nsw flag as only the
6039 // offset is treated as signed, while the base is unsigned.
6040 if (NW.hasNoUnsignedWrap() ||
6042 Flags = setFlags(Flags, SCEV::FlagNUW);
6043 }
6044
6045 // We cannot transfer nuw and nsw flags from subtraction
6046 // operations -- sub nuw X, Y is not the same as add nuw X, -Y
6047 // for instance.
6048 }
6049
6050 const SCEV *StartVal = getSCEV(StartValueV);
6051 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
6052
6053 // Okay, for the entire analysis of this edge we assumed the PHI
6054 // to be symbolic. We now need to go back and purge all of the
6055 // entries for the scalars that use the symbolic expression.
6056 forgetMemoizedResults({SymbolicName});
6057 insertValueToMap(PN, PHISCEV);
6058
6059 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
6060 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR),
6062 proveNoWrapViaConstantRanges(AR)));
6063 }
6064
6065 // We can add Flags to the post-inc expression only if we
6066 // know that it is *undefined behavior* for BEValueV to
6067 // overflow.
6068 if (auto *BEInst = dyn_cast<Instruction>(BEValueV))
6069 if (isLoopInvariant(Accum, L) && isAddRecNeverPoison(BEInst, L))
6070 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
6071
6072 return PHISCEV;
6073 }
6074 }
6075 } else {
6076 // Otherwise, this could be a loop like this:
6077 // i = 0; for (j = 1; ..; ++j) { .... i = j; }
6078 // In this case, j = {1,+,1} and BEValue is j.
6079 // Because the other in-value of i (0) fits the evolution of BEValue
6080 // i really is an addrec evolution.
6081 //
6082 // We can generalize this saying that i is the shifted value of BEValue
6083 // by one iteration:
6084 // PHI(f(0), f({1,+,1})) --> f({0,+,1})
6085
6086 // Do not allow refinement in rewriting of BEValue.
6087 const SCEV *Shifted = SCEVShiftRewriter::rewrite(BEValue, L, *this);
6088 const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this, false);
6089 if (Shifted != getCouldNotCompute() && Start != getCouldNotCompute() &&
6090 isGuaranteedNotToCauseUB(Shifted) && ::impliesPoison(Shifted, Start)) {
6091 const SCEV *StartVal = getSCEV(StartValueV);
6092 if (Start == StartVal) {
6093 // Okay, for the entire analysis of this edge we assumed the PHI
6094 // to be symbolic. We now need to go back and purge all of the
6095 // entries for the scalars that use the symbolic expression.
6096 forgetMemoizedResults({SymbolicName});
6097 insertValueToMap(PN, Shifted);
6098 return Shifted;
6099 }
6100 }
6101 }
6102
6103 // Remove the temporary PHI node SCEV that has been inserted while intending
6104 // to create an AddRecExpr for this PHI node. We can not keep this temporary
6105 // as it will prevent later (possibly simpler) SCEV expressions to be added
6106 // to the ValueExprMap.
6107 eraseValueFromMap(PN);
6108
6109 return nullptr;
6110}
6111
6112// Try to match a control flow sequence that branches out at BI and merges back
6113// at Merge into a "C ? LHS : RHS" select pattern. Return true on a successful
6114// match.
6116 Value *&C, Value *&LHS, Value *&RHS) {
6117 C = BI->getCondition();
6118
6119 BasicBlockEdge LeftEdge(BI->getParent(), BI->getSuccessor(0));
6120 BasicBlockEdge RightEdge(BI->getParent(), BI->getSuccessor(1));
6121
6122 Use &LeftUse = Merge->getOperandUse(0);
6123 Use &RightUse = Merge->getOperandUse(1);
6124
6125 if (DT.dominates(LeftEdge, LeftUse) && DT.dominates(RightEdge, RightUse)) {
6126 LHS = LeftUse;
6127 RHS = RightUse;
6128 return true;
6129 }
6130
6131 if (DT.dominates(LeftEdge, RightUse) && DT.dominates(RightEdge, LeftUse)) {
6132 LHS = RightUse;
6133 RHS = LeftUse;
6134 return true;
6135 }
6136
6137 return false;
6138}
6139
6141 Value *&Cond, Value *&LHS,
6142 Value *&RHS) {
6143 auto IsReachable =
6144 [&](BasicBlock *BB) { return DT.isReachableFromEntry(BB); };
6145 if (PN->getNumIncomingValues() == 2 && all_of(PN->blocks(), IsReachable)) {
6146 // Try to match
6147 //
6148 // br %cond, label %left, label %right
6149 // left:
6150 // br label %merge
6151 // right:
6152 // br label %merge
6153 // merge:
6154 // V = phi [ %x, %left ], [ %y, %right ]
6155 //
6156 // as "select %cond, %x, %y"
6157
6158 BasicBlock *IDom = DT[PN->getParent()]->getIDom()->getBlock();
6159 assert(IDom && "At least the entry block should dominate PN");
6160
6161 auto *BI = dyn_cast<CondBrInst>(IDom->getTerminator());
6162 return BI && BrPHIToSelect(DT, BI, PN, Cond, LHS, RHS);
6163 }
6164 return false;
6165}
6166
6167const SCEV *ScalarEvolution::createNodeFromSelectLikePHI(PHINode *PN) {
6168 Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr;
6169 if (getOperandsForSelectLikePHI(DT, PN, Cond, LHS, RHS) &&
6172 return createNodeForSelectOrPHI(PN, Cond, LHS, RHS);
6173
6174 return nullptr;
6175}
6176
6178 BinaryOperator *CommonInst = nullptr;
6179 // Check if instructions are identical.
6180 for (Value *Incoming : PN->incoming_values()) {
6181 auto *IncomingInst = dyn_cast<BinaryOperator>(Incoming);
6182 if (!IncomingInst)
6183 return nullptr;
6184 if (CommonInst) {
6185 if (!CommonInst->isIdenticalToWhenDefined(IncomingInst))
6186 return nullptr; // Not identical, give up
6187 } else {
6188 // Remember binary operator
6189 CommonInst = IncomingInst;
6190 }
6191 }
6192 return CommonInst;
6193}
6194
6195/// Returns SCEV for the first operand of a phi if all phi operands have
6196/// identical opcodes and operands
6197/// eg.
6198/// a: %add = %a + %b
6199/// br %c
6200/// b: %add1 = %a + %b
6201/// br %c
6202/// c: %phi = phi [%add, a], [%add1, b]
6203/// scev(%phi) => scev(%add)
6204const SCEV *
6205ScalarEvolution::createNodeForPHIWithIdenticalOperands(PHINode *PN) {
6206 BinaryOperator *CommonInst = getCommonInstForPHI(PN);
6207 if (!CommonInst)
6208 return nullptr;
6209
6210 // Check if SCEV exprs for instructions are identical.
6211 const SCEV *CommonSCEV = getSCEV(CommonInst);
6212 bool SCEVExprsIdentical =
6214 [this, CommonSCEV](Value *V) { return CommonSCEV == getSCEV(V); });
6215 return SCEVExprsIdentical ? CommonSCEV : nullptr;
6216}
6217
6218const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
6219 if (const SCEV *S = createAddRecFromPHI(PN))
6220 return S;
6221
6222 // We do not allow simplifying phi (undef, X) to X here, to avoid reusing the
6223 // phi node for X.
6224 if (Value *V = simplifyInstruction(
6225 PN, {getDataLayout(), &TLI, &DT, &AC, /*CtxI=*/nullptr,
6226 /*UseInstrInfo=*/true, /*CanUseUndef=*/false}))
6227 return getSCEV(V);
6228
6229 if (const SCEV *S = createNodeForPHIWithIdenticalOperands(PN))
6230 return S;
6231
6232 if (const SCEV *S = createNodeFromSelectLikePHI(PN))
6233 return S;
6234
6235 // If it's not a loop phi, we can't handle it yet.
6236 return getUnknown(PN);
6237}
6238
6239bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind,
6240 SCEVTypes RootKind) {
6241 struct FindClosure {
6242 const SCEV *OperandToFind;
6243 const SCEVTypes RootKind; // Must be a sequential min/max expression.
6244 const SCEVTypes NonSequentialRootKind; // Non-seq variant of RootKind.
6245
6246 bool Found = false;
6247
6248 bool canRecurseInto(SCEVTypes Kind) const {
6249 // We can only recurse into the SCEV expression of the same effective type
6250 // as the type of our root SCEV expression, and into zero-extensions.
6251 return RootKind == Kind || NonSequentialRootKind == Kind ||
6252 scZeroExtend == Kind;
6253 };
6254
6255 FindClosure(const SCEV *OperandToFind, SCEVTypes RootKind)
6256 : OperandToFind(OperandToFind), RootKind(RootKind),
6257 NonSequentialRootKind(
6259 RootKind)) {}
6260
6261 bool follow(const SCEV *S) {
6262 Found = S == OperandToFind;
6263
6264 return !isDone() && canRecurseInto(S->getSCEVType());
6265 }
6266
6267 bool isDone() const { return Found; }
6268 };
6269
6270 FindClosure FC(OperandToFind, RootKind);
6271 visitAll(Root, FC);
6272 return FC.Found;
6273}
6274
6275std::optional<const SCEV *>
6276ScalarEvolution::createNodeForSelectOrPHIInstWithICmpInstCond(Type *Ty,
6277 ICmpInst *Cond,
6278 Value *TrueVal,
6279 Value *FalseVal) {
6280 // Try to match some simple smax or umax patterns.
6281 auto *ICI = Cond;
6282
6283 Value *LHS = ICI->getOperand(0);
6284 Value *RHS = ICI->getOperand(1);
6285
6286 switch (ICI->getPredicate()) {
6287 case ICmpInst::ICMP_SLT:
6288 case ICmpInst::ICMP_SLE:
6289 case ICmpInst::ICMP_ULT:
6290 case ICmpInst::ICMP_ULE:
6291 std::swap(LHS, RHS);
6292 [[fallthrough]];
6293 case ICmpInst::ICMP_SGT:
6294 case ICmpInst::ICMP_SGE:
6295 case ICmpInst::ICMP_UGT:
6296 case ICmpInst::ICMP_UGE:
6297 // a > b ? a+x : b+x -> max(a, b)+x
6298 // a > b ? b+x : a+x -> min(a, b)+x
6300 bool Signed = ICI->isSigned();
6301 const SCEV *LA = getSCEV(TrueVal);
6302 const SCEV *RA = getSCEV(FalseVal);
6303 const SCEV *LS = getSCEV(LHS);
6304 const SCEV *RS = getSCEV(RHS);
6305 if (LA->getType()->isPointerTy()) {
6306 // FIXME: Handle cases where LS/RS are pointers not equal to LA/RA.
6307 // Need to make sure we can't produce weird expressions involving
6308 // negated pointers.
6309 if (LA == LS && RA == RS)
6310 return Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS);
6311 if (LA == RS && RA == LS)
6312 return Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS);
6313 }
6314 auto CoerceOperand = [&](const SCEV *Op) -> const SCEV * {
6315 if (Op->getType()->isPointerTy()) {
6318 return Op;
6319 }
6320 if (Signed)
6321 Op = getNoopOrSignExtend(Op, Ty);
6322 else
6323 Op = getNoopOrZeroExtend(Op, Ty);
6324 return Op;
6325 };
6326 LS = CoerceOperand(LS);
6327 RS = CoerceOperand(RS);
6329 break;
6330 const SCEV *LDiff = getMinusSCEV(LA, LS);
6331 const SCEV *RDiff = getMinusSCEV(RA, RS);
6332 if (LDiff == RDiff)
6333 return getAddExpr(Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS),
6334 LDiff);
6335 LDiff = getMinusSCEV(LA, RS);
6336 RDiff = getMinusSCEV(RA, LS);
6337 if (LDiff == RDiff)
6338 return getAddExpr(Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS),
6339 LDiff);
6340 }
6341 break;
6342 case ICmpInst::ICMP_NE:
6343 // x != 0 ? x+y : C+y -> x == 0 ? C+y : x+y
6344 std::swap(TrueVal, FalseVal);
6345 [[fallthrough]];
6346 case ICmpInst::ICMP_EQ:
6347 // x == 0 ? C+y : x+y -> umax(x, C)+y iff C u<= 1
6350 const SCEV *X = getNoopOrZeroExtend(getSCEV(LHS), Ty);
6351 const SCEV *TrueValExpr = getSCEV(TrueVal); // C+y
6352 const SCEV *FalseValExpr = getSCEV(FalseVal); // x+y
6353 const SCEV *Y = getMinusSCEV(FalseValExpr, X); // y = (x+y)-x
6354 const SCEV *C = getMinusSCEV(TrueValExpr, Y); // C = (C+y)-y
6355 if (isa<SCEVConstant>(C) && cast<SCEVConstant>(C)->getAPInt().ule(1))
6356 return getAddExpr(getUMaxExpr(X, C), Y);
6357 }
6358 // x == 0 ? 0 : umin (..., x, ...) -> umin_seq(x, umin (...))
6359 // x == 0 ? 0 : umin_seq(..., x, ...) -> umin_seq(x, umin_seq(...))
6360 // x == 0 ? 0 : umin (..., umin_seq(..., x, ...), ...)
6361 // -> umin_seq(x, umin (..., umin_seq(...), ...))
6363 isa<ConstantInt>(TrueVal) && cast<ConstantInt>(TrueVal)->isZero()) {
6364 const SCEV *X = getSCEV(LHS);
6365 while (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(X))
6366 X = ZExt->getOperand();
6367 if (getTypeSizeInBits(X->getType()) <= getTypeSizeInBits(Ty)) {
6368 const SCEV *FalseValExpr = getSCEV(FalseVal);
6369 if (SCEVMinMaxExprContains(FalseValExpr, X, scSequentialUMinExpr))
6370 return getUMinExpr(getNoopOrZeroExtend(X, Ty), FalseValExpr,
6371 /*Sequential=*/true);
6372 }
6373 }
6374 break;
6375 default:
6376 break;
6377 }
6378
6379 return std::nullopt;
6380}
6381
6382static std::optional<const SCEV *>
6384 const SCEV *TrueExpr, const SCEV *FalseExpr) {
6385 assert(CondExpr->getType()->isIntegerTy(1) &&
6386 TrueExpr->getType() == FalseExpr->getType() &&
6387 TrueExpr->getType()->isIntegerTy(1) &&
6388 "Unexpected operands of a select.");
6389
6390 // i1 cond ? i1 x : i1 C --> C + (i1 cond ? (i1 x - i1 C) : i1 0)
6391 // --> C + (umin_seq cond, x - C)
6392 //
6393 // i1 cond ? i1 C : i1 x --> C + (i1 cond ? i1 0 : (i1 x - i1 C))
6394 // --> C + (i1 ~cond ? (i1 x - i1 C) : i1 0)
6395 // --> C + (umin_seq ~cond, x - C)
6396
6397 // FIXME: while we can't legally model the case where both of the hands
6398 // are fully variable, we only require that the *difference* is constant.
6399 if (!isa<SCEVConstant>(TrueExpr) && !isa<SCEVConstant>(FalseExpr))
6400 return std::nullopt;
6401
6402 const SCEV *X, *C;
6403 if (isa<SCEVConstant>(TrueExpr)) {
6404 CondExpr = SE->getNotSCEV(CondExpr);
6405 X = FalseExpr;
6406 C = TrueExpr;
6407 } else {
6408 X = TrueExpr;
6409 C = FalseExpr;
6410 }
6411 return SE->getAddExpr(C, SE->getUMinExpr(CondExpr, SE->getMinusSCEV(X, C),
6412 /*Sequential=*/true));
6413}
6414
6415static std::optional<const SCEV *>
6417 Value *FalseVal) {
6418 if (!isa<ConstantInt>(TrueVal) && !isa<ConstantInt>(FalseVal))
6419 return std::nullopt;
6420
6421 const auto *SECond = SE->getSCEV(Cond);
6422 const auto *SETrue = SE->getSCEV(TrueVal);
6423 const auto *SEFalse = SE->getSCEV(FalseVal);
6424 return createNodeForSelectViaUMinSeq(SE, SECond, SETrue, SEFalse);
6425}
6426
6427const SCEV *ScalarEvolution::createNodeForSelectOrPHIViaUMinSeq(
6428 Value *V, Value *Cond, Value *TrueVal, Value *FalseVal) {
6429 assert(Cond->getType()->isIntegerTy(1) && "Select condition is not an i1?");
6430 assert(TrueVal->getType() == FalseVal->getType() &&
6431 V->getType() == TrueVal->getType() &&
6432 "Types of select hands and of the result must match.");
6433
6434 // For now, only deal with i1-typed `select`s.
6435 if (!V->getType()->isIntegerTy(1))
6436 return getUnknown(V);
6437
6438 if (std::optional<const SCEV *> S =
6439 createNodeForSelectViaUMinSeq(this, Cond, TrueVal, FalseVal))
6440 return *S;
6441
6442 return getUnknown(V);
6443}
6444
6445const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Value *V, Value *Cond,
6446 Value *TrueVal,
6447 Value *FalseVal) {
6448 // Handle "constant" branch or select. This can occur for instance when a
6449 // loop pass transforms an inner loop and moves on to process the outer loop.
6450 if (auto *CI = dyn_cast<ConstantInt>(Cond))
6451 return getSCEV(CI->isOne() ? TrueVal : FalseVal);
6452
6453 if (auto *I = dyn_cast<Instruction>(V)) {
6454 if (auto *ICI = dyn_cast<ICmpInst>(Cond)) {
6455 if (std::optional<const SCEV *> S =
6456 createNodeForSelectOrPHIInstWithICmpInstCond(I->getType(), ICI,
6457 TrueVal, FalseVal))
6458 return *S;
6459 }
6460 }
6461
6462 return createNodeForSelectOrPHIViaUMinSeq(V, Cond, TrueVal, FalseVal);
6463}
6464
6465/// Expand GEP instructions into add and multiply operations. This allows them
6466/// to be analyzed by regular SCEV code.
6467const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
6468 assert(GEP->getSourceElementType()->isSized() &&
6469 "GEP source element type must be sized");
6470
6471 SmallVector<SCEVUse, 4> IndexExprs;
6472 for (Value *Index : GEP->indices())
6473 IndexExprs.push_back(getSCEV(Index));
6474 return getGEPExpr(GEP, IndexExprs);
6475}
6476
6477APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S,
6478 const Instruction *CtxI) {
6479 uint64_t BitWidth = getTypeSizeInBits(S->getType());
6480 auto GetShiftedByZeros = [BitWidth](uint32_t TrailingZeros) {
6481 return TrailingZeros >= BitWidth
6483 : APInt::getOneBitSet(BitWidth, TrailingZeros);
6484 };
6485 auto GetGCDMultiple = [this, CtxI](const SCEVNAryExpr *N) {
6486 // The result is GCD of all operands results.
6487 APInt Res = getConstantMultiple(N->getOperand(0), CtxI);
6488 for (unsigned I = 1, E = N->getNumOperands(); I < E && Res != 1; ++I)
6490 Res, getConstantMultiple(N->getOperand(I), CtxI));
6491 return Res;
6492 };
6493
6494 switch (S->getSCEVType()) {
6495 case scConstant:
6496 return cast<SCEVConstant>(S)->getAPInt();
6497 case scPtrToAddr:
6498 case scPtrToInt:
6499 return getConstantMultiple(cast<SCEVCastExpr>(S)->getOperand());
6500 case scUDivExpr:
6501 case scVScale:
6502 return APInt(BitWidth, 1);
6503 case scTruncate: {
6504 // Only multiples that are a power of 2 will hold after truncation.
6505 const SCEVTruncateExpr *T = cast<SCEVTruncateExpr>(S);
6506 uint32_t TZ = getMinTrailingZeros(T->getOperand(), CtxI);
6507 return GetShiftedByZeros(TZ);
6508 }
6509 case scZeroExtend: {
6510 const SCEVZeroExtendExpr *Z = cast<SCEVZeroExtendExpr>(S);
6511 return getConstantMultiple(Z->getOperand(), CtxI).zext(BitWidth);
6512 }
6513 case scSignExtend: {
6514 // Only multiples that are a power of 2 will hold after sext.
6515 const SCEVSignExtendExpr *E = cast<SCEVSignExtendExpr>(S);
6516 uint32_t TZ = getMinTrailingZeros(E->getOperand(), CtxI);
6517 return GetShiftedByZeros(TZ);
6518 }
6519 case scMulExpr: {
6520 const SCEVMulExpr *M = cast<SCEVMulExpr>(S);
6521 if (M->hasNoUnsignedWrap()) {
6522 // The result is the product of all operand results.
6523 APInt Res = getConstantMultiple(M->getOperand(0), CtxI);
6524 for (const SCEV *Operand : M->operands().drop_front())
6525 Res = Res * getConstantMultiple(Operand, CtxI);
6526 return Res;
6527 }
6528
6529 // If there are no wrap guarentees, find the trailing zeros, which is the
6530 // sum of trailing zeros for all its operands.
6531 uint32_t TZ = 0;
6532 for (const SCEV *Operand : M->operands())
6533 TZ += getMinTrailingZeros(Operand, CtxI);
6534 return GetShiftedByZeros(TZ);
6535 }
6536 case scAddExpr:
6537 case scAddRecExpr: {
6538 const SCEVNAryExpr *N = cast<SCEVNAryExpr>(S);
6539 if (N->hasNoUnsignedWrap())
6540 return GetGCDMultiple(N);
6541 // Find the trailing bits, which is the minimum of its operands.
6542 uint32_t TZ = getMinTrailingZeros(N->getOperand(0), CtxI);
6543 for (const SCEV *Operand : N->operands().drop_front())
6544 TZ = std::min(TZ, getMinTrailingZeros(Operand, CtxI));
6545 return GetShiftedByZeros(TZ);
6546 }
6547 case scUMaxExpr:
6548 case scSMaxExpr:
6549 case scUMinExpr:
6550 case scSMinExpr:
6552 return GetGCDMultiple(cast<SCEVNAryExpr>(S));
6553 case scUnknown: {
6554 // Ask ValueTracking for known bits. SCEVUnknown only become available at
6555 // the point their underlying IR instruction has been defined. If CtxI was
6556 // not provided, use:
6557 // * the first instruction in the entry block if it is an argument
6558 // * the instruction itself otherwise.
6559 const SCEVUnknown *U = cast<SCEVUnknown>(S);
6560 if (!CtxI) {
6561 if (isa<Argument>(U->getValue()))
6562 CtxI = &*F.getEntryBlock().begin();
6563 else if (auto *I = dyn_cast<Instruction>(U->getValue()))
6564 CtxI = I;
6565 }
6566 unsigned Known =
6567 computeKnownBits(U->getValue(), getDataLayout(), &AC, CtxI, &DT)
6568 .countMinTrailingZeros();
6569 return GetShiftedByZeros(Known);
6570 }
6571 case scCouldNotCompute:
6572 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6573 }
6574 llvm_unreachable("Unknown SCEV kind!");
6575}
6576
6578 const Instruction *CtxI) {
6579 // Skip looking up and updating the cache if there is a context instruction,
6580 // as the result will only be valid in the specified context.
6581 if (CtxI)
6582 return getConstantMultipleImpl(S, CtxI);
6583
6584 auto I = ConstantMultipleCache.find(S);
6585 if (I != ConstantMultipleCache.end())
6586 return I->second;
6587
6588 APInt Result = getConstantMultipleImpl(S, CtxI);
6589 auto InsertPair = ConstantMultipleCache.insert({S, Result});
6590 assert(InsertPair.second && "Should insert a new key");
6591 return InsertPair.first->second;
6592}
6593
6595 APInt Multiple = getConstantMultiple(S);
6596 return Multiple == 0 ? APInt(Multiple.getBitWidth(), 1) : Multiple;
6597}
6598
6600 const Instruction *CtxI) {
6601 return std::min(getConstantMultiple(S, CtxI).countTrailingZeros(),
6602 (unsigned)getTypeSizeInBits(S->getType()));
6603}
6604
6605/// Helper method to assign a range to V from metadata present in the IR.
6606static std::optional<ConstantRange> GetRangeFromMetadata(Value *V) {
6608 if (MDNode *MD = I->getMetadata(LLVMContext::MD_range))
6609 return getConstantRangeFromMetadata(*MD);
6610 if (const auto *CB = dyn_cast<CallBase>(V))
6611 if (std::optional<ConstantRange> Range = CB->getRange())
6612 return Range;
6613 }
6614 if (auto *A = dyn_cast<Argument>(V))
6615 if (std::optional<ConstantRange> Range = A->getRange())
6616 return Range;
6617
6618 return std::nullopt;
6619}
6620
6622 SCEV::NoWrapFlags Flags) {
6623 if (AddRec->getNoWrapFlags(Flags) != Flags) {
6624 AddRec->setNoWrapFlags(Flags);
6625 UnsignedRanges.erase(AddRec);
6626 SignedRanges.erase(AddRec);
6627 ConstantMultipleCache.erase(AddRec);
6628 }
6629}
6630
6631ConstantRange ScalarEvolution::
6632getRangeForUnknownRecurrence(const SCEVUnknown *U) {
6633 const DataLayout &DL = getDataLayout();
6634
6635 unsigned BitWidth = getTypeSizeInBits(U->getType());
6636 const ConstantRange FullSet(BitWidth, /*isFullSet=*/true);
6637
6638 // Match a simple recurrence of the form: <start, ShiftOp, Step>, and then
6639 // use information about the trip count to improve our available range. Note
6640 // that the trip count independent cases are already handled by known bits.
6641 // WARNING: The definition of recurrence used here is subtly different than
6642 // the one used by AddRec (and thus most of this file). Step is allowed to
6643 // be arbitrarily loop varying here, where AddRec allows only loop invariant
6644 // and other addrecs in the same loop (for non-affine addrecs). The code
6645 // below intentionally handles the case where step is not loop invariant.
6646 auto *P = dyn_cast<PHINode>(U->getValue());
6647 if (!P)
6648 return FullSet;
6649
6650 // Make sure that no Phi input comes from an unreachable block. Otherwise,
6651 // even the values that are not available in these blocks may come from them,
6652 // and this leads to false-positive recurrence test.
6653 for (auto *Pred : predecessors(P->getParent()))
6654 if (!DT.isReachableFromEntry(Pred))
6655 return FullSet;
6656
6657 BinaryOperator *BO;
6658 Value *Start, *Step;
6659 if (!matchSimpleRecurrence(P, BO, Start, Step))
6660 return FullSet;
6661
6662 // If we found a recurrence in reachable code, we must be in a loop. Note
6663 // that BO might be in some subloop of L, and that's completely okay.
6664 auto *L = LI.getLoopFor(P->getParent());
6665 assert(L && L->getHeader() == P->getParent());
6666 if (!L->contains(BO->getParent()))
6667 // NOTE: This bailout should be an assert instead. However, asserting
6668 // the condition here exposes a case where LoopFusion is querying SCEV
6669 // with malformed loop information during the midst of the transform.
6670 // There doesn't appear to be an obvious fix, so for the moment bailout
6671 // until the caller issue can be fixed. PR49566 tracks the bug.
6672 return FullSet;
6673
6674 // TODO: Extend to other opcodes such as mul, and div
6675 switch (BO->getOpcode()) {
6676 default:
6677 return FullSet;
6678 case Instruction::AShr:
6679 case Instruction::LShr:
6680 case Instruction::Shl:
6681 break;
6682 };
6683
6684 if (BO->getOperand(0) != P)
6685 // TODO: Handle the power function forms some day.
6686 return FullSet;
6687
6688 unsigned TC = getSmallConstantMaxTripCount(L);
6689 if (!TC || TC >= BitWidth)
6690 return FullSet;
6691
6692 auto KnownStart = computeKnownBits(Start, DL, &AC, nullptr, &DT);
6693 auto KnownStep = computeKnownBits(Step, DL, &AC, nullptr, &DT);
6694 assert(KnownStart.getBitWidth() == BitWidth &&
6695 KnownStep.getBitWidth() == BitWidth);
6696
6697 // Compute total shift amount, being careful of overflow and bitwidths.
6698 auto MaxShiftAmt = KnownStep.getMaxValue();
6699 APInt TCAP(BitWidth, TC-1);
6700 bool Overflow = false;
6701 auto TotalShift = MaxShiftAmt.umul_ov(TCAP, Overflow);
6702 if (Overflow)
6703 return FullSet;
6704
6705 switch (BO->getOpcode()) {
6706 default:
6707 llvm_unreachable("filtered out above");
6708 case Instruction::AShr: {
6709 // For each ashr, three cases:
6710 // shift = 0 => unchanged value
6711 // saturation => 0 or -1
6712 // other => a value closer to zero (of the same sign)
6713 // Thus, the end value is closer to zero than the start.
6714 auto KnownEnd = KnownBits::ashr(KnownStart,
6715 KnownBits::makeConstant(TotalShift));
6716 if (KnownStart.isNonNegative())
6717 // Analogous to lshr (simply not yet canonicalized)
6718 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6719 KnownStart.getMaxValue() + 1);
6720 if (KnownStart.isNegative())
6721 // End >=u Start && End <=s Start
6722 return ConstantRange::getNonEmpty(KnownStart.getMinValue(),
6723 KnownEnd.getMaxValue() + 1);
6724 break;
6725 }
6726 case Instruction::LShr: {
6727 // For each lshr, three cases:
6728 // shift = 0 => unchanged value
6729 // saturation => 0
6730 // other => a smaller positive number
6731 // Thus, the low end of the unsigned range is the last value produced.
6732 auto KnownEnd = KnownBits::lshr(KnownStart,
6733 KnownBits::makeConstant(TotalShift));
6734 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6735 KnownStart.getMaxValue() + 1);
6736 }
6737 case Instruction::Shl: {
6738 // Iff no bits are shifted out, value increases on every shift.
6739 auto KnownEnd = KnownBits::shl(KnownStart,
6740 KnownBits::makeConstant(TotalShift));
6741 if (TotalShift.ult(KnownStart.countMinLeadingZeros()))
6742 return ConstantRange(KnownStart.getMinValue(),
6743 KnownEnd.getMaxValue() + 1);
6744 break;
6745 }
6746 };
6747 return FullSet;
6748}
6749
6750// The goal of this function is to check if recursively visiting the operands
6751// of this PHI might lead to an infinite loop. If we do see such a loop,
6752// there's no good way to break it, so we avoid analyzing such cases.
6753//
6754// getRangeRef previously used a visited set to avoid infinite loops, but this
6755// caused other issues: the result was dependent on the order of getRangeRef
6756// calls, and the interaction with createSCEVIter could cause a stack overflow
6757// in some cases (see issue #148253).
6758//
6759// FIXME: The way this is implemented is overly conservative; this checks
6760// for a few obviously safe patterns, but anything that doesn't lead to
6761// recursion is fine.
6763 Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr;
6765 return true;
6766
6767 if (all_of(PHI->operands(),
6768 [&](Value *Operand) { return DT.dominates(Operand, PHI); }))
6769 return true;
6770
6771 return false;
6772}
6773
6774const ConstantRange &
6775ScalarEvolution::getRangeRefIter(const SCEV *S,
6776 ScalarEvolution::RangeSignHint SignHint) {
6777 DenseMap<const SCEV *, ConstantRange> &Cache =
6778 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6779 : SignedRanges;
6780 SmallVector<SCEVUse> WorkList;
6781 SmallPtrSet<const SCEV *, 8> Seen;
6782
6783 // Add Expr to the worklist, if Expr is either an N-ary expression or a
6784 // SCEVUnknown PHI node.
6785 auto AddToWorklist = [&WorkList, &Seen, &Cache](const SCEV *Expr) {
6786 if (!Seen.insert(Expr).second)
6787 return;
6788 if (Cache.contains(Expr))
6789 return;
6790 switch (Expr->getSCEVType()) {
6791 case scUnknown:
6792 if (!isa<PHINode>(cast<SCEVUnknown>(Expr)->getValue()))
6793 break;
6794 [[fallthrough]];
6795 case scConstant:
6796 case scVScale:
6797 case scTruncate:
6798 case scZeroExtend:
6799 case scSignExtend:
6800 case scPtrToAddr:
6801 case scPtrToInt:
6802 case scAddExpr:
6803 case scMulExpr:
6804 case scUDivExpr:
6805 case scAddRecExpr:
6806 case scUMaxExpr:
6807 case scSMaxExpr:
6808 case scUMinExpr:
6809 case scSMinExpr:
6811 WorkList.push_back(Expr);
6812 break;
6813 case scCouldNotCompute:
6814 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6815 }
6816 };
6817 AddToWorklist(S);
6818
6819 // Build worklist by queuing operands of N-ary expressions and phi nodes.
6820 for (unsigned I = 0; I != WorkList.size(); ++I) {
6821 const SCEV *P = WorkList[I];
6822 auto *UnknownS = dyn_cast<SCEVUnknown>(P);
6823 // If it is not a `SCEVUnknown`, just recurse into operands.
6824 if (!UnknownS) {
6825 for (const SCEV *Op : P->operands())
6826 AddToWorklist(Op);
6827 continue;
6828 }
6829 // `SCEVUnknown`'s require special treatment.
6830 if (PHINode *P = dyn_cast<PHINode>(UnknownS->getValue())) {
6831 if (!RangeRefPHIAllowedOperands(DT, P))
6832 continue;
6833 for (auto &Op : reverse(P->operands()))
6834 AddToWorklist(getSCEV(Op));
6835 }
6836 }
6837
6838 if (!WorkList.empty()) {
6839 // Use getRangeRef to compute ranges for items in the worklist in reverse
6840 // order. This will force ranges for earlier operands to be computed before
6841 // their users in most cases.
6842 for (const SCEV *P : reverse(drop_begin(WorkList))) {
6843 getRangeRef(P, SignHint);
6844 }
6845 }
6846
6847 return getRangeRef(S, SignHint, 0);
6848}
6849
6850/// Determine the range for a particular SCEV. If SignHint is
6851/// HINT_RANGE_UNSIGNED (resp. HINT_RANGE_SIGNED) then getRange prefers ranges
6852/// with a "cleaner" unsigned (resp. signed) representation.
6853const ConstantRange &ScalarEvolution::getRangeRef(
6854 const SCEV *S, ScalarEvolution::RangeSignHint SignHint, unsigned Depth) {
6855 DenseMap<const SCEV *, ConstantRange> &Cache =
6856 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6857 : SignedRanges;
6859 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? ConstantRange::Unsigned
6861
6862 // See if we've computed this range already.
6864 if (I != Cache.end())
6865 return I->second;
6866
6867 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
6868 return setRange(C, SignHint, ConstantRange(C->getAPInt()));
6869
6870 // Switch to iteratively computing the range for S, if it is part of a deeply
6871 // nested expression.
6873 return getRangeRefIter(S, SignHint);
6874
6875 unsigned BitWidth = getTypeSizeInBits(S->getType());
6876 ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true);
6877 using OBO = OverflowingBinaryOperator;
6878
6879 // If the value has known zeros, the maximum value will have those known zeros
6880 // as well.
6881 if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) {
6882 APInt Multiple = getNonZeroConstantMultiple(S);
6883 APInt Remainder = APInt::getMaxValue(BitWidth).urem(Multiple);
6884 if (!Remainder.isZero())
6885 ConservativeResult =
6886 ConstantRange(APInt::getMinValue(BitWidth),
6887 APInt::getMaxValue(BitWidth) - Remainder + 1);
6888 }
6889 else {
6890 uint32_t TZ = getMinTrailingZeros(S);
6891 if (TZ != 0) {
6892 ConservativeResult = ConstantRange(
6894 APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1);
6895 }
6896 }
6897
6898 switch (S->getSCEVType()) {
6899 case scConstant:
6900 llvm_unreachable("Already handled above.");
6901 case scVScale:
6902 return setRange(S, SignHint, getVScaleRange(&F, BitWidth));
6903 case scTruncate: {
6904 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(S);
6905 ConstantRange X = getRangeRef(Trunc->getOperand(), SignHint, Depth + 1);
6906 return setRange(
6907 Trunc, SignHint,
6908 ConservativeResult.intersectWith(X.truncate(BitWidth), RangeType));
6909 }
6910 case scZeroExtend: {
6911 const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(S);
6912 ConstantRange X = getRangeRef(ZExt->getOperand(), SignHint, Depth + 1);
6913 return setRange(
6914 ZExt, SignHint,
6915 ConservativeResult.intersectWith(X.zeroExtend(BitWidth), RangeType));
6916 }
6917 case scSignExtend: {
6918 const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(S);
6919 ConstantRange X = getRangeRef(SExt->getOperand(), SignHint, Depth + 1);
6920 return setRange(
6921 SExt, SignHint,
6922 ConservativeResult.intersectWith(X.signExtend(BitWidth), RangeType));
6923 }
6924 case scPtrToAddr:
6925 case scPtrToInt: {
6926 const SCEVCastExpr *Cast = cast<SCEVCastExpr>(S);
6927 ConstantRange X = getRangeRef(Cast->getOperand(), SignHint, Depth + 1);
6928 return setRange(Cast, SignHint, X);
6929 }
6930 case scAddExpr: {
6931 const SCEVAddExpr *Add = cast<SCEVAddExpr>(S);
6932 // Check if this is a URem pattern: A - (A / B) * B, which is always < B.
6933 const SCEV *URemLHS = nullptr, *URemRHS = nullptr;
6934 if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED &&
6935 match(S, m_scev_URem(m_SCEV(URemLHS), m_SCEV(URemRHS), *this))) {
6936 ConstantRange LHSRange = getRangeRef(URemLHS, SignHint, Depth + 1);
6937 ConstantRange RHSRange = getRangeRef(URemRHS, SignHint, Depth + 1);
6938 ConservativeResult =
6939 ConservativeResult.intersectWith(LHSRange.urem(RHSRange), RangeType);
6940 }
6941 ConstantRange X = getRangeRef(Add->getOperand(0), SignHint, Depth + 1);
6942 unsigned WrapType = OBO::AnyWrap;
6943 if (Add->hasNoSignedWrap())
6944 WrapType |= OBO::NoSignedWrap;
6945 if (Add->hasNoUnsignedWrap())
6946 WrapType |= OBO::NoUnsignedWrap;
6947 for (const SCEV *Op : drop_begin(Add->operands()))
6948 X = X.addWithNoWrap(getRangeRef(Op, SignHint, Depth + 1), WrapType,
6949 RangeType);
6950 return setRange(Add, SignHint,
6951 ConservativeResult.intersectWith(X, RangeType));
6952 }
6953 case scMulExpr: {
6954 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(S);
6955 ConstantRange X = getRangeRef(Mul->getOperand(0), SignHint, Depth + 1);
6956 for (const SCEV *Op : drop_begin(Mul->operands()))
6957 X = X.multiply(getRangeRef(Op, SignHint, Depth + 1));
6958 return setRange(Mul, SignHint,
6959 ConservativeResult.intersectWith(X, RangeType));
6960 }
6961 case scUDivExpr: {
6962 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
6963 ConstantRange X = getRangeRef(UDiv->getLHS(), SignHint, Depth + 1);
6964 ConstantRange Y = getRangeRef(UDiv->getRHS(), SignHint, Depth + 1);
6965 return setRange(UDiv, SignHint,
6966 ConservativeResult.intersectWith(X.udiv(Y), RangeType));
6967 }
6968 case scAddRecExpr: {
6969 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(S);
6970 // If there's no unsigned wrap, the value will never be less than its
6971 // initial value.
6972 if (AddRec->hasNoUnsignedWrap()) {
6973 APInt UnsignedMinValue = getUnsignedRangeMin(AddRec->getStart());
6974 if (!UnsignedMinValue.isZero())
6975 ConservativeResult = ConservativeResult.intersectWith(
6976 ConstantRange(UnsignedMinValue, APInt(BitWidth, 0)), RangeType);
6977 }
6978
6979 // If there's no signed wrap, and all the operands except initial value have
6980 // the same sign or zero, the value won't ever be:
6981 // 1: smaller than initial value if operands are non negative,
6982 // 2: bigger than initial value if operands are non positive.
6983 // For both cases, value can not cross signed min/max boundary.
6984 if (AddRec->hasNoSignedWrap()) {
6985 bool AllNonNeg = true;
6986 bool AllNonPos = true;
6987 for (unsigned i = 1, e = AddRec->getNumOperands(); i != e; ++i) {
6988 if (!isKnownNonNegative(AddRec->getOperand(i)))
6989 AllNonNeg = false;
6990 if (!isKnownNonPositive(AddRec->getOperand(i)))
6991 AllNonPos = false;
6992 }
6993 if (AllNonNeg)
6994 ConservativeResult = ConservativeResult.intersectWith(
6997 RangeType);
6998 else if (AllNonPos)
6999 ConservativeResult = ConservativeResult.intersectWith(
7001 getSignedRangeMax(AddRec->getStart()) +
7002 1),
7003 RangeType);
7004 }
7005
7006 // TODO: non-affine addrec
7007 if (AddRec->isAffine()) {
7008 const SCEV *MaxBEScev =
7010 if (!isa<SCEVCouldNotCompute>(MaxBEScev)) {
7011 APInt MaxBECount = cast<SCEVConstant>(MaxBEScev)->getAPInt();
7012
7013 // Adjust MaxBECount to the same bitwidth as AddRec. We can truncate if
7014 // MaxBECount's active bits are all <= AddRec's bit width.
7015 if (MaxBECount.getBitWidth() > BitWidth &&
7016 MaxBECount.getActiveBits() <= BitWidth)
7017 MaxBECount = MaxBECount.trunc(BitWidth);
7018 else if (MaxBECount.getBitWidth() < BitWidth)
7019 MaxBECount = MaxBECount.zext(BitWidth);
7020
7021 if (MaxBECount.getBitWidth() == BitWidth) {
7022 auto RangeFromAffine = getRangeForAffineAR(
7023 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
7024 ConservativeResult =
7025 ConservativeResult.intersectWith(RangeFromAffine, RangeType);
7026
7027 auto RangeFromFactoring = getRangeViaFactoring(
7028 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
7029 ConservativeResult =
7030 ConservativeResult.intersectWith(RangeFromFactoring, RangeType);
7031 }
7032 }
7033
7034 // Now try symbolic BE count and more powerful methods.
7036 const SCEV *SymbolicMaxBECount =
7038 if (!isa<SCEVCouldNotCompute>(SymbolicMaxBECount) &&
7039 getTypeSizeInBits(MaxBEScev->getType()) <= BitWidth &&
7040 AddRec->hasNoSelfWrap()) {
7041 auto RangeFromAffineNew = getRangeForAffineNoSelfWrappingAR(
7042 AddRec, SymbolicMaxBECount, BitWidth, SignHint);
7043 ConservativeResult =
7044 ConservativeResult.intersectWith(RangeFromAffineNew, RangeType);
7045 }
7046 }
7047 }
7048
7049 return setRange(AddRec, SignHint, std::move(ConservativeResult));
7050 }
7051 case scUMaxExpr:
7052 case scSMaxExpr:
7053 case scUMinExpr:
7054 case scSMinExpr:
7055 case scSequentialUMinExpr: {
7057 switch (S->getSCEVType()) {
7058 case scUMaxExpr:
7059 ID = Intrinsic::umax;
7060 break;
7061 case scSMaxExpr:
7062 ID = Intrinsic::smax;
7063 break;
7064 case scUMinExpr:
7066 ID = Intrinsic::umin;
7067 break;
7068 case scSMinExpr:
7069 ID = Intrinsic::smin;
7070 break;
7071 default:
7072 llvm_unreachable("Unknown SCEVMinMaxExpr/SCEVSequentialMinMaxExpr.");
7073 }
7074
7075 const auto *NAry = cast<SCEVNAryExpr>(S);
7076 ConstantRange X = getRangeRef(NAry->getOperand(0), SignHint, Depth + 1);
7077 for (unsigned i = 1, e = NAry->getNumOperands(); i != e; ++i)
7078 X = X.intrinsic(
7079 ID, {X, getRangeRef(NAry->getOperand(i), SignHint, Depth + 1)});
7080 return setRange(S, SignHint,
7081 ConservativeResult.intersectWith(X, RangeType));
7082 }
7083 case scUnknown: {
7084 const SCEVUnknown *U = cast<SCEVUnknown>(S);
7085 Value *V = U->getValue();
7086
7087 // Check if the IR explicitly contains !range metadata.
7088 std::optional<ConstantRange> MDRange = GetRangeFromMetadata(V);
7089 if (MDRange)
7090 ConservativeResult =
7091 ConservativeResult.intersectWith(*MDRange, RangeType);
7092
7093 // Use facts about recurrences in the underlying IR. Note that add
7094 // recurrences are AddRecExprs and thus don't hit this path. This
7095 // primarily handles shift recurrences.
7096 auto CR = getRangeForUnknownRecurrence(U);
7097 ConservativeResult = ConservativeResult.intersectWith(CR);
7098
7099 // See if ValueTracking can give us a useful range.
7100 const DataLayout &DL = getDataLayout();
7101 KnownBits Known = computeKnownBits(V, DL, &AC, nullptr, &DT);
7102 if (Known.getBitWidth() != BitWidth)
7103 Known = Known.zextOrTrunc(BitWidth);
7104
7105 // ValueTracking may be able to compute a tighter result for the number of
7106 // sign bits than for the value of those sign bits.
7107 unsigned NS = ComputeNumSignBits(V, DL, &AC, nullptr, &DT);
7108 if (U->getType()->isPointerTy()) {
7109 // If the pointer size is larger than the index size type, this can cause
7110 // NS to be larger than BitWidth. So compensate for this.
7111 unsigned ptrSize = DL.getPointerTypeSizeInBits(U->getType());
7112 int ptrIdxDiff = ptrSize - BitWidth;
7113 if (ptrIdxDiff > 0 && ptrSize > BitWidth && NS > (unsigned)ptrIdxDiff)
7114 NS -= ptrIdxDiff;
7115 }
7116
7117 if (NS > 1) {
7118 // If we know any of the sign bits, we know all of the sign bits.
7119 if (!Known.Zero.getHiBits(NS).isZero())
7120 Known.Zero.setHighBits(NS);
7121 if (!Known.One.getHiBits(NS).isZero())
7122 Known.One.setHighBits(NS);
7123 }
7124
7125 if (Known.getMinValue() != Known.getMaxValue() + 1)
7126 ConservativeResult = ConservativeResult.intersectWith(
7127 ConstantRange(Known.getMinValue(), Known.getMaxValue() + 1),
7128 RangeType);
7129 if (NS > 1)
7130 ConservativeResult = ConservativeResult.intersectWith(
7131 ConstantRange(APInt::getSignedMinValue(BitWidth).ashr(NS - 1),
7132 APInt::getSignedMaxValue(BitWidth).ashr(NS - 1) + 1),
7133 RangeType);
7134
7135 if (U->getType()->isPointerTy() && SignHint == HINT_RANGE_UNSIGNED) {
7136 // Strengthen the range if the underlying IR value is a
7137 // global/alloca/heap allocation using the size of the object.
7138 bool CanBeNull, CanBeFreed;
7139 uint64_t DerefBytes =
7140 V->getPointerDereferenceableBytes(DL, CanBeNull, CanBeFreed);
7141 if (DerefBytes > 1 && isUIntN(BitWidth, DerefBytes)) {
7142 // The highest address the object can start is DerefBytes bytes before
7143 // the end (unsigned max value). If this value is not a multiple of the
7144 // alignment, the last possible start value is the next lowest multiple
7145 // of the alignment. Note: The computations below cannot overflow,
7146 // because if they would there's no possible start address for the
7147 // object.
7148 APInt MaxVal =
7149 APInt::getMaxValue(BitWidth) - APInt(BitWidth, DerefBytes);
7150 uint64_t Align = U->getValue()->getPointerAlignment(DL).value();
7151 uint64_t Rem = MaxVal.urem(Align);
7152 MaxVal -= APInt(BitWidth, Rem);
7153 APInt MinVal = APInt::getZero(BitWidth);
7154 if (llvm::isKnownNonZero(V, DL))
7155 MinVal = Align;
7156 ConservativeResult = ConservativeResult.intersectWith(
7157 ConstantRange::getNonEmpty(MinVal, MaxVal + 1), RangeType);
7158 }
7159 }
7160
7161 // A range of Phi is a subset of union of all ranges of its input.
7162 if (PHINode *Phi = dyn_cast<PHINode>(V)) {
7163 // SCEVExpander sometimes creates SCEVUnknowns that are secretly
7164 // AddRecs; return the range for the corresponding AddRec.
7165 if (auto *AR = dyn_cast<SCEVAddRecExpr>(getSCEV(V)))
7166 return getRangeRef(AR, SignHint, Depth + 1);
7167
7168 // Make sure that we do not run over cycled Phis.
7169 if (RangeRefPHIAllowedOperands(DT, Phi)) {
7170 ConstantRange RangeFromOps(BitWidth, /*isFullSet=*/false);
7171
7172 for (const auto &Op : Phi->operands()) {
7173 auto OpRange = getRangeRef(getSCEV(Op), SignHint, Depth + 1);
7174 RangeFromOps = RangeFromOps.unionWith(OpRange);
7175 // No point to continue if we already have a full set.
7176 if (RangeFromOps.isFullSet())
7177 break;
7178 }
7179 ConservativeResult =
7180 ConservativeResult.intersectWith(RangeFromOps, RangeType);
7181 }
7182 }
7183
7184 // vscale can't be equal to zero
7185 if (const auto *II = dyn_cast<IntrinsicInst>(V))
7186 if (II->getIntrinsicID() == Intrinsic::vscale) {
7187 ConstantRange Disallowed = APInt::getZero(BitWidth);
7188 ConservativeResult = ConservativeResult.difference(Disallowed);
7189 }
7190
7191 return setRange(U, SignHint, std::move(ConservativeResult));
7192 }
7193 case scCouldNotCompute:
7194 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
7195 }
7196
7197 return setRange(S, SignHint, std::move(ConservativeResult));
7198}
7199
7200// Given a StartRange, Step and MaxBECount for an expression compute a range of
7201// values that the expression can take. Initially, the expression has a value
7202// from StartRange and then is changed by Step up to MaxBECount times. Signed
7203// argument defines if we treat Step as signed or unsigned.
7205 const ConstantRange &StartRange,
7206 const APInt &MaxBECount,
7207 bool Signed) {
7208 unsigned BitWidth = Step.getBitWidth();
7209 assert(BitWidth == StartRange.getBitWidth() &&
7210 BitWidth == MaxBECount.getBitWidth() && "mismatched bit widths");
7211 // If either Step or MaxBECount is 0, then the expression won't change, and we
7212 // just need to return the initial range.
7213 if (Step == 0 || MaxBECount == 0)
7214 return StartRange;
7215
7216 // If we don't know anything about the initial value (i.e. StartRange is
7217 // FullRange), then we don't know anything about the final range either.
7218 // Return FullRange.
7219 if (StartRange.isFullSet())
7220 return ConstantRange::getFull(BitWidth);
7221
7222 // If Step is signed and negative, then we use its absolute value, but we also
7223 // note that we're moving in the opposite direction.
7224 bool Descending = Signed && Step.isNegative();
7225
7226 if (Signed)
7227 // This is correct even for INT_SMIN. Let's look at i8 to illustrate this:
7228 // abs(INT_SMIN) = abs(-128) = abs(0x80) = -0x80 = 0x80 = 128.
7229 // This equations hold true due to the well-defined wrap-around behavior of
7230 // APInt.
7231 Step = Step.abs();
7232
7233 // Check if Offset is more than full span of BitWidth. If it is, the
7234 // expression is guaranteed to overflow.
7235 if (APInt::getMaxValue(StartRange.getBitWidth()).udiv(Step).ult(MaxBECount))
7236 return ConstantRange::getFull(BitWidth);
7237
7238 // Offset is by how much the expression can change. Checks above guarantee no
7239 // overflow here.
7240 APInt Offset = Step * MaxBECount;
7241
7242 // Minimum value of the final range will match the minimal value of StartRange
7243 // if the expression is increasing and will be decreased by Offset otherwise.
7244 // Maximum value of the final range will match the maximal value of StartRange
7245 // if the expression is decreasing and will be increased by Offset otherwise.
7246 APInt StartLower = StartRange.getLower();
7247 APInt StartUpper = StartRange.getUpper() - 1;
7248 APInt MovedBoundary = Descending ? (StartLower - std::move(Offset))
7249 : (StartUpper + std::move(Offset));
7250
7251 // It's possible that the new minimum/maximum value will fall into the initial
7252 // range (due to wrap around). This means that the expression can take any
7253 // value in this bitwidth, and we have to return full range.
7254 if (StartRange.contains(MovedBoundary))
7255 return ConstantRange::getFull(BitWidth);
7256
7257 APInt NewLower =
7258 Descending ? std::move(MovedBoundary) : std::move(StartLower);
7259 APInt NewUpper =
7260 Descending ? std::move(StartUpper) : std::move(MovedBoundary);
7261 NewUpper += 1;
7262
7263 // No overflow detected, return [StartLower, StartUpper + Offset + 1) range.
7264 return ConstantRange::getNonEmpty(std::move(NewLower), std::move(NewUpper));
7265}
7266
7267ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start,
7268 const SCEV *Step,
7269 const APInt &MaxBECount) {
7270 assert(getTypeSizeInBits(Start->getType()) ==
7271 getTypeSizeInBits(Step->getType()) &&
7272 getTypeSizeInBits(Start->getType()) == MaxBECount.getBitWidth() &&
7273 "mismatched bit widths");
7274
7275 // First, consider step signed.
7276 ConstantRange StartSRange = getSignedRange(Start);
7277 ConstantRange StepSRange = getSignedRange(Step);
7278
7279 // If Step can be both positive and negative, we need to find ranges for the
7280 // maximum absolute step values in both directions and union them.
7281 ConstantRange SR = getRangeForAffineARHelper(
7282 StepSRange.getSignedMin(), StartSRange, MaxBECount, /* Signed = */ true);
7284 StartSRange, MaxBECount,
7285 /* Signed = */ true));
7286
7287 // Next, consider step unsigned.
7288 ConstantRange UR = getRangeForAffineARHelper(
7289 getUnsignedRangeMax(Step), getUnsignedRange(Start), MaxBECount,
7290 /* Signed = */ false);
7291
7292 // Finally, intersect signed and unsigned ranges.
7294}
7295
7296ConstantRange ScalarEvolution::getRangeForAffineNoSelfWrappingAR(
7297 const SCEVAddRecExpr *AddRec, const SCEV *MaxBECount, unsigned BitWidth,
7298 ScalarEvolution::RangeSignHint SignHint) {
7299 assert(AddRec->isAffine() && "Non-affine AddRecs are not suppored!\n");
7300 assert(AddRec->hasNoSelfWrap() &&
7301 "This only works for non-self-wrapping AddRecs!");
7302 const bool IsSigned = SignHint == HINT_RANGE_SIGNED;
7303 const SCEV *Step = AddRec->getStepRecurrence(*this);
7304 // Only deal with constant step to save compile time.
7305 if (!isa<SCEVConstant>(Step))
7306 return ConstantRange::getFull(BitWidth);
7307 // Let's make sure that we can prove that we do not self-wrap during
7308 // MaxBECount iterations. We need this because MaxBECount is a maximum
7309 // iteration count estimate, and we might infer nw from some exit for which we
7310 // do not know max exit count (or any other side reasoning).
7311 // TODO: Turn into assert at some point.
7312 if (getTypeSizeInBits(MaxBECount->getType()) >
7313 getTypeSizeInBits(AddRec->getType()))
7314 return ConstantRange::getFull(BitWidth);
7315 MaxBECount = getNoopOrZeroExtend(MaxBECount, AddRec->getType());
7316 const SCEV *RangeWidth = getMinusOne(AddRec->getType());
7317 const SCEV *StepAbs = getUMinExpr(Step, getNegativeSCEV(Step));
7318 const SCEV *MaxItersWithoutWrap = getUDivExpr(RangeWidth, StepAbs);
7319 if (!isKnownPredicateViaConstantRanges(ICmpInst::ICMP_ULE, MaxBECount,
7320 MaxItersWithoutWrap))
7321 return ConstantRange::getFull(BitWidth);
7322
7323 ICmpInst::Predicate LEPred =
7325 ICmpInst::Predicate GEPred =
7327 const SCEV *End = AddRec->evaluateAtIteration(MaxBECount, *this);
7328
7329 // We know that there is no self-wrap. Let's take Start and End values and
7330 // look at all intermediate values V1, V2, ..., Vn that IndVar takes during
7331 // the iteration. They either lie inside the range [Min(Start, End),
7332 // Max(Start, End)] or outside it:
7333 //
7334 // Case 1: RangeMin ... Start V1 ... VN End ... RangeMax;
7335 // Case 2: RangeMin Vk ... V1 Start ... End Vn ... Vk + 1 RangeMax;
7336 //
7337 // No self wrap flag guarantees that the intermediate values cannot be BOTH
7338 // outside and inside the range [Min(Start, End), Max(Start, End)]. Using that
7339 // knowledge, let's try to prove that we are dealing with Case 1. It is so if
7340 // Start <= End and step is positive, or Start >= End and step is negative.
7341 const SCEV *Start = applyLoopGuards(AddRec->getStart(), AddRec->getLoop());
7342 ConstantRange StartRange = getRangeRef(Start, SignHint);
7343 ConstantRange EndRange = getRangeRef(End, SignHint);
7344 ConstantRange RangeBetween = StartRange.unionWith(EndRange);
7345 // If they already cover full iteration space, we will know nothing useful
7346 // even if we prove what we want to prove.
7347 if (RangeBetween.isFullSet())
7348 return RangeBetween;
7349 // Only deal with ranges that do not wrap (i.e. RangeMin < RangeMax).
7350 bool IsWrappedSet = IsSigned ? RangeBetween.isSignWrappedSet()
7351 : RangeBetween.isWrappedSet();
7352 if (IsWrappedSet)
7353 return ConstantRange::getFull(BitWidth);
7354
7355 if (isKnownPositive(Step) &&
7356 isKnownPredicateViaConstantRanges(LEPred, Start, End))
7357 return RangeBetween;
7358 if (isKnownNegative(Step) &&
7359 isKnownPredicateViaConstantRanges(GEPred, Start, End))
7360 return RangeBetween;
7361 return ConstantRange::getFull(BitWidth);
7362}
7363
7364ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start,
7365 const SCEV *Step,
7366 const APInt &MaxBECount) {
7367 // RangeOf({C?A:B,+,C?P:Q}) == RangeOf(C?{A,+,P}:{B,+,Q})
7368 // == RangeOf({A,+,P}) union RangeOf({B,+,Q})
7369
7370 unsigned BitWidth = MaxBECount.getBitWidth();
7371 assert(getTypeSizeInBits(Start->getType()) == BitWidth &&
7372 getTypeSizeInBits(Step->getType()) == BitWidth &&
7373 "mismatched bit widths");
7374
7375 struct SelectPattern {
7376 Value *Condition = nullptr;
7377 APInt TrueValue;
7378 APInt FalseValue;
7379
7380 explicit SelectPattern(ScalarEvolution &SE, unsigned BitWidth,
7381 const SCEV *S) {
7382 std::optional<unsigned> CastOp;
7383 APInt Offset(BitWidth, 0);
7384
7386 "Should be!");
7387
7388 // Peel off a constant offset. In the future we could consider being
7389 // smarter here and handle {Start+Step,+,Step} too.
7390 const APInt *Off;
7391 if (match(S, m_scev_Add(m_scev_APInt(Off), m_SCEV(S))))
7392 Offset = *Off;
7393
7394 // Peel off a cast operation
7395 if (auto *SCast = dyn_cast<SCEVIntegralCastExpr>(S)) {
7396 CastOp = SCast->getSCEVType();
7397 S = SCast->getOperand();
7398 }
7399
7400 using namespace llvm::PatternMatch;
7401
7402 auto *SU = dyn_cast<SCEVUnknown>(S);
7403 const APInt *TrueVal, *FalseVal;
7404 if (!SU ||
7405 !match(SU->getValue(), m_Select(m_Value(Condition), m_APInt(TrueVal),
7406 m_APInt(FalseVal)))) {
7407 Condition = nullptr;
7408 return;
7409 }
7410
7411 TrueValue = *TrueVal;
7412 FalseValue = *FalseVal;
7413
7414 // Re-apply the cast we peeled off earlier
7415 if (CastOp)
7416 switch (*CastOp) {
7417 default:
7418 llvm_unreachable("Unknown SCEV cast type!");
7419
7420 case scTruncate:
7421 TrueValue = TrueValue.trunc(BitWidth);
7422 FalseValue = FalseValue.trunc(BitWidth);
7423 break;
7424 case scZeroExtend:
7425 TrueValue = TrueValue.zext(BitWidth);
7426 FalseValue = FalseValue.zext(BitWidth);
7427 break;
7428 case scSignExtend:
7429 TrueValue = TrueValue.sext(BitWidth);
7430 FalseValue = FalseValue.sext(BitWidth);
7431 break;
7432 }
7433
7434 // Re-apply the constant offset we peeled off earlier
7435 TrueValue += Offset;
7436 FalseValue += Offset;
7437 }
7438
7439 bool isRecognized() { return Condition != nullptr; }
7440 };
7441
7442 SelectPattern StartPattern(*this, BitWidth, Start);
7443 if (!StartPattern.isRecognized())
7444 return ConstantRange::getFull(BitWidth);
7445
7446 SelectPattern StepPattern(*this, BitWidth, Step);
7447 if (!StepPattern.isRecognized())
7448 return ConstantRange::getFull(BitWidth);
7449
7450 if (StartPattern.Condition != StepPattern.Condition) {
7451 // We don't handle this case today; but we could, by considering four
7452 // possibilities below instead of two. I'm not sure if there are cases where
7453 // that will help over what getRange already does, though.
7454 return ConstantRange::getFull(BitWidth);
7455 }
7456
7457 // NB! Calling ScalarEvolution::getConstant is fine, but we should not try to
7458 // construct arbitrary general SCEV expressions here. This function is called
7459 // from deep in the call stack, and calling getSCEV (on a sext instruction,
7460 // say) can end up caching a suboptimal value.
7461
7462 // FIXME: without the explicit `this` receiver below, MSVC errors out with
7463 // C2352 and C2512 (otherwise it isn't needed).
7464
7465 const SCEV *TrueStart = this->getConstant(StartPattern.TrueValue);
7466 const SCEV *TrueStep = this->getConstant(StepPattern.TrueValue);
7467 const SCEV *FalseStart = this->getConstant(StartPattern.FalseValue);
7468 const SCEV *FalseStep = this->getConstant(StepPattern.FalseValue);
7469
7470 ConstantRange TrueRange =
7471 this->getRangeForAffineAR(TrueStart, TrueStep, MaxBECount);
7472 ConstantRange FalseRange =
7473 this->getRangeForAffineAR(FalseStart, FalseStep, MaxBECount);
7474
7475 return TrueRange.unionWith(FalseRange);
7476}
7477
7478SCEV::NoWrapFlags ScalarEvolution::getNoWrapFlagsFromUB(const Value *V) {
7479 if (isa<ConstantExpr>(V)) return SCEV::FlagAnyWrap;
7480 const BinaryOperator *BinOp = cast<BinaryOperator>(V);
7481
7482 // Return early if there are no flags to propagate to the SCEV.
7484 if (BinOp->hasNoUnsignedWrap())
7486 if (BinOp->hasNoSignedWrap())
7488 if (Flags == SCEV::FlagAnyWrap)
7489 return SCEV::FlagAnyWrap;
7490
7491 return isSCEVExprNeverPoison(BinOp) ? Flags : SCEV::FlagAnyWrap;
7492}
7493
7494const Instruction *
7495ScalarEvolution::getNonTrivialDefiningScopeBound(const SCEV *S) {
7496 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S))
7497 return &*AddRec->getLoop()->getHeader()->begin();
7498 if (auto *U = dyn_cast<SCEVUnknown>(S))
7499 if (auto *I = dyn_cast<Instruction>(U->getValue()))
7500 return I;
7501 return nullptr;
7502}
7503
7504const Instruction *ScalarEvolution::getDefiningScopeBound(ArrayRef<SCEVUse> Ops,
7505 bool &Precise) {
7506 Precise = true;
7507 // Do a bounded search of the def relation of the requested SCEVs.
7508 SmallPtrSet<const SCEV *, 16> Visited;
7509 SmallVector<SCEVUse> Worklist;
7510 auto pushOp = [&](const SCEV *S) {
7511 if (!Visited.insert(S).second)
7512 return;
7513 // Threshold of 30 here is arbitrary.
7514 if (Visited.size() > 30) {
7515 Precise = false;
7516 return;
7517 }
7518 Worklist.push_back(S);
7519 };
7520
7521 for (SCEVUse S : Ops)
7522 pushOp(S);
7523
7524 const Instruction *Bound = nullptr;
7525 while (!Worklist.empty()) {
7526 SCEVUse S = Worklist.pop_back_val();
7527 if (auto *DefI = getNonTrivialDefiningScopeBound(S)) {
7528 if (!Bound || DT.dominates(Bound, DefI))
7529 Bound = DefI;
7530 } else {
7531 for (SCEVUse Op : S->operands())
7532 pushOp(Op);
7533 }
7534 }
7535 return Bound ? Bound : &*F.getEntryBlock().begin();
7536}
7537
7538const Instruction *
7539ScalarEvolution::getDefiningScopeBound(ArrayRef<SCEVUse> Ops) {
7540 bool Discard;
7541 return getDefiningScopeBound(Ops, Discard);
7542}
7543
7544bool ScalarEvolution::isGuaranteedToTransferExecutionTo(const Instruction *A,
7545 const Instruction *B) {
7546 if (A->getParent() == B->getParent() &&
7548 B->getIterator()))
7549 return true;
7550
7551 auto *BLoop = LI.getLoopFor(B->getParent());
7552 if (BLoop && BLoop->getHeader() == B->getParent() &&
7553 BLoop->getLoopPreheader() == A->getParent() &&
7555 A->getParent()->end()) &&
7556 isGuaranteedToTransferExecutionToSuccessor(B->getParent()->begin(),
7557 B->getIterator()))
7558 return true;
7559 return false;
7560}
7561
7562bool ScalarEvolution::isGuaranteedNotToBePoison(const SCEV *Op) {
7563 SCEVPoisonCollector PC(/* LookThroughMaybePoisonBlocking */ true);
7564 visitAll(Op, PC);
7565 return PC.MaybePoison.empty();
7566}
7567
7568bool ScalarEvolution::isGuaranteedNotToCauseUB(const SCEV *Op) {
7569 return !SCEVExprContains(Op, [this](const SCEV *S) {
7570 const SCEV *Op1;
7571 bool M = match(S, m_scev_UDiv(m_SCEV(), m_SCEV(Op1)));
7572 // The UDiv may be UB if the divisor is poison or zero. Unless the divisor
7573 // is a non-zero constant, we have to assume the UDiv may be UB.
7574 return M && (!isKnownNonZero(Op1) || !isGuaranteedNotToBePoison(Op1));
7575 });
7576}
7577
7578bool ScalarEvolution::isSCEVExprNeverPoison(const Instruction *I) {
7579 // Only proceed if we can prove that I does not yield poison.
7581 return false;
7582
7583 // At this point we know that if I is executed, then it does not wrap
7584 // according to at least one of NSW or NUW. If I is not executed, then we do
7585 // not know if the calculation that I represents would wrap. Multiple
7586 // instructions can map to the same SCEV. If we apply NSW or NUW from I to
7587 // the SCEV, we must guarantee no wrapping for that SCEV also when it is
7588 // derived from other instructions that map to the same SCEV. We cannot make
7589 // that guarantee for cases where I is not executed. So we need to find a
7590 // upper bound on the defining scope for the SCEV, and prove that I is
7591 // executed every time we enter that scope. When the bounding scope is a
7592 // loop (the common case), this is equivalent to proving I executes on every
7593 // iteration of that loop.
7594 SmallVector<SCEVUse> SCEVOps;
7595 for (const Use &Op : I->operands()) {
7596 // I could be an extractvalue from a call to an overflow intrinsic.
7597 // TODO: We can do better here in some cases.
7598 if (isSCEVable(Op->getType()))
7599 SCEVOps.push_back(getSCEV(Op));
7600 }
7601 auto *DefI = getDefiningScopeBound(SCEVOps);
7602 return isGuaranteedToTransferExecutionTo(DefI, I);
7603}
7604
7605bool ScalarEvolution::isAddRecNeverPoison(const Instruction *I, const Loop *L) {
7606 // If we know that \c I can never be poison period, then that's enough.
7607 if (isSCEVExprNeverPoison(I))
7608 return true;
7609
7610 // If the loop only has one exit, then we know that, if the loop is entered,
7611 // any instruction dominating that exit will be executed. If any such
7612 // instruction would result in UB, the addrec cannot be poison.
7613 //
7614 // This is basically the same reasoning as in isSCEVExprNeverPoison(), but
7615 // also handles uses outside the loop header (they just need to dominate the
7616 // single exit).
7617
7618 auto *ExitingBB = L->getExitingBlock();
7619 if (!ExitingBB || !loopHasNoAbnormalExits(L))
7620 return false;
7621
7622 SmallPtrSet<const Value *, 16> KnownPoison;
7624
7625 // We start by assuming \c I, the post-inc add recurrence, is poison. Only
7626 // things that are known to be poison under that assumption go on the
7627 // Worklist.
7628 KnownPoison.insert(I);
7629 Worklist.push_back(I);
7630
7631 while (!Worklist.empty()) {
7632 const Instruction *Poison = Worklist.pop_back_val();
7633
7634 for (const Use &U : Poison->uses()) {
7635 const Instruction *PoisonUser = cast<Instruction>(U.getUser());
7636 if (mustTriggerUB(PoisonUser, KnownPoison) &&
7637 DT.dominates(PoisonUser->getParent(), ExitingBB))
7638 return true;
7639
7640 if (propagatesPoison(U) && L->contains(PoisonUser))
7641 if (KnownPoison.insert(PoisonUser).second)
7642 Worklist.push_back(PoisonUser);
7643 }
7644 }
7645
7646 return false;
7647}
7648
7649ScalarEvolution::LoopProperties
7650ScalarEvolution::getLoopProperties(const Loop *L) {
7651 using LoopProperties = ScalarEvolution::LoopProperties;
7652
7653 auto Itr = LoopPropertiesCache.find(L);
7654 if (Itr == LoopPropertiesCache.end()) {
7655 auto HasSideEffects = [](Instruction *I) {
7656 if (auto *SI = dyn_cast<StoreInst>(I))
7657 return !SI->isSimple();
7658
7659 if (I->mayThrow())
7660 return true;
7661
7662 // Non-volatile memset / memcpy do not count as side-effect for forward
7663 // progress.
7664 if (isa<MemIntrinsic>(I) && !I->isVolatile())
7665 return false;
7666
7667 return I->mayWriteToMemory();
7668 };
7669
7670 LoopProperties LP = {/* HasNoAbnormalExits */ true,
7671 /*HasNoSideEffects*/ true};
7672
7673 for (auto *BB : L->getBlocks())
7674 for (auto &I : *BB) {
7676 LP.HasNoAbnormalExits = false;
7677 if (HasSideEffects(&I))
7678 LP.HasNoSideEffects = false;
7679 if (!LP.HasNoAbnormalExits && !LP.HasNoSideEffects)
7680 break; // We're already as pessimistic as we can get.
7681 }
7682
7683 auto InsertPair = LoopPropertiesCache.insert({L, LP});
7684 assert(InsertPair.second && "We just checked!");
7685 Itr = InsertPair.first;
7686 }
7687
7688 return Itr->second;
7689}
7690
7692 // A mustprogress loop without side effects must be finite.
7693 // TODO: The check used here is very conservative. It's only *specific*
7694 // side effects which are well defined in infinite loops.
7695 return isFinite(L) || (isMustProgress(L) && loopHasNoSideEffects(L));
7696}
7697
7698const SCEV *ScalarEvolution::createSCEVIter(Value *V) {
7699 // Worklist item with a Value and a bool indicating whether all operands have
7700 // been visited already.
7703
7704 Stack.emplace_back(V, true);
7705 Stack.emplace_back(V, false);
7706 while (!Stack.empty()) {
7707 auto E = Stack.pop_back_val();
7708 Value *CurV = E.getPointer();
7709
7710 if (getExistingSCEV(CurV))
7711 continue;
7712
7714 const SCEV *CreatedSCEV = nullptr;
7715 // If all operands have been visited already, create the SCEV.
7716 if (E.getInt()) {
7717 CreatedSCEV = createSCEV(CurV);
7718 } else {
7719 // Otherwise get the operands we need to create SCEV's for before creating
7720 // the SCEV for CurV. If the SCEV for CurV can be constructed trivially,
7721 // just use it.
7722 CreatedSCEV = getOperandsToCreate(CurV, Ops);
7723 }
7724
7725 if (CreatedSCEV) {
7726 insertValueToMap(CurV, CreatedSCEV);
7727 } else {
7728 // Queue CurV for SCEV creation, followed by its's operands which need to
7729 // be constructed first.
7730 Stack.emplace_back(CurV, true);
7731 for (Value *Op : Ops)
7732 Stack.emplace_back(Op, false);
7733 }
7734 }
7735
7736 return getExistingSCEV(V);
7737}
7738
7739const SCEV *
7740ScalarEvolution::getOperandsToCreate(Value *V, SmallVectorImpl<Value *> &Ops) {
7741 if (!isSCEVable(V->getType()))
7742 return getUnknown(V);
7743
7744 if (Instruction *I = dyn_cast<Instruction>(V)) {
7745 // Don't attempt to analyze instructions in blocks that aren't
7746 // reachable. Such instructions don't matter, and they aren't required
7747 // to obey basic rules for definitions dominating uses which this
7748 // analysis depends on.
7749 if (!DT.isReachableFromEntry(I->getParent()))
7750 return getUnknown(PoisonValue::get(V->getType()));
7751 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7752 return getConstant(CI);
7753 else if (isa<GlobalAlias>(V))
7754 return getUnknown(V);
7755 else if (!isa<ConstantExpr>(V))
7756 return getUnknown(V);
7757
7759 if (auto BO =
7761 bool IsConstArg = isa<ConstantInt>(BO->RHS);
7762 switch (BO->Opcode) {
7763 case Instruction::Add:
7764 case Instruction::Mul: {
7765 // For additions and multiplications, traverse add/mul chains for which we
7766 // can potentially create a single SCEV, to reduce the number of
7767 // get{Add,Mul}Expr calls.
7768 do {
7769 if (BO->Op) {
7770 if (BO->Op != V && getExistingSCEV(BO->Op)) {
7771 Ops.push_back(BO->Op);
7772 break;
7773 }
7774 }
7775 Ops.push_back(BO->RHS);
7776 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7778 if (!NewBO ||
7779 (BO->Opcode == Instruction::Add &&
7780 (NewBO->Opcode != Instruction::Add &&
7781 NewBO->Opcode != Instruction::Sub)) ||
7782 (BO->Opcode == Instruction::Mul &&
7783 NewBO->Opcode != Instruction::Mul)) {
7784 Ops.push_back(BO->LHS);
7785 break;
7786 }
7787 // CreateSCEV calls getNoWrapFlagsFromUB, which under certain conditions
7788 // requires a SCEV for the LHS.
7789 if (BO->Op && (BO->IsNSW || BO->IsNUW)) {
7790 auto *I = dyn_cast<Instruction>(BO->Op);
7791 if (I && programUndefinedIfPoison(I)) {
7792 Ops.push_back(BO->LHS);
7793 break;
7794 }
7795 }
7796 BO = NewBO;
7797 } while (true);
7798 return nullptr;
7799 }
7800 case Instruction::Sub:
7801 case Instruction::UDiv:
7802 case Instruction::URem:
7803 break;
7804 case Instruction::AShr:
7805 case Instruction::Shl:
7806 case Instruction::Xor:
7807 if (!IsConstArg)
7808 return nullptr;
7809 break;
7810 case Instruction::And:
7811 case Instruction::Or:
7812 if (!IsConstArg && !BO->LHS->getType()->isIntegerTy(1))
7813 return nullptr;
7814 break;
7815 case Instruction::LShr:
7816 return getUnknown(V);
7817 default:
7818 llvm_unreachable("Unhandled binop");
7819 break;
7820 }
7821
7822 Ops.push_back(BO->LHS);
7823 Ops.push_back(BO->RHS);
7824 return nullptr;
7825 }
7826
7827 switch (U->getOpcode()) {
7828 case Instruction::Trunc:
7829 case Instruction::ZExt:
7830 case Instruction::SExt:
7831 case Instruction::PtrToAddr:
7832 case Instruction::PtrToInt:
7833 Ops.push_back(U->getOperand(0));
7834 return nullptr;
7835
7836 case Instruction::BitCast:
7837 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType())) {
7838 Ops.push_back(U->getOperand(0));
7839 return nullptr;
7840 }
7841 return getUnknown(V);
7842
7843 case Instruction::SDiv:
7844 case Instruction::SRem:
7845 Ops.push_back(U->getOperand(0));
7846 Ops.push_back(U->getOperand(1));
7847 return nullptr;
7848
7849 case Instruction::GetElementPtr:
7850 assert(cast<GEPOperator>(U)->getSourceElementType()->isSized() &&
7851 "GEP source element type must be sized");
7852 llvm::append_range(Ops, U->operands());
7853 return nullptr;
7854
7855 case Instruction::IntToPtr:
7856 return getUnknown(V);
7857
7858 case Instruction::PHI:
7859 // getNodeForPHI has four ways to turn a PHI into a SCEV; retrieve the
7860 // relevant nodes for each of them.
7861 //
7862 // The first is just to call simplifyInstruction, and get something back
7863 // that isn't a PHI.
7864 if (Value *V = simplifyInstruction(
7865 cast<PHINode>(U),
7866 {getDataLayout(), &TLI, &DT, &AC, /*CtxI=*/nullptr,
7867 /*UseInstrInfo=*/true, /*CanUseUndef=*/false})) {
7868 assert(V);
7869 Ops.push_back(V);
7870 return nullptr;
7871 }
7872 // The second is createNodeForPHIWithIdenticalOperands: this looks for
7873 // operands which all perform the same operation, but haven't been
7874 // CSE'ed for whatever reason.
7875 if (BinaryOperator *BO = getCommonInstForPHI(cast<PHINode>(U))) {
7876 assert(BO);
7877 Ops.push_back(BO);
7878 return nullptr;
7879 }
7880 // The third is createNodeFromSelectLikePHI; this takes a PHI which
7881 // is equivalent to a select, and analyzes it like a select.
7882 {
7883 Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr;
7885 assert(Cond);
7886 assert(LHS);
7887 assert(RHS);
7888 if (auto *CondICmp = dyn_cast<ICmpInst>(Cond)) {
7889 Ops.push_back(CondICmp->getOperand(0));
7890 Ops.push_back(CondICmp->getOperand(1));
7891 }
7892 Ops.push_back(Cond);
7893 Ops.push_back(LHS);
7894 Ops.push_back(RHS);
7895 return nullptr;
7896 }
7897 }
7898 // The fourth way is createAddRecFromPHI. It's complicated to handle here,
7899 // so just construct it recursively.
7900 //
7901 // In addition to getNodeForPHI, also construct nodes which might be needed
7902 // by getRangeRef.
7904 for (Value *V : cast<PHINode>(U)->operands())
7905 Ops.push_back(V);
7906 return nullptr;
7907 }
7908 return nullptr;
7909
7910 case Instruction::Select: {
7911 // Check if U is a select that can be simplified to a SCEVUnknown.
7912 auto CanSimplifyToUnknown = [this, U]() {
7913 if (U->getType()->isIntegerTy(1) || isa<ConstantInt>(U->getOperand(0)))
7914 return false;
7915
7916 auto *ICI = dyn_cast<ICmpInst>(U->getOperand(0));
7917 if (!ICI)
7918 return false;
7919 Value *LHS = ICI->getOperand(0);
7920 Value *RHS = ICI->getOperand(1);
7921 if (ICI->getPredicate() == CmpInst::ICMP_EQ ||
7922 ICI->getPredicate() == CmpInst::ICMP_NE) {
7924 return true;
7925 } else if (getTypeSizeInBits(LHS->getType()) >
7926 getTypeSizeInBits(U->getType()))
7927 return true;
7928 return false;
7929 };
7930 if (CanSimplifyToUnknown())
7931 return getUnknown(U);
7932
7933 llvm::append_range(Ops, U->operands());
7934 return nullptr;
7935 break;
7936 }
7937 case Instruction::Call:
7938 case Instruction::Invoke:
7939 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand()) {
7940 Ops.push_back(RV);
7941 return nullptr;
7942 }
7943
7944 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
7945 switch (II->getIntrinsicID()) {
7946 case Intrinsic::abs:
7947 Ops.push_back(II->getArgOperand(0));
7948 return nullptr;
7949 case Intrinsic::umax:
7950 case Intrinsic::umin:
7951 case Intrinsic::smax:
7952 case Intrinsic::smin:
7953 case Intrinsic::usub_sat:
7954 case Intrinsic::uadd_sat:
7955 Ops.push_back(II->getArgOperand(0));
7956 Ops.push_back(II->getArgOperand(1));
7957 return nullptr;
7958 case Intrinsic::start_loop_iterations:
7959 case Intrinsic::annotation:
7960 case Intrinsic::ptr_annotation:
7961 Ops.push_back(II->getArgOperand(0));
7962 return nullptr;
7963 default:
7964 break;
7965 }
7966 }
7967 break;
7968 }
7969
7970 return nullptr;
7971}
7972
7973const SCEV *ScalarEvolution::createSCEV(Value *V) {
7974 if (!isSCEVable(V->getType()))
7975 return getUnknown(V);
7976
7977 if (Instruction *I = dyn_cast<Instruction>(V)) {
7978 // Don't attempt to analyze instructions in blocks that aren't
7979 // reachable. Such instructions don't matter, and they aren't required
7980 // to obey basic rules for definitions dominating uses which this
7981 // analysis depends on.
7982 if (!DT.isReachableFromEntry(I->getParent()))
7983 return getUnknown(PoisonValue::get(V->getType()));
7984 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7985 return getConstant(CI);
7986 else if (isa<GlobalAlias>(V))
7987 return getUnknown(V);
7988 else if (!isa<ConstantExpr>(V))
7989 return getUnknown(V);
7990
7991 const SCEV *LHS;
7992 const SCEV *RHS;
7993
7995 if (auto BO =
7997 switch (BO->Opcode) {
7998 case Instruction::Add: {
7999 // The simple thing to do would be to just call getSCEV on both operands
8000 // and call getAddExpr with the result. However if we're looking at a
8001 // bunch of things all added together, this can be quite inefficient,
8002 // because it leads to N-1 getAddExpr calls for N ultimate operands.
8003 // Instead, gather up all the operands and make a single getAddExpr call.
8004 // LLVM IR canonical form means we need only traverse the left operands.
8006 do {
8007 if (BO->Op) {
8008 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
8009 AddOps.push_back(OpSCEV);
8010 break;
8011 }
8012
8013 // If a NUW or NSW flag can be applied to the SCEV for this
8014 // addition, then compute the SCEV for this addition by itself
8015 // with a separate call to getAddExpr. We need to do that
8016 // instead of pushing the operands of the addition onto AddOps,
8017 // since the flags are only known to apply to this particular
8018 // addition - they may not apply to other additions that can be
8019 // formed with operands from AddOps.
8020 const SCEV *RHS = getSCEV(BO->RHS);
8021 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
8022 if (Flags != SCEV::FlagAnyWrap) {
8023 const SCEV *LHS = getSCEV(BO->LHS);
8024 if (BO->Opcode == Instruction::Sub)
8025 AddOps.push_back(getMinusSCEV(LHS, RHS, Flags));
8026 else
8027 AddOps.push_back(getAddExpr(LHS, RHS, Flags));
8028 break;
8029 }
8030 }
8031
8032 if (BO->Opcode == Instruction::Sub)
8033 AddOps.push_back(getNegativeSCEV(getSCEV(BO->RHS)));
8034 else
8035 AddOps.push_back(getSCEV(BO->RHS));
8036
8037 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
8039 if (!NewBO || (NewBO->Opcode != Instruction::Add &&
8040 NewBO->Opcode != Instruction::Sub)) {
8041 AddOps.push_back(getSCEV(BO->LHS));
8042 break;
8043 }
8044 BO = NewBO;
8045 } while (true);
8046
8047 return getAddExpr(AddOps);
8048 }
8049
8050 case Instruction::Mul: {
8052 do {
8053 if (BO->Op) {
8054 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
8055 MulOps.push_back(OpSCEV);
8056 break;
8057 }
8058
8059 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
8060 if (Flags != SCEV::FlagAnyWrap) {
8061 LHS = getSCEV(BO->LHS);
8062 RHS = getSCEV(BO->RHS);
8063 MulOps.push_back(getMulExpr(LHS, RHS, Flags));
8064 break;
8065 }
8066 }
8067
8068 MulOps.push_back(getSCEV(BO->RHS));
8069 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
8071 if (!NewBO || NewBO->Opcode != Instruction::Mul) {
8072 MulOps.push_back(getSCEV(BO->LHS));
8073 break;
8074 }
8075 BO = NewBO;
8076 } while (true);
8077
8078 return getMulExpr(MulOps);
8079 }
8080 case Instruction::UDiv:
8081 LHS = getSCEV(BO->LHS);
8082 RHS = getSCEV(BO->RHS);
8083 return getUDivExpr(LHS, RHS);
8084 case Instruction::URem:
8085 LHS = getSCEV(BO->LHS);
8086 RHS = getSCEV(BO->RHS);
8087 return getURemExpr(LHS, RHS);
8088 case Instruction::Sub: {
8090 if (BO->Op)
8091 Flags = getNoWrapFlagsFromUB(BO->Op);
8092 LHS = getSCEV(BO->LHS);
8093 RHS = getSCEV(BO->RHS);
8094 return getMinusSCEV(LHS, RHS, Flags);
8095 }
8096 case Instruction::And:
8097 // For an expression like x&255 that merely masks off the high bits,
8098 // use zext(trunc(x)) as the SCEV expression.
8099 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
8100 if (CI->isZero())
8101 return getSCEV(BO->RHS);
8102 if (CI->isMinusOne())
8103 return getSCEV(BO->LHS);
8104 const APInt &A = CI->getValue();
8105
8106 // Instcombine's ShrinkDemandedConstant may strip bits out of
8107 // constants, obscuring what would otherwise be a low-bits mask.
8108 // Use computeKnownBits to compute what ShrinkDemandedConstant
8109 // knew about to reconstruct a low-bits mask value.
8110 unsigned LZ = A.countl_zero();
8111 unsigned TZ = A.countr_zero();
8112 unsigned BitWidth = A.getBitWidth();
8113 KnownBits Known(BitWidth);
8114 computeKnownBits(BO->LHS, Known, getDataLayout(), &AC, nullptr, &DT);
8115
8116 APInt EffectiveMask =
8117 APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ);
8118 if ((LZ != 0 || TZ != 0) && !((~A & ~Known.Zero) & EffectiveMask)) {
8119 const SCEV *MulCount = getConstant(APInt::getOneBitSet(BitWidth, TZ));
8120 const SCEV *LHS = getSCEV(BO->LHS);
8121 const SCEV *ShiftedLHS = nullptr;
8122 if (auto *LHSMul = dyn_cast<SCEVMulExpr>(LHS)) {
8123 if (auto *OpC = dyn_cast<SCEVConstant>(LHSMul->getOperand(0))) {
8124 // For an expression like (x * 8) & 8, simplify the multiply.
8125 unsigned MulZeros = OpC->getAPInt().countr_zero();
8126 unsigned GCD = std::min(MulZeros, TZ);
8127 APInt DivAmt = APInt::getOneBitSet(BitWidth, TZ - GCD);
8129 MulOps.push_back(getConstant(OpC->getAPInt().ashr(GCD)));
8130 append_range(MulOps, LHSMul->operands().drop_front());
8131 auto *NewMul = getMulExpr(MulOps, LHSMul->getNoWrapFlags());
8132 ShiftedLHS = getUDivExpr(NewMul, getConstant(DivAmt));
8133 }
8134 }
8135 if (!ShiftedLHS)
8136 ShiftedLHS = getUDivExpr(LHS, MulCount);
8137 return getMulExpr(
8139 getTruncateExpr(ShiftedLHS,
8140 IntegerType::get(getContext(), BitWidth - LZ - TZ)),
8141 BO->LHS->getType()),
8142 MulCount);
8143 }
8144 }
8145 // Binary `and` is a bit-wise `umin`.
8146 if (BO->LHS->getType()->isIntegerTy(1)) {
8147 LHS = getSCEV(BO->LHS);
8148 RHS = getSCEV(BO->RHS);
8149 return getUMinExpr(LHS, RHS);
8150 }
8151 break;
8152
8153 case Instruction::Or:
8154 // Binary `or` is a bit-wise `umax`.
8155 if (BO->LHS->getType()->isIntegerTy(1)) {
8156 LHS = getSCEV(BO->LHS);
8157 RHS = getSCEV(BO->RHS);
8158 return getUMaxExpr(LHS, RHS);
8159 }
8160 break;
8161
8162 case Instruction::Xor:
8163 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
8164 // If the RHS of xor is -1, then this is a not operation.
8165 if (CI->isMinusOne())
8166 return getNotSCEV(getSCEV(BO->LHS));
8167
8168 // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask.
8169 // This is a variant of the check for xor with -1, and it handles
8170 // the case where instcombine has trimmed non-demanded bits out
8171 // of an xor with -1.
8172 if (auto *LBO = dyn_cast<BinaryOperator>(BO->LHS))
8173 if (ConstantInt *LCI = dyn_cast<ConstantInt>(LBO->getOperand(1)))
8174 if (LBO->getOpcode() == Instruction::And &&
8175 LCI->getValue() == CI->getValue())
8176 if (const SCEVZeroExtendExpr *Z =
8178 Type *UTy = BO->LHS->getType();
8179 const SCEV *Z0 = Z->getOperand();
8180 Type *Z0Ty = Z0->getType();
8181 unsigned Z0TySize = getTypeSizeInBits(Z0Ty);
8182
8183 // If C is a low-bits mask, the zero extend is serving to
8184 // mask off the high bits. Complement the operand and
8185 // re-apply the zext.
8186 if (CI->getValue().isMask(Z0TySize))
8187 return getZeroExtendExpr(getNotSCEV(Z0), UTy);
8188
8189 // If C is a single bit, it may be in the sign-bit position
8190 // before the zero-extend. In this case, represent the xor
8191 // using an add, which is equivalent, and re-apply the zext.
8192 APInt Trunc = CI->getValue().trunc(Z0TySize);
8193 if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() &&
8194 Trunc.isSignMask())
8195 return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)),
8196 UTy);
8197 }
8198 }
8199 break;
8200
8201 case Instruction::Shl:
8202 // Turn shift left of a constant amount into a multiply.
8203 if (ConstantInt *SA = dyn_cast<ConstantInt>(BO->RHS)) {
8204 uint32_t BitWidth = cast<IntegerType>(SA->getType())->getBitWidth();
8205
8206 // If the shift count is not less than the bitwidth, the result of
8207 // the shift is undefined. Don't try to analyze it, because the
8208 // resolution chosen here may differ from the resolution chosen in
8209 // other parts of the compiler.
8210 if (SA->getValue().uge(BitWidth))
8211 break;
8212
8213 // We can safely preserve the nuw flag in all cases. It's also safe to
8214 // turn a nuw nsw shl into a nuw nsw mul. However, nsw in isolation
8215 // requires special handling. It can be preserved as long as we're not
8216 // left shifting by bitwidth - 1.
8217 auto Flags = SCEV::FlagAnyWrap;
8218 if (BO->Op) {
8219 auto MulFlags = getNoWrapFlagsFromUB(BO->Op);
8220 if ((MulFlags & SCEV::FlagNSW) &&
8221 ((MulFlags & SCEV::FlagNUW) || SA->getValue().ult(BitWidth - 1)))
8223 if (MulFlags & SCEV::FlagNUW)
8225 }
8226
8227 ConstantInt *X = ConstantInt::get(
8228 getContext(), APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
8229 return getMulExpr(getSCEV(BO->LHS), getConstant(X), Flags);
8230 }
8231 break;
8232
8233 case Instruction::AShr:
8234 // AShr X, C, where C is a constant.
8235 ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS);
8236 if (!CI)
8237 break;
8238
8239 Type *OuterTy = BO->LHS->getType();
8240 uint64_t BitWidth = getTypeSizeInBits(OuterTy);
8241 // If the shift count is not less than the bitwidth, the result of
8242 // the shift is undefined. Don't try to analyze it, because the
8243 // resolution chosen here may differ from the resolution chosen in
8244 // other parts of the compiler.
8245 if (CI->getValue().uge(BitWidth))
8246 break;
8247
8248 if (CI->isZero())
8249 return getSCEV(BO->LHS); // shift by zero --> noop
8250
8251 uint64_t AShrAmt = CI->getZExtValue();
8252 Type *TruncTy = IntegerType::get(getContext(), BitWidth - AShrAmt);
8253
8254 Operator *L = dyn_cast<Operator>(BO->LHS);
8255 const SCEV *AddTruncateExpr = nullptr;
8256 ConstantInt *ShlAmtCI = nullptr;
8257 const SCEV *AddConstant = nullptr;
8258
8259 if (L && L->getOpcode() == Instruction::Add) {
8260 // X = Shl A, n
8261 // Y = Add X, c
8262 // Z = AShr Y, m
8263 // n, c and m are constants.
8264
8265 Operator *LShift = dyn_cast<Operator>(L->getOperand(0));
8266 ConstantInt *AddOperandCI = dyn_cast<ConstantInt>(L->getOperand(1));
8267 if (LShift && LShift->getOpcode() == Instruction::Shl) {
8268 if (AddOperandCI) {
8269 const SCEV *ShlOp0SCEV = getSCEV(LShift->getOperand(0));
8270 ShlAmtCI = dyn_cast<ConstantInt>(LShift->getOperand(1));
8271 // since we truncate to TruncTy, the AddConstant should be of the
8272 // same type, so create a new Constant with type same as TruncTy.
8273 // Also, the Add constant should be shifted right by AShr amount.
8274 APInt AddOperand = AddOperandCI->getValue().ashr(AShrAmt);
8275 AddConstant = getConstant(AddOperand.trunc(BitWidth - AShrAmt));
8276 // we model the expression as sext(add(trunc(A), c << n)), since the
8277 // sext(trunc) part is already handled below, we create a
8278 // AddExpr(TruncExp) which will be used later.
8279 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
8280 }
8281 }
8282 } else if (L && L->getOpcode() == Instruction::Shl) {
8283 // X = Shl A, n
8284 // Y = AShr X, m
8285 // Both n and m are constant.
8286
8287 const SCEV *ShlOp0SCEV = getSCEV(L->getOperand(0));
8288 ShlAmtCI = dyn_cast<ConstantInt>(L->getOperand(1));
8289 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
8290 }
8291
8292 if (AddTruncateExpr && ShlAmtCI) {
8293 // We can merge the two given cases into a single SCEV statement,
8294 // incase n = m, the mul expression will be 2^0, so it gets resolved to
8295 // a simpler case. The following code handles the two cases:
8296 //
8297 // 1) For a two-shift sext-inreg, i.e. n = m,
8298 // use sext(trunc(x)) as the SCEV expression.
8299 //
8300 // 2) When n > m, use sext(mul(trunc(x), 2^(n-m)))) as the SCEV
8301 // expression. We already checked that ShlAmt < BitWidth, so
8302 // the multiplier, 1 << (ShlAmt - AShrAmt), fits into TruncTy as
8303 // ShlAmt - AShrAmt < Amt.
8304 const APInt &ShlAmt = ShlAmtCI->getValue();
8305 if (ShlAmt.ult(BitWidth) && ShlAmt.uge(AShrAmt)) {
8306 APInt Mul = APInt::getOneBitSet(BitWidth - AShrAmt,
8307 ShlAmtCI->getZExtValue() - AShrAmt);
8308 const SCEV *CompositeExpr =
8309 getMulExpr(AddTruncateExpr, getConstant(Mul));
8310 if (L->getOpcode() != Instruction::Shl)
8311 CompositeExpr = getAddExpr(CompositeExpr, AddConstant);
8312
8313 return getSignExtendExpr(CompositeExpr, OuterTy);
8314 }
8315 }
8316 break;
8317 }
8318 }
8319
8320 switch (U->getOpcode()) {
8321 case Instruction::Trunc:
8322 return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType());
8323
8324 case Instruction::ZExt:
8325 return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType());
8326
8327 case Instruction::SExt:
8328 if (auto BO = MatchBinaryOp(U->getOperand(0), getDataLayout(), AC, DT,
8330 // The NSW flag of a subtract does not always survive the conversion to
8331 // A + (-1)*B. By pushing sign extension onto its operands we are much
8332 // more likely to preserve NSW and allow later AddRec optimisations.
8333 //
8334 // NOTE: This is effectively duplicating this logic from getSignExtend:
8335 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
8336 // but by that point the NSW information has potentially been lost.
8337 if (BO->Opcode == Instruction::Sub && BO->IsNSW) {
8338 Type *Ty = U->getType();
8339 auto *V1 = getSignExtendExpr(getSCEV(BO->LHS), Ty);
8340 auto *V2 = getSignExtendExpr(getSCEV(BO->RHS), Ty);
8341 return getMinusSCEV(V1, V2, SCEV::FlagNSW);
8342 }
8343 }
8344 return getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType());
8345
8346 case Instruction::BitCast:
8347 // BitCasts are no-op casts so we just eliminate the cast.
8348 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType()))
8349 return getSCEV(U->getOperand(0));
8350 break;
8351
8352 case Instruction::PtrToAddr: {
8353 const SCEV *IntOp = getPtrToAddrExpr(getSCEV(U->getOperand(0)));
8354 if (isa<SCEVCouldNotCompute>(IntOp))
8355 return getUnknown(V);
8356 return IntOp;
8357 }
8358
8359 case Instruction::PtrToInt: {
8360 // Pointer to integer cast is straight-forward, so do model it.
8361 const SCEV *Op = getSCEV(U->getOperand(0));
8362 Type *DstIntTy = U->getType();
8363 // But only if effective SCEV (integer) type is wide enough to represent
8364 // all possible pointer values.
8365 const SCEV *IntOp = getPtrToIntExpr(Op, DstIntTy);
8366 if (isa<SCEVCouldNotCompute>(IntOp))
8367 return getUnknown(V);
8368 return IntOp;
8369 }
8370 case Instruction::IntToPtr:
8371 // Just don't deal with inttoptr casts.
8372 return getUnknown(V);
8373
8374 case Instruction::SDiv:
8375 // If both operands are non-negative, this is just an udiv.
8376 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8377 isKnownNonNegative(getSCEV(U->getOperand(1))))
8378 return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8379 break;
8380
8381 case Instruction::SRem:
8382 // If both operands are non-negative, this is just an urem.
8383 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8384 isKnownNonNegative(getSCEV(U->getOperand(1))))
8385 return getURemExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8386 break;
8387
8388 case Instruction::GetElementPtr:
8389 return createNodeForGEP(cast<GEPOperator>(U));
8390
8391 case Instruction::PHI:
8392 return createNodeForPHI(cast<PHINode>(U));
8393
8394 case Instruction::Select:
8395 return createNodeForSelectOrPHI(U, U->getOperand(0), U->getOperand(1),
8396 U->getOperand(2));
8397
8398 case Instruction::Call:
8399 case Instruction::Invoke:
8400 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand())
8401 return getSCEV(RV);
8402
8403 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
8404 switch (II->getIntrinsicID()) {
8405 case Intrinsic::abs:
8406 return getAbsExpr(
8407 getSCEV(II->getArgOperand(0)),
8408 /*IsNSW=*/cast<ConstantInt>(II->getArgOperand(1))->isOne());
8409 case Intrinsic::umax:
8410 LHS = getSCEV(II->getArgOperand(0));
8411 RHS = getSCEV(II->getArgOperand(1));
8412 return getUMaxExpr(LHS, RHS);
8413 case Intrinsic::umin:
8414 LHS = getSCEV(II->getArgOperand(0));
8415 RHS = getSCEV(II->getArgOperand(1));
8416 return getUMinExpr(LHS, RHS);
8417 case Intrinsic::smax:
8418 LHS = getSCEV(II->getArgOperand(0));
8419 RHS = getSCEV(II->getArgOperand(1));
8420 return getSMaxExpr(LHS, RHS);
8421 case Intrinsic::smin:
8422 LHS = getSCEV(II->getArgOperand(0));
8423 RHS = getSCEV(II->getArgOperand(1));
8424 return getSMinExpr(LHS, RHS);
8425 case Intrinsic::usub_sat: {
8426 const SCEV *X = getSCEV(II->getArgOperand(0));
8427 const SCEV *Y = getSCEV(II->getArgOperand(1));
8428 const SCEV *ClampedY = getUMinExpr(X, Y);
8429 return getMinusSCEV(X, ClampedY, SCEV::FlagNUW);
8430 }
8431 case Intrinsic::uadd_sat: {
8432 const SCEV *X = getSCEV(II->getArgOperand(0));
8433 const SCEV *Y = getSCEV(II->getArgOperand(1));
8434 const SCEV *ClampedX = getUMinExpr(X, getNotSCEV(Y));
8435 return getAddExpr(ClampedX, Y, SCEV::FlagNUW);
8436 }
8437 case Intrinsic::start_loop_iterations:
8438 case Intrinsic::annotation:
8439 case Intrinsic::ptr_annotation:
8440 // A start_loop_iterations or llvm.annotation or llvm.prt.annotation is
8441 // just eqivalent to the first operand for SCEV purposes.
8442 return getSCEV(II->getArgOperand(0));
8443 case Intrinsic::vscale:
8444 return getVScale(II->getType());
8445 default:
8446 break;
8447 }
8448 }
8449 break;
8450 }
8451
8452 return getUnknown(V);
8453}
8454
8455//===----------------------------------------------------------------------===//
8456// Iteration Count Computation Code
8457//
8458
8460 if (isa<SCEVCouldNotCompute>(ExitCount))
8461 return getCouldNotCompute();
8462
8463 auto *ExitCountType = ExitCount->getType();
8464 assert(ExitCountType->isIntegerTy());
8465 auto *EvalTy = Type::getIntNTy(ExitCountType->getContext(),
8466 1 + ExitCountType->getScalarSizeInBits());
8467 return getTripCountFromExitCount(ExitCount, EvalTy, nullptr);
8468}
8469
8471 Type *EvalTy,
8472 const Loop *L) {
8473 if (isa<SCEVCouldNotCompute>(ExitCount))
8474 return getCouldNotCompute();
8475
8476 unsigned ExitCountSize = getTypeSizeInBits(ExitCount->getType());
8477 unsigned EvalSize = EvalTy->getPrimitiveSizeInBits();
8478
8479 auto CanAddOneWithoutOverflow = [&]() {
8480 ConstantRange ExitCountRange =
8481 getRangeRef(ExitCount, RangeSignHint::HINT_RANGE_UNSIGNED);
8482 if (!ExitCountRange.contains(APInt::getMaxValue(ExitCountSize)))
8483 return true;
8484
8485 return L && isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, ExitCount,
8486 getMinusOne(ExitCount->getType()));
8487 };
8488
8489 // If we need to zero extend the backedge count, check if we can add one to
8490 // it prior to zero extending without overflow. Provided this is safe, it
8491 // allows better simplification of the +1.
8492 if (EvalSize > ExitCountSize && CanAddOneWithoutOverflow())
8493 return getZeroExtendExpr(
8494 getAddExpr(ExitCount, getOne(ExitCount->getType())), EvalTy);
8495
8496 // Get the total trip count from the count by adding 1. This may wrap.
8497 return getAddExpr(getTruncateOrZeroExtend(ExitCount, EvalTy), getOne(EvalTy));
8498}
8499
8500static unsigned getConstantTripCount(const SCEVConstant *ExitCount) {
8501 if (!ExitCount)
8502 return 0;
8503
8504 ConstantInt *ExitConst = ExitCount->getValue();
8505
8506 // Guard against huge trip counts.
8507 if (ExitConst->getValue().getActiveBits() > 32)
8508 return 0;
8509
8510 // In case of integer overflow, this returns 0, which is correct.
8511 return ((unsigned)ExitConst->getZExtValue()) + 1;
8512}
8513
8515 auto *ExitCount = dyn_cast<SCEVConstant>(getBackedgeTakenCount(L, Exact));
8516 return getConstantTripCount(ExitCount);
8517}
8518
8519unsigned
8521 const BasicBlock *ExitingBlock) {
8522 assert(ExitingBlock && "Must pass a non-null exiting block!");
8523 assert(L->isLoopExiting(ExitingBlock) &&
8524 "Exiting block must actually branch out of the loop!");
8525 const SCEVConstant *ExitCount =
8526 dyn_cast<SCEVConstant>(getExitCount(L, ExitingBlock));
8527 return getConstantTripCount(ExitCount);
8528}
8529
8531 const Loop *L, SmallVectorImpl<const SCEVPredicate *> *Predicates) {
8532
8533 const auto *MaxExitCount =
8534 Predicates ? getPredicatedConstantMaxBackedgeTakenCount(L, *Predicates)
8536 return getConstantTripCount(dyn_cast<SCEVConstant>(MaxExitCount));
8537}
8538
8540 SmallVector<BasicBlock *, 8> ExitingBlocks;
8541 L->getExitingBlocks(ExitingBlocks);
8542
8543 std::optional<unsigned> Res;
8544 for (auto *ExitingBB : ExitingBlocks) {
8545 unsigned Multiple = getSmallConstantTripMultiple(L, ExitingBB);
8546 if (!Res)
8547 Res = Multiple;
8548 Res = std::gcd(*Res, Multiple);
8549 }
8550 return Res.value_or(1);
8551}
8552
8554 const SCEV *ExitCount) {
8555 if (isa<SCEVCouldNotCompute>(ExitCount))
8556 return 1;
8557
8558 // Get the trip count
8559 const SCEV *TCExpr = getTripCountFromExitCount(applyLoopGuards(ExitCount, L));
8560
8561 APInt Multiple = getNonZeroConstantMultiple(TCExpr);
8562 // If a trip multiple is huge (>=2^32), the trip count is still divisible by
8563 // the greatest power of 2 divisor less than 2^32.
8564 return Multiple.getActiveBits() > 32
8565 ? 1U << std::min(31U, Multiple.countTrailingZeros())
8566 : (unsigned)Multiple.getZExtValue();
8567}
8568
8569/// Returns the largest constant divisor of the trip count of this loop as a
8570/// normal unsigned value, if possible. This means that the actual trip count is
8571/// always a multiple of the returned value (don't forget the trip count could
8572/// very well be zero as well!).
8573///
8574/// Returns 1 if the trip count is unknown or not guaranteed to be the
8575/// multiple of a constant (which is also the case if the trip count is simply
8576/// constant, use getSmallConstantTripCount for that case), Will also return 1
8577/// if the trip count is very large (>= 2^32).
8578///
8579/// As explained in the comments for getSmallConstantTripCount, this assumes
8580/// that control exits the loop via ExitingBlock.
8581unsigned
8583 const BasicBlock *ExitingBlock) {
8584 assert(ExitingBlock && "Must pass a non-null exiting block!");
8585 assert(L->isLoopExiting(ExitingBlock) &&
8586 "Exiting block must actually branch out of the loop!");
8587 const SCEV *ExitCount = getExitCount(L, ExitingBlock);
8588 return getSmallConstantTripMultiple(L, ExitCount);
8589}
8590
8592 const BasicBlock *ExitingBlock,
8593 ExitCountKind Kind) {
8594 switch (Kind) {
8595 case Exact:
8596 return getBackedgeTakenInfo(L).getExact(ExitingBlock, this);
8597 case SymbolicMaximum:
8598 return getBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this);
8599 case ConstantMaximum:
8600 return getBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this);
8601 };
8602 llvm_unreachable("Invalid ExitCountKind!");
8603}
8604
8606 const Loop *L, const BasicBlock *ExitingBlock,
8608 switch (Kind) {
8609 case Exact:
8610 return getPredicatedBackedgeTakenInfo(L).getExact(ExitingBlock, this,
8611 Predicates);
8612 case SymbolicMaximum:
8613 return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this,
8614 Predicates);
8615 case ConstantMaximum:
8616 return getPredicatedBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this,
8617 Predicates);
8618 };
8619 llvm_unreachable("Invalid ExitCountKind!");
8620}
8621
8624 return getPredicatedBackedgeTakenInfo(L).getExact(L, this, &Preds);
8625}
8626
8628 ExitCountKind Kind) {
8629 switch (Kind) {
8630 case Exact:
8631 return getBackedgeTakenInfo(L).getExact(L, this);
8632 case ConstantMaximum:
8633 return getBackedgeTakenInfo(L).getConstantMax(this);
8634 case SymbolicMaximum:
8635 return getBackedgeTakenInfo(L).getSymbolicMax(L, this);
8636 };
8637 llvm_unreachable("Invalid ExitCountKind!");
8638}
8639
8642 return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(L, this, &Preds);
8643}
8644
8647 return getPredicatedBackedgeTakenInfo(L).getConstantMax(this, &Preds);
8648}
8649
8651 return getBackedgeTakenInfo(L).isConstantMaxOrZero(this);
8652}
8653
8654/// Push PHI nodes in the header of the given loop onto the given Worklist.
8655static void PushLoopPHIs(const Loop *L,
8658 BasicBlock *Header = L->getHeader();
8659
8660 // Push all Loop-header PHIs onto the Worklist stack.
8661 for (PHINode &PN : Header->phis())
8662 if (Visited.insert(&PN).second)
8663 Worklist.push_back(&PN);
8664}
8665
8666ScalarEvolution::BackedgeTakenInfo &
8667ScalarEvolution::getPredicatedBackedgeTakenInfo(const Loop *L) {
8668 auto &BTI = getBackedgeTakenInfo(L);
8669 if (BTI.hasFullInfo())
8670 return BTI;
8671
8672 auto Pair = PredicatedBackedgeTakenCounts.try_emplace(L);
8673
8674 if (!Pair.second)
8675 return Pair.first->second;
8676
8677 BackedgeTakenInfo Result =
8678 computeBackedgeTakenCount(L, /*AllowPredicates=*/true);
8679
8680 return PredicatedBackedgeTakenCounts.find(L)->second = std::move(Result);
8681}
8682
8683ScalarEvolution::BackedgeTakenInfo &
8684ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
8685 // Initially insert an invalid entry for this loop. If the insertion
8686 // succeeds, proceed to actually compute a backedge-taken count and
8687 // update the value. The temporary CouldNotCompute value tells SCEV
8688 // code elsewhere that it shouldn't attempt to request a new
8689 // backedge-taken count, which could result in infinite recursion.
8690 std::pair<DenseMap<const Loop *, BackedgeTakenInfo>::iterator, bool> Pair =
8691 BackedgeTakenCounts.try_emplace(L);
8692 if (!Pair.second)
8693 return Pair.first->second;
8694
8695 // computeBackedgeTakenCount may allocate memory for its result. Inserting it
8696 // into the BackedgeTakenCounts map transfers ownership. Otherwise, the result
8697 // must be cleared in this scope.
8698 BackedgeTakenInfo Result = computeBackedgeTakenCount(L);
8699
8700 // Now that we know more about the trip count for this loop, forget any
8701 // existing SCEV values for PHI nodes in this loop since they are only
8702 // conservative estimates made without the benefit of trip count
8703 // information. This invalidation is not necessary for correctness, and is
8704 // only done to produce more precise results.
8705 if (Result.hasAnyInfo()) {
8706 // Invalidate any expression using an addrec in this loop.
8707 SmallVector<SCEVUse, 8> ToForget;
8708 auto LoopUsersIt = LoopUsers.find(L);
8709 if (LoopUsersIt != LoopUsers.end())
8710 append_range(ToForget, LoopUsersIt->second);
8711 forgetMemoizedResults(ToForget);
8712
8713 // Invalidate constant-evolved loop header phis.
8714 for (PHINode &PN : L->getHeader()->phis())
8715 ConstantEvolutionLoopExitValue.erase(&PN);
8716 }
8717
8718 // Re-lookup the insert position, since the call to
8719 // computeBackedgeTakenCount above could result in a
8720 // recusive call to getBackedgeTakenInfo (on a different
8721 // loop), which would invalidate the iterator computed
8722 // earlier.
8723 return BackedgeTakenCounts.find(L)->second = std::move(Result);
8724}
8725
8727 // This method is intended to forget all info about loops. It should
8728 // invalidate caches as if the following happened:
8729 // - The trip counts of all loops have changed arbitrarily
8730 // - Every llvm::Value has been updated in place to produce a different
8731 // result.
8732 BackedgeTakenCounts.clear();
8733 PredicatedBackedgeTakenCounts.clear();
8734 BECountUsers.clear();
8735 LoopPropertiesCache.clear();
8736 ConstantEvolutionLoopExitValue.clear();
8737 ValueExprMap.clear();
8738 ValuesAtScopes.clear();
8739 ValuesAtScopesUsers.clear();
8740 LoopDispositions.clear();
8741 BlockDispositions.clear();
8742 UnsignedRanges.clear();
8743 SignedRanges.clear();
8744 ExprValueMap.clear();
8745 HasRecMap.clear();
8746 ConstantMultipleCache.clear();
8747 PredicatedSCEVRewrites.clear();
8748 FoldCache.clear();
8749 FoldCacheUser.clear();
8750}
8751void ScalarEvolution::visitAndClearUsers(
8754 SmallVectorImpl<SCEVUse> &ToForget) {
8755 while (!Worklist.empty()) {
8756 Instruction *I = Worklist.pop_back_val();
8757 if (!isSCEVable(I->getType()) && !isa<WithOverflowInst>(I))
8758 continue;
8759
8761 ValueExprMap.find_as(static_cast<Value *>(I));
8762 if (It != ValueExprMap.end()) {
8763 eraseValueFromMap(It->first);
8764 ToForget.push_back(It->second);
8765 if (PHINode *PN = dyn_cast<PHINode>(I))
8766 ConstantEvolutionLoopExitValue.erase(PN);
8767 }
8768
8769 PushDefUseChildren(I, Worklist, Visited);
8770 }
8771}
8772
8774 SmallVector<const Loop *, 16> LoopWorklist(1, L);
8777 SmallVector<SCEVUse, 16> ToForget;
8778
8779 // Iterate over all the loops and sub-loops to drop SCEV information.
8780 while (!LoopWorklist.empty()) {
8781 auto *CurrL = LoopWorklist.pop_back_val();
8782
8783 // Drop any stored trip count value.
8784 forgetBackedgeTakenCounts(CurrL, /* Predicated */ false);
8785 forgetBackedgeTakenCounts(CurrL, /* Predicated */ true);
8786
8787 // Drop information about predicated SCEV rewrites for this loop.
8788 for (auto I = PredicatedSCEVRewrites.begin();
8789 I != PredicatedSCEVRewrites.end();) {
8790 std::pair<const SCEV *, const Loop *> Entry = I->first;
8791 if (Entry.second == CurrL)
8792 PredicatedSCEVRewrites.erase(I++);
8793 else
8794 ++I;
8795 }
8796
8797 auto LoopUsersItr = LoopUsers.find(CurrL);
8798 if (LoopUsersItr != LoopUsers.end())
8799 llvm::append_range(ToForget, LoopUsersItr->second);
8800
8801 // Drop information about expressions based on loop-header PHIs.
8802 PushLoopPHIs(CurrL, Worklist, Visited);
8803 visitAndClearUsers(Worklist, Visited, ToForget);
8804
8805 LoopPropertiesCache.erase(CurrL);
8806 // Forget all contained loops too, to avoid dangling entries in the
8807 // ValuesAtScopes map.
8808 LoopWorklist.append(CurrL->begin(), CurrL->end());
8809 }
8810 forgetMemoizedResults(ToForget);
8811}
8812
8814 forgetLoop(L->getOutermostLoop());
8815}
8816
8819 if (!I) return;
8820
8821 // Drop information about expressions based on loop-header PHIs.
8824 SmallVector<SCEVUse, 8> ToForget;
8825 Worklist.push_back(I);
8826 Visited.insert(I);
8827 visitAndClearUsers(Worklist, Visited, ToForget);
8828
8829 forgetMemoizedResults(ToForget);
8830}
8831
8833 if (!isSCEVable(V->getType()))
8834 return;
8835
8836 // If SCEV looked through a trivial LCSSA phi node, we might have SCEV's
8837 // directly using a SCEVUnknown/SCEVAddRec defined in the loop. After an
8838 // extra predecessor is added, this is no longer valid. Find all Unknowns and
8839 // AddRecs defined in the loop and invalidate any SCEV's making use of them.
8840 if (const SCEV *S = getExistingSCEV(V)) {
8841 struct InvalidationRootCollector {
8842 Loop *L;
8844
8845 InvalidationRootCollector(Loop *L) : L(L) {}
8846
8847 bool follow(const SCEV *S) {
8848 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
8849 if (auto *I = dyn_cast<Instruction>(SU->getValue()))
8850 if (L->contains(I))
8851 Roots.push_back(S);
8852 } else if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
8853 if (L->contains(AddRec->getLoop()))
8854 Roots.push_back(S);
8855 }
8856 return true;
8857 }
8858 bool isDone() const { return false; }
8859 };
8860
8861 InvalidationRootCollector C(L);
8862 visitAll(S, C);
8863 forgetMemoizedResults(C.Roots);
8864 }
8865
8866 // Also perform the normal invalidation.
8867 forgetValue(V);
8868}
8869
8870void ScalarEvolution::forgetLoopDispositions() { LoopDispositions.clear(); }
8871
8873 // Unless a specific value is passed to invalidation, completely clear both
8874 // caches.
8875 if (!V) {
8876 BlockDispositions.clear();
8877 LoopDispositions.clear();
8878 return;
8879 }
8880
8881 if (!isSCEVable(V->getType()))
8882 return;
8883
8884 const SCEV *S = getExistingSCEV(V);
8885 if (!S)
8886 return;
8887
8888 // Invalidate the block and loop dispositions cached for S. Dispositions of
8889 // S's users may change if S's disposition changes (i.e. a user may change to
8890 // loop-invariant, if S changes to loop invariant), so also invalidate
8891 // dispositions of S's users recursively.
8892 SmallVector<SCEVUse, 8> Worklist = {S};
8894 while (!Worklist.empty()) {
8895 const SCEV *Curr = Worklist.pop_back_val();
8896 bool LoopDispoRemoved = LoopDispositions.erase(Curr);
8897 bool BlockDispoRemoved = BlockDispositions.erase(Curr);
8898 if (!LoopDispoRemoved && !BlockDispoRemoved)
8899 continue;
8900 auto Users = SCEVUsers.find(Curr);
8901 if (Users != SCEVUsers.end())
8902 for (const auto *User : Users->second)
8903 if (Seen.insert(User).second)
8904 Worklist.push_back(User);
8905 }
8906}
8907
8908/// Get the exact loop backedge taken count considering all loop exits. A
8909/// computable result can only be returned for loops with all exiting blocks
8910/// dominating the latch. howFarToZero assumes that the limit of each loop test
8911/// is never skipped. This is a valid assumption as long as the loop exits via
8912/// that test. For precise results, it is the caller's responsibility to specify
8913/// the relevant loop exiting block using getExact(ExitingBlock, SE).
8914const SCEV *ScalarEvolution::BackedgeTakenInfo::getExact(
8915 const Loop *L, ScalarEvolution *SE,
8917 // If any exits were not computable, the loop is not computable.
8918 if (!isComplete() || ExitNotTaken.empty())
8919 return SE->getCouldNotCompute();
8920
8921 const BasicBlock *Latch = L->getLoopLatch();
8922 // All exiting blocks we have collected must dominate the only backedge.
8923 if (!Latch)
8924 return SE->getCouldNotCompute();
8925
8926 // All exiting blocks we have gathered dominate loop's latch, so exact trip
8927 // count is simply a minimum out of all these calculated exit counts.
8929 for (const auto &ENT : ExitNotTaken) {
8930 const SCEV *BECount = ENT.ExactNotTaken;
8931 assert(BECount != SE->getCouldNotCompute() && "Bad exit SCEV!");
8932 assert(SE->DT.dominates(ENT.ExitingBlock, Latch) &&
8933 "We should only have known counts for exiting blocks that dominate "
8934 "latch!");
8935
8936 Ops.push_back(BECount);
8937
8938 if (Preds)
8939 append_range(*Preds, ENT.Predicates);
8940
8941 assert((Preds || ENT.hasAlwaysTruePredicate()) &&
8942 "Predicate should be always true!");
8943 }
8944
8945 // If an earlier exit exits on the first iteration (exit count zero), then
8946 // a later poison exit count should not propagate into the result. This are
8947 // exactly the semantics provided by umin_seq.
8948 return SE->getUMinFromMismatchedTypes(Ops, /* Sequential */ true);
8949}
8950
8951const ScalarEvolution::ExitNotTakenInfo *
8952ScalarEvolution::BackedgeTakenInfo::getExitNotTaken(
8953 const BasicBlock *ExitingBlock,
8954 SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
8955 for (const auto &ENT : ExitNotTaken)
8956 if (ENT.ExitingBlock == ExitingBlock) {
8957 if (ENT.hasAlwaysTruePredicate())
8958 return &ENT;
8959 else if (Predicates) {
8960 append_range(*Predicates, ENT.Predicates);
8961 return &ENT;
8962 }
8963 }
8964
8965 return nullptr;
8966}
8967
8968/// getConstantMax - Get the constant max backedge taken count for the loop.
8969const SCEV *ScalarEvolution::BackedgeTakenInfo::getConstantMax(
8970 ScalarEvolution *SE,
8971 SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
8972 if (!getConstantMax())
8973 return SE->getCouldNotCompute();
8974
8975 for (const auto &ENT : ExitNotTaken)
8976 if (!ENT.hasAlwaysTruePredicate()) {
8977 if (!Predicates)
8978 return SE->getCouldNotCompute();
8979 append_range(*Predicates, ENT.Predicates);
8980 }
8981
8982 assert((isa<SCEVCouldNotCompute>(getConstantMax()) ||
8983 isa<SCEVConstant>(getConstantMax())) &&
8984 "No point in having a non-constant max backedge taken count!");
8985 return getConstantMax();
8986}
8987
8988const SCEV *ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(
8989 const Loop *L, ScalarEvolution *SE,
8990 SmallVectorImpl<const SCEVPredicate *> *Predicates) {
8991 if (!SymbolicMax) {
8992 // Form an expression for the maximum exit count possible for this loop. We
8993 // merge the max and exact information to approximate a version of
8994 // getConstantMaxBackedgeTakenCount which isn't restricted to just
8995 // constants.
8996 SmallVector<SCEVUse, 4> ExitCounts;
8997
8998 for (const auto &ENT : ExitNotTaken) {
8999 const SCEV *ExitCount = ENT.SymbolicMaxNotTaken;
9000 if (!isa<SCEVCouldNotCompute>(ExitCount)) {
9001 assert(SE->DT.dominates(ENT.ExitingBlock, L->getLoopLatch()) &&
9002 "We should only have known counts for exiting blocks that "
9003 "dominate latch!");
9004 ExitCounts.push_back(ExitCount);
9005 if (Predicates)
9006 append_range(*Predicates, ENT.Predicates);
9007
9008 assert((Predicates || ENT.hasAlwaysTruePredicate()) &&
9009 "Predicate should be always true!");
9010 }
9011 }
9012 if (ExitCounts.empty())
9013 SymbolicMax = SE->getCouldNotCompute();
9014 else
9015 SymbolicMax =
9016 SE->getUMinFromMismatchedTypes(ExitCounts, /*Sequential*/ true);
9017 }
9018 return SymbolicMax;
9019}
9020
9021bool ScalarEvolution::BackedgeTakenInfo::isConstantMaxOrZero(
9022 ScalarEvolution *SE) const {
9023 auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) {
9024 return !ENT.hasAlwaysTruePredicate();
9025 };
9026 return MaxOrZero && !any_of(ExitNotTaken, PredicateNotAlwaysTrue);
9027}
9028
9031
9033 const SCEV *E, const SCEV *ConstantMaxNotTaken,
9034 const SCEV *SymbolicMaxNotTaken, bool MaxOrZero,
9038 // If we prove the max count is zero, so is the symbolic bound. This happens
9039 // in practice due to differences in a) how context sensitive we've chosen
9040 // to be and b) how we reason about bounds implied by UB.
9041 if (ConstantMaxNotTaken->isZero()) {
9042 this->ExactNotTaken = E = ConstantMaxNotTaken;
9043 this->SymbolicMaxNotTaken = SymbolicMaxNotTaken = ConstantMaxNotTaken;
9044 }
9045
9048 "Exact is not allowed to be less precise than Constant Max");
9051 "Exact is not allowed to be less precise than Symbolic Max");
9054 "Symbolic Max is not allowed to be less precise than Constant Max");
9057 "No point in having a non-constant max backedge taken count!");
9059 for (const auto PredList : PredLists)
9060 for (const auto *P : PredList) {
9061 if (SeenPreds.contains(P))
9062 continue;
9063 assert(!isa<SCEVUnionPredicate>(P) && "Only add leaf predicates here!");
9064 SeenPreds.insert(P);
9065 Predicates.push_back(P);
9066 }
9067 assert((isa<SCEVCouldNotCompute>(E) || !E->getType()->isPointerTy()) &&
9068 "Backedge count should be int");
9070 !ConstantMaxNotTaken->getType()->isPointerTy()) &&
9071 "Max backedge count should be int");
9072}
9073
9081
9082/// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each
9083/// computable exit into a persistent ExitNotTakenInfo array.
9084ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(
9086 bool IsComplete, const SCEV *ConstantMax, bool MaxOrZero)
9087 : ConstantMax(ConstantMax), IsComplete(IsComplete), MaxOrZero(MaxOrZero) {
9088 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
9089
9090 ExitNotTaken.reserve(ExitCounts.size());
9091 std::transform(ExitCounts.begin(), ExitCounts.end(),
9092 std::back_inserter(ExitNotTaken),
9093 [&](const EdgeExitInfo &EEI) {
9094 BasicBlock *ExitBB = EEI.first;
9095 const ExitLimit &EL = EEI.second;
9096 return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken,
9097 EL.ConstantMaxNotTaken, EL.SymbolicMaxNotTaken,
9098 EL.Predicates);
9099 });
9100 assert((isa<SCEVCouldNotCompute>(ConstantMax) ||
9101 isa<SCEVConstant>(ConstantMax)) &&
9102 "No point in having a non-constant max backedge taken count!");
9103}
9104
9105/// Compute the number of times the backedge of the specified loop will execute.
9106ScalarEvolution::BackedgeTakenInfo
9107ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
9108 bool AllowPredicates) {
9109 SmallVector<BasicBlock *, 8> ExitingBlocks;
9110 L->getExitingBlocks(ExitingBlocks);
9111
9112 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
9113
9115 bool CouldComputeBECount = true;
9116 BasicBlock *Latch = L->getLoopLatch(); // may be NULL.
9117 const SCEV *MustExitMaxBECount = nullptr;
9118 const SCEV *MayExitMaxBECount = nullptr;
9119 bool MustExitMaxOrZero = false;
9120 bool IsOnlyExit = ExitingBlocks.size() == 1;
9121
9122 // Compute the ExitLimit for each loop exit. Use this to populate ExitCounts
9123 // and compute maxBECount.
9124 // Do a union of all the predicates here.
9125 for (BasicBlock *ExitBB : ExitingBlocks) {
9126 // We canonicalize untaken exits to br (constant), ignore them so that
9127 // proving an exit untaken doesn't negatively impact our ability to reason
9128 // about the loop as whole.
9129 if (auto *BI = dyn_cast<CondBrInst>(ExitBB->getTerminator()))
9130 if (auto *CI = dyn_cast<ConstantInt>(BI->getCondition())) {
9131 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
9132 if (ExitIfTrue == CI->isZero())
9133 continue;
9134 }
9135
9136 ExitLimit EL = computeExitLimit(L, ExitBB, IsOnlyExit, AllowPredicates);
9137
9138 assert((AllowPredicates || EL.Predicates.empty()) &&
9139 "Predicated exit limit when predicates are not allowed!");
9140
9141 // 1. For each exit that can be computed, add an entry to ExitCounts.
9142 // CouldComputeBECount is true only if all exits can be computed.
9143 if (EL.ExactNotTaken != getCouldNotCompute())
9144 ++NumExitCountsComputed;
9145 else
9146 // We couldn't compute an exact value for this exit, so
9147 // we won't be able to compute an exact value for the loop.
9148 CouldComputeBECount = false;
9149 // Remember exit count if either exact or symbolic is known. Because
9150 // Exact always implies symbolic, only check symbolic.
9151 if (EL.SymbolicMaxNotTaken != getCouldNotCompute())
9152 ExitCounts.emplace_back(ExitBB, EL);
9153 else {
9154 assert(EL.ExactNotTaken == getCouldNotCompute() &&
9155 "Exact is known but symbolic isn't?");
9156 ++NumExitCountsNotComputed;
9157 }
9158
9159 // 2. Derive the loop's MaxBECount from each exit's max number of
9160 // non-exiting iterations. Partition the loop exits into two kinds:
9161 // LoopMustExits and LoopMayExits.
9162 //
9163 // If the exit dominates the loop latch, it is a LoopMustExit otherwise it
9164 // is a LoopMayExit. If any computable LoopMustExit is found, then
9165 // MaxBECount is the minimum EL.ConstantMaxNotTaken of computable
9166 // LoopMustExits. Otherwise, MaxBECount is conservatively the maximum
9167 // EL.ConstantMaxNotTaken, where CouldNotCompute is considered greater than
9168 // any
9169 // computable EL.ConstantMaxNotTaken.
9170 if (EL.ConstantMaxNotTaken != getCouldNotCompute() && Latch &&
9171 DT.dominates(ExitBB, Latch)) {
9172 if (!MustExitMaxBECount) {
9173 MustExitMaxBECount = EL.ConstantMaxNotTaken;
9174 MustExitMaxOrZero = EL.MaxOrZero;
9175 } else {
9176 MustExitMaxBECount = getUMinFromMismatchedTypes(MustExitMaxBECount,
9177 EL.ConstantMaxNotTaken);
9178 }
9179 } else if (MayExitMaxBECount != getCouldNotCompute()) {
9180 if (!MayExitMaxBECount || EL.ConstantMaxNotTaken == getCouldNotCompute())
9181 MayExitMaxBECount = EL.ConstantMaxNotTaken;
9182 else {
9183 MayExitMaxBECount = getUMaxFromMismatchedTypes(MayExitMaxBECount,
9184 EL.ConstantMaxNotTaken);
9185 }
9186 }
9187 }
9188 const SCEV *MaxBECount = MustExitMaxBECount ? MustExitMaxBECount :
9189 (MayExitMaxBECount ? MayExitMaxBECount : getCouldNotCompute());
9190 // The loop backedge will be taken the maximum or zero times if there's
9191 // a single exit that must be taken the maximum or zero times.
9192 bool MaxOrZero = (MustExitMaxOrZero && ExitingBlocks.size() == 1);
9193
9194 // Remember which SCEVs are used in exit limits for invalidation purposes.
9195 // We only care about non-constant SCEVs here, so we can ignore
9196 // EL.ConstantMaxNotTaken
9197 // and MaxBECount, which must be SCEVConstant.
9198 for (const auto &Pair : ExitCounts) {
9199 if (!isa<SCEVConstant>(Pair.second.ExactNotTaken))
9200 BECountUsers[Pair.second.ExactNotTaken].insert({L, AllowPredicates});
9201 if (!isa<SCEVConstant>(Pair.second.SymbolicMaxNotTaken))
9202 BECountUsers[Pair.second.SymbolicMaxNotTaken].insert(
9203 {L, AllowPredicates});
9204 }
9205 return BackedgeTakenInfo(std::move(ExitCounts), CouldComputeBECount,
9206 MaxBECount, MaxOrZero);
9207}
9208
9209ScalarEvolution::ExitLimit
9210ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
9211 bool IsOnlyExit, bool AllowPredicates) {
9212 assert(L->contains(ExitingBlock) && "Exit count for non-loop block?");
9213 // If our exiting block does not dominate the latch, then its connection with
9214 // loop's exit limit may be far from trivial.
9215 const BasicBlock *Latch = L->getLoopLatch();
9216 if (!Latch || !DT.dominates(ExitingBlock, Latch))
9217 return getCouldNotCompute();
9218
9219 Instruction *Term = ExitingBlock->getTerminator();
9220 if (CondBrInst *BI = dyn_cast<CondBrInst>(Term)) {
9221 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
9222 assert(ExitIfTrue == L->contains(BI->getSuccessor(1)) &&
9223 "It should have one successor in loop and one exit block!");
9224 // Proceed to the next level to examine the exit condition expression.
9225 return computeExitLimitFromCond(L, BI->getCondition(), ExitIfTrue,
9226 /*ControlsOnlyExit=*/IsOnlyExit,
9227 AllowPredicates);
9228 }
9229
9230 if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) {
9231 // For switch, make sure that there is a single exit from the loop.
9232 BasicBlock *Exit = nullptr;
9233 for (auto *SBB : successors(ExitingBlock))
9234 if (!L->contains(SBB)) {
9235 if (Exit) // Multiple exit successors.
9236 return getCouldNotCompute();
9237 Exit = SBB;
9238 }
9239 assert(Exit && "Exiting block must have at least one exit");
9240 return computeExitLimitFromSingleExitSwitch(
9241 L, SI, Exit, /*ControlsOnlyExit=*/IsOnlyExit);
9242 }
9243
9244 return getCouldNotCompute();
9245}
9246
9248 const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
9249 bool AllowPredicates) {
9250 ScalarEvolution::ExitLimitCacheTy Cache(L, ExitIfTrue, AllowPredicates);
9251 return computeExitLimitFromCondCached(Cache, L, ExitCond, ExitIfTrue,
9252 ControlsOnlyExit, AllowPredicates);
9253}
9254
9255std::optional<ScalarEvolution::ExitLimit>
9256ScalarEvolution::ExitLimitCache::find(const Loop *L, Value *ExitCond,
9257 bool ExitIfTrue, bool ControlsOnlyExit,
9258 bool AllowPredicates) {
9259 (void)this->L;
9260 (void)this->ExitIfTrue;
9261 (void)this->AllowPredicates;
9262
9263 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
9264 this->AllowPredicates == AllowPredicates &&
9265 "Variance in assumed invariant key components!");
9266 auto Itr = TripCountMap.find({ExitCond, ControlsOnlyExit});
9267 if (Itr == TripCountMap.end())
9268 return std::nullopt;
9269 return Itr->second;
9270}
9271
9272void ScalarEvolution::ExitLimitCache::insert(const Loop *L, Value *ExitCond,
9273 bool ExitIfTrue,
9274 bool ControlsOnlyExit,
9275 bool AllowPredicates,
9276 const ExitLimit &EL) {
9277 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
9278 this->AllowPredicates == AllowPredicates &&
9279 "Variance in assumed invariant key components!");
9280
9281 auto InsertResult = TripCountMap.insert({{ExitCond, ControlsOnlyExit}, EL});
9282 assert(InsertResult.second && "Expected successful insertion!");
9283 (void)InsertResult;
9284 (void)ExitIfTrue;
9285}
9286
9287ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached(
9288 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9289 bool ControlsOnlyExit, bool AllowPredicates) {
9290
9291 if (auto MaybeEL = Cache.find(L, ExitCond, ExitIfTrue, ControlsOnlyExit,
9292 AllowPredicates))
9293 return *MaybeEL;
9294
9295 ExitLimit EL = computeExitLimitFromCondImpl(
9296 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates);
9297 Cache.insert(L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates, EL);
9298 return EL;
9299}
9300
9301ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
9302 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9303 bool ControlsOnlyExit, bool AllowPredicates) {
9304 // Handle BinOp conditions (And, Or).
9305 if (auto LimitFromBinOp = computeExitLimitFromCondFromBinOp(
9306 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates))
9307 return *LimitFromBinOp;
9308
9309 // With an icmp, it may be feasible to compute an exact backedge-taken count.
9310 // Proceed to the next level to examine the icmp.
9311 if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond)) {
9312 ExitLimit EL =
9313 computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsOnlyExit);
9314 if (EL.hasFullInfo() || !AllowPredicates)
9315 return EL;
9316
9317 // Try again, but use SCEV predicates this time.
9318 return computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue,
9319 ControlsOnlyExit,
9320 /*AllowPredicates=*/true);
9321 }
9322
9323 // Check for a constant condition. These are normally stripped out by
9324 // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to
9325 // preserve the CFG and is temporarily leaving constant conditions
9326 // in place.
9327 if (ConstantInt *CI = dyn_cast<ConstantInt>(ExitCond)) {
9328 if (ExitIfTrue == !CI->getZExtValue())
9329 // The backedge is always taken.
9330 return getCouldNotCompute();
9331 // The backedge is never taken.
9332 return getZero(CI->getType());
9333 }
9334
9335 // If we're exiting based on the overflow flag of an x.with.overflow intrinsic
9336 // with a constant step, we can form an equivalent icmp predicate and figure
9337 // out how many iterations will be taken before we exit.
9338 const WithOverflowInst *WO;
9339 const APInt *C;
9340 if (match(ExitCond, m_ExtractValue<1>(m_WithOverflowInst(WO))) &&
9341 match(WO->getRHS(), m_APInt(C))) {
9342 ConstantRange NWR =
9344 WO->getNoWrapKind());
9345 CmpInst::Predicate Pred;
9346 APInt NewRHSC, Offset;
9347 NWR.getEquivalentICmp(Pred, NewRHSC, Offset);
9348 if (!ExitIfTrue)
9349 Pred = ICmpInst::getInversePredicate(Pred);
9350 auto *LHS = getSCEV(WO->getLHS());
9351 if (Offset != 0)
9353 auto EL = computeExitLimitFromICmp(L, Pred, LHS, getConstant(NewRHSC),
9354 ControlsOnlyExit, AllowPredicates);
9355 if (EL.hasAnyInfo())
9356 return EL;
9357 }
9358
9359 // If it's not an integer or pointer comparison then compute it the hard way.
9360 return computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
9361}
9362
9363std::optional<ScalarEvolution::ExitLimit>
9364ScalarEvolution::computeExitLimitFromCondFromBinOp(
9365 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9366 bool ControlsOnlyExit, bool AllowPredicates) {
9367 // Check if the controlling expression for this loop is an And or Or.
9368 Value *Op0, *Op1;
9369 bool IsAnd = false;
9370 if (match(ExitCond, m_LogicalAnd(m_Value(Op0), m_Value(Op1))))
9371 IsAnd = true;
9372 else if (match(ExitCond, m_LogicalOr(m_Value(Op0), m_Value(Op1))))
9373 IsAnd = false;
9374 else
9375 return std::nullopt;
9376
9377 // EitherMayExit is true in these two cases:
9378 // br (and Op0 Op1), loop, exit
9379 // br (or Op0 Op1), exit, loop
9380 bool EitherMayExit = IsAnd ^ ExitIfTrue;
9381 ExitLimit EL0 = computeExitLimitFromCondCached(
9382 Cache, L, Op0, ExitIfTrue, ControlsOnlyExit && !EitherMayExit,
9383 AllowPredicates);
9384 ExitLimit EL1 = computeExitLimitFromCondCached(
9385 Cache, L, Op1, ExitIfTrue, ControlsOnlyExit && !EitherMayExit,
9386 AllowPredicates);
9387
9388 // Be robust against unsimplified IR for the form "op i1 X, NeutralElement"
9389 const Constant *NeutralElement = ConstantInt::get(ExitCond->getType(), IsAnd);
9390 if (isa<ConstantInt>(Op1))
9391 return Op1 == NeutralElement ? EL0 : EL1;
9392 if (isa<ConstantInt>(Op0))
9393 return Op0 == NeutralElement ? EL1 : EL0;
9394
9395 const SCEV *BECount = getCouldNotCompute();
9396 const SCEV *ConstantMaxBECount = getCouldNotCompute();
9397 const SCEV *SymbolicMaxBECount = getCouldNotCompute();
9398 if (EitherMayExit) {
9399 bool UseSequentialUMin = !isa<BinaryOperator>(ExitCond);
9400 // Both conditions must be same for the loop to continue executing.
9401 // Choose the less conservative count.
9402 if (EL0.ExactNotTaken != getCouldNotCompute() &&
9403 EL1.ExactNotTaken != getCouldNotCompute()) {
9404 BECount = getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken,
9405 UseSequentialUMin);
9406 }
9407 if (EL0.ConstantMaxNotTaken == getCouldNotCompute())
9408 ConstantMaxBECount = EL1.ConstantMaxNotTaken;
9409 else if (EL1.ConstantMaxNotTaken == getCouldNotCompute())
9410 ConstantMaxBECount = EL0.ConstantMaxNotTaken;
9411 else
9412 ConstantMaxBECount = getUMinFromMismatchedTypes(EL0.ConstantMaxNotTaken,
9413 EL1.ConstantMaxNotTaken);
9414 if (EL0.SymbolicMaxNotTaken == getCouldNotCompute())
9415 SymbolicMaxBECount = EL1.SymbolicMaxNotTaken;
9416 else if (EL1.SymbolicMaxNotTaken == getCouldNotCompute())
9417 SymbolicMaxBECount = EL0.SymbolicMaxNotTaken;
9418 else
9419 SymbolicMaxBECount = getUMinFromMismatchedTypes(
9420 EL0.SymbolicMaxNotTaken, EL1.SymbolicMaxNotTaken, UseSequentialUMin);
9421 } else {
9422 // Both conditions must be same at the same time for the loop to exit.
9423 // For now, be conservative.
9424 if (EL0.ExactNotTaken == EL1.ExactNotTaken)
9425 BECount = EL0.ExactNotTaken;
9426 }
9427
9428 // There are cases (e.g. PR26207) where computeExitLimitFromCond is able
9429 // to be more aggressive when computing BECount than when computing
9430 // ConstantMaxBECount. In these cases it is possible for EL0.ExactNotTaken
9431 // and
9432 // EL1.ExactNotTaken to match, but for EL0.ConstantMaxNotTaken and
9433 // EL1.ConstantMaxNotTaken to not.
9434 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
9435 !isa<SCEVCouldNotCompute>(BECount))
9436 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
9437 if (isa<SCEVCouldNotCompute>(SymbolicMaxBECount))
9438 SymbolicMaxBECount =
9439 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
9440 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
9441 {ArrayRef(EL0.Predicates), ArrayRef(EL1.Predicates)});
9442}
9443
9444ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9445 const Loop *L, ICmpInst *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
9446 bool AllowPredicates) {
9447 // If the condition was exit on true, convert the condition to exit on false
9448 CmpPredicate Pred;
9449 if (!ExitIfTrue)
9450 Pred = ExitCond->getCmpPredicate();
9451 else
9452 Pred = ExitCond->getInverseCmpPredicate();
9453 const ICmpInst::Predicate OriginalPred = Pred;
9454
9455 const SCEV *LHS = getSCEV(ExitCond->getOperand(0));
9456 const SCEV *RHS = getSCEV(ExitCond->getOperand(1));
9457
9458 ExitLimit EL = computeExitLimitFromICmp(L, Pred, LHS, RHS, ControlsOnlyExit,
9459 AllowPredicates);
9460 if (EL.hasAnyInfo())
9461 return EL;
9462
9463 auto *ExhaustiveCount =
9464 computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
9465
9466 if (!isa<SCEVCouldNotCompute>(ExhaustiveCount))
9467 return ExhaustiveCount;
9468
9469 return computeShiftCompareExitLimit(ExitCond->getOperand(0),
9470 ExitCond->getOperand(1), L, OriginalPred);
9471}
9472ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9473 const Loop *L, CmpPredicate Pred, SCEVUse LHS, SCEVUse RHS,
9474 bool ControlsOnlyExit, bool AllowPredicates) {
9475
9476 // Try to evaluate any dependencies out of the loop.
9477 LHS = getSCEVAtScope(LHS, L);
9478 RHS = getSCEVAtScope(RHS, L);
9479
9480 // At this point, we would like to compute how many iterations of the
9481 // loop the predicate will return true for these inputs.
9482 if (isLoopInvariant(LHS, L) && !isLoopInvariant(RHS, L)) {
9483 // If there is a loop-invariant, force it into the RHS.
9484 std::swap(LHS, RHS);
9486 }
9487
9488 bool ControllingFiniteLoop = ControlsOnlyExit && loopHasNoAbnormalExits(L) &&
9490 // Simplify the operands before analyzing them.
9491 (void)SimplifyICmpOperands(Pred, LHS, RHS, /*Depth=*/0);
9492
9493 // If we have a comparison of a chrec against a constant, try to use value
9494 // ranges to answer this query.
9495 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS))
9496 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS))
9497 if (AddRec->getLoop() == L) {
9498 // Form the constant range.
9499 ConstantRange CompRange =
9500 ConstantRange::makeExactICmpRegion(Pred, RHSC->getAPInt());
9501
9502 const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this);
9503 if (!isa<SCEVCouldNotCompute>(Ret)) return Ret;
9504 }
9505
9506 // If this loop must exit based on this condition (or execute undefined
9507 // behaviour), see if we can improve wrap flags. This is essentially
9508 // a must execute style proof.
9509 if (ControllingFiniteLoop && isLoopInvariant(RHS, L)) {
9510 // If we can prove the test sequence produced must repeat the same values
9511 // on self-wrap of the IV, then we can infer that IV doesn't self wrap
9512 // because if it did, we'd have an infinite (undefined) loop.
9513 // TODO: We can peel off any functions which are invertible *in L*. Loop
9514 // invariant terms are effectively constants for our purposes here.
9515 SCEVUse InnerLHS = LHS;
9516 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS))
9517 InnerLHS = ZExt->getOperand();
9518 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(InnerLHS);
9519 AR && !AR->hasNoSelfWrap() && AR->getLoop() == L && AR->isAffine() &&
9520 isKnownToBeAPowerOfTwo(AR->getStepRecurrence(*this), /*OrZero=*/true,
9521 /*OrNegative=*/true)) {
9522 auto Flags = AR->getNoWrapFlags();
9523 Flags = setFlags(Flags, SCEV::FlagNW);
9524 SmallVector<SCEVUse> Operands{AR->operands()};
9525 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
9526 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
9527 }
9528
9529 // For a slt/ult condition with a positive step, can we prove nsw/nuw?
9530 // From no-self-wrap, this follows trivially from the fact that every
9531 // (un)signed-wrapped, but not self-wrapped value must be LT than the
9532 // last value before (un)signed wrap. Since we know that last value
9533 // didn't exit, nor will any smaller one.
9534 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT) {
9535 auto WrapType = Pred == ICmpInst::ICMP_SLT ? SCEV::FlagNSW : SCEV::FlagNUW;
9536 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS);
9537 AR && AR->getLoop() == L && AR->isAffine() &&
9538 !AR->getNoWrapFlags(WrapType) && AR->hasNoSelfWrap() &&
9539 isKnownPositive(AR->getStepRecurrence(*this))) {
9540 auto Flags = AR->getNoWrapFlags();
9541 Flags = setFlags(Flags, WrapType);
9542 SmallVector<SCEVUse> Operands{AR->operands()};
9543 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
9544 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
9545 }
9546 }
9547 }
9548
9549 switch (Pred) {
9550 case ICmpInst::ICMP_NE: { // while (X != Y)
9551 // Convert to: while (X-Y != 0)
9552 if (LHS->getType()->isPointerTy()) {
9555 return LHS;
9556 }
9557 if (RHS->getType()->isPointerTy()) {
9560 return RHS;
9561 }
9562 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit,
9563 AllowPredicates);
9564 if (EL.hasAnyInfo())
9565 return EL;
9566 break;
9567 }
9568 case ICmpInst::ICMP_EQ: { // while (X == Y)
9569 // Convert to: while (X-Y == 0)
9570 if (LHS->getType()->isPointerTy()) {
9573 return LHS;
9574 }
9575 if (RHS->getType()->isPointerTy()) {
9578 return RHS;
9579 }
9580 ExitLimit EL = howFarToNonZero(getMinusSCEV(LHS, RHS), L);
9581 if (EL.hasAnyInfo()) return EL;
9582 break;
9583 }
9584 case ICmpInst::ICMP_SLE:
9585 case ICmpInst::ICMP_ULE:
9586 // Since the loop is finite, an invariant RHS cannot include the boundary
9587 // value, otherwise it would loop forever.
9588 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9589 !isLoopInvariant(RHS, L)) {
9590 // Otherwise, perform the addition in a wider type, to avoid overflow.
9591 // If the LHS is an addrec with the appropriate nowrap flag, the
9592 // extension will be sunk into it and the exit count can be analyzed.
9593 auto *OldType = dyn_cast<IntegerType>(LHS->getType());
9594 if (!OldType)
9595 break;
9596 // Prefer doubling the bitwidth over adding a single bit to make it more
9597 // likely that we use a legal type.
9598 auto *NewType =
9599 Type::getIntNTy(OldType->getContext(), OldType->getBitWidth() * 2);
9600 if (ICmpInst::isSigned(Pred)) {
9601 LHS = getSignExtendExpr(LHS, NewType);
9602 RHS = getSignExtendExpr(RHS, NewType);
9603 } else {
9604 LHS = getZeroExtendExpr(LHS, NewType);
9605 RHS = getZeroExtendExpr(RHS, NewType);
9606 }
9607 }
9609 [[fallthrough]];
9610 case ICmpInst::ICMP_SLT:
9611 case ICmpInst::ICMP_ULT: { // while (X < Y)
9612 bool IsSigned = ICmpInst::isSigned(Pred);
9613 ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9614 AllowPredicates);
9615 if (EL.hasAnyInfo())
9616 return EL;
9617 break;
9618 }
9619 case ICmpInst::ICMP_SGE:
9620 case ICmpInst::ICMP_UGE:
9621 // Since the loop is finite, an invariant RHS cannot include the boundary
9622 // value, otherwise it would loop forever.
9623 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9624 !isLoopInvariant(RHS, L))
9625 break;
9627 [[fallthrough]];
9628 case ICmpInst::ICMP_SGT:
9629 case ICmpInst::ICMP_UGT: { // while (X > Y)
9630 bool IsSigned = ICmpInst::isSigned(Pred);
9631 ExitLimit EL = howManyGreaterThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9632 AllowPredicates);
9633 if (EL.hasAnyInfo())
9634 return EL;
9635 break;
9636 }
9637 default:
9638 break;
9639 }
9640
9641 return getCouldNotCompute();
9642}
9643
9644ScalarEvolution::ExitLimit
9645ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L,
9646 SwitchInst *Switch,
9647 BasicBlock *ExitingBlock,
9648 bool ControlsOnlyExit) {
9649 assert(!L->contains(ExitingBlock) && "Not an exiting block!");
9650
9651 // Give up if the exit is the default dest of a switch.
9652 if (Switch->getDefaultDest() == ExitingBlock)
9653 return getCouldNotCompute();
9654
9655 assert(L->contains(Switch->getDefaultDest()) &&
9656 "Default case must not exit the loop!");
9657 const SCEV *LHS = getSCEVAtScope(Switch->getCondition(), L);
9658 const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock));
9659
9660 // while (X != Y) --> while (X-Y != 0)
9661 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit);
9662 if (EL.hasAnyInfo())
9663 return EL;
9664
9665 return getCouldNotCompute();
9666}
9667
9668static ConstantInt *
9670 ScalarEvolution &SE) {
9671 const SCEV *InVal = SE.getConstant(C);
9672 const SCEV *Val = AddRec->evaluateAtIteration(InVal, SE);
9674 "Evaluation of SCEV at constant didn't fold correctly?");
9675 return cast<SCEVConstant>(Val)->getValue();
9676}
9677
9678ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit(
9679 Value *LHS, Value *RHSV, const Loop *L, ICmpInst::Predicate Pred) {
9680 ConstantInt *RHS = dyn_cast<ConstantInt>(RHSV);
9681 if (!RHS)
9682 return getCouldNotCompute();
9683
9684 const BasicBlock *Latch = L->getLoopLatch();
9685 if (!Latch)
9686 return getCouldNotCompute();
9687
9688 const BasicBlock *Predecessor = L->getLoopPredecessor();
9689 if (!Predecessor)
9690 return getCouldNotCompute();
9691
9692 // Return true if V is of the form "LHS `shift_op` <positive constant>".
9693 // Return LHS in OutLHS and shift_opt in OutOpCode.
9694 auto MatchPositiveShift =
9695 [](Value *V, Value *&OutLHS, Instruction::BinaryOps &OutOpCode) {
9696
9697 using namespace PatternMatch;
9698
9699 ConstantInt *ShiftAmt;
9700 if (match(V, m_LShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9701 OutOpCode = Instruction::LShr;
9702 else if (match(V, m_AShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9703 OutOpCode = Instruction::AShr;
9704 else if (match(V, m_Shl(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9705 OutOpCode = Instruction::Shl;
9706 else
9707 return false;
9708
9709 return ShiftAmt->getValue().isStrictlyPositive();
9710 };
9711
9712 // Recognize a "shift recurrence" either of the form %iv or of %iv.shifted in
9713 //
9714 // loop:
9715 // %iv = phi i32 [ %iv.shifted, %loop ], [ %val, %preheader ]
9716 // %iv.shifted = lshr i32 %iv, <positive constant>
9717 //
9718 // Return true on a successful match. Return the corresponding PHI node (%iv
9719 // above) in PNOut and the opcode of the shift operation in OpCodeOut.
9720 auto MatchShiftRecurrence =
9721 [&](Value *V, PHINode *&PNOut, Instruction::BinaryOps &OpCodeOut) {
9722 std::optional<Instruction::BinaryOps> PostShiftOpCode;
9723
9724 {
9726 Value *V;
9727
9728 // If we encounter a shift instruction, "peel off" the shift operation,
9729 // and remember that we did so. Later when we inspect %iv's backedge
9730 // value, we will make sure that the backedge value uses the same
9731 // operation.
9732 //
9733 // Note: the peeled shift operation does not have to be the same
9734 // instruction as the one feeding into the PHI's backedge value. We only
9735 // really care about it being the same *kind* of shift instruction --
9736 // that's all that is required for our later inferences to hold.
9737 if (MatchPositiveShift(LHS, V, OpC)) {
9738 PostShiftOpCode = OpC;
9739 LHS = V;
9740 }
9741 }
9742
9743 PNOut = dyn_cast<PHINode>(LHS);
9744 if (!PNOut || PNOut->getParent() != L->getHeader())
9745 return false;
9746
9747 Value *BEValue = PNOut->getIncomingValueForBlock(Latch);
9748 Value *OpLHS;
9749
9750 return
9751 // The backedge value for the PHI node must be a shift by a positive
9752 // amount
9753 MatchPositiveShift(BEValue, OpLHS, OpCodeOut) &&
9754
9755 // of the PHI node itself
9756 OpLHS == PNOut &&
9757
9758 // and the kind of shift should be match the kind of shift we peeled
9759 // off, if any.
9760 (!PostShiftOpCode || *PostShiftOpCode == OpCodeOut);
9761 };
9762
9763 PHINode *PN;
9765 if (!MatchShiftRecurrence(LHS, PN, OpCode))
9766 return getCouldNotCompute();
9767
9768 const DataLayout &DL = getDataLayout();
9769
9770 // The key rationale for this optimization is that for some kinds of shift
9771 // recurrences, the value of the recurrence "stabilizes" to either 0 or -1
9772 // within a finite number of iterations. If the condition guarding the
9773 // backedge (in the sense that the backedge is taken if the condition is true)
9774 // is false for the value the shift recurrence stabilizes to, then we know
9775 // that the backedge is taken only a finite number of times.
9776
9777 ConstantInt *StableValue = nullptr;
9778 switch (OpCode) {
9779 default:
9780 llvm_unreachable("Impossible case!");
9781
9782 case Instruction::AShr: {
9783 // {K,ashr,<positive-constant>} stabilizes to signum(K) in at most
9784 // bitwidth(K) iterations.
9785 Value *FirstValue = PN->getIncomingValueForBlock(Predecessor);
9786 KnownBits Known = computeKnownBits(FirstValue, DL, &AC,
9787 Predecessor->getTerminator(), &DT);
9788 auto *Ty = cast<IntegerType>(RHS->getType());
9789 if (Known.isNonNegative())
9790 StableValue = ConstantInt::get(Ty, 0);
9791 else if (Known.isNegative())
9792 StableValue = ConstantInt::get(Ty, -1, true);
9793 else
9794 return getCouldNotCompute();
9795
9796 break;
9797 }
9798 case Instruction::LShr:
9799 case Instruction::Shl:
9800 // Both {K,lshr,<positive-constant>} and {K,shl,<positive-constant>}
9801 // stabilize to 0 in at most bitwidth(K) iterations.
9802 StableValue = ConstantInt::get(cast<IntegerType>(RHS->getType()), 0);
9803 break;
9804 }
9805
9806 auto *Result =
9807 ConstantFoldCompareInstOperands(Pred, StableValue, RHS, DL, &TLI);
9808 assert(Result->getType()->isIntegerTy(1) &&
9809 "Otherwise cannot be an operand to a branch instruction");
9810
9811 if (Result->isNullValue()) {
9812 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
9813 const SCEV *UpperBound =
9815 return ExitLimit(getCouldNotCompute(), UpperBound, UpperBound, false);
9816 }
9817
9818 return getCouldNotCompute();
9819}
9820
9821/// Return true if we can constant fold an instruction of the specified type,
9822/// assuming that all operands were constants.
9823static bool CanConstantFold(const Instruction *I) {
9827 return true;
9828
9829 if (const CallInst *CI = dyn_cast<CallInst>(I))
9830 if (const Function *F = CI->getCalledFunction())
9831 return canConstantFoldCallTo(CI, F);
9832 return false;
9833}
9834
9835/// Determine whether this instruction can constant evolve within this loop
9836/// assuming its operands can all constant evolve.
9837static bool canConstantEvolve(Instruction *I, const Loop *L) {
9838 // An instruction outside of the loop can't be derived from a loop PHI.
9839 if (!L->contains(I)) return false;
9840
9841 if (isa<PHINode>(I)) {
9842 // We don't currently keep track of the control flow needed to evaluate
9843 // PHIs, so we cannot handle PHIs inside of loops.
9844 return L->getHeader() == I->getParent();
9845 }
9846
9847 // If we won't be able to constant fold this expression even if the operands
9848 // are constants, bail early.
9849 return CanConstantFold(I);
9850}
9851
9852/// getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by
9853/// recursing through each instruction operand until reaching a loop header phi.
9854static PHINode *
9857 unsigned Depth) {
9859 return nullptr;
9860
9861 // Otherwise, we can evaluate this instruction if all of its operands are
9862 // constant or derived from a PHI node themselves.
9863 PHINode *PHI = nullptr;
9864 for (Value *Op : UseInst->operands()) {
9865 if (isa<Constant>(Op)) continue;
9866
9868 if (!OpInst || !canConstantEvolve(OpInst, L)) return nullptr;
9869
9870 PHINode *P = dyn_cast<PHINode>(OpInst);
9871 if (!P)
9872 // If this operand is already visited, reuse the prior result.
9873 // We may have P != PHI if this is the deepest point at which the
9874 // inconsistent paths meet.
9875 P = PHIMap.lookup(OpInst);
9876 if (!P) {
9877 // Recurse and memoize the results, whether a phi is found or not.
9878 // This recursive call invalidates pointers into PHIMap.
9879 P = getConstantEvolvingPHIOperands(OpInst, L, PHIMap, Depth + 1);
9880 PHIMap[OpInst] = P;
9881 }
9882 if (!P)
9883 return nullptr; // Not evolving from PHI
9884 if (PHI && PHI != P)
9885 return nullptr; // Evolving from multiple different PHIs.
9886 PHI = P;
9887 }
9888 // This is a expression evolving from a constant PHI!
9889 return PHI;
9890}
9891
9892/// getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node
9893/// in the loop that V is derived from. We allow arbitrary operations along the
9894/// way, but the operands of an operation must either be constants or a value
9895/// derived from a constant PHI. If this expression does not fit with these
9896/// constraints, return null.
9899 if (!I || !canConstantEvolve(I, L)) return nullptr;
9900
9901 if (PHINode *PN = dyn_cast<PHINode>(I))
9902 return PN;
9903
9904 // Record non-constant instructions contained by the loop.
9906 return getConstantEvolvingPHIOperands(I, L, PHIMap, 0);
9907}
9908
9909/// EvaluateExpression - Given an expression that passes the
9910/// getConstantEvolvingPHI predicate, evaluate its value assuming the PHI node
9911/// in the loop has the value PHIVal. If we can't fold this expression for some
9912/// reason, return null.
9915 const DataLayout &DL,
9916 const TargetLibraryInfo *TLI) {
9917 // Convenient constant check, but redundant for recursive calls.
9918 if (Constant *C = dyn_cast<Constant>(V)) return C;
9920 if (!I) return nullptr;
9921
9922 if (Constant *C = Vals.lookup(I)) return C;
9923
9924 // An instruction inside the loop depends on a value outside the loop that we
9925 // weren't given a mapping for, or a value such as a call inside the loop.
9926 if (!canConstantEvolve(I, L)) return nullptr;
9927
9928 // An unmapped PHI can be due to a branch or another loop inside this loop,
9929 // or due to this not being the initial iteration through a loop where we
9930 // couldn't compute the evolution of this particular PHI last time.
9931 if (isa<PHINode>(I)) return nullptr;
9932
9933 std::vector<Constant*> Operands(I->getNumOperands());
9934
9935 for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
9936 Instruction *Operand = dyn_cast<Instruction>(I->getOperand(i));
9937 if (!Operand) {
9938 Operands[i] = dyn_cast<Constant>(I->getOperand(i));
9939 if (!Operands[i]) return nullptr;
9940 continue;
9941 }
9942 Constant *C = EvaluateExpression(Operand, L, Vals, DL, TLI);
9943 Vals[Operand] = C;
9944 if (!C) return nullptr;
9945 Operands[i] = C;
9946 }
9947
9948 return ConstantFoldInstOperands(I, Operands, DL, TLI,
9949 /*AllowNonDeterministic=*/false);
9950}
9951
9952
9953// If every incoming value to PN except the one for BB is a specific Constant,
9954// return that, else return nullptr.
9956 Constant *IncomingVal = nullptr;
9957
9958 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
9959 if (PN->getIncomingBlock(i) == BB)
9960 continue;
9961
9962 auto *CurrentVal = dyn_cast<Constant>(PN->getIncomingValue(i));
9963 if (!CurrentVal)
9964 return nullptr;
9965
9966 if (IncomingVal != CurrentVal) {
9967 if (IncomingVal)
9968 return nullptr;
9969 IncomingVal = CurrentVal;
9970 }
9971 }
9972
9973 return IncomingVal;
9974}
9975
9976/// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
9977/// in the header of its containing loop, we know the loop executes a
9978/// constant number of times, and the PHI node is just a recurrence
9979/// involving constants, fold it.
9980Constant *
9981ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN,
9982 const APInt &BEs,
9983 const Loop *L) {
9984 auto [I, Inserted] = ConstantEvolutionLoopExitValue.try_emplace(PN);
9985 if (!Inserted)
9986 return I->second;
9987
9989 return nullptr; // Not going to evaluate it.
9990
9991 Constant *&RetVal = I->second;
9992
9993 DenseMap<Instruction *, Constant *> CurrentIterVals;
9994 BasicBlock *Header = L->getHeader();
9995 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
9996
9997 BasicBlock *Latch = L->getLoopLatch();
9998 if (!Latch)
9999 return nullptr;
10000
10001 for (PHINode &PHI : Header->phis()) {
10002 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
10003 CurrentIterVals[&PHI] = StartCST;
10004 }
10005 if (!CurrentIterVals.count(PN))
10006 return RetVal = nullptr;
10007
10008 Value *BEValue = PN->getIncomingValueForBlock(Latch);
10009
10010 // Execute the loop symbolically to determine the exit value.
10011 assert(BEs.getActiveBits() < CHAR_BIT * sizeof(unsigned) &&
10012 "BEs is <= MaxBruteForceIterations which is an 'unsigned'!");
10013
10014 unsigned NumIterations = BEs.getZExtValue(); // must be in range
10015 unsigned IterationNum = 0;
10016 const DataLayout &DL = getDataLayout();
10017 for (; ; ++IterationNum) {
10018 if (IterationNum == NumIterations)
10019 return RetVal = CurrentIterVals[PN]; // Got exit value!
10020
10021 // Compute the value of the PHIs for the next iteration.
10022 // EvaluateExpression adds non-phi values to the CurrentIterVals map.
10023 DenseMap<Instruction *, Constant *> NextIterVals;
10024 Constant *NextPHI =
10025 EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
10026 if (!NextPHI)
10027 return nullptr; // Couldn't evaluate!
10028 NextIterVals[PN] = NextPHI;
10029
10030 bool StoppedEvolving = NextPHI == CurrentIterVals[PN];
10031
10032 // Also evaluate the other PHI nodes. However, we don't get to stop if we
10033 // cease to be able to evaluate one of them or if they stop evolving,
10034 // because that doesn't necessarily prevent us from computing PN.
10036 for (const auto &I : CurrentIterVals) {
10037 PHINode *PHI = dyn_cast<PHINode>(I.first);
10038 if (!PHI || PHI == PN || PHI->getParent() != Header) continue;
10039 PHIsToCompute.emplace_back(PHI, I.second);
10040 }
10041 // We use two distinct loops because EvaluateExpression may invalidate any
10042 // iterators into CurrentIterVals.
10043 for (const auto &I : PHIsToCompute) {
10044 PHINode *PHI = I.first;
10045 Constant *&NextPHI = NextIterVals[PHI];
10046 if (!NextPHI) { // Not already computed.
10047 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
10048 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
10049 }
10050 if (NextPHI != I.second)
10051 StoppedEvolving = false;
10052 }
10053
10054 // If all entries in CurrentIterVals == NextIterVals then we can stop
10055 // iterating, the loop can't continue to change.
10056 if (StoppedEvolving)
10057 return RetVal = CurrentIterVals[PN];
10058
10059 CurrentIterVals.swap(NextIterVals);
10060 }
10061}
10062
10063const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L,
10064 Value *Cond,
10065 bool ExitWhen) {
10066 PHINode *PN = getConstantEvolvingPHI(Cond, L);
10067 if (!PN) return getCouldNotCompute();
10068
10069 // If the loop is canonicalized, the PHI will have exactly two entries.
10070 // That's the only form we support here.
10071 if (PN->getNumIncomingValues() != 2) return getCouldNotCompute();
10072
10073 DenseMap<Instruction *, Constant *> CurrentIterVals;
10074 BasicBlock *Header = L->getHeader();
10075 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
10076
10077 BasicBlock *Latch = L->getLoopLatch();
10078 assert(Latch && "Should follow from NumIncomingValues == 2!");
10079
10080 for (PHINode &PHI : Header->phis()) {
10081 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
10082 CurrentIterVals[&PHI] = StartCST;
10083 }
10084 if (!CurrentIterVals.count(PN))
10085 return getCouldNotCompute();
10086
10087 // Okay, we find a PHI node that defines the trip count of this loop. Execute
10088 // the loop symbolically to determine when the condition gets a value of
10089 // "ExitWhen".
10090 unsigned MaxIterations = MaxBruteForceIterations; // Limit analysis.
10091 const DataLayout &DL = getDataLayout();
10092 for (unsigned IterationNum = 0; IterationNum != MaxIterations;++IterationNum){
10093 auto *CondVal = dyn_cast_or_null<ConstantInt>(
10094 EvaluateExpression(Cond, L, CurrentIterVals, DL, &TLI));
10095
10096 // Couldn't symbolically evaluate.
10097 if (!CondVal) return getCouldNotCompute();
10098
10099 if (CondVal->getValue() == uint64_t(ExitWhen)) {
10100 ++NumBruteForceTripCountsComputed;
10101 return getConstant(Type::getInt32Ty(getContext()), IterationNum);
10102 }
10103
10104 // Update all the PHI nodes for the next iteration.
10105 DenseMap<Instruction *, Constant *> NextIterVals;
10106
10107 // Create a list of which PHIs we need to compute. We want to do this before
10108 // calling EvaluateExpression on them because that may invalidate iterators
10109 // into CurrentIterVals.
10110 SmallVector<PHINode *, 8> PHIsToCompute;
10111 for (const auto &I : CurrentIterVals) {
10112 PHINode *PHI = dyn_cast<PHINode>(I.first);
10113 if (!PHI || PHI->getParent() != Header) continue;
10114 PHIsToCompute.push_back(PHI);
10115 }
10116 for (PHINode *PHI : PHIsToCompute) {
10117 Constant *&NextPHI = NextIterVals[PHI];
10118 if (NextPHI) continue; // Already computed!
10119
10120 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
10121 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
10122 }
10123 CurrentIterVals.swap(NextIterVals);
10124 }
10125
10126 // Too many iterations were needed to evaluate.
10127 return getCouldNotCompute();
10128}
10129
10130const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
10132 ValuesAtScopes[V];
10133 // Check to see if we've folded this expression at this loop before.
10134 for (auto &LS : Values)
10135 if (LS.first == L)
10136 return LS.second ? LS.second : V;
10137
10138 Values.emplace_back(L, nullptr);
10139
10140 // Otherwise compute it.
10141 const SCEV *C = computeSCEVAtScope(V, L);
10142 for (auto &LS : reverse(ValuesAtScopes[V]))
10143 if (LS.first == L) {
10144 LS.second = C;
10145 if (!isa<SCEVConstant>(C))
10146 ValuesAtScopesUsers[C].push_back({L, V});
10147 break;
10148 }
10149 return C;
10150}
10151
10152/// This builds up a Constant using the ConstantExpr interface. That way, we
10153/// will return Constants for objects which aren't represented by a
10154/// SCEVConstant, because SCEVConstant is restricted to ConstantInt.
10155/// Returns NULL if the SCEV isn't representable as a Constant.
10157 switch (V->getSCEVType()) {
10158 case scCouldNotCompute:
10159 case scAddRecExpr:
10160 case scVScale:
10161 return nullptr;
10162 case scConstant:
10163 return cast<SCEVConstant>(V)->getValue();
10164 case scUnknown:
10165 return dyn_cast<Constant>(cast<SCEVUnknown>(V)->getValue());
10166 case scPtrToAddr: {
10168 if (Constant *CastOp = BuildConstantFromSCEV(P2I->getOperand()))
10169 return ConstantExpr::getPtrToAddr(CastOp, P2I->getType());
10170
10171 return nullptr;
10172 }
10173 case scPtrToInt: {
10175 if (Constant *CastOp = BuildConstantFromSCEV(P2I->getOperand()))
10176 return ConstantExpr::getPtrToInt(CastOp, P2I->getType());
10177
10178 return nullptr;
10179 }
10180 case scTruncate: {
10182 if (Constant *CastOp = BuildConstantFromSCEV(ST->getOperand()))
10183 return ConstantExpr::getTrunc(CastOp, ST->getType());
10184 return nullptr;
10185 }
10186 case scAddExpr: {
10187 const SCEVAddExpr *SA = cast<SCEVAddExpr>(V);
10188 Constant *C = nullptr;
10189 for (const SCEV *Op : SA->operands()) {
10191 if (!OpC)
10192 return nullptr;
10193 if (!C) {
10194 C = OpC;
10195 continue;
10196 }
10197 assert(!C->getType()->isPointerTy() &&
10198 "Can only have one pointer, and it must be last");
10199 if (OpC->getType()->isPointerTy()) {
10200 // The offsets have been converted to bytes. We can add bytes using
10201 // an i8 GEP.
10202 C = ConstantExpr::getPtrAdd(OpC, C);
10203 } else {
10204 C = ConstantExpr::getAdd(C, OpC);
10205 }
10206 }
10207 return C;
10208 }
10209 case scMulExpr:
10210 case scSignExtend:
10211 case scZeroExtend:
10212 case scUDivExpr:
10213 case scSMaxExpr:
10214 case scUMaxExpr:
10215 case scSMinExpr:
10216 case scUMinExpr:
10218 return nullptr;
10219 }
10220 llvm_unreachable("Unknown SCEV kind!");
10221}
10222
10223const SCEV *ScalarEvolution::getWithOperands(const SCEV *S,
10224 SmallVectorImpl<SCEVUse> &NewOps) {
10225 switch (S->getSCEVType()) {
10226 case scTruncate:
10227 case scZeroExtend:
10228 case scSignExtend:
10229 case scPtrToAddr:
10230 case scPtrToInt:
10231 return getCastExpr(S->getSCEVType(), NewOps[0], S->getType());
10232 case scAddRecExpr: {
10233 auto *AddRec = cast<SCEVAddRecExpr>(S);
10234 return getAddRecExpr(NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags());
10235 }
10236 case scAddExpr:
10237 return getAddExpr(NewOps, cast<SCEVAddExpr>(S)->getNoWrapFlags());
10238 case scMulExpr:
10239 return getMulExpr(NewOps, cast<SCEVMulExpr>(S)->getNoWrapFlags());
10240 case scUDivExpr:
10241 return getUDivExpr(NewOps[0], NewOps[1]);
10242 case scUMaxExpr:
10243 case scSMaxExpr:
10244 case scUMinExpr:
10245 case scSMinExpr:
10246 return getMinMaxExpr(S->getSCEVType(), NewOps);
10248 return getSequentialMinMaxExpr(S->getSCEVType(), NewOps);
10249 case scConstant:
10250 case scVScale:
10251 case scUnknown:
10252 return S;
10253 case scCouldNotCompute:
10254 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
10255 }
10256 llvm_unreachable("Unknown SCEV kind!");
10257}
10258
10259const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
10260 switch (V->getSCEVType()) {
10261 case scConstant:
10262 case scVScale:
10263 return V;
10264 case scAddRecExpr: {
10265 // If this is a loop recurrence for a loop that does not contain L, then we
10266 // are dealing with the final value computed by the loop.
10267 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(V);
10268 // First, attempt to evaluate each operand.
10269 // Avoid performing the look-up in the common case where the specified
10270 // expression has no loop-variant portions.
10271 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
10272 const SCEV *OpAtScope = getSCEVAtScope(AddRec->getOperand(i), L);
10273 if (OpAtScope == AddRec->getOperand(i))
10274 continue;
10275
10276 // Okay, at least one of these operands is loop variant but might be
10277 // foldable. Build a new instance of the folded commutative expression.
10279 NewOps.reserve(AddRec->getNumOperands());
10280 append_range(NewOps, AddRec->operands().take_front(i));
10281 NewOps.push_back(OpAtScope);
10282 for (++i; i != e; ++i)
10283 NewOps.push_back(getSCEVAtScope(AddRec->getOperand(i), L));
10284
10285 const SCEV *FoldedRec = getAddRecExpr(
10286 NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags(SCEV::FlagNW));
10287 AddRec = dyn_cast<SCEVAddRecExpr>(FoldedRec);
10288 // The addrec may be folded to a nonrecurrence, for example, if the
10289 // induction variable is multiplied by zero after constant folding. Go
10290 // ahead and return the folded value.
10291 if (!AddRec)
10292 return FoldedRec;
10293 break;
10294 }
10295
10296 // If the scope is outside the addrec's loop, evaluate it by using the
10297 // loop exit value of the addrec.
10298 if (!AddRec->getLoop()->contains(L)) {
10299 // To evaluate this recurrence, we need to know how many times the AddRec
10300 // loop iterates. Compute this now.
10301 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop());
10302 if (BackedgeTakenCount == getCouldNotCompute())
10303 return AddRec;
10304
10305 // Then, evaluate the AddRec.
10306 return AddRec->evaluateAtIteration(BackedgeTakenCount, *this);
10307 }
10308
10309 return AddRec;
10310 }
10311 case scTruncate:
10312 case scZeroExtend:
10313 case scSignExtend:
10314 case scPtrToAddr:
10315 case scPtrToInt:
10316 case scAddExpr:
10317 case scMulExpr:
10318 case scUDivExpr:
10319 case scUMaxExpr:
10320 case scSMaxExpr:
10321 case scUMinExpr:
10322 case scSMinExpr:
10323 case scSequentialUMinExpr: {
10324 ArrayRef<SCEVUse> Ops = V->operands();
10325 // Avoid performing the look-up in the common case where the specified
10326 // expression has no loop-variant portions.
10327 for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
10328 const SCEV *OpAtScope = getSCEVAtScope(Ops[i].getPointer(), L);
10329 if (OpAtScope != Ops[i].getPointer()) {
10330 // Okay, at least one of these operands is loop variant but might be
10331 // foldable. Build a new instance of the folded commutative expression.
10333 NewOps.reserve(Ops.size());
10334 append_range(NewOps, Ops.take_front(i));
10335 NewOps.push_back(OpAtScope);
10336
10337 for (++i; i != e; ++i) {
10338 OpAtScope = getSCEVAtScope(Ops[i].getPointer(), L);
10339 NewOps.push_back(OpAtScope);
10340 }
10341
10342 return getWithOperands(V, NewOps);
10343 }
10344 }
10345 // If we got here, all operands are loop invariant.
10346 return V;
10347 }
10348 case scUnknown: {
10349 // If this instruction is evolved from a constant-evolving PHI, compute the
10350 // exit value from the loop without using SCEVs.
10351 const SCEVUnknown *SU = cast<SCEVUnknown>(V);
10353 if (!I)
10354 return V; // This is some other type of SCEVUnknown, just return it.
10355
10356 if (PHINode *PN = dyn_cast<PHINode>(I)) {
10357 const Loop *CurrLoop = this->LI[I->getParent()];
10358 // Looking for loop exit value.
10359 if (CurrLoop && CurrLoop->getParentLoop() == L &&
10360 PN->getParent() == CurrLoop->getHeader()) {
10361 // Okay, there is no closed form solution for the PHI node. Check
10362 // to see if the loop that contains it has a known backedge-taken
10363 // count. If so, we may be able to force computation of the exit
10364 // value.
10365 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(CurrLoop);
10366 // This trivial case can show up in some degenerate cases where
10367 // the incoming IR has not yet been fully simplified.
10368 if (BackedgeTakenCount->isZero()) {
10369 Value *InitValue = nullptr;
10370 bool MultipleInitValues = false;
10371 for (unsigned i = 0; i < PN->getNumIncomingValues(); i++) {
10372 if (!CurrLoop->contains(PN->getIncomingBlock(i))) {
10373 if (!InitValue)
10374 InitValue = PN->getIncomingValue(i);
10375 else if (InitValue != PN->getIncomingValue(i)) {
10376 MultipleInitValues = true;
10377 break;
10378 }
10379 }
10380 }
10381 if (!MultipleInitValues && InitValue)
10382 return getSCEV(InitValue);
10383 }
10384 // Do we have a loop invariant value flowing around the backedge
10385 // for a loop which must execute the backedge?
10386 if (!isa<SCEVCouldNotCompute>(BackedgeTakenCount) &&
10387 isKnownNonZero(BackedgeTakenCount) &&
10388 PN->getNumIncomingValues() == 2) {
10389
10390 unsigned InLoopPred =
10391 CurrLoop->contains(PN->getIncomingBlock(0)) ? 0 : 1;
10392 Value *BackedgeVal = PN->getIncomingValue(InLoopPred);
10393 if (CurrLoop->isLoopInvariant(BackedgeVal))
10394 return getSCEV(BackedgeVal);
10395 }
10396 if (auto *BTCC = dyn_cast<SCEVConstant>(BackedgeTakenCount)) {
10397 // Okay, we know how many times the containing loop executes. If
10398 // this is a constant evolving PHI node, get the final value at
10399 // the specified iteration number.
10400 Constant *RV =
10401 getConstantEvolutionLoopExitValue(PN, BTCC->getAPInt(), CurrLoop);
10402 if (RV)
10403 return getSCEV(RV);
10404 }
10405 }
10406 }
10407
10408 // Okay, this is an expression that we cannot symbolically evaluate
10409 // into a SCEV. Check to see if it's possible to symbolically evaluate
10410 // the arguments into constants, and if so, try to constant propagate the
10411 // result. This is particularly useful for computing loop exit values.
10412 if (!CanConstantFold(I))
10413 return V; // This is some other type of SCEVUnknown, just return it.
10414
10415 SmallVector<Constant *, 4> Operands;
10416 Operands.reserve(I->getNumOperands());
10417 bool MadeImprovement = false;
10418 for (Value *Op : I->operands()) {
10419 if (Constant *C = dyn_cast<Constant>(Op)) {
10420 Operands.push_back(C);
10421 continue;
10422 }
10423
10424 // If any of the operands is non-constant and if they are
10425 // non-integer and non-pointer, don't even try to analyze them
10426 // with scev techniques.
10427 if (!isSCEVable(Op->getType()))
10428 return V;
10429
10430 const SCEV *OrigV = getSCEV(Op);
10431 const SCEV *OpV = getSCEVAtScope(OrigV, L);
10432 MadeImprovement |= OrigV != OpV;
10433
10435 if (!C)
10436 return V;
10437 assert(C->getType() == Op->getType() && "Type mismatch");
10438 Operands.push_back(C);
10439 }
10440
10441 // Check to see if getSCEVAtScope actually made an improvement.
10442 if (!MadeImprovement)
10443 return V; // This is some other type of SCEVUnknown, just return it.
10444
10445 Constant *C = nullptr;
10446 const DataLayout &DL = getDataLayout();
10447 C = ConstantFoldInstOperands(I, Operands, DL, &TLI,
10448 /*AllowNonDeterministic=*/false);
10449 if (!C)
10450 return V;
10451 return getSCEV(C);
10452 }
10453 case scCouldNotCompute:
10454 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
10455 }
10456 llvm_unreachable("Unknown SCEV type!");
10457}
10458
10460 return getSCEVAtScope(getSCEV(V), L);
10461}
10462
10463const SCEV *ScalarEvolution::stripInjectiveFunctions(const SCEV *S) const {
10465 return stripInjectiveFunctions(ZExt->getOperand());
10467 return stripInjectiveFunctions(SExt->getOperand());
10468 return S;
10469}
10470
10471/// Finds the minimum unsigned root of the following equation:
10472///
10473/// A * X = B (mod N)
10474///
10475/// where N = 2^BW and BW is the common bit width of A and B. The signedness of
10476/// A and B isn't important.
10477///
10478/// If the equation does not have a solution, SCEVCouldNotCompute is returned.
10479static const SCEV *
10482 ScalarEvolution &SE, const Loop *L) {
10483 uint32_t BW = A.getBitWidth();
10484 assert(BW == SE.getTypeSizeInBits(B->getType()));
10485 assert(A != 0 && "A must be non-zero.");
10486
10487 // 1. D = gcd(A, N)
10488 //
10489 // The gcd of A and N may have only one prime factor: 2. The number of
10490 // trailing zeros in A is its multiplicity
10491 uint32_t Mult2 = A.countr_zero();
10492 // D = 2^Mult2
10493
10494 // 2. Check if B is divisible by D.
10495 //
10496 // B is divisible by D if and only if the multiplicity of prime factor 2 for B
10497 // is not less than multiplicity of this prime factor for D.
10498 unsigned MinTZ = SE.getMinTrailingZeros(B);
10499 // Try again with the terminator of the loop predecessor for context-specific
10500 // result, if MinTZ s too small.
10501 if (MinTZ < Mult2 && L->getLoopPredecessor())
10502 MinTZ = SE.getMinTrailingZeros(B, L->getLoopPredecessor()->getTerminator());
10503 if (MinTZ < Mult2) {
10504 // Check if we can prove there's no remainder using URem.
10505 const SCEV *URem =
10506 SE.getURemExpr(B, SE.getConstant(APInt::getOneBitSet(BW, Mult2)));
10507 const SCEV *Zero = SE.getZero(B->getType());
10508 if (!SE.isKnownPredicate(CmpInst::ICMP_EQ, URem, Zero)) {
10509 // Try to add a predicate ensuring B is a multiple of 1 << Mult2.
10510 if (!Predicates)
10511 return SE.getCouldNotCompute();
10512
10513 // Avoid adding a predicate that is known to be false.
10514 if (SE.isKnownPredicate(CmpInst::ICMP_NE, URem, Zero))
10515 return SE.getCouldNotCompute();
10516 Predicates->push_back(SE.getEqualPredicate(URem, Zero));
10517 }
10518 }
10519
10520 // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic
10521 // modulo (N / D).
10522 //
10523 // If D == 1, (N / D) == N == 2^BW, so we need one extra bit to represent
10524 // (N / D) in general. The inverse itself always fits into BW bits, though,
10525 // so we immediately truncate it.
10526 APInt AD = A.lshr(Mult2).trunc(BW - Mult2); // AD = A / D
10527 APInt I = AD.multiplicativeInverse().zext(BW);
10528
10529 // 4. Compute the minimum unsigned root of the equation:
10530 // I * (B / D) mod (N / D)
10531 // To simplify the computation, we factor out the divide by D:
10532 // (I * B mod N) / D
10533 const SCEV *D = SE.getConstant(APInt::getOneBitSet(BW, Mult2));
10534 return SE.getUDivExactExpr(SE.getMulExpr(B, SE.getConstant(I)), D);
10535}
10536
10537/// For a given quadratic addrec, generate coefficients of the corresponding
10538/// quadratic equation, multiplied by a common value to ensure that they are
10539/// integers.
10540/// The returned value is a tuple { A, B, C, M, BitWidth }, where
10541/// Ax^2 + Bx + C is the quadratic function, M is the value that A, B and C
10542/// were multiplied by, and BitWidth is the bit width of the original addrec
10543/// coefficients.
10544/// This function returns std::nullopt if the addrec coefficients are not
10545/// compile- time constants.
10546static std::optional<std::tuple<APInt, APInt, APInt, APInt, unsigned>>
10548 assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!");
10549 const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0));
10550 const SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1));
10551 const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2));
10552 LLVM_DEBUG(dbgs() << __func__ << ": analyzing quadratic addrec: "
10553 << *AddRec << '\n');
10554
10555 // We currently can only solve this if the coefficients are constants.
10556 if (!LC || !MC || !NC) {
10557 LLVM_DEBUG(dbgs() << __func__ << ": coefficients are not constant\n");
10558 return std::nullopt;
10559 }
10560
10561 APInt L = LC->getAPInt();
10562 APInt M = MC->getAPInt();
10563 APInt N = NC->getAPInt();
10564 assert(!N.isZero() && "This is not a quadratic addrec");
10565
10566 unsigned BitWidth = LC->getAPInt().getBitWidth();
10567 unsigned NewWidth = BitWidth + 1;
10568 LLVM_DEBUG(dbgs() << __func__ << ": addrec coeff bw: "
10569 << BitWidth << '\n');
10570 // The sign-extension (as opposed to a zero-extension) here matches the
10571 // extension used in SolveQuadraticEquationWrap (with the same motivation).
10572 N = N.sext(NewWidth);
10573 M = M.sext(NewWidth);
10574 L = L.sext(NewWidth);
10575
10576 // The increments are M, M+N, M+2N, ..., so the accumulated values are
10577 // L+M, (L+M)+(M+N), (L+M)+(M+N)+(M+2N), ..., that is,
10578 // L+M, L+2M+N, L+3M+3N, ...
10579 // After n iterations the accumulated value Acc is L + nM + n(n-1)/2 N.
10580 //
10581 // The equation Acc = 0 is then
10582 // L + nM + n(n-1)/2 N = 0, or 2L + 2M n + n(n-1) N = 0.
10583 // In a quadratic form it becomes:
10584 // N n^2 + (2M-N) n + 2L = 0.
10585
10586 APInt A = N;
10587 APInt B = 2 * M - A;
10588 APInt C = 2 * L;
10589 APInt T = APInt(NewWidth, 2);
10590 LLVM_DEBUG(dbgs() << __func__ << ": equation " << A << "x^2 + " << B
10591 << "x + " << C << ", coeff bw: " << NewWidth
10592 << ", multiplied by " << T << '\n');
10593 return std::make_tuple(A, B, C, T, BitWidth);
10594}
10595
10596/// Helper function to compare optional APInts:
10597/// (a) if X and Y both exist, return min(X, Y),
10598/// (b) if neither X nor Y exist, return std::nullopt,
10599/// (c) if exactly one of X and Y exists, return that value.
10600static std::optional<APInt> MinOptional(std::optional<APInt> X,
10601 std::optional<APInt> Y) {
10602 if (X && Y) {
10603 unsigned W = std::max(X->getBitWidth(), Y->getBitWidth());
10604 APInt XW = X->sext(W);
10605 APInt YW = Y->sext(W);
10606 return XW.slt(YW) ? *X : *Y;
10607 }
10608 if (!X && !Y)
10609 return std::nullopt;
10610 return X ? *X : *Y;
10611}
10612
10613/// Helper function to truncate an optional APInt to a given BitWidth.
10614/// When solving addrec-related equations, it is preferable to return a value
10615/// that has the same bit width as the original addrec's coefficients. If the
10616/// solution fits in the original bit width, truncate it (except for i1).
10617/// Returning a value of a different bit width may inhibit some optimizations.
10618///
10619/// In general, a solution to a quadratic equation generated from an addrec
10620/// may require BW+1 bits, where BW is the bit width of the addrec's
10621/// coefficients. The reason is that the coefficients of the quadratic
10622/// equation are BW+1 bits wide (to avoid truncation when converting from
10623/// the addrec to the equation).
10624static std::optional<APInt> TruncIfPossible(std::optional<APInt> X,
10625 unsigned BitWidth) {
10626 if (!X)
10627 return std::nullopt;
10628 unsigned W = X->getBitWidth();
10630 return X->trunc(BitWidth);
10631 return X;
10632}
10633
10634/// Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n
10635/// iterations. The values L, M, N are assumed to be signed, and they
10636/// should all have the same bit widths.
10637/// Find the least n >= 0 such that c(n) = 0 in the arithmetic modulo 2^BW,
10638/// where BW is the bit width of the addrec's coefficients.
10639/// If the calculated value is a BW-bit integer (for BW > 1), it will be
10640/// returned as such, otherwise the bit width of the returned value may
10641/// be greater than BW.
10642///
10643/// This function returns std::nullopt if
10644/// (a) the addrec coefficients are not constant, or
10645/// (b) SolveQuadraticEquationWrap was unable to find a solution. For cases
10646/// like x^2 = 5, no integer solutions exist, in other cases an integer
10647/// solution may exist, but SolveQuadraticEquationWrap may fail to find it.
10648static std::optional<APInt>
10650 APInt A, B, C, M;
10651 unsigned BitWidth;
10652 auto T = GetQuadraticEquation(AddRec);
10653 if (!T)
10654 return std::nullopt;
10655
10656 std::tie(A, B, C, M, BitWidth) = *T;
10657 LLVM_DEBUG(dbgs() << __func__ << ": solving for unsigned overflow\n");
10658 std::optional<APInt> X =
10660 if (!X)
10661 return std::nullopt;
10662
10663 ConstantInt *CX = ConstantInt::get(SE.getContext(), *X);
10664 ConstantInt *V = EvaluateConstantChrecAtConstant(AddRec, CX, SE);
10665 if (!V->isZero())
10666 return std::nullopt;
10667
10668 return TruncIfPossible(X, BitWidth);
10669}
10670
10671/// Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n
10672/// iterations. The values M, N are assumed to be signed, and they
10673/// should all have the same bit widths.
10674/// Find the least n such that c(n) does not belong to the given range,
10675/// while c(n-1) does.
10676///
10677/// This function returns std::nullopt if
10678/// (a) the addrec coefficients are not constant, or
10679/// (b) SolveQuadraticEquationWrap was unable to find a solution for the
10680/// bounds of the range.
10681static std::optional<APInt>
10683 const ConstantRange &Range, ScalarEvolution &SE) {
10684 assert(AddRec->getOperand(0)->isZero() &&
10685 "Starting value of addrec should be 0");
10686 LLVM_DEBUG(dbgs() << __func__ << ": solving boundary crossing for range "
10687 << Range << ", addrec " << *AddRec << '\n');
10688 // This case is handled in getNumIterationsInRange. Here we can assume that
10689 // we start in the range.
10690 assert(Range.contains(APInt(SE.getTypeSizeInBits(AddRec->getType()), 0)) &&
10691 "Addrec's initial value should be in range");
10692
10693 APInt A, B, C, M;
10694 unsigned BitWidth;
10695 auto T = GetQuadraticEquation(AddRec);
10696 if (!T)
10697 return std::nullopt;
10698
10699 // Be careful about the return value: there can be two reasons for not
10700 // returning an actual number. First, if no solutions to the equations
10701 // were found, and second, if the solutions don't leave the given range.
10702 // The first case means that the actual solution is "unknown", the second
10703 // means that it's known, but not valid. If the solution is unknown, we
10704 // cannot make any conclusions.
10705 // Return a pair: the optional solution and a flag indicating if the
10706 // solution was found.
10707 auto SolveForBoundary =
10708 [&](APInt Bound) -> std::pair<std::optional<APInt>, bool> {
10709 // Solve for signed overflow and unsigned overflow, pick the lower
10710 // solution.
10711 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: checking boundary "
10712 << Bound << " (before multiplying by " << M << ")\n");
10713 Bound *= M; // The quadratic equation multiplier.
10714
10715 std::optional<APInt> SO;
10716 if (BitWidth > 1) {
10717 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10718 "signed overflow\n");
10720 }
10721 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10722 "unsigned overflow\n");
10723 std::optional<APInt> UO =
10725
10726 auto LeavesRange = [&] (const APInt &X) {
10727 ConstantInt *C0 = ConstantInt::get(SE.getContext(), X);
10728 ConstantInt *V0 = EvaluateConstantChrecAtConstant(AddRec, C0, SE);
10729 if (Range.contains(V0->getValue()))
10730 return false;
10731 // X should be at least 1, so X-1 is non-negative.
10732 ConstantInt *C1 = ConstantInt::get(SE.getContext(), X-1);
10733 ConstantInt *V1 = EvaluateConstantChrecAtConstant(AddRec, C1, SE);
10734 if (Range.contains(V1->getValue()))
10735 return true;
10736 return false;
10737 };
10738
10739 // If SolveQuadraticEquationWrap returns std::nullopt, it means that there
10740 // can be a solution, but the function failed to find it. We cannot treat it
10741 // as "no solution".
10742 if (!SO || !UO)
10743 return {std::nullopt, false};
10744
10745 // Check the smaller value first to see if it leaves the range.
10746 // At this point, both SO and UO must have values.
10747 std::optional<APInt> Min = MinOptional(SO, UO);
10748 if (LeavesRange(*Min))
10749 return { Min, true };
10750 std::optional<APInt> Max = Min == SO ? UO : SO;
10751 if (LeavesRange(*Max))
10752 return { Max, true };
10753
10754 // Solutions were found, but were eliminated, hence the "true".
10755 return {std::nullopt, true};
10756 };
10757
10758 std::tie(A, B, C, M, BitWidth) = *T;
10759 // Lower bound is inclusive, subtract 1 to represent the exiting value.
10760 APInt Lower = Range.getLower().sext(A.getBitWidth()) - 1;
10761 APInt Upper = Range.getUpper().sext(A.getBitWidth());
10762 auto SL = SolveForBoundary(Lower);
10763 auto SU = SolveForBoundary(Upper);
10764 // If any of the solutions was unknown, no meaninigful conclusions can
10765 // be made.
10766 if (!SL.second || !SU.second)
10767 return std::nullopt;
10768
10769 // Claim: The correct solution is not some value between Min and Max.
10770 //
10771 // Justification: Assuming that Min and Max are different values, one of
10772 // them is when the first signed overflow happens, the other is when the
10773 // first unsigned overflow happens. Crossing the range boundary is only
10774 // possible via an overflow (treating 0 as a special case of it, modeling
10775 // an overflow as crossing k*2^W for some k).
10776 //
10777 // The interesting case here is when Min was eliminated as an invalid
10778 // solution, but Max was not. The argument is that if there was another
10779 // overflow between Min and Max, it would also have been eliminated if
10780 // it was considered.
10781 //
10782 // For a given boundary, it is possible to have two overflows of the same
10783 // type (signed/unsigned) without having the other type in between: this
10784 // can happen when the vertex of the parabola is between the iterations
10785 // corresponding to the overflows. This is only possible when the two
10786 // overflows cross k*2^W for the same k. In such case, if the second one
10787 // left the range (and was the first one to do so), the first overflow
10788 // would have to enter the range, which would mean that either we had left
10789 // the range before or that we started outside of it. Both of these cases
10790 // are contradictions.
10791 //
10792 // Claim: In the case where SolveForBoundary returns std::nullopt, the correct
10793 // solution is not some value between the Max for this boundary and the
10794 // Min of the other boundary.
10795 //
10796 // Justification: Assume that we had such Max_A and Min_B corresponding
10797 // to range boundaries A and B and such that Max_A < Min_B. If there was
10798 // a solution between Max_A and Min_B, it would have to be caused by an
10799 // overflow corresponding to either A or B. It cannot correspond to B,
10800 // since Min_B is the first occurrence of such an overflow. If it
10801 // corresponded to A, it would have to be either a signed or an unsigned
10802 // overflow that is larger than both eliminated overflows for A. But
10803 // between the eliminated overflows and this overflow, the values would
10804 // cover the entire value space, thus crossing the other boundary, which
10805 // is a contradiction.
10806
10807 return TruncIfPossible(MinOptional(SL.first, SU.first), BitWidth);
10808}
10809
10810ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
10811 const Loop *L,
10812 bool ControlsOnlyExit,
10813 bool AllowPredicates) {
10814
10815 // This is only used for loops with a "x != y" exit test. The exit condition
10816 // is now expressed as a single expression, V = x-y. So the exit test is
10817 // effectively V != 0. We know and take advantage of the fact that this
10818 // expression only being used in a comparison by zero context.
10819
10821 // If the value is a constant
10822 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10823 // If the value is already zero, the branch will execute zero times.
10824 if (C->getValue()->isZero()) return C;
10825 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10826 }
10827
10828 const SCEVAddRecExpr *AddRec =
10829 dyn_cast<SCEVAddRecExpr>(stripInjectiveFunctions(V));
10830
10831 if (!AddRec && AllowPredicates)
10832 // Try to make this an AddRec using runtime tests, in the first X
10833 // iterations of this loop, where X is the SCEV expression found by the
10834 // algorithm below.
10835 AddRec = convertSCEVToAddRecWithPredicates(V, L, Predicates);
10836
10837 if (!AddRec || AddRec->getLoop() != L)
10838 return getCouldNotCompute();
10839
10840 // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of
10841 // the quadratic equation to solve it.
10842 if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) {
10843 // We can only use this value if the chrec ends up with an exact zero
10844 // value at this index. When solving for "X*X != 5", for example, we
10845 // should not accept a root of 2.
10846 if (auto S = SolveQuadraticAddRecExact(AddRec, *this)) {
10847 const auto *R = cast<SCEVConstant>(getConstant(*S));
10848 return ExitLimit(R, R, R, false, Predicates);
10849 }
10850 return getCouldNotCompute();
10851 }
10852
10853 // Otherwise we can only handle this if it is affine.
10854 if (!AddRec->isAffine())
10855 return getCouldNotCompute();
10856
10857 // If this is an affine expression, the execution count of this branch is
10858 // the minimum unsigned root of the following equation:
10859 //
10860 // Start + Step*N = 0 (mod 2^BW)
10861 //
10862 // equivalent to:
10863 //
10864 // Step*N = -Start (mod 2^BW)
10865 //
10866 // where BW is the common bit width of Start and Step.
10867
10868 // Get the initial value for the loop.
10869 const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop());
10870 const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop());
10871
10872 if (!isLoopInvariant(Step, L))
10873 return getCouldNotCompute();
10874
10875 LoopGuards Guards = LoopGuards::collect(L, *this);
10876 // Specialize step for this loop so we get context sensitive facts below.
10877 const SCEV *StepWLG = applyLoopGuards(Step, Guards);
10878
10879 // For positive steps (counting up until unsigned overflow):
10880 // N = -Start/Step (as unsigned)
10881 // For negative steps (counting down to zero):
10882 // N = Start/-Step
10883 // First compute the unsigned distance from zero in the direction of Step.
10884 bool CountDown = isKnownNegative(StepWLG);
10885 if (!CountDown && !isKnownNonNegative(StepWLG))
10886 return getCouldNotCompute();
10887
10888 const SCEV *Distance = CountDown ? Start : getNegativeSCEV(Start);
10889 // Handle unitary steps, which cannot wraparound.
10890 // 1*N = -Start; -1*N = Start (mod 2^BW), so:
10891 // N = Distance (as unsigned)
10892
10893 if (match(Step, m_CombineOr(m_scev_One(), m_scev_AllOnes()))) {
10894 APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, Guards));
10895 MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance));
10896
10897 // When a loop like "for (int i = 0; i != n; ++i) { /* body */ }" is rotated,
10898 // we end up with a loop whose backedge-taken count is n - 1. Detect this
10899 // case, and see if we can improve the bound.
10900 //
10901 // Explicitly handling this here is necessary because getUnsignedRange
10902 // isn't context-sensitive; it doesn't know that we only care about the
10903 // range inside the loop.
10904 const SCEV *Zero = getZero(Distance->getType());
10905 const SCEV *One = getOne(Distance->getType());
10906 const SCEV *DistancePlusOne = getAddExpr(Distance, One);
10907 if (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, DistancePlusOne, Zero)) {
10908 // If Distance + 1 doesn't overflow, we can compute the maximum distance
10909 // as "unsigned_max(Distance + 1) - 1".
10910 ConstantRange CR = getUnsignedRange(DistancePlusOne);
10911 MaxBECount = APIntOps::umin(MaxBECount, CR.getUnsignedMax() - 1);
10912 }
10913 return ExitLimit(Distance, getConstant(MaxBECount), Distance, false,
10914 Predicates);
10915 }
10916
10917 // If the condition controls loop exit (the loop exits only if the expression
10918 // is true) and the addition is no-wrap we can use unsigned divide to
10919 // compute the backedge count. In this case, the step may not divide the
10920 // distance, but we don't care because if the condition is "missed" the loop
10921 // will have undefined behavior due to wrapping.
10922 if (ControlsOnlyExit && AddRec->hasNoSelfWrap() &&
10923 loopHasNoAbnormalExits(AddRec->getLoop())) {
10924
10925 // If the stride is zero and the start is non-zero, the loop must be
10926 // infinite. In C++, most loops are finite by assumption, in which case the
10927 // step being zero implies UB must execute if the loop is entered.
10928 if (!(loopIsFiniteByAssumption(L) && isKnownNonZero(Start)) &&
10929 !isKnownNonZero(StepWLG))
10930 return getCouldNotCompute();
10931
10932 const SCEV *Exact =
10933 getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
10934 const SCEV *ConstantMax = getCouldNotCompute();
10935 if (Exact != getCouldNotCompute()) {
10936 APInt MaxInt = getUnsignedRangeMax(applyLoopGuards(Exact, Guards));
10937 ConstantMax =
10939 }
10940 const SCEV *SymbolicMax =
10941 isa<SCEVCouldNotCompute>(Exact) ? ConstantMax : Exact;
10942 return ExitLimit(Exact, ConstantMax, SymbolicMax, false, Predicates);
10943 }
10944
10945 // Solve the general equation.
10946 const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
10947 if (!StepC || StepC->getValue()->isZero())
10948 return getCouldNotCompute();
10949 const SCEV *E = SolveLinEquationWithOverflow(
10950 StepC->getAPInt(), getNegativeSCEV(Start),
10951 AllowPredicates ? &Predicates : nullptr, *this, L);
10952
10953 const SCEV *M = E;
10954 if (E != getCouldNotCompute()) {
10955 APInt MaxWithGuards = getUnsignedRangeMax(applyLoopGuards(E, Guards));
10956 M = getConstant(APIntOps::umin(MaxWithGuards, getUnsignedRangeMax(E)));
10957 }
10958 auto *S = isa<SCEVCouldNotCompute>(E) ? M : E;
10959 return ExitLimit(E, M, S, false, Predicates);
10960}
10961
10962ScalarEvolution::ExitLimit
10963ScalarEvolution::howFarToNonZero(const SCEV *V, const Loop *L) {
10964 // Loops that look like: while (X == 0) are very strange indeed. We don't
10965 // handle them yet except for the trivial case. This could be expanded in the
10966 // future as needed.
10967
10968 // If the value is a constant, check to see if it is known to be non-zero
10969 // already. If so, the backedge will execute zero times.
10970 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10971 if (!C->getValue()->isZero())
10972 return getZero(C->getType());
10973 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10974 }
10975
10976 // We could implement others, but I really doubt anyone writes loops like
10977 // this, and if they did, they would already be constant folded.
10978 return getCouldNotCompute();
10979}
10980
10981std::pair<const BasicBlock *, const BasicBlock *>
10982ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(const BasicBlock *BB)
10983 const {
10984 // If the block has a unique predecessor, then there is no path from the
10985 // predecessor to the block that does not go through the direct edge
10986 // from the predecessor to the block.
10987 if (const BasicBlock *Pred = BB->getSinglePredecessor())
10988 return {Pred, BB};
10989
10990 // A loop's header is defined to be a block that dominates the loop.
10991 // If the header has a unique predecessor outside the loop, it must be
10992 // a block that has exactly one successor that can reach the loop.
10993 if (const Loop *L = LI.getLoopFor(BB))
10994 return {L->getLoopPredecessor(), L->getHeader()};
10995
10996 return {nullptr, BB};
10997}
10998
10999/// SCEV structural equivalence is usually sufficient for testing whether two
11000/// expressions are equal, however for the purposes of looking for a condition
11001/// guarding a loop, it can be useful to be a little more general, since a
11002/// front-end may have replicated the controlling expression.
11003static bool HasSameValue(const SCEV *A, const SCEV *B) {
11004 // Quick check to see if they are the same SCEV.
11005 if (A == B) return true;
11006
11007 auto ComputesEqualValues = [](const Instruction *A, const Instruction *B) {
11008 // Not all instructions that are "identical" compute the same value. For
11009 // instance, two distinct alloca instructions allocating the same type are
11010 // identical and do not read memory; but compute distinct values.
11011 return A->isIdenticalTo(B) && (isa<BinaryOperator>(A) || isa<GetElementPtrInst>(A));
11012 };
11013
11014 // Otherwise, if they're both SCEVUnknown, it's possible that they hold
11015 // two different instructions with the same value. Check for this case.
11016 if (const SCEVUnknown *AU = dyn_cast<SCEVUnknown>(A))
11017 if (const SCEVUnknown *BU = dyn_cast<SCEVUnknown>(B))
11018 if (const Instruction *AI = dyn_cast<Instruction>(AU->getValue()))
11019 if (const Instruction *BI = dyn_cast<Instruction>(BU->getValue()))
11020 if (ComputesEqualValues(AI, BI))
11021 return true;
11022
11023 // Otherwise assume they may have a different value.
11024 return false;
11025}
11026
11027static bool MatchBinarySub(const SCEV *S, SCEVUse &LHS, SCEVUse &RHS) {
11028 const SCEV *Op0, *Op1;
11029 if (!match(S, m_scev_Add(m_SCEV(Op0), m_SCEV(Op1))))
11030 return false;
11031 if (match(Op0, m_scev_Mul(m_scev_AllOnes(), m_SCEV(RHS)))) {
11032 LHS = Op1;
11033 return true;
11034 }
11035 if (match(Op1, m_scev_Mul(m_scev_AllOnes(), m_SCEV(RHS)))) {
11036 LHS = Op0;
11037 return true;
11038 }
11039 return false;
11040}
11041
11043 SCEVUse &RHS, unsigned Depth) {
11044 bool Changed = false;
11045 // Simplifies ICMP to trivial true or false by turning it into '0 == 0' or
11046 // '0 != 0'.
11047 auto TrivialCase = [&](bool TriviallyTrue) {
11049 Pred = TriviallyTrue ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE;
11050 return true;
11051 };
11052 // If we hit the max recursion limit bail out.
11053 if (Depth >= 3)
11054 return false;
11055
11056 const SCEV *NewLHS, *NewRHS;
11057 if (match(LHS, m_scev_c_Mul(m_SCEV(NewLHS), m_SCEVVScale())) &&
11058 match(RHS, m_scev_c_Mul(m_SCEV(NewRHS), m_SCEVVScale()))) {
11059 const SCEVMulExpr *LMul = cast<SCEVMulExpr>(LHS);
11060 const SCEVMulExpr *RMul = cast<SCEVMulExpr>(RHS);
11061
11062 // (X * vscale) pred (Y * vscale) ==> X pred Y
11063 // when both multiples are NSW.
11064 // (X * vscale) uicmp/eq/ne (Y * vscale) ==> X uicmp/eq/ne Y
11065 // when both multiples are NUW.
11066 if ((LMul->hasNoSignedWrap() && RMul->hasNoSignedWrap()) ||
11067 (LMul->hasNoUnsignedWrap() && RMul->hasNoUnsignedWrap() &&
11068 !ICmpInst::isSigned(Pred))) {
11069 LHS = NewLHS;
11070 RHS = NewRHS;
11071 Changed = true;
11072 }
11073 }
11074
11075 // Canonicalize a constant to the right side.
11076 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
11077 // Check for both operands constant.
11078 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
11079 if (!ICmpInst::compare(LHSC->getAPInt(), RHSC->getAPInt(), Pred))
11080 return TrivialCase(false);
11081 return TrivialCase(true);
11082 }
11083 // Otherwise swap the operands to put the constant on the right.
11084 std::swap(LHS, RHS);
11086 Changed = true;
11087 }
11088
11089 // If we're comparing an addrec with a value which is loop-invariant in the
11090 // addrec's loop, put the addrec on the left. Also make a dominance check,
11091 // as both operands could be addrecs loop-invariant in each other's loop.
11092 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(RHS)) {
11093 const Loop *L = AR->getLoop();
11094 if (isLoopInvariant(LHS, L) && properlyDominates(LHS, L->getHeader())) {
11095 std::swap(LHS, RHS);
11097 Changed = true;
11098 }
11099 }
11100
11101 // If there's a constant operand, canonicalize comparisons with boundary
11102 // cases, and canonicalize *-or-equal comparisons to regular comparisons.
11103 if (const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS)) {
11104 const APInt &RA = RC->getAPInt();
11105
11106 bool SimplifiedByConstantRange = false;
11107
11108 if (!ICmpInst::isEquality(Pred)) {
11110 if (ExactCR.isFullSet())
11111 return TrivialCase(true);
11112 if (ExactCR.isEmptySet())
11113 return TrivialCase(false);
11114
11115 APInt NewRHS;
11116 CmpInst::Predicate NewPred;
11117 if (ExactCR.getEquivalentICmp(NewPred, NewRHS) &&
11118 ICmpInst::isEquality(NewPred)) {
11119 // We were able to convert an inequality to an equality.
11120 Pred = NewPred;
11121 RHS = getConstant(NewRHS);
11122 Changed = SimplifiedByConstantRange = true;
11123 }
11124 }
11125
11126 if (!SimplifiedByConstantRange) {
11127 switch (Pred) {
11128 default:
11129 break;
11130 case ICmpInst::ICMP_EQ:
11131 case ICmpInst::ICMP_NE:
11132 // Fold ((-1) * %a) + %b == 0 (equivalent to %b-%a == 0) into %a == %b.
11133 if (RA.isZero() && MatchBinarySub(LHS, LHS, RHS))
11134 Changed = true;
11135 break;
11136
11137 // The "Should have been caught earlier!" messages refer to the fact
11138 // that the ExactCR.isFullSet() or ExactCR.isEmptySet() check above
11139 // should have fired on the corresponding cases, and canonicalized the
11140 // check to trivial case.
11141
11142 case ICmpInst::ICMP_UGE:
11143 assert(!RA.isMinValue() && "Should have been caught earlier!");
11144 Pred = ICmpInst::ICMP_UGT;
11145 RHS = getConstant(RA - 1);
11146 Changed = true;
11147 break;
11148 case ICmpInst::ICMP_ULE:
11149 assert(!RA.isMaxValue() && "Should have been caught earlier!");
11150 Pred = ICmpInst::ICMP_ULT;
11151 RHS = getConstant(RA + 1);
11152 Changed = true;
11153 break;
11154 case ICmpInst::ICMP_SGE:
11155 assert(!RA.isMinSignedValue() && "Should have been caught earlier!");
11156 Pred = ICmpInst::ICMP_SGT;
11157 RHS = getConstant(RA - 1);
11158 Changed = true;
11159 break;
11160 case ICmpInst::ICMP_SLE:
11161 assert(!RA.isMaxSignedValue() && "Should have been caught earlier!");
11162 Pred = ICmpInst::ICMP_SLT;
11163 RHS = getConstant(RA + 1);
11164 Changed = true;
11165 break;
11166 }
11167 }
11168 }
11169
11170 // Check for obvious equality.
11171 if (HasSameValue(LHS, RHS)) {
11172 if (ICmpInst::isTrueWhenEqual(Pred))
11173 return TrivialCase(true);
11175 return TrivialCase(false);
11176 }
11177
11178 // If possible, canonicalize GE/LE comparisons to GT/LT comparisons, by
11179 // adding or subtracting 1 from one of the operands.
11180 switch (Pred) {
11181 case ICmpInst::ICMP_SLE:
11182 if (!getSignedRangeMax(RHS).isMaxSignedValue()) {
11183 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
11185 Pred = ICmpInst::ICMP_SLT;
11186 Changed = true;
11187 } else if (!getSignedRangeMin(LHS).isMinSignedValue()) {
11188 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS,
11190 Pred = ICmpInst::ICMP_SLT;
11191 Changed = true;
11192 }
11193 break;
11194 case ICmpInst::ICMP_SGE:
11195 if (!getSignedRangeMin(RHS).isMinSignedValue()) {
11196 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS,
11198 Pred = ICmpInst::ICMP_SGT;
11199 Changed = true;
11200 } else if (!getSignedRangeMax(LHS).isMaxSignedValue()) {
11201 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
11203 Pred = ICmpInst::ICMP_SGT;
11204 Changed = true;
11205 }
11206 break;
11207 case ICmpInst::ICMP_ULE:
11208 if (!getUnsignedRangeMax(RHS).isMaxValue()) {
11209 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
11211 Pred = ICmpInst::ICMP_ULT;
11212 Changed = true;
11213 } else if (!getUnsignedRangeMin(LHS).isMinValue()) {
11214 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS);
11215 Pred = ICmpInst::ICMP_ULT;
11216 Changed = true;
11217 }
11218 break;
11219 case ICmpInst::ICMP_UGE:
11220 // If RHS is an op we can fold the -1, try that first.
11221 // Otherwise prefer LHS to preserve the nuw flag.
11222 if ((isa<SCEVConstant>(RHS) ||
11224 isa<SCEVConstant>(cast<SCEVNAryExpr>(RHS)->getOperand(0)))) &&
11225 !getUnsignedRangeMin(RHS).isMinValue()) {
11226 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
11227 Pred = ICmpInst::ICMP_UGT;
11228 Changed = true;
11229 } else if (!getUnsignedRangeMax(LHS).isMaxValue()) {
11230 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
11232 Pred = ICmpInst::ICMP_UGT;
11233 Changed = true;
11234 } else if (!getUnsignedRangeMin(RHS).isMinValue()) {
11235 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
11236 Pred = ICmpInst::ICMP_UGT;
11237 Changed = true;
11238 }
11239 break;
11240 default:
11241 break;
11242 }
11243
11244 // TODO: More simplifications are possible here.
11245
11246 // Recursively simplify until we either hit a recursion limit or nothing
11247 // changes.
11248 if (Changed)
11249 (void)SimplifyICmpOperands(Pred, LHS, RHS, Depth + 1);
11250
11251 return Changed;
11252}
11253
11255 return getSignedRangeMax(S).isNegative();
11256}
11257
11261
11263 return !getSignedRangeMin(S).isNegative();
11264}
11265
11269
11271 // Query push down for cases where the unsigned range is
11272 // less than sufficient.
11273 if (const auto *SExt = dyn_cast<SCEVSignExtendExpr>(S))
11274 return isKnownNonZero(SExt->getOperand(0));
11275 return getUnsignedRangeMin(S) != 0;
11276}
11277
11279 bool OrNegative) {
11280 auto NonRecursive = [OrNegative](const SCEV *S) {
11281 if (auto *C = dyn_cast<SCEVConstant>(S))
11282 return C->getAPInt().isPowerOf2() ||
11283 (OrNegative && C->getAPInt().isNegatedPowerOf2());
11284
11285 // vscale is a power-of-two.
11286 return isa<SCEVVScale>(S);
11287 };
11288
11289 if (NonRecursive(S))
11290 return true;
11291
11292 auto *Mul = dyn_cast<SCEVMulExpr>(S);
11293 if (!Mul)
11294 return false;
11295 return all_of(Mul->operands(), NonRecursive) && (OrZero || isKnownNonZero(S));
11296}
11297
11299 const SCEV *S, uint64_t M,
11301 if (M == 0)
11302 return false;
11303 if (M == 1)
11304 return true;
11305
11306 // Recursively check AddRec operands. An AddRecExpr S is a multiple of M if S
11307 // starts with a multiple of M and at every iteration step S only adds
11308 // multiples of M.
11309 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S))
11310 return isKnownMultipleOf(AddRec->getStart(), M, Assumptions) &&
11311 isKnownMultipleOf(AddRec->getStepRecurrence(*this), M, Assumptions);
11312
11313 // For a constant, check that "S % M == 0".
11314 if (auto *Cst = dyn_cast<SCEVConstant>(S)) {
11315 APInt C = Cst->getAPInt();
11316 return C.urem(M) == 0;
11317 }
11318
11319 // TODO: Also check other SCEV expressions, i.e., SCEVAddRecExpr, etc.
11320
11321 // Basic tests have failed.
11322 // Check "S % M == 0" at compile time and record runtime Assumptions.
11323 auto *STy = dyn_cast<IntegerType>(S->getType());
11324 const SCEV *SmodM =
11325 getURemExpr(S, getConstant(ConstantInt::get(STy, M, false)));
11326 const SCEV *Zero = getZero(STy);
11327
11328 // Check whether "S % M == 0" is known at compile time.
11329 if (isKnownPredicate(ICmpInst::ICMP_EQ, SmodM, Zero))
11330 return true;
11331
11332 // Check whether "S % M != 0" is known at compile time.
11333 if (isKnownPredicate(ICmpInst::ICMP_NE, SmodM, Zero))
11334 return false;
11335
11337
11338 // Detect redundant predicates.
11339 for (auto *A : Assumptions)
11340 if (A->implies(P, *this))
11341 return true;
11342
11343 // Only record non-redundant predicates.
11344 Assumptions.push_back(P);
11345 return true;
11346}
11347
11349 return ((isKnownNonNegative(S1) && isKnownNonNegative(S2)) ||
11351}
11352
11353std::pair<const SCEV *, const SCEV *>
11355 // Compute SCEV on entry of loop L.
11356 const SCEV *Start = SCEVInitRewriter::rewrite(S, L, *this);
11357 if (Start == getCouldNotCompute())
11358 return { Start, Start };
11359 // Compute post increment SCEV for loop L.
11360 const SCEV *PostInc = SCEVPostIncRewriter::rewrite(S, L, *this);
11361 assert(PostInc != getCouldNotCompute() && "Unexpected could not compute");
11362 return { Start, PostInc };
11363}
11364
11366 SCEVUse RHS) {
11367 // First collect all loops.
11369 getUsedLoops(LHS, LoopsUsed);
11370 getUsedLoops(RHS, LoopsUsed);
11371
11372 if (LoopsUsed.empty())
11373 return false;
11374
11375 // Domination relationship must be a linear order on collected loops.
11376#ifndef NDEBUG
11377 for (const auto *L1 : LoopsUsed)
11378 for (const auto *L2 : LoopsUsed)
11379 assert((DT.dominates(L1->getHeader(), L2->getHeader()) ||
11380 DT.dominates(L2->getHeader(), L1->getHeader())) &&
11381 "Domination relationship is not a linear order");
11382#endif
11383
11384 const Loop *MDL =
11385 *llvm::max_element(LoopsUsed, [&](const Loop *L1, const Loop *L2) {
11386 return DT.properlyDominates(L1->getHeader(), L2->getHeader());
11387 });
11388
11389 // Get init and post increment value for LHS.
11390 auto SplitLHS = SplitIntoInitAndPostInc(MDL, LHS);
11391 // if LHS contains unknown non-invariant SCEV then bail out.
11392 if (SplitLHS.first == getCouldNotCompute())
11393 return false;
11394 assert (SplitLHS.second != getCouldNotCompute() && "Unexpected CNC");
11395 // Get init and post increment value for RHS.
11396 auto SplitRHS = SplitIntoInitAndPostInc(MDL, RHS);
11397 // if RHS contains unknown non-invariant SCEV then bail out.
11398 if (SplitRHS.first == getCouldNotCompute())
11399 return false;
11400 assert (SplitRHS.second != getCouldNotCompute() && "Unexpected CNC");
11401 // It is possible that init SCEV contains an invariant load but it does
11402 // not dominate MDL and is not available at MDL loop entry, so we should
11403 // check it here.
11404 if (!isAvailableAtLoopEntry(SplitLHS.first, MDL) ||
11405 !isAvailableAtLoopEntry(SplitRHS.first, MDL))
11406 return false;
11407
11408 // It seems backedge guard check is faster than entry one so in some cases
11409 // it can speed up whole estimation by short circuit
11410 return isLoopBackedgeGuardedByCond(MDL, Pred, SplitLHS.second,
11411 SplitRHS.second) &&
11412 isLoopEntryGuardedByCond(MDL, Pred, SplitLHS.first, SplitRHS.first);
11413}
11414
11416 SCEVUse RHS) {
11417 // Canonicalize the inputs first.
11418 (void)SimplifyICmpOperands(Pred, LHS, RHS);
11419
11420 if (isKnownViaInduction(Pred, LHS, RHS))
11421 return true;
11422
11423 if (isKnownPredicateViaSplitting(Pred, LHS, RHS))
11424 return true;
11425
11426 // Otherwise see what can be done with some simple reasoning.
11427 return isKnownViaNonRecursiveReasoning(Pred, LHS, RHS);
11428}
11429
11431 const SCEV *LHS,
11432 const SCEV *RHS) {
11433 if (isKnownPredicate(Pred, LHS, RHS))
11434 return true;
11436 return false;
11437 return std::nullopt;
11438}
11439
11441 const SCEV *RHS,
11442 const Instruction *CtxI) {
11443 // TODO: Analyze guards and assumes from Context's block.
11444 return isKnownPredicate(Pred, LHS, RHS) ||
11445 isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS);
11446}
11447
11448std::optional<bool>
11450 const SCEV *RHS, const Instruction *CtxI) {
11451 std::optional<bool> KnownWithoutContext = evaluatePredicate(Pred, LHS, RHS);
11452 if (KnownWithoutContext)
11453 return KnownWithoutContext;
11454
11455 if (isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS))
11456 return true;
11458 CtxI->getParent(), ICmpInst::getInverseCmpPredicate(Pred), LHS, RHS))
11459 return false;
11460 return std::nullopt;
11461}
11462
11464 const SCEVAddRecExpr *LHS,
11465 const SCEV *RHS) {
11466 const Loop *L = LHS->getLoop();
11467 return isLoopEntryGuardedByCond(L, Pred, LHS->getStart(), RHS) &&
11468 isLoopBackedgeGuardedByCond(L, Pred, LHS->getPostIncExpr(*this), RHS);
11469}
11470
11471std::optional<ScalarEvolution::MonotonicPredicateType>
11473 ICmpInst::Predicate Pred) {
11474 auto Result = getMonotonicPredicateTypeImpl(LHS, Pred);
11475
11476#ifndef NDEBUG
11477 // Verify an invariant: inverting the predicate should turn a monotonically
11478 // increasing change to a monotonically decreasing one, and vice versa.
11479 if (Result) {
11480 auto ResultSwapped =
11481 getMonotonicPredicateTypeImpl(LHS, ICmpInst::getSwappedPredicate(Pred));
11482
11483 assert(*ResultSwapped != *Result &&
11484 "monotonicity should flip as we flip the predicate");
11485 }
11486#endif
11487
11488 return Result;
11489}
11490
11491std::optional<ScalarEvolution::MonotonicPredicateType>
11492ScalarEvolution::getMonotonicPredicateTypeImpl(const SCEVAddRecExpr *LHS,
11493 ICmpInst::Predicate Pred) {
11494 // A zero step value for LHS means the induction variable is essentially a
11495 // loop invariant value. We don't really depend on the predicate actually
11496 // flipping from false to true (for increasing predicates, and the other way
11497 // around for decreasing predicates), all we care about is that *if* the
11498 // predicate changes then it only changes from false to true.
11499 //
11500 // A zero step value in itself is not very useful, but there may be places
11501 // where SCEV can prove X >= 0 but not prove X > 0, so it is helpful to be
11502 // as general as possible.
11503
11504 // Only handle LE/LT/GE/GT predicates.
11505 if (!ICmpInst::isRelational(Pred))
11506 return std::nullopt;
11507
11508 bool IsGreater = ICmpInst::isGE(Pred) || ICmpInst::isGT(Pred);
11509 assert((IsGreater || ICmpInst::isLE(Pred) || ICmpInst::isLT(Pred)) &&
11510 "Should be greater or less!");
11511
11512 // Check that AR does not wrap.
11513 if (ICmpInst::isUnsigned(Pred)) {
11514 if (!LHS->hasNoUnsignedWrap())
11515 return std::nullopt;
11517 }
11518 assert(ICmpInst::isSigned(Pred) &&
11519 "Relational predicate is either signed or unsigned!");
11520 if (!LHS->hasNoSignedWrap())
11521 return std::nullopt;
11522
11523 const SCEV *Step = LHS->getStepRecurrence(*this);
11524
11525 if (isKnownNonNegative(Step))
11527
11528 if (isKnownNonPositive(Step))
11530
11531 return std::nullopt;
11532}
11533
11534std::optional<ScalarEvolution::LoopInvariantPredicate>
11536 const SCEV *RHS, const Loop *L,
11537 const Instruction *CtxI) {
11538 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
11539 if (!isLoopInvariant(RHS, L)) {
11540 if (!isLoopInvariant(LHS, L))
11541 return std::nullopt;
11542
11543 std::swap(LHS, RHS);
11545 }
11546
11547 const SCEVAddRecExpr *ArLHS = dyn_cast<SCEVAddRecExpr>(LHS);
11548 if (!ArLHS || ArLHS->getLoop() != L)
11549 return std::nullopt;
11550
11551 auto MonotonicType = getMonotonicPredicateType(ArLHS, Pred);
11552 if (!MonotonicType)
11553 return std::nullopt;
11554 // If the predicate "ArLHS `Pred` RHS" monotonically increases from false to
11555 // true as the loop iterates, and the backedge is control dependent on
11556 // "ArLHS `Pred` RHS" == true then we can reason as follows:
11557 //
11558 // * if the predicate was false in the first iteration then the predicate
11559 // is never evaluated again, since the loop exits without taking the
11560 // backedge.
11561 // * if the predicate was true in the first iteration then it will
11562 // continue to be true for all future iterations since it is
11563 // monotonically increasing.
11564 //
11565 // For both the above possibilities, we can replace the loop varying
11566 // predicate with its value on the first iteration of the loop (which is
11567 // loop invariant).
11568 //
11569 // A similar reasoning applies for a monotonically decreasing predicate, by
11570 // replacing true with false and false with true in the above two bullets.
11572 auto P = Increasing ? Pred : ICmpInst::getInverseCmpPredicate(Pred);
11573
11574 if (isLoopBackedgeGuardedByCond(L, P, LHS, RHS))
11576 RHS);
11577
11578 if (!CtxI)
11579 return std::nullopt;
11580 // Try to prove via context.
11581 // TODO: Support other cases.
11582 switch (Pred) {
11583 default:
11584 break;
11585 case ICmpInst::ICMP_ULE:
11586 case ICmpInst::ICMP_ULT: {
11587 assert(ArLHS->hasNoUnsignedWrap() && "Is a requirement of monotonicity!");
11588 // Given preconditions
11589 // (1) ArLHS does not cross the border of positive and negative parts of
11590 // range because of:
11591 // - Positive step; (TODO: lift this limitation)
11592 // - nuw - does not cross zero boundary;
11593 // - nsw - does not cross SINT_MAX boundary;
11594 // (2) ArLHS <s RHS
11595 // (3) RHS >=s 0
11596 // we can replace the loop variant ArLHS <u RHS condition with loop
11597 // invariant Start(ArLHS) <u RHS.
11598 //
11599 // Because of (1) there are two options:
11600 // - ArLHS is always negative. It means that ArLHS <u RHS is always false;
11601 // - ArLHS is always non-negative. Because of (3) RHS is also non-negative.
11602 // It means that ArLHS <s RHS <=> ArLHS <u RHS.
11603 // Because of (2) ArLHS <u RHS is trivially true.
11604 // All together it means that ArLHS <u RHS <=> Start(ArLHS) >=s 0.
11605 // We can strengthen this to Start(ArLHS) <u RHS.
11606 auto SignFlippedPred = ICmpInst::getFlippedSignednessPredicate(Pred);
11607 if (ArLHS->hasNoSignedWrap() && ArLHS->isAffine() &&
11608 isKnownPositive(ArLHS->getStepRecurrence(*this)) &&
11609 isKnownNonNegative(RHS) &&
11610 isKnownPredicateAt(SignFlippedPred, ArLHS, RHS, CtxI))
11612 RHS);
11613 }
11614 }
11615
11616 return std::nullopt;
11617}
11618
11619std::optional<ScalarEvolution::LoopInvariantPredicate>
11621 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11622 const Instruction *CtxI, const SCEV *MaxIter) {
11624 Pred, LHS, RHS, L, CtxI, MaxIter))
11625 return LIP;
11626 if (auto *UMin = dyn_cast<SCEVUMinExpr>(MaxIter))
11627 // Number of iterations expressed as UMIN isn't always great for expressing
11628 // the value on the last iteration. If the straightforward approach didn't
11629 // work, try the following trick: if the a predicate is invariant for X, it
11630 // is also invariant for umin(X, ...). So try to find something that works
11631 // among subexpressions of MaxIter expressed as umin.
11632 for (SCEVUse Op : UMin->operands())
11634 Pred, LHS, RHS, L, CtxI, Op))
11635 return LIP;
11636 return std::nullopt;
11637}
11638
11639std::optional<ScalarEvolution::LoopInvariantPredicate>
11641 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11642 const Instruction *CtxI, const SCEV *MaxIter) {
11643 // Try to prove the following set of facts:
11644 // - The predicate is monotonic in the iteration space.
11645 // - If the check does not fail on the 1st iteration:
11646 // - No overflow will happen during first MaxIter iterations;
11647 // - It will not fail on the MaxIter'th iteration.
11648 // If the check does fail on the 1st iteration, we leave the loop and no
11649 // other checks matter.
11650
11651 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
11652 if (!isLoopInvariant(RHS, L)) {
11653 if (!isLoopInvariant(LHS, L))
11654 return std::nullopt;
11655
11656 std::swap(LHS, RHS);
11658 }
11659
11660 auto *AR = dyn_cast<SCEVAddRecExpr>(LHS);
11661 if (!AR || AR->getLoop() != L)
11662 return std::nullopt;
11663
11664 // Even if both are valid, we need to consistently chose the unsigned or the
11665 // signed predicate below, not mixtures of both. For now, prefer the unsigned
11666 // predicate.
11667 Pred = Pred.dropSameSign();
11668
11669 // The predicate must be relational (i.e. <, <=, >=, >).
11670 if (!ICmpInst::isRelational(Pred))
11671 return std::nullopt;
11672
11673 // TODO: Support steps other than +/- 1.
11674 const SCEV *Step = AR->getStepRecurrence(*this);
11675 auto *One = getOne(Step->getType());
11676 auto *MinusOne = getNegativeSCEV(One);
11677 if (Step != One && Step != MinusOne)
11678 return std::nullopt;
11679
11680 // Type mismatch here means that MaxIter is potentially larger than max
11681 // unsigned value in start type, which mean we cannot prove no wrap for the
11682 // indvar.
11683 if (AR->getType() != MaxIter->getType())
11684 return std::nullopt;
11685
11686 // Value of IV on suggested last iteration.
11687 const SCEV *Last = AR->evaluateAtIteration(MaxIter, *this);
11688 // Does it still meet the requirement?
11689 if (!isLoopBackedgeGuardedByCond(L, Pred, Last, RHS))
11690 return std::nullopt;
11691 // Because step is +/- 1 and MaxIter has same type as Start (i.e. it does
11692 // not exceed max unsigned value of this type), this effectively proves
11693 // that there is no wrap during the iteration. To prove that there is no
11694 // signed/unsigned wrap, we need to check that
11695 // Start <= Last for step = 1 or Start >= Last for step = -1.
11696 ICmpInst::Predicate NoOverflowPred =
11698 if (Step == MinusOne)
11699 NoOverflowPred = ICmpInst::getSwappedPredicate(NoOverflowPred);
11700 const SCEV *Start = AR->getStart();
11701 if (!isKnownPredicateAt(NoOverflowPred, Start, Last, CtxI))
11702 return std::nullopt;
11703
11704 // Everything is fine.
11705 return ScalarEvolution::LoopInvariantPredicate(Pred, Start, RHS);
11706}
11707
11708bool ScalarEvolution::isKnownPredicateViaConstantRanges(CmpPredicate Pred,
11709 SCEVUse LHS,
11710 SCEVUse RHS) {
11711 if (HasSameValue(LHS, RHS))
11712 return ICmpInst::isTrueWhenEqual(Pred);
11713
11714 auto CheckRange = [&](bool IsSigned) {
11715 auto RangeLHS = IsSigned ? getSignedRange(LHS) : getUnsignedRange(LHS);
11716 auto RangeRHS = IsSigned ? getSignedRange(RHS) : getUnsignedRange(RHS);
11717 return RangeLHS.icmp(Pred, RangeRHS);
11718 };
11719
11720 // The check at the top of the function catches the case where the values are
11721 // known to be equal.
11722 if (Pred == CmpInst::ICMP_EQ)
11723 return false;
11724
11725 if (Pred == CmpInst::ICMP_NE) {
11726 if (CheckRange(true) || CheckRange(false))
11727 return true;
11728 auto *Diff = getMinusSCEV(LHS, RHS);
11729 return !isa<SCEVCouldNotCompute>(Diff) && isKnownNonZero(Diff);
11730 }
11731
11732 return CheckRange(CmpInst::isSigned(Pred));
11733}
11734
11735bool ScalarEvolution::isKnownPredicateViaNoOverflow(CmpPredicate Pred,
11737 // Match X to (A + C1)<ExpectedFlags> and Y to (A + C2)<ExpectedFlags>, where
11738 // C1 and C2 are constant integers. If either X or Y are not add expressions,
11739 // consider them as X + 0 and Y + 0 respectively. C1 and C2 are returned via
11740 // OutC1 and OutC2.
11741 auto MatchBinaryAddToConst = [this](SCEVUse X, SCEVUse Y, APInt &OutC1,
11742 APInt &OutC2,
11743 SCEV::NoWrapFlags ExpectedFlags) {
11744 SCEVUse XNonConstOp, XConstOp;
11745 SCEVUse YNonConstOp, YConstOp;
11746 SCEV::NoWrapFlags XFlagsPresent;
11747 SCEV::NoWrapFlags YFlagsPresent;
11748
11749 if (!splitBinaryAdd(X, XConstOp, XNonConstOp, XFlagsPresent)) {
11750 XConstOp = getZero(X->getType());
11751 XNonConstOp = X;
11752 XFlagsPresent = ExpectedFlags;
11753 }
11754 if (!isa<SCEVConstant>(XConstOp))
11755 return false;
11756
11757 if (!splitBinaryAdd(Y, YConstOp, YNonConstOp, YFlagsPresent)) {
11758 YConstOp = getZero(Y->getType());
11759 YNonConstOp = Y;
11760 YFlagsPresent = ExpectedFlags;
11761 }
11762
11763 if (YNonConstOp != XNonConstOp)
11764 return false;
11765
11766 if (!isa<SCEVConstant>(YConstOp))
11767 return false;
11768
11769 // When matching ADDs with NUW flags (and unsigned predicates), only the
11770 // second ADD (with the larger constant) requires NUW.
11771 if ((YFlagsPresent & ExpectedFlags) != ExpectedFlags)
11772 return false;
11773 if (ExpectedFlags != SCEV::FlagNUW &&
11774 (XFlagsPresent & ExpectedFlags) != ExpectedFlags) {
11775 return false;
11776 }
11777
11778 OutC1 = cast<SCEVConstant>(XConstOp)->getAPInt();
11779 OutC2 = cast<SCEVConstant>(YConstOp)->getAPInt();
11780
11781 return true;
11782 };
11783
11784 APInt C1;
11785 APInt C2;
11786
11787 switch (Pred) {
11788 default:
11789 break;
11790
11791 case ICmpInst::ICMP_SGE:
11792 std::swap(LHS, RHS);
11793 [[fallthrough]];
11794 case ICmpInst::ICMP_SLE:
11795 // (X + C1)<nsw> s<= (X + C2)<nsw> if C1 s<= C2.
11796 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.sle(C2))
11797 return true;
11798
11799 break;
11800
11801 case ICmpInst::ICMP_SGT:
11802 std::swap(LHS, RHS);
11803 [[fallthrough]];
11804 case ICmpInst::ICMP_SLT:
11805 // (X + C1)<nsw> s< (X + C2)<nsw> if C1 s< C2.
11806 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.slt(C2))
11807 return true;
11808
11809 break;
11810
11811 case ICmpInst::ICMP_UGE:
11812 std::swap(LHS, RHS);
11813 [[fallthrough]];
11814 case ICmpInst::ICMP_ULE:
11815 // (X + C1) u<= (X + C2)<nuw> for C1 u<= C2.
11816 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNUW) && C1.ule(C2))
11817 return true;
11818
11819 break;
11820
11821 case ICmpInst::ICMP_UGT:
11822 std::swap(LHS, RHS);
11823 [[fallthrough]];
11824 case ICmpInst::ICMP_ULT:
11825 // (X + C1) u< (X + C2)<nuw> if C1 u< C2.
11826 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNUW) && C1.ult(C2))
11827 return true;
11828 break;
11829 }
11830
11831 return false;
11832}
11833
11834bool ScalarEvolution::isKnownPredicateViaSplitting(CmpPredicate Pred,
11836 if (Pred != ICmpInst::ICMP_ULT || ProvingSplitPredicate)
11837 return false;
11838
11839 // Allowing arbitrary number of activations of isKnownPredicateViaSplitting on
11840 // the stack can result in exponential time complexity.
11841 SaveAndRestore Restore(ProvingSplitPredicate, true);
11842
11843 // If L >= 0 then I `ult` L <=> I >= 0 && I `slt` L
11844 //
11845 // To prove L >= 0 we use isKnownNonNegative whereas to prove I >= 0 we use
11846 // isKnownPredicate. isKnownPredicate is more powerful, but also more
11847 // expensive; and using isKnownNonNegative(RHS) is sufficient for most of the
11848 // interesting cases seen in practice. We can consider "upgrading" L >= 0 to
11849 // use isKnownPredicate later if needed.
11850 return isKnownNonNegative(RHS) &&
11853}
11854
11855bool ScalarEvolution::isImpliedViaGuard(const BasicBlock *BB, CmpPredicate Pred,
11856 const SCEV *LHS, const SCEV *RHS) {
11857 // No need to even try if we know the module has no guards.
11858 if (!HasGuards)
11859 return false;
11860
11861 return any_of(*BB, [&](const Instruction &I) {
11862 using namespace llvm::PatternMatch;
11863
11864 Value *Condition;
11866 m_Value(Condition))) &&
11867 isImpliedCond(Pred, LHS, RHS, Condition, false);
11868 });
11869}
11870
11871/// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is
11872/// protected by a conditional between LHS and RHS. This is used to
11873/// to eliminate casts.
11875 CmpPredicate Pred,
11876 const SCEV *LHS,
11877 const SCEV *RHS) {
11878 // Interpret a null as meaning no loop, where there is obviously no guard
11879 // (interprocedural conditions notwithstanding). Do not bother about
11880 // unreachable loops.
11881 if (!L || !DT.isReachableFromEntry(L->getHeader()))
11882 return true;
11883
11884 if (VerifyIR)
11885 assert(!verifyFunction(*L->getHeader()->getParent(), &dbgs()) &&
11886 "This cannot be done on broken IR!");
11887
11888
11889 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
11890 return true;
11891
11892 BasicBlock *Latch = L->getLoopLatch();
11893 if (!Latch)
11894 return false;
11895
11896 CondBrInst *LoopContinuePredicate =
11898 if (LoopContinuePredicate &&
11899 isImpliedCond(Pred, LHS, RHS, LoopContinuePredicate->getCondition(),
11900 LoopContinuePredicate->getSuccessor(0) != L->getHeader()))
11901 return true;
11902
11903 // We don't want more than one activation of the following loops on the stack
11904 // -- that can lead to O(n!) time complexity.
11905 if (WalkingBEDominatingConds)
11906 return false;
11907
11908 SaveAndRestore ClearOnExit(WalkingBEDominatingConds, true);
11909
11910 // See if we can exploit a trip count to prove the predicate.
11911 const auto &BETakenInfo = getBackedgeTakenInfo(L);
11912 const SCEV *LatchBECount = BETakenInfo.getExact(Latch, this);
11913 if (LatchBECount != getCouldNotCompute()) {
11914 // We know that Latch branches back to the loop header exactly
11915 // LatchBECount times. This means the backdege condition at Latch is
11916 // equivalent to "{0,+,1} u< LatchBECount".
11917 Type *Ty = LatchBECount->getType();
11918 auto NoWrapFlags = SCEV::NoWrapFlags(SCEV::FlagNUW | SCEV::FlagNW);
11919 const SCEV *LoopCounter =
11920 getAddRecExpr(getZero(Ty), getOne(Ty), L, NoWrapFlags);
11921 if (isImpliedCond(Pred, LHS, RHS, ICmpInst::ICMP_ULT, LoopCounter,
11922 LatchBECount))
11923 return true;
11924 }
11925
11926 // Check conditions due to any @llvm.assume intrinsics.
11927 for (auto &AssumeVH : AC.assumptions()) {
11928 if (!AssumeVH)
11929 continue;
11930 auto *CI = cast<CallInst>(AssumeVH);
11931 if (!DT.dominates(CI, Latch->getTerminator()))
11932 continue;
11933
11934 if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false))
11935 return true;
11936 }
11937
11938 if (isImpliedViaGuard(Latch, Pred, LHS, RHS))
11939 return true;
11940
11941 for (DomTreeNode *DTN = DT[Latch], *HeaderDTN = DT[L->getHeader()];
11942 DTN != HeaderDTN; DTN = DTN->getIDom()) {
11943 assert(DTN && "should reach the loop header before reaching the root!");
11944
11945 BasicBlock *BB = DTN->getBlock();
11946 if (isImpliedViaGuard(BB, Pred, LHS, RHS))
11947 return true;
11948
11949 BasicBlock *PBB = BB->getSinglePredecessor();
11950 if (!PBB)
11951 continue;
11952
11954 if (!ContBr || ContBr->getSuccessor(0) == ContBr->getSuccessor(1))
11955 continue;
11956
11957 // If we have an edge `E` within the loop body that dominates the only
11958 // latch, the condition guarding `E` also guards the backedge. This
11959 // reasoning works only for loops with a single latch.
11960 // We're constructively (and conservatively) enumerating edges within the
11961 // loop body that dominate the latch. The dominator tree better agree
11962 // with us on this:
11963 assert(DT.dominates(BasicBlockEdge(PBB, BB), Latch) && "should be!");
11964 if (isImpliedCond(Pred, LHS, RHS, ContBr->getCondition(),
11965 BB != ContBr->getSuccessor(0)))
11966 return true;
11967 }
11968
11969 return false;
11970}
11971
11973 CmpPredicate Pred,
11974 const SCEV *LHS,
11975 const SCEV *RHS) {
11976 // Do not bother proving facts for unreachable code.
11977 if (!DT.isReachableFromEntry(BB))
11978 return true;
11979 if (VerifyIR)
11980 assert(!verifyFunction(*BB->getParent(), &dbgs()) &&
11981 "This cannot be done on broken IR!");
11982
11983 // If we cannot prove strict comparison (e.g. a > b), maybe we can prove
11984 // the facts (a >= b && a != b) separately. A typical situation is when the
11985 // non-strict comparison is known from ranges and non-equality is known from
11986 // dominating predicates. If we are proving strict comparison, we always try
11987 // to prove non-equality and non-strict comparison separately.
11988 CmpPredicate NonStrictPredicate = ICmpInst::getNonStrictCmpPredicate(Pred);
11989 const bool ProvingStrictComparison =
11990 Pred != NonStrictPredicate.dropSameSign();
11991 bool ProvedNonStrictComparison = false;
11992 bool ProvedNonEquality = false;
11993
11994 auto SplitAndProve = [&](std::function<bool(CmpPredicate)> Fn) -> bool {
11995 if (!ProvedNonStrictComparison)
11996 ProvedNonStrictComparison = Fn(NonStrictPredicate);
11997 if (!ProvedNonEquality)
11998 ProvedNonEquality = Fn(ICmpInst::ICMP_NE);
11999 if (ProvedNonStrictComparison && ProvedNonEquality)
12000 return true;
12001 return false;
12002 };
12003
12004 if (ProvingStrictComparison) {
12005 auto ProofFn = [&](CmpPredicate P) {
12006 return isKnownViaNonRecursiveReasoning(P, LHS, RHS);
12007 };
12008 if (SplitAndProve(ProofFn))
12009 return true;
12010 }
12011
12012 // Try to prove (Pred, LHS, RHS) using isImpliedCond.
12013 auto ProveViaCond = [&](const Value *Condition, bool Inverse) {
12014 const Instruction *CtxI = &BB->front();
12015 if (isImpliedCond(Pred, LHS, RHS, Condition, Inverse, CtxI))
12016 return true;
12017 if (ProvingStrictComparison) {
12018 auto ProofFn = [&](CmpPredicate P) {
12019 return isImpliedCond(P, LHS, RHS, Condition, Inverse, CtxI);
12020 };
12021 if (SplitAndProve(ProofFn))
12022 return true;
12023 }
12024 return false;
12025 };
12026
12027 // Starting at the block's predecessor, climb up the predecessor chain, as long
12028 // as there are predecessors that can be found that have unique successors
12029 // leading to the original block.
12030 const Loop *ContainingLoop = LI.getLoopFor(BB);
12031 const BasicBlock *PredBB;
12032 if (ContainingLoop && ContainingLoop->getHeader() == BB)
12033 PredBB = ContainingLoop->getLoopPredecessor();
12034 else
12035 PredBB = BB->getSinglePredecessor();
12036 for (std::pair<const BasicBlock *, const BasicBlock *> Pair(PredBB, BB);
12037 Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
12038 const CondBrInst *BlockEntryPredicate =
12039 dyn_cast<CondBrInst>(Pair.first->getTerminator());
12040 if (!BlockEntryPredicate)
12041 continue;
12042
12043 if (ProveViaCond(BlockEntryPredicate->getCondition(),
12044 BlockEntryPredicate->getSuccessor(0) != Pair.second))
12045 return true;
12046 }
12047
12048 // Check conditions due to any @llvm.assume intrinsics.
12049 for (auto &AssumeVH : AC.assumptions()) {
12050 if (!AssumeVH)
12051 continue;
12052 auto *CI = cast<CallInst>(AssumeVH);
12053 if (!DT.dominates(CI, BB))
12054 continue;
12055
12056 if (ProveViaCond(CI->getArgOperand(0), false))
12057 return true;
12058 }
12059
12060 // Check conditions due to any @llvm.experimental.guard intrinsics.
12061 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
12062 F.getParent(), Intrinsic::experimental_guard);
12063 if (GuardDecl)
12064 for (const auto *GU : GuardDecl->users())
12065 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
12066 if (Guard->getFunction() == BB->getParent() && DT.dominates(Guard, BB))
12067 if (ProveViaCond(Guard->getArgOperand(0), false))
12068 return true;
12069 return false;
12070}
12071
12073 const SCEV *LHS,
12074 const SCEV *RHS) {
12075 // Interpret a null as meaning no loop, where there is obviously no guard
12076 // (interprocedural conditions notwithstanding).
12077 if (!L)
12078 return false;
12079
12080 // Both LHS and RHS must be available at loop entry.
12082 "LHS is not available at Loop Entry");
12084 "RHS is not available at Loop Entry");
12085
12086 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
12087 return true;
12088
12089 return isBasicBlockEntryGuardedByCond(L->getHeader(), Pred, LHS, RHS);
12090}
12091
12092bool ScalarEvolution::isImpliedCond(CmpPredicate Pred, const SCEV *LHS,
12093 const SCEV *RHS,
12094 const Value *FoundCondValue, bool Inverse,
12095 const Instruction *CtxI) {
12096 // False conditions implies anything. Do not bother analyzing it further.
12097 if (FoundCondValue ==
12098 ConstantInt::getBool(FoundCondValue->getContext(), Inverse))
12099 return true;
12100
12101 if (!PendingLoopPredicates.insert(FoundCondValue).second)
12102 return false;
12103
12104 llvm::scope_exit ClearOnExit(
12105 [&]() { PendingLoopPredicates.erase(FoundCondValue); });
12106
12107 // Recursively handle And and Or conditions.
12108 const Value *Op0, *Op1;
12109 if (match(FoundCondValue, m_LogicalAnd(m_Value(Op0), m_Value(Op1)))) {
12110 if (!Inverse)
12111 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
12112 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
12113 } else if (match(FoundCondValue, m_LogicalOr(m_Value(Op0), m_Value(Op1)))) {
12114 if (Inverse)
12115 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
12116 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
12117 }
12118
12119 const ICmpInst *ICI = dyn_cast<ICmpInst>(FoundCondValue);
12120 if (!ICI) return false;
12121
12122 // Now that we found a conditional branch that dominates the loop or controls
12123 // the loop latch. Check to see if it is the comparison we are looking for.
12124 CmpPredicate FoundPred;
12125 if (Inverse)
12126 FoundPred = ICI->getInverseCmpPredicate();
12127 else
12128 FoundPred = ICI->getCmpPredicate();
12129
12130 const SCEV *FoundLHS = getSCEV(ICI->getOperand(0));
12131 const SCEV *FoundRHS = getSCEV(ICI->getOperand(1));
12132
12133 return isImpliedCond(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS, CtxI);
12134}
12135
12136bool ScalarEvolution::isImpliedCond(CmpPredicate Pred, const SCEV *LHS,
12137 const SCEV *RHS, CmpPredicate FoundPred,
12138 const SCEV *FoundLHS, const SCEV *FoundRHS,
12139 const Instruction *CtxI) {
12140 // Balance the types.
12141 if (getTypeSizeInBits(LHS->getType()) <
12142 getTypeSizeInBits(FoundLHS->getType())) {
12143 // For unsigned and equality predicates, try to prove that both found
12144 // operands fit into narrow unsigned range. If so, try to prove facts in
12145 // narrow types.
12146 if (!CmpInst::isSigned(FoundPred) && !FoundLHS->getType()->isPointerTy() &&
12147 !FoundRHS->getType()->isPointerTy()) {
12148 auto *NarrowType = LHS->getType();
12149 auto *WideType = FoundLHS->getType();
12150 auto BitWidth = getTypeSizeInBits(NarrowType);
12151 const SCEV *MaxValue = getZeroExtendExpr(
12153 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundLHS,
12154 MaxValue) &&
12155 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundRHS,
12156 MaxValue)) {
12157 const SCEV *TruncFoundLHS = getTruncateExpr(FoundLHS, NarrowType);
12158 const SCEV *TruncFoundRHS = getTruncateExpr(FoundRHS, NarrowType);
12159 // We cannot preserve samesign after truncation.
12160 if (isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred.dropSameSign(),
12161 TruncFoundLHS, TruncFoundRHS, CtxI))
12162 return true;
12163 }
12164 }
12165
12166 if (LHS->getType()->isPointerTy() || RHS->getType()->isPointerTy())
12167 return false;
12168 if (CmpInst::isSigned(Pred)) {
12169 LHS = getSignExtendExpr(LHS, FoundLHS->getType());
12170 RHS = getSignExtendExpr(RHS, FoundLHS->getType());
12171 } else {
12172 LHS = getZeroExtendExpr(LHS, FoundLHS->getType());
12173 RHS = getZeroExtendExpr(RHS, FoundLHS->getType());
12174 }
12175 } else if (getTypeSizeInBits(LHS->getType()) >
12176 getTypeSizeInBits(FoundLHS->getType())) {
12177 if (FoundLHS->getType()->isPointerTy() || FoundRHS->getType()->isPointerTy())
12178 return false;
12179 if (CmpInst::isSigned(FoundPred)) {
12180 FoundLHS = getSignExtendExpr(FoundLHS, LHS->getType());
12181 FoundRHS = getSignExtendExpr(FoundRHS, LHS->getType());
12182 } else {
12183 FoundLHS = getZeroExtendExpr(FoundLHS, LHS->getType());
12184 FoundRHS = getZeroExtendExpr(FoundRHS, LHS->getType());
12185 }
12186 }
12187 return isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, FoundLHS,
12188 FoundRHS, CtxI);
12189}
12190
12191bool ScalarEvolution::isImpliedCondBalancedTypes(
12192 CmpPredicate Pred, SCEVUse LHS, SCEVUse RHS, CmpPredicate FoundPred,
12193 SCEVUse FoundLHS, SCEVUse FoundRHS, const Instruction *CtxI) {
12195 getTypeSizeInBits(FoundLHS->getType()) &&
12196 "Types should be balanced!");
12197 // Canonicalize the query to match the way instcombine will have
12198 // canonicalized the comparison.
12199 if (SimplifyICmpOperands(Pred, LHS, RHS))
12200 if (LHS == RHS)
12201 return CmpInst::isTrueWhenEqual(Pred);
12202 if (SimplifyICmpOperands(FoundPred, FoundLHS, FoundRHS))
12203 if (FoundLHS == FoundRHS)
12204 return CmpInst::isFalseWhenEqual(FoundPred);
12205
12206 // Check to see if we can make the LHS or RHS match.
12207 if (LHS == FoundRHS || RHS == FoundLHS) {
12208 if (isa<SCEVConstant>(RHS)) {
12209 std::swap(FoundLHS, FoundRHS);
12210 FoundPred = ICmpInst::getSwappedCmpPredicate(FoundPred);
12211 } else {
12212 std::swap(LHS, RHS);
12214 }
12215 }
12216
12217 // Check whether the found predicate is the same as the desired predicate.
12218 if (auto P = CmpPredicate::getMatching(FoundPred, Pred))
12219 return isImpliedCondOperands(*P, LHS, RHS, FoundLHS, FoundRHS, CtxI);
12220
12221 // Check whether swapping the found predicate makes it the same as the
12222 // desired predicate.
12223 if (auto P = CmpPredicate::getMatching(
12224 ICmpInst::getSwappedCmpPredicate(FoundPred), Pred)) {
12225 // We can write the implication
12226 // 0. LHS Pred RHS <- FoundLHS SwapPred FoundRHS
12227 // using one of the following ways:
12228 // 1. LHS Pred RHS <- FoundRHS Pred FoundLHS
12229 // 2. RHS SwapPred LHS <- FoundLHS SwapPred FoundRHS
12230 // 3. LHS Pred RHS <- ~FoundLHS Pred ~FoundRHS
12231 // 4. ~LHS SwapPred ~RHS <- FoundLHS SwapPred FoundRHS
12232 // Forms 1. and 2. require swapping the operands of one condition. Don't
12233 // do this if it would break canonical constant/addrec ordering.
12235 return isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(*P), RHS,
12236 LHS, FoundLHS, FoundRHS, CtxI);
12237 if (!isa<SCEVConstant>(FoundRHS) && !isa<SCEVAddRecExpr>(FoundLHS))
12238 return isImpliedCondOperands(*P, LHS, RHS, FoundRHS, FoundLHS, CtxI);
12239
12240 // There's no clear preference between forms 3. and 4., try both. Avoid
12241 // forming getNotSCEV of pointer values as the resulting subtract is
12242 // not legal.
12243 if (!LHS->getType()->isPointerTy() && !RHS->getType()->isPointerTy() &&
12244 isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(*P),
12245 getNotSCEV(LHS), getNotSCEV(RHS), FoundLHS,
12246 FoundRHS, CtxI))
12247 return true;
12248
12249 if (!FoundLHS->getType()->isPointerTy() &&
12250 !FoundRHS->getType()->isPointerTy() &&
12251 isImpliedCondOperands(*P, LHS, RHS, getNotSCEV(FoundLHS),
12252 getNotSCEV(FoundRHS), CtxI))
12253 return true;
12254
12255 return false;
12256 }
12257
12258 auto IsSignFlippedPredicate = [](CmpInst::Predicate P1,
12260 assert(P1 != P2 && "Handled earlier!");
12261 return CmpInst::isRelational(P2) &&
12263 };
12264 if (IsSignFlippedPredicate(Pred, FoundPred)) {
12265 // Unsigned comparison is the same as signed comparison when both the
12266 // operands are non-negative or negative.
12267 if (haveSameSign(FoundLHS, FoundRHS))
12268 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI);
12269 // Create local copies that we can freely swap and canonicalize our
12270 // conditions to "le/lt".
12271 CmpPredicate CanonicalPred = Pred, CanonicalFoundPred = FoundPred;
12272 const SCEV *CanonicalLHS = LHS, *CanonicalRHS = RHS,
12273 *CanonicalFoundLHS = FoundLHS, *CanonicalFoundRHS = FoundRHS;
12274 if (ICmpInst::isGT(CanonicalPred) || ICmpInst::isGE(CanonicalPred)) {
12275 CanonicalPred = ICmpInst::getSwappedCmpPredicate(CanonicalPred);
12276 CanonicalFoundPred = ICmpInst::getSwappedCmpPredicate(CanonicalFoundPred);
12277 std::swap(CanonicalLHS, CanonicalRHS);
12278 std::swap(CanonicalFoundLHS, CanonicalFoundRHS);
12279 }
12280 assert((ICmpInst::isLT(CanonicalPred) || ICmpInst::isLE(CanonicalPred)) &&
12281 "Must be!");
12282 assert((ICmpInst::isLT(CanonicalFoundPred) ||
12283 ICmpInst::isLE(CanonicalFoundPred)) &&
12284 "Must be!");
12285 if (ICmpInst::isSigned(CanonicalPred) && isKnownNonNegative(CanonicalRHS))
12286 // Use implication:
12287 // x <u y && y >=s 0 --> x <s y.
12288 // If we can prove the left part, the right part is also proven.
12289 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
12290 CanonicalRHS, CanonicalFoundLHS,
12291 CanonicalFoundRHS);
12292 if (ICmpInst::isUnsigned(CanonicalPred) && isKnownNegative(CanonicalRHS))
12293 // Use implication:
12294 // x <s y && y <s 0 --> x <u y.
12295 // If we can prove the left part, the right part is also proven.
12296 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
12297 CanonicalRHS, CanonicalFoundLHS,
12298 CanonicalFoundRHS);
12299 }
12300
12301 // Check if we can make progress by sharpening ranges.
12302 if (FoundPred == ICmpInst::ICMP_NE &&
12303 (isa<SCEVConstant>(FoundLHS) || isa<SCEVConstant>(FoundRHS))) {
12304
12305 const SCEVConstant *C = nullptr;
12306 const SCEV *V = nullptr;
12307
12308 if (isa<SCEVConstant>(FoundLHS)) {
12309 C = cast<SCEVConstant>(FoundLHS);
12310 V = FoundRHS;
12311 } else {
12312 C = cast<SCEVConstant>(FoundRHS);
12313 V = FoundLHS;
12314 }
12315
12316 // The guarding predicate tells us that C != V. If the known range
12317 // of V is [C, t), we can sharpen the range to [C + 1, t). The
12318 // range we consider has to correspond to same signedness as the
12319 // predicate we're interested in folding.
12320
12321 APInt Min = ICmpInst::isSigned(Pred) ?
12323
12324 if (Min == C->getAPInt()) {
12325 // Given (V >= Min && V != Min) we conclude V >= (Min + 1).
12326 // This is true even if (Min + 1) wraps around -- in case of
12327 // wraparound, (Min + 1) < Min, so (V >= Min => V >= (Min + 1)).
12328
12329 APInt SharperMin = Min + 1;
12330
12331 switch (Pred) {
12332 case ICmpInst::ICMP_SGE:
12333 case ICmpInst::ICMP_UGE:
12334 // We know V `Pred` SharperMin. If this implies LHS `Pred`
12335 // RHS, we're done.
12336 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(SharperMin),
12337 CtxI))
12338 return true;
12339 [[fallthrough]];
12340
12341 case ICmpInst::ICMP_SGT:
12342 case ICmpInst::ICMP_UGT:
12343 // We know from the range information that (V `Pred` Min ||
12344 // V == Min). We know from the guarding condition that !(V
12345 // == Min). This gives us
12346 //
12347 // V `Pred` Min || V == Min && !(V == Min)
12348 // => V `Pred` Min
12349 //
12350 // If V `Pred` Min implies LHS `Pred` RHS, we're done.
12351
12352 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min), CtxI))
12353 return true;
12354 break;
12355
12356 // `LHS < RHS` and `LHS <= RHS` are handled in the same way as `RHS > LHS` and `RHS >= LHS` respectively.
12357 case ICmpInst::ICMP_SLE:
12358 case ICmpInst::ICMP_ULE:
12359 if (isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(Pred), RHS,
12360 LHS, V, getConstant(SharperMin), CtxI))
12361 return true;
12362 [[fallthrough]];
12363
12364 case ICmpInst::ICMP_SLT:
12365 case ICmpInst::ICMP_ULT:
12366 if (isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(Pred), RHS,
12367 LHS, V, getConstant(Min), CtxI))
12368 return true;
12369 break;
12370
12371 default:
12372 // No change
12373 break;
12374 }
12375 }
12376 }
12377
12378 // Check whether the actual condition is beyond sufficient.
12379 if (FoundPred == ICmpInst::ICMP_EQ)
12380 if (ICmpInst::isTrueWhenEqual(Pred))
12381 if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
12382 return true;
12383 if (Pred == ICmpInst::ICMP_NE)
12384 if (!ICmpInst::isTrueWhenEqual(FoundPred))
12385 if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
12386 return true;
12387
12388 if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS))
12389 return true;
12390
12391 // Otherwise assume the worst.
12392 return false;
12393}
12394
12395bool ScalarEvolution::splitBinaryAdd(SCEVUse Expr, SCEVUse &L, SCEVUse &R,
12396 SCEV::NoWrapFlags &Flags) {
12397 if (!match(Expr, m_scev_Add(m_SCEV(L), m_SCEV(R))))
12398 return false;
12399
12400 Flags = cast<SCEVAddExpr>(Expr)->getNoWrapFlags();
12401 return true;
12402}
12403
12404std::optional<APInt>
12406 // We avoid subtracting expressions here because this function is usually
12407 // fairly deep in the call stack (i.e. is called many times).
12408
12409 unsigned BW = getTypeSizeInBits(More->getType());
12410 APInt Diff(BW, 0);
12411 APInt DiffMul(BW, 1);
12412 // Try various simplifications to reduce the difference to a constant. Limit
12413 // the number of allowed simplifications to keep compile-time low.
12414 for (unsigned I = 0; I < 8; ++I) {
12415 if (More == Less)
12416 return Diff;
12417
12418 // Reduce addrecs with identical steps to their start value.
12420 const auto *LAR = cast<SCEVAddRecExpr>(Less);
12421 const auto *MAR = cast<SCEVAddRecExpr>(More);
12422
12423 if (LAR->getLoop() != MAR->getLoop())
12424 return std::nullopt;
12425
12426 // We look at affine expressions only; not for correctness but to keep
12427 // getStepRecurrence cheap.
12428 if (!LAR->isAffine() || !MAR->isAffine())
12429 return std::nullopt;
12430
12431 if (LAR->getStepRecurrence(*this) != MAR->getStepRecurrence(*this))
12432 return std::nullopt;
12433
12434 Less = LAR->getStart();
12435 More = MAR->getStart();
12436 continue;
12437 }
12438
12439 // Try to match a common constant multiply.
12440 auto MatchConstMul =
12441 [](const SCEV *S) -> std::optional<std::pair<const SCEV *, APInt>> {
12442 const APInt *C;
12443 const SCEV *Op;
12444 if (match(S, m_scev_Mul(m_scev_APInt(C), m_SCEV(Op))))
12445 return {{Op, *C}};
12446 return std::nullopt;
12447 };
12448 if (auto MatchedMore = MatchConstMul(More)) {
12449 if (auto MatchedLess = MatchConstMul(Less)) {
12450 if (MatchedMore->second == MatchedLess->second) {
12451 More = MatchedMore->first;
12452 Less = MatchedLess->first;
12453 DiffMul *= MatchedMore->second;
12454 continue;
12455 }
12456 }
12457 }
12458
12459 // Try to cancel out common factors in two add expressions.
12461 auto Add = [&](const SCEV *S, int Mul) {
12462 if (auto *C = dyn_cast<SCEVConstant>(S)) {
12463 if (Mul == 1) {
12464 Diff += C->getAPInt() * DiffMul;
12465 } else {
12466 assert(Mul == -1);
12467 Diff -= C->getAPInt() * DiffMul;
12468 }
12469 } else
12470 Multiplicity[S] += Mul;
12471 };
12472 auto Decompose = [&](const SCEV *S, int Mul) {
12473 if (isa<SCEVAddExpr>(S)) {
12474 for (const SCEV *Op : S->operands())
12475 Add(Op, Mul);
12476 } else
12477 Add(S, Mul);
12478 };
12479 Decompose(More, 1);
12480 Decompose(Less, -1);
12481
12482 // Check whether all the non-constants cancel out, or reduce to new
12483 // More/Less values.
12484 const SCEV *NewMore = nullptr, *NewLess = nullptr;
12485 for (const auto &[S, Mul] : Multiplicity) {
12486 if (Mul == 0)
12487 continue;
12488 if (Mul == 1) {
12489 if (NewMore)
12490 return std::nullopt;
12491 NewMore = S;
12492 } else if (Mul == -1) {
12493 if (NewLess)
12494 return std::nullopt;
12495 NewLess = S;
12496 } else
12497 return std::nullopt;
12498 }
12499
12500 // Values stayed the same, no point in trying further.
12501 if (NewMore == More || NewLess == Less)
12502 return std::nullopt;
12503
12504 More = NewMore;
12505 Less = NewLess;
12506
12507 // Reduced to constant.
12508 if (!More && !Less)
12509 return Diff;
12510
12511 // Left with variable on only one side, bail out.
12512 if (!More || !Less)
12513 return std::nullopt;
12514 }
12515
12516 // Did not reduce to constant.
12517 return std::nullopt;
12518}
12519
12520bool ScalarEvolution::isImpliedCondOperandsViaAddRecStart(
12521 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS,
12522 const SCEV *FoundRHS, const Instruction *CtxI) {
12523 // Try to recognize the following pattern:
12524 //
12525 // FoundRHS = ...
12526 // ...
12527 // loop:
12528 // FoundLHS = {Start,+,W}
12529 // context_bb: // Basic block from the same loop
12530 // known(Pred, FoundLHS, FoundRHS)
12531 //
12532 // If some predicate is known in the context of a loop, it is also known on
12533 // each iteration of this loop, including the first iteration. Therefore, in
12534 // this case, `FoundLHS Pred FoundRHS` implies `Start Pred FoundRHS`. Try to
12535 // prove the original pred using this fact.
12536 if (!CtxI)
12537 return false;
12538 const BasicBlock *ContextBB = CtxI->getParent();
12539 // Make sure AR varies in the context block.
12540 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundLHS)) {
12541 const Loop *L = AR->getLoop();
12542 const auto *Latch = L->getLoopLatch();
12543 // Make sure that context belongs to the loop and executes on 1st iteration
12544 // (if it ever executes at all).
12545 if (!L->contains(ContextBB) || !Latch || !DT.dominates(ContextBB, Latch))
12546 return false;
12547 if (!isAvailableAtLoopEntry(FoundRHS, AR->getLoop()))
12548 return false;
12549 return isImpliedCondOperands(Pred, LHS, RHS, AR->getStart(), FoundRHS);
12550 }
12551
12552 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundRHS)) {
12553 const Loop *L = AR->getLoop();
12554 const auto *Latch = L->getLoopLatch();
12555 // Make sure that context belongs to the loop and executes on 1st iteration
12556 // (if it ever executes at all).
12557 if (!L->contains(ContextBB) || !Latch || !DT.dominates(ContextBB, Latch))
12558 return false;
12559 if (!isAvailableAtLoopEntry(FoundLHS, AR->getLoop()))
12560 return false;
12561 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, AR->getStart());
12562 }
12563
12564 return false;
12565}
12566
12567bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow(CmpPredicate Pred,
12568 const SCEV *LHS,
12569 const SCEV *RHS,
12570 const SCEV *FoundLHS,
12571 const SCEV *FoundRHS) {
12572 if (Pred != CmpInst::ICMP_SLT && Pred != CmpInst::ICMP_ULT)
12573 return false;
12574
12575 const auto *AddRecLHS = dyn_cast<SCEVAddRecExpr>(LHS);
12576 if (!AddRecLHS)
12577 return false;
12578
12579 const auto *AddRecFoundLHS = dyn_cast<SCEVAddRecExpr>(FoundLHS);
12580 if (!AddRecFoundLHS)
12581 return false;
12582
12583 // We'd like to let SCEV reason about control dependencies, so we constrain
12584 // both the inequalities to be about add recurrences on the same loop. This
12585 // way we can use isLoopEntryGuardedByCond later.
12586
12587 const Loop *L = AddRecFoundLHS->getLoop();
12588 if (L != AddRecLHS->getLoop())
12589 return false;
12590
12591 // FoundLHS u< FoundRHS u< -C => (FoundLHS + C) u< (FoundRHS + C) ... (1)
12592 //
12593 // FoundLHS s< FoundRHS s< INT_MIN - C => (FoundLHS + C) s< (FoundRHS + C)
12594 // ... (2)
12595 //
12596 // Informal proof for (2), assuming (1) [*]:
12597 //
12598 // We'll also assume (A s< B) <=> ((A + INT_MIN) u< (B + INT_MIN)) ... (3)[**]
12599 //
12600 // Then
12601 //
12602 // FoundLHS s< FoundRHS s< INT_MIN - C
12603 // <=> (FoundLHS + INT_MIN) u< (FoundRHS + INT_MIN) u< -C [ using (3) ]
12604 // <=> (FoundLHS + INT_MIN + C) u< (FoundRHS + INT_MIN + C) [ using (1) ]
12605 // <=> (FoundLHS + INT_MIN + C + INT_MIN) s<
12606 // (FoundRHS + INT_MIN + C + INT_MIN) [ using (3) ]
12607 // <=> FoundLHS + C s< FoundRHS + C
12608 //
12609 // [*]: (1) can be proved by ruling out overflow.
12610 //
12611 // [**]: This can be proved by analyzing all the four possibilities:
12612 // (A s< 0, B s< 0), (A s< 0, B s>= 0), (A s>= 0, B s< 0) and
12613 // (A s>= 0, B s>= 0).
12614 //
12615 // Note:
12616 // Despite (2), "FoundRHS s< INT_MIN - C" does not mean that "FoundRHS + C"
12617 // will not sign underflow. For instance, say FoundLHS = (i8 -128), FoundRHS
12618 // = (i8 -127) and C = (i8 -100). Then INT_MIN - C = (i8 -28), and FoundRHS
12619 // s< (INT_MIN - C). Lack of sign overflow / underflow in "FoundRHS + C" is
12620 // neither necessary nor sufficient to prove "(FoundLHS + C) s< (FoundRHS +
12621 // C)".
12622
12623 std::optional<APInt> LDiff = computeConstantDifference(LHS, FoundLHS);
12624 if (!LDiff)
12625 return false;
12626 std::optional<APInt> RDiff = computeConstantDifference(RHS, FoundRHS);
12627 if (!RDiff || *LDiff != *RDiff)
12628 return false;
12629
12630 if (LDiff->isMinValue())
12631 return true;
12632
12633 APInt FoundRHSLimit;
12634
12635 if (Pred == CmpInst::ICMP_ULT) {
12636 FoundRHSLimit = -(*RDiff);
12637 } else {
12638 assert(Pred == CmpInst::ICMP_SLT && "Checked above!");
12639 FoundRHSLimit = APInt::getSignedMinValue(getTypeSizeInBits(RHS->getType())) - *RDiff;
12640 }
12641
12642 // Try to prove (1) or (2), as needed.
12643 return isAvailableAtLoopEntry(FoundRHS, L) &&
12644 isLoopEntryGuardedByCond(L, Pred, FoundRHS,
12645 getConstant(FoundRHSLimit));
12646}
12647
12648bool ScalarEvolution::isImpliedViaMerge(CmpPredicate Pred, const SCEV *LHS,
12649 const SCEV *RHS, const SCEV *FoundLHS,
12650 const SCEV *FoundRHS, unsigned Depth) {
12651 const PHINode *LPhi = nullptr, *RPhi = nullptr;
12652
12653 llvm::scope_exit ClearOnExit([&]() {
12654 if (LPhi) {
12655 bool Erased = PendingMerges.erase(LPhi);
12656 assert(Erased && "Failed to erase LPhi!");
12657 (void)Erased;
12658 }
12659 if (RPhi) {
12660 bool Erased = PendingMerges.erase(RPhi);
12661 assert(Erased && "Failed to erase RPhi!");
12662 (void)Erased;
12663 }
12664 });
12665
12666 // Find respective Phis and check that they are not being pending.
12667 if (const SCEVUnknown *LU = dyn_cast<SCEVUnknown>(LHS))
12668 if (auto *Phi = dyn_cast<PHINode>(LU->getValue())) {
12669 if (!PendingMerges.insert(Phi).second)
12670 return false;
12671 LPhi = Phi;
12672 }
12673 if (const SCEVUnknown *RU = dyn_cast<SCEVUnknown>(RHS))
12674 if (auto *Phi = dyn_cast<PHINode>(RU->getValue())) {
12675 // If we detect a loop of Phi nodes being processed by this method, for
12676 // example:
12677 //
12678 // %a = phi i32 [ %some1, %preheader ], [ %b, %latch ]
12679 // %b = phi i32 [ %some2, %preheader ], [ %a, %latch ]
12680 //
12681 // we don't want to deal with a case that complex, so return conservative
12682 // answer false.
12683 if (!PendingMerges.insert(Phi).second)
12684 return false;
12685 RPhi = Phi;
12686 }
12687
12688 // If none of LHS, RHS is a Phi, nothing to do here.
12689 if (!LPhi && !RPhi)
12690 return false;
12691
12692 // If there is a SCEVUnknown Phi we are interested in, make it left.
12693 if (!LPhi) {
12694 std::swap(LHS, RHS);
12695 std::swap(FoundLHS, FoundRHS);
12696 std::swap(LPhi, RPhi);
12698 }
12699
12700 assert(LPhi && "LPhi should definitely be a SCEVUnknown Phi!");
12701 const BasicBlock *LBB = LPhi->getParent();
12702 const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
12703
12704 auto ProvedEasily = [&](const SCEV *S1, const SCEV *S2) {
12705 return isKnownViaNonRecursiveReasoning(Pred, S1, S2) ||
12706 isImpliedCondOperandsViaRanges(Pred, S1, S2, Pred, FoundLHS, FoundRHS) ||
12707 isImpliedViaOperations(Pred, S1, S2, FoundLHS, FoundRHS, Depth);
12708 };
12709
12710 if (RPhi && RPhi->getParent() == LBB) {
12711 // Case one: RHS is also a SCEVUnknown Phi from the same basic block.
12712 // If we compare two Phis from the same block, and for each entry block
12713 // the predicate is true for incoming values from this block, then the
12714 // predicate is also true for the Phis.
12715 for (const BasicBlock *IncBB : predecessors(LBB)) {
12716 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12717 const SCEV *R = getSCEV(RPhi->getIncomingValueForBlock(IncBB));
12718 if (!ProvedEasily(L, R))
12719 return false;
12720 }
12721 } else if (RAR && RAR->getLoop()->getHeader() == LBB) {
12722 // Case two: RHS is also a Phi from the same basic block, and it is an
12723 // AddRec. It means that there is a loop which has both AddRec and Unknown
12724 // PHIs, for it we can compare incoming values of AddRec from above the loop
12725 // and latch with their respective incoming values of LPhi.
12726 // TODO: Generalize to handle loops with many inputs in a header.
12727 if (LPhi->getNumIncomingValues() != 2) return false;
12728
12729 auto *RLoop = RAR->getLoop();
12730 auto *Predecessor = RLoop->getLoopPredecessor();
12731 assert(Predecessor && "Loop with AddRec with no predecessor?");
12732 const SCEV *L1 = getSCEV(LPhi->getIncomingValueForBlock(Predecessor));
12733 if (!ProvedEasily(L1, RAR->getStart()))
12734 return false;
12735 auto *Latch = RLoop->getLoopLatch();
12736 assert(Latch && "Loop with AddRec with no latch?");
12737 const SCEV *L2 = getSCEV(LPhi->getIncomingValueForBlock(Latch));
12738 if (!ProvedEasily(L2, RAR->getPostIncExpr(*this)))
12739 return false;
12740 } else {
12741 // In all other cases go over inputs of LHS and compare each of them to RHS,
12742 // the predicate is true for (LHS, RHS) if it is true for all such pairs.
12743 // At this point RHS is either a non-Phi, or it is a Phi from some block
12744 // different from LBB.
12745 for (const BasicBlock *IncBB : predecessors(LBB)) {
12746 // Check that RHS is available in this block.
12747 if (!dominates(RHS, IncBB))
12748 return false;
12749 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12750 // Make sure L does not refer to a value from a potentially previous
12751 // iteration of a loop.
12752 if (!properlyDominates(L, LBB))
12753 return false;
12754 // Addrecs are considered to properly dominate their loop, so are missed
12755 // by the previous check. Discard any values that have computable
12756 // evolution in this loop.
12757 if (auto *Loop = LI.getLoopFor(LBB))
12758 if (hasComputableLoopEvolution(L, Loop))
12759 return false;
12760 if (!ProvedEasily(L, RHS))
12761 return false;
12762 }
12763 }
12764 return true;
12765}
12766
12767bool ScalarEvolution::isImpliedCondOperandsViaShift(CmpPredicate Pred,
12768 const SCEV *LHS,
12769 const SCEV *RHS,
12770 const SCEV *FoundLHS,
12771 const SCEV *FoundRHS) {
12772 // We want to imply LHS < RHS from LHS < (RHS >> shiftvalue). First, make
12773 // sure that we are dealing with same LHS.
12774 if (RHS == FoundRHS) {
12775 std::swap(LHS, RHS);
12776 std::swap(FoundLHS, FoundRHS);
12778 }
12779 if (LHS != FoundLHS)
12780 return false;
12781
12782 auto *SUFoundRHS = dyn_cast<SCEVUnknown>(FoundRHS);
12783 if (!SUFoundRHS)
12784 return false;
12785
12786 Value *Shiftee, *ShiftValue;
12787
12788 using namespace PatternMatch;
12789 if (match(SUFoundRHS->getValue(),
12790 m_LShr(m_Value(Shiftee), m_Value(ShiftValue)))) {
12791 auto *ShifteeS = getSCEV(Shiftee);
12792 // Prove one of the following:
12793 // LHS <u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <u RHS
12794 // LHS <=u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <=u RHS
12795 // LHS <s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12796 // ---> LHS <s RHS
12797 // LHS <=s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12798 // ---> LHS <=s RHS
12799 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE)
12800 return isKnownPredicate(ICmpInst::ICMP_ULE, ShifteeS, RHS);
12801 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE)
12802 if (isKnownNonNegative(ShifteeS))
12803 return isKnownPredicate(ICmpInst::ICMP_SLE, ShifteeS, RHS);
12804 }
12805
12806 return false;
12807}
12808
12809bool ScalarEvolution::isImpliedCondOperands(CmpPredicate Pred, const SCEV *LHS,
12810 const SCEV *RHS,
12811 const SCEV *FoundLHS,
12812 const SCEV *FoundRHS,
12813 const Instruction *CtxI) {
12814 return isImpliedCondOperandsViaRanges(Pred, LHS, RHS, Pred, FoundLHS,
12815 FoundRHS) ||
12816 isImpliedCondOperandsViaNoOverflow(Pred, LHS, RHS, FoundLHS,
12817 FoundRHS) ||
12818 isImpliedCondOperandsViaShift(Pred, LHS, RHS, FoundLHS, FoundRHS) ||
12819 isImpliedCondOperandsViaAddRecStart(Pred, LHS, RHS, FoundLHS, FoundRHS,
12820 CtxI) ||
12821 isImpliedCondOperandsHelper(Pred, LHS, RHS, FoundLHS, FoundRHS);
12822}
12823
12824/// Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
12825template <typename MinMaxExprType>
12826static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr,
12827 const SCEV *Candidate) {
12828 const MinMaxExprType *MinMaxExpr = dyn_cast<MinMaxExprType>(MaybeMinMaxExpr);
12829 if (!MinMaxExpr)
12830 return false;
12831
12832 return is_contained(MinMaxExpr->operands(), Candidate);
12833}
12834
12836 CmpPredicate Pred, const SCEV *LHS,
12837 const SCEV *RHS) {
12838 // If both sides are affine addrecs for the same loop, with equal
12839 // steps, and we know the recurrences don't wrap, then we only
12840 // need to check the predicate on the starting values.
12841
12842 if (!ICmpInst::isRelational(Pred))
12843 return false;
12844
12845 const SCEV *LStart, *RStart, *Step;
12846 const Loop *L;
12847 if (!match(LHS,
12848 m_scev_AffineAddRec(m_SCEV(LStart), m_SCEV(Step), m_Loop(L))) ||
12850 m_SpecificLoop(L))))
12851 return false;
12856 if (!LAR->getNoWrapFlags(NW) || !RAR->getNoWrapFlags(NW))
12857 return false;
12858
12859 return SE.isKnownPredicate(Pred, LStart, RStart);
12860}
12861
12862/// Is LHS `Pred` RHS true on the virtue of LHS or RHS being a Min or Max
12863/// expression?
12865 const SCEV *LHS, const SCEV *RHS) {
12866 switch (Pred) {
12867 default:
12868 return false;
12869
12870 case ICmpInst::ICMP_SGE:
12871 std::swap(LHS, RHS);
12872 [[fallthrough]];
12873 case ICmpInst::ICMP_SLE:
12874 return
12875 // min(A, ...) <= A
12877 // A <= max(A, ...)
12879
12880 case ICmpInst::ICMP_UGE:
12881 std::swap(LHS, RHS);
12882 [[fallthrough]];
12883 case ICmpInst::ICMP_ULE:
12884 return
12885 // min(A, ...) <= A
12886 // FIXME: what about umin_seq?
12888 // A <= max(A, ...)
12890 }
12891
12892 llvm_unreachable("covered switch fell through?!");
12893}
12894
12895bool ScalarEvolution::isImpliedViaOperations(CmpPredicate Pred, const SCEV *LHS,
12896 const SCEV *RHS,
12897 const SCEV *FoundLHS,
12898 const SCEV *FoundRHS,
12899 unsigned Depth) {
12902 "LHS and RHS have different sizes?");
12903 assert(getTypeSizeInBits(FoundLHS->getType()) ==
12904 getTypeSizeInBits(FoundRHS->getType()) &&
12905 "FoundLHS and FoundRHS have different sizes?");
12906 // We want to avoid hurting the compile time with analysis of too big trees.
12908 return false;
12909
12910 // We only want to work with GT comparison so far.
12911 if (ICmpInst::isLT(Pred)) {
12913 std::swap(LHS, RHS);
12914 std::swap(FoundLHS, FoundRHS);
12915 }
12916
12918
12919 // For unsigned, try to reduce it to corresponding signed comparison.
12920 if (P == ICmpInst::ICMP_UGT)
12921 // We can replace unsigned predicate with its signed counterpart if all
12922 // involved values are non-negative.
12923 // TODO: We could have better support for unsigned.
12924 if (isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) {
12925 // Knowing that both FoundLHS and FoundRHS are non-negative, and knowing
12926 // FoundLHS >u FoundRHS, we also know that FoundLHS >s FoundRHS. Let us
12927 // use this fact to prove that LHS and RHS are non-negative.
12928 const SCEV *MinusOne = getMinusOne(LHS->getType());
12929 if (isImpliedCondOperands(ICmpInst::ICMP_SGT, LHS, MinusOne, FoundLHS,
12930 FoundRHS) &&
12931 isImpliedCondOperands(ICmpInst::ICMP_SGT, RHS, MinusOne, FoundLHS,
12932 FoundRHS))
12934 }
12935
12936 if (P != ICmpInst::ICMP_SGT)
12937 return false;
12938
12939 auto GetOpFromSExt = [&](const SCEV *S) -> const SCEV * {
12940 if (auto *Ext = dyn_cast<SCEVSignExtendExpr>(S))
12941 return Ext->getOperand();
12942 // TODO: If S is a SCEVConstant then you can cheaply "strip" the sext off
12943 // the constant in some cases.
12944 return S;
12945 };
12946
12947 // Acquire values from extensions.
12948 auto *OrigLHS = LHS;
12949 auto *OrigFoundLHS = FoundLHS;
12950 LHS = GetOpFromSExt(LHS);
12951 FoundLHS = GetOpFromSExt(FoundLHS);
12952
12953 // Is the SGT predicate can be proved trivially or using the found context.
12954 auto IsSGTViaContext = [&](const SCEV *S1, const SCEV *S2) {
12955 return isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGT, S1, S2) ||
12956 isImpliedViaOperations(ICmpInst::ICMP_SGT, S1, S2, OrigFoundLHS,
12957 FoundRHS, Depth + 1);
12958 };
12959
12960 if (auto *LHSAddExpr = dyn_cast<SCEVAddExpr>(LHS)) {
12961 // We want to avoid creation of any new non-constant SCEV. Since we are
12962 // going to compare the operands to RHS, we should be certain that we don't
12963 // need any size extensions for this. So let's decline all cases when the
12964 // sizes of types of LHS and RHS do not match.
12965 // TODO: Maybe try to get RHS from sext to catch more cases?
12967 return false;
12968
12969 // Should not overflow.
12970 if (!LHSAddExpr->hasNoSignedWrap())
12971 return false;
12972
12973 SCEVUse LL = LHSAddExpr->getOperand(0);
12974 SCEVUse LR = LHSAddExpr->getOperand(1);
12975 auto *MinusOne = getMinusOne(RHS->getType());
12976
12977 // Checks that S1 >= 0 && S2 > RHS, trivially or using the found context.
12978 auto IsSumGreaterThanRHS = [&](const SCEV *S1, const SCEV *S2) {
12979 return IsSGTViaContext(S1, MinusOne) && IsSGTViaContext(S2, RHS);
12980 };
12981 // Try to prove the following rule:
12982 // (LHS = LL + LR) && (LL >= 0) && (LR > RHS) => (LHS > RHS).
12983 // (LHS = LL + LR) && (LR >= 0) && (LL > RHS) => (LHS > RHS).
12984 if (IsSumGreaterThanRHS(LL, LR) || IsSumGreaterThanRHS(LR, LL))
12985 return true;
12986 } else if (auto *LHSUnknownExpr = dyn_cast<SCEVUnknown>(LHS)) {
12987 Value *LL, *LR;
12988 // FIXME: Once we have SDiv implemented, we can get rid of this matching.
12989
12990 using namespace llvm::PatternMatch;
12991
12992 if (match(LHSUnknownExpr->getValue(), m_SDiv(m_Value(LL), m_Value(LR)))) {
12993 // Rules for division.
12994 // We are going to perform some comparisons with Denominator and its
12995 // derivative expressions. In general case, creating a SCEV for it may
12996 // lead to a complex analysis of the entire graph, and in particular it
12997 // can request trip count recalculation for the same loop. This would
12998 // cache as SCEVCouldNotCompute to avoid the infinite recursion. To avoid
12999 // this, we only want to create SCEVs that are constants in this section.
13000 // So we bail if Denominator is not a constant.
13001 if (!isa<ConstantInt>(LR))
13002 return false;
13003
13004 auto *Denominator = cast<SCEVConstant>(getSCEV(LR));
13005
13006 // We want to make sure that LHS = FoundLHS / Denominator. If it is so,
13007 // then a SCEV for the numerator already exists and matches with FoundLHS.
13008 auto *Numerator = getExistingSCEV(LL);
13009 if (!Numerator || Numerator->getType() != FoundLHS->getType())
13010 return false;
13011
13012 // Make sure that the numerator matches with FoundLHS and the denominator
13013 // is positive.
13014 if (!HasSameValue(Numerator, FoundLHS) || !isKnownPositive(Denominator))
13015 return false;
13016
13017 auto *DTy = Denominator->getType();
13018 auto *FRHSTy = FoundRHS->getType();
13019 if (DTy->isPointerTy() != FRHSTy->isPointerTy())
13020 // One of types is a pointer and another one is not. We cannot extend
13021 // them properly to a wider type, so let us just reject this case.
13022 // TODO: Usage of getEffectiveSCEVType for DTy, FRHSTy etc should help
13023 // to avoid this check.
13024 return false;
13025
13026 // Given that:
13027 // FoundLHS > FoundRHS, LHS = FoundLHS / Denominator, Denominator > 0.
13028 auto *WTy = getWiderType(DTy, FRHSTy);
13029 auto *DenominatorExt = getNoopOrSignExtend(Denominator, WTy);
13030 auto *FoundRHSExt = getNoopOrSignExtend(FoundRHS, WTy);
13031
13032 // Try to prove the following rule:
13033 // (FoundRHS > Denominator - 2) && (RHS <= 0) => (LHS > RHS).
13034 // For example, given that FoundLHS > 2. It means that FoundLHS is at
13035 // least 3. If we divide it by Denominator < 4, we will have at least 1.
13036 auto *DenomMinusTwo = getMinusSCEV(DenominatorExt, getConstant(WTy, 2));
13037 if (isKnownNonPositive(RHS) &&
13038 IsSGTViaContext(FoundRHSExt, DenomMinusTwo))
13039 return true;
13040
13041 // Try to prove the following rule:
13042 // (FoundRHS > -1 - Denominator) && (RHS < 0) => (LHS > RHS).
13043 // For example, given that FoundLHS > -3. Then FoundLHS is at least -2.
13044 // If we divide it by Denominator > 2, then:
13045 // 1. If FoundLHS is negative, then the result is 0.
13046 // 2. If FoundLHS is non-negative, then the result is non-negative.
13047 // Anyways, the result is non-negative.
13048 auto *MinusOne = getMinusOne(WTy);
13049 auto *NegDenomMinusOne = getMinusSCEV(MinusOne, DenominatorExt);
13050 if (isKnownNegative(RHS) &&
13051 IsSGTViaContext(FoundRHSExt, NegDenomMinusOne))
13052 return true;
13053 }
13054 }
13055
13056 // If our expression contained SCEVUnknown Phis, and we split it down and now
13057 // need to prove something for them, try to prove the predicate for every
13058 // possible incoming values of those Phis.
13059 if (isImpliedViaMerge(Pred, OrigLHS, RHS, OrigFoundLHS, FoundRHS, Depth + 1))
13060 return true;
13061
13062 return false;
13063}
13064
13066 const SCEV *RHS) {
13067 // zext x u<= sext x, sext x s<= zext x
13068 const SCEV *Op;
13069 switch (Pred) {
13070 case ICmpInst::ICMP_SGE:
13071 std::swap(LHS, RHS);
13072 [[fallthrough]];
13073 case ICmpInst::ICMP_SLE: {
13074 // If operand >=s 0 then ZExt == SExt. If operand <s 0 then SExt <s ZExt.
13075 return match(LHS, m_scev_SExt(m_SCEV(Op))) &&
13077 }
13078 case ICmpInst::ICMP_UGE:
13079 std::swap(LHS, RHS);
13080 [[fallthrough]];
13081 case ICmpInst::ICMP_ULE: {
13082 // If operand >=u 0 then ZExt == SExt. If operand <u 0 then ZExt <u SExt.
13083 return match(LHS, m_scev_ZExt(m_SCEV(Op))) &&
13085 }
13086 default:
13087 return false;
13088 };
13089 llvm_unreachable("unhandled case");
13090}
13091
13092bool ScalarEvolution::isKnownViaNonRecursiveReasoning(CmpPredicate Pred,
13093 SCEVUse LHS,
13094 SCEVUse RHS) {
13095 return isKnownPredicateExtendIdiom(Pred, LHS, RHS) ||
13096 isKnownPredicateViaConstantRanges(Pred, LHS, RHS) ||
13097 IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) ||
13098 IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) ||
13099 isKnownPredicateViaNoOverflow(Pred, LHS, RHS);
13100}
13101
13102bool ScalarEvolution::isImpliedCondOperandsHelper(CmpPredicate Pred,
13103 const SCEV *LHS,
13104 const SCEV *RHS,
13105 const SCEV *FoundLHS,
13106 const SCEV *FoundRHS) {
13107 switch (Pred) {
13108 default:
13109 llvm_unreachable("Unexpected CmpPredicate value!");
13110 case ICmpInst::ICMP_EQ:
13111 case ICmpInst::ICMP_NE:
13112 if (HasSameValue(LHS, FoundLHS) && HasSameValue(RHS, FoundRHS))
13113 return true;
13114 break;
13115 case ICmpInst::ICMP_SLT:
13116 case ICmpInst::ICMP_SLE:
13117 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, LHS, FoundLHS) &&
13118 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, RHS, FoundRHS))
13119 return true;
13120 break;
13121 case ICmpInst::ICMP_SGT:
13122 case ICmpInst::ICMP_SGE:
13123 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, LHS, FoundLHS) &&
13124 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, RHS, FoundRHS))
13125 return true;
13126 break;
13127 case ICmpInst::ICMP_ULT:
13128 case ICmpInst::ICMP_ULE:
13129 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, LHS, FoundLHS) &&
13130 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, RHS, FoundRHS))
13131 return true;
13132 break;
13133 case ICmpInst::ICMP_UGT:
13134 case ICmpInst::ICMP_UGE:
13135 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, LHS, FoundLHS) &&
13136 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, RHS, FoundRHS))
13137 return true;
13138 break;
13139 }
13140
13141 // Maybe it can be proved via operations?
13142 if (isImpliedViaOperations(Pred, LHS, RHS, FoundLHS, FoundRHS))
13143 return true;
13144
13145 return false;
13146}
13147
13148bool ScalarEvolution::isImpliedCondOperandsViaRanges(
13149 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, CmpPredicate FoundPred,
13150 const SCEV *FoundLHS, const SCEV *FoundRHS) {
13151 if (!isa<SCEVConstant>(RHS) || !isa<SCEVConstant>(FoundRHS))
13152 // The restriction on `FoundRHS` be lifted easily -- it exists only to
13153 // reduce the compile time impact of this optimization.
13154 return false;
13155
13156 std::optional<APInt> Addend = computeConstantDifference(LHS, FoundLHS);
13157 if (!Addend)
13158 return false;
13159
13160 const APInt &ConstFoundRHS = cast<SCEVConstant>(FoundRHS)->getAPInt();
13161
13162 // `FoundLHSRange` is the range we know `FoundLHS` to be in by virtue of the
13163 // antecedent "`FoundLHS` `FoundPred` `FoundRHS`".
13164 ConstantRange FoundLHSRange =
13165 ConstantRange::makeExactICmpRegion(FoundPred, ConstFoundRHS);
13166
13167 // Since `LHS` is `FoundLHS` + `Addend`, we can compute a range for `LHS`:
13168 ConstantRange LHSRange = FoundLHSRange.add(ConstantRange(*Addend));
13169
13170 // We can also compute the range of values for `LHS` that satisfy the
13171 // consequent, "`LHS` `Pred` `RHS`":
13172 const APInt &ConstRHS = cast<SCEVConstant>(RHS)->getAPInt();
13173 // The antecedent implies the consequent if every value of `LHS` that
13174 // satisfies the antecedent also satisfies the consequent.
13175 return LHSRange.icmp(Pred, ConstRHS);
13176}
13177
13178bool ScalarEvolution::canIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride,
13179 bool IsSigned) {
13180 assert(isKnownPositive(Stride) && "Positive stride expected!");
13181
13182 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
13183 const SCEV *One = getOne(Stride->getType());
13184
13185 if (IsSigned) {
13186 APInt MaxRHS = getSignedRangeMax(RHS);
13187 APInt MaxValue = APInt::getSignedMaxValue(BitWidth);
13188 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
13189
13190 // SMaxRHS + SMaxStrideMinusOne > SMaxValue => overflow!
13191 return (std::move(MaxValue) - MaxStrideMinusOne).slt(MaxRHS);
13192 }
13193
13194 APInt MaxRHS = getUnsignedRangeMax(RHS);
13195 APInt MaxValue = APInt::getMaxValue(BitWidth);
13196 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
13197
13198 // UMaxRHS + UMaxStrideMinusOne > UMaxValue => overflow!
13199 return (std::move(MaxValue) - MaxStrideMinusOne).ult(MaxRHS);
13200}
13201
13202bool ScalarEvolution::canIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride,
13203 bool IsSigned) {
13204
13205 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
13206 const SCEV *One = getOne(Stride->getType());
13207
13208 if (IsSigned) {
13209 APInt MinRHS = getSignedRangeMin(RHS);
13210 APInt MinValue = APInt::getSignedMinValue(BitWidth);
13211 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
13212
13213 // SMinRHS - SMaxStrideMinusOne < SMinValue => overflow!
13214 return (std::move(MinValue) + MaxStrideMinusOne).sgt(MinRHS);
13215 }
13216
13217 APInt MinRHS = getUnsignedRangeMin(RHS);
13218 APInt MinValue = APInt::getMinValue(BitWidth);
13219 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
13220
13221 // UMinRHS - UMaxStrideMinusOne < UMinValue => overflow!
13222 return (std::move(MinValue) + MaxStrideMinusOne).ugt(MinRHS);
13223}
13224
13226 // umin(N, 1) + floor((N - umin(N, 1)) / D)
13227 // This is equivalent to "1 + floor((N - 1) / D)" for N != 0. The umin
13228 // expression fixes the case of N=0.
13229 const SCEV *MinNOne = getUMinExpr(N, getOne(N->getType()));
13230 const SCEV *NMinusOne = getMinusSCEV(N, MinNOne);
13231 return getAddExpr(MinNOne, getUDivExpr(NMinusOne, D));
13232}
13233
13234const SCEV *ScalarEvolution::computeMaxBECountForLT(const SCEV *Start,
13235 const SCEV *Stride,
13236 const SCEV *End,
13237 unsigned BitWidth,
13238 bool IsSigned) {
13239 // The logic in this function assumes we can represent a positive stride.
13240 // If we can't, the backedge-taken count must be zero.
13241 if (IsSigned && BitWidth == 1)
13242 return getZero(Stride->getType());
13243
13244 // This code below only been closely audited for negative strides in the
13245 // unsigned comparison case, it may be correct for signed comparison, but
13246 // that needs to be established.
13247 if (IsSigned && isKnownNegative(Stride))
13248 return getCouldNotCompute();
13249
13250 // Calculate the maximum backedge count based on the range of values
13251 // permitted by Start, End, and Stride.
13252 APInt MinStart =
13253 IsSigned ? getSignedRangeMin(Start) : getUnsignedRangeMin(Start);
13254
13255 APInt MinStride =
13256 IsSigned ? getSignedRangeMin(Stride) : getUnsignedRangeMin(Stride);
13257
13258 // We assume either the stride is positive, or the backedge-taken count
13259 // is zero. So force StrideForMaxBECount to be at least one.
13260 APInt One(BitWidth, 1);
13261 APInt StrideForMaxBECount = IsSigned ? APIntOps::smax(One, MinStride)
13262 : APIntOps::umax(One, MinStride);
13263
13264 APInt MaxValue = IsSigned ? APInt::getSignedMaxValue(BitWidth)
13265 : APInt::getMaxValue(BitWidth);
13266 APInt Limit = MaxValue - (StrideForMaxBECount - 1);
13267
13268 // Although End can be a MAX expression we estimate MaxEnd considering only
13269 // the case End = RHS of the loop termination condition. This is safe because
13270 // in the other case (End - Start) is zero, leading to a zero maximum backedge
13271 // taken count.
13272 APInt MaxEnd = IsSigned ? APIntOps::smin(getSignedRangeMax(End), Limit)
13273 : APIntOps::umin(getUnsignedRangeMax(End), Limit);
13274
13275 // MaxBECount = ceil((max(MaxEnd, MinStart) - MinStart) / Stride)
13276 MaxEnd = IsSigned ? APIntOps::smax(MaxEnd, MinStart)
13277 : APIntOps::umax(MaxEnd, MinStart);
13278
13279 return getUDivCeilSCEV(getConstant(MaxEnd - MinStart) /* Delta */,
13280 getConstant(StrideForMaxBECount) /* Step */);
13281}
13282
13284ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
13285 const Loop *L, bool IsSigned,
13286 bool ControlsOnlyExit, bool AllowPredicates) {
13288
13290 bool PredicatedIV = false;
13291 if (!IV) {
13292 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS)) {
13293 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(ZExt->getOperand());
13294 if (AR && AR->getLoop() == L && AR->isAffine()) {
13295 auto canProveNUW = [&]() {
13296 // We can use the comparison to infer no-wrap flags only if it fully
13297 // controls the loop exit.
13298 if (!ControlsOnlyExit)
13299 return false;
13300
13301 if (!isLoopInvariant(RHS, L))
13302 return false;
13303
13304 if (!isKnownNonZero(AR->getStepRecurrence(*this)))
13305 // We need the sequence defined by AR to strictly increase in the
13306 // unsigned integer domain for the logic below to hold.
13307 return false;
13308
13309 const unsigned InnerBitWidth = getTypeSizeInBits(AR->getType());
13310 const unsigned OuterBitWidth = getTypeSizeInBits(RHS->getType());
13311 // If RHS <=u Limit, then there must exist a value V in the sequence
13312 // defined by AR (e.g. {Start,+,Step}) such that V >u RHS, and
13313 // V <=u UINT_MAX. Thus, we must exit the loop before unsigned
13314 // overflow occurs. This limit also implies that a signed comparison
13315 // (in the wide bitwidth) is equivalent to an unsigned comparison as
13316 // the high bits on both sides must be zero.
13317 APInt StrideMax = getUnsignedRangeMax(AR->getStepRecurrence(*this));
13318 APInt Limit = APInt::getMaxValue(InnerBitWidth) - (StrideMax - 1);
13319 Limit = Limit.zext(OuterBitWidth);
13320 return getUnsignedRangeMax(applyLoopGuards(RHS, L)).ule(Limit);
13321 };
13322 auto Flags = AR->getNoWrapFlags();
13323 if (!hasFlags(Flags, SCEV::FlagNUW) && canProveNUW())
13324 Flags = setFlags(Flags, SCEV::FlagNUW);
13325
13326 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
13327 if (AR->hasNoUnsignedWrap()) {
13328 // Emulate what getZeroExtendExpr would have done during construction
13329 // if we'd been able to infer the fact just above at that time.
13330 const SCEV *Step = AR->getStepRecurrence(*this);
13331 Type *Ty = ZExt->getType();
13332 auto *S = getAddRecExpr(
13334 getZeroExtendExpr(Step, Ty, 0), L, AR->getNoWrapFlags());
13336 }
13337 }
13338 }
13339 }
13340
13341
13342 if (!IV && AllowPredicates) {
13343 // Try to make this an AddRec using runtime tests, in the first X
13344 // iterations of this loop, where X is the SCEV expression found by the
13345 // algorithm below.
13346 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
13347 PredicatedIV = true;
13348 }
13349
13350 // Avoid weird loops
13351 if (!IV || IV->getLoop() != L || !IV->isAffine())
13352 return getCouldNotCompute();
13353
13354 // A precondition of this method is that the condition being analyzed
13355 // reaches an exiting branch which dominates the latch. Given that, we can
13356 // assume that an increment which violates the nowrap specification and
13357 // produces poison must cause undefined behavior when the resulting poison
13358 // value is branched upon and thus we can conclude that the backedge is
13359 // taken no more often than would be required to produce that poison value.
13360 // Note that a well defined loop can exit on the iteration which violates
13361 // the nowrap specification if there is another exit (either explicit or
13362 // implicit/exceptional) which causes the loop to execute before the
13363 // exiting instruction we're analyzing would trigger UB.
13364 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
13365 bool NoWrap = ControlsOnlyExit && IV->getNoWrapFlags(WrapType);
13367
13368 const SCEV *Stride = IV->getStepRecurrence(*this);
13369
13370 bool PositiveStride = isKnownPositive(Stride);
13371
13372 // Avoid negative or zero stride values.
13373 if (!PositiveStride) {
13374 // We can compute the correct backedge taken count for loops with unknown
13375 // strides if we can prove that the loop is not an infinite loop with side
13376 // effects. Here's the loop structure we are trying to handle -
13377 //
13378 // i = start
13379 // do {
13380 // A[i] = i;
13381 // i += s;
13382 // } while (i < end);
13383 //
13384 // The backedge taken count for such loops is evaluated as -
13385 // (max(end, start + stride) - start - 1) /u stride
13386 //
13387 // The additional preconditions that we need to check to prove correctness
13388 // of the above formula is as follows -
13389 //
13390 // a) IV is either nuw or nsw depending upon signedness (indicated by the
13391 // NoWrap flag).
13392 // b) the loop is guaranteed to be finite (e.g. is mustprogress and has
13393 // no side effects within the loop)
13394 // c) loop has a single static exit (with no abnormal exits)
13395 //
13396 // Precondition a) implies that if the stride is negative, this is a single
13397 // trip loop. The backedge taken count formula reduces to zero in this case.
13398 //
13399 // Precondition b) and c) combine to imply that if rhs is invariant in L,
13400 // then a zero stride means the backedge can't be taken without executing
13401 // undefined behavior.
13402 //
13403 // The positive stride case is the same as isKnownPositive(Stride) returning
13404 // true (original behavior of the function).
13405 //
13406 if (PredicatedIV || !NoWrap || !loopIsFiniteByAssumption(L) ||
13408 return getCouldNotCompute();
13409
13410 if (!isKnownNonZero(Stride)) {
13411 // If we have a step of zero, and RHS isn't invariant in L, we don't know
13412 // if it might eventually be greater than start and if so, on which
13413 // iteration. We can't even produce a useful upper bound.
13414 if (!isLoopInvariant(RHS, L))
13415 return getCouldNotCompute();
13416
13417 // We allow a potentially zero stride, but we need to divide by stride
13418 // below. Since the loop can't be infinite and this check must control
13419 // the sole exit, we can infer the exit must be taken on the first
13420 // iteration (e.g. backedge count = 0) if the stride is zero. Given that,
13421 // we know the numerator in the divides below must be zero, so we can
13422 // pick an arbitrary non-zero value for the denominator (e.g. stride)
13423 // and produce the right result.
13424 // FIXME: Handle the case where Stride is poison?
13425 auto wouldZeroStrideBeUB = [&]() {
13426 // Proof by contradiction. Suppose the stride were zero. If we can
13427 // prove that the backedge *is* taken on the first iteration, then since
13428 // we know this condition controls the sole exit, we must have an
13429 // infinite loop. We can't have a (well defined) infinite loop per
13430 // check just above.
13431 // Note: The (Start - Stride) term is used to get the start' term from
13432 // (start' + stride,+,stride). Remember that we only care about the
13433 // result of this expression when stride == 0 at runtime.
13434 auto *StartIfZero = getMinusSCEV(IV->getStart(), Stride);
13435 return isLoopEntryGuardedByCond(L, Cond, StartIfZero, RHS);
13436 };
13437 if (!wouldZeroStrideBeUB()) {
13438 Stride = getUMaxExpr(Stride, getOne(Stride->getType()));
13439 }
13440 }
13441 } else if (!NoWrap) {
13442 // Avoid proven overflow cases: this will ensure that the backedge taken
13443 // count will not generate any unsigned overflow.
13444 if (canIVOverflowOnLT(RHS, Stride, IsSigned))
13445 return getCouldNotCompute();
13446 }
13447
13448 // On all paths just preceeding, we established the following invariant:
13449 // IV can be assumed not to overflow up to and including the exiting
13450 // iteration. We proved this in one of two ways:
13451 // 1) We can show overflow doesn't occur before the exiting iteration
13452 // 1a) canIVOverflowOnLT, and b) step of one
13453 // 2) We can show that if overflow occurs, the loop must execute UB
13454 // before any possible exit.
13455 // Note that we have not yet proved RHS invariant (in general).
13456
13457 const SCEV *Start = IV->getStart();
13458
13459 // Preserve pointer-typed Start/RHS to pass to isLoopEntryGuardedByCond.
13460 // If we convert to integers, isLoopEntryGuardedByCond will miss some cases.
13461 // Use integer-typed versions for actual computation; we can't subtract
13462 // pointers in general.
13463 const SCEV *OrigStart = Start;
13464 const SCEV *OrigRHS = RHS;
13465 if (Start->getType()->isPointerTy()) {
13467 if (isa<SCEVCouldNotCompute>(Start))
13468 return Start;
13469 }
13470 if (RHS->getType()->isPointerTy()) {
13473 return RHS;
13474 }
13475
13476 const SCEV *End = nullptr, *BECount = nullptr,
13477 *BECountIfBackedgeTaken = nullptr;
13478 if (!isLoopInvariant(RHS, L)) {
13479 const auto *RHSAddRec = dyn_cast<SCEVAddRecExpr>(RHS);
13480 if (PositiveStride && RHSAddRec != nullptr && RHSAddRec->getLoop() == L &&
13481 RHSAddRec->getNoWrapFlags()) {
13482 // The structure of loop we are trying to calculate backedge count of:
13483 //
13484 // left = left_start
13485 // right = right_start
13486 //
13487 // while(left < right){
13488 // ... do something here ...
13489 // left += s1; // stride of left is s1 (s1 > 0)
13490 // right += s2; // stride of right is s2 (s2 < 0)
13491 // }
13492 //
13493
13494 const SCEV *RHSStart = RHSAddRec->getStart();
13495 const SCEV *RHSStride = RHSAddRec->getStepRecurrence(*this);
13496
13497 // If Stride - RHSStride is positive and does not overflow, we can write
13498 // backedge count as ->
13499 // ceil((End - Start) /u (Stride - RHSStride))
13500 // Where, End = max(RHSStart, Start)
13501
13502 // Check if RHSStride < 0 and Stride - RHSStride will not overflow.
13503 if (isKnownNegative(RHSStride) &&
13504 willNotOverflow(Instruction::Sub, /*Signed=*/true, Stride,
13505 RHSStride)) {
13506
13507 const SCEV *Denominator = getMinusSCEV(Stride, RHSStride);
13508 if (isKnownPositive(Denominator)) {
13509 End = IsSigned ? getSMaxExpr(RHSStart, Start)
13510 : getUMaxExpr(RHSStart, Start);
13511
13512 // We can do this because End >= Start, as End = max(RHSStart, Start)
13513 const SCEV *Delta = getMinusSCEV(End, Start);
13514
13515 BECount = getUDivCeilSCEV(Delta, Denominator);
13516 BECountIfBackedgeTaken =
13517 getUDivCeilSCEV(getMinusSCEV(RHSStart, Start), Denominator);
13518 }
13519 }
13520 }
13521 if (BECount == nullptr) {
13522 // If we cannot calculate ExactBECount, we can calculate the MaxBECount,
13523 // given the start, stride and max value for the end bound of the
13524 // loop (RHS), and the fact that IV does not overflow (which is
13525 // checked above).
13526 const SCEV *MaxBECount = computeMaxBECountForLT(
13527 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13528 return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,
13529 MaxBECount, false /*MaxOrZero*/, Predicates);
13530 }
13531 } else {
13532 // We use the expression (max(End,Start)-Start)/Stride to describe the
13533 // backedge count, as if the backedge is taken at least once
13534 // max(End,Start) is End and so the result is as above, and if not
13535 // max(End,Start) is Start so we get a backedge count of zero.
13536 auto *OrigStartMinusStride = getMinusSCEV(OrigStart, Stride);
13537 assert(isAvailableAtLoopEntry(OrigStartMinusStride, L) && "Must be!");
13538 assert(isAvailableAtLoopEntry(OrigStart, L) && "Must be!");
13539 assert(isAvailableAtLoopEntry(OrigRHS, L) && "Must be!");
13540 // Can we prove (max(RHS,Start) > Start - Stride?
13541 if (isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigStart) &&
13542 isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigRHS)) {
13543 // In this case, we can use a refined formula for computing backedge
13544 // taken count. The general formula remains:
13545 // "End-Start /uceiling Stride" where "End = max(RHS,Start)"
13546 // We want to use the alternate formula:
13547 // "((End - 1) - (Start - Stride)) /u Stride"
13548 // Let's do a quick case analysis to show these are equivalent under
13549 // our precondition that max(RHS,Start) > Start - Stride.
13550 // * For RHS <= Start, the backedge-taken count must be zero.
13551 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
13552 // "((Start - 1) - (Start - Stride)) /u Stride" which simplies to
13553 // "Stride - 1 /u Stride" which is indeed zero for all non-zero values
13554 // of Stride. For 0 stride, we've use umin(1,Stride) above,
13555 // reducing this to the stride of 1 case.
13556 // * For RHS >= Start, the backedge count must be "RHS-Start /uceil
13557 // Stride".
13558 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
13559 // "((RHS - 1) - (Start - Stride)) /u Stride" reassociates to
13560 // "((RHS - (Start - Stride) - 1) /u Stride".
13561 // Our preconditions trivially imply no overflow in that form.
13562 const SCEV *MinusOne = getMinusOne(Stride->getType());
13563 const SCEV *Numerator =
13564 getMinusSCEV(getAddExpr(RHS, MinusOne), getMinusSCEV(Start, Stride));
13565 BECount = getUDivExpr(Numerator, Stride);
13566 }
13567
13568 if (!BECount) {
13569 auto canProveRHSGreaterThanEqualStart = [&]() {
13570 auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
13571 const SCEV *GuardedRHS = applyLoopGuards(OrigRHS, L);
13572 const SCEV *GuardedStart = applyLoopGuards(OrigStart, L);
13573
13574 if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart) ||
13575 isKnownPredicate(CondGE, GuardedRHS, GuardedStart))
13576 return true;
13577
13578 // (RHS > Start - 1) implies RHS >= Start.
13579 // * "RHS >= Start" is trivially equivalent to "RHS > Start - 1" if
13580 // "Start - 1" doesn't overflow.
13581 // * For signed comparison, if Start - 1 does overflow, it's equal
13582 // to INT_MAX, and "RHS >s INT_MAX" is trivially false.
13583 // * For unsigned comparison, if Start - 1 does overflow, it's equal
13584 // to UINT_MAX, and "RHS >u UINT_MAX" is trivially false.
13585 //
13586 // FIXME: Should isLoopEntryGuardedByCond do this for us?
13587 auto CondGT = IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT;
13588 auto *StartMinusOne =
13589 getAddExpr(OrigStart, getMinusOne(OrigStart->getType()));
13590 return isLoopEntryGuardedByCond(L, CondGT, OrigRHS, StartMinusOne);
13591 };
13592
13593 // If we know that RHS >= Start in the context of loop, then we know
13594 // that max(RHS, Start) = RHS at this point.
13595 if (canProveRHSGreaterThanEqualStart()) {
13596 End = RHS;
13597 } else {
13598 // If RHS < Start, the backedge will be taken zero times. So in
13599 // general, we can write the backedge-taken count as:
13600 //
13601 // RHS >= Start ? ceil(RHS - Start) / Stride : 0
13602 //
13603 // We convert it to the following to make it more convenient for SCEV:
13604 //
13605 // ceil(max(RHS, Start) - Start) / Stride
13606 End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start);
13607
13608 // See what would happen if we assume the backedge is taken. This is
13609 // used to compute MaxBECount.
13610 BECountIfBackedgeTaken =
13611 getUDivCeilSCEV(getMinusSCEV(RHS, Start), Stride);
13612 }
13613
13614 // At this point, we know:
13615 //
13616 // 1. If IsSigned, Start <=s End; otherwise, Start <=u End
13617 // 2. The index variable doesn't overflow.
13618 //
13619 // Therefore, we know N exists such that
13620 // (Start + Stride * N) >= End, and computing "(Start + Stride * N)"
13621 // doesn't overflow.
13622 //
13623 // Using this information, try to prove whether the addition in
13624 // "(Start - End) + (Stride - 1)" has unsigned overflow.
13625 const SCEV *One = getOne(Stride->getType());
13626 bool MayAddOverflow = [&] {
13627 if (isKnownToBeAPowerOfTwo(Stride)) {
13628 // Suppose Stride is a power of two, and Start/End are unsigned
13629 // integers. Let UMAX be the largest representable unsigned
13630 // integer.
13631 //
13632 // By the preconditions of this function, we know
13633 // "(Start + Stride * N) >= End", and this doesn't overflow.
13634 // As a formula:
13635 //
13636 // End <= (Start + Stride * N) <= UMAX
13637 //
13638 // Subtracting Start from all the terms:
13639 //
13640 // End - Start <= Stride * N <= UMAX - Start
13641 //
13642 // Since Start is unsigned, UMAX - Start <= UMAX. Therefore:
13643 //
13644 // End - Start <= Stride * N <= UMAX
13645 //
13646 // Stride * N is a multiple of Stride. Therefore,
13647 //
13648 // End - Start <= Stride * N <= UMAX - (UMAX mod Stride)
13649 //
13650 // Since Stride is a power of two, UMAX + 1 is divisible by
13651 // Stride. Therefore, UMAX mod Stride == Stride - 1. So we can
13652 // write:
13653 //
13654 // End - Start <= Stride * N <= UMAX - Stride - 1
13655 //
13656 // Dropping the middle term:
13657 //
13658 // End - Start <= UMAX - Stride - 1
13659 //
13660 // Adding Stride - 1 to both sides:
13661 //
13662 // (End - Start) + (Stride - 1) <= UMAX
13663 //
13664 // In other words, the addition doesn't have unsigned overflow.
13665 //
13666 // A similar proof works if we treat Start/End as signed values.
13667 // Just rewrite steps before "End - Start <= Stride * N <= UMAX"
13668 // to use signed max instead of unsigned max. Note that we're
13669 // trying to prove a lack of unsigned overflow in either case.
13670 return false;
13671 }
13672 if (Start == Stride || Start == getMinusSCEV(Stride, One)) {
13673 // If Start is equal to Stride, (End - Start) + (Stride - 1) == End
13674 // - 1. If !IsSigned, 0 <u Stride == Start <=u End; so 0 <u End - 1
13675 // <u End. If IsSigned, 0 <s Stride == Start <=s End; so 0 <s End -
13676 // 1 <s End.
13677 //
13678 // If Start is equal to Stride - 1, (End - Start) + Stride - 1 ==
13679 // End.
13680 return false;
13681 }
13682 return true;
13683 }();
13684
13685 const SCEV *Delta = getMinusSCEV(End, Start);
13686 if (!MayAddOverflow) {
13687 // floor((D + (S - 1)) / S)
13688 // We prefer this formulation if it's legal because it's fewer
13689 // operations.
13690 BECount =
13691 getUDivExpr(getAddExpr(Delta, getMinusSCEV(Stride, One)), Stride);
13692 } else {
13693 BECount = getUDivCeilSCEV(Delta, Stride);
13694 }
13695 }
13696 }
13697
13698 const SCEV *ConstantMaxBECount;
13699 bool MaxOrZero = false;
13700 if (isa<SCEVConstant>(BECount)) {
13701 ConstantMaxBECount = BECount;
13702 } else if (BECountIfBackedgeTaken &&
13703 isa<SCEVConstant>(BECountIfBackedgeTaken)) {
13704 // If we know exactly how many times the backedge will be taken if it's
13705 // taken at least once, then the backedge count will either be that or
13706 // zero.
13707 ConstantMaxBECount = BECountIfBackedgeTaken;
13708 MaxOrZero = true;
13709 } else {
13710 ConstantMaxBECount = computeMaxBECountForLT(
13711 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13712 }
13713
13714 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
13715 !isa<SCEVCouldNotCompute>(BECount))
13716 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
13717
13718 const SCEV *SymbolicMaxBECount =
13719 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13720 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, MaxOrZero,
13721 Predicates);
13722}
13723
13724ScalarEvolution::ExitLimit ScalarEvolution::howManyGreaterThans(
13725 const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned,
13726 bool ControlsOnlyExit, bool AllowPredicates) {
13728 // We handle only IV > Invariant
13729 if (!isLoopInvariant(RHS, L))
13730 return getCouldNotCompute();
13731
13732 const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
13733 if (!IV && AllowPredicates)
13734 // Try to make this an AddRec using runtime tests, in the first X
13735 // iterations of this loop, where X is the SCEV expression found by the
13736 // algorithm below.
13737 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
13738
13739 // Avoid weird loops
13740 if (!IV || IV->getLoop() != L || !IV->isAffine())
13741 return getCouldNotCompute();
13742
13743 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
13744 bool NoWrap = ControlsOnlyExit && IV->getNoWrapFlags(WrapType);
13746
13747 const SCEV *Stride = getNegativeSCEV(IV->getStepRecurrence(*this));
13748
13749 // Avoid negative or zero stride values
13750 if (!isKnownPositive(Stride))
13751 return getCouldNotCompute();
13752
13753 // Avoid proven overflow cases: this will ensure that the backedge taken count
13754 // will not generate any unsigned overflow. Relaxed no-overflow conditions
13755 // exploit NoWrapFlags, allowing to optimize in presence of undefined
13756 // behaviors like the case of C language.
13757 if (!Stride->isOne() && !NoWrap)
13758 if (canIVOverflowOnGT(RHS, Stride, IsSigned))
13759 return getCouldNotCompute();
13760
13761 const SCEV *Start = IV->getStart();
13762 const SCEV *End = RHS;
13763 if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) {
13764 // If we know that Start >= RHS in the context of loop, then we know that
13765 // min(RHS, Start) = RHS at this point.
13767 L, IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE, Start, RHS))
13768 End = RHS;
13769 else
13770 End = IsSigned ? getSMinExpr(RHS, Start) : getUMinExpr(RHS, Start);
13771 }
13772
13773 if (Start->getType()->isPointerTy()) {
13775 if (isa<SCEVCouldNotCompute>(Start))
13776 return Start;
13777 }
13778 if (End->getType()->isPointerTy()) {
13779 End = getLosslessPtrToIntExpr(End);
13780 if (isa<SCEVCouldNotCompute>(End))
13781 return End;
13782 }
13783
13784 // Compute ((Start - End) + (Stride - 1)) / Stride.
13785 // FIXME: This can overflow. Holding off on fixing this for now;
13786 // howManyGreaterThans will hopefully be gone soon.
13787 const SCEV *One = getOne(Stride->getType());
13788 const SCEV *BECount = getUDivExpr(
13789 getAddExpr(getMinusSCEV(Start, End), getMinusSCEV(Stride, One)), Stride);
13790
13791 APInt MaxStart = IsSigned ? getSignedRangeMax(Start)
13793
13794 APInt MinStride = IsSigned ? getSignedRangeMin(Stride)
13795 : getUnsignedRangeMin(Stride);
13796
13797 unsigned BitWidth = getTypeSizeInBits(LHS->getType());
13798 APInt Limit = IsSigned ? APInt::getSignedMinValue(BitWidth) + (MinStride - 1)
13799 : APInt::getMinValue(BitWidth) + (MinStride - 1);
13800
13801 // Although End can be a MIN expression we estimate MinEnd considering only
13802 // the case End = RHS. This is safe because in the other case (Start - End)
13803 // is zero, leading to a zero maximum backedge taken count.
13804 APInt MinEnd =
13805 IsSigned ? APIntOps::smax(getSignedRangeMin(RHS), Limit)
13806 : APIntOps::umax(getUnsignedRangeMin(RHS), Limit);
13807
13808 const SCEV *ConstantMaxBECount =
13809 isa<SCEVConstant>(BECount)
13810 ? BECount
13811 : getUDivCeilSCEV(getConstant(MaxStart - MinEnd),
13812 getConstant(MinStride));
13813
13814 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount))
13815 ConstantMaxBECount = BECount;
13816 const SCEV *SymbolicMaxBECount =
13817 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13818
13819 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
13820 Predicates);
13821}
13822
13824 ScalarEvolution &SE) const {
13825 if (Range.isFullSet()) // Infinite loop.
13826 return SE.getCouldNotCompute();
13827
13828 // If the start is a non-zero constant, shift the range to simplify things.
13829 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart()))
13830 if (!SC->getValue()->isZero()) {
13832 Operands[0] = SE.getZero(SC->getType());
13833 const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop(),
13835 if (const auto *ShiftedAddRec = dyn_cast<SCEVAddRecExpr>(Shifted))
13836 return ShiftedAddRec->getNumIterationsInRange(
13837 Range.subtract(SC->getAPInt()), SE);
13838 // This is strange and shouldn't happen.
13839 return SE.getCouldNotCompute();
13840 }
13841
13842 // The only time we can solve this is when we have all constant indices.
13843 // Otherwise, we cannot determine the overflow conditions.
13844 if (any_of(operands(), [](const SCEV *Op) { return !isa<SCEVConstant>(Op); }))
13845 return SE.getCouldNotCompute();
13846
13847 // Okay at this point we know that all elements of the chrec are constants and
13848 // that the start element is zero.
13849
13850 // First check to see if the range contains zero. If not, the first
13851 // iteration exits.
13852 unsigned BitWidth = SE.getTypeSizeInBits(getType());
13853 if (!Range.contains(APInt(BitWidth, 0)))
13854 return SE.getZero(getType());
13855
13856 if (isAffine()) {
13857 // If this is an affine expression then we have this situation:
13858 // Solve {0,+,A} in Range === Ax in Range
13859
13860 // We know that zero is in the range. If A is positive then we know that
13861 // the upper value of the range must be the first possible exit value.
13862 // If A is negative then the lower of the range is the last possible loop
13863 // value. Also note that we already checked for a full range.
13864 APInt A = cast<SCEVConstant>(getOperand(1))->getAPInt();
13865 APInt End = A.sge(1) ? (Range.getUpper() - 1) : Range.getLower();
13866
13867 // The exit value should be (End+A)/A.
13868 APInt ExitVal = (End + A).udiv(A);
13869 ConstantInt *ExitValue = ConstantInt::get(SE.getContext(), ExitVal);
13870
13871 // Evaluate at the exit value. If we really did fall out of the valid
13872 // range, then we computed our trip count, otherwise wrap around or other
13873 // things must have happened.
13874 ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue, SE);
13875 if (Range.contains(Val->getValue()))
13876 return SE.getCouldNotCompute(); // Something strange happened
13877
13878 // Ensure that the previous value is in the range.
13879 assert(Range.contains(
13881 ConstantInt::get(SE.getContext(), ExitVal - 1), SE)->getValue()) &&
13882 "Linear scev computation is off in a bad way!");
13883 return SE.getConstant(ExitValue);
13884 }
13885
13886 if (isQuadratic()) {
13887 if (auto S = SolveQuadraticAddRecRange(this, Range, SE))
13888 return SE.getConstant(*S);
13889 }
13890
13891 return SE.getCouldNotCompute();
13892}
13893
13894const SCEVAddRecExpr *
13896 assert(getNumOperands() > 1 && "AddRec with zero step?");
13897 // There is a temptation to just call getAddExpr(this, getStepRecurrence(SE)),
13898 // but in this case we cannot guarantee that the value returned will be an
13899 // AddRec because SCEV does not have a fixed point where it stops
13900 // simplification: it is legal to return ({rec1} + {rec2}). For example, it
13901 // may happen if we reach arithmetic depth limit while simplifying. So we
13902 // construct the returned value explicitly.
13904 // If this is {A,+,B,+,C,...,+,N}, then its step is {B,+,C,+,...,+,N}, and
13905 // (this + Step) is {A+B,+,B+C,+...,+,N}.
13906 for (unsigned i = 0, e = getNumOperands() - 1; i < e; ++i)
13907 Ops.push_back(SE.getAddExpr(getOperand(i), getOperand(i + 1)));
13908 // We know that the last operand is not a constant zero (otherwise it would
13909 // have been popped out earlier). This guarantees us that if the result has
13910 // the same last operand, then it will also not be popped out, meaning that
13911 // the returned value will be an AddRec.
13912 const SCEV *Last = getOperand(getNumOperands() - 1);
13913 assert(!Last->isZero() && "Recurrency with zero step?");
13914 Ops.push_back(Last);
13917}
13918
13919// Return true when S contains at least an undef value.
13921 return SCEVExprContains(
13922 S, [](const SCEV *S) { return match(S, m_scev_UndefOrPoison()); });
13923}
13924
13925// Return true when S contains a value that is a nullptr.
13927 return SCEVExprContains(S, [](const SCEV *S) {
13928 if (const auto *SU = dyn_cast<SCEVUnknown>(S))
13929 return SU->getValue() == nullptr;
13930 return false;
13931 });
13932}
13933
13934/// Return the size of an element read or written by Inst.
13936 Type *Ty;
13937 if (StoreInst *Store = dyn_cast<StoreInst>(Inst))
13938 Ty = Store->getValueOperand()->getType();
13939 else if (LoadInst *Load = dyn_cast<LoadInst>(Inst))
13940 Ty = Load->getType();
13941 else
13942 return nullptr;
13943
13945 return getSizeOfExpr(ETy, Ty);
13946}
13947
13948//===----------------------------------------------------------------------===//
13949// SCEVCallbackVH Class Implementation
13950//===----------------------------------------------------------------------===//
13951
13953 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13954 if (PHINode *PN = dyn_cast<PHINode>(getValPtr()))
13955 SE->ConstantEvolutionLoopExitValue.erase(PN);
13956 SE->eraseValueFromMap(getValPtr());
13957 // this now dangles!
13958}
13959
13960void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) {
13961 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13962
13963 // Forget all the expressions associated with users of the old value,
13964 // so that future queries will recompute the expressions using the new
13965 // value.
13966 SE->forgetValue(getValPtr());
13967 // this now dangles!
13968}
13969
13970ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se)
13971 : CallbackVH(V), SE(se) {}
13972
13973//===----------------------------------------------------------------------===//
13974// ScalarEvolution Class Implementation
13975//===----------------------------------------------------------------------===//
13976
13979 LoopInfo &LI)
13980 : F(F), DL(F.getDataLayout()), TLI(TLI), AC(AC), DT(DT), LI(LI),
13981 CouldNotCompute(new SCEVCouldNotCompute()), ValuesAtScopes(64),
13982 LoopDispositions(64), BlockDispositions(64) {
13983 // To use guards for proving predicates, we need to scan every instruction in
13984 // relevant basic blocks, and not just terminators. Doing this is a waste of
13985 // time if the IR does not actually contain any calls to
13986 // @llvm.experimental.guard, so do a quick check and remember this beforehand.
13987 //
13988 // This pessimizes the case where a pass that preserves ScalarEvolution wants
13989 // to _add_ guards to the module when there weren't any before, and wants
13990 // ScalarEvolution to optimize based on those guards. For now we prefer to be
13991 // efficient in lieu of being smart in that rather obscure case.
13992
13993 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
13994 F.getParent(), Intrinsic::experimental_guard);
13995 HasGuards = GuardDecl && !GuardDecl->use_empty();
13996}
13997
13999 : F(Arg.F), DL(Arg.DL), HasGuards(Arg.HasGuards), TLI(Arg.TLI), AC(Arg.AC),
14000 DT(Arg.DT), LI(Arg.LI), CouldNotCompute(std::move(Arg.CouldNotCompute)),
14001 ValueExprMap(std::move(Arg.ValueExprMap)),
14002 PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)),
14003 PendingMerges(std::move(Arg.PendingMerges)),
14004 ConstantMultipleCache(std::move(Arg.ConstantMultipleCache)),
14005 BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)),
14006 PredicatedBackedgeTakenCounts(
14007 std::move(Arg.PredicatedBackedgeTakenCounts)),
14008 BECountUsers(std::move(Arg.BECountUsers)),
14009 ConstantEvolutionLoopExitValue(
14010 std::move(Arg.ConstantEvolutionLoopExitValue)),
14011 ValuesAtScopes(std::move(Arg.ValuesAtScopes)),
14012 ValuesAtScopesUsers(std::move(Arg.ValuesAtScopesUsers)),
14013 LoopDispositions(std::move(Arg.LoopDispositions)),
14014 LoopPropertiesCache(std::move(Arg.LoopPropertiesCache)),
14015 BlockDispositions(std::move(Arg.BlockDispositions)),
14016 SCEVUsers(std::move(Arg.SCEVUsers)),
14017 UnsignedRanges(std::move(Arg.UnsignedRanges)),
14018 SignedRanges(std::move(Arg.SignedRanges)),
14019 UniqueSCEVs(std::move(Arg.UniqueSCEVs)),
14020 UniquePreds(std::move(Arg.UniquePreds)),
14021 SCEVAllocator(std::move(Arg.SCEVAllocator)),
14022 LoopUsers(std::move(Arg.LoopUsers)),
14023 PredicatedSCEVRewrites(std::move(Arg.PredicatedSCEVRewrites)),
14024 FirstUnknown(Arg.FirstUnknown) {
14025 Arg.FirstUnknown = nullptr;
14026}
14027
14029 // Iterate through all the SCEVUnknown instances and call their
14030 // destructors, so that they release their references to their values.
14031 for (SCEVUnknown *U = FirstUnknown; U;) {
14032 SCEVUnknown *Tmp = U;
14033 U = U->Next;
14034 Tmp->~SCEVUnknown();
14035 }
14036 FirstUnknown = nullptr;
14037
14038 ExprValueMap.clear();
14039 ValueExprMap.clear();
14040 HasRecMap.clear();
14041 BackedgeTakenCounts.clear();
14042 PredicatedBackedgeTakenCounts.clear();
14043
14044 assert(PendingLoopPredicates.empty() && "isImpliedCond garbage");
14045 assert(PendingMerges.empty() && "isImpliedViaMerge garbage");
14046 assert(!WalkingBEDominatingConds && "isLoopBackedgeGuardedByCond garbage!");
14047 assert(!ProvingSplitPredicate && "ProvingSplitPredicate garbage!");
14048}
14049
14053
14054/// When printing a top-level SCEV for trip counts, it's helpful to include
14055/// a type for constants which are otherwise hard to disambiguate.
14056static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV* S) {
14057 if (isa<SCEVConstant>(S))
14058 OS << *S->getType() << " ";
14059 OS << *S;
14060}
14061
14063 const Loop *L) {
14064 // Print all inner loops first
14065 for (Loop *I : *L)
14066 PrintLoopInfo(OS, SE, I);
14067
14068 OS << "Loop ";
14069 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14070 OS << ": ";
14071
14072 SmallVector<BasicBlock *, 8> ExitingBlocks;
14073 L->getExitingBlocks(ExitingBlocks);
14074 if (ExitingBlocks.size() != 1)
14075 OS << "<multiple exits> ";
14076
14077 auto *BTC = SE->getBackedgeTakenCount(L);
14078 if (!isa<SCEVCouldNotCompute>(BTC)) {
14079 OS << "backedge-taken count is ";
14080 PrintSCEVWithTypeHint(OS, BTC);
14081 } else
14082 OS << "Unpredictable backedge-taken count.";
14083 OS << "\n";
14084
14085 if (ExitingBlocks.size() > 1)
14086 for (BasicBlock *ExitingBlock : ExitingBlocks) {
14087 OS << " exit count for " << ExitingBlock->getName() << ": ";
14088 const SCEV *EC = SE->getExitCount(L, ExitingBlock);
14089 PrintSCEVWithTypeHint(OS, EC);
14090 if (isa<SCEVCouldNotCompute>(EC)) {
14091 // Retry with predicates.
14093 EC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates);
14094 if (!isa<SCEVCouldNotCompute>(EC)) {
14095 OS << "\n predicated exit count for " << ExitingBlock->getName()
14096 << ": ";
14097 PrintSCEVWithTypeHint(OS, EC);
14098 OS << "\n Predicates:\n";
14099 for (const auto *P : Predicates)
14100 P->print(OS, 4);
14101 }
14102 }
14103 OS << "\n";
14104 }
14105
14106 OS << "Loop ";
14107 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14108 OS << ": ";
14109
14110 auto *ConstantBTC = SE->getConstantMaxBackedgeTakenCount(L);
14111 if (!isa<SCEVCouldNotCompute>(ConstantBTC)) {
14112 OS << "constant max backedge-taken count is ";
14113 PrintSCEVWithTypeHint(OS, ConstantBTC);
14115 OS << ", actual taken count either this or zero.";
14116 } else {
14117 OS << "Unpredictable constant max backedge-taken count. ";
14118 }
14119
14120 OS << "\n"
14121 "Loop ";
14122 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14123 OS << ": ";
14124
14125 auto *SymbolicBTC = SE->getSymbolicMaxBackedgeTakenCount(L);
14126 if (!isa<SCEVCouldNotCompute>(SymbolicBTC)) {
14127 OS << "symbolic max backedge-taken count is ";
14128 PrintSCEVWithTypeHint(OS, SymbolicBTC);
14130 OS << ", actual taken count either this or zero.";
14131 } else {
14132 OS << "Unpredictable symbolic max backedge-taken count. ";
14133 }
14134 OS << "\n";
14135
14136 if (ExitingBlocks.size() > 1)
14137 for (BasicBlock *ExitingBlock : ExitingBlocks) {
14138 OS << " symbolic max exit count for " << ExitingBlock->getName() << ": ";
14139 auto *ExitBTC = SE->getExitCount(L, ExitingBlock,
14141 PrintSCEVWithTypeHint(OS, ExitBTC);
14142 if (isa<SCEVCouldNotCompute>(ExitBTC)) {
14143 // Retry with predicates.
14145 ExitBTC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates,
14147 if (!isa<SCEVCouldNotCompute>(ExitBTC)) {
14148 OS << "\n predicated symbolic max exit count for "
14149 << ExitingBlock->getName() << ": ";
14150 PrintSCEVWithTypeHint(OS, ExitBTC);
14151 OS << "\n Predicates:\n";
14152 for (const auto *P : Predicates)
14153 P->print(OS, 4);
14154 }
14155 }
14156 OS << "\n";
14157 }
14158
14160 auto *PBT = SE->getPredicatedBackedgeTakenCount(L, Preds);
14161 if (PBT != BTC) {
14162 assert(!Preds.empty() && "Different predicated BTC, but no predicates");
14163 OS << "Loop ";
14164 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14165 OS << ": ";
14166 if (!isa<SCEVCouldNotCompute>(PBT)) {
14167 OS << "Predicated backedge-taken count is ";
14168 PrintSCEVWithTypeHint(OS, PBT);
14169 } else
14170 OS << "Unpredictable predicated backedge-taken count.";
14171 OS << "\n";
14172 OS << " Predicates:\n";
14173 for (const auto *P : Preds)
14174 P->print(OS, 4);
14175 }
14176 Preds.clear();
14177
14178 auto *PredConstantMax =
14180 if (PredConstantMax != ConstantBTC) {
14181 assert(!Preds.empty() &&
14182 "different predicated constant max BTC but no predicates");
14183 OS << "Loop ";
14184 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14185 OS << ": ";
14186 if (!isa<SCEVCouldNotCompute>(PredConstantMax)) {
14187 OS << "Predicated constant max backedge-taken count is ";
14188 PrintSCEVWithTypeHint(OS, PredConstantMax);
14189 } else
14190 OS << "Unpredictable predicated constant max backedge-taken count.";
14191 OS << "\n";
14192 OS << " Predicates:\n";
14193 for (const auto *P : Preds)
14194 P->print(OS, 4);
14195 }
14196 Preds.clear();
14197
14198 auto *PredSymbolicMax =
14200 if (SymbolicBTC != PredSymbolicMax) {
14201 assert(!Preds.empty() &&
14202 "Different predicated symbolic max BTC, but no predicates");
14203 OS << "Loop ";
14204 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14205 OS << ": ";
14206 if (!isa<SCEVCouldNotCompute>(PredSymbolicMax)) {
14207 OS << "Predicated symbolic max backedge-taken count is ";
14208 PrintSCEVWithTypeHint(OS, PredSymbolicMax);
14209 } else
14210 OS << "Unpredictable predicated symbolic max backedge-taken count.";
14211 OS << "\n";
14212 OS << " Predicates:\n";
14213 for (const auto *P : Preds)
14214 P->print(OS, 4);
14215 }
14216
14218 OS << "Loop ";
14219 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14220 OS << ": ";
14221 OS << "Trip multiple is " << SE->getSmallConstantTripMultiple(L) << "\n";
14222 }
14223}
14224
14225namespace llvm {
14226// Note: these overloaded operators need to be in the llvm namespace for them
14227// to be resolved correctly. If we put them outside the llvm namespace, the
14228//
14229// OS << ": " << SE.getLoopDisposition(SV, InnerL);
14230//
14231// code below "breaks" and start printing raw enum values as opposed to the
14232// string values.
14235 switch (LD) {
14237 OS << "Variant";
14238 break;
14240 OS << "Invariant";
14241 break;
14243 OS << "Computable";
14244 break;
14245 }
14246 return OS;
14247}
14248
14251 switch (BD) {
14253 OS << "DoesNotDominate";
14254 break;
14256 OS << "Dominates";
14257 break;
14259 OS << "ProperlyDominates";
14260 break;
14261 }
14262 return OS;
14263}
14264} // namespace llvm
14265
14267 // ScalarEvolution's implementation of the print method is to print
14268 // out SCEV values of all instructions that are interesting. Doing
14269 // this potentially causes it to create new SCEV objects though,
14270 // which technically conflicts with the const qualifier. This isn't
14271 // observable from outside the class though, so casting away the
14272 // const isn't dangerous.
14273 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
14274
14275 if (ClassifyExpressions) {
14276 OS << "Classifying expressions for: ";
14277 F.printAsOperand(OS, /*PrintType=*/false);
14278 OS << "\n";
14279 for (Instruction &I : instructions(F))
14280 if (isSCEVable(I.getType()) && !isa<CmpInst>(I)) {
14281 OS << I << '\n';
14282 OS << " --> ";
14283 const SCEV *SV = SE.getSCEV(&I);
14284 SV->print(OS);
14285 if (!isa<SCEVCouldNotCompute>(SV)) {
14286 OS << " U: ";
14287 SE.getUnsignedRange(SV).print(OS);
14288 OS << " S: ";
14289 SE.getSignedRange(SV).print(OS);
14290 }
14291
14292 const Loop *L = LI.getLoopFor(I.getParent());
14293
14294 const SCEV *AtUse = SE.getSCEVAtScope(SV, L);
14295 if (AtUse != SV) {
14296 OS << " --> ";
14297 AtUse->print(OS);
14298 if (!isa<SCEVCouldNotCompute>(AtUse)) {
14299 OS << " U: ";
14300 SE.getUnsignedRange(AtUse).print(OS);
14301 OS << " S: ";
14302 SE.getSignedRange(AtUse).print(OS);
14303 }
14304 }
14305
14306 if (L) {
14307 OS << "\t\t" "Exits: ";
14308 const SCEV *ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop());
14309 if (!SE.isLoopInvariant(ExitValue, L)) {
14310 OS << "<<Unknown>>";
14311 } else {
14312 OS << *ExitValue;
14313 }
14314
14315 ListSeparator LS(", ", "\t\tLoopDispositions: { ");
14316 for (const auto *Iter = L; Iter; Iter = Iter->getParentLoop()) {
14317 OS << LS;
14318 Iter->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14319 OS << ": " << SE.getLoopDisposition(SV, Iter);
14320 }
14321
14322 for (const auto *InnerL : depth_first(L)) {
14323 if (InnerL == L)
14324 continue;
14325 OS << LS;
14326 InnerL->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14327 OS << ": " << SE.getLoopDisposition(SV, InnerL);
14328 }
14329
14330 OS << " }";
14331 }
14332
14333 OS << "\n";
14334 }
14335 }
14336
14337 OS << "Determining loop execution counts for: ";
14338 F.printAsOperand(OS, /*PrintType=*/false);
14339 OS << "\n";
14340 for (Loop *I : LI)
14341 PrintLoopInfo(OS, &SE, I);
14342}
14343
14346 auto &Values = LoopDispositions[S];
14347 for (auto &V : Values) {
14348 if (V.getPointer() == L)
14349 return V.getInt();
14350 }
14351 Values.emplace_back(L, LoopVariant);
14352 LoopDisposition D = computeLoopDisposition(S, L);
14353 auto &Values2 = LoopDispositions[S];
14354 for (auto &V : llvm::reverse(Values2)) {
14355 if (V.getPointer() == L) {
14356 V.setInt(D);
14357 break;
14358 }
14359 }
14360 return D;
14361}
14362
14364ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) {
14365 switch (S->getSCEVType()) {
14366 case scConstant:
14367 case scVScale:
14368 return LoopInvariant;
14369 case scAddRecExpr: {
14370 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
14371
14372 // If L is the addrec's loop, it's computable.
14373 if (AR->getLoop() == L)
14374 return LoopComputable;
14375
14376 // Add recurrences are never invariant in the function-body (null loop).
14377 if (!L)
14378 return LoopVariant;
14379
14380 // Everything that is not defined at loop entry is variant.
14381 if (DT.dominates(L->getHeader(), AR->getLoop()->getHeader()))
14382 return LoopVariant;
14383 assert(!L->contains(AR->getLoop()) && "Containing loop's header does not"
14384 " dominate the contained loop's header?");
14385
14386 // This recurrence is invariant w.r.t. L if AR's loop contains L.
14387 if (AR->getLoop()->contains(L))
14388 return LoopInvariant;
14389
14390 // This recurrence is variant w.r.t. L if any of its operands
14391 // are variant.
14392 for (SCEVUse Op : AR->operands())
14393 if (!isLoopInvariant(Op, L))
14394 return LoopVariant;
14395
14396 // Otherwise it's loop-invariant.
14397 return LoopInvariant;
14398 }
14399 case scTruncate:
14400 case scZeroExtend:
14401 case scSignExtend:
14402 case scPtrToAddr:
14403 case scPtrToInt:
14404 case scAddExpr:
14405 case scMulExpr:
14406 case scUDivExpr:
14407 case scUMaxExpr:
14408 case scSMaxExpr:
14409 case scUMinExpr:
14410 case scSMinExpr:
14411 case scSequentialUMinExpr: {
14412 bool HasVarying = false;
14413 for (SCEVUse Op : S->operands()) {
14415 if (D == LoopVariant)
14416 return LoopVariant;
14417 if (D == LoopComputable)
14418 HasVarying = true;
14419 }
14420 return HasVarying ? LoopComputable : LoopInvariant;
14421 }
14422 case scUnknown:
14423 // All non-instruction values are loop invariant. All instructions are loop
14424 // invariant if they are not contained in the specified loop.
14425 // Instructions are never considered invariant in the function body
14426 // (null loop) because they are defined within the "loop".
14427 if (auto *I = dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue()))
14428 return (L && !L->contains(I)) ? LoopInvariant : LoopVariant;
14429 return LoopInvariant;
14430 case scCouldNotCompute:
14431 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
14432 }
14433 llvm_unreachable("Unknown SCEV kind!");
14434}
14435
14437 return getLoopDisposition(S, L) == LoopInvariant;
14438}
14439
14441 return getLoopDisposition(S, L) == LoopComputable;
14442}
14443
14446 auto &Values = BlockDispositions[S];
14447 for (auto &V : Values) {
14448 if (V.getPointer() == BB)
14449 return V.getInt();
14450 }
14451 Values.emplace_back(BB, DoesNotDominateBlock);
14452 BlockDisposition D = computeBlockDisposition(S, BB);
14453 auto &Values2 = BlockDispositions[S];
14454 for (auto &V : llvm::reverse(Values2)) {
14455 if (V.getPointer() == BB) {
14456 V.setInt(D);
14457 break;
14458 }
14459 }
14460 return D;
14461}
14462
14464ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) {
14465 switch (S->getSCEVType()) {
14466 case scConstant:
14467 case scVScale:
14469 case scAddRecExpr: {
14470 // This uses a "dominates" query instead of "properly dominates" query
14471 // to test for proper dominance too, because the instruction which
14472 // produces the addrec's value is a PHI, and a PHI effectively properly
14473 // dominates its entire containing block.
14474 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
14475 if (!DT.dominates(AR->getLoop()->getHeader(), BB))
14476 return DoesNotDominateBlock;
14477
14478 // Fall through into SCEVNAryExpr handling.
14479 [[fallthrough]];
14480 }
14481 case scTruncate:
14482 case scZeroExtend:
14483 case scSignExtend:
14484 case scPtrToAddr:
14485 case scPtrToInt:
14486 case scAddExpr:
14487 case scMulExpr:
14488 case scUDivExpr:
14489 case scUMaxExpr:
14490 case scSMaxExpr:
14491 case scUMinExpr:
14492 case scSMinExpr:
14493 case scSequentialUMinExpr: {
14494 bool Proper = true;
14495 for (const SCEV *NAryOp : S->operands()) {
14497 if (D == DoesNotDominateBlock)
14498 return DoesNotDominateBlock;
14499 if (D == DominatesBlock)
14500 Proper = false;
14501 }
14502 return Proper ? ProperlyDominatesBlock : DominatesBlock;
14503 }
14504 case scUnknown:
14505 if (Instruction *I =
14506 dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue())) {
14507 if (I->getParent() == BB)
14508 return DominatesBlock;
14509 if (DT.properlyDominates(I->getParent(), BB))
14511 return DoesNotDominateBlock;
14512 }
14514 case scCouldNotCompute:
14515 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
14516 }
14517 llvm_unreachable("Unknown SCEV kind!");
14518}
14519
14520bool ScalarEvolution::dominates(const SCEV *S, const BasicBlock *BB) {
14521 return getBlockDisposition(S, BB) >= DominatesBlock;
14522}
14523
14526}
14527
14528bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const {
14529 return SCEVExprContains(S, [&](const SCEV *Expr) { return Expr == Op; });
14530}
14531
14532void ScalarEvolution::forgetBackedgeTakenCounts(const Loop *L,
14533 bool Predicated) {
14534 auto &BECounts =
14535 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
14536 auto It = BECounts.find(L);
14537 if (It != BECounts.end()) {
14538 for (const ExitNotTakenInfo &ENT : It->second.ExitNotTaken) {
14539 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
14540 if (!isa<SCEVConstant>(S)) {
14541 auto UserIt = BECountUsers.find(S);
14542 assert(UserIt != BECountUsers.end());
14543 UserIt->second.erase({L, Predicated});
14544 }
14545 }
14546 }
14547 BECounts.erase(It);
14548 }
14549}
14550
14551void ScalarEvolution::forgetMemoizedResults(ArrayRef<SCEVUse> SCEVs) {
14552 SmallPtrSet<const SCEV *, 8> ToForget(llvm::from_range, SCEVs);
14553 SmallVector<SCEVUse, 8> Worklist(ToForget.begin(), ToForget.end());
14554
14555 while (!Worklist.empty()) {
14556 const SCEV *Curr = Worklist.pop_back_val();
14557 auto Users = SCEVUsers.find(Curr);
14558 if (Users != SCEVUsers.end())
14559 for (const auto *User : Users->second)
14560 if (ToForget.insert(User).second)
14561 Worklist.push_back(User);
14562 }
14563
14564 for (const auto *S : ToForget)
14565 forgetMemoizedResultsImpl(S);
14566
14567 for (auto I = PredicatedSCEVRewrites.begin();
14568 I != PredicatedSCEVRewrites.end();) {
14569 std::pair<const SCEV *, const Loop *> Entry = I->first;
14570 if (ToForget.count(Entry.first))
14571 PredicatedSCEVRewrites.erase(I++);
14572 else
14573 ++I;
14574 }
14575}
14576
14577void ScalarEvolution::forgetMemoizedResultsImpl(const SCEV *S) {
14578 LoopDispositions.erase(S);
14579 BlockDispositions.erase(S);
14580 UnsignedRanges.erase(S);
14581 SignedRanges.erase(S);
14582 HasRecMap.erase(S);
14583 ConstantMultipleCache.erase(S);
14584
14585 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S)) {
14586 UnsignedWrapViaInductionTried.erase(AR);
14587 SignedWrapViaInductionTried.erase(AR);
14588 }
14589
14590 auto ExprIt = ExprValueMap.find(S);
14591 if (ExprIt != ExprValueMap.end()) {
14592 for (Value *V : ExprIt->second) {
14593 auto ValueIt = ValueExprMap.find_as(V);
14594 if (ValueIt != ValueExprMap.end())
14595 ValueExprMap.erase(ValueIt);
14596 }
14597 ExprValueMap.erase(ExprIt);
14598 }
14599
14600 auto ScopeIt = ValuesAtScopes.find(S);
14601 if (ScopeIt != ValuesAtScopes.end()) {
14602 for (const auto &Pair : ScopeIt->second)
14603 if (!isa_and_nonnull<SCEVConstant>(Pair.second))
14604 llvm::erase(ValuesAtScopesUsers[Pair.second],
14605 std::make_pair(Pair.first, S));
14606 ValuesAtScopes.erase(ScopeIt);
14607 }
14608
14609 auto ScopeUserIt = ValuesAtScopesUsers.find(S);
14610 if (ScopeUserIt != ValuesAtScopesUsers.end()) {
14611 for (const auto &Pair : ScopeUserIt->second)
14612 llvm::erase(ValuesAtScopes[Pair.second], std::make_pair(Pair.first, S));
14613 ValuesAtScopesUsers.erase(ScopeUserIt);
14614 }
14615
14616 auto BEUsersIt = BECountUsers.find(S);
14617 if (BEUsersIt != BECountUsers.end()) {
14618 // Work on a copy, as forgetBackedgeTakenCounts() will modify the original.
14619 auto Copy = BEUsersIt->second;
14620 for (const auto &Pair : Copy)
14621 forgetBackedgeTakenCounts(Pair.getPointer(), Pair.getInt());
14622 BECountUsers.erase(BEUsersIt);
14623 }
14624
14625 auto FoldUser = FoldCacheUser.find(S);
14626 if (FoldUser != FoldCacheUser.end())
14627 for (auto &KV : FoldUser->second)
14628 FoldCache.erase(KV);
14629 FoldCacheUser.erase(S);
14630}
14631
14632void
14633ScalarEvolution::getUsedLoops(const SCEV *S,
14634 SmallPtrSetImpl<const Loop *> &LoopsUsed) {
14635 struct FindUsedLoops {
14636 FindUsedLoops(SmallPtrSetImpl<const Loop *> &LoopsUsed)
14637 : LoopsUsed(LoopsUsed) {}
14638 SmallPtrSetImpl<const Loop *> &LoopsUsed;
14639 bool follow(const SCEV *S) {
14640 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S))
14641 LoopsUsed.insert(AR->getLoop());
14642 return true;
14643 }
14644
14645 bool isDone() const { return false; }
14646 };
14647
14648 FindUsedLoops F(LoopsUsed);
14649 SCEVTraversal<FindUsedLoops>(F).visitAll(S);
14650}
14651
14652void ScalarEvolution::getReachableBlocks(
14655 Worklist.push_back(&F.getEntryBlock());
14656 while (!Worklist.empty()) {
14657 BasicBlock *BB = Worklist.pop_back_val();
14658 if (!Reachable.insert(BB).second)
14659 continue;
14660
14661 Value *Cond;
14662 BasicBlock *TrueBB, *FalseBB;
14663 if (match(BB->getTerminator(), m_Br(m_Value(Cond), m_BasicBlock(TrueBB),
14664 m_BasicBlock(FalseBB)))) {
14665 if (auto *C = dyn_cast<ConstantInt>(Cond)) {
14666 Worklist.push_back(C->isOne() ? TrueBB : FalseBB);
14667 continue;
14668 }
14669
14670 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
14671 const SCEV *L = getSCEV(Cmp->getOperand(0));
14672 const SCEV *R = getSCEV(Cmp->getOperand(1));
14673 if (isKnownPredicateViaConstantRanges(Cmp->getCmpPredicate(), L, R)) {
14674 Worklist.push_back(TrueBB);
14675 continue;
14676 }
14677 if (isKnownPredicateViaConstantRanges(Cmp->getInverseCmpPredicate(), L,
14678 R)) {
14679 Worklist.push_back(FalseBB);
14680 continue;
14681 }
14682 }
14683 }
14684
14685 append_range(Worklist, successors(BB));
14686 }
14687}
14688
14690 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
14691 ScalarEvolution SE2(F, TLI, AC, DT, LI);
14692
14693 SmallVector<Loop *, 8> LoopStack(LI.begin(), LI.end());
14694
14695 // Map's SCEV expressions from one ScalarEvolution "universe" to another.
14696 struct SCEVMapper : public SCEVRewriteVisitor<SCEVMapper> {
14697 SCEVMapper(ScalarEvolution &SE) : SCEVRewriteVisitor<SCEVMapper>(SE) {}
14698
14699 const SCEV *visitConstant(const SCEVConstant *Constant) {
14700 return SE.getConstant(Constant->getAPInt());
14701 }
14702
14703 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14704 return SE.getUnknown(Expr->getValue());
14705 }
14706
14707 const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
14708 return SE.getCouldNotCompute();
14709 }
14710 };
14711
14712 SCEVMapper SCM(SE2);
14713 SmallPtrSet<BasicBlock *, 16> ReachableBlocks;
14714 SE2.getReachableBlocks(ReachableBlocks, F);
14715
14716 auto GetDelta = [&](const SCEV *Old, const SCEV *New) -> const SCEV * {
14717 if (containsUndefs(Old) || containsUndefs(New)) {
14718 // SCEV treats "undef" as an unknown but consistent value (i.e. it does
14719 // not propagate undef aggressively). This means we can (and do) fail
14720 // verification in cases where a transform makes a value go from "undef"
14721 // to "undef+1" (say). The transform is fine, since in both cases the
14722 // result is "undef", but SCEV thinks the value increased by 1.
14723 return nullptr;
14724 }
14725
14726 // Unless VerifySCEVStrict is set, we only compare constant deltas.
14727 const SCEV *Delta = SE2.getMinusSCEV(Old, New);
14728 if (!VerifySCEVStrict && !isa<SCEVConstant>(Delta))
14729 return nullptr;
14730
14731 return Delta;
14732 };
14733
14734 while (!LoopStack.empty()) {
14735 auto *L = LoopStack.pop_back_val();
14736 llvm::append_range(LoopStack, *L);
14737
14738 // Only verify BECounts in reachable loops. For an unreachable loop,
14739 // any BECount is legal.
14740 if (!ReachableBlocks.contains(L->getHeader()))
14741 continue;
14742
14743 // Only verify cached BECounts. Computing new BECounts may change the
14744 // results of subsequent SCEV uses.
14745 auto It = BackedgeTakenCounts.find(L);
14746 if (It == BackedgeTakenCounts.end())
14747 continue;
14748
14749 auto *CurBECount =
14750 SCM.visit(It->second.getExact(L, const_cast<ScalarEvolution *>(this)));
14751 auto *NewBECount = SE2.getBackedgeTakenCount(L);
14752
14753 if (CurBECount == SE2.getCouldNotCompute() ||
14754 NewBECount == SE2.getCouldNotCompute()) {
14755 // NB! This situation is legal, but is very suspicious -- whatever pass
14756 // change the loop to make a trip count go from could not compute to
14757 // computable or vice-versa *should have* invalidated SCEV. However, we
14758 // choose not to assert here (for now) since we don't want false
14759 // positives.
14760 continue;
14761 }
14762
14763 if (SE.getTypeSizeInBits(CurBECount->getType()) >
14764 SE.getTypeSizeInBits(NewBECount->getType()))
14765 NewBECount = SE2.getZeroExtendExpr(NewBECount, CurBECount->getType());
14766 else if (SE.getTypeSizeInBits(CurBECount->getType()) <
14767 SE.getTypeSizeInBits(NewBECount->getType()))
14768 CurBECount = SE2.getZeroExtendExpr(CurBECount, NewBECount->getType());
14769
14770 const SCEV *Delta = GetDelta(CurBECount, NewBECount);
14771 if (Delta && !Delta->isZero()) {
14772 dbgs() << "Trip Count for " << *L << " Changed!\n";
14773 dbgs() << "Old: " << *CurBECount << "\n";
14774 dbgs() << "New: " << *NewBECount << "\n";
14775 dbgs() << "Delta: " << *Delta << "\n";
14776 std::abort();
14777 }
14778 }
14779
14780 // Collect all valid loops currently in LoopInfo.
14781 SmallPtrSet<Loop *, 32> ValidLoops;
14782 SmallVector<Loop *, 32> Worklist(LI.begin(), LI.end());
14783 while (!Worklist.empty()) {
14784 Loop *L = Worklist.pop_back_val();
14785 if (ValidLoops.insert(L).second)
14786 Worklist.append(L->begin(), L->end());
14787 }
14788 for (const auto &KV : ValueExprMap) {
14789#ifndef NDEBUG
14790 // Check for SCEV expressions referencing invalid/deleted loops.
14791 if (auto *AR = dyn_cast<SCEVAddRecExpr>(KV.second)) {
14792 assert(ValidLoops.contains(AR->getLoop()) &&
14793 "AddRec references invalid loop");
14794 }
14795#endif
14796
14797 // Check that the value is also part of the reverse map.
14798 auto It = ExprValueMap.find(KV.second);
14799 if (It == ExprValueMap.end() || !It->second.contains(KV.first)) {
14800 dbgs() << "Value " << *KV.first
14801 << " is in ValueExprMap but not in ExprValueMap\n";
14802 std::abort();
14803 }
14804
14805 if (auto *I = dyn_cast<Instruction>(&*KV.first)) {
14806 if (!ReachableBlocks.contains(I->getParent()))
14807 continue;
14808 const SCEV *OldSCEV = SCM.visit(KV.second);
14809 const SCEV *NewSCEV = SE2.getSCEV(I);
14810 const SCEV *Delta = GetDelta(OldSCEV, NewSCEV);
14811 if (Delta && !Delta->isZero()) {
14812 dbgs() << "SCEV for value " << *I << " changed!\n"
14813 << "Old: " << *OldSCEV << "\n"
14814 << "New: " << *NewSCEV << "\n"
14815 << "Delta: " << *Delta << "\n";
14816 std::abort();
14817 }
14818 }
14819 }
14820
14821 for (const auto &KV : ExprValueMap) {
14822 for (Value *V : KV.second) {
14823 const SCEV *S = ValueExprMap.lookup(V);
14824 if (!S) {
14825 dbgs() << "Value " << *V
14826 << " is in ExprValueMap but not in ValueExprMap\n";
14827 std::abort();
14828 }
14829 if (S != KV.first) {
14830 dbgs() << "Value " << *V << " mapped to " << *S << " rather than "
14831 << *KV.first << "\n";
14832 std::abort();
14833 }
14834 }
14835 }
14836
14837 // Verify integrity of SCEV users.
14838 for (const auto &S : UniqueSCEVs) {
14839 for (SCEVUse Op : S.operands()) {
14840 // We do not store dependencies of constants.
14841 if (isa<SCEVConstant>(Op))
14842 continue;
14843 auto It = SCEVUsers.find(Op);
14844 if (It != SCEVUsers.end() && It->second.count(&S))
14845 continue;
14846 dbgs() << "Use of operand " << *Op << " by user " << S
14847 << " is not being tracked!\n";
14848 std::abort();
14849 }
14850 }
14851
14852 // Verify integrity of ValuesAtScopes users.
14853 for (const auto &ValueAndVec : ValuesAtScopes) {
14854 const SCEV *Value = ValueAndVec.first;
14855 for (const auto &LoopAndValueAtScope : ValueAndVec.second) {
14856 const Loop *L = LoopAndValueAtScope.first;
14857 const SCEV *ValueAtScope = LoopAndValueAtScope.second;
14858 if (!isa<SCEVConstant>(ValueAtScope)) {
14859 auto It = ValuesAtScopesUsers.find(ValueAtScope);
14860 if (It != ValuesAtScopesUsers.end() &&
14861 is_contained(It->second, std::make_pair(L, Value)))
14862 continue;
14863 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14864 << *ValueAtScope << " missing in ValuesAtScopesUsers\n";
14865 std::abort();
14866 }
14867 }
14868 }
14869
14870 for (const auto &ValueAtScopeAndVec : ValuesAtScopesUsers) {
14871 const SCEV *ValueAtScope = ValueAtScopeAndVec.first;
14872 for (const auto &LoopAndValue : ValueAtScopeAndVec.second) {
14873 const Loop *L = LoopAndValue.first;
14874 const SCEV *Value = LoopAndValue.second;
14876 auto It = ValuesAtScopes.find(Value);
14877 if (It != ValuesAtScopes.end() &&
14878 is_contained(It->second, std::make_pair(L, ValueAtScope)))
14879 continue;
14880 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14881 << *ValueAtScope << " missing in ValuesAtScopes\n";
14882 std::abort();
14883 }
14884 }
14885
14886 // Verify integrity of BECountUsers.
14887 auto VerifyBECountUsers = [&](bool Predicated) {
14888 auto &BECounts =
14889 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
14890 for (const auto &LoopAndBEInfo : BECounts) {
14891 for (const ExitNotTakenInfo &ENT : LoopAndBEInfo.second.ExitNotTaken) {
14892 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
14893 if (!isa<SCEVConstant>(S)) {
14894 auto UserIt = BECountUsers.find(S);
14895 if (UserIt != BECountUsers.end() &&
14896 UserIt->second.contains({ LoopAndBEInfo.first, Predicated }))
14897 continue;
14898 dbgs() << "Value " << *S << " for loop " << *LoopAndBEInfo.first
14899 << " missing from BECountUsers\n";
14900 std::abort();
14901 }
14902 }
14903 }
14904 }
14905 };
14906 VerifyBECountUsers(/* Predicated */ false);
14907 VerifyBECountUsers(/* Predicated */ true);
14908
14909 // Verify intergity of loop disposition cache.
14910 for (auto &[S, Values] : LoopDispositions) {
14911 for (auto [Loop, CachedDisposition] : Values) {
14912 const auto RecomputedDisposition = SE2.getLoopDisposition(S, Loop);
14913 if (CachedDisposition != RecomputedDisposition) {
14914 dbgs() << "Cached disposition of " << *S << " for loop " << *Loop
14915 << " is incorrect: cached " << CachedDisposition << ", actual "
14916 << RecomputedDisposition << "\n";
14917 std::abort();
14918 }
14919 }
14920 }
14921
14922 // Verify integrity of the block disposition cache.
14923 for (auto &[S, Values] : BlockDispositions) {
14924 for (auto [BB, CachedDisposition] : Values) {
14925 const auto RecomputedDisposition = SE2.getBlockDisposition(S, BB);
14926 if (CachedDisposition != RecomputedDisposition) {
14927 dbgs() << "Cached disposition of " << *S << " for block %"
14928 << BB->getName() << " is incorrect: cached " << CachedDisposition
14929 << ", actual " << RecomputedDisposition << "\n";
14930 std::abort();
14931 }
14932 }
14933 }
14934
14935 // Verify FoldCache/FoldCacheUser caches.
14936 for (auto [FoldID, Expr] : FoldCache) {
14937 auto I = FoldCacheUser.find(Expr);
14938 if (I == FoldCacheUser.end()) {
14939 dbgs() << "Missing entry in FoldCacheUser for cached expression " << *Expr
14940 << "!\n";
14941 std::abort();
14942 }
14943 if (!is_contained(I->second, FoldID)) {
14944 dbgs() << "Missing FoldID in cached users of " << *Expr << "!\n";
14945 std::abort();
14946 }
14947 }
14948 for (auto [Expr, IDs] : FoldCacheUser) {
14949 for (auto &FoldID : IDs) {
14950 const SCEV *S = FoldCache.lookup(FoldID);
14951 if (!S) {
14952 dbgs() << "Missing entry in FoldCache for expression " << *Expr
14953 << "!\n";
14954 std::abort();
14955 }
14956 if (S != Expr) {
14957 dbgs() << "Entry in FoldCache doesn't match FoldCacheUser: " << *S
14958 << " != " << *Expr << "!\n";
14959 std::abort();
14960 }
14961 }
14962 }
14963
14964 // Verify that ConstantMultipleCache computations are correct. We check that
14965 // cached multiples and recomputed multiples are multiples of each other to
14966 // verify correctness. It is possible that a recomputed multiple is different
14967 // from the cached multiple due to strengthened no wrap flags or changes in
14968 // KnownBits computations.
14969 for (auto [S, Multiple] : ConstantMultipleCache) {
14970 APInt RecomputedMultiple = SE2.getConstantMultiple(S);
14971 if ((Multiple != 0 && RecomputedMultiple != 0 &&
14972 Multiple.urem(RecomputedMultiple) != 0 &&
14973 RecomputedMultiple.urem(Multiple) != 0)) {
14974 dbgs() << "Incorrect cached computation in ConstantMultipleCache for "
14975 << *S << " : Computed " << RecomputedMultiple
14976 << " but cache contains " << Multiple << "!\n";
14977 std::abort();
14978 }
14979 }
14980}
14981
14983 Function &F, const PreservedAnalyses &PA,
14984 FunctionAnalysisManager::Invalidator &Inv) {
14985 // Invalidate the ScalarEvolution object whenever it isn't preserved or one
14986 // of its dependencies is invalidated.
14987 auto PAC = PA.getChecker<ScalarEvolutionAnalysis>();
14988 return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>()) ||
14989 Inv.invalidate<AssumptionAnalysis>(F, PA) ||
14990 Inv.invalidate<DominatorTreeAnalysis>(F, PA) ||
14991 Inv.invalidate<LoopAnalysis>(F, PA);
14992}
14993
14994AnalysisKey ScalarEvolutionAnalysis::Key;
14995
14998 auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
14999 auto &AC = AM.getResult<AssumptionAnalysis>(F);
15000 auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
15001 auto &LI = AM.getResult<LoopAnalysis>(F);
15002 return ScalarEvolution(F, TLI, AC, DT, LI);
15003}
15004
15010
15013 // For compatibility with opt's -analyze feature under legacy pass manager
15014 // which was not ported to NPM. This keeps tests using
15015 // update_analyze_test_checks.py working.
15016 OS << "Printing analysis 'Scalar Evolution Analysis' for function '"
15017 << F.getName() << "':\n";
15019 return PreservedAnalyses::all();
15020}
15021
15023 "Scalar Evolution Analysis", false, true)
15029 "Scalar Evolution Analysis", false, true)
15030
15032
15034
15036 SE.reset(new ScalarEvolution(
15038 getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F),
15040 getAnalysis<LoopInfoWrapperPass>().getLoopInfo()));
15041 return false;
15042}
15043
15045
15047 SE->print(OS);
15048}
15049
15051 if (!VerifySCEV)
15052 return;
15053
15054 SE->verify();
15055}
15056
15064
15066 const SCEV *RHS) {
15067 return getComparePredicate(ICmpInst::ICMP_EQ, LHS, RHS);
15068}
15069
15070const SCEVPredicate *
15072 const SCEV *LHS, const SCEV *RHS) {
15074 assert(LHS->getType() == RHS->getType() &&
15075 "Type mismatch between LHS and RHS");
15076 // Unique this node based on the arguments
15077 ID.AddInteger(SCEVPredicate::P_Compare);
15078 ID.AddInteger(Pred);
15079 ID.AddPointer(LHS);
15080 ID.AddPointer(RHS);
15081 void *IP = nullptr;
15082 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
15083 return S;
15084 SCEVComparePredicate *Eq = new (SCEVAllocator)
15085 SCEVComparePredicate(ID.Intern(SCEVAllocator), Pred, LHS, RHS);
15086 UniquePreds.InsertNode(Eq, IP);
15087 return Eq;
15088}
15089
15091 const SCEVAddRecExpr *AR,
15094 // Unique this node based on the arguments
15095 ID.AddInteger(SCEVPredicate::P_Wrap);
15096 ID.AddPointer(AR);
15097 ID.AddInteger(AddedFlags);
15098 void *IP = nullptr;
15099 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
15100 return S;
15101 auto *OF = new (SCEVAllocator)
15102 SCEVWrapPredicate(ID.Intern(SCEVAllocator), AR, AddedFlags);
15103 UniquePreds.InsertNode(OF, IP);
15104 return OF;
15105}
15106
15107namespace {
15108
15109class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
15110public:
15111
15112 /// Rewrites \p S in the context of a loop L and the SCEV predication
15113 /// infrastructure.
15114 ///
15115 /// If \p Pred is non-null, the SCEV expression is rewritten to respect the
15116 /// equivalences present in \p Pred.
15117 ///
15118 /// If \p NewPreds is non-null, rewrite is free to add further predicates to
15119 /// \p NewPreds such that the result will be an AddRecExpr.
15120 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
15122 const SCEVPredicate *Pred) {
15123 SCEVPredicateRewriter Rewriter(L, SE, NewPreds, Pred);
15124 return Rewriter.visit(S);
15125 }
15126
15127 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
15128 if (Pred) {
15129 if (auto *U = dyn_cast<SCEVUnionPredicate>(Pred)) {
15130 for (const auto *Pred : U->getPredicates())
15131 if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred))
15132 if (IPred->getLHS() == Expr &&
15133 IPred->getPredicate() == ICmpInst::ICMP_EQ)
15134 return IPred->getRHS();
15135 } else if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred)) {
15136 if (IPred->getLHS() == Expr &&
15137 IPred->getPredicate() == ICmpInst::ICMP_EQ)
15138 return IPred->getRHS();
15139 }
15140 }
15141 return convertToAddRecWithPreds(Expr);
15142 }
15143
15144 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
15145 const SCEV *Operand = visit(Expr->getOperand());
15146 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
15147 if (AR && AR->getLoop() == L && AR->isAffine()) {
15148 // This couldn't be folded because the operand didn't have the nuw
15149 // flag. Add the nusw flag as an assumption that we could make.
15150 const SCEV *Step = AR->getStepRecurrence(SE);
15151 Type *Ty = Expr->getType();
15152 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNUSW))
15153 return SE.getAddRecExpr(SE.getZeroExtendExpr(AR->getStart(), Ty),
15154 SE.getSignExtendExpr(Step, Ty), L,
15155 AR->getNoWrapFlags());
15156 }
15157 return SE.getZeroExtendExpr(Operand, Expr->getType());
15158 }
15159
15160 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
15161 const SCEV *Operand = visit(Expr->getOperand());
15162 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
15163 if (AR && AR->getLoop() == L && AR->isAffine()) {
15164 // This couldn't be folded because the operand didn't have the nsw
15165 // flag. Add the nssw flag as an assumption that we could make.
15166 const SCEV *Step = AR->getStepRecurrence(SE);
15167 Type *Ty = Expr->getType();
15168 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNSSW))
15169 return SE.getAddRecExpr(SE.getSignExtendExpr(AR->getStart(), Ty),
15170 SE.getSignExtendExpr(Step, Ty), L,
15171 AR->getNoWrapFlags());
15172 }
15173 return SE.getSignExtendExpr(Operand, Expr->getType());
15174 }
15175
15176private:
15177 explicit SCEVPredicateRewriter(
15178 const Loop *L, ScalarEvolution &SE,
15179 SmallVectorImpl<const SCEVPredicate *> *NewPreds,
15180 const SCEVPredicate *Pred)
15181 : SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {}
15182
15183 bool addOverflowAssumption(const SCEVPredicate *P) {
15184 if (!NewPreds) {
15185 // Check if we've already made this assumption.
15186 return Pred && Pred->implies(P, SE);
15187 }
15188 NewPreds->push_back(P);
15189 return true;
15190 }
15191
15192 bool addOverflowAssumption(const SCEVAddRecExpr *AR,
15194 auto *A = SE.getWrapPredicate(AR, AddedFlags);
15195 return addOverflowAssumption(A);
15196 }
15197
15198 // If \p Expr represents a PHINode, we try to see if it can be represented
15199 // as an AddRec, possibly under a predicate (PHISCEVPred). If it is possible
15200 // to add this predicate as a runtime overflow check, we return the AddRec.
15201 // If \p Expr does not meet these conditions (is not a PHI node, or we
15202 // couldn't create an AddRec for it, or couldn't add the predicate), we just
15203 // return \p Expr.
15204 const SCEV *convertToAddRecWithPreds(const SCEVUnknown *Expr) {
15205 if (!isa<PHINode>(Expr->getValue()))
15206 return Expr;
15207 std::optional<
15208 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
15209 PredicatedRewrite = SE.createAddRecFromPHIWithCasts(Expr);
15210 if (!PredicatedRewrite)
15211 return Expr;
15212 for (const auto *P : PredicatedRewrite->second){
15213 // Wrap predicates from outer loops are not supported.
15214 if (auto *WP = dyn_cast<const SCEVWrapPredicate>(P)) {
15215 if (L != WP->getExpr()->getLoop())
15216 return Expr;
15217 }
15218 if (!addOverflowAssumption(P))
15219 return Expr;
15220 }
15221 return PredicatedRewrite->first;
15222 }
15223
15224 SmallVectorImpl<const SCEVPredicate *> *NewPreds;
15225 const SCEVPredicate *Pred;
15226 const Loop *L;
15227};
15228
15229} // end anonymous namespace
15230
15231const SCEV *
15233 const SCEVPredicate &Preds) {
15234 return SCEVPredicateRewriter::rewrite(S, L, *this, nullptr, &Preds);
15235}
15236
15238 const SCEV *S, const Loop *L,
15241 S = SCEVPredicateRewriter::rewrite(S, L, *this, &TransformPreds, nullptr);
15242 auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
15243
15244 if (!AddRec)
15245 return nullptr;
15246
15247 // Check if any of the transformed predicates is known to be false. In that
15248 // case, it doesn't make sense to convert to a predicated AddRec, as the
15249 // versioned loop will never execute.
15250 for (const SCEVPredicate *Pred : TransformPreds) {
15251 auto *WrapPred = dyn_cast<SCEVWrapPredicate>(Pred);
15252 if (!WrapPred || WrapPred->getFlags() != SCEVWrapPredicate::IncrementNSSW)
15253 continue;
15254
15255 const SCEVAddRecExpr *AddRecToCheck = WrapPred->getExpr();
15256 const SCEV *ExitCount = getBackedgeTakenCount(AddRecToCheck->getLoop());
15257 if (isa<SCEVCouldNotCompute>(ExitCount))
15258 continue;
15259
15260 const SCEV *Step = AddRecToCheck->getStepRecurrence(*this);
15261 if (!Step->isOne())
15262 continue;
15263
15264 ExitCount = getTruncateOrSignExtend(ExitCount, Step->getType());
15265 const SCEV *Add = getAddExpr(AddRecToCheck->getStart(), ExitCount);
15266 if (isKnownPredicate(CmpInst::ICMP_SLT, Add, AddRecToCheck->getStart()))
15267 return nullptr;
15268 }
15269
15270 // Since the transformation was successful, we can now transfer the SCEV
15271 // predicates.
15272 Preds.append(TransformPreds.begin(), TransformPreds.end());
15273
15274 return AddRec;
15275}
15276
15277/// SCEV predicates
15281
15283 const ICmpInst::Predicate Pred,
15284 const SCEV *LHS, const SCEV *RHS)
15285 : SCEVPredicate(ID, P_Compare), Pred(Pred), LHS(LHS), RHS(RHS) {
15286 assert(LHS->getType() == RHS->getType() && "LHS and RHS types don't match");
15287 assert(LHS != RHS && "LHS and RHS are the same SCEV");
15288}
15289
15291 ScalarEvolution &SE) const {
15292 const auto *Op = dyn_cast<SCEVComparePredicate>(N);
15293
15294 if (!Op)
15295 return false;
15296
15297 if (Pred != ICmpInst::ICMP_EQ)
15298 return false;
15299
15300 return Op->LHS == LHS && Op->RHS == RHS;
15301}
15302
15303bool SCEVComparePredicate::isAlwaysTrue() const { return false; }
15304
15306 if (Pred == ICmpInst::ICMP_EQ)
15307 OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n";
15308 else
15309 OS.indent(Depth) << "Compare predicate: " << *LHS << " " << Pred << ") "
15310 << *RHS << "\n";
15311
15312}
15313
15315 const SCEVAddRecExpr *AR,
15316 IncrementWrapFlags Flags)
15317 : SCEVPredicate(ID, P_Wrap), AR(AR), Flags(Flags) {}
15318
15319const SCEVAddRecExpr *SCEVWrapPredicate::getExpr() const { return AR; }
15320
15322 ScalarEvolution &SE) const {
15323 const auto *Op = dyn_cast<SCEVWrapPredicate>(N);
15324 if (!Op || setFlags(Flags, Op->Flags) != Flags)
15325 return false;
15326
15327 if (Op->AR == AR)
15328 return true;
15329
15330 if (Flags != SCEVWrapPredicate::IncrementNSSW &&
15332 return false;
15333
15334 const SCEV *Start = AR->getStart();
15335 const SCEV *OpStart = Op->AR->getStart();
15336 if (Start->getType()->isPointerTy() != OpStart->getType()->isPointerTy())
15337 return false;
15338
15339 // Reject pointers to different address spaces.
15340 if (Start->getType()->isPointerTy() && Start->getType() != OpStart->getType())
15341 return false;
15342
15343 const SCEV *Step = AR->getStepRecurrence(SE);
15344 const SCEV *OpStep = Op->AR->getStepRecurrence(SE);
15345 if (!SE.isKnownPositive(Step) || !SE.isKnownPositive(OpStep))
15346 return false;
15347
15348 // If both steps are positive, this implies N, if N's start and step are
15349 // ULE/SLE (for NSUW/NSSW) than this'.
15350 Type *WiderTy = SE.getWiderType(Step->getType(), OpStep->getType());
15351 Step = SE.getNoopOrZeroExtend(Step, WiderTy);
15352 OpStep = SE.getNoopOrZeroExtend(OpStep, WiderTy);
15353
15354 bool IsNUW = Flags == SCEVWrapPredicate::IncrementNUSW;
15355 OpStart = IsNUW ? SE.getNoopOrZeroExtend(OpStart, WiderTy)
15356 : SE.getNoopOrSignExtend(OpStart, WiderTy);
15357 Start = IsNUW ? SE.getNoopOrZeroExtend(Start, WiderTy)
15358 : SE.getNoopOrSignExtend(Start, WiderTy);
15360 return SE.isKnownPredicate(Pred, OpStep, Step) &&
15361 SE.isKnownPredicate(Pred, OpStart, Start);
15362}
15363
15365 SCEV::NoWrapFlags ScevFlags = AR->getNoWrapFlags();
15366 IncrementWrapFlags IFlags = Flags;
15367
15368 if (ScalarEvolution::setFlags(ScevFlags, SCEV::FlagNSW) == ScevFlags)
15369 IFlags = clearFlags(IFlags, IncrementNSSW);
15370
15371 return IFlags == IncrementAnyWrap;
15372}
15373
15374void SCEVWrapPredicate::print(raw_ostream &OS, unsigned Depth) const {
15375 OS.indent(Depth) << *getExpr() << " Added Flags: ";
15377 OS << "<nusw>";
15379 OS << "<nssw>";
15380 OS << "\n";
15381}
15382
15385 ScalarEvolution &SE) {
15386 IncrementWrapFlags ImpliedFlags = IncrementAnyWrap;
15387 SCEV::NoWrapFlags StaticFlags = AR->getNoWrapFlags();
15388
15389 // We can safely transfer the NSW flag as NSSW.
15390 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNSW) == StaticFlags)
15391 ImpliedFlags = IncrementNSSW;
15392
15393 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNUW) == StaticFlags) {
15394 // If the increment is positive, the SCEV NUW flag will also imply the
15395 // WrapPredicate NUSW flag.
15396 if (const auto *Step = dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE)))
15397 if (Step->getValue()->getValue().isNonNegative())
15398 ImpliedFlags = setFlags(ImpliedFlags, IncrementNUSW);
15399 }
15400
15401 return ImpliedFlags;
15402}
15403
15404/// Union predicates don't get cached so create a dummy set ID for it.
15406 ScalarEvolution &SE)
15407 : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
15408 for (const auto *P : Preds)
15409 add(P, SE);
15410}
15411
15413 return all_of(Preds,
15414 [](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
15415}
15416
15418 ScalarEvolution &SE) const {
15419 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N))
15420 return all_of(Set->Preds, [this, &SE](const SCEVPredicate *I) {
15421 return this->implies(I, SE);
15422 });
15423
15424 return any_of(Preds,
15425 [N, &SE](const SCEVPredicate *I) { return I->implies(N, SE); });
15426}
15427
15429 for (const auto *Pred : Preds)
15430 Pred->print(OS, Depth);
15431}
15432
15433void SCEVUnionPredicate::add(const SCEVPredicate *N, ScalarEvolution &SE) {
15434 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) {
15435 for (const auto *Pred : Set->Preds)
15436 add(Pred, SE);
15437 return;
15438 }
15439
15440 // Implication checks are quadratic in the number of predicates. Stop doing
15441 // them if there are many predicates, as they should be too expensive to use
15442 // anyway at that point.
15443 bool CheckImplies = Preds.size() < 16;
15444
15445 // Only add predicate if it is not already implied by this union predicate.
15446 if (CheckImplies && implies(N, SE))
15447 return;
15448
15449 // Build a new vector containing the current predicates, except the ones that
15450 // are implied by the new predicate N.
15452 for (auto *P : Preds) {
15453 if (CheckImplies && N->implies(P, SE))
15454 continue;
15455 PrunedPreds.push_back(P);
15456 }
15457 Preds = std::move(PrunedPreds);
15458 Preds.push_back(N);
15459}
15460
15462 Loop &L)
15463 : SE(SE), L(L) {
15465 Preds = std::make_unique<SCEVUnionPredicate>(Empty, SE);
15466}
15467
15470 for (const auto *Op : Ops)
15471 // We do not expect that forgetting cached data for SCEVConstants will ever
15472 // open any prospects for sharpening or introduce any correctness issues,
15473 // so we don't bother storing their dependencies.
15474 if (!isa<SCEVConstant>(Op))
15475 SCEVUsers[Op].insert(User);
15476}
15477
15479 for (const SCEV *Op : Ops)
15480 // We do not expect that forgetting cached data for SCEVConstants will ever
15481 // open any prospects for sharpening or introduce any correctness issues,
15482 // so we don't bother storing their dependencies.
15483 if (!isa<SCEVConstant>(Op))
15484 SCEVUsers[Op].insert(User);
15485}
15486
15488 const SCEV *Expr = SE.getSCEV(V);
15489 return getPredicatedSCEV(Expr);
15490}
15491
15493 RewriteEntry &Entry = RewriteMap[Expr];
15494
15495 // If we already have an entry and the version matches, return it.
15496 if (Entry.second && Generation == Entry.first)
15497 return Entry.second;
15498
15499 // We found an entry but it's stale. Rewrite the stale entry
15500 // according to the current predicate.
15501 if (Entry.second)
15502 Expr = Entry.second;
15503
15504 const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, &L, *Preds);
15505 Entry = {Generation, NewSCEV};
15506
15507 return NewSCEV;
15508}
15509
15511 if (!BackedgeCount) {
15513 BackedgeCount = SE.getPredicatedBackedgeTakenCount(&L, Preds);
15514 for (const auto *P : Preds)
15515 addPredicate(*P);
15516 }
15517 return BackedgeCount;
15518}
15519
15521 if (!SymbolicMaxBackedgeCount) {
15523 SymbolicMaxBackedgeCount =
15524 SE.getPredicatedSymbolicMaxBackedgeTakenCount(&L, Preds);
15525 for (const auto *P : Preds)
15526 addPredicate(*P);
15527 }
15528 return SymbolicMaxBackedgeCount;
15529}
15530
15532 if (!SmallConstantMaxTripCount) {
15534 SmallConstantMaxTripCount = SE.getSmallConstantMaxTripCount(&L, &Preds);
15535 for (const auto *P : Preds)
15536 addPredicate(*P);
15537 }
15538 return *SmallConstantMaxTripCount;
15539}
15540
15542 if (Preds->implies(&Pred, SE))
15543 return;
15544
15545 SmallVector<const SCEVPredicate *, 4> NewPreds(Preds->getPredicates());
15546 NewPreds.push_back(&Pred);
15547 Preds = std::make_unique<SCEVUnionPredicate>(NewPreds, SE);
15548 updateGeneration();
15549}
15550
15552 return *Preds;
15553}
15554
15555void PredicatedScalarEvolution::updateGeneration() {
15556 // If the generation number wrapped recompute everything.
15557 if (++Generation == 0) {
15558 for (auto &II : RewriteMap) {
15559 const SCEV *Rewritten = II.second.second;
15560 II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, &L, *Preds)};
15561 }
15562 }
15563}
15564
15567 const SCEV *Expr = getSCEV(V);
15568 const auto *AR = cast<SCEVAddRecExpr>(Expr);
15569
15570 auto ImpliedFlags = SCEVWrapPredicate::getImpliedFlags(AR, SE);
15571
15572 // Clear the statically implied flags.
15573 Flags = SCEVWrapPredicate::clearFlags(Flags, ImpliedFlags);
15574 addPredicate(*SE.getWrapPredicate(AR, Flags));
15575
15576 auto II = FlagsMap.insert({V, Flags});
15577 if (!II.second)
15578 II.first->second = SCEVWrapPredicate::setFlags(Flags, II.first->second);
15579}
15580
15583 const SCEV *Expr = getSCEV(V);
15584 const auto *AR = cast<SCEVAddRecExpr>(Expr);
15585
15587 Flags, SCEVWrapPredicate::getImpliedFlags(AR, SE));
15588
15589 auto II = FlagsMap.find(V);
15590
15591 if (II != FlagsMap.end())
15592 Flags = SCEVWrapPredicate::clearFlags(Flags, II->second);
15593
15595}
15596
15598 const SCEV *Expr = this->getSCEV(V);
15600 auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, NewPreds);
15601
15602 if (!New)
15603 return nullptr;
15604
15605 for (const auto *P : NewPreds)
15606 addPredicate(*P);
15607
15608 RewriteMap[SE.getSCEV(V)] = {Generation, New};
15609 return New;
15610}
15611
15614 : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
15615 Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates(),
15616 SE)),
15617 Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
15618 for (auto I : Init.FlagsMap)
15619 FlagsMap.insert(I);
15620}
15621
15623 // For each block.
15624 for (auto *BB : L.getBlocks())
15625 for (auto &I : *BB) {
15626 if (!SE.isSCEVable(I.getType()))
15627 continue;
15628
15629 auto *Expr = SE.getSCEV(&I);
15630 auto II = RewriteMap.find(Expr);
15631
15632 if (II == RewriteMap.end())
15633 continue;
15634
15635 // Don't print things that are not interesting.
15636 if (II->second.second == Expr)
15637 continue;
15638
15639 OS.indent(Depth) << "[PSE]" << I << ":\n";
15640 OS.indent(Depth + 2) << *Expr << "\n";
15641 OS.indent(Depth + 2) << "--> " << *II->second.second << "\n";
15642 }
15643}
15644
15647 BasicBlock *Header = L->getHeader();
15648 BasicBlock *Pred = L->getLoopPredecessor();
15649 LoopGuards Guards(SE);
15650 if (!Pred)
15651 return Guards;
15653 collectFromBlock(SE, Guards, Header, Pred, VisitedBlocks);
15654 return Guards;
15655}
15656
15657void ScalarEvolution::LoopGuards::collectFromPHI(
15661 unsigned Depth) {
15662 if (!SE.isSCEVable(Phi.getType()))
15663 return;
15664
15665 using MinMaxPattern = std::pair<const SCEVConstant *, SCEVTypes>;
15666 auto GetMinMaxConst = [&](unsigned IncomingIdx) -> MinMaxPattern {
15667 const BasicBlock *InBlock = Phi.getIncomingBlock(IncomingIdx);
15668 if (!VisitedBlocks.insert(InBlock).second)
15669 return {nullptr, scCouldNotCompute};
15670
15671 // Avoid analyzing unreachable blocks so that we don't get trapped
15672 // traversing cycles with ill-formed dominance or infinite cycles
15673 if (!SE.DT.isReachableFromEntry(InBlock))
15674 return {nullptr, scCouldNotCompute};
15675
15676 auto [G, Inserted] = IncomingGuards.try_emplace(InBlock, LoopGuards(SE));
15677 if (Inserted)
15678 collectFromBlock(SE, G->second, Phi.getParent(), InBlock, VisitedBlocks,
15679 Depth + 1);
15680 auto &RewriteMap = G->second.RewriteMap;
15681 if (RewriteMap.empty())
15682 return {nullptr, scCouldNotCompute};
15683 auto S = RewriteMap.find(SE.getSCEV(Phi.getIncomingValue(IncomingIdx)));
15684 if (S == RewriteMap.end())
15685 return {nullptr, scCouldNotCompute};
15686 auto *SM = dyn_cast_if_present<SCEVMinMaxExpr>(S->second);
15687 if (!SM)
15688 return {nullptr, scCouldNotCompute};
15689 if (const SCEVConstant *C0 = dyn_cast<SCEVConstant>(SM->getOperand(0)))
15690 return {C0, SM->getSCEVType()};
15691 return {nullptr, scCouldNotCompute};
15692 };
15693 auto MergeMinMaxConst = [](MinMaxPattern P1,
15694 MinMaxPattern P2) -> MinMaxPattern {
15695 auto [C1, T1] = P1;
15696 auto [C2, T2] = P2;
15697 if (!C1 || !C2 || T1 != T2)
15698 return {nullptr, scCouldNotCompute};
15699 switch (T1) {
15700 case scUMaxExpr:
15701 return {C1->getAPInt().ult(C2->getAPInt()) ? C1 : C2, T1};
15702 case scSMaxExpr:
15703 return {C1->getAPInt().slt(C2->getAPInt()) ? C1 : C2, T1};
15704 case scUMinExpr:
15705 return {C1->getAPInt().ugt(C2->getAPInt()) ? C1 : C2, T1};
15706 case scSMinExpr:
15707 return {C1->getAPInt().sgt(C2->getAPInt()) ? C1 : C2, T1};
15708 default:
15709 llvm_unreachable("Trying to merge non-MinMaxExpr SCEVs.");
15710 }
15711 };
15712 auto P = GetMinMaxConst(0);
15713 for (unsigned int In = 1; In < Phi.getNumIncomingValues(); In++) {
15714 if (!P.first)
15715 break;
15716 P = MergeMinMaxConst(P, GetMinMaxConst(In));
15717 }
15718 if (P.first) {
15719 const SCEV *LHS = SE.getSCEV(const_cast<PHINode *>(&Phi));
15720 SmallVector<SCEVUse, 2> Ops({P.first, LHS});
15721 const SCEV *RHS = SE.getMinMaxExpr(P.second, Ops);
15722 Guards.RewriteMap.insert({LHS, RHS});
15723 }
15724}
15725
15726// Return a new SCEV that modifies \p Expr to the closest number divides by
15727// \p Divisor and less or equal than Expr. For now, only handle constant
15728// Expr.
15730 const APInt &DivisorVal,
15731 ScalarEvolution &SE) {
15732 const APInt *ExprVal;
15733 if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() ||
15734 DivisorVal.isNonPositive())
15735 return Expr;
15736 APInt Rem = ExprVal->urem(DivisorVal);
15737 // return the SCEV: Expr - Expr % Divisor
15738 return SE.getConstant(*ExprVal - Rem);
15739}
15740
15741// Return a new SCEV that modifies \p Expr to the closest number divides by
15742// \p Divisor and greater or equal than Expr. For now, only handle constant
15743// Expr.
15744static const SCEV *getNextSCEVDivisibleByDivisor(const SCEV *Expr,
15745 const APInt &DivisorVal,
15746 ScalarEvolution &SE) {
15747 const APInt *ExprVal;
15748 if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() ||
15749 DivisorVal.isNonPositive())
15750 return Expr;
15751 APInt Rem = ExprVal->urem(DivisorVal);
15752 if (Rem.isZero())
15753 return Expr;
15754 // return the SCEV: Expr + Divisor - Expr % Divisor
15755 return SE.getConstant(*ExprVal + DivisorVal - Rem);
15756}
15757
15759 ICmpInst::Predicate Predicate, const SCEV *LHS, const SCEV *RHS,
15762 // If we have LHS == 0, check if LHS is computing a property of some unknown
15763 // SCEV %v which we can rewrite %v to express explicitly.
15765 return false;
15766 // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
15767 // explicitly express that.
15768 const SCEVUnknown *URemLHS = nullptr;
15769 const SCEV *URemRHS = nullptr;
15770 if (!match(LHS, m_scev_URem(m_SCEVUnknown(URemLHS), m_SCEV(URemRHS), SE)))
15771 return false;
15772
15773 const SCEV *Multiple =
15774 SE.getMulExpr(SE.getUDivExpr(URemLHS, URemRHS), URemRHS);
15775 DivInfo[URemLHS] = Multiple;
15776 if (auto *C = dyn_cast<SCEVConstant>(URemRHS))
15777 Multiples[URemLHS] = C->getAPInt();
15778 return true;
15779}
15780
15781// Check if the condition is a divisibility guard (A % B == 0).
15782static bool isDivisibilityGuard(const SCEV *LHS, const SCEV *RHS,
15783 ScalarEvolution &SE) {
15784 const SCEV *X, *Y;
15785 return match(LHS, m_scev_URem(m_SCEV(X), m_SCEV(Y), SE)) && RHS->isZero();
15786}
15787
15788// Apply divisibility by \p Divisor on MinMaxExpr with constant values,
15789// recursively. This is done by aligning up/down the constant value to the
15790// Divisor.
15791static const SCEV *applyDivisibilityOnMinMaxExpr(const SCEV *MinMaxExpr,
15792 APInt Divisor,
15793 ScalarEvolution &SE) {
15794 // Return true if \p Expr is a MinMax SCEV expression with a non-negative
15795 // constant operand. If so, return in \p SCTy the SCEV type and in \p RHS
15796 // the non-constant operand and in \p LHS the constant operand.
15797 auto IsMinMaxSCEVWithNonNegativeConstant =
15798 [&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS,
15799 const SCEV *&RHS) {
15800 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
15801 if (MinMax->getNumOperands() != 2)
15802 return false;
15803 if (auto *C = dyn_cast<SCEVConstant>(MinMax->getOperand(0))) {
15804 if (C->getAPInt().isNegative())
15805 return false;
15806 SCTy = MinMax->getSCEVType();
15807 LHS = MinMax->getOperand(0);
15808 RHS = MinMax->getOperand(1);
15809 return true;
15810 }
15811 }
15812 return false;
15813 };
15814
15815 const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr;
15816 SCEVTypes SCTy;
15817 if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS,
15818 MinMaxRHS))
15819 return MinMaxExpr;
15820 auto IsMin = isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr);
15821 assert(SE.isKnownNonNegative(MinMaxLHS) && "Expected non-negative operand!");
15822 auto *DivisibleExpr =
15823 IsMin ? getPreviousSCEVDivisibleByDivisor(MinMaxLHS, Divisor, SE)
15824 : getNextSCEVDivisibleByDivisor(MinMaxLHS, Divisor, SE);
15826 applyDivisibilityOnMinMaxExpr(MinMaxRHS, Divisor, SE), DivisibleExpr};
15827 return SE.getMinMaxExpr(SCTy, Ops);
15828}
15829
15830void ScalarEvolution::LoopGuards::collectFromBlock(
15831 ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
15832 const BasicBlock *Block, const BasicBlock *Pred,
15833 SmallPtrSetImpl<const BasicBlock *> &VisitedBlocks, unsigned Depth) {
15834
15836
15837 SmallVector<SCEVUse> ExprsToRewrite;
15838 auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
15839 const SCEV *RHS,
15840 DenseMap<const SCEV *, const SCEV *> &RewriteMap,
15841 const LoopGuards &DivGuards) {
15842 // WARNING: It is generally unsound to apply any wrap flags to the proposed
15843 // replacement SCEV which isn't directly implied by the structure of that
15844 // SCEV. In particular, using contextual facts to imply flags is *NOT*
15845 // legal. See the scoping rules for flags in the header to understand why.
15846
15847 // Check for a condition of the form (-C1 + X < C2). InstCombine will
15848 // create this form when combining two checks of the form (X u< C2 + C1) and
15849 // (X >=u C1).
15850 auto MatchRangeCheckIdiom = [&SE, Predicate, LHS, RHS, &RewriteMap,
15851 &ExprsToRewrite]() {
15852 const SCEVConstant *C1;
15853 const SCEVUnknown *LHSUnknown;
15854 auto *C2 = dyn_cast<SCEVConstant>(RHS);
15855 if (!match(LHS,
15856 m_scev_Add(m_SCEVConstant(C1), m_SCEVUnknown(LHSUnknown))) ||
15857 !C2)
15858 return false;
15859
15860 auto ExactRegion =
15861 ConstantRange::makeExactICmpRegion(Predicate, C2->getAPInt())
15862 .sub(C1->getAPInt());
15863
15864 // Bail out, unless we have a non-wrapping, monotonic range.
15865 if (ExactRegion.isWrappedSet() || ExactRegion.isFullSet())
15866 return false;
15867 auto [I, Inserted] = RewriteMap.try_emplace(LHSUnknown);
15868 const SCEV *RewrittenLHS = Inserted ? LHSUnknown : I->second;
15869 I->second = SE.getUMaxExpr(
15870 SE.getConstant(ExactRegion.getUnsignedMin()),
15871 SE.getUMinExpr(RewrittenLHS,
15872 SE.getConstant(ExactRegion.getUnsignedMax())));
15873 ExprsToRewrite.push_back(LHSUnknown);
15874 return true;
15875 };
15876 if (MatchRangeCheckIdiom())
15877 return;
15878
15879 // Do not apply information for constants or if RHS contains an AddRec.
15881 return;
15882
15883 // If RHS is SCEVUnknown, make sure the information is applied to it.
15885 std::swap(LHS, RHS);
15887 }
15888
15889 // Puts rewrite rule \p From -> \p To into the rewrite map. Also if \p From
15890 // and \p FromRewritten are the same (i.e. there has been no rewrite
15891 // registered for \p From), then puts this value in the list of rewritten
15892 // expressions.
15893 auto AddRewrite = [&](const SCEV *From, const SCEV *FromRewritten,
15894 const SCEV *To) {
15895 if (From == FromRewritten)
15896 ExprsToRewrite.push_back(From);
15897 RewriteMap[From] = To;
15898 };
15899
15900 // Checks whether \p S has already been rewritten. In that case returns the
15901 // existing rewrite because we want to chain further rewrites onto the
15902 // already rewritten value. Otherwise returns \p S.
15903 auto GetMaybeRewritten = [&](const SCEV *S) {
15904 return RewriteMap.lookup_or(S, S);
15905 };
15906
15907 const SCEV *RewrittenLHS = GetMaybeRewritten(LHS);
15908 // Apply divisibility information when computing the constant multiple.
15909 const APInt &DividesBy =
15910 SE.getConstantMultiple(DivGuards.rewrite(RewrittenLHS));
15911
15912 // Collect rewrites for LHS and its transitive operands based on the
15913 // condition.
15914 // For min/max expressions, also apply the guard to its operands:
15915 // 'min(a, b) >= c' -> '(a >= c) and (b >= c)',
15916 // 'min(a, b) > c' -> '(a > c) and (b > c)',
15917 // 'max(a, b) <= c' -> '(a <= c) and (b <= c)',
15918 // 'max(a, b) < c' -> '(a < c) and (b < c)'.
15919
15920 // We cannot express strict predicates in SCEV, so instead we replace them
15921 // with non-strict ones against plus or minus one of RHS depending on the
15922 // predicate.
15923 const SCEV *One = SE.getOne(RHS->getType());
15924 switch (Predicate) {
15925 case CmpInst::ICMP_ULT:
15926 if (RHS->getType()->isPointerTy())
15927 return;
15928 RHS = SE.getUMaxExpr(RHS, One);
15929 [[fallthrough]];
15930 case CmpInst::ICMP_SLT: {
15931 RHS = SE.getMinusSCEV(RHS, One);
15932 RHS = getPreviousSCEVDivisibleByDivisor(RHS, DividesBy, SE);
15933 break;
15934 }
15935 case CmpInst::ICMP_UGT:
15936 case CmpInst::ICMP_SGT:
15937 RHS = SE.getAddExpr(RHS, One);
15938 RHS = getNextSCEVDivisibleByDivisor(RHS, DividesBy, SE);
15939 break;
15940 case CmpInst::ICMP_ULE:
15941 case CmpInst::ICMP_SLE:
15942 RHS = getPreviousSCEVDivisibleByDivisor(RHS, DividesBy, SE);
15943 break;
15944 case CmpInst::ICMP_UGE:
15945 case CmpInst::ICMP_SGE:
15946 RHS = getNextSCEVDivisibleByDivisor(RHS, DividesBy, SE);
15947 break;
15948 default:
15949 break;
15950 }
15951
15952 SmallVector<SCEVUse, 16> Worklist(1, LHS);
15953 SmallPtrSet<const SCEV *, 16> Visited;
15954
15955 auto EnqueueOperands = [&Worklist](const SCEVNAryExpr *S) {
15956 append_range(Worklist, S->operands());
15957 };
15958
15959 while (!Worklist.empty()) {
15960 const SCEV *From = Worklist.pop_back_val();
15961 if (isa<SCEVConstant>(From))
15962 continue;
15963 if (!Visited.insert(From).second)
15964 continue;
15965 const SCEV *FromRewritten = GetMaybeRewritten(From);
15966 const SCEV *To = nullptr;
15967
15968 switch (Predicate) {
15969 case CmpInst::ICMP_ULT:
15970 case CmpInst::ICMP_ULE:
15971 To = SE.getUMinExpr(FromRewritten, RHS);
15972 if (auto *UMax = dyn_cast<SCEVUMaxExpr>(FromRewritten))
15973 EnqueueOperands(UMax);
15974 break;
15975 case CmpInst::ICMP_SLT:
15976 case CmpInst::ICMP_SLE:
15977 To = SE.getSMinExpr(FromRewritten, RHS);
15978 if (auto *SMax = dyn_cast<SCEVSMaxExpr>(FromRewritten))
15979 EnqueueOperands(SMax);
15980 break;
15981 case CmpInst::ICMP_UGT:
15982 case CmpInst::ICMP_UGE:
15983 To = SE.getUMaxExpr(FromRewritten, RHS);
15984 if (auto *UMin = dyn_cast<SCEVUMinExpr>(FromRewritten))
15985 EnqueueOperands(UMin);
15986 break;
15987 case CmpInst::ICMP_SGT:
15988 case CmpInst::ICMP_SGE:
15989 To = SE.getSMaxExpr(FromRewritten, RHS);
15990 if (auto *SMin = dyn_cast<SCEVSMinExpr>(FromRewritten))
15991 EnqueueOperands(SMin);
15992 break;
15993 case CmpInst::ICMP_EQ:
15995 To = RHS;
15996 break;
15997 case CmpInst::ICMP_NE:
15998 if (match(RHS, m_scev_Zero())) {
15999 const SCEV *OneAlignedUp =
16000 getNextSCEVDivisibleByDivisor(One, DividesBy, SE);
16001 To = SE.getUMaxExpr(FromRewritten, OneAlignedUp);
16002 } else {
16003 // LHS != RHS can be rewritten as (LHS - RHS) = UMax(1, LHS - RHS),
16004 // but creating the subtraction eagerly is expensive. Track the
16005 // inequalities in a separate map, and materialize the rewrite lazily
16006 // when encountering a suitable subtraction while re-writing.
16007 if (LHS->getType()->isPointerTy()) {
16011 break;
16012 }
16013 const SCEVConstant *C;
16014 const SCEV *A, *B;
16017 RHS = A;
16018 LHS = B;
16019 }
16020 if (LHS > RHS)
16021 std::swap(LHS, RHS);
16022 Guards.NotEqual.insert({LHS, RHS});
16023 continue;
16024 }
16025 break;
16026 default:
16027 break;
16028 }
16029
16030 if (To)
16031 AddRewrite(From, FromRewritten, To);
16032 }
16033 };
16034
16036 // First, collect information from assumptions dominating the loop.
16037 for (auto &AssumeVH : SE.AC.assumptions()) {
16038 if (!AssumeVH)
16039 continue;
16040 auto *AssumeI = cast<CallInst>(AssumeVH);
16041 if (!SE.DT.dominates(AssumeI, Block))
16042 continue;
16043 Terms.emplace_back(AssumeI->getOperand(0), true);
16044 }
16045
16046 // Second, collect information from llvm.experimental.guards dominating the loop.
16047 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
16048 SE.F.getParent(), Intrinsic::experimental_guard);
16049 if (GuardDecl)
16050 for (const auto *GU : GuardDecl->users())
16051 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
16052 if (Guard->getFunction() == Block->getParent() &&
16053 SE.DT.dominates(Guard, Block))
16054 Terms.emplace_back(Guard->getArgOperand(0), true);
16055
16056 // Third, collect conditions from dominating branches. Starting at the loop
16057 // predecessor, climb up the predecessor chain, as long as there are
16058 // predecessors that can be found that have unique successors leading to the
16059 // original header.
16060 // TODO: share this logic with isLoopEntryGuardedByCond.
16061 unsigned NumCollectedConditions = 0;
16063 std::pair<const BasicBlock *, const BasicBlock *> Pair(Pred, Block);
16064 for (; Pair.first;
16065 Pair = SE.getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
16066 VisitedBlocks.insert(Pair.second);
16067 const CondBrInst *LoopEntryPredicate =
16068 dyn_cast<CondBrInst>(Pair.first->getTerminator());
16069 if (!LoopEntryPredicate)
16070 continue;
16071
16072 Terms.emplace_back(LoopEntryPredicate->getCondition(),
16073 LoopEntryPredicate->getSuccessor(0) == Pair.second);
16074 NumCollectedConditions++;
16075
16076 // If we are recursively collecting guards stop after 2
16077 // conditions to limit compile-time impact for now.
16078 if (Depth > 0 && NumCollectedConditions == 2)
16079 break;
16080 }
16081 // Finally, if we stopped climbing the predecessor chain because
16082 // there wasn't a unique one to continue, try to collect conditions
16083 // for PHINodes by recursively following all of their incoming
16084 // blocks and try to merge the found conditions to build a new one
16085 // for the Phi.
16086 if (Pair.second->hasNPredecessorsOrMore(2) &&
16088 SmallDenseMap<const BasicBlock *, LoopGuards> IncomingGuards;
16089 for (auto &Phi : Pair.second->phis())
16090 collectFromPHI(SE, Guards, Phi, VisitedBlocks, IncomingGuards, Depth);
16091 }
16092
16093 // Now apply the information from the collected conditions to
16094 // Guards.RewriteMap. Conditions are processed in reverse order, so the
16095 // earliest conditions is processed first, except guards with divisibility
16096 // information, which are moved to the back. This ensures the SCEVs with the
16097 // shortest dependency chains are constructed first.
16099 GuardsToProcess;
16100 for (auto [Term, EnterIfTrue] : reverse(Terms)) {
16101 SmallVector<Value *, 8> Worklist;
16102 SmallPtrSet<Value *, 8> Visited;
16103 Worklist.push_back(Term);
16104 while (!Worklist.empty()) {
16105 Value *Cond = Worklist.pop_back_val();
16106 if (!Visited.insert(Cond).second)
16107 continue;
16108
16109 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
16110 auto Predicate =
16111 EnterIfTrue ? Cmp->getPredicate() : Cmp->getInversePredicate();
16112 const auto *LHS = SE.getSCEV(Cmp->getOperand(0));
16113 const auto *RHS = SE.getSCEV(Cmp->getOperand(1));
16114 // If LHS is a constant, apply information to the other expression.
16115 // TODO: If LHS is not a constant, check if using CompareSCEVComplexity
16116 // can improve results.
16117 if (isa<SCEVConstant>(LHS)) {
16118 std::swap(LHS, RHS);
16120 }
16121 GuardsToProcess.emplace_back(Predicate, LHS, RHS);
16122 continue;
16123 }
16124
16125 Value *L, *R;
16126 if (EnterIfTrue ? match(Cond, m_LogicalAnd(m_Value(L), m_Value(R)))
16127 : match(Cond, m_LogicalOr(m_Value(L), m_Value(R)))) {
16128 Worklist.push_back(L);
16129 Worklist.push_back(R);
16130 }
16131 }
16132 }
16133
16134 // Process divisibility guards in reverse order to populate DivGuards early.
16135 DenseMap<const SCEV *, APInt> Multiples;
16136 LoopGuards DivGuards(SE);
16137 for (const auto &[Predicate, LHS, RHS] : GuardsToProcess) {
16138 if (!isDivisibilityGuard(LHS, RHS, SE))
16139 continue;
16140 collectDivisibilityInformation(Predicate, LHS, RHS, DivGuards.RewriteMap,
16141 Multiples, SE);
16142 }
16143
16144 for (const auto &[Predicate, LHS, RHS] : GuardsToProcess)
16145 CollectCondition(Predicate, LHS, RHS, Guards.RewriteMap, DivGuards);
16146
16147 // Apply divisibility information last. This ensures it is applied to the
16148 // outermost expression after other rewrites for the given value.
16149 for (const auto &[K, Divisor] : Multiples) {
16150 const SCEV *DivisorSCEV = SE.getConstant(Divisor);
16151 Guards.RewriteMap[K] =
16153 Guards.rewrite(K), Divisor, SE),
16154 DivisorSCEV),
16155 DivisorSCEV);
16156 ExprsToRewrite.push_back(K);
16157 }
16158
16159 // Let the rewriter preserve NUW/NSW flags if the unsigned/signed ranges of
16160 // the replacement expressions are contained in the ranges of the replaced
16161 // expressions.
16162 Guards.PreserveNUW = true;
16163 Guards.PreserveNSW = true;
16164 for (const SCEV *Expr : ExprsToRewrite) {
16165 const SCEV *RewriteTo = Guards.RewriteMap[Expr];
16166 Guards.PreserveNUW &=
16167 SE.getUnsignedRange(Expr).contains(SE.getUnsignedRange(RewriteTo));
16168 Guards.PreserveNSW &=
16169 SE.getSignedRange(Expr).contains(SE.getSignedRange(RewriteTo));
16170 }
16171
16172 // Now that all rewrite information is collect, rewrite the collected
16173 // expressions with the information in the map. This applies information to
16174 // sub-expressions.
16175 if (ExprsToRewrite.size() > 1) {
16176 for (const SCEV *Expr : ExprsToRewrite) {
16177 const SCEV *RewriteTo = Guards.RewriteMap[Expr];
16178 Guards.RewriteMap.erase(Expr);
16179 Guards.RewriteMap.insert({Expr, Guards.rewrite(RewriteTo)});
16180 }
16181 }
16182}
16183
16185 /// A rewriter to replace SCEV expressions in Map with the corresponding entry
16186 /// in the map. It skips AddRecExpr because we cannot guarantee that the
16187 /// replacement is loop invariant in the loop of the AddRec.
16188 class SCEVLoopGuardRewriter
16189 : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
16192
16194
16195 public:
16196 SCEVLoopGuardRewriter(ScalarEvolution &SE,
16197 const ScalarEvolution::LoopGuards &Guards)
16198 : SCEVRewriteVisitor(SE), Map(Guards.RewriteMap),
16199 NotEqual(Guards.NotEqual) {
16200 if (Guards.PreserveNUW)
16201 FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNUW);
16202 if (Guards.PreserveNSW)
16203 FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNSW);
16204 }
16205
16206 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
16207
16208 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
16209 return Map.lookup_or(Expr, Expr);
16210 }
16211
16212 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
16213 if (const SCEV *S = Map.lookup(Expr))
16214 return S;
16215
16216 // If we didn't find the extact ZExt expr in the map, check if there's
16217 // an entry for a smaller ZExt we can use instead.
16218 Type *Ty = Expr->getType();
16219 const SCEV *Op = Expr->getOperand(0);
16220 unsigned Bitwidth = Ty->getScalarSizeInBits() / 2;
16221 while (Bitwidth % 8 == 0 && Bitwidth >= 8 &&
16222 Bitwidth > Op->getType()->getScalarSizeInBits()) {
16223 Type *NarrowTy = IntegerType::get(SE.getContext(), Bitwidth);
16224 auto *NarrowExt = SE.getZeroExtendExpr(Op, NarrowTy);
16225 if (const SCEV *S = Map.lookup(NarrowExt))
16226 return SE.getZeroExtendExpr(S, Ty);
16227 Bitwidth = Bitwidth / 2;
16228 }
16229
16231 Expr);
16232 }
16233
16234 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
16235 if (const SCEV *S = Map.lookup(Expr))
16236 return S;
16238 Expr);
16239 }
16240
16241 const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) {
16242 if (const SCEV *S = Map.lookup(Expr))
16243 return S;
16245 }
16246
16247 const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) {
16248 if (const SCEV *S = Map.lookup(Expr))
16249 return S;
16251 }
16252
16253 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
16254 // Helper to check if S is a subtraction (A - B) where A != B, and if so,
16255 // return UMax(S, 1).
16256 auto RewriteSubtraction = [&](const SCEV *S) -> const SCEV * {
16257 SCEVUse LHS, RHS;
16258 if (MatchBinarySub(S, LHS, RHS)) {
16259 if (LHS > RHS)
16260 std::swap(LHS, RHS);
16261 if (NotEqual.contains({LHS, RHS})) {
16262 const SCEV *OneAlignedUp = getNextSCEVDivisibleByDivisor(
16263 SE.getOne(S->getType()), SE.getConstantMultiple(S), SE);
16264 return SE.getUMaxExpr(OneAlignedUp, S);
16265 }
16266 }
16267 return nullptr;
16268 };
16269
16270 // Check if Expr itself is a subtraction pattern with guard info.
16271 if (const SCEV *Rewritten = RewriteSubtraction(Expr))
16272 return Rewritten;
16273
16274 // Trip count expressions sometimes consist of adding 3 operands, i.e.
16275 // (Const + A + B). There may be guard info for A + B, and if so, apply
16276 // it.
16277 // TODO: Could more generally apply guards to Add sub-expressions.
16278 if (isa<SCEVConstant>(Expr->getOperand(0)) &&
16279 Expr->getNumOperands() == 3) {
16280 const SCEV *Add =
16281 SE.getAddExpr(Expr->getOperand(1), Expr->getOperand(2));
16282 if (const SCEV *Rewritten = RewriteSubtraction(Add))
16283 return SE.getAddExpr(
16284 Expr->getOperand(0), Rewritten,
16285 ScalarEvolution::maskFlags(Expr->getNoWrapFlags(), FlagMask));
16286 if (const SCEV *S = Map.lookup(Add))
16287 return SE.getAddExpr(Expr->getOperand(0), S);
16288 }
16289 SmallVector<SCEVUse, 2> Operands;
16290 bool Changed = false;
16291 for (SCEVUse Op : Expr->operands()) {
16292 Operands.push_back(
16294 Changed |= Op != Operands.back();
16295 }
16296 // We are only replacing operands with equivalent values, so transfer the
16297 // flags from the original expression.
16298 return !Changed ? Expr
16299 : SE.getAddExpr(Operands,
16301 Expr->getNoWrapFlags(), FlagMask));
16302 }
16303
16304 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
16305 SmallVector<SCEVUse, 2> Operands;
16306 bool Changed = false;
16307 for (SCEVUse Op : Expr->operands()) {
16308 Operands.push_back(
16310 Changed |= Op != Operands.back();
16311 }
16312 // We are only replacing operands with equivalent values, so transfer the
16313 // flags from the original expression.
16314 return !Changed ? Expr
16315 : SE.getMulExpr(Operands,
16317 Expr->getNoWrapFlags(), FlagMask));
16318 }
16319 };
16320
16321 if (RewriteMap.empty() && NotEqual.empty())
16322 return Expr;
16323
16324 SCEVLoopGuardRewriter Rewriter(SE, *this);
16325 return Rewriter.visit(Expr);
16326}
16327
16328const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
16329 return applyLoopGuards(Expr, LoopGuards::collect(L, *this));
16330}
16331
16333 const LoopGuards &Guards) {
16334 return Guards.rewrite(Expr);
16335}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
constexpr LLT S1
Rewrite undef for PHI
This file implements a class to represent arbitrary precision integral constant values and operations...
@ PostInc
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Expand Atomic instructions
#define X(NUM, ENUM, NAME)
Definition ELF.h:851
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
#define LLVM_DUMP_METHOD
Mark debug helper function definitions like dump() that should not be stripped from debug builds.
Definition Compiler.h:661
This file contains the declarations for the subclasses of Constant, which represent the different fla...
SmallPtrSet< const BasicBlock *, 8 > VisitedBlocks
This file defines the DenseMap class.
This file builds on the ADT/GraphTraits.h file to build generic depth first graph iterator.
This file defines a hash set that can be used to remove duplication of nodes in a graph.
#define op(i)
Hexagon Common GEP
Value * getPointer(Value *Ptr)
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
This defines the Use class.
iv Induction Variable Users
Definition IVUsers.cpp:48
const AbstractManglingParser< Derived, Alloc >::OperatorInfo AbstractManglingParser< Derived, Alloc >::Ops[]
static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, AssumptionCache *AC)
Definition Lint.cpp:539
#define F(x, y, z)
Definition MD5.cpp:54
#define I(x, y, z)
Definition MD5.cpp:57
#define G(x, y, z)
Definition MD5.cpp:55
#define T
#define T1
static constexpr unsigned SM(unsigned Version)
ConstantRange Range(APInt(BitWidth, Low), APInt(BitWidth, High))
uint64_t IntrinsicInst * II
#define P(N)
ppc ctr loops verify
PowerPC Reduce CR logical Operation
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition PassSupport.h:42
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition PassSupport.h:44
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition PassSupport.h:39
R600 Clause Merge
const SmallVectorImpl< MachineOperand > & Cond
static DominatorTree getDomTree(Function &F)
static bool isValid(const char C)
Returns true if C is a valid mangled character: <0-9a-zA-Z_>.
SI optimize exec mask operations pre RA
static void visit(BasicBlock &Start, std::function< bool(BasicBlock *)> op)
This file contains some templates that are useful if you are working with the STL at all.
This file provides utility classes that use RAII to save and restore values.
bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind, SCEVTypes RootKind)
static cl::opt< unsigned > MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden, cl::desc("Max coefficients in AddRec during evolving"), cl::init(8))
static cl::opt< unsigned > RangeIterThreshold("scev-range-iter-threshold", cl::Hidden, cl::desc("Threshold for switching to iteratively computing SCEV ranges"), cl::init(32))
static const Loop * isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI)
static unsigned getConstantTripCount(const SCEVConstant *ExitCount)
static int CompareValueComplexity(const LoopInfo *const LI, Value *LV, Value *RV, unsigned Depth)
Compare the two values LV and RV in terms of their "complexity" where "complexity" is a partial (and ...
static const SCEV * getNextSCEVDivisibleByDivisor(const SCEV *Expr, const APInt &DivisorVal, ScalarEvolution &SE)
static void PushLoopPHIs(const Loop *L, SmallVectorImpl< Instruction * > &Worklist, SmallPtrSetImpl< Instruction * > &Visited)
Push PHI nodes in the header of the given loop onto the given Worklist.
static void insertFoldCacheEntry(const ScalarEvolution::FoldID &ID, const SCEV *S, DenseMap< ScalarEvolution::FoldID, const SCEV * > &FoldCache, DenseMap< const SCEV *, SmallVector< ScalarEvolution::FoldID, 2 > > &FoldCacheUser)
static cl::opt< bool > ClassifyExpressions("scalar-evolution-classify-expressions", cl::Hidden, cl::init(true), cl::desc("When printing analysis, include information on every instruction"))
static bool hasHugeExpression(ArrayRef< SCEVUse > Ops)
Returns true if Ops contains a huge SCEV (the subtree of S contains at least HugeExprThreshold nodes)...
static bool CanConstantFold(const Instruction *I)
Return true if we can constant fold an instruction of the specified type, assuming that all operands ...
static cl::opt< unsigned > AddOpsInlineThreshold("scev-addops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining addition operands into a SCEV"), cl::init(500))
static cl::opt< unsigned > MaxLoopGuardCollectionDepth("scalar-evolution-max-loop-guard-collection-depth", cl::Hidden, cl::desc("Maximum depth for recursive loop guard collection"), cl::init(1))
static cl::opt< bool > VerifyIR("scev-verify-ir", cl::Hidden, cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"), cl::init(false))
static bool RangeRefPHIAllowedOperands(DominatorTree &DT, PHINode *PHI)
static const SCEV * getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static std::optional< APInt > MinOptional(std::optional< APInt > X, std::optional< APInt > Y)
Helper function to compare optional APInts: (a) if X and Y both exist, return min(X,...
static cl::opt< unsigned > MulOpsInlineThreshold("scev-mulops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining multiplication operands into a SCEV"), cl::init(32))
static BinaryOperator * getCommonInstForPHI(PHINode *PN)
static bool isDivisibilityGuard(const SCEV *LHS, const SCEV *RHS, ScalarEvolution &SE)
static std::optional< const SCEV * > createNodeForSelectViaUMinSeq(ScalarEvolution *SE, const SCEV *CondExpr, const SCEV *TrueExpr, const SCEV *FalseExpr)
static Constant * BuildConstantFromSCEV(const SCEV *V)
This builds up a Constant using the ConstantExpr interface.
static ConstantInt * EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C, ScalarEvolution &SE)
static const SCEV * BinomialCoefficient(const SCEV *It, unsigned K, ScalarEvolution &SE, Type *ResultTy)
Compute BC(It, K). The result has width W. Assume, K > 0.
static cl::opt< unsigned > MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden, cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"), cl::init(8))
static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr, const SCEV *Candidate)
Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
static PHINode * getConstantEvolvingPHI(Value *V, const Loop *L)
getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node in the loop that V is deri...
static const SCEV * SolveLinEquationWithOverflow(const APInt &A, const SCEV *B, SmallVectorImpl< const SCEVPredicate * > *Predicates, ScalarEvolution &SE, const Loop *L)
Finds the minimum unsigned root of the following equation:
static cl::opt< unsigned > MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden, cl::desc("Maximum number of iterations SCEV will " "symbolically execute a constant " "derived loop"), cl::init(100))
static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow)
static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV *S)
When printing a top-level SCEV for trip counts, it's helpful to include a type for constants which ar...
static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, const Loop *L)
static SCEV::NoWrapFlags StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, ArrayRef< SCEVUse > Ops, SCEV::NoWrapFlags Flags)
static bool containsConstantInAddMulChain(const SCEV *StartExpr)
Determine if any of the operands in this SCEV are a constant or if any of the add or multiply express...
static const SCEV * getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static bool CollectAddOperandsWithScales(SmallDenseMap< SCEVUse, APInt, 16 > &M, SmallVectorImpl< SCEVUse > &NewOps, APInt &AccumulatedConstant, ArrayRef< SCEVUse > Ops, const APInt &Scale, ScalarEvolution &SE)
Process the given Ops list, which is a list of operands to be added under the given scale,...
static const SCEV * constantFoldAndGroupOps(ScalarEvolution &SE, LoopInfo &LI, DominatorTree &DT, SmallVectorImpl< SCEVUse > &Ops, FoldT Fold, IsIdentityT IsIdentity, IsAbsorberT IsAbsorber)
Performs a number of common optimizations on the passed Ops.
static cl::opt< unsigned > MaxPhiSCCAnalysisSize("scalar-evolution-max-scc-analysis-depth", cl::Hidden, cl::desc("Maximum amount of nodes to process while searching SCEVUnknown " "Phi strongly connected components"), cl::init(8))
static bool IsKnownPredicateViaAddRecStart(ScalarEvolution &SE, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
static void GroupByComplexity(SmallVectorImpl< SCEVUse > &Ops, LoopInfo *LI, DominatorTree &DT)
Given a list of SCEV objects, order them by their complexity, and group objects of the same complexit...
static bool collectDivisibilityInformation(ICmpInst::Predicate Predicate, const SCEV *LHS, const SCEV *RHS, DenseMap< const SCEV *, const SCEV * > &DivInfo, DenseMap< const SCEV *, APInt > &Multiples, ScalarEvolution &SE)
static cl::opt< unsigned > MaxSCEVOperationsImplicationDepth("scalar-evolution-max-scev-operations-implication-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV operations implication analysis"), cl::init(2))
static void PushDefUseChildren(Instruction *I, SmallVectorImpl< Instruction * > &Worklist, SmallPtrSetImpl< Instruction * > &Visited)
Push users of the given Instruction onto the given Worklist.
static std::optional< APInt > SolveQuadraticAddRecRange(const SCEVAddRecExpr *AddRec, const ConstantRange &Range, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n iterations.
static cl::opt< bool > UseContextForNoWrapFlagInference("scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden, cl::desc("Infer nuw/nsw flags using context where suitable"), cl::init(true))
static cl::opt< bool > EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden, cl::desc("Handle <= and >= in finite loops"), cl::init(true))
static bool getOperandsForSelectLikePHI(DominatorTree &DT, PHINode *PN, Value *&Cond, Value *&LHS, Value *&RHS)
static std::optional< std::tuple< APInt, APInt, APInt, APInt, unsigned > > GetQuadraticEquation(const SCEVAddRecExpr *AddRec)
For a given quadratic addrec, generate coefficients of the corresponding quadratic equation,...
static bool isKnownPredicateExtendIdiom(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
static std::optional< BinaryOp > MatchBinaryOp(Value *V, const DataLayout &DL, AssumptionCache &AC, const DominatorTree &DT, const Instruction *CxtI)
Try to map V into a BinaryOp, and return std::nullopt on failure.
static std::optional< APInt > SolveQuadraticAddRecExact(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n iterations.
static std::optional< APInt > TruncIfPossible(std::optional< APInt > X, unsigned BitWidth)
Helper function to truncate an optional APInt to a given BitWidth.
static cl::opt< unsigned > MaxSCEVCompareDepth("scalar-evolution-max-scev-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV complexity comparisons"), cl::init(32))
static APInt extractConstantWithoutWrapping(ScalarEvolution &SE, const SCEVConstant *ConstantTerm, const SCEVAddExpr *WholeAddExpr)
static cl::opt< unsigned > MaxConstantEvolvingDepth("scalar-evolution-max-constant-evolving-depth", cl::Hidden, cl::desc("Maximum depth of recursive constant evolving"), cl::init(32))
static ConstantRange getRangeForAffineARHelper(APInt Step, const ConstantRange &StartRange, const APInt &MaxBECount, bool Signed)
static bool MatchBinarySub(const SCEV *S, SCEVUse &LHS, SCEVUse &RHS)
static std::optional< ConstantRange > GetRangeFromMetadata(Value *V)
Helper method to assign a range to V from metadata present in the IR.
static cl::opt< unsigned > HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden, cl::desc("Size of the expression which is considered huge"), cl::init(4096))
static Type * isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI, bool &Signed, ScalarEvolution &SE)
Helper function to createAddRecFromPHIWithCasts.
static Constant * EvaluateExpression(Value *V, const Loop *L, DenseMap< Instruction *, Constant * > &Vals, const DataLayout &DL, const TargetLibraryInfo *TLI)
EvaluateExpression - Given an expression that passes the getConstantEvolvingPHI predicate,...
static const SCEV * getPreviousSCEVDivisibleByDivisor(const SCEV *Expr, const APInt &DivisorVal, ScalarEvolution &SE)
static const SCEV * MatchNotExpr(const SCEV *Expr)
If Expr computes ~A, return A else return nullptr.
static cl::opt< unsigned > MaxValueCompareDepth("scalar-evolution-max-value-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive value complexity comparisons"), cl::init(2))
static const SCEV * applyDivisibilityOnMinMaxExpr(const SCEV *MinMaxExpr, APInt Divisor, ScalarEvolution &SE)
static cl::opt< bool, true > VerifySCEVOpt("verify-scev", cl::Hidden, cl::location(VerifySCEV), cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"))
static const SCEV * getSignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
static cl::opt< unsigned > MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden, cl::desc("Maximum depth of recursive arithmetics"), cl::init(32))
static bool HasSameValue(const SCEV *A, const SCEV *B)
SCEV structural equivalence is usually sufficient for testing whether two expressions are equal,...
static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow)
Compute the result of "n choose k", the binomial coefficient.
static std::optional< int > CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS, const SCEV *RHS, DominatorTree &DT, unsigned Depth=0)
static bool canConstantEvolve(Instruction *I, const Loop *L)
Determine whether this instruction can constant evolve within this loop assuming its operands can all...
static PHINode * getConstantEvolvingPHIOperands(Instruction *UseInst, const Loop *L, DenseMap< Instruction *, PHINode * > &PHIMap, unsigned Depth)
getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by recursing through each instructi...
static bool scevUnconditionallyPropagatesPoisonFromOperands(SCEVTypes Kind)
static cl::opt< bool > VerifySCEVStrict("verify-scev-strict", cl::Hidden, cl::desc("Enable stricter verification with -verify-scev is passed"))
static Constant * getOtherIncomingValue(PHINode *PN, BasicBlock *BB)
static cl::opt< bool > UseExpensiveRangeSharpening("scalar-evolution-use-expensive-range-sharpening", cl::Hidden, cl::init(false), cl::desc("Use more powerful methods of sharpening expression ranges. May " "be costly in terms of compile time"))
static const SCEV * getUnsignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Is LHS Pred RHS true on the virtue of LHS or RHS being a Min or Max expression?
static bool BrPHIToSelect(DominatorTree &DT, CondBrInst *BI, PHINode *Merge, Value *&C, Value *&LHS, Value *&RHS)
This file defines the make_scope_exit function, which executes user-defined cleanup logic at scope ex...
static bool InBlock(const Value *V, const BasicBlock *BB)
Provides some synthesis utilities to produce sequences of values.
This file defines the SmallPtrSet class.
This file defines the SmallVector class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition Statistic.h:171
This file contains some functions that are useful when dealing with strings.
#define LLVM_DEBUG(...)
Definition Debug.h:114
static TableGen::Emitter::Opt Y("gen-skeleton-entry", EmitSkeleton, "Generate example skeleton entry")
static SymbolRef::Type getType(const Symbol *Sym)
Definition TapiFile.cpp:39
LocallyHashedType DenseMapInfo< LocallyHashedType >::Empty
static std::optional< unsigned > getOpcode(ArrayRef< VPValue * > Values)
Returns the opcode of Values or ~0 if they do not all agree.
Definition VPlanSLP.cpp:247
static std::optional< bool > isImpliedCondOperands(CmpInst::Predicate Pred, const Value *ALHS, const Value *ARHS, const Value *BLHS, const Value *BRHS)
Return true if "icmp Pred BLHS BRHS" is true whenever "icmp PredALHS ARHS" is true.
Virtual Register Rewriter
Value * RHS
Value * LHS
BinaryOperator * Mul
static const uint32_t IV[8]
Definition blake3_impl.h:83
SCEVCastSinkingRewriter(ScalarEvolution &SE, Type *TargetTy, ConversionFn CreatePtrCast)
static const SCEV * rewrite(const SCEV *Scev, ScalarEvolution &SE, Type *TargetTy, ConversionFn CreatePtrCast)
const SCEV * visitUnknown(const SCEVUnknown *Expr)
const SCEV * visitMulExpr(const SCEVMulExpr *Expr)
const SCEV * visitAddExpr(const SCEVAddExpr *Expr)
const SCEV * visit(const SCEV *S)
Class for arbitrary precision integers.
Definition APInt.h:78
LLVM_ABI APInt umul_ov(const APInt &RHS, bool &Overflow) const
Definition APInt.cpp:2011
LLVM_ABI APInt zext(unsigned width) const
Zero extend to a new width.
Definition APInt.cpp:1043
bool isMinSignedValue() const
Determine if this is the smallest signed value.
Definition APInt.h:424
uint64_t getZExtValue() const
Get zero extended value.
Definition APInt.h:1555
void setHighBits(unsigned hiBits)
Set the top hiBits bits.
Definition APInt.h:1406
LLVM_ABI APInt getHiBits(unsigned numBits) const
Compute an APInt containing numBits highbits from this APInt.
Definition APInt.cpp:639
unsigned getActiveBits() const
Compute the number of active bits in the value.
Definition APInt.h:1527
LLVM_ABI APInt trunc(unsigned width) const
Truncate to new width.
Definition APInt.cpp:956
static APInt getMaxValue(unsigned numBits)
Gets maximum unsigned value of APInt for specific bit width.
Definition APInt.h:207
APInt abs() const
Get the absolute value.
Definition APInt.h:1810
bool sgt(const APInt &RHS) const
Signed greater than comparison.
Definition APInt.h:1208
bool isAllOnes() const
Determine if all bits are set. This is true for zero-width values.
Definition APInt.h:372
bool ugt(const APInt &RHS) const
Unsigned greater than comparison.
Definition APInt.h:1189
bool isZero() const
Determine if this value is zero, i.e. all bits are clear.
Definition APInt.h:381
bool isSignMask() const
Check if the APInt's value is returned by getSignMask.
Definition APInt.h:467
LLVM_ABI APInt urem(const APInt &RHS) const
Unsigned remainder operation.
Definition APInt.cpp:1697
unsigned getBitWidth() const
Return the number of bits in the APInt.
Definition APInt.h:1503
bool ult(const APInt &RHS) const
Unsigned less than comparison.
Definition APInt.h:1118
static APInt getSignedMaxValue(unsigned numBits)
Gets maximum signed value of APInt for a specific bit width.
Definition APInt.h:210
static APInt getMinValue(unsigned numBits)
Gets minimum unsigned value of APInt for a specific bit width.
Definition APInt.h:217
bool isNegative() const
Determine sign of this APInt.
Definition APInt.h:330
bool sle(const APInt &RHS) const
Signed less or equal comparison.
Definition APInt.h:1173
static APInt getSignedMinValue(unsigned numBits)
Gets minimum signed value of APInt for a specific bit width.
Definition APInt.h:220
bool isNonPositive() const
Determine if this APInt Value is non-positive (<= 0).
Definition APInt.h:362
unsigned countTrailingZeros() const
Definition APInt.h:1662
bool isStrictlyPositive() const
Determine if this APInt Value is positive.
Definition APInt.h:357
unsigned logBase2() const
Definition APInt.h:1776
APInt ashr(unsigned ShiftAmt) const
Arithmetic right-shift function.
Definition APInt.h:834
LLVM_ABI APInt multiplicativeInverse() const
Definition APInt.cpp:1305
bool ule(const APInt &RHS) const
Unsigned less or equal comparison.
Definition APInt.h:1157
LLVM_ABI APInt sext(unsigned width) const
Sign extend to a new width.
Definition APInt.cpp:1016
APInt shl(unsigned shiftAmt) const
Left-shift function.
Definition APInt.h:880
bool isPowerOf2() const
Check if this APInt's value is a power of two greater than zero.
Definition APInt.h:441
static APInt getLowBitsSet(unsigned numBits, unsigned loBitsSet)
Constructs an APInt value that has the bottom loBitsSet bits set.
Definition APInt.h:307
bool isSignBitSet() const
Determine if sign bit of this APInt is set.
Definition APInt.h:342
bool slt(const APInt &RHS) const
Signed less than comparison.
Definition APInt.h:1137
static APInt getZero(unsigned numBits)
Get the '0' value for the specified bit-width.
Definition APInt.h:201
bool isIntN(unsigned N) const
Check if this APInt has an N-bits unsigned integer value.
Definition APInt.h:433
static APInt getOneBitSet(unsigned numBits, unsigned BitNo)
Return an APInt with exactly one bit set in the result.
Definition APInt.h:240
bool uge(const APInt &RHS) const
Unsigned greater or equal comparison.
Definition APInt.h:1228
This templated class represents "all analyses that operate over <aparticular IR unit>" (e....
Definition Analysis.h:50
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Represent the analysis usage information of a pass.
void setPreservesAll()
Set by analyses that do not transform their input at all.
AnalysisUsage & addRequiredTransitive()
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition ArrayRef.h:40
iterator end() const
Definition ArrayRef.h:131
size_t size() const
size - Get the array size.
Definition ArrayRef.h:142
iterator begin() const
Definition ArrayRef.h:130
A function analysis which provides an AssumptionCache.
An immutable pass that tracks lazily created AssumptionCache objects.
A cache of @llvm.assume calls within a function.
MutableArrayRef< WeakVH > assumptions()
Access the list of assumption handles currently tracked for this function.
LLVM Basic Block Representation.
Definition BasicBlock.h:62
iterator begin()
Instruction iterator methods.
Definition BasicBlock.h:461
const Function * getParent() const
Return the enclosing method, or null if none.
Definition BasicBlock.h:213
LLVM_ABI const BasicBlock * getSinglePredecessor() const
Return the predecessor of this block if it has a single predecessor block.
const Instruction & front() const
Definition BasicBlock.h:484
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction; assumes that the block is well-formed.
Definition BasicBlock.h:237
LLVM_ABI unsigned getNoWrapKind() const
Returns one of OBO::NoSignedWrap or OBO::NoUnsignedWrap.
LLVM_ABI Instruction::BinaryOps getBinaryOp() const
Returns the binary operation underlying the intrinsic.
BinaryOps getOpcode() const
Definition InstrTypes.h:374
This class represents a function call, abstracting a target machine's calling convention.
virtual void deleted()
Callback for Value destruction.
void setValPtr(Value *P)
bool isFalseWhenEqual() const
This is just a convenience.
Definition InstrTypes.h:948
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition InstrTypes.h:676
@ ICMP_SLT
signed less than
Definition InstrTypes.h:705
@ ICMP_SLE
signed less or equal
Definition InstrTypes.h:706
@ ICMP_UGE
unsigned greater or equal
Definition InstrTypes.h:700
@ ICMP_UGT
unsigned greater than
Definition InstrTypes.h:699
@ ICMP_SGT
signed greater than
Definition InstrTypes.h:703
@ ICMP_ULT
unsigned less than
Definition InstrTypes.h:701
@ ICMP_NE
not equal
Definition InstrTypes.h:698
@ ICMP_SGE
signed greater or equal
Definition InstrTypes.h:704
@ ICMP_ULE
unsigned less or equal
Definition InstrTypes.h:702
bool isSigned() const
Definition InstrTypes.h:930
Predicate getSwappedPredicate() const
For example, EQ->EQ, SLE->SGE, ULT->UGT, OEQ->OEQ, ULE->UGE, OLT->OGT, etc.
Definition InstrTypes.h:827
bool isTrueWhenEqual() const
This is just a convenience.
Definition InstrTypes.h:942
Predicate getInversePredicate() const
For example, EQ -> NE, UGT -> ULE, SLT -> SGE, OEQ -> UNE, UGT -> OLE, OLT -> UGE,...
Definition InstrTypes.h:789
bool isUnsigned() const
Definition InstrTypes.h:936
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
Definition InstrTypes.h:926
An abstraction over a floating-point predicate, and a pack of an integer predicate with samesign info...
static LLVM_ABI std::optional< CmpPredicate > getMatching(CmpPredicate A, CmpPredicate B)
Compares two CmpPredicates taking samesign into account and returns the canonicalized CmpPredicate if...
LLVM_ABI CmpInst::Predicate getPreferredSignedPredicate() const
Attempts to return a signed CmpInst::Predicate from the CmpPredicate.
CmpInst::Predicate dropSameSign() const
Drops samesign information.
Conditional Branch instruction.
Value * getCondition() const
BasicBlock * getSuccessor(unsigned i) const
static LLVM_ABI Constant * getNot(Constant *C)
static Constant * getPtrAdd(Constant *Ptr, Constant *Offset, GEPNoWrapFlags NW=GEPNoWrapFlags::none(), std::optional< ConstantRange > InRange=std::nullopt, Type *OnlyIfReduced=nullptr)
Create a getelementptr i8, ptr, offset constant expression.
Definition Constants.h:1472
static LLVM_ABI Constant * getPtrToInt(Constant *C, Type *Ty, bool OnlyIfReduced=false)
static LLVM_ABI Constant * getPtrToAddr(Constant *C, Type *Ty, bool OnlyIfReduced=false)
static LLVM_ABI Constant * getAdd(Constant *C1, Constant *C2, bool HasNUW=false, bool HasNSW=false)
static LLVM_ABI Constant * getNeg(Constant *C, bool HasNSW=false)
static LLVM_ABI Constant * getTrunc(Constant *C, Type *Ty, bool OnlyIfReduced=false)
This is the shared class of boolean and integer constants.
Definition Constants.h:87
bool isZero() const
This is just a convenience method to make client code smaller for a common code.
Definition Constants.h:219
static LLVM_ABI ConstantInt * getFalse(LLVMContext &Context)
uint64_t getZExtValue() const
Return the constant as a 64-bit unsigned integer value after it has been zero extended as appropriate...
Definition Constants.h:168
const APInt & getValue() const
Return the constant as an APInt value reference.
Definition Constants.h:159
static LLVM_ABI ConstantInt * getBool(LLVMContext &Context, bool V)
This class represents a range of values.
LLVM_ABI ConstantRange add(const ConstantRange &Other) const
Return a new range representing the possible values resulting from an addition of a value in this ran...
LLVM_ABI ConstantRange zextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
PreferredRangeType
If represented precisely, the result of some range operations may consist of multiple disjoint ranges...
LLVM_ABI bool getEquivalentICmp(CmpInst::Predicate &Pred, APInt &RHS) const
Set up Pred and RHS such that ConstantRange::makeExactICmpRegion(Pred, RHS) == *this.
const APInt & getLower() const
Return the lower value for this range.
LLVM_ABI ConstantRange urem(const ConstantRange &Other) const
Return a new range representing the possible values resulting from an unsigned remainder operation of...
LLVM_ABI bool isFullSet() const
Return true if this set contains all of the elements possible for this data-type.
LLVM_ABI bool icmp(CmpInst::Predicate Pred, const ConstantRange &Other) const
Does the predicate Pred hold between ranges this and Other?
LLVM_ABI bool isEmptySet() const
Return true if this set contains no members.
LLVM_ABI ConstantRange zeroExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
LLVM_ABI bool isSignWrappedSet() const
Return true if this set wraps around the signed domain.
LLVM_ABI APInt getSignedMin() const
Return the smallest signed value contained in the ConstantRange.
LLVM_ABI bool isWrappedSet() const
Return true if this set wraps around the unsigned domain.
LLVM_ABI void print(raw_ostream &OS) const
Print out the bounds to a stream.
LLVM_ABI ConstantRange truncate(uint32_t BitWidth, unsigned NoWrapKind=0) const
Return a new range in the specified integer type, which must be strictly smaller than the current typ...
LLVM_ABI ConstantRange signExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
const APInt & getUpper() const
Return the upper value for this range.
LLVM_ABI ConstantRange unionWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the union of this range with another range.
static LLVM_ABI ConstantRange makeExactICmpRegion(CmpInst::Predicate Pred, const APInt &Other)
Produce the exact range such that all values in the returned range satisfy the given predicate with a...
LLVM_ABI bool contains(const APInt &Val) const
Return true if the specified value is in the set.
LLVM_ABI APInt getUnsignedMax() const
Return the largest unsigned value contained in the ConstantRange.
LLVM_ABI ConstantRange intersectWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the intersection of this range with another range.
LLVM_ABI APInt getSignedMax() const
Return the largest signed value contained in the ConstantRange.
static ConstantRange getNonEmpty(APInt Lower, APInt Upper)
Create non-empty constant range with the given bounds.
static LLVM_ABI ConstantRange makeGuaranteedNoWrapRegion(Instruction::BinaryOps BinOp, const ConstantRange &Other, unsigned NoWrapKind)
Produce the largest range containing all X such that "X BinOp Y" is guaranteed not to wrap (overflow)...
LLVM_ABI unsigned getMinSignedBits() const
Compute the maximal number of bits needed to represent every value in this signed range.
uint32_t getBitWidth() const
Get the bit width of this ConstantRange.
LLVM_ABI ConstantRange sub(const ConstantRange &Other) const
Return a new range representing the possible values resulting from a subtraction of a value in this r...
LLVM_ABI ConstantRange sextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
static LLVM_ABI ConstantRange makeExactNoWrapRegion(Instruction::BinaryOps BinOp, const APInt &Other, unsigned NoWrapKind)
Produce the range that contains X if and only if "X BinOp Other" does not wrap.
This is an important base class in LLVM.
Definition Constant.h:43
A parsed version of the target data layout string in and methods for querying it.
Definition DataLayout.h:64
LLVM_ABI const StructLayout * getStructLayout(StructType *Ty) const
Returns a StructLayout object, indicating the alignment of the struct, its size, and the offsets of i...
LLVM_ABI IntegerType * getIntPtrType(LLVMContext &C, unsigned AddressSpace=0) const
Returns an integer type with size at least as big as that of a pointer in the given address space.
LLVM_ABI unsigned getIndexTypeSizeInBits(Type *Ty) const
The size in bits of the index used in GEP calculation for this type.
LLVM_ABI IntegerType * getIndexType(LLVMContext &C, unsigned AddressSpace) const
Returns the type of a GEP index in AddressSpace.
TypeSize getTypeSizeInBits(Type *Ty) const
Size examples:
Definition DataLayout.h:784
ValueT lookup(const_arg_type_t< KeyT > Val) const
lookup - Return the entry for the specified key, or a default constructed value if no such entry exis...
Definition DenseMap.h:205
iterator find(const_arg_type_t< KeyT > Val)
Definition DenseMap.h:178
std::pair< iterator, bool > try_emplace(KeyT &&Key, Ts &&...Args)
Definition DenseMap.h:256
DenseMapIterator< KeyT, ValueT, KeyInfoT, BucketT > iterator
Definition DenseMap.h:74
iterator find_as(const LookupKeyT &Val)
Alternate version of find() which allows a different, and possibly less expensive,...
Definition DenseMap.h:191
size_type count(const_arg_type_t< KeyT > Val) const
Return 1 if the specified key is in the map, 0 otherwise.
Definition DenseMap.h:174
iterator end()
Definition DenseMap.h:81
bool contains(const_arg_type_t< KeyT > Val) const
Return true if the specified key is in the map, false otherwise.
Definition DenseMap.h:169
void swap(DerivedT &RHS)
Definition DenseMap.h:371
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition DenseMap.h:241
Analysis pass which computes a DominatorTree.
Definition Dominators.h:278
Legacy analysis pass which computes a DominatorTree.
Definition Dominators.h:316
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition Dominators.h:159
LLVM_ABI bool isReachableFromEntry(const Use &U) const
Provide an overload for a Use.
LLVM_ABI bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
FoldingSetNodeIDRef - This class describes a reference to an interned FoldingSetNodeID,...
Definition FoldingSet.h:172
FoldingSetNodeID - This class is used to gather all the unique data bits of a node.
Definition FoldingSet.h:209
FunctionPass(char &pid)
Definition Pass.h:316
Represents flags for the getelementptr instruction/expression.
bool hasNoUnsignedSignedWrap() const
bool hasNoUnsignedWrap() const
static GEPNoWrapFlags none()
static LLVM_ABI Type * getTypeAtIndex(Type *Ty, Value *Idx)
Return the type of the element at the given index of an indexable type.
Module * getParent()
Get the module that this global value is contained inside of...
static bool isPrivateLinkage(LinkageTypes Linkage)
static bool isInternalLinkage(LinkageTypes Linkage)
This instruction compares its operands according to the predicate given to the constructor.
CmpPredicate getCmpPredicate() const
static bool isGE(Predicate P)
Return true if the predicate is SGE or UGE.
CmpPredicate getSwappedCmpPredicate() const
static LLVM_ABI bool compare(const APInt &LHS, const APInt &RHS, ICmpInst::Predicate Pred)
Return result of LHS Pred RHS comparison.
static bool isLT(Predicate P)
Return true if the predicate is SLT or ULT.
CmpPredicate getInverseCmpPredicate() const
Predicate getNonStrictCmpPredicate() const
For example, SGT -> SGE, SLT -> SLE, ULT -> ULE, UGT -> UGE.
static bool isGT(Predicate P)
Return true if the predicate is SGT or UGT.
Predicate getFlippedSignednessPredicate() const
For example, SLT->ULT, ULT->SLT, SLE->ULE, ULE->SLE, EQ->EQ.
static CmpPredicate getInverseCmpPredicate(CmpPredicate Pred)
static bool isEquality(Predicate P)
Return true if this predicate is either EQ or NE.
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
static bool isLE(Predicate P)
Return true if the predicate is SLE or ULE.
LLVM_ABI bool hasNoUnsignedWrap() const LLVM_READONLY
Determine whether the no unsigned wrap flag is set.
LLVM_ABI bool hasNoSignedWrap() const LLVM_READONLY
Determine whether the no signed wrap flag is set.
LLVM_ABI bool isIdenticalToWhenDefined(const Instruction *I, bool IntersectAttrs=false) const LLVM_READONLY
This is like isIdenticalTo, except that it ignores the SubclassOptionalData flags,...
Class to represent integer types.
static LLVM_ABI IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
Definition Type.cpp:354
A helper class to return the specified delimiter string after the first invocation of operator String...
An instruction for reading from memory.
Analysis pass that exposes the LoopInfo for a function.
Definition LoopInfo.h:569
bool contains(const LoopT *L) const
Return true if the specified loop is contained within in this loop.
BlockT * getHeader() const
unsigned getLoopDepth() const
Return the nesting level of this loop.
BlockT * getLoopPredecessor() const
If the given loop's header has exactly one unique predecessor outside the loop, return it.
LoopT * getParentLoop() const
Return the parent loop if it exists or nullptr for top level loops.
unsigned getLoopDepth(const BlockT *BB) const
Return the loop nesting level of the specified block.
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
The legacy pass manager's analysis pass to compute loop information.
Definition LoopInfo.h:596
Represents a single loop in the control flow graph.
Definition LoopInfo.h:40
bool isLoopInvariant(const Value *V) const
Return true if the specified value is loop invariant.
Definition LoopInfo.cpp:67
Metadata node.
Definition Metadata.h:1080
A Module instance is used to store all the information related to an LLVM module.
Definition Module.h:67
unsigned getOpcode() const
Return the opcode for this Instruction or ConstantExpr.
Definition Operator.h:43
Utility class for integer operators which may exhibit overflow - Add, Sub, Mul, and Shl.
Definition Operator.h:78
bool hasNoSignedWrap() const
Test whether this operation is known to never undergo signed overflow, aka the nsw property.
Definition Operator.h:111
bool hasNoUnsignedWrap() const
Test whether this operation is known to never undergo unsigned overflow, aka the nuw property.
Definition Operator.h:105
iterator_range< const_block_iterator > blocks() const
op_range incoming_values()
Value * getIncomingValueForBlock(const BasicBlock *BB) const
BasicBlock * getIncomingBlock(unsigned i) const
Return incoming basic block number i.
Value * getIncomingValue(unsigned i) const
Return incoming value number x.
unsigned getNumIncomingValues() const
Return the number of incoming edges.
AnalysisType & getAnalysis() const
getAnalysis<AnalysisType>() - This function is used by subclasses to get to the analysis information ...
PointerIntPair - This class implements a pair of a pointer and small integer.
static PointerType * getUnqual(Type *ElementType)
This constructs a pointer to an object of the specified type in the default address space (address sp...
static LLVM_ABI PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
LLVM_ABI void addPredicate(const SCEVPredicate &Pred)
Adds a new predicate.
LLVM_ABI const SCEVPredicate & getPredicate() const
LLVM_ABI const SCEV * getPredicatedSCEV(const SCEV *Expr)
Returns the rewritten SCEV for Expr in the context of the current SCEV predicate.
LLVM_ABI bool hasNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags)
Returns true if we've proved that V doesn't wrap by means of a SCEV predicate.
LLVM_ABI void setNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags)
Proves that V doesn't overflow by adding SCEV predicate.
LLVM_ABI void print(raw_ostream &OS, unsigned Depth) const
Print the SCEV mappings done by the Predicated Scalar Evolution.
LLVM_ABI bool areAddRecsEqualWithPreds(const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const
Check if AR1 and AR2 are equal, while taking into account Equal predicates in Preds.
LLVM_ABI PredicatedScalarEvolution(ScalarEvolution &SE, Loop &L)
LLVM_ABI const SCEVAddRecExpr * getAsAddRec(Value *V)
Attempts to produce an AddRecExpr for V by adding additional SCEV predicates.
LLVM_ABI unsigned getSmallConstantMaxTripCount()
Returns the upper bound of the loop trip count as a normal unsigned value, or 0 if the trip count is ...
LLVM_ABI const SCEV * getBackedgeTakenCount()
Get the (predicated) backedge count for the analyzed loop.
LLVM_ABI const SCEV * getSymbolicMaxBackedgeTakenCount()
Get the (predicated) symbolic max backedge count for the analyzed loop.
LLVM_ABI const SCEV * getSCEV(Value *V)
Returns the SCEV expression of V, in the context of the current SCEV predicate.
A set of analyses that are preserved following a run of a transformation pass.
Definition Analysis.h:112
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition Analysis.h:118
PreservedAnalysisChecker getChecker() const
Build a checker for this PreservedAnalyses and the specified analysis type.
Definition Analysis.h:275
constexpr bool isValid() const
Definition Register.h:112
This node represents an addition of some number of SCEVs.
This node represents a polynomial recurrence on the trip count of the specified loop.
LLVM_ABI const SCEV * evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const
Return the value of this chain of recurrences at the specified iteration number.
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a recurrence without clearing any previously set flags.
bool isAffine() const
Return true if this represents an expression A + B*x where A and B are loop invariant values.
bool isQuadratic() const
Return true if this represents an expression A + B*x + C*x^2 where A, B and C are loop invariant valu...
LLVM_ABI const SCEV * getNumIterationsInRange(const ConstantRange &Range, ScalarEvolution &SE) const
Return the number of iterations of this loop that produce values in the specified constant range.
LLVM_ABI const SCEVAddRecExpr * getPostIncExpr(ScalarEvolution &SE) const
Return an expression representing the value of this expression one iteration of the loop ahead.
SCEVUse getStepRecurrence(ScalarEvolution &SE) const
Constructs and returns the recurrence indicating how much this expression steps by.
This is the base class for unary cast operator classes.
LLVM_ABI SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, SCEVUse op, Type *ty)
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a non-recurrence without clearing previously set flags.
This class represents an assumption that the expression LHS Pred RHS evaluates to true,...
SCEVComparePredicate(const FoldingSetNodeIDRef ID, const ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Implementation of the SCEVPredicate interface.
This class represents a constant integer value.
ConstantInt * getValue() const
const APInt & getAPInt() const
This is the base class for unary integral cast operator classes.
LLVM_ABI SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, SCEVUse op, Type *ty)
This node is the base class min/max selections.
static enum SCEVTypes negate(enum SCEVTypes T)
This node represents multiplication of some number of SCEVs.
This node is a base class providing common functionality for n'ary operators.
ArrayRef< SCEVUse > operands() const
NoWrapFlags getNoWrapFlags(NoWrapFlags Mask=NoWrapMask) const
SCEVUse getOperand(unsigned i) const
This class represents an assumption made using SCEV expressions which can be checked at run-time.
SCEVPredicate(const SCEVPredicate &)=default
virtual bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const =0
Returns true if this predicate implies N.
SCEVPredicateKind Kind
This class represents a cast from a pointer to a pointer-sized integer value, without capturing the p...
This class represents a cast from a pointer to a pointer-sized integer value.
This visitor recursively visits a SCEV expression and re-writes it.
const SCEV * visitSignExtendExpr(const SCEVSignExtendExpr *Expr)
const SCEV * visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr)
const SCEV * visitSMinExpr(const SCEVSMinExpr *Expr)
const SCEV * visitUMinExpr(const SCEVUMinExpr *Expr)
This class represents a signed minimum selection.
This node is the base class for sequential/in-order min/max selections.
static SCEVTypes getEquivalentNonSequentialSCEVType(SCEVTypes Ty)
This class represents a sign extension of a small integer value to a larger integer value.
Visit all nodes in the expression tree using worklist traversal.
This class represents a truncation of an integer value to a smaller integer value.
This class represents a binary unsigned division operation.
This class represents an unsigned minimum selection.
This class represents a composition of other SCEV predicates, and is the class that most clients will...
void print(raw_ostream &OS, unsigned Depth) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Returns true if this predicate implies N.
SCEVUnionPredicate(ArrayRef< const SCEVPredicate * > Preds, ScalarEvolution &SE)
Union predicates don't get cached so create a dummy set ID for it.
bool isAlwaysTrue() const override
Implementation of the SCEVPredicate interface.
This means that we are dealing with an entirely unknown SCEV value, and only represent it as its LLVM...
This class represents the value of vscale, as used when defining the length of a scalable vector or r...
This class represents an assumption made on an AddRec expression.
IncrementWrapFlags
Similar to SCEV::NoWrapFlags, but with slightly different semantics for FlagNUSW.
SCEVWrapPredicate(const FoldingSetNodeIDRef ID, const SCEVAddRecExpr *AR, IncrementWrapFlags Flags)
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Returns true if this predicate implies N.
static SCEVWrapPredicate::IncrementWrapFlags setFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OnFlags)
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
const SCEVAddRecExpr * getExpr() const
Implementation of the SCEVPredicate interface.
static SCEVWrapPredicate::IncrementWrapFlags clearFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OffFlags)
Convenient IncrementWrapFlags manipulation methods.
static SCEVWrapPredicate::IncrementWrapFlags getImpliedFlags(const SCEVAddRecExpr *AR, ScalarEvolution &SE)
Returns the set of SCEVWrapPredicate no wrap flags implied by a SCEVAddRecExpr.
IncrementWrapFlags getFlags() const
Returns the set assumed no overflow flags.
This class represents a zero extension of a small integer value to a larger integer value.
This class represents an analyzed expression in the program.
unsigned short getExpressionSize() const
LLVM_ABI bool isOne() const
Return true if the expression is a constant one.
LLVM_ABI void computeAndSetCanonical(ScalarEvolution &SE)
Compute and set the canonical SCEV, by constructing a SCEV with the same operands,...
LLVM_ABI bool isZero() const
Return true if the expression is a constant zero.
const SCEV * CanonicalSCEV
Pointer to the canonical version of the SCEV, i.e.
LLVM_ABI void dump() const
This method is used for debugging.
LLVM_ABI bool isAllOnesValue() const
Return true if the expression is a constant all-ones value.
LLVM_ABI bool isNonConstantNegative() const
Return true if the specified scev is negated, but not a constant.
LLVM_ABI ArrayRef< SCEVUse > operands() const
Return operands of this SCEV expression.
LLVM_ABI void print(raw_ostream &OS) const
Print out the internal representation of this scalar to the specified stream.
SCEV(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, unsigned short ExpressionSize)
SCEVTypes getSCEVType() const
LLVM_ABI Type * getType() const
Return the LLVM type of this SCEV expression.
NoWrapFlags
NoWrapFlags are bitfield indices into SubclassData.
Analysis pass that exposes the ScalarEvolution for a function.
LLVM_ABI ScalarEvolution run(Function &F, FunctionAnalysisManager &AM)
LLVM_ABI PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
LLVM_ABI PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
void print(raw_ostream &OS, const Module *=nullptr) const override
print - Print out the internal state of the pass.
bool runOnFunction(Function &F) override
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
void releaseMemory() override
releaseMemory() - This member can be implemented by a pass if it wants to be able to release its memo...
void verifyAnalysis() const override
verifyAnalysis() - This member can be implemented by a analysis pass to check state of analysis infor...
static LLVM_ABI LoopGuards collect(const Loop *L, ScalarEvolution &SE)
Collect rewrite map for loop guards for loop L, together with flags indicating if NUW and NSW can be ...
LLVM_ABI const SCEV * rewrite(const SCEV *Expr) const
Try to apply the collected loop guards to Expr.
The main scalar evolution driver.
LLVM_ABI const SCEV * getUDivExpr(SCEVUse LHS, SCEVUse RHS)
Get a canonical unsigned division expression, or something simpler if possible.
const SCEV * getConstantMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEVConstant that is greater than or equal to (i.e.
static bool hasFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags TestFlags)
const DataLayout & getDataLayout() const
Return the DataLayout associated with the module this SCEV instance is operating on.
LLVM_ABI bool isKnownNonNegative(const SCEV *S)
Test if the given expression is known to be non-negative.
LLVM_ABI bool isKnownOnEveryIteration(CmpPredicate Pred, const SCEVAddRecExpr *LHS, const SCEV *RHS)
Test if the condition described by Pred, LHS, RHS is known to be true on every iteration of the loop ...
LLVM_ABI const SCEV * getNegativeSCEV(const SCEV *V, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap)
Return the SCEV object corresponding to -V.
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterationsImpl(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
LLVM_ABI const SCEV * getUDivCeilSCEV(const SCEV *N, const SCEV *D)
Compute ceil(N / D).
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterations(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L at given Context duri...
LLVM_ABI Type * getWiderType(Type *Ty1, Type *Ty2) const
LLVM_ABI const SCEV * getAbsExpr(const SCEV *Op, bool IsNSW)
LLVM_ABI bool isKnownNonPositive(const SCEV *S)
Test if the given expression is known to be non-positive.
LLVM_ABI bool isKnownNegative(const SCEV *S)
Test if the given expression is known to be negative.
LLVM_ABI const SCEV * getPredicatedConstantMaxBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getConstantMaxBackedgeTakenCount, except it will add a set of SCEV predicates to Predicate...
LLVM_ABI const SCEV * removePointerBase(const SCEV *S)
Compute an expression equivalent to S - getPointerBase(S).
LLVM_ABI bool isLoopEntryGuardedByCond(const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the loop is protected by a conditional between LHS and RHS.
LLVM_ABI bool isKnownNonZero(const SCEV *S)
Test if the given expression is known to be non-zero.
LLVM_ABI const SCEV * getURemExpr(SCEVUse LHS, SCEVUse RHS)
Represents an unsigned remainder expression based on unsigned division.
LLVM_ABI const SCEV * getSCEVAtScope(const SCEV *S, const Loop *L)
Return a SCEV expression for the specified value at the specified scope in the program.
LLVM_ABI const SCEV * getBackedgeTakenCount(const Loop *L, ExitCountKind Kind=Exact)
If the specified loop has a predictable backedge-taken count, return it, otherwise return a SCEVCould...
LLVM_ABI const SCEV * getSMinExpr(SCEVUse LHS, SCEVUse RHS)
LLVM_ABI void setNoWrapFlags(SCEVAddRecExpr *AddRec, SCEV::NoWrapFlags Flags)
Update no-wrap flags of an AddRec.
LLVM_ABI const SCEV * getUMaxFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS)
Promote the operands to the wider of the types using zero-extension, and then perform a umax operatio...
const SCEV * getZero(Type *Ty)
Return a SCEV for the constant 0 of a specific type.
LLVM_ABI bool willNotOverflow(Instruction::BinaryOps BinOp, bool Signed, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI=nullptr)
Is operation BinOp between LHS and RHS provably does not have a signed/unsigned overflow (Signed)?
LLVM_ABI ExitLimit computeExitLimitFromCond(const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit, bool AllowPredicates=false)
Compute the number of times the backedge of the specified loop will execute if its exit condition wer...
LLVM_ABI const SCEV * getZeroExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI const SCEV * getMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< SCEVUse > &Operands)
LLVM_ABI const SCEVPredicate * getEqualPredicate(const SCEV *LHS, const SCEV *RHS)
LLVM_ABI unsigned getSmallConstantTripMultiple(const Loop *L, const SCEV *ExitCount)
Returns the largest constant divisor of the trip count as a normal unsigned value,...
LLVM_ABI uint64_t getTypeSizeInBits(Type *Ty) const
Return the size in bits of the specified type, for which isSCEVable must return true.
LLVM_ABI const SCEV * getConstant(ConstantInt *V)
LLVM_ABI const SCEV * getPredicatedBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getBackedgeTakenCount, except it will add a set of SCEV predicates to Predicates that are ...
LLVM_ABI const SCEV * getSCEV(Value *V)
Return a SCEV expression for the full generality of the specified expression.
LLVM_ABI const SCEV * getMinusSCEV(SCEVUse LHS, SCEVUse RHS, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Return LHS-RHS.
ConstantRange getSignedRange(const SCEV *S)
Determine the signed range for a particular SCEV.
LLVM_ABI const SCEV * getAddRecExpr(SCEVUse Start, SCEVUse Step, const Loop *L, SCEV::NoWrapFlags Flags)
Get an add recurrence expression for the specified loop.
LLVM_ABI const SCEV * getNoopOrSignExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
bool loopHasNoAbnormalExits(const Loop *L)
Return true if the loop has no abnormal exits.
LLVM_ABI const SCEV * getTripCountFromExitCount(const SCEV *ExitCount)
A version of getTripCountFromExitCount below which always picks an evaluation type which can not resu...
LLVM_ABI ScalarEvolution(Function &F, TargetLibraryInfo &TLI, AssumptionCache &AC, DominatorTree &DT, LoopInfo &LI)
const SCEV * getOne(Type *Ty)
Return a SCEV for the constant 1 of a specific type.
LLVM_ABI const SCEV * getTruncateOrNoop(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI const SCEV * getLosslessPtrToIntExpr(const SCEV *Op)
LLVM_ABI const SCEV * getCastExpr(SCEVTypes Kind, const SCEV *Op, Type *Ty)
LLVM_ABI const SCEV * getSequentialMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< SCEVUse > &Operands)
LLVM_ABI std::optional< bool > evaluatePredicateAt(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Check whether the condition described by Pred, LHS, and RHS is true or false in the given Context.
LLVM_ABI unsigned getSmallConstantMaxTripCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > *Predicates=nullptr)
Returns the upper bound of the loop trip count as a normal unsigned value.
LLVM_ABI const SCEV * getPtrToIntExpr(const SCEV *Op, Type *Ty)
LLVM_ABI bool isBackedgeTakenCountMaxOrZero(const Loop *L)
Return true if the backedge taken count is either the value returned by getConstantMaxBackedgeTakenCo...
LLVM_ABI void forgetLoop(const Loop *L)
This method should be called by the client when it has changed a loop in a way that may effect Scalar...
LLVM_ABI bool isLoopInvariant(const SCEV *S, const Loop *L)
Return true if the value of the given SCEV is unchanging in the specified loop.
LLVM_ABI bool isKnownPositive(const SCEV *S)
Test if the given expression is known to be positive.
LLVM_ABI bool SimplifyICmpOperands(CmpPredicate &Pred, SCEVUse &LHS, SCEVUse &RHS, unsigned Depth=0)
Simplify LHS and RHS in a comparison with predicate Pred.
APInt getUnsignedRangeMin(const SCEV *S)
Determine the min of the unsigned range for a particular SCEV.
LLVM_ABI const SCEV * getOffsetOfExpr(Type *IntTy, StructType *STy, unsigned FieldNo)
Return an expression for offsetof on the given field with type IntTy.
LLVM_ABI LoopDisposition getLoopDisposition(const SCEV *S, const Loop *L)
Return the "disposition" of the given SCEV with respect to the given loop.
LLVM_ABI bool containsAddRecurrence(const SCEV *S)
Return true if the SCEV is a scAddRecExpr or it contains scAddRecExpr.
LLVM_ABI const SCEV * getSignExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI bool hasOperand(const SCEV *S, const SCEV *Op) const
Test whether the given SCEV has Op as a direct or indirect operand.
LLVM_ABI const SCEV * getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI bool isSCEVable(Type *Ty) const
Test if values of the given type are analyzable within the SCEV framework.
LLVM_ABI Type * getEffectiveSCEVType(Type *Ty) const
Return a type with the same bitwidth as the given type and which represents how SCEV will treat the g...
LLVM_ABI const SCEVPredicate * getComparePredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
LLVM_ABI bool haveSameSign(const SCEV *S1, const SCEV *S2)
Return true if we know that S1 and S2 must have the same sign.
LLVM_ABI const SCEV * getNotSCEV(const SCEV *V)
Return the SCEV object corresponding to ~V.
LLVM_ABI const SCEV * getElementCount(Type *Ty, ElementCount EC, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap)
LLVM_ABI bool instructionCouldExistWithOperands(const SCEV *A, const SCEV *B)
Return true if there exists a point in the program at which both A and B could be operands to the sam...
ConstantRange getUnsignedRange(const SCEV *S)
Determine the unsigned range for a particular SCEV.
LLVM_ABI void print(raw_ostream &OS) const
LLVM_ABI const SCEV * getPredicatedExitCount(const Loop *L, const BasicBlock *ExitingBlock, SmallVectorImpl< const SCEVPredicate * > *Predicates, ExitCountKind Kind=Exact)
Same as above except this uses the predicated backedge taken info and may require predicates.
static SCEV::NoWrapFlags clearFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OffFlags)
LLVM_ABI void forgetTopmostLoop(const Loop *L)
LLVM_ABI void forgetValue(Value *V)
This method should be called by the client when it has changed a value in a way that may effect its v...
APInt getSignedRangeMin(const SCEV *S)
Determine the min of the signed range for a particular SCEV.
LLVM_ABI const SCEV * getNoopOrAnyExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI void forgetBlockAndLoopDispositions(Value *V=nullptr)
Called when the client has changed the disposition of values in a loop or block.
LLVM_ABI const SCEV * getTruncateExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI const SCEV * getUMaxExpr(SCEVUse LHS, SCEVUse RHS)
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantPredicate(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI=nullptr)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L, return a LoopInvaria...
LLVM_ABI const SCEV * getStoreSizeOfExpr(Type *IntTy, Type *StoreTy)
Return an expression for the store size of StoreTy that is type IntTy.
LLVM_ABI const SCEVPredicate * getWrapPredicate(const SCEVAddRecExpr *AR, SCEVWrapPredicate::IncrementWrapFlags AddedFlags)
LLVM_ABI bool isLoopBackedgeGuardedByCond(const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether the backedge of the loop is protected by a conditional between LHS and RHS.
LLVM_ABI APInt getNonZeroConstantMultiple(const SCEV *S)
const SCEV * getMinusOne(Type *Ty)
Return a SCEV for the constant -1 of a specific type.
static SCEV::NoWrapFlags setFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OnFlags)
LLVM_ABI bool hasLoopInvariantBackedgeTakenCount(const Loop *L)
Return true if the specified loop has an analyzable loop-invariant backedge-taken count.
LLVM_ABI BlockDisposition getBlockDisposition(const SCEV *S, const BasicBlock *BB)
Return the "disposition" of the given SCEV with respect to the given block.
LLVM_ABI const SCEV * getNoopOrZeroExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI bool invalidate(Function &F, const PreservedAnalyses &PA, FunctionAnalysisManager::Invalidator &Inv)
LLVM_ABI const SCEV * getUMinFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS, bool Sequential=false)
Promote the operands to the wider of the types using zero-extension, and then perform a umin operatio...
LLVM_ABI bool loopIsFiniteByAssumption(const Loop *L)
Return true if this loop is finite by assumption.
LLVM_ABI const SCEV * getExistingSCEV(Value *V)
Return an existing SCEV for V if there is one, otherwise return nullptr.
LLVM_ABI APInt getConstantMultiple(const SCEV *S, const Instruction *CtxI=nullptr)
Returns the max constant multiple of S.
LoopDisposition
An enum describing the relationship between a SCEV and a loop.
@ LoopComputable
The SCEV varies predictably with the loop.
@ LoopVariant
The SCEV is loop-variant (unknown).
@ LoopInvariant
The SCEV is loop-invariant.
LLVM_ABI bool isKnownMultipleOf(const SCEV *S, uint64_t M, SmallVectorImpl< const SCEVPredicate * > &Assumptions)
Check that S is a multiple of M.
LLVM_ABI const SCEV * getAnyExtendExpr(const SCEV *Op, Type *Ty)
getAnyExtendExpr - Return a SCEV for the given operand extended with unspecified bits out to the give...
LLVM_ABI bool isKnownToBeAPowerOfTwo(const SCEV *S, bool OrZero=false, bool OrNegative=false)
Test if the given expression is known to be a power of 2.
LLVM_ABI std::optional< SCEV::NoWrapFlags > getStrengthenedNoWrapFlagsFromBinOp(const OverflowingBinaryOperator *OBO)
Parse NSW/NUW flags from add/sub/mul IR binary operation Op into SCEV no-wrap flags,...
LLVM_ABI void forgetLcssaPhiWithNewPredecessor(Loop *L, PHINode *V)
Forget LCSSA phi node V of loop L to which a new predecessor was added, such that it may no longer be...
LLVM_ABI bool containsUndefs(const SCEV *S) const
Return true if the SCEV expression contains an undef value.
LLVM_ABI std::optional< MonotonicPredicateType > getMonotonicPredicateType(const SCEVAddRecExpr *LHS, ICmpInst::Predicate Pred)
If, for all loop invariant X, the predicate "LHS `Pred` X" is monotonically increasing or decreasing,...
LLVM_ABI const SCEV * getCouldNotCompute()
LLVM_ABI const SCEV * getMulExpr(SmallVectorImpl< SCEVUse > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical multiply expression, or something simpler if possible.
LLVM_ABI bool isAvailableAtLoopEntry(const SCEV *S, const Loop *L)
Determine if the SCEV can be evaluated at loop's entry.
LLVM_ABI uint32_t getMinTrailingZeros(const SCEV *S, const Instruction *CtxI=nullptr)
Determine the minimum number of zero bits that S is guaranteed to end in (at every loop iteration).
BlockDisposition
An enum describing the relationship between a SCEV and a basic block.
@ DominatesBlock
The SCEV dominates the block.
@ ProperlyDominatesBlock
The SCEV properly dominates the block.
@ DoesNotDominateBlock
The SCEV does not dominate the block.
LLVM_ABI const SCEV * getExitCount(const Loop *L, const BasicBlock *ExitingBlock, ExitCountKind Kind=Exact)
Return the number of times the backedge executes before the given exit would be taken; if not exactly...
LLVM_ABI const SCEV * getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI void getPoisonGeneratingValues(SmallPtrSetImpl< const Value * > &Result, const SCEV *S)
Return the set of Values that, if poison, will definitively result in S being poison as well.
LLVM_ABI void forgetLoopDispositions()
Called when the client has changed the disposition of values in this loop.
LLVM_ABI const SCEV * getVScale(Type *Ty)
LLVM_ABI unsigned getSmallConstantTripCount(const Loop *L)
Returns the exact trip count of the loop if we can compute it, and the result is a small constant.
LLVM_ABI bool hasComputableLoopEvolution(const SCEV *S, const Loop *L)
Return true if the given SCEV changes value in a known way in the specified loop.
LLVM_ABI const SCEV * getPointerBase(const SCEV *V)
Transitively follow the chain of pointer-type operands until reaching a SCEV that does not have a sin...
LLVM_ABI void forgetAllLoops()
LLVM_ABI bool dominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV dominate the specified basic block.
APInt getUnsignedRangeMax(const SCEV *S)
Determine the max of the unsigned range for a particular SCEV.
LLVM_ABI const SCEV * getAddExpr(SmallVectorImpl< SCEVUse > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical add expression, or something simpler if possible.
ExitCountKind
The terms "backedge taken count" and "exit count" are used interchangeably to refer to the number of ...
@ SymbolicMaximum
An expression which provides an upper bound on the exact trip count.
@ ConstantMaximum
A constant which provides an upper bound on the exact trip count.
@ Exact
An expression exactly describing the number of times the backedge has executed when a loop is exited.
LLVM_ABI bool isKnownPredicate(CmpPredicate Pred, SCEVUse LHS, SCEVUse RHS)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
LLVM_ABI const SCEV * applyLoopGuards(const SCEV *Expr, const Loop *L)
Try to apply information from loop guards for L to Expr.
LLVM_ABI const SCEV * getPtrToAddrExpr(const SCEV *Op)
LLVM_ABI const SCEVAddRecExpr * convertSCEVToAddRecWithPredicates(const SCEV *S, const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Preds)
Tries to convert the S expression to an AddRec expression, adding additional predicates to Preds as r...
LLVM_ABI const SCEV * getSMaxExpr(SCEVUse LHS, SCEVUse RHS)
LLVM_ABI const SCEV * getElementSize(Instruction *Inst)
Return the size of an element read or written by Inst.
LLVM_ABI const SCEV * getSizeOfExpr(Type *IntTy, TypeSize Size)
Return an expression for a TypeSize.
LLVM_ABI std::optional< bool > evaluatePredicate(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Check whether the condition described by Pred, LHS, and RHS is true or false.
LLVM_ABI const SCEV * getUnknown(Value *V)
LLVM_ABI std::optional< std::pair< const SCEV *, SmallVector< const SCEVPredicate *, 3 > > > createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI)
Checks if SymbolicPHI can be rewritten as an AddRecExpr under some Predicates.
LLVM_ABI const SCEV * getTruncateOrZeroExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI bool isKnownViaInduction(CmpPredicate Pred, SCEVUse LHS, SCEVUse RHS)
We'd like to check the predicate on every iteration of the most dominated loop between loops used in ...
static SCEV::NoWrapFlags maskFlags(SCEV::NoWrapFlags Flags, int Mask)
Convenient NoWrapFlags manipulation that hides enum casts and is visible in the ScalarEvolution name ...
LLVM_ABI std::optional< APInt > computeConstantDifference(const SCEV *LHS, const SCEV *RHS)
Compute LHS - RHS and returns the result as an APInt if it is a constant, and std::nullopt if it isn'...
LLVM_ABI bool properlyDominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV properly dominate the specified basic block.
LLVM_ABI const SCEV * getUDivExactExpr(SCEVUse LHS, SCEVUse RHS)
Get a canonical unsigned division expression, or something simpler if possible.
LLVM_ABI const SCEV * rewriteUsingPredicate(const SCEV *S, const Loop *L, const SCEVPredicate &A)
Re-writes the SCEV according to the Predicates in A.
LLVM_ABI std::pair< const SCEV *, const SCEV * > SplitIntoInitAndPostInc(const Loop *L, const SCEV *S)
Splits SCEV expression S into two SCEVs.
LLVM_ABI bool canReuseInstruction(const SCEV *S, Instruction *I, SmallVectorImpl< Instruction * > &DropPoisonGeneratingInsts)
Check whether it is poison-safe to represent the expression S using the instruction I.
LLVM_ABI bool isKnownPredicateAt(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
LLVM_ABI const SCEV * getPredicatedSymbolicMaxBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getSymbolicMaxBackedgeTakenCount, except it will add a set of SCEV predicates to Predicate...
LLVM_ABI const SCEV * getGEPExpr(GEPOperator *GEP, ArrayRef< SCEVUse > IndexExprs)
Returns an expression for a GEP.
LLVM_ABI const SCEV * getUMinExpr(SCEVUse LHS, SCEVUse RHS, bool Sequential=false)
LLVM_ABI void registerUser(const SCEV *User, ArrayRef< const SCEV * > Ops)
Notify this ScalarEvolution that User directly uses SCEVs in Ops.
LLVM_ABI bool isBasicBlockEntryGuardedByCond(const BasicBlock *BB, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the basic block is protected by a conditional between LHS and RHS.
LLVM_ABI const SCEV * getTruncateOrSignExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI bool containsErasedValue(const SCEV *S) const
Return true if the SCEV expression contains a Value that has been optimised out and is now a nullptr.
const SCEV * getSymbolicMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEV that is greater than or equal to (i.e.
APInt getSignedRangeMax(const SCEV *S)
Determine the max of the signed range for a particular SCEV.
LLVM_ABI void verify() const
LLVMContext & getContext() const
Implements a dense probed hash-table based set with some number of buckets stored inline.
Definition DenseSet.h:291
size_type size() const
Definition SmallPtrSet.h:99
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
bool contains(ConstPtrType Ptr) const
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
reference emplace_back(ArgTypes &&... Args)
void reserve(size_type N)
iterator erase(const_iterator CI)
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
iterator insert(iterator I, T &&Elt)
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
An instruction for storing to memory.
StringRef - Represent a constant reference to a string, i.e.
Definition StringRef.h:55
Used to lazily calculate structure layout information for a target machine, based on the DataLayout s...
Definition DataLayout.h:736
TypeSize getElementOffset(unsigned Idx) const
Definition DataLayout.h:767
TypeSize getSizeInBits() const
Definition DataLayout.h:747
Class to represent struct types.
Analysis pass providing the TargetLibraryInfo.
Provides information about what library functions are available for the current target.
The instances of the Type class are immutable: once they are created, they are never changed.
Definition Type.h:46
static LLVM_ABI IntegerType * getInt32Ty(LLVMContext &C)
Definition Type.cpp:313
bool isPointerTy() const
True if this is an instance of PointerType.
Definition Type.h:284
LLVM_ABI TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
Definition Type.cpp:201
static LLVM_ABI IntegerType * getInt1Ty(LLVMContext &C)
Definition Type.cpp:310
bool isIntOrPtrTy() const
Return true if this is an integer type or a pointer type.
Definition Type.h:272
bool isIntegerTy() const
True if this is an instance of IntegerType.
Definition Type.h:257
static LLVM_ABI IntegerType * getIntNTy(LLVMContext &C, unsigned N)
Definition Type.cpp:317
A Use represents the edge between a Value definition and its users.
Definition Use.h:35
op_range operands()
Definition User.h:267
Use & Op()
Definition User.h:171
Value * getOperand(unsigned i) const
Definition User.h:207
LLVM Value Representation.
Definition Value.h:75
Type * getType() const
All values are typed, get the type of this value.
Definition Value.h:255
LLVMContext & getContext() const
All values hold a context through their type.
Definition Value.h:258
unsigned getValueID() const
Return an ID for the concrete type of this object.
Definition Value.h:543
LLVM_ABI void printAsOperand(raw_ostream &O, bool PrintType=true, const Module *M=nullptr) const
Print the name of this Value out to the specified raw_ostream.
LLVM_ABI StringRef getName() const
Return a constant reference to the value's name.
Definition Value.cpp:318
constexpr bool isScalable() const
Returns whether the quantity is scaled by a runtime quantity (vscale).
Definition TypeSize.h:168
An efficient, type-erasing, non-owning reference to a callable.
const ParentTy * getParent() const
Definition ilist_node.h:34
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition raw_ostream.h:53
raw_ostream & indent(unsigned NumSpaces)
indent - Insert 'NumSpaces' spaces.
Changed
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr char Align[]
Key for Kernel::Arg::Metadata::mAlign.
const APInt & smin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be signed.
Definition APInt.h:2266
const APInt & smax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be signed.
Definition APInt.h:2271
const APInt & umin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be unsigned.
Definition APInt.h:2276
LLVM_ABI std::optional< APInt > SolveQuadraticEquationWrap(APInt A, APInt B, APInt C, unsigned RangeWidth)
Let q(n) = An^2 + Bn + C, and BW = bit width of the value range (e.g.
Definition APInt.cpp:2852
const APInt & umax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be unsigned.
Definition APInt.h:2281
LLVM_ABI APInt GreatestCommonDivisor(APInt A, APInt B)
Compute GCD of two unsigned APInt values.
Definition APInt.cpp:818
@ Entry
Definition COFF.h:862
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition CallingConv.h:24
@ C
The default llvm calling convention, compatible with C.
Definition CallingConv.h:34
int getMinValue(MCInstrInfo const &MCII, MCInst const &MCI)
Return the minimum value of an extendable operand.
@ BasicBlock
Various leaf nodes.
Definition ISDOpcodes.h:81
LLVM_ABI Function * getDeclarationIfExists(const Module *M, ID id)
Look up the Function declaration of the intrinsic id in the Module M and return it if it exists.
Predicate
Predicate - These are "(BI << 5) | BO" for various predicates.
BinaryOp_match< LHS, RHS, Instruction::AShr > m_AShr(const LHS &L, const RHS &R)
ap_match< APInt > m_APInt(const APInt *&Res)
Match a ConstantInt or splatted ConstantVector, binding the specified pointer to the contained APInt.
bool match(Val *V, const Pattern &P)
class_match< ConstantInt > m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
IntrinsicID_match m_Intrinsic()
Match intrinsic calls like this: m_Intrinsic<Intrinsic::fabs>(m_Value(X))
ThreeOps_match< Cond, LHS, RHS, Instruction::Select > m_Select(const Cond &C, const LHS &L, const RHS &R)
Matches SelectInst.
ExtractValue_match< Ind, Val_t > m_ExtractValue(const Val_t &V)
Match a single index ExtractValue instruction.
bind_ty< WithOverflowInst > m_WithOverflowInst(WithOverflowInst *&I)
Match a with overflow intrinsic, capturing it if we match.
auto m_LogicalOr()
Matches L || R where L and R are arbitrary values.
brc_match< Cond_t, bind_ty< BasicBlock >, bind_ty< BasicBlock > > m_Br(const Cond_t &C, BasicBlock *&T, BasicBlock *&F)
BinaryOp_match< LHS, RHS, Instruction::SDiv > m_SDiv(const LHS &L, const RHS &R)
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
BinaryOp_match< LHS, RHS, Instruction::LShr > m_LShr(const LHS &L, const RHS &R)
BinaryOp_match< LHS, RHS, Instruction::Shl > m_Shl(const LHS &L, const RHS &R)
auto m_LogicalAnd()
Matches L && R where L and R are arbitrary values.
class_match< BasicBlock > m_BasicBlock()
Match an arbitrary basic block value and ignore it.
match_combine_or< LTy, RTy > m_CombineOr(const LTy &L, const RTy &R)
Combine two pattern matchers matching L || R.
class_match< const SCEVVScale > m_SCEVVScale()
bind_cst_ty m_scev_APInt(const APInt *&C)
Match an SCEV constant and bind it to an APInt.
cst_pred_ty< is_all_ones > m_scev_AllOnes()
Match an integer with all bits set.
SCEVUnaryExpr_match< SCEVZeroExtendExpr, Op0_t > m_scev_ZExt(const Op0_t &Op0)
is_undef_or_poison m_scev_UndefOrPoison()
Match an SCEVUnknown wrapping undef or poison.
class_match< const SCEVConstant > m_SCEVConstant()
cst_pred_ty< is_one > m_scev_One()
Match an integer 1.
specificloop_ty m_SpecificLoop(const Loop *L)
SCEVAffineAddRec_match< Op0_t, Op1_t, class_match< const Loop > > m_scev_AffineAddRec(const Op0_t &Op0, const Op1_t &Op1)
bind_ty< const SCEVMulExpr > m_scev_Mul(const SCEVMulExpr *&V)
SCEVUnaryExpr_match< SCEVSignExtendExpr, Op0_t > m_scev_SExt(const Op0_t &Op0)
cst_pred_ty< is_zero > m_scev_Zero()
Match an integer 0.
SCEVUnaryExpr_match< SCEVTruncateExpr, Op0_t > m_scev_Trunc(const Op0_t &Op0)
bool match(const SCEV *S, const Pattern &P)
SCEVBinaryExpr_match< SCEVUDivExpr, Op0_t, Op1_t > m_scev_UDiv(const Op0_t &Op0, const Op1_t &Op1)
specificscev_ty m_scev_Specific(const SCEV *S)
Match if we have a specific specified SCEV.
SCEVBinaryExpr_match< SCEVMulExpr, Op0_t, Op1_t, SCEV::FlagNUW, true > m_scev_c_NUWMul(const Op0_t &Op0, const Op1_t &Op1)
class_match< const Loop > m_Loop()
bind_ty< const SCEVAddExpr > m_scev_Add(const SCEVAddExpr *&V)
bind_ty< const SCEVUnknown > m_SCEVUnknown(const SCEVUnknown *&V)
SCEVBinaryExpr_match< SCEVMulExpr, Op0_t, Op1_t, SCEV::FlagAnyWrap, true > m_scev_c_Mul(const Op0_t &Op0, const Op1_t &Op1)
SCEVBinaryExpr_match< SCEVSMaxExpr, Op0_t, Op1_t > m_scev_SMax(const Op0_t &Op0, const Op1_t &Op1)
SCEVURem_match< Op0_t, Op1_t > m_scev_URem(Op0_t LHS, Op1_t RHS, ScalarEvolution &SE)
Match the mathematical pattern A - (A / B) * B, where A and B can be arbitrary expressions.
class_match< const SCEV > m_SCEV()
@ Valid
The data is already valid.
initializer< Ty > init(const Ty &Val)
LocationClass< Ty > location(Ty &L)
@ Switch
The "resume-switch" lowering, where there are separate resume and destroy functions that are shared b...
Definition CoroShape.h:31
constexpr double e
NodeAddr< PhiNode * > Phi
Definition RDFGraph.h:390
friend class Instruction
Iterator for Instructions in a `BasicBlock.
Definition BasicBlock.h:73
This is an optimization pass for GlobalISel generic memory operations.
void visitAll(const SCEV *Root, SV &Visitor)
Use SCEVTraversal to visit all nodes in the given expression tree.
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
Definition STLExtras.h:316
@ Offset
Definition DWP.cpp:532
FunctionAddr VTableAddr Value
Definition InstrProf.h:137
LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt gcd(const DynamicAPInt &A, const DynamicAPInt &B)
void stable_sort(R &&Range)
Definition STLExtras.h:2116
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1739
SaveAndRestore(T &) -> SaveAndRestore< T >
Printable print(const GCNRegPressure &RP, const GCNSubtarget *ST=nullptr, unsigned DynamicVGPRBlockSize=0)
LLVM_ABI bool canCreatePoison(const Operator *Op, bool ConsiderFlagsAndMetadata=true)
LLVM_ABI bool mustTriggerUB(const Instruction *I, const SmallPtrSetImpl< const Value * > &KnownPoison)
Return true if the given instruction must trigger undefined behavior when I is executed with any oper...
LLVM_ABI bool canConstantFoldCallTo(const CallBase *Call, const Function *F)
canConstantFoldCallTo - Return true if its even possible to fold a call to the specified function.
InterleavedRange< Range > interleaved(const Range &R, StringRef Separator=", ", StringRef Prefix="", StringRef Suffix="")
Output range R as a sequence of interleaved elements.
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:643
LLVM_ABI bool verifyFunction(const Function &F, raw_ostream *OS=nullptr)
Check a function for errors, useful for use when debugging a pass.
auto successors(const MachineBasicBlock *BB)
scope_exit(Callable) -> scope_exit< Callable >
constexpr from_range_t from_range
auto dyn_cast_if_present(const Y &Val)
dyn_cast_if_present<X> - Functionally identical to dyn_cast, except that a null (or none in the case ...
Definition Casting.h:732
bool set_is_subset(const S1Ty &S1, const S2Ty &S2)
set_is_subset(A, B) - Return true iff A in B
void append_range(Container &C, Range &&R)
Wrapper function to append range R to container C.
Definition STLExtras.h:2208
constexpr bool isUIntN(unsigned N, uint64_t x)
Checks if an unsigned integer fits into the given (dynamic) bit width.
Definition MathExtras.h:243
LLVM_ABI Constant * ConstantFoldCompareInstOperands(unsigned Predicate, Constant *LHS, Constant *RHS, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr, const Instruction *I=nullptr)
Attempt to constant fold a compare instruction (icmp/fcmp) with the specified operands.
void * PointerTy
LLVM_ABI bool VerifySCEV
auto uninitialized_copy(R &&Src, IterTy Dst)
Definition STLExtras.h:2111
bool isa_and_nonnull(const Y &Val)
Definition Casting.h:676
LLVM_ABI ConstantRange getConstantRangeFromMetadata(const MDNode &RangeMD)
Parse out a conservative ConstantRange from !range metadata.
int countr_zero(T Val)
Count number of 0's from the least significant bit to the most stopping at the first 1.
Definition bit.h:202
LLVM_ABI Value * simplifyInstruction(Instruction *I, const SimplifyQuery &Q)
See if we can compute a simplified version of this instruction.
LLVM_ABI bool isOverflowIntrinsicNoWrap(const WithOverflowInst *WO, const DominatorTree &DT)
Returns true if the arithmetic part of the WO 's result is used only along the paths control dependen...
DomTreeNodeBase< BasicBlock > DomTreeNode
Definition Dominators.h:94
LLVM_ABI bool matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO, Value *&Start, Value *&Step)
Attempt to match a simple first order recurrence cycle of the form: iv = phi Ty [Start,...
auto dyn_cast_or_null(const Y &Val)
Definition Casting.h:753
void erase(Container &C, ValueType V)
Wrapper function to remove a value from a container:
Definition STLExtras.h:2200
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1746
iterator_range< pointee_iterator< WrappedIteratorT > > make_pointee_range(RangeT &&Range)
Definition iterator.h:341
auto reverse(ContainerTy &&C)
Definition STLExtras.h:408
LLVM_ABI bool isMustProgress(const Loop *L)
Return true if this loop can be assumed to make progress.
LLVM_ABI bool impliesPoison(const Value *ValAssumedPoison, const Value *V)
Return true if V is poison given that ValAssumedPoison is already poison.
LLVM_ABI bool isFinite(const Loop *L)
Return true if this loop can be assumed to run for a finite number of iterations.
LLVM_ABI void computeKnownBits(const Value *V, KnownBits &Known, const DataLayout &DL, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true, unsigned Depth=0)
Determine which bits of V are known to be either zero or one and return them in the KnownZero/KnownOn...
unsigned short computeExpressionSize(ArrayRef< SCEVUse > Args)
LLVM_ABI bool programUndefinedIfPoison(const Instruction *Inst)
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:207
bool isPointerTy(const Type *T)
Definition SPIRVUtils.h:368
FunctionAddr VTableAddr Count
Definition InstrProf.h:139
LLVM_ABI ConstantRange getVScaleRange(const Function *F, unsigned BitWidth)
Determine the possible constant range of vscale with the given bit width, based on the vscale_range f...
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
Definition Casting.h:547
LLVM_ATTRIBUTE_VISIBILITY_DEFAULT AnalysisKey InnerAnalysisManagerProxy< AnalysisManagerT, IRUnitT, ExtraArgTs... >::Key
LLVM_ABI bool isKnownNonZero(const Value *V, const SimplifyQuery &Q, unsigned Depth=0)
Return true if the given value is known to be non-zero when defined.
LLVM_ABI bool propagatesPoison(const Use &PoisonOp)
Return true if PoisonOp's user yields poison or raises UB if its operand PoisonOp is poison.
@ UMin
Unsigned integer min implemented in terms of select(cmp()).
@ Mul
Product of integers.
@ SMax
Signed integer max implemented in terms of select(cmp()).
@ SMin
Signed integer min implemented in terms of select(cmp()).
@ Add
Sum of integers.
@ UMax
Unsigned integer max implemented in terms of select(cmp()).
auto count(R &&Range, const E &Element)
Wrapper function around std::count to count the number of times an element Element occurs in the give...
Definition STLExtras.h:2012
DWARFExpression::Operation Op
auto max_element(R &&Range)
Provide wrappers to std::max_element which take ranges instead of having to pass begin/end explicitly...
Definition STLExtras.h:2088
raw_ostream & operator<<(raw_ostream &OS, const APFixedPoint &FX)
ArrayRef(const T &OneElt) -> ArrayRef< T >
LLVM_ABI unsigned ComputeNumSignBits(const Value *Op, const DataLayout &DL, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true, unsigned Depth=0)
Return the number of times the sign bit of the register is replicated into the other bits.
constexpr unsigned BitWidth
OutputIt move(R &&Range, OutputIt Out)
Provide wrappers to std::move which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1917
LLVM_ABI bool isGuaranteedToTransferExecutionToSuccessor(const Instruction *I)
Return true if this function can prove that the instruction I will always transfer execution to one o...
auto count_if(R &&Range, UnaryPredicate P)
Wrapper function around std::count_if to count the number of times an element satisfying a given pred...
Definition STLExtras.h:2019
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:559
constexpr bool isIntN(unsigned N, int64_t x)
Checks if an signed integer fits into the given (dynamic) bit width.
Definition MathExtras.h:248
auto predecessors(const MachineBasicBlock *BB)
bool is_contained(R &&Range, const E &Element)
Returns true if Element is found in Range.
Definition STLExtras.h:1947
iterator_range< df_iterator< T > > depth_first(const T &G)
auto seq(T Begin, T End)
Iterate over an integral type from Begin up to - but not including - End.
Definition Sequence.h:305
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.
LLVM_ABI bool isGuaranteedNotToBePoison(const Value *V, AssumptionCache *AC=nullptr, const Instruction *CtxI=nullptr, const DominatorTree *DT=nullptr, unsigned Depth=0)
Returns true if V cannot be poison, but may be undef.
LLVM_ABI Constant * ConstantFoldInstOperands(const Instruction *I, ArrayRef< Constant * > Ops, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr, bool AllowNonDeterministic=true)
ConstantFoldInstOperands - Attempt to constant fold an instruction with the specified operands.
SCEVUseT< const SCEV * > SCEVUse
bool SCEVExprContains(const SCEV *Root, PredTy Pred)
Return true if any node in Root satisfies the predicate Pred.
Implement std::hash so that hash_code can be used in STL containers.
Definition BitVector.h:870
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition BitVector.h:872
#define N
#define NC
Definition regutils.h:42
A special type used by analysis passes to provide an address that identifies that particular analysis...
Definition Analysis.h:29
Incoming for lane mask phi as machine instruction, incoming register Reg and incoming block Block are...
static KnownBits makeConstant(const APInt &C)
Create known bits from a known constant.
Definition KnownBits.h:317
bool isNonNegative() const
Returns true if this value is known to be non-negative.
Definition KnownBits.h:108
static LLVM_ABI KnownBits ashr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for ashr(LHS, RHS).
unsigned getBitWidth() const
Get the bit width of this value.
Definition KnownBits.h:44
static LLVM_ABI KnownBits lshr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for lshr(LHS, RHS).
KnownBits zextOrTrunc(unsigned BitWidth) const
Return known bits for a zero extension or truncation of the value we're tracking.
Definition KnownBits.h:202
APInt getMaxValue() const
Return the maximal unsigned value possible given these KnownBits.
Definition KnownBits.h:148
APInt getMinValue() const
Return the minimal unsigned value possible given these KnownBits.
Definition KnownBits.h:132
bool isNegative() const
Returns true if this value is known to be negative.
Definition KnownBits.h:105
static LLVM_ABI KnownBits shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW=false, bool NSW=false, bool ShAmtNonZero=false)
Compute known bits for shl(LHS, RHS).
An object of this class is returned by queries that could not be answered.
static LLVM_ABI bool classof(const SCEV *S)
Methods for support type inquiry through isa, cast, and dyn_cast:
This class defines a simple visitor class that may be used for various SCEV analysis purposes.
A utility class that uses RAII to save and restore the value of a variable.
Information about the number of loop iterations for which a loop exit's branch condition evaluates to...
LLVM_ABI ExitLimit(const SCEV *E)
Construct either an exact exit limit from a constant, or an unknown one from a SCEVCouldNotCompute.
SmallVector< const SCEVPredicate *, 4 > Predicates
A vector of predicate guards for this ExitLimit.