LLVM 23.0.0git
ScalarEvolution.cpp
Go to the documentation of this file.
1//===- ScalarEvolution.cpp - Scalar Evolution Analysis --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains the implementation of the scalar evolution analysis
10// engine, which is used primarily to analyze expressions involving induction
11// variables in loops.
12//
13// There are several aspects to this library. First is the representation of
14// scalar expressions, which are represented as subclasses of the SCEV class.
15// These classes are used to represent certain types of subexpressions that we
16// can handle. We only create one SCEV of a particular shape, so
17// pointer-comparisons for equality are legal.
18//
19// One important aspect of the SCEV objects is that they are never cyclic, even
20// if there is a cycle in the dataflow for an expression (ie, a PHI node). If
21// the PHI node is one of the idioms that we can represent (e.g., a polynomial
22// recurrence) then we represent it directly as a recurrence node, otherwise we
23// represent it as a SCEVUnknown node.
24//
25// In addition to being able to represent expressions of various types, we also
26// have folders that are used to build the *canonical* representation for a
27// particular expression. These folders are capable of using a variety of
28// rewrite rules to simplify the expressions.
29//
30// Once the folders are defined, we can implement the more interesting
31// higher-level code, such as the code that recognizes PHI nodes of various
32// types, computes the execution count of a loop, etc.
33//
34// TODO: We should use these routines and value representations to implement
35// dependence analysis!
36//
37//===----------------------------------------------------------------------===//
38//
39// There are several good references for the techniques used in this analysis.
40//
41// Chains of recurrences -- a method to expedite the evaluation
42// of closed-form functions
43// Olaf Bachmann, Paul S. Wang, Eugene V. Zima
44//
45// On computational properties of chains of recurrences
46// Eugene V. Zima
47//
48// Symbolic Evaluation of Chains of Recurrences for Loop Optimization
49// Robert A. van Engelen
50//
51// Efficient Symbolic Analysis for Optimizing Compilers
52// Robert A. van Engelen
53//
54// Using the chains of recurrences algebra for data dependence testing and
55// induction variable substitution
56// MS Thesis, Johnie Birch
57//
58//===----------------------------------------------------------------------===//
59
61#include "llvm/ADT/APInt.h"
62#include "llvm/ADT/ArrayRef.h"
63#include "llvm/ADT/DenseMap.h"
65#include "llvm/ADT/FoldingSet.h"
66#include "llvm/ADT/STLExtras.h"
67#include "llvm/ADT/ScopeExit.h"
68#include "llvm/ADT/Sequence.h"
71#include "llvm/ADT/Statistic.h"
73#include "llvm/ADT/StringRef.h"
83#include "llvm/Config/llvm-config.h"
84#include "llvm/IR/Argument.h"
85#include "llvm/IR/BasicBlock.h"
86#include "llvm/IR/CFG.h"
87#include "llvm/IR/Constant.h"
89#include "llvm/IR/Constants.h"
90#include "llvm/IR/DataLayout.h"
92#include "llvm/IR/Dominators.h"
93#include "llvm/IR/Function.h"
94#include "llvm/IR/GlobalAlias.h"
95#include "llvm/IR/GlobalValue.h"
97#include "llvm/IR/InstrTypes.h"
98#include "llvm/IR/Instruction.h"
101#include "llvm/IR/Intrinsics.h"
102#include "llvm/IR/LLVMContext.h"
103#include "llvm/IR/Operator.h"
104#include "llvm/IR/PatternMatch.h"
105#include "llvm/IR/Type.h"
106#include "llvm/IR/Use.h"
107#include "llvm/IR/User.h"
108#include "llvm/IR/Value.h"
109#include "llvm/IR/Verifier.h"
111#include "llvm/Pass.h"
112#include "llvm/Support/Casting.h"
115#include "llvm/Support/Debug.h"
121#include <algorithm>
122#include <cassert>
123#include <climits>
124#include <cstdint>
125#include <cstdlib>
126#include <map>
127#include <memory>
128#include <numeric>
129#include <optional>
130#include <tuple>
131#include <utility>
132#include <vector>
133
134using namespace llvm;
135using namespace PatternMatch;
136using namespace SCEVPatternMatch;
137
138#define DEBUG_TYPE "scalar-evolution"
139
140STATISTIC(NumExitCountsComputed,
141 "Number of loop exits with predictable exit counts");
142STATISTIC(NumExitCountsNotComputed,
143 "Number of loop exits without predictable exit counts");
144STATISTIC(NumBruteForceTripCountsComputed,
145 "Number of loops with trip counts computed by force");
146
147#ifdef EXPENSIVE_CHECKS
148bool llvm::VerifySCEV = true;
149#else
150bool llvm::VerifySCEV = false;
151#endif
152
154 MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
155 cl::desc("Maximum number of iterations SCEV will "
156 "symbolically execute a constant "
157 "derived loop"),
158 cl::init(100));
159
161 "verify-scev", cl::Hidden, cl::location(VerifySCEV),
162 cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"));
164 "verify-scev-strict", cl::Hidden,
165 cl::desc("Enable stricter verification with -verify-scev is passed"));
166
168 "scev-verify-ir", cl::Hidden,
169 cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"),
170 cl::init(false));
171
173 "scev-mulops-inline-threshold", cl::Hidden,
174 cl::desc("Threshold for inlining multiplication operands into a SCEV"),
175 cl::init(32));
176
178 "scev-addops-inline-threshold", cl::Hidden,
179 cl::desc("Threshold for inlining addition operands into a SCEV"),
180 cl::init(500));
181
183 "scalar-evolution-max-scev-compare-depth", cl::Hidden,
184 cl::desc("Maximum depth of recursive SCEV complexity comparisons"),
185 cl::init(32));
186
188 "scalar-evolution-max-scev-operations-implication-depth", cl::Hidden,
189 cl::desc("Maximum depth of recursive SCEV operations implication analysis"),
190 cl::init(2));
191
193 "scalar-evolution-max-value-compare-depth", cl::Hidden,
194 cl::desc("Maximum depth of recursive value complexity comparisons"),
195 cl::init(2));
196
198 MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden,
199 cl::desc("Maximum depth of recursive arithmetics"),
200 cl::init(32));
201
203 "scalar-evolution-max-constant-evolving-depth", cl::Hidden,
204 cl::desc("Maximum depth of recursive constant evolving"), cl::init(32));
205
207 MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden,
208 cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"),
209 cl::init(8));
210
212 MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden,
213 cl::desc("Max coefficients in AddRec during evolving"),
214 cl::init(8));
215
217 HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden,
218 cl::desc("Size of the expression which is considered huge"),
219 cl::init(4096));
220
222 "scev-range-iter-threshold", cl::Hidden,
223 cl::desc("Threshold for switching to iteratively computing SCEV ranges"),
224 cl::init(32));
225
227 "scalar-evolution-max-loop-guard-collection-depth", cl::Hidden,
228 cl::desc("Maximum depth for recursive loop guard collection"), cl::init(1));
229
230static cl::opt<bool>
231ClassifyExpressions("scalar-evolution-classify-expressions",
232 cl::Hidden, cl::init(true),
233 cl::desc("When printing analysis, include information on every instruction"));
234
236 "scalar-evolution-use-expensive-range-sharpening", cl::Hidden,
237 cl::init(false),
238 cl::desc("Use more powerful methods of sharpening expression ranges. May "
239 "be costly in terms of compile time"));
240
242 "scalar-evolution-max-scc-analysis-depth", cl::Hidden,
243 cl::desc("Maximum amount of nodes to process while searching SCEVUnknown "
244 "Phi strongly connected components"),
245 cl::init(8));
246
247static cl::opt<bool>
248 EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden,
249 cl::desc("Handle <= and >= in finite loops"),
250 cl::init(true));
251
253 "scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden,
254 cl::desc("Infer nuw/nsw flags using context where suitable"),
255 cl::init(true));
256
257//===----------------------------------------------------------------------===//
258// SCEV class definitions
259//===----------------------------------------------------------------------===//
260
262 // Leaf nodes are always their own canonical.
263 switch (getSCEVType()) {
264 case scConstant:
265 case scVScale:
266 case scUnknown:
267 CanonicalSCEV = this;
268 return;
269 default:
270 break;
271 }
272
273 // For all other expressions, check whether any immediate operand has a
274 // different canonical. Since operands are always created before their parent,
275 // their canonical pointers are already set — no recursion needed.
276 bool Changed = false;
278 for (SCEVUse Op : operands()) {
279 CanonOps.push_back(Op->getCanonical());
280 Changed |= CanonOps.back() != Op.getPointer();
281 }
282
283 if (!Changed) {
284 CanonicalSCEV = this;
285 return;
286 }
287
288 auto *NAry = dyn_cast<SCEVNAryExpr>(this);
289 SCEV::NoWrapFlags Flags = NAry ? NAry->getNoWrapFlags() : SCEV::FlagAnyWrap;
290 switch (getSCEVType()) {
291 case scPtrToAddr:
292 CanonicalSCEV = SE.getPtrToAddrExpr(CanonOps[0]);
293 return;
294 case scPtrToInt:
295 CanonicalSCEV = SE.getPtrToIntExpr(CanonOps[0], getType());
296 return;
297 case scTruncate:
298 CanonicalSCEV = SE.getTruncateExpr(CanonOps[0], getType());
299 return;
300 case scZeroExtend:
301 CanonicalSCEV = SE.getZeroExtendExpr(CanonOps[0], getType());
302 return;
303 case scSignExtend:
304 CanonicalSCEV = SE.getSignExtendExpr(CanonOps[0], getType());
305 return;
306 case scUDivExpr:
307 CanonicalSCEV = SE.getUDivExpr(CanonOps[0], CanonOps[1]);
308 return;
309 case scAddExpr:
310 CanonicalSCEV = SE.getAddExpr(CanonOps, Flags);
311 return;
312 case scMulExpr:
313 CanonicalSCEV = SE.getMulExpr(CanonOps, Flags);
314 return;
315 case scAddRecExpr:
317 CanonOps, cast<SCEVAddRecExpr>(this)->getLoop(), Flags);
318 return;
319 case scSMaxExpr:
320 CanonicalSCEV = SE.getSMaxExpr(CanonOps);
321 return;
322 case scUMaxExpr:
323 CanonicalSCEV = SE.getUMaxExpr(CanonOps);
324 return;
325 case scSMinExpr:
326 CanonicalSCEV = SE.getSMinExpr(CanonOps);
327 return;
328 case scUMinExpr:
329 CanonicalSCEV = SE.getUMinExpr(CanonOps);
330 return;
332 CanonicalSCEV = SE.getUMinExpr(CanonOps, /*Sequential=*/true);
333 return;
334 default:
335 llvm_unreachable("Unknown SCEV type");
336 }
337}
338
339//===----------------------------------------------------------------------===//
340// Implementation of the SCEV class.
341//
342
343#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
345 print(dbgs());
346 dbgs() << '\n';
347}
348#endif
349
350void SCEV::print(raw_ostream &OS) const {
351 switch (getSCEVType()) {
352 case scConstant:
353 cast<SCEVConstant>(this)->getValue()->printAsOperand(OS, false);
354 return;
355 case scVScale:
356 OS << "vscale";
357 return;
358 case scPtrToAddr:
359 case scPtrToInt: {
360 const SCEVCastExpr *PtrCast = cast<SCEVCastExpr>(this);
361 const SCEV *Op = PtrCast->getOperand();
362 StringRef OpS = getSCEVType() == scPtrToAddr ? "addr" : "int";
363 OS << "(ptrto" << OpS << " " << *Op->getType() << " " << *Op << " to "
364 << *PtrCast->getType() << ")";
365 return;
366 }
367 case scTruncate: {
368 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(this);
369 const SCEV *Op = Trunc->getOperand();
370 OS << "(trunc " << *Op->getType() << " " << *Op << " to "
371 << *Trunc->getType() << ")";
372 return;
373 }
374 case scZeroExtend: {
376 const SCEV *Op = ZExt->getOperand();
377 OS << "(zext " << *Op->getType() << " " << *Op << " to "
378 << *ZExt->getType() << ")";
379 return;
380 }
381 case scSignExtend: {
383 const SCEV *Op = SExt->getOperand();
384 OS << "(sext " << *Op->getType() << " " << *Op << " to "
385 << *SExt->getType() << ")";
386 return;
387 }
388 case scAddRecExpr: {
389 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(this);
390 OS << "{" << *AR->getOperand(0);
391 for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i)
392 OS << ",+," << *AR->getOperand(i);
393 OS << "}<";
394 if (AR->hasNoUnsignedWrap())
395 OS << "nuw><";
396 if (AR->hasNoSignedWrap())
397 OS << "nsw><";
398 if (AR->hasNoSelfWrap() && !AR->hasNoUnsignedWrap() &&
399 !AR->hasNoSignedWrap())
400 OS << "nw><";
401 AR->getLoop()->getHeader()->printAsOperand(OS, /*PrintType=*/false);
402 OS << ">";
403 return;
404 }
405 case scAddExpr:
406 case scMulExpr:
407 case scUMaxExpr:
408 case scSMaxExpr:
409 case scUMinExpr:
410 case scSMinExpr:
412 const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(this);
413 const char *OpStr = nullptr;
414 switch (NAry->getSCEVType()) {
415 case scAddExpr: OpStr = " + "; break;
416 case scMulExpr: OpStr = " * "; break;
417 case scUMaxExpr: OpStr = " umax "; break;
418 case scSMaxExpr: OpStr = " smax "; break;
419 case scUMinExpr:
420 OpStr = " umin ";
421 break;
422 case scSMinExpr:
423 OpStr = " smin ";
424 break;
426 OpStr = " umin_seq ";
427 break;
428 default:
429 llvm_unreachable("There are no other nary expression types.");
430 }
431 OS << "("
433 << ")";
434 switch (NAry->getSCEVType()) {
435 case scAddExpr:
436 case scMulExpr:
437 if (NAry->hasNoUnsignedWrap())
438 OS << "<nuw>";
439 if (NAry->hasNoSignedWrap())
440 OS << "<nsw>";
441 break;
442 default:
443 // Nothing to print for other nary expressions.
444 break;
445 }
446 return;
447 }
448 case scUDivExpr: {
449 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(this);
450 OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")";
451 return;
452 }
453 case scUnknown:
454 cast<SCEVUnknown>(this)->getValue()->printAsOperand(OS, false);
455 return;
457 OS << "***COULDNOTCOMPUTE***";
458 return;
459 }
460 llvm_unreachable("Unknown SCEV kind!");
461}
462
464 switch (getSCEVType()) {
465 case scConstant:
466 return cast<SCEVConstant>(this)->getType();
467 case scVScale:
468 return cast<SCEVVScale>(this)->getType();
469 case scPtrToAddr:
470 case scPtrToInt:
471 case scTruncate:
472 case scZeroExtend:
473 case scSignExtend:
474 return cast<SCEVCastExpr>(this)->getType();
475 case scAddRecExpr:
476 return cast<SCEVAddRecExpr>(this)->getType();
477 case scMulExpr:
478 return cast<SCEVMulExpr>(this)->getType();
479 case scUMaxExpr:
480 case scSMaxExpr:
481 case scUMinExpr:
482 case scSMinExpr:
483 return cast<SCEVMinMaxExpr>(this)->getType();
485 return cast<SCEVSequentialMinMaxExpr>(this)->getType();
486 case scAddExpr:
487 return cast<SCEVAddExpr>(this)->getType();
488 case scUDivExpr:
489 return cast<SCEVUDivExpr>(this)->getType();
490 case scUnknown:
491 return cast<SCEVUnknown>(this)->getType();
493 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
494 }
495 llvm_unreachable("Unknown SCEV kind!");
496}
497
499 switch (getSCEVType()) {
500 case scConstant:
501 case scVScale:
502 case scUnknown:
503 return {};
504 case scPtrToAddr:
505 case scPtrToInt:
506 case scTruncate:
507 case scZeroExtend:
508 case scSignExtend:
509 return cast<SCEVCastExpr>(this)->operands();
510 case scAddRecExpr:
511 case scAddExpr:
512 case scMulExpr:
513 case scUMaxExpr:
514 case scSMaxExpr:
515 case scUMinExpr:
516 case scSMinExpr:
518 return cast<SCEVNAryExpr>(this)->operands();
519 case scUDivExpr:
520 return cast<SCEVUDivExpr>(this)->operands();
522 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
523 }
524 llvm_unreachable("Unknown SCEV kind!");
525}
526
527bool SCEV::isZero() const { return match(this, m_scev_Zero()); }
528
529bool SCEV::isOne() const { return match(this, m_scev_One()); }
530
531bool SCEV::isAllOnesValue() const { return match(this, m_scev_AllOnes()); }
532
535 if (!Mul) return false;
536
537 // If there is a constant factor, it will be first.
538 const SCEVConstant *SC = dyn_cast<SCEVConstant>(Mul->getOperand(0));
539 if (!SC) return false;
540
541 // Return true if the value is negative, this matches things like (-42 * V).
542 return SC->getAPInt().isNegative();
543}
544
547
549 return S->getSCEVType() == scCouldNotCompute;
550}
551
554 ID.AddInteger(scConstant);
555 ID.AddPointer(V);
556 void *IP = nullptr;
557 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
558 SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V);
559 UniqueSCEVs.InsertNode(S, IP);
560 S->computeAndSetCanonical(*this);
561 return S;
562}
563
565 return getConstant(ConstantInt::get(getContext(), Val));
566}
567
568const SCEV *
571 // TODO: Avoid implicit trunc?
572 // See https://github.com/llvm/llvm-project/issues/112510.
573 return getConstant(
574 ConstantInt::get(ITy, V, isSigned, /*ImplicitTrunc=*/true));
575}
576
579 ID.AddInteger(scVScale);
580 ID.AddPointer(Ty);
581 void *IP = nullptr;
582 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
583 return S;
584 SCEV *S = new (SCEVAllocator) SCEVVScale(ID.Intern(SCEVAllocator), Ty);
585 UniqueSCEVs.InsertNode(S, IP);
586 S->computeAndSetCanonical(*this);
587 return S;
588}
589
591 SCEV::NoWrapFlags Flags) {
592 const SCEV *Res = getConstant(Ty, EC.getKnownMinValue());
593 if (EC.isScalable())
594 Res = getMulExpr(Res, getVScale(Ty), Flags);
595 return Res;
596}
597
601
602SCEVPtrToAddrExpr::SCEVPtrToAddrExpr(const FoldingSetNodeIDRef ID,
603 const SCEV *Op, Type *ITy)
604 : SCEVCastExpr(ID, scPtrToAddr, Op, ITy) {
605 assert(getOperand()->getType()->isPointerTy() && Ty->isIntegerTy() &&
606 "Must be a non-bit-width-changing pointer-to-integer cast!");
607}
608
609SCEVPtrToIntExpr::SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, SCEVUse Op,
610 Type *ITy)
611 : SCEVCastExpr(ID, scPtrToInt, Op, ITy) {
612 assert(getOperand()->getType()->isPointerTy() && Ty->isIntegerTy() &&
613 "Must be a non-bit-width-changing pointer-to-integer cast!");
614}
615
620
621SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID, SCEVUse op,
622 Type *ty)
624 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
625 "Cannot truncate non-integer value!");
626}
627
628SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID, SCEVUse op,
629 Type *ty)
631 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
632 "Cannot zero extend non-integer value!");
633}
634
635SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID, SCEVUse op,
636 Type *ty)
638 assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
639 "Cannot sign extend non-integer value!");
640}
641
643 // Clear this SCEVUnknown from various maps.
644 SE->forgetMemoizedResults({this});
645
646 // Remove this SCEVUnknown from the uniquing map.
647 SE->UniqueSCEVs.RemoveNode(this);
648
649 // Release the value.
650 setValPtr(nullptr);
651}
652
653void SCEVUnknown::allUsesReplacedWith(Value *New) {
654 // Clear this SCEVUnknown from various maps.
655 SE->forgetMemoizedResults({this});
656
657 // Remove this SCEVUnknown from the uniquing map.
658 SE->UniqueSCEVs.RemoveNode(this);
659
660 // Replace the value pointer in case someone is still using this SCEVUnknown.
661 setValPtr(New);
662}
663
664//===----------------------------------------------------------------------===//
665// SCEV Utilities
666//===----------------------------------------------------------------------===//
667
668/// Compare the two values \p LV and \p RV in terms of their "complexity" where
669/// "complexity" is a partial (and somewhat ad-hoc) relation used to order
670/// operands in SCEV expressions.
671static int CompareValueComplexity(const LoopInfo *const LI, Value *LV,
672 Value *RV, unsigned Depth) {
674 return 0;
675
676 // Order pointer values after integer values. This helps SCEVExpander form
677 // GEPs.
678 bool LIsPointer = LV->getType()->isPointerTy(),
679 RIsPointer = RV->getType()->isPointerTy();
680 if (LIsPointer != RIsPointer)
681 return (int)LIsPointer - (int)RIsPointer;
682
683 // Compare getValueID values.
684 unsigned LID = LV->getValueID(), RID = RV->getValueID();
685 if (LID != RID)
686 return (int)LID - (int)RID;
687
688 // Sort arguments by their position.
689 if (const auto *LA = dyn_cast<Argument>(LV)) {
690 const auto *RA = cast<Argument>(RV);
691 unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo();
692 return (int)LArgNo - (int)RArgNo;
693 }
694
695 if (const auto *LGV = dyn_cast<GlobalValue>(LV)) {
696 const auto *RGV = cast<GlobalValue>(RV);
697
698 if (auto L = LGV->getLinkage() - RGV->getLinkage())
699 return L;
700
701 const auto IsGVNameSemantic = [&](const GlobalValue *GV) {
702 auto LT = GV->getLinkage();
703 return !(GlobalValue::isPrivateLinkage(LT) ||
705 };
706
707 // Use the names to distinguish the two values, but only if the
708 // names are semantically important.
709 if (IsGVNameSemantic(LGV) && IsGVNameSemantic(RGV))
710 return LGV->getName().compare(RGV->getName());
711 }
712
713 // For instructions, compare their loop depth, and their operand count. This
714 // is pretty loose.
715 if (const auto *LInst = dyn_cast<Instruction>(LV)) {
716 const auto *RInst = cast<Instruction>(RV);
717
718 // Compare loop depths.
719 const BasicBlock *LParent = LInst->getParent(),
720 *RParent = RInst->getParent();
721 if (LParent != RParent) {
722 unsigned LDepth = LI->getLoopDepth(LParent),
723 RDepth = LI->getLoopDepth(RParent);
724 if (LDepth != RDepth)
725 return (int)LDepth - (int)RDepth;
726 }
727
728 // Compare the number of operands.
729 unsigned LNumOps = LInst->getNumOperands(),
730 RNumOps = RInst->getNumOperands();
731 if (LNumOps != RNumOps)
732 return (int)LNumOps - (int)RNumOps;
733
734 for (unsigned Idx : seq(LNumOps)) {
735 int Result = CompareValueComplexity(LI, LInst->getOperand(Idx),
736 RInst->getOperand(Idx), Depth + 1);
737 if (Result != 0)
738 return Result;
739 }
740 }
741
742 return 0;
743}
744
745// Return negative, zero, or positive, if LHS is less than, equal to, or greater
746// than RHS, respectively. A three-way result allows recursive comparisons to be
747// more efficient.
748// If the max analysis depth was reached, return std::nullopt, assuming we do
749// not know if they are equivalent for sure.
750static std::optional<int>
751CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS,
752 const SCEV *RHS, DominatorTree &DT, unsigned Depth = 0) {
753 // Fast-path: SCEVs are uniqued so we can do a quick equality check.
754 if (LHS == RHS)
755 return 0;
756
757 // Primarily, sort the SCEVs by their getSCEVType().
758 SCEVTypes LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
759 if (LType != RType)
760 return (int)LType - (int)RType;
761
763 return std::nullopt;
764
765 // Aside from the getSCEVType() ordering, the particular ordering
766 // isn't very important except that it's beneficial to be consistent,
767 // so that (a + b) and (b + a) don't end up as different expressions.
768 switch (LType) {
769 case scUnknown: {
770 const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
771 const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
772
773 int X =
774 CompareValueComplexity(LI, LU->getValue(), RU->getValue(), Depth + 1);
775 return X;
776 }
777
778 case scConstant: {
781
782 // Compare constant values.
783 const APInt &LA = LC->getAPInt();
784 const APInt &RA = RC->getAPInt();
785 unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth();
786 if (LBitWidth != RBitWidth)
787 return (int)LBitWidth - (int)RBitWidth;
788 return LA.ult(RA) ? -1 : 1;
789 }
790
791 case scVScale: {
792 const auto *LTy = cast<IntegerType>(cast<SCEVVScale>(LHS)->getType());
793 const auto *RTy = cast<IntegerType>(cast<SCEVVScale>(RHS)->getType());
794 return LTy->getBitWidth() - RTy->getBitWidth();
795 }
796
797 case scAddRecExpr: {
800
801 // There is always a dominance between two recs that are used by one SCEV,
802 // so we can safely sort recs by loop header dominance. We require such
803 // order in getAddExpr.
804 const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop();
805 if (LLoop != RLoop) {
806 const BasicBlock *LHead = LLoop->getHeader(), *RHead = RLoop->getHeader();
807 assert(LHead != RHead && "Two loops share the same header?");
808 if (DT.dominates(LHead, RHead))
809 return 1;
810 assert(DT.dominates(RHead, LHead) &&
811 "No dominance between recurrences used by one SCEV?");
812 return -1;
813 }
814
815 [[fallthrough]];
816 }
817
818 case scTruncate:
819 case scZeroExtend:
820 case scSignExtend:
821 case scPtrToAddr:
822 case scPtrToInt:
823 case scAddExpr:
824 case scMulExpr:
825 case scUDivExpr:
826 case scSMaxExpr:
827 case scUMaxExpr:
828 case scSMinExpr:
829 case scUMinExpr:
831 ArrayRef<SCEVUse> LOps = LHS->operands();
832 ArrayRef<SCEVUse> ROps = RHS->operands();
833
834 // Lexicographically compare n-ary-like expressions.
835 unsigned LNumOps = LOps.size(), RNumOps = ROps.size();
836 if (LNumOps != RNumOps)
837 return (int)LNumOps - (int)RNumOps;
838
839 for (unsigned i = 0; i != LNumOps; ++i) {
840 auto X = CompareSCEVComplexity(LI, LOps[i].getPointer(),
841 ROps[i].getPointer(), DT, Depth + 1);
842 if (X != 0)
843 return X;
844 }
845 return 0;
846 }
847
849 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
850 }
851 llvm_unreachable("Unknown SCEV kind!");
852}
853
854/// Given a list of SCEV objects, order them by their complexity, and group
855/// objects of the same complexity together by value. When this routine is
856/// finished, we know that any duplicates in the vector are consecutive and that
857/// complexity is monotonically increasing.
858///
859/// Note that we go take special precautions to ensure that we get deterministic
860/// results from this routine. In other words, we don't want the results of
861/// this to depend on where the addresses of various SCEV objects happened to
862/// land in memory.
864 DominatorTree &DT) {
865 if (Ops.size() < 2) return; // Noop
866
867 // Whether LHS has provably less complexity than RHS.
868 auto IsLessComplex = [&](SCEVUse LHS, SCEVUse RHS) {
869 auto Complexity = CompareSCEVComplexity(LI, LHS, RHS, DT);
870 return Complexity && *Complexity < 0;
871 };
872 if (Ops.size() == 2) {
873 // This is the common case, which also happens to be trivially simple.
874 // Special case it.
875 SCEVUse &LHS = Ops[0], &RHS = Ops[1];
876 if (IsLessComplex(RHS, LHS))
877 std::swap(LHS, RHS);
878 return;
879 }
880
881 // Do the rough sort by complexity.
883 Ops, [&](SCEVUse LHS, SCEVUse RHS) { return IsLessComplex(LHS, RHS); });
884
885 // Now that we are sorted by complexity, group elements of the same
886 // complexity. Note that this is, at worst, N^2, but the vector is likely to
887 // be extremely short in practice. Note that we take this approach because we
888 // do not want to depend on the addresses of the objects we are grouping.
889 for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) {
890 const SCEV *S = Ops[i];
891 unsigned Complexity = S->getSCEVType();
892
893 // If there are any objects of the same complexity and same value as this
894 // one, group them.
895 for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) {
896 if (Ops[j] == S) { // Found a duplicate.
897 // Move it to immediately after i'th element.
898 std::swap(Ops[i+1], Ops[j]);
899 ++i; // no need to rescan it.
900 if (i == e-2) return; // Done!
901 }
902 }
903 }
904}
905
906/// Returns true if \p Ops contains a huge SCEV (the subtree of S contains at
907/// least HugeExprThreshold nodes).
909 return any_of(Ops, [](const SCEV *S) {
911 });
912}
913
914/// Performs a number of common optimizations on the passed \p Ops. If the
915/// whole expression reduces down to a single operand, it will be returned.
916///
917/// The following optimizations are performed:
918/// * Fold constants using the \p Fold function.
919/// * Remove identity constants satisfying \p IsIdentity.
920/// * If a constant satisfies \p IsAbsorber, return it.
921/// * Sort operands by complexity.
922template <typename FoldT, typename IsIdentityT, typename IsAbsorberT>
923static const SCEV *
925 SmallVectorImpl<SCEVUse> &Ops, FoldT Fold,
926 IsIdentityT IsIdentity, IsAbsorberT IsAbsorber) {
927 const SCEVConstant *Folded = nullptr;
928 for (unsigned Idx = 0; Idx < Ops.size();) {
929 const SCEV *Op = Ops[Idx];
930 if (const auto *C = dyn_cast<SCEVConstant>(Op)) {
931 if (!Folded)
932 Folded = C;
933 else
934 Folded = cast<SCEVConstant>(
935 SE.getConstant(Fold(Folded->getAPInt(), C->getAPInt())));
936 Ops.erase(Ops.begin() + Idx);
937 continue;
938 }
939 ++Idx;
940 }
941
942 if (Ops.empty()) {
943 assert(Folded && "Must have folded value");
944 return Folded;
945 }
946
947 if (Folded && IsAbsorber(Folded->getAPInt()))
948 return Folded;
949
950 GroupByComplexity(Ops, &LI, DT);
951 if (Folded && !IsIdentity(Folded->getAPInt()))
952 Ops.insert(Ops.begin(), Folded);
953
954 return Ops.size() == 1 ? Ops[0] : nullptr;
955}
956
957//===----------------------------------------------------------------------===//
958// Simple SCEV method implementations
959//===----------------------------------------------------------------------===//
960
961/// Compute BC(It, K). The result has width W. Assume, K > 0.
962static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
963 ScalarEvolution &SE,
964 Type *ResultTy) {
965 // Handle the simplest case efficiently.
966 if (K == 1)
967 return SE.getTruncateOrZeroExtend(It, ResultTy);
968
969 // We are using the following formula for BC(It, K):
970 //
971 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
972 //
973 // Suppose, W is the bitwidth of the return value. We must be prepared for
974 // overflow. Hence, we must assure that the result of our computation is
975 // equal to the accurate one modulo 2^W. Unfortunately, division isn't
976 // safe in modular arithmetic.
977 //
978 // However, this code doesn't use exactly that formula; the formula it uses
979 // is something like the following, where T is the number of factors of 2 in
980 // K! (i.e. trailing zeros in the binary representation of K!), and ^ is
981 // exponentiation:
982 //
983 // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T)
984 //
985 // This formula is trivially equivalent to the previous formula. However,
986 // this formula can be implemented much more efficiently. The trick is that
987 // K! / 2^T is odd, and exact division by an odd number *is* safe in modular
988 // arithmetic. To do exact division in modular arithmetic, all we have
989 // to do is multiply by the inverse. Therefore, this step can be done at
990 // width W.
991 //
992 // The next issue is how to safely do the division by 2^T. The way this
993 // is done is by doing the multiplication step at a width of at least W + T
994 // bits. This way, the bottom W+T bits of the product are accurate. Then,
995 // when we perform the division by 2^T (which is equivalent to a right shift
996 // by T), the bottom W bits are accurate. Extra bits are okay; they'll get
997 // truncated out after the division by 2^T.
998 //
999 // In comparison to just directly using the first formula, this technique
1000 // is much more efficient; using the first formula requires W * K bits,
1001 // but this formula less than W + K bits. Also, the first formula requires
1002 // a division step, whereas this formula only requires multiplies and shifts.
1003 //
1004 // It doesn't matter whether the subtraction step is done in the calculation
1005 // width or the input iteration count's width; if the subtraction overflows,
1006 // the result must be zero anyway. We prefer here to do it in the width of
1007 // the induction variable because it helps a lot for certain cases; CodeGen
1008 // isn't smart enough to ignore the overflow, which leads to much less
1009 // efficient code if the width of the subtraction is wider than the native
1010 // register width.
1011 //
1012 // (It's possible to not widen at all by pulling out factors of 2 before
1013 // the multiplication; for example, K=2 can be calculated as
1014 // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires
1015 // extra arithmetic, so it's not an obvious win, and it gets
1016 // much more complicated for K > 3.)
1017
1018 // Protection from insane SCEVs; this bound is conservative,
1019 // but it probably doesn't matter.
1020 if (K > 1000)
1021 return SE.getCouldNotCompute();
1022
1023 unsigned W = SE.getTypeSizeInBits(ResultTy);
1024
1025 // Calculate K! / 2^T and T; we divide out the factors of two before
1026 // multiplying for calculating K! / 2^T to avoid overflow.
1027 // Other overflow doesn't matter because we only care about the bottom
1028 // W bits of the result.
1029 APInt OddFactorial(W, 1);
1030 unsigned T = 1;
1031 for (unsigned i = 3; i <= K; ++i) {
1032 unsigned TwoFactors = countr_zero(i);
1033 T += TwoFactors;
1034 OddFactorial *= (i >> TwoFactors);
1035 }
1036
1037 // We need at least W + T bits for the multiplication step
1038 unsigned CalculationBits = W + T;
1039
1040 // Calculate 2^T, at width T+W.
1041 APInt DivFactor = APInt::getOneBitSet(CalculationBits, T);
1042
1043 // Calculate the multiplicative inverse of K! / 2^T;
1044 // this multiplication factor will perform the exact division by
1045 // K! / 2^T.
1046 APInt MultiplyFactor = OddFactorial.multiplicativeInverse();
1047
1048 // Calculate the product, at width T+W
1049 IntegerType *CalculationTy = IntegerType::get(SE.getContext(),
1050 CalculationBits);
1051 const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
1052 for (unsigned i = 1; i != K; ++i) {
1053 const SCEV *S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i));
1054 Dividend = SE.getMulExpr(Dividend,
1055 SE.getTruncateOrZeroExtend(S, CalculationTy));
1056 }
1057
1058 // Divide by 2^T
1059 const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
1060
1061 // Truncate the result, and divide by K! / 2^T.
1062
1063 return SE.getMulExpr(SE.getConstant(MultiplyFactor),
1064 SE.getTruncateOrZeroExtend(DivResult, ResultTy));
1065}
1066
1067/// Return the value of this chain of recurrences at the specified iteration
1068/// number. We can evaluate this recurrence by multiplying each element in the
1069/// chain by the binomial coefficient corresponding to it. In other words, we
1070/// can evaluate {A,+,B,+,C,+,D} as:
1071///
1072/// A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
1073///
1074/// where BC(It, k) stands for binomial coefficient.
1076 ScalarEvolution &SE) const {
1077 return evaluateAtIteration(operands(), It, SE);
1078}
1079
1081 const SCEV *It,
1082 ScalarEvolution &SE) {
1083 assert(Operands.size() > 0);
1084 const SCEV *Result = Operands[0].getPointer();
1085 for (unsigned i = 1, e = Operands.size(); i != e; ++i) {
1086 // The computation is correct in the face of overflow provided that the
1087 // multiplication is performed _after_ the evaluation of the binomial
1088 // coefficient.
1089 const SCEV *Coeff = BinomialCoefficient(It, i, SE, Result->getType());
1090 if (isa<SCEVCouldNotCompute>(Coeff))
1091 return Coeff;
1092
1093 Result =
1094 SE.getAddExpr(Result, SE.getMulExpr(Operands[i].getPointer(), Coeff));
1095 }
1096 return Result;
1097}
1098
1099//===----------------------------------------------------------------------===//
1100// SCEV Expression folder implementations
1101//===----------------------------------------------------------------------===//
1102
1103/// The SCEVCastSinkingRewriter takes a scalar evolution expression,
1104/// which computes a pointer-typed value, and rewrites the whole expression
1105/// tree so that *all* the computations are done on integers, and the only
1106/// pointer-typed operands in the expression are SCEVUnknown.
1107/// The CreatePtrCast callback is invoked to create the actual conversion
1108/// (ptrtoint or ptrtoaddr) at the SCEVUnknown leaves.
1110 : public SCEVRewriteVisitor<SCEVCastSinkingRewriter> {
1112 using ConversionFn = function_ref<const SCEV *(const SCEVUnknown *)>;
1113 Type *TargetTy;
1114 ConversionFn CreatePtrCast;
1115
1116public:
1118 ConversionFn CreatePtrCast)
1119 : Base(SE), TargetTy(TargetTy), CreatePtrCast(std::move(CreatePtrCast)) {}
1120
1121 static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE,
1122 Type *TargetTy, ConversionFn CreatePtrCast) {
1123 SCEVCastSinkingRewriter Rewriter(SE, TargetTy, std::move(CreatePtrCast));
1124 return Rewriter.visit(Scev);
1125 }
1126
1127 const SCEV *visit(const SCEV *S) {
1128 Type *STy = S->getType();
1129 // If the expression is not pointer-typed, just keep it as-is.
1130 if (!STy->isPointerTy())
1131 return S;
1132 // Else, recursively sink the cast down into it.
1133 return Base::visit(S);
1134 }
1135
1136 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
1137 // Preserve wrap flags on rewritten SCEVAddExpr, which the default
1138 // implementation drops.
1139 SmallVector<SCEVUse, 2> Operands;
1140 bool Changed = false;
1141 for (SCEVUse Op : Expr->operands()) {
1142 Operands.push_back(visit(Op.getPointer()));
1143 Changed |= Op.getPointer() != Operands.back();
1144 }
1145 return !Changed ? Expr : SE.getAddExpr(Operands, Expr->getNoWrapFlags());
1146 }
1147
1148 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
1149 SmallVector<SCEVUse, 2> Operands;
1150 bool Changed = false;
1151 for (SCEVUse Op : Expr->operands()) {
1152 Operands.push_back(visit(Op.getPointer()));
1153 Changed |= Op.getPointer() != Operands.back();
1154 }
1155 return !Changed ? Expr : SE.getMulExpr(Operands, Expr->getNoWrapFlags());
1156 }
1157
1158 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
1159 assert(Expr->getType()->isPointerTy() &&
1160 "Should only reach pointer-typed SCEVUnknown's.");
1161 // Perform some basic constant folding. If the operand of the cast is a
1162 // null pointer, don't create a cast SCEV expression (that will be left
1163 // as-is), but produce a zero constant.
1165 return SE.getZero(TargetTy);
1166 return CreatePtrCast(Expr);
1167 }
1168};
1169
1171 assert(Op->getType()->isPointerTy() && "Op must be a pointer");
1172
1173 // It isn't legal for optimizations to construct new ptrtoint expressions
1174 // for non-integral pointers.
1175 if (getDataLayout().isNonIntegralPointerType(Op->getType()))
1176 return getCouldNotCompute();
1177
1178 Type *IntPtrTy = getDataLayout().getIntPtrType(Op->getType());
1179
1180 // We can only trivially model ptrtoint if SCEV's effective (integer) type
1181 // is sufficiently wide to represent all possible pointer values.
1182 // We could theoretically teach SCEV to truncate wider pointers, but
1183 // that isn't implemented for now.
1185 getDataLayout().getTypeSizeInBits(IntPtrTy))
1186 return getCouldNotCompute();
1187
1188 // Use the rewriter to sink the cast down to SCEVUnknown leaves.
1190 Op, *this, IntPtrTy, [this, IntPtrTy](const SCEVUnknown *U) {
1192 ID.AddInteger(scPtrToInt);
1193 ID.AddPointer(U);
1194 void *IP = nullptr;
1195 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1196 return S;
1197 SCEV *S = new (SCEVAllocator)
1198 SCEVPtrToIntExpr(ID.Intern(SCEVAllocator), U, IntPtrTy);
1199 UniqueSCEVs.InsertNode(S, IP);
1200 S->computeAndSetCanonical(*this);
1201 registerUser(S, U);
1202 return static_cast<const SCEV *>(S);
1203 });
1204 assert(IntOp->getType()->isIntegerTy() &&
1205 "We must have succeeded in sinking the cast, "
1206 "and ending up with an integer-typed expression!");
1207 return IntOp;
1208}
1209
1211 assert(Op->getType()->isPointerTy() && "Op must be a pointer");
1212
1213 // Treat pointers with unstable representation conservatively, since the
1214 // address bits may change.
1215 if (DL.hasUnstableRepresentation(Op->getType()))
1216 return getCouldNotCompute();
1217
1218 Type *Ty = DL.getAddressType(Op->getType());
1219
1220 // Use the rewriter to sink the cast down to SCEVUnknown leaves.
1221 // The rewriter handles null pointer constant folding.
1223 Op, *this, Ty, [this, Ty](const SCEVUnknown *U) {
1225 ID.AddInteger(scPtrToAddr);
1226 ID.AddPointer(U);
1227 ID.AddPointer(Ty);
1228 void *IP = nullptr;
1229 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1230 return S;
1231 SCEV *S = new (SCEVAllocator)
1232 SCEVPtrToAddrExpr(ID.Intern(SCEVAllocator), U, Ty);
1233 UniqueSCEVs.InsertNode(S, IP);
1234 S->computeAndSetCanonical(*this);
1235 registerUser(S, U);
1236 return static_cast<const SCEV *>(S);
1237 });
1238 assert(IntOp->getType()->isIntegerTy() &&
1239 "We must have succeeded in sinking the cast, "
1240 "and ending up with an integer-typed expression!");
1241 return IntOp;
1242}
1243
1245 assert(Ty->isIntegerTy() && "Target type must be an integer type!");
1246
1247 const SCEV *IntOp = getLosslessPtrToIntExpr(Op);
1248 if (isa<SCEVCouldNotCompute>(IntOp))
1249 return IntOp;
1250
1251 return getTruncateOrZeroExtend(IntOp, Ty);
1252}
1253
1255 unsigned Depth) {
1256 assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) &&
1257 "This is not a truncating conversion!");
1258 assert(isSCEVable(Ty) &&
1259 "This is not a conversion to a SCEVable type!");
1260 assert(!Op->getType()->isPointerTy() && "Can't truncate pointer!");
1261 Ty = getEffectiveSCEVType(Ty);
1262
1264 ID.AddInteger(scTruncate);
1265 ID.AddPointer(Op);
1266 ID.AddPointer(Ty);
1267 void *IP = nullptr;
1268 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1269
1270 // Fold if the operand is constant.
1271 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1272 return getConstant(
1273 cast<ConstantInt>(ConstantExpr::getTrunc(SC->getValue(), Ty)));
1274
1275 // trunc(trunc(x)) --> trunc(x)
1277 return getTruncateExpr(ST->getOperand(), Ty, Depth + 1);
1278
1279 // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing
1281 return getTruncateOrSignExtend(SS->getOperand(), Ty, Depth + 1);
1282
1283 // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing
1285 return getTruncateOrZeroExtend(SZ->getOperand(), Ty, Depth + 1);
1286
1287 if (Depth > MaxCastDepth) {
1288 SCEV *S =
1289 new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), Op, Ty);
1290 UniqueSCEVs.InsertNode(S, IP);
1291 S->computeAndSetCanonical(*this);
1292 registerUser(S, Op);
1293 return S;
1294 }
1295
1296 // trunc(x1 + ... + xN) --> trunc(x1) + ... + trunc(xN) and
1297 // trunc(x1 * ... * xN) --> trunc(x1) * ... * trunc(xN),
1298 // if after transforming we have at most one truncate, not counting truncates
1299 // that replace other casts.
1301 auto *CommOp = cast<SCEVCommutativeExpr>(Op);
1302 SmallVector<SCEVUse, 4> Operands;
1303 unsigned numTruncs = 0;
1304 for (unsigned i = 0, e = CommOp->getNumOperands(); i != e && numTruncs < 2;
1305 ++i) {
1306 const SCEV *S = getTruncateExpr(CommOp->getOperand(i), Ty, Depth + 1);
1307 if (!isa<SCEVIntegralCastExpr>(CommOp->getOperand(i)) &&
1309 numTruncs++;
1310 Operands.push_back(S);
1311 }
1312 if (numTruncs < 2) {
1313 if (isa<SCEVAddExpr>(Op))
1314 return getAddExpr(Operands);
1315 if (isa<SCEVMulExpr>(Op))
1316 return getMulExpr(Operands);
1317 llvm_unreachable("Unexpected SCEV type for Op.");
1318 }
1319 // Although we checked in the beginning that ID is not in the cache, it is
1320 // possible that during recursion and different modification ID was inserted
1321 // into the cache. So if we find it, just return it.
1322 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1323 return S;
1324 }
1325
1326 // If the input value is a chrec scev, truncate the chrec's operands.
1327 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
1328 SmallVector<SCEVUse, 4> Operands;
1329 for (const SCEV *Op : AddRec->operands())
1330 Operands.push_back(getTruncateExpr(Op, Ty, Depth + 1));
1331 return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap);
1332 }
1333
1334 // Return zero if truncating to known zeros.
1335 uint32_t MinTrailingZeros = getMinTrailingZeros(Op);
1336 if (MinTrailingZeros >= getTypeSizeInBits(Ty))
1337 return getZero(Ty);
1338
1339 // The cast wasn't folded; create an explicit cast node. We can reuse
1340 // the existing insert position since if we get here, we won't have
1341 // made any changes which would invalidate it.
1342 SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator),
1343 Op, Ty);
1344 UniqueSCEVs.InsertNode(S, IP);
1345 S->computeAndSetCanonical(*this);
1346 registerUser(S, Op);
1347 return S;
1348}
1349
1350// Get the limit of a recurrence such that incrementing by Step cannot cause
1351// signed overflow as long as the value of the recurrence within the
1352// loop does not exceed this limit before incrementing.
1353static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step,
1354 ICmpInst::Predicate *Pred,
1355 ScalarEvolution *SE) {
1356 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1357 if (SE->isKnownPositive(Step)) {
1358 *Pred = ICmpInst::ICMP_SLT;
1360 SE->getSignedRangeMax(Step));
1361 }
1362 if (SE->isKnownNegative(Step)) {
1363 *Pred = ICmpInst::ICMP_SGT;
1365 SE->getSignedRangeMin(Step));
1366 }
1367 return nullptr;
1368}
1369
1370// Get the limit of a recurrence such that incrementing by Step cannot cause
1371// unsigned overflow as long as the value of the recurrence within the loop does
1372// not exceed this limit before incrementing.
1374 ICmpInst::Predicate *Pred,
1375 ScalarEvolution *SE) {
1376 unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1377 *Pred = ICmpInst::ICMP_ULT;
1378
1380 SE->getUnsignedRangeMax(Step));
1381}
1382
1383namespace {
1384
1385struct ExtendOpTraitsBase {
1386 typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *,
1387 unsigned);
1388};
1389
1390// Used to make code generic over signed and unsigned overflow.
1391template <typename ExtendOp> struct ExtendOpTraits {
1392 // Members present:
1393 //
1394 // static const SCEV::NoWrapFlags WrapType;
1395 //
1396 // static const ExtendOpTraitsBase::GetExtendExprTy GetExtendExpr;
1397 //
1398 // static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1399 // ICmpInst::Predicate *Pred,
1400 // ScalarEvolution *SE);
1401};
1402
1403template <>
1404struct ExtendOpTraits<SCEVSignExtendExpr> : public ExtendOpTraitsBase {
1405 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNSW;
1406
1407 static const GetExtendExprTy GetExtendExpr;
1408
1409 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1410 ICmpInst::Predicate *Pred,
1411 ScalarEvolution *SE) {
1412 return getSignedOverflowLimitForStep(Step, Pred, SE);
1413 }
1414};
1415
1416const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1418
1419template <>
1420struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase {
1421 static const SCEV::NoWrapFlags WrapType = SCEV::FlagNUW;
1422
1423 static const GetExtendExprTy GetExtendExpr;
1424
1425 static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1426 ICmpInst::Predicate *Pred,
1427 ScalarEvolution *SE) {
1428 return getUnsignedOverflowLimitForStep(Step, Pred, SE);
1429 }
1430};
1431
1432const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1434
1435} // end anonymous namespace
1436
1437// The recurrence AR has been shown to have no signed/unsigned wrap or something
1438// close to it. Typically, if we can prove NSW/NUW for AR, then we can just as
1439// easily prove NSW/NUW for its preincrement or postincrement sibling. This
1440// allows normalizing a sign/zero extended AddRec as such: {sext/zext(Step +
1441// Start),+,Step} => {(Step + sext/zext(Start),+,Step} As a result, the
1442// expression "Step + sext/zext(PreIncAR)" is congruent with
1443// "sext/zext(PostIncAR)"
1444template <typename ExtendOpTy>
1445static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty,
1446 ScalarEvolution *SE, unsigned Depth) {
1447 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1448 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1449
1450 const Loop *L = AR->getLoop();
1451 const SCEV *Start = AR->getStart();
1452 const SCEV *Step = AR->getStepRecurrence(*SE);
1453
1454 // Check for a simple looking step prior to loop entry.
1455 const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Start);
1456 if (!SA)
1457 return nullptr;
1458
1459 // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV
1460 // subtraction is expensive. For this purpose, perform a quick and dirty
1461 // difference, by checking for Step in the operand list. Note, that
1462 // SA might have repeated ops, like %a + %a + ..., so only remove one.
1463 SmallVector<SCEVUse, 4> DiffOps(SA->operands());
1464 for (auto It = DiffOps.begin(); It != DiffOps.end(); ++It)
1465 if (*It == Step) {
1466 DiffOps.erase(It);
1467 break;
1468 }
1469
1470 if (DiffOps.size() == SA->getNumOperands())
1471 return nullptr;
1472
1473 // Try to prove `WrapType` (SCEV::FlagNSW or SCEV::FlagNUW) on `PreStart` +
1474 // `Step`:
1475
1476 // 1. NSW/NUW flags on the step increment.
1477 auto PreStartFlags =
1479 const SCEV *PreStart = SE->getAddExpr(DiffOps, PreStartFlags);
1481 SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap));
1482
1483 // "{S,+,X} is <nsw>/<nuw>" and "the backedge is taken at least once" implies
1484 // "S+X does not sign/unsign-overflow".
1485 //
1486
1487 const SCEV *BECount = SE->getBackedgeTakenCount(L);
1488 if (PreAR && any(PreAR->getNoWrapFlags(WrapType)) &&
1489 !isa<SCEVCouldNotCompute>(BECount) && SE->isKnownPositive(BECount))
1490 return PreStart;
1491
1492 // 2. Direct overflow check on the step operation's expression.
1493 unsigned BitWidth = SE->getTypeSizeInBits(AR->getType());
1494 Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2);
1495 const SCEV *OperandExtendedStart =
1496 SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy, Depth),
1497 (SE->*GetExtendExpr)(Step, WideTy, Depth));
1498 if ((SE->*GetExtendExpr)(Start, WideTy, Depth) == OperandExtendedStart) {
1499 if (PreAR && any(AR->getNoWrapFlags(WrapType))) {
1500 // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW
1501 // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then
1502 // `PreAR` == {`PreStart`,+,`Step`} is also `WrapType`. Cache this fact.
1503 SE->setNoWrapFlags(const_cast<SCEVAddRecExpr *>(PreAR), WrapType);
1504 }
1505 return PreStart;
1506 }
1507
1508 // 3. Loop precondition.
1510 const SCEV *OverflowLimit =
1511 ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(Step, &Pred, SE);
1512
1513 if (OverflowLimit &&
1514 SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit))
1515 return PreStart;
1516
1517 return nullptr;
1518}
1519
1520// Get the normalized zero or sign extended expression for this AddRec's Start.
1521template <typename ExtendOpTy>
1522static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty,
1523 ScalarEvolution *SE,
1524 unsigned Depth) {
1525 auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1526
1527 const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE, Depth);
1528 if (!PreStart)
1529 return (SE->*GetExtendExpr)(AR->getStart(), Ty, Depth);
1530
1531 return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty,
1532 Depth),
1533 (SE->*GetExtendExpr)(PreStart, Ty, Depth));
1534}
1535
1536// Try to prove away overflow by looking at "nearby" add recurrences. A
1537// motivating example for this rule: if we know `{0,+,4}` is `ult` `-1` and it
1538// does not itself wrap then we can conclude that `{1,+,4}` is `nuw`.
1539//
1540// Formally:
1541//
1542// {S,+,X} == {S-T,+,X} + T
1543// => Ext({S,+,X}) == Ext({S-T,+,X} + T)
1544//
1545// If ({S-T,+,X} + T) does not overflow ... (1)
1546//
1547// RHS == Ext({S-T,+,X} + T) == Ext({S-T,+,X}) + Ext(T)
1548//
1549// If {S-T,+,X} does not overflow ... (2)
1550//
1551// RHS == Ext({S-T,+,X}) + Ext(T) == {Ext(S-T),+,Ext(X)} + Ext(T)
1552// == {Ext(S-T)+Ext(T),+,Ext(X)}
1553//
1554// If (S-T)+T does not overflow ... (3)
1555//
1556// RHS == {Ext(S-T)+Ext(T),+,Ext(X)} == {Ext(S-T+T),+,Ext(X)}
1557// == {Ext(S),+,Ext(X)} == LHS
1558//
1559// Thus, if (1), (2) and (3) are true for some T, then
1560// Ext({S,+,X}) == {Ext(S),+,Ext(X)}
1561//
1562// (3) is implied by (1) -- "(S-T)+T does not overflow" is simply "({S-T,+,X}+T)
1563// does not overflow" restricted to the 0th iteration. Therefore we only need
1564// to check for (1) and (2).
1565//
1566// In the current context, S is `Start`, X is `Step`, Ext is `ExtendOpTy` and T
1567// is `Delta` (defined below).
1568template <typename ExtendOpTy>
1569bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start,
1570 const SCEV *Step,
1571 const Loop *L) {
1572 auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1573
1574 // We restrict `Start` to a constant to prevent SCEV from spending too much
1575 // time here. It is correct (but more expensive) to continue with a
1576 // non-constant `Start` and do a general SCEV subtraction to compute
1577 // `PreStart` below.
1578 const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start);
1579 if (!StartC)
1580 return false;
1581
1582 APInt StartAI = StartC->getAPInt();
1583
1584 for (unsigned Delta : {-2, -1, 1, 2}) {
1585 const SCEV *PreStart = getConstant(StartAI - Delta);
1586
1587 FoldingSetNodeID ID;
1588 ID.AddInteger(scAddRecExpr);
1589 ID.AddPointer(PreStart);
1590 ID.AddPointer(Step);
1591 ID.AddPointer(L);
1592 void *IP = nullptr;
1593 const auto *PreAR =
1594 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
1595
1596 // Give up if we don't already have the add recurrence we need because
1597 // actually constructing an add recurrence is relatively expensive.
1598 if (PreAR && any(PreAR->getNoWrapFlags(WrapType))) { // proves (2)
1599 const SCEV *DeltaS = getConstant(StartC->getType(), Delta);
1601 const SCEV *Limit = ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(
1602 DeltaS, &Pred, this);
1603 if (Limit && isKnownPredicate(Pred, PreAR, Limit)) // proves (1)
1604 return true;
1605 }
1606 }
1607
1608 return false;
1609}
1610
1611// Finds an integer D for an expression (C + x + y + ...) such that the top
1612// level addition in (D + (C - D + x + y + ...)) would not wrap (signed or
1613// unsigned) and the number of trailing zeros of (C - D + x + y + ...) is
1614// maximized, where C is the \p ConstantTerm, x, y, ... are arbitrary SCEVs, and
1615// the (C + x + y + ...) expression is \p WholeAddExpr.
1617 const SCEVConstant *ConstantTerm,
1618 const SCEVAddExpr *WholeAddExpr) {
1619 const APInt &C = ConstantTerm->getAPInt();
1620 const unsigned BitWidth = C.getBitWidth();
1621 // Find number of trailing zeros of (x + y + ...) w/o the C first:
1622 uint32_t TZ = BitWidth;
1623 for (unsigned I = 1, E = WholeAddExpr->getNumOperands(); I < E && TZ; ++I)
1624 TZ = std::min(TZ, SE.getMinTrailingZeros(WholeAddExpr->getOperand(I)));
1625 if (TZ) {
1626 // Set D to be as many least significant bits of C as possible while still
1627 // guaranteeing that adding D to (C - D + x + y + ...) won't cause a wrap:
1628 return TZ < BitWidth ? C.trunc(TZ).zext(BitWidth) : C;
1629 }
1630 return APInt(BitWidth, 0);
1631}
1632
1633// Finds an integer D for an affine AddRec expression {C,+,x} such that the top
1634// level addition in (D + {C-D,+,x}) would not wrap (signed or unsigned) and the
1635// number of trailing zeros of (C - D + x * n) is maximized, where C is the \p
1636// ConstantStart, x is an arbitrary \p Step, and n is the loop trip count.
1638 const APInt &ConstantStart,
1639 const SCEV *Step) {
1640 const unsigned BitWidth = ConstantStart.getBitWidth();
1641 const uint32_t TZ = SE.getMinTrailingZeros(Step);
1642 if (TZ)
1643 return TZ < BitWidth ? ConstantStart.trunc(TZ).zext(BitWidth)
1644 : ConstantStart;
1645 return APInt(BitWidth, 0);
1646}
1647
1649 const ScalarEvolution::FoldID &ID, const SCEV *S,
1652 &FoldCacheUser) {
1653 auto I = FoldCache.insert({ID, S});
1654 if (!I.second) {
1655 // Remove FoldCacheUser entry for ID when replacing an existing FoldCache
1656 // entry.
1657 auto &UserIDs = FoldCacheUser[I.first->second];
1658 assert(count(UserIDs, ID) == 1 && "unexpected duplicates in UserIDs");
1659 for (unsigned I = 0; I != UserIDs.size(); ++I)
1660 if (UserIDs[I] == ID) {
1661 std::swap(UserIDs[I], UserIDs.back());
1662 break;
1663 }
1664 UserIDs.pop_back();
1665 I.first->second = S;
1666 }
1667 FoldCacheUser[S].push_back(ID);
1668}
1669
1670const SCEV *
1672 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1673 "This is not an extending conversion!");
1674 assert(isSCEVable(Ty) &&
1675 "This is not a conversion to a SCEVable type!");
1676 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1677 Ty = getEffectiveSCEVType(Ty);
1678
1679 FoldID ID(scZeroExtend, Op, Ty);
1680 if (const SCEV *S = FoldCache.lookup(ID))
1681 return S;
1682
1683 const SCEV *S = getZeroExtendExprImpl(Op, Ty, Depth);
1685 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
1686 return S;
1687}
1688
1690 unsigned Depth) {
1691 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1692 "This is not an extending conversion!");
1693 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
1694 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1695
1696 // Fold if the operand is constant.
1697 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1698 return getConstant(SC->getAPInt().zext(getTypeSizeInBits(Ty)));
1699
1700 // zext(zext(x)) --> zext(x)
1702 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1703
1704 // Before doing any expensive analysis, check to see if we've already
1705 // computed a SCEV for this Op and Ty.
1707 ID.AddInteger(scZeroExtend);
1708 ID.AddPointer(Op);
1709 ID.AddPointer(Ty);
1710 void *IP = nullptr;
1711 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1712 if (Depth > MaxCastDepth) {
1713 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1714 Op, Ty);
1715 UniqueSCEVs.InsertNode(S, IP);
1716 S->computeAndSetCanonical(*this);
1717 registerUser(S, Op);
1718 return S;
1719 }
1720
1721 // zext(trunc(x)) --> zext(x) or x or trunc(x)
1723 // It's possible the bits taken off by the truncate were all zero bits. If
1724 // so, we should be able to simplify this further.
1725 const SCEV *X = ST->getOperand();
1727 unsigned TruncBits = getTypeSizeInBits(ST->getType());
1728 unsigned NewBits = getTypeSizeInBits(Ty);
1729 if (CR.truncate(TruncBits).zeroExtend(NewBits).contains(
1730 CR.zextOrTrunc(NewBits)))
1731 return getTruncateOrZeroExtend(X, Ty, Depth);
1732 }
1733
1734 // If the input value is a chrec scev, and we can prove that the value
1735 // did not overflow the old, smaller, value, we can zero extend all of the
1736 // operands (often constants). This allows analysis of something like
1737 // this: for (unsigned char X = 0; X < 100; ++X) { int Y = X; }
1739 if (AR->isAffine()) {
1740 const SCEV *Start = AR->getStart();
1741 const SCEV *Step = AR->getStepRecurrence(*this);
1742 unsigned BitWidth = getTypeSizeInBits(AR->getType());
1743 const Loop *L = AR->getLoop();
1744
1745 // If we have special knowledge that this addrec won't overflow,
1746 // we don't need to do any further analysis.
1747 if (AR->hasNoUnsignedWrap()) {
1748 Start =
1750 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1751 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1752 }
1753
1754 // Check whether the backedge-taken count is SCEVCouldNotCompute.
1755 // Note that this serves two purposes: It filters out loops that are
1756 // simply not analyzable, and it covers the case where this code is
1757 // being called from within backedge-taken count analysis, such that
1758 // attempting to ask for the backedge-taken count would likely result
1759 // in infinite recursion. In the later case, the analysis code will
1760 // cope with a conservative value, and it will take care to purge
1761 // that value once it has finished.
1762 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
1763 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
1764 // Manually compute the final value for AR, checking for overflow.
1765
1766 // Check whether the backedge-taken count can be losslessly casted to
1767 // the addrec's type. The count is always unsigned.
1768 const SCEV *CastedMaxBECount =
1769 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
1770 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
1771 CastedMaxBECount, MaxBECount->getType(), Depth);
1772 if (MaxBECount == RecastedMaxBECount) {
1773 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
1774 // Check whether Start+Step*MaxBECount has no unsigned overflow.
1775 const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step,
1777 const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul,
1779 Depth + 1),
1780 WideTy, Depth + 1);
1781 const SCEV *WideStart = getZeroExtendExpr(Start, WideTy, Depth + 1);
1782 const SCEV *WideMaxBECount =
1783 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
1784 const SCEV *OperandExtendedAdd =
1785 getAddExpr(WideStart,
1786 getMulExpr(WideMaxBECount,
1787 getZeroExtendExpr(Step, WideTy, Depth + 1),
1790 if (ZAdd == OperandExtendedAdd) {
1791 // Cache knowledge of AR NUW, which is propagated to this AddRec.
1792 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1793 // Return the expression with the addrec on the outside.
1794 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1795 Depth + 1);
1796 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1797 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1798 }
1799 // Similar to above, only this time treat the step value as signed.
1800 // This covers loops that count down.
1801 OperandExtendedAdd =
1802 getAddExpr(WideStart,
1803 getMulExpr(WideMaxBECount,
1804 getSignExtendExpr(Step, WideTy, Depth + 1),
1807 if (ZAdd == OperandExtendedAdd) {
1808 // Cache knowledge of AR NW, which is propagated to this AddRec.
1809 // Negative step causes unsigned wrap, but it still can't self-wrap.
1810 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1811 // Return the expression with the addrec on the outside.
1812 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1813 Depth + 1);
1814 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1815 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1816 }
1817 }
1818 }
1819
1820 // Normally, in the cases we can prove no-overflow via a
1821 // backedge guarding condition, we can also compute a backedge
1822 // taken count for the loop. The exceptions are assumptions and
1823 // guards present in the loop -- SCEV is not great at exploiting
1824 // these to compute max backedge taken counts, but can still use
1825 // these to prove lack of overflow. Use this fact to avoid
1826 // doing extra work that may not pay off.
1827 if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards ||
1828 !AC.assumptions().empty()) {
1829
1830 auto NewFlags = proveNoUnsignedWrapViaInduction(AR);
1831 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
1832 if (AR->hasNoUnsignedWrap()) {
1833 // Same as nuw case above - duplicated here to avoid a compile time
1834 // issue. It's not clear that the order of checks does matter, but
1835 // it's one of two issue possible causes for a change which was
1836 // reverted. Be conservative for the moment.
1837 Start =
1839 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1840 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1841 }
1842
1843 // For a negative step, we can extend the operands iff doing so only
1844 // traverses values in the range zext([0,UINT_MAX]).
1845 if (isKnownNegative(Step)) {
1847 getSignedRangeMin(Step));
1850 // Cache knowledge of AR NW, which is propagated to this
1851 // AddRec. Negative step causes unsigned wrap, but it
1852 // still can't self-wrap.
1853 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1854 // Return the expression with the addrec on the outside.
1855 Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1856 Depth + 1);
1857 Step = getSignExtendExpr(Step, Ty, Depth + 1);
1858 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1859 }
1860 }
1861 }
1862
1863 // zext({C,+,Step}) --> (zext(D) + zext({C-D,+,Step}))<nuw><nsw>
1864 // if D + (C - D + Step * n) could be proven to not unsigned wrap
1865 // where D maximizes the number of trailing zeros of (C - D + Step * n)
1866 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
1867 const APInt &C = SC->getAPInt();
1868 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
1869 if (D != 0) {
1870 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1871 const SCEV *SResidual =
1872 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
1873 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1874 return getAddExpr(SZExtD, SZExtR, SCEV::FlagNSW | SCEV::FlagNUW,
1875 Depth + 1);
1876 }
1877 }
1878
1879 if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) {
1880 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1881 Start =
1883 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1884 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1885 }
1886 }
1887
1888 // zext(A % B) --> zext(A) % zext(B)
1889 {
1890 const SCEV *LHS;
1891 const SCEV *RHS;
1892 if (match(Op, m_scev_URem(m_SCEV(LHS), m_SCEV(RHS), *this)))
1893 return getURemExpr(getZeroExtendExpr(LHS, Ty, Depth + 1),
1894 getZeroExtendExpr(RHS, Ty, Depth + 1));
1895 }
1896
1897 // zext(A / B) --> zext(A) / zext(B).
1898 if (auto *Div = dyn_cast<SCEVUDivExpr>(Op))
1899 return getUDivExpr(getZeroExtendExpr(Div->getLHS(), Ty, Depth + 1),
1900 getZeroExtendExpr(Div->getRHS(), Ty, Depth + 1));
1901
1902 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1903 // zext((A + B + ...)<nuw>) --> (zext(A) + zext(B) + ...)<nuw>
1904 if (SA->hasNoUnsignedWrap()) {
1905 // If the addition does not unsign overflow then we can, by definition,
1906 // commute the zero extension with the addition operation.
1908 for (SCEVUse Op : SA->operands())
1909 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1910 return getAddExpr(Ops, SCEV::FlagNUW, Depth + 1);
1911 }
1912
1913 // zext(C + x + y + ...) --> (zext(D) + zext((C - D) + x + y + ...))
1914 // if D + (C - D + x + y + ...) could be proven to not unsigned wrap
1915 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1916 //
1917 // Often address arithmetics contain expressions like
1918 // (zext (add (shl X, C1), C2)), for instance, (zext (5 + (4 * X))).
1919 // This transformation is useful while proving that such expressions are
1920 // equal or differ by a small constant amount, see LoadStoreVectorizer pass.
1921 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1922 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1923 if (D != 0) {
1924 const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1925 const SCEV *SResidual =
1927 const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1928 return getAddExpr(SZExtD, SZExtR, (SCEV::FlagNSW | SCEV::FlagNUW),
1929 Depth + 1);
1930 }
1931 }
1932 }
1933
1934 if (auto *SM = dyn_cast<SCEVMulExpr>(Op)) {
1935 // zext((A * B * ...)<nuw>) --> (zext(A) * zext(B) * ...)<nuw>
1936 if (SM->hasNoUnsignedWrap()) {
1937 // If the multiply does not unsign overflow then we can, by definition,
1938 // commute the zero extension with the multiply operation.
1940 for (SCEVUse Op : SM->operands())
1941 Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1942 return getMulExpr(Ops, SCEV::FlagNUW, Depth + 1);
1943 }
1944
1945 // zext(2^K * (trunc X to iN)) to iM ->
1946 // 2^K * (zext(trunc X to i{N-K}) to iM)<nuw>
1947 //
1948 // Proof:
1949 //
1950 // zext(2^K * (trunc X to iN)) to iM
1951 // = zext((trunc X to iN) << K) to iM
1952 // = zext((trunc X to i{N-K}) << K)<nuw> to iM
1953 // (because shl removes the top K bits)
1954 // = zext((2^K * (trunc X to i{N-K}))<nuw>) to iM
1955 // = (2^K * (zext(trunc X to i{N-K}) to iM))<nuw>.
1956 //
1957 const APInt *C;
1958 const SCEV *TruncRHS;
1959 if (match(SM,
1960 m_scev_Mul(m_scev_APInt(C), m_scev_Trunc(m_SCEV(TruncRHS)))) &&
1961 C->isPowerOf2()) {
1962 int NewTruncBits =
1963 getTypeSizeInBits(SM->getOperand(1)->getType()) - C->logBase2();
1964 Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits);
1965 return getMulExpr(
1966 getZeroExtendExpr(SM->getOperand(0), Ty),
1967 getZeroExtendExpr(getTruncateExpr(TruncRHS, NewTruncTy), Ty),
1968 SCEV::FlagNUW, Depth + 1);
1969 }
1970 }
1971
1972 // zext(umin(x, y)) -> umin(zext(x), zext(y))
1973 // zext(umax(x, y)) -> umax(zext(x), zext(y))
1976 SmallVector<SCEVUse, 4> Operands;
1977 for (SCEVUse Operand : MinMax->operands())
1978 Operands.push_back(getZeroExtendExpr(Operand, Ty));
1980 return getUMinExpr(Operands);
1981 return getUMaxExpr(Operands);
1982 }
1983
1984 // zext(umin_seq(x, y)) -> umin_seq(zext(x), zext(y))
1986 assert(isa<SCEVSequentialUMinExpr>(MinMax) && "Not supported!");
1987 SmallVector<SCEVUse, 4> Operands;
1988 for (SCEVUse Operand : MinMax->operands())
1989 Operands.push_back(getZeroExtendExpr(Operand, Ty));
1990 return getUMinExpr(Operands, /*Sequential*/ true);
1991 }
1992
1993 // The cast wasn't folded; create an explicit cast node.
1994 // Recompute the insert position, as it may have been invalidated.
1995 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1996 SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1997 Op, Ty);
1998 UniqueSCEVs.InsertNode(S, IP);
1999 S->computeAndSetCanonical(*this);
2000 registerUser(S, Op);
2001 return S;
2002}
2003
2004const SCEV *
2006 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
2007 "This is not an extending conversion!");
2008 assert(isSCEVable(Ty) &&
2009 "This is not a conversion to a SCEVable type!");
2010 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
2011 Ty = getEffectiveSCEVType(Ty);
2012
2013 FoldID ID(scSignExtend, Op, Ty);
2014 if (const SCEV *S = FoldCache.lookup(ID))
2015 return S;
2016
2017 const SCEV *S = getSignExtendExprImpl(Op, Ty, Depth);
2019 insertFoldCacheEntry(ID, S, FoldCache, FoldCacheUser);
2020 return S;
2021}
2022
2024 unsigned Depth) {
2025 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
2026 "This is not an extending conversion!");
2027 assert(isSCEVable(Ty) && "This is not a conversion to a SCEVable type!");
2028 assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
2029 Ty = getEffectiveSCEVType(Ty);
2030
2031 // Fold if the operand is constant.
2032 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
2033 return getConstant(SC->getAPInt().sext(getTypeSizeInBits(Ty)));
2034
2035 // sext(sext(x)) --> sext(x)
2037 return getSignExtendExpr(SS->getOperand(), Ty, Depth + 1);
2038
2039 // sext(zext(x)) --> zext(x)
2041 return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
2042
2043 // Before doing any expensive analysis, check to see if we've already
2044 // computed a SCEV for this Op and Ty.
2046 ID.AddInteger(scSignExtend);
2047 ID.AddPointer(Op);
2048 ID.AddPointer(Ty);
2049 void *IP = nullptr;
2050 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2051 // Limit recursion depth.
2052 if (Depth > MaxCastDepth) {
2053 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
2054 Op, Ty);
2055 UniqueSCEVs.InsertNode(S, IP);
2056 S->computeAndSetCanonical(*this);
2057 registerUser(S, Op);
2058 return S;
2059 }
2060
2061 // sext(trunc(x)) --> sext(x) or x or trunc(x)
2063 // It's possible the bits taken off by the truncate were all sign bits. If
2064 // so, we should be able to simplify this further.
2065 const SCEV *X = ST->getOperand();
2067 unsigned TruncBits = getTypeSizeInBits(ST->getType());
2068 unsigned NewBits = getTypeSizeInBits(Ty);
2069 if (CR.truncate(TruncBits).signExtend(NewBits).contains(
2070 CR.sextOrTrunc(NewBits)))
2071 return getTruncateOrSignExtend(X, Ty, Depth);
2072 }
2073
2074 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
2075 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
2076 if (SA->hasNoSignedWrap()) {
2077 // If the addition does not sign overflow then we can, by definition,
2078 // commute the sign extension with the addition operation.
2080 for (SCEVUse Op : SA->operands())
2081 Ops.push_back(getSignExtendExpr(Op, Ty, Depth + 1));
2082 return getAddExpr(Ops, SCEV::FlagNSW, Depth + 1);
2083 }
2084
2085 // sext(C + x + y + ...) --> (sext(D) + sext((C - D) + x + y + ...))
2086 // if D + (C - D + x + y + ...) could be proven to not signed wrap
2087 // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
2088 //
2089 // For instance, this will bring two seemingly different expressions:
2090 // 1 + sext(5 + 20 * %x + 24 * %y) and
2091 // sext(6 + 20 * %x + 24 * %y)
2092 // to the same form:
2093 // 2 + sext(4 + 20 * %x + 24 * %y)
2094 if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
2095 const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
2096 if (D != 0) {
2097 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
2098 const SCEV *SResidual =
2100 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
2101 return getAddExpr(SSExtD, SSExtR, (SCEV::FlagNSW | SCEV::FlagNUW),
2102 Depth + 1);
2103 }
2104 }
2105 }
2106 // If the input value is a chrec scev, and we can prove that the value
2107 // did not overflow the old, smaller, value, we can sign extend all of the
2108 // operands (often constants). This allows analysis of something like
2109 // this: for (signed char X = 0; X < 100; ++X) { int Y = X; }
2111 if (AR->isAffine()) {
2112 const SCEV *Start = AR->getStart();
2113 const SCEV *Step = AR->getStepRecurrence(*this);
2114 unsigned BitWidth = getTypeSizeInBits(AR->getType());
2115 const Loop *L = AR->getLoop();
2116
2117 // If we have special knowledge that this addrec won't overflow,
2118 // we don't need to do any further analysis.
2119 if (AR->hasNoSignedWrap()) {
2120 Start =
2122 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2123 return getAddRecExpr(Start, Step, L, SCEV::FlagNSW);
2124 }
2125
2126 // Check whether the backedge-taken count is SCEVCouldNotCompute.
2127 // Note that this serves two purposes: It filters out loops that are
2128 // simply not analyzable, and it covers the case where this code is
2129 // being called from within backedge-taken count analysis, such that
2130 // attempting to ask for the backedge-taken count would likely result
2131 // in infinite recursion. In the later case, the analysis code will
2132 // cope with a conservative value, and it will take care to purge
2133 // that value once it has finished.
2134 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
2135 if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
2136 // Manually compute the final value for AR, checking for
2137 // overflow.
2138
2139 // Check whether the backedge-taken count can be losslessly casted to
2140 // the addrec's type. The count is always unsigned.
2141 const SCEV *CastedMaxBECount =
2142 getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
2143 const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
2144 CastedMaxBECount, MaxBECount->getType(), Depth);
2145 if (MaxBECount == RecastedMaxBECount) {
2146 Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
2147 // Check whether Start+Step*MaxBECount has no signed overflow.
2148 const SCEV *SMul = getMulExpr(CastedMaxBECount, Step,
2150 const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul,
2152 Depth + 1),
2153 WideTy, Depth + 1);
2154 const SCEV *WideStart = getSignExtendExpr(Start, WideTy, Depth + 1);
2155 const SCEV *WideMaxBECount =
2156 getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
2157 const SCEV *OperandExtendedAdd =
2158 getAddExpr(WideStart,
2159 getMulExpr(WideMaxBECount,
2160 getSignExtendExpr(Step, WideTy, Depth + 1),
2163 if (SAdd == OperandExtendedAdd) {
2164 // Cache knowledge of AR NSW, which is propagated to this AddRec.
2165 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2166 // Return the expression with the addrec on the outside.
2167 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2168 Depth + 1);
2169 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2170 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2171 }
2172 // Similar to above, only this time treat the step value as unsigned.
2173 // This covers loops that count up with an unsigned step.
2174 OperandExtendedAdd =
2175 getAddExpr(WideStart,
2176 getMulExpr(WideMaxBECount,
2177 getZeroExtendExpr(Step, WideTy, Depth + 1),
2180 if (SAdd == OperandExtendedAdd) {
2181 // If AR wraps around then
2182 //
2183 // abs(Step) * MaxBECount > unsigned-max(AR->getType())
2184 // => SAdd != OperandExtendedAdd
2185 //
2186 // Thus (AR is not NW => SAdd != OperandExtendedAdd) <=>
2187 // (SAdd == OperandExtendedAdd => AR is NW)
2188
2189 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
2190
2191 // Return the expression with the addrec on the outside.
2192 Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2193 Depth + 1);
2194 Step = getZeroExtendExpr(Step, Ty, Depth + 1);
2195 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2196 }
2197 }
2198 }
2199
2200 auto NewFlags = proveNoSignedWrapViaInduction(AR);
2201 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
2202 if (AR->hasNoSignedWrap()) {
2203 // Same as nsw case above - duplicated here to avoid a compile time
2204 // issue. It's not clear that the order of checks does matter, but
2205 // it's one of two issue possible causes for a change which was
2206 // reverted. Be conservative for the moment.
2207 Start =
2209 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2210 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2211 }
2212
2213 // sext({C,+,Step}) --> (sext(D) + sext({C-D,+,Step}))<nuw><nsw>
2214 // if D + (C - D + Step * n) could be proven to not signed wrap
2215 // where D maximizes the number of trailing zeros of (C - D + Step * n)
2216 if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
2217 const APInt &C = SC->getAPInt();
2218 const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
2219 if (D != 0) {
2220 const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
2221 const SCEV *SResidual =
2222 getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
2223 const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
2224 return getAddExpr(SSExtD, SSExtR, (SCEV::FlagNSW | SCEV::FlagNUW),
2225 Depth + 1);
2226 }
2227 }
2228
2229 if (proveNoWrapByVaryingStart<SCEVSignExtendExpr>(Start, Step, L)) {
2230 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2231 Start =
2233 Step = getSignExtendExpr(Step, Ty, Depth + 1);
2234 return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2235 }
2236 }
2237
2238 // If the input value is provably positive and we could not simplify
2239 // away the sext build a zext instead.
2241 return getZeroExtendExpr(Op, Ty, Depth + 1);
2242
2243 // sext(smin(x, y)) -> smin(sext(x), sext(y))
2244 // sext(smax(x, y)) -> smax(sext(x), sext(y))
2247 SmallVector<SCEVUse, 4> Operands;
2248 for (SCEVUse Operand : MinMax->operands())
2249 Operands.push_back(getSignExtendExpr(Operand, Ty));
2251 return getSMinExpr(Operands);
2252 return getSMaxExpr(Operands);
2253 }
2254
2255 // The cast wasn't folded; create an explicit cast node.
2256 // Recompute the insert position, as it may have been invalidated.
2257 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2258 SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
2259 Op, Ty);
2260 UniqueSCEVs.InsertNode(S, IP);
2261 S->computeAndSetCanonical(*this);
2262 registerUser(S, Op);
2263 return S;
2264}
2265
2267 Type *Ty) {
2268 switch (Kind) {
2269 case scTruncate:
2270 return getTruncateExpr(Op, Ty);
2271 case scZeroExtend:
2272 return getZeroExtendExpr(Op, Ty);
2273 case scSignExtend:
2274 return getSignExtendExpr(Op, Ty);
2275 case scPtrToInt:
2276 return getPtrToIntExpr(Op, Ty);
2277 default:
2278 llvm_unreachable("Not a SCEV cast expression!");
2279 }
2280}
2281
2282/// getAnyExtendExpr - Return a SCEV for the given operand extended with
2283/// unspecified bits out to the given type.
2285 Type *Ty) {
2286 assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
2287 "This is not an extending conversion!");
2288 assert(isSCEVable(Ty) &&
2289 "This is not a conversion to a SCEVable type!");
2290 Ty = getEffectiveSCEVType(Ty);
2291
2292 // Sign-extend negative constants.
2293 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
2294 if (SC->getAPInt().isNegative())
2295 return getSignExtendExpr(Op, Ty);
2296
2297 // Peel off a truncate cast.
2299 const SCEV *NewOp = T->getOperand();
2300 if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty))
2301 return getAnyExtendExpr(NewOp, Ty);
2302 return getTruncateOrNoop(NewOp, Ty);
2303 }
2304
2305 // Next try a zext cast. If the cast is folded, use it.
2306 const SCEV *ZExt = getZeroExtendExpr(Op, Ty);
2307 if (!isa<SCEVZeroExtendExpr>(ZExt))
2308 return ZExt;
2309
2310 // Next try a sext cast. If the cast is folded, use it.
2311 const SCEV *SExt = getSignExtendExpr(Op, Ty);
2312 if (!isa<SCEVSignExtendExpr>(SExt))
2313 return SExt;
2314
2315 // Force the cast to be folded into the operands of an addrec.
2316 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) {
2318 for (const SCEV *Op : AR->operands())
2319 Ops.push_back(getAnyExtendExpr(Op, Ty));
2320 return getAddRecExpr(Ops, AR->getLoop(), SCEV::FlagNW);
2321 }
2322
2323 // If the expression is obviously signed, use the sext cast value.
2324 if (isa<SCEVSMaxExpr>(Op))
2325 return SExt;
2326
2327 // Absent any other information, use the zext cast value.
2328 return ZExt;
2329}
2330
2331/// Process the given Ops list, which is a list of operands to be added under
2332/// the given scale, update the given map. This is a helper function for
2333/// getAddRecExpr. As an example of what it does, given a sequence of operands
2334/// that would form an add expression like this:
2335///
2336/// m + n + 13 + (A * (o + p + (B * (q + m + 29)))) + r + (-1 * r)
2337///
2338/// where A and B are constants, update the map with these values:
2339///
2340/// (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0)
2341///
2342/// and add 13 + A*B*29 to AccumulatedConstant.
2343/// This will allow getAddRecExpr to produce this:
2344///
2345/// 13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B)
2346///
2347/// This form often exposes folding opportunities that are hidden in
2348/// the original operand list.
2349///
2350/// Return true iff it appears that any interesting folding opportunities
2351/// may be exposed. This helps getAddRecExpr short-circuit extra work in
2352/// the common case where no interesting opportunities are present, and
2353/// is also used as a check to avoid infinite recursion.
2356 APInt &AccumulatedConstant,
2358 const APInt &Scale,
2359 ScalarEvolution &SE) {
2360 bool Interesting = false;
2361
2362 // Iterate over the add operands. They are sorted, with constants first.
2363 unsigned i = 0;
2364 while (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
2365 ++i;
2366 // Pull a buried constant out to the outside.
2367 if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero())
2368 Interesting = true;
2369 AccumulatedConstant += Scale * C->getAPInt();
2370 }
2371
2372 // Next comes everything else. We're especially interested in multiplies
2373 // here, but they're in the middle, so just visit the rest with one loop.
2374 for (; i != Ops.size(); ++i) {
2376 if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) {
2377 APInt NewScale =
2378 Scale * cast<SCEVConstant>(Mul->getOperand(0))->getAPInt();
2379 if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) {
2380 // A multiplication of a constant with another add; recurse.
2381 const SCEVAddExpr *Add = cast<SCEVAddExpr>(Mul->getOperand(1));
2382 Interesting |= CollectAddOperandsWithScales(
2383 M, NewOps, AccumulatedConstant, Add->operands(), NewScale, SE);
2384 } else {
2385 // A multiplication of a constant with some other value. Update
2386 // the map.
2387 SmallVector<SCEVUse, 4> MulOps(drop_begin(Mul->operands()));
2388 const SCEV *Key = SE.getMulExpr(MulOps);
2389 auto Pair = M.insert({Key, NewScale});
2390 if (Pair.second) {
2391 NewOps.push_back(Pair.first->first);
2392 } else {
2393 Pair.first->second += NewScale;
2394 // The map already had an entry for this value, which may indicate
2395 // a folding opportunity.
2396 Interesting = true;
2397 }
2398 }
2399 } else {
2400 // An ordinary operand. Update the map.
2401 auto Pair = M.insert({Ops[i], Scale});
2402 if (Pair.second) {
2403 NewOps.push_back(Pair.first->first);
2404 } else {
2405 Pair.first->second += Scale;
2406 // The map already had an entry for this value, which may indicate
2407 // a folding opportunity.
2408 Interesting = true;
2409 }
2410 }
2411 }
2412
2413 return Interesting;
2414}
2415
2417 const SCEV *LHS, const SCEV *RHS,
2418 const Instruction *CtxI) {
2420 unsigned);
2421 switch (BinOp) {
2422 default:
2423 llvm_unreachable("Unsupported binary op");
2424 case Instruction::Add:
2426 break;
2427 case Instruction::Sub:
2429 break;
2430 case Instruction::Mul:
2432 break;
2433 }
2434
2435 const SCEV *(ScalarEvolution::*Extension)(const SCEV *, Type *, unsigned) =
2438
2439 // Check ext(LHS op RHS) == ext(LHS) op ext(RHS)
2440 auto *NarrowTy = cast<IntegerType>(LHS->getType());
2441 auto *WideTy =
2442 IntegerType::get(NarrowTy->getContext(), NarrowTy->getBitWidth() * 2);
2443
2444 const SCEV *A = (this->*Extension)(
2445 (this->*Operation)(LHS, RHS, SCEV::FlagAnyWrap, 0), WideTy, 0);
2446 const SCEV *LHSB = (this->*Extension)(LHS, WideTy, 0);
2447 const SCEV *RHSB = (this->*Extension)(RHS, WideTy, 0);
2448 const SCEV *B = (this->*Operation)(LHSB, RHSB, SCEV::FlagAnyWrap, 0);
2449 if (A == B)
2450 return true;
2451 // Can we use context to prove the fact we need?
2452 if (!CtxI)
2453 return false;
2454 // TODO: Support mul.
2455 if (BinOp == Instruction::Mul)
2456 return false;
2457 auto *RHSC = dyn_cast<SCEVConstant>(RHS);
2458 // TODO: Lift this limitation.
2459 if (!RHSC)
2460 return false;
2461 APInt C = RHSC->getAPInt();
2462 unsigned NumBits = C.getBitWidth();
2463 bool IsSub = (BinOp == Instruction::Sub);
2464 bool IsNegativeConst = (Signed && C.isNegative());
2465 // Compute the direction and magnitude by which we need to check overflow.
2466 bool OverflowDown = IsSub ^ IsNegativeConst;
2467 APInt Magnitude = C;
2468 if (IsNegativeConst) {
2469 if (C == APInt::getSignedMinValue(NumBits))
2470 // TODO: SINT_MIN on inversion gives the same negative value, we don't
2471 // want to deal with that.
2472 return false;
2473 Magnitude = -C;
2474 }
2475
2477 if (OverflowDown) {
2478 // To avoid overflow down, we need to make sure that MIN + Magnitude <= LHS.
2479 APInt Min = Signed ? APInt::getSignedMinValue(NumBits)
2480 : APInt::getMinValue(NumBits);
2481 APInt Limit = Min + Magnitude;
2482 return isKnownPredicateAt(Pred, getConstant(Limit), LHS, CtxI);
2483 } else {
2484 // To avoid overflow up, we need to make sure that LHS <= MAX - Magnitude.
2485 APInt Max = Signed ? APInt::getSignedMaxValue(NumBits)
2486 : APInt::getMaxValue(NumBits);
2487 APInt Limit = Max - Magnitude;
2488 return isKnownPredicateAt(Pred, LHS, getConstant(Limit), CtxI);
2489 }
2490}
2491
2492std::optional<SCEV::NoWrapFlags>
2494 const OverflowingBinaryOperator *OBO) {
2495 // It cannot be done any better.
2496 if (OBO->hasNoUnsignedWrap() && OBO->hasNoSignedWrap())
2497 return std::nullopt;
2498
2499 SCEV::NoWrapFlags Flags = SCEV::NoWrapFlags::FlagAnyWrap;
2500
2501 if (OBO->hasNoUnsignedWrap())
2503 if (OBO->hasNoSignedWrap())
2505
2506 bool Deduced = false;
2507
2508 if (OBO->getOpcode() != Instruction::Add &&
2509 OBO->getOpcode() != Instruction::Sub &&
2510 OBO->getOpcode() != Instruction::Mul)
2511 return std::nullopt;
2512
2513 const SCEV *LHS = getSCEV(OBO->getOperand(0));
2514 const SCEV *RHS = getSCEV(OBO->getOperand(1));
2515
2516 const Instruction *CtxI =
2518 if (!OBO->hasNoUnsignedWrap() &&
2520 /* Signed */ false, LHS, RHS, CtxI)) {
2522 Deduced = true;
2523 }
2524
2525 if (!OBO->hasNoSignedWrap() &&
2527 /* Signed */ true, LHS, RHS, CtxI)) {
2529 Deduced = true;
2530 }
2531
2532 if (Deduced)
2533 return Flags;
2534 return std::nullopt;
2535}
2536
2537// We're trying to construct a SCEV of type `Type' with `Ops' as operands and
2538// `OldFlags' as can't-wrap behavior. Infer a more aggressive set of
2539// can't-overflow flags for the operation if possible.
2543 SCEV::NoWrapFlags Flags) {
2544 using namespace std::placeholders;
2545
2546 using OBO = OverflowingBinaryOperator;
2547
2548 bool CanAnalyze =
2550 (void)CanAnalyze;
2551 assert(CanAnalyze && "don't call from other places!");
2552
2553 SCEV::NoWrapFlags SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
2554 SCEV::NoWrapFlags SignOrUnsignWrap =
2555 ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2556
2557 // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
2558 auto IsKnownNonNegative = [&](SCEVUse U) {
2559 return SE->isKnownNonNegative(U);
2560 };
2561
2562 if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative))
2563 Flags = ScalarEvolution::setFlags(Flags, SignOrUnsignMask);
2564
2565 SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2566
2567 if (SignOrUnsignWrap != SignOrUnsignMask &&
2568 (Type == scAddExpr || Type == scMulExpr) && Ops.size() == 2 &&
2569 isa<SCEVConstant>(Ops[0])) {
2570
2571 auto Opcode = [&] {
2572 switch (Type) {
2573 case scAddExpr:
2574 return Instruction::Add;
2575 case scMulExpr:
2576 return Instruction::Mul;
2577 default:
2578 llvm_unreachable("Unexpected SCEV op.");
2579 }
2580 }();
2581
2582 const APInt &C = cast<SCEVConstant>(Ops[0])->getAPInt();
2583
2584 // (A <opcode> C) --> (A <opcode> C)<nsw> if the op doesn't sign overflow.
2585 if (!(SignOrUnsignWrap & SCEV::FlagNSW)) {
2587 Opcode, C, OBO::NoSignedWrap);
2588 if (NSWRegion.contains(SE->getSignedRange(Ops[1])))
2590 }
2591
2592 // (A <opcode> C) --> (A <opcode> C)<nuw> if the op doesn't unsign overflow.
2593 if (!(SignOrUnsignWrap & SCEV::FlagNUW)) {
2595 Opcode, C, OBO::NoUnsignedWrap);
2596 if (NUWRegion.contains(SE->getUnsignedRange(Ops[1])))
2598 }
2599 }
2600
2601 // <0,+,nonnegative><nw> is also nuw
2602 // TODO: Add corresponding nsw case
2604 !ScalarEvolution::hasFlags(Flags, SCEV::FlagNUW) && Ops.size() == 2 &&
2605 Ops[0]->isZero() && IsKnownNonNegative(Ops[1]))
2607
2608 // both (udiv X, Y) * Y and Y * (udiv X, Y) are always NUW
2610 Ops.size() == 2) {
2611 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[0]))
2612 if (UDiv->getOperand(1) == Ops[1])
2614 if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[1]))
2615 if (UDiv->getOperand(1) == Ops[0])
2617 }
2618
2619 return Flags;
2620}
2621
2623 return isLoopInvariant(S, L) && properlyDominates(S, L->getHeader());
2624}
2625
2626/// Get a canonical add expression, or something simpler if possible.
2628 SCEV::NoWrapFlags OrigFlags,
2629 unsigned Depth) {
2630 assert(!(OrigFlags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) &&
2631 "only nuw or nsw allowed");
2632 assert(!Ops.empty() && "Cannot get empty add!");
2633 if (Ops.size() == 1) return Ops[0];
2634#ifndef NDEBUG
2635 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
2636 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2637 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
2638 "SCEVAddExpr operand types don't match!");
2639 unsigned NumPtrs = count_if(
2640 Ops, [](const SCEV *Op) { return Op->getType()->isPointerTy(); });
2641 assert(NumPtrs <= 1 && "add has at most one pointer operand");
2642#endif
2643
2644 const SCEV *Folded = constantFoldAndGroupOps(
2645 *this, LI, DT, Ops,
2646 [](const APInt &C1, const APInt &C2) { return C1 + C2; },
2647 [](const APInt &C) { return C.isZero(); }, // identity
2648 [](const APInt &C) { return false; }); // absorber
2649 if (Folded)
2650 return Folded;
2651
2652 unsigned Idx = isa<SCEVConstant>(Ops[0]) ? 1 : 0;
2653
2654 // Delay expensive flag strengthening until necessary.
2655 auto ComputeFlags = [this, OrigFlags](ArrayRef<SCEVUse> Ops) {
2656 return StrengthenNoWrapFlags(this, scAddExpr, Ops, OrigFlags);
2657 };
2658
2659 // Limit recursion calls depth.
2661 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2662
2663 if (SCEV *S = findExistingSCEVInCache(scAddExpr, Ops)) {
2664 // Don't strengthen flags if we have no new information.
2665 SCEVAddExpr *Add = static_cast<SCEVAddExpr *>(S);
2666 if (Add->getNoWrapFlags(OrigFlags) != OrigFlags)
2667 Add->setNoWrapFlags(ComputeFlags(Ops));
2668 return S;
2669 }
2670
2671 // Okay, check to see if the same value occurs in the operand list more than
2672 // once. If so, merge them together into an multiply expression. Since we
2673 // sorted the list, these values are required to be adjacent.
2674 Type *Ty = Ops[0]->getType();
2675 bool FoundMatch = false;
2676 for (unsigned i = 0, e = Ops.size(); i != e-1; ++i)
2677 if (Ops[i] == Ops[i+1]) { // X + Y + Y --> X + Y*2
2678 // Scan ahead to count how many equal operands there are.
2679 unsigned Count = 2;
2680 while (i+Count != e && Ops[i+Count] == Ops[i])
2681 ++Count;
2682 // Merge the values into a multiply.
2683 SCEVUse Scale = getConstant(Ty, Count);
2684 const SCEV *Mul = getMulExpr(Scale, Ops[i], SCEV::FlagAnyWrap, Depth + 1);
2685 if (Ops.size() == Count)
2686 return Mul;
2687 Ops[i] = Mul;
2688 Ops.erase(Ops.begin()+i+1, Ops.begin()+i+Count);
2689 --i; e -= Count - 1;
2690 FoundMatch = true;
2691 }
2692 if (FoundMatch)
2693 return getAddExpr(Ops, OrigFlags, Depth + 1);
2694
2695 // Check for truncates. If all the operands are truncated from the same
2696 // type, see if factoring out the truncate would permit the result to be
2697 // folded. eg., n*trunc(x) + m*trunc(y) --> trunc(trunc(m)*x + trunc(n)*y)
2698 // if the contents of the resulting outer trunc fold to something simple.
2699 auto FindTruncSrcType = [&]() -> Type * {
2700 // We're ultimately looking to fold an addrec of truncs and muls of only
2701 // constants and truncs, so if we find any other types of SCEV
2702 // as operands of the addrec then we bail and return nullptr here.
2703 // Otherwise, we return the type of the operand of a trunc that we find.
2704 if (auto *T = dyn_cast<SCEVTruncateExpr>(Ops[Idx]))
2705 return T->getOperand()->getType();
2706 if (const auto *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
2707 SCEVUse LastOp = Mul->getOperand(Mul->getNumOperands() - 1);
2708 if (const auto *T = dyn_cast<SCEVTruncateExpr>(LastOp))
2709 return T->getOperand()->getType();
2710 }
2711 return nullptr;
2712 };
2713 if (auto *SrcType = FindTruncSrcType()) {
2714 SmallVector<SCEVUse, 8> LargeOps;
2715 bool Ok = true;
2716 // Check all the operands to see if they can be represented in the
2717 // source type of the truncate.
2718 for (const SCEV *Op : Ops) {
2720 if (T->getOperand()->getType() != SrcType) {
2721 Ok = false;
2722 break;
2723 }
2724 LargeOps.push_back(T->getOperand());
2725 } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Op)) {
2726 LargeOps.push_back(getAnyExtendExpr(C, SrcType));
2727 } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Op)) {
2728 SmallVector<SCEVUse, 8> LargeMulOps;
2729 for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) {
2730 if (const SCEVTruncateExpr *T =
2731 dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) {
2732 if (T->getOperand()->getType() != SrcType) {
2733 Ok = false;
2734 break;
2735 }
2736 LargeMulOps.push_back(T->getOperand());
2737 } else if (const auto *C = dyn_cast<SCEVConstant>(M->getOperand(j))) {
2738 LargeMulOps.push_back(getAnyExtendExpr(C, SrcType));
2739 } else {
2740 Ok = false;
2741 break;
2742 }
2743 }
2744 if (Ok)
2745 LargeOps.push_back(getMulExpr(LargeMulOps, SCEV::FlagAnyWrap, Depth + 1));
2746 } else {
2747 Ok = false;
2748 break;
2749 }
2750 }
2751 if (Ok) {
2752 // Evaluate the expression in the larger type.
2753 const SCEV *Fold = getAddExpr(LargeOps, SCEV::FlagAnyWrap, Depth + 1);
2754 // If it folds to something simple, use it. Otherwise, don't.
2755 if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
2756 return getTruncateExpr(Fold, Ty);
2757 }
2758 }
2759
2760 if (Ops.size() == 2) {
2761 // Check if we have an expression of the form ((X + C1) - C2), where C1 and
2762 // C2 can be folded in a way that allows retaining wrapping flags of (X +
2763 // C1).
2764 const SCEV *A = Ops[0];
2765 const SCEV *B = Ops[1];
2766 auto *AddExpr = dyn_cast<SCEVAddExpr>(B);
2767 auto *C = dyn_cast<SCEVConstant>(A);
2768 if (AddExpr && C && isa<SCEVConstant>(AddExpr->getOperand(0))) {
2769 auto C1 = cast<SCEVConstant>(AddExpr->getOperand(0))->getAPInt();
2770 auto C2 = C->getAPInt();
2771 SCEV::NoWrapFlags PreservedFlags = SCEV::FlagAnyWrap;
2772
2773 APInt ConstAdd = C1 + C2;
2774 auto AddFlags = AddExpr->getNoWrapFlags();
2775 // Adding a smaller constant is NUW if the original AddExpr was NUW.
2777 ConstAdd.ule(C1)) {
2778 PreservedFlags =
2780 }
2781
2782 // Adding a constant with the same sign and small magnitude is NSW, if the
2783 // original AddExpr was NSW.
2785 C1.isSignBitSet() == ConstAdd.isSignBitSet() &&
2786 ConstAdd.abs().ule(C1.abs())) {
2787 PreservedFlags =
2789 }
2790
2791 if (PreservedFlags != SCEV::FlagAnyWrap) {
2792 SmallVector<SCEVUse, 4> NewOps(AddExpr->operands());
2793 NewOps[0] = getConstant(ConstAdd);
2794 return getAddExpr(NewOps, PreservedFlags);
2795 }
2796 }
2797
2798 // Try to push the constant operand into a ZExt: A + zext (-A + B) -> zext
2799 // (B), if trunc (A) + -A + B does not unsigned-wrap.
2800 const SCEVAddExpr *InnerAdd;
2801 if (match(B, m_scev_ZExt(m_scev_Add(InnerAdd)))) {
2802 const SCEV *NarrowA = getTruncateExpr(A, InnerAdd->getType());
2803 if (NarrowA == getNegativeSCEV(InnerAdd->getOperand(0)) &&
2804 getZeroExtendExpr(NarrowA, B->getType()) == A &&
2805 hasFlags(StrengthenNoWrapFlags(this, scAddExpr, {NarrowA, InnerAdd},
2807 SCEV::FlagNUW)) {
2808 return getZeroExtendExpr(getAddExpr(NarrowA, InnerAdd), B->getType());
2809 }
2810 }
2811 }
2812
2813 // Canonicalize (-1 * urem X, Y) + X --> (Y * X/Y)
2814 const SCEV *Y;
2815 if (Ops.size() == 2 &&
2816 match(Ops[0],
2818 m_scev_URem(m_scev_Specific(Ops[1]), m_SCEV(Y), *this))))
2819 return getMulExpr(Y, getUDivExpr(Ops[1], Y));
2820
2821 // Skip past any other cast SCEVs.
2822 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
2823 ++Idx;
2824
2825 // If there are add operands they would be next.
2826 if (Idx < Ops.size()) {
2827 bool DeletedAdd = false;
2828 // If the original flags and all inlined SCEVAddExprs are NUW, use the
2829 // common NUW flag for expression after inlining. Other flags cannot be
2830 // preserved, because they may depend on the original order of operations.
2831 SCEV::NoWrapFlags CommonFlags = maskFlags(OrigFlags, SCEV::FlagNUW);
2832 while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
2833 if (Ops.size() > AddOpsInlineThreshold ||
2834 Add->getNumOperands() > AddOpsInlineThreshold)
2835 break;
2836 // If we have an add, expand the add operands onto the end of the operands
2837 // list.
2838 Ops.erase(Ops.begin()+Idx);
2839 append_range(Ops, Add->operands());
2840 DeletedAdd = true;
2841 CommonFlags = maskFlags(CommonFlags, Add->getNoWrapFlags());
2842 }
2843
2844 // If we deleted at least one add, we added operands to the end of the list,
2845 // and they are not necessarily sorted. Recurse to resort and resimplify
2846 // any operands we just acquired.
2847 if (DeletedAdd)
2848 return getAddExpr(Ops, CommonFlags, Depth + 1);
2849 }
2850
2851 // Skip over the add expression until we get to a multiply.
2852 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
2853 ++Idx;
2854
2855 // Check to see if there are any folding opportunities present with
2856 // operands multiplied by constant values.
2857 if (Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx])) {
2861 APInt AccumulatedConstant(BitWidth, 0);
2862 if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2863 Ops, APInt(BitWidth, 1), *this)) {
2864 struct APIntCompare {
2865 bool operator()(const APInt &LHS, const APInt &RHS) const {
2866 return LHS.ult(RHS);
2867 }
2868 };
2869
2870 // Some interesting folding opportunity is present, so its worthwhile to
2871 // re-generate the operands list. Group the operands by constant scale,
2872 // to avoid multiplying by the same constant scale multiple times.
2873 std::map<APInt, SmallVector<SCEVUse, 4>, APIntCompare> MulOpLists;
2874 for (const SCEV *NewOp : NewOps)
2875 MulOpLists[M.find(NewOp)->second].push_back(NewOp);
2876 // Re-generate the operands list.
2877 Ops.clear();
2878 if (AccumulatedConstant != 0)
2879 Ops.push_back(getConstant(AccumulatedConstant));
2880 for (auto &MulOp : MulOpLists) {
2881 if (MulOp.first == 1) {
2882 Ops.push_back(getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1));
2883 } else if (MulOp.first != 0) {
2884 Ops.push_back(getMulExpr(
2885 getConstant(MulOp.first),
2886 getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1),
2887 SCEV::FlagAnyWrap, Depth + 1));
2888 }
2889 }
2890 if (Ops.empty())
2891 return getZero(Ty);
2892 if (Ops.size() == 1)
2893 return Ops[0];
2894 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2895 }
2896 }
2897
2898 // If we are adding something to a multiply expression, make sure the
2899 // something is not already an operand of the multiply. If so, merge it into
2900 // the multiply.
2901 for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) {
2902 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]);
2903 for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) {
2904 const SCEV *MulOpSCEV = Mul->getOperand(MulOp);
2905 if (isa<SCEVConstant>(MulOpSCEV))
2906 continue;
2907 for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
2908 if (MulOpSCEV == Ops[AddOp]) {
2909 // Fold W + X + (X * Y * Z) --> W + (X * ((Y*Z)+1))
2910 const SCEV *InnerMul = Mul->getOperand(MulOp == 0);
2911 if (Mul->getNumOperands() != 2) {
2912 // If the multiply has more than two operands, we must get the
2913 // Y*Z term.
2914 SmallVector<SCEVUse, 4> MulOps(Mul->operands().take_front(MulOp));
2915 append_range(MulOps, Mul->operands().drop_front(MulOp + 1));
2916 InnerMul = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2917 }
2918 const SCEV *AddOne =
2919 getAddExpr(getOne(Ty), InnerMul, SCEV::FlagAnyWrap, Depth + 1);
2920 const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV,
2922 if (Ops.size() == 2) return OuterMul;
2923 if (AddOp < Idx) {
2924 Ops.erase(Ops.begin()+AddOp);
2925 Ops.erase(Ops.begin()+Idx-1);
2926 } else {
2927 Ops.erase(Ops.begin()+Idx);
2928 Ops.erase(Ops.begin()+AddOp-1);
2929 }
2930 Ops.push_back(OuterMul);
2931 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2932 }
2933
2934 // Check this multiply against other multiplies being added together.
2935 for (unsigned OtherMulIdx = Idx+1;
2936 OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]);
2937 ++OtherMulIdx) {
2938 const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]);
2939 // If MulOp occurs in OtherMul, we can fold the two multiplies
2940 // together.
2941 for (unsigned OMulOp = 0, e = OtherMul->getNumOperands();
2942 OMulOp != e; ++OMulOp)
2943 if (OtherMul->getOperand(OMulOp) == MulOpSCEV) {
2944 // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E))
2945 const SCEV *InnerMul1 = Mul->getOperand(MulOp == 0);
2946 if (Mul->getNumOperands() != 2) {
2947 SmallVector<SCEVUse, 4> MulOps(Mul->operands().take_front(MulOp));
2948 append_range(MulOps, Mul->operands().drop_front(MulOp+1));
2949 InnerMul1 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2950 }
2951 const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0);
2952 if (OtherMul->getNumOperands() != 2) {
2954 OtherMul->operands().take_front(OMulOp));
2955 append_range(MulOps, OtherMul->operands().drop_front(OMulOp+1));
2956 InnerMul2 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2957 }
2958 const SCEV *InnerMulSum =
2959 getAddExpr(InnerMul1, InnerMul2, SCEV::FlagAnyWrap, Depth + 1);
2960 const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum,
2962 if (Ops.size() == 2) return OuterMul;
2963 Ops.erase(Ops.begin()+Idx);
2964 Ops.erase(Ops.begin()+OtherMulIdx-1);
2965 Ops.push_back(OuterMul);
2966 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2967 }
2968 }
2969 }
2970 }
2971
2972 // If there are any add recurrences in the operands list, see if any other
2973 // added values are loop invariant. If so, we can fold them into the
2974 // recurrence.
2975 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
2976 ++Idx;
2977
2978 // Scan over all recurrences, trying to fold loop invariants into them.
2979 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
2980 // Scan all of the other operands to this add and add them to the vector if
2981 // they are loop invariant w.r.t. the recurrence.
2983 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
2984 const Loop *AddRecLoop = AddRec->getLoop();
2985 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2986 if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) {
2987 LIOps.push_back(Ops[i]);
2988 Ops.erase(Ops.begin()+i);
2989 --i; --e;
2990 }
2991
2992 // If we found some loop invariants, fold them into the recurrence.
2993 if (!LIOps.empty()) {
2994 // Compute nowrap flags for the addition of the loop-invariant ops and
2995 // the addrec. Temporarily push it as an operand for that purpose. These
2996 // flags are valid in the scope of the addrec only.
2997 LIOps.push_back(AddRec);
2998 SCEV::NoWrapFlags Flags = ComputeFlags(LIOps);
2999 LIOps.pop_back();
3000
3001 // NLI + LI + {Start,+,Step} --> NLI + {LI+Start,+,Step}
3002 LIOps.push_back(AddRec->getStart());
3003
3004 SmallVector<SCEVUse, 4> AddRecOps(AddRec->operands());
3005
3006 // It is not in general safe to propagate flags valid on an add within
3007 // the addrec scope to one outside it. We must prove that the inner
3008 // scope is guaranteed to execute if the outer one does to be able to
3009 // safely propagate. We know the program is undefined if poison is
3010 // produced on the inner scoped addrec. We also know that *for this use*
3011 // the outer scoped add can't overflow (because of the flags we just
3012 // computed for the inner scoped add) without the program being undefined.
3013 // Proving that entry to the outer scope neccesitates entry to the inner
3014 // scope, thus proves the program undefined if the flags would be violated
3015 // in the outer scope.
3016 SCEV::NoWrapFlags AddFlags = Flags;
3017 if (AddFlags != SCEV::FlagAnyWrap) {
3018 auto *DefI = getDefiningScopeBound(LIOps);
3019 auto *ReachI = &*AddRecLoop->getHeader()->begin();
3020 if (!isGuaranteedToTransferExecutionTo(DefI, ReachI))
3021 AddFlags = SCEV::FlagAnyWrap;
3022 }
3023 AddRecOps[0] = getAddExpr(LIOps, AddFlags, Depth + 1);
3024
3025 // Build the new addrec. Propagate the NUW and NSW flags if both the
3026 // outer add and the inner addrec are guaranteed to have no overflow.
3027 // Always propagate NW.
3028 Flags = AddRec->getNoWrapFlags(setFlags(Flags, SCEV::FlagNW));
3029 const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags);
3030
3031 // If all of the other operands were loop invariant, we are done.
3032 if (Ops.size() == 1) return NewRec;
3033
3034 // Otherwise, add the folded AddRec by the non-invariant parts.
3035 for (unsigned i = 0;; ++i)
3036 if (Ops[i] == AddRec) {
3037 Ops[i] = NewRec;
3038 break;
3039 }
3040 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3041 }
3042
3043 // Okay, if there weren't any loop invariants to be folded, check to see if
3044 // there are multiple AddRec's with the same loop induction variable being
3045 // added together. If so, we can fold them.
3046 for (unsigned OtherIdx = Idx+1;
3047 OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
3048 ++OtherIdx) {
3049 // We expect the AddRecExpr's to be sorted in reverse dominance order,
3050 // so that the 1st found AddRecExpr is dominated by all others.
3051 assert(DT.dominates(
3052 cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()->getHeader(),
3053 AddRec->getLoop()->getHeader()) &&
3054 "AddRecExprs are not sorted in reverse dominance order?");
3055 if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) {
3056 // Other + {A,+,B}<L> + {C,+,D}<L> --> Other + {A+C,+,B+D}<L>
3057 SmallVector<SCEVUse, 4> AddRecOps(AddRec->operands());
3058 for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
3059 ++OtherIdx) {
3060 const auto *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
3061 if (OtherAddRec->getLoop() == AddRecLoop) {
3062 for (unsigned i = 0, e = OtherAddRec->getNumOperands();
3063 i != e; ++i) {
3064 if (i >= AddRecOps.size()) {
3065 append_range(AddRecOps, OtherAddRec->operands().drop_front(i));
3066 break;
3067 }
3068 AddRecOps[i] =
3069 getAddExpr(AddRecOps[i], OtherAddRec->getOperand(i),
3071 }
3072 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
3073 }
3074 }
3075 // Step size has changed, so we cannot guarantee no self-wraparound.
3076 Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap);
3077 return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3078 }
3079 }
3080
3081 // Otherwise couldn't fold anything into this recurrence. Move onto the
3082 // next one.
3083 }
3084
3085 // Okay, it looks like we really DO need an add expr. Check to see if we
3086 // already have one, otherwise create a new one.
3087 return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
3088}
3089
3090const SCEV *ScalarEvolution::getOrCreateAddExpr(ArrayRef<SCEVUse> Ops,
3091 SCEV::NoWrapFlags Flags) {
3093 ID.AddInteger(scAddExpr);
3094 for (const SCEV *Op : Ops)
3095 ID.AddPointer(Op);
3096 void *IP = nullptr;
3097 SCEVAddExpr *S =
3098 static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3099 if (!S) {
3100 SCEVUse *O = SCEVAllocator.Allocate<SCEVUse>(Ops.size());
3102 S = new (SCEVAllocator)
3103 SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size());
3104 UniqueSCEVs.InsertNode(S, IP);
3105 S->computeAndSetCanonical(*this);
3106 registerUser(S, Ops);
3107 }
3108 S->setNoWrapFlags(Flags);
3109 return S;
3110}
3111
3112const SCEV *ScalarEvolution::getOrCreateAddRecExpr(ArrayRef<SCEVUse> Ops,
3113 const Loop *L,
3114 SCEV::NoWrapFlags Flags) {
3115 FoldingSetNodeID ID;
3116 ID.AddInteger(scAddRecExpr);
3117 for (const SCEV *Op : Ops)
3118 ID.AddPointer(Op);
3119 ID.AddPointer(L);
3120 void *IP = nullptr;
3121 SCEVAddRecExpr *S =
3122 static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3123 if (!S) {
3124 SCEVUse *O = SCEVAllocator.Allocate<SCEVUse>(Ops.size());
3126 S = new (SCEVAllocator)
3127 SCEVAddRecExpr(ID.Intern(SCEVAllocator), O, Ops.size(), L);
3128 UniqueSCEVs.InsertNode(S, IP);
3129 S->computeAndSetCanonical(*this);
3130 LoopUsers[L].push_back(S);
3131 registerUser(S, Ops);
3132 }
3133 setNoWrapFlags(S, Flags);
3134 return S;
3135}
3136
3137const SCEV *ScalarEvolution::getOrCreateMulExpr(ArrayRef<SCEVUse> Ops,
3138 SCEV::NoWrapFlags Flags) {
3139 FoldingSetNodeID ID;
3140 ID.AddInteger(scMulExpr);
3141 for (const SCEV *Op : Ops)
3142 ID.AddPointer(Op);
3143 void *IP = nullptr;
3144 SCEVMulExpr *S =
3145 static_cast<SCEVMulExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3146 if (!S) {
3147 SCEVUse *O = SCEVAllocator.Allocate<SCEVUse>(Ops.size());
3149 S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator),
3150 O, Ops.size());
3151 UniqueSCEVs.InsertNode(S, IP);
3152 S->computeAndSetCanonical(*this);
3153 registerUser(S, Ops);
3154 }
3155 S->setNoWrapFlags(Flags);
3156 return S;
3157}
3158
3159static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) {
3160 uint64_t k = i*j;
3161 if (j > 1 && k / j != i) Overflow = true;
3162 return k;
3163}
3164
3165/// Compute the result of "n choose k", the binomial coefficient. If an
3166/// intermediate computation overflows, Overflow will be set and the return will
3167/// be garbage. Overflow is not cleared on absence of overflow.
3168static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) {
3169 // We use the multiplicative formula:
3170 // n(n-1)(n-2)...(n-(k-1)) / k(k-1)(k-2)...1 .
3171 // At each iteration, we take the n-th term of the numeral and divide by the
3172 // (k-n)th term of the denominator. This division will always produce an
3173 // integral result, and helps reduce the chance of overflow in the
3174 // intermediate computations. However, we can still overflow even when the
3175 // final result would fit.
3176
3177 if (n == 0 || n == k) return 1;
3178 if (k > n) return 0;
3179
3180 if (k > n/2)
3181 k = n-k;
3182
3183 uint64_t r = 1;
3184 for (uint64_t i = 1; i <= k; ++i) {
3185 r = umul_ov(r, n-(i-1), Overflow);
3186 r /= i;
3187 }
3188 return r;
3189}
3190
3191/// Determine if any of the operands in this SCEV are a constant or if
3192/// any of the add or multiply expressions in this SCEV contain a constant.
3193static bool containsConstantInAddMulChain(const SCEV *StartExpr) {
3194 struct FindConstantInAddMulChain {
3195 bool FoundConstant = false;
3196
3197 bool follow(const SCEV *S) {
3198 FoundConstant |= isa<SCEVConstant>(S);
3199 return isa<SCEVAddExpr>(S) || isa<SCEVMulExpr>(S);
3200 }
3201
3202 bool isDone() const {
3203 return FoundConstant;
3204 }
3205 };
3206
3207 FindConstantInAddMulChain F;
3209 ST.visitAll(StartExpr);
3210 return F.FoundConstant;
3211}
3212
3213/// Get a canonical multiply expression, or something simpler if possible.
3215 SCEV::NoWrapFlags OrigFlags,
3216 unsigned Depth) {
3217 assert(OrigFlags == maskFlags(OrigFlags, SCEV::FlagNUW | SCEV::FlagNSW) &&
3218 "only nuw or nsw allowed");
3219 assert(!Ops.empty() && "Cannot get empty mul!");
3220 if (Ops.size() == 1) return Ops[0];
3221#ifndef NDEBUG
3222 Type *ETy = Ops[0]->getType();
3223 assert(!ETy->isPointerTy());
3224 for (unsigned i = 1, e = Ops.size(); i != e; ++i)
3225 assert(Ops[i]->getType() == ETy &&
3226 "SCEVMulExpr operand types don't match!");
3227#endif
3228
3229 const SCEV *Folded = constantFoldAndGroupOps(
3230 *this, LI, DT, Ops,
3231 [](const APInt &C1, const APInt &C2) { return C1 * C2; },
3232 [](const APInt &C) { return C.isOne(); }, // identity
3233 [](const APInt &C) { return C.isZero(); }); // absorber
3234 if (Folded)
3235 return Folded;
3236
3237 // Delay expensive flag strengthening until necessary.
3238 auto ComputeFlags = [this, OrigFlags](const ArrayRef<SCEVUse> Ops) {
3239 return StrengthenNoWrapFlags(this, scMulExpr, Ops, OrigFlags);
3240 };
3241
3242 // Limit recursion calls depth.
3244 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3245
3246 if (SCEV *S = findExistingSCEVInCache(scMulExpr, Ops)) {
3247 // Don't strengthen flags if we have no new information.
3248 SCEVMulExpr *Mul = static_cast<SCEVMulExpr *>(S);
3249 if (Mul->getNoWrapFlags(OrigFlags) != OrigFlags)
3250 Mul->setNoWrapFlags(ComputeFlags(Ops));
3251 return S;
3252 }
3253
3254 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3255 if (Ops.size() == 2) {
3256 // C1*(C2+V) -> C1*C2 + C1*V
3257 // If any of Add's ops are Adds or Muls with a constant, apply this
3258 // transformation as well.
3259 //
3260 // TODO: There are some cases where this transformation is not
3261 // profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of
3262 // this transformation should be narrowed down.
3263 const SCEV *Op0, *Op1;
3264 if (match(Ops[1], m_scev_Add(m_SCEV(Op0), m_SCEV(Op1))) &&
3266 const SCEV *LHS = getMulExpr(LHSC, Op0, SCEV::FlagAnyWrap, Depth + 1);
3267 const SCEV *RHS = getMulExpr(LHSC, Op1, SCEV::FlagAnyWrap, Depth + 1);
3268 return getAddExpr(LHS, RHS, SCEV::FlagAnyWrap, Depth + 1);
3269 }
3270
3271 if (Ops[0]->isAllOnesValue()) {
3272 // If we have a mul by -1 of an add, try distributing the -1 among the
3273 // add operands.
3274 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) {
3276 bool AnyFolded = false;
3277 for (const SCEV *AddOp : Add->operands()) {
3278 const SCEV *Mul = getMulExpr(Ops[0], SCEVUse(AddOp),
3280 if (!isa<SCEVMulExpr>(Mul)) AnyFolded = true;
3281 NewOps.push_back(Mul);
3282 }
3283 if (AnyFolded)
3284 return getAddExpr(NewOps, SCEV::FlagAnyWrap, Depth + 1);
3285 } else if (const auto *AddRec = dyn_cast<SCEVAddRecExpr>(Ops[1])) {
3286 // Negation preserves a recurrence's no self-wrap property.
3287 SmallVector<SCEVUse, 4> Operands;
3288 for (const SCEV *AddRecOp : AddRec->operands())
3289 Operands.push_back(getMulExpr(Ops[0], SCEVUse(AddRecOp),
3290 SCEV::FlagAnyWrap, Depth + 1));
3291 // Let M be the minimum representable signed value. AddRec with nsw
3292 // multiplied by -1 can have signed overflow if and only if it takes a
3293 // value of M: M * (-1) would stay M and (M + 1) * (-1) would be the
3294 // maximum signed value. In all other cases signed overflow is
3295 // impossible.
3296 auto FlagsMask = SCEV::FlagNW;
3297 if (AddRec->hasNoSignedWrap()) {
3298 auto MinInt =
3299 APInt::getSignedMinValue(getTypeSizeInBits(AddRec->getType()));
3300 if (getSignedRangeMin(AddRec) != MinInt)
3301 FlagsMask = setFlags(FlagsMask, SCEV::FlagNSW);
3302 }
3303 return getAddRecExpr(Operands, AddRec->getLoop(),
3304 AddRec->getNoWrapFlags(FlagsMask));
3305 }
3306 }
3307
3308 // Try to push the constant operand into a ZExt: C * zext (A + B) ->
3309 // zext (C*A + C*B) if trunc (C) * (A + B) does not unsigned-wrap.
3310 const SCEVAddExpr *InnerAdd;
3311 if (match(Ops[1], m_scev_ZExt(m_scev_Add(InnerAdd)))) {
3312 const SCEV *NarrowC = getTruncateExpr(LHSC, InnerAdd->getType());
3313 if (isa<SCEVConstant>(InnerAdd->getOperand(0)) &&
3314 getZeroExtendExpr(NarrowC, Ops[1]->getType()) == LHSC &&
3315 hasFlags(StrengthenNoWrapFlags(this, scMulExpr, {NarrowC, InnerAdd},
3317 SCEV::FlagNUW)) {
3318 auto *Res = getMulExpr(NarrowC, InnerAdd, SCEV::FlagNUW, Depth + 1);
3319 return getZeroExtendExpr(Res, Ops[1]->getType(), Depth + 1);
3320 };
3321 }
3322
3323 // Try to fold (C1 * D /u C2) -> C1/C2 * D, if C1 and C2 are powers-of-2,
3324 // D is a multiple of C2, and C1 is a multiple of C2. If C2 is a multiple
3325 // of C1, fold to (D /u (C2 /u C1)).
3326 const SCEV *D;
3327 APInt C1V = LHSC->getAPInt();
3328 // (C1 * D /u C2) == -1 * -C1 * D /u C2 when C1 != INT_MIN. Don't treat -1
3329 // as -1 * 1, as it won't enable additional folds.
3330 if (C1V.isNegative() && !C1V.isMinSignedValue() && !C1V.isAllOnes())
3331 C1V = C1V.abs();
3332 const SCEVConstant *C2;
3333 if (C1V.isPowerOf2() &&
3335 C2->getAPInt().isPowerOf2() &&
3336 C1V.logBase2() <= getMinTrailingZeros(D)) {
3337 const SCEV *NewMul = nullptr;
3338 if (C1V.uge(C2->getAPInt())) {
3339 NewMul = getMulExpr(getUDivExpr(getConstant(C1V), C2), D);
3340 } else if (C2->getAPInt().logBase2() <= getMinTrailingZeros(D)) {
3341 assert(C1V.ugt(1) && "C1 <= 1 should have been folded earlier");
3342 NewMul = getUDivExpr(D, getUDivExpr(C2, getConstant(C1V)));
3343 }
3344 if (NewMul)
3345 return C1V == LHSC->getAPInt() ? NewMul : getNegativeSCEV(NewMul);
3346 }
3347 }
3348 }
3349
3350 // Skip over the add expression until we get to a multiply.
3351 unsigned Idx = 0;
3352 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
3353 ++Idx;
3354
3355 // If there are mul operands inline them all into this expression.
3356 if (Idx < Ops.size()) {
3357 bool DeletedMul = false;
3358 while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
3359 if (Ops.size() > MulOpsInlineThreshold)
3360 break;
3361 // If we have an mul, expand the mul operands onto the end of the
3362 // operands list.
3363 Ops.erase(Ops.begin()+Idx);
3364 append_range(Ops, Mul->operands());
3365 DeletedMul = true;
3366 }
3367
3368 // If we deleted at least one mul, we added operands to the end of the
3369 // list, and they are not necessarily sorted. Recurse to resort and
3370 // resimplify any operands we just acquired.
3371 if (DeletedMul)
3372 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3373 }
3374
3375 // If there are any add recurrences in the operands list, see if any other
3376 // added values are loop invariant. If so, we can fold them into the
3377 // recurrence.
3378 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
3379 ++Idx;
3380
3381 // Scan over all recurrences, trying to fold loop invariants into them.
3382 for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
3383 // Scan all of the other operands to this mul and add them to the vector
3384 // if they are loop invariant w.r.t. the recurrence.
3386 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
3387 for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3388 if (isAvailableAtLoopEntry(Ops[i], AddRec->getLoop())) {
3389 LIOps.push_back(Ops[i]);
3390 Ops.erase(Ops.begin()+i);
3391 --i; --e;
3392 }
3393
3394 // If we found some loop invariants, fold them into the recurrence.
3395 if (!LIOps.empty()) {
3396 // NLI * LI * {Start,+,Step} --> NLI * {LI*Start,+,LI*Step}
3398 NewOps.reserve(AddRec->getNumOperands());
3399 const SCEV *Scale = getMulExpr(LIOps, SCEV::FlagAnyWrap, Depth + 1);
3400
3401 // If both the mul and addrec are nuw, we can preserve nuw.
3402 // If both the mul and addrec are nsw, we can only preserve nsw if either
3403 // a) they are also nuw, or
3404 // b) all multiplications of addrec operands with scale are nsw.
3405 SCEV::NoWrapFlags Flags =
3406 AddRec->getNoWrapFlags(ComputeFlags({Scale, AddRec}));
3407
3408 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
3409 NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i),
3410 SCEV::FlagAnyWrap, Depth + 1));
3411
3412 if (hasFlags(Flags, SCEV::FlagNSW) && !hasFlags(Flags, SCEV::FlagNUW)) {
3414 Instruction::Mul, getSignedRange(Scale),
3416 if (!NSWRegion.contains(getSignedRange(AddRec->getOperand(i))))
3417 Flags = clearFlags(Flags, SCEV::FlagNSW);
3418 }
3419 }
3420
3421 const SCEV *NewRec = getAddRecExpr(NewOps, AddRec->getLoop(), Flags);
3422
3423 // If all of the other operands were loop invariant, we are done.
3424 if (Ops.size() == 1) return NewRec;
3425
3426 // Otherwise, multiply the folded AddRec by the non-invariant parts.
3427 for (unsigned i = 0;; ++i)
3428 if (Ops[i] == AddRec) {
3429 Ops[i] = NewRec;
3430 break;
3431 }
3432 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3433 }
3434
3435 // Okay, if there weren't any loop invariants to be folded, check to see
3436 // if there are multiple AddRec's with the same loop induction variable
3437 // being multiplied together. If so, we can fold them.
3438
3439 // {A1,+,A2,+,...,+,An}<L> * {B1,+,B2,+,...,+,Bn}<L>
3440 // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [
3441 // choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z
3442 // ]]],+,...up to x=2n}.
3443 // Note that the arguments to choose() are always integers with values
3444 // known at compile time, never SCEV objects.
3445 //
3446 // The implementation avoids pointless extra computations when the two
3447 // addrec's are of different length (mathematically, it's equivalent to
3448 // an infinite stream of zeros on the right).
3449 bool OpsModified = false;
3450 for (unsigned OtherIdx = Idx+1;
3451 OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
3452 ++OtherIdx) {
3453 const SCEVAddRecExpr *OtherAddRec =
3454 dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]);
3455 if (!OtherAddRec || OtherAddRec->getLoop() != AddRec->getLoop())
3456 continue;
3457
3458 // Limit max number of arguments to avoid creation of unreasonably big
3459 // SCEVAddRecs with very complex operands.
3460 if (AddRec->getNumOperands() + OtherAddRec->getNumOperands() - 1 >
3461 MaxAddRecSize || hasHugeExpression({AddRec, OtherAddRec}))
3462 continue;
3463
3464 bool Overflow = false;
3465 Type *Ty = AddRec->getType();
3466 bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64;
3467 SmallVector<SCEVUse, 7> AddRecOps;
3468 for (int x = 0, xe = AddRec->getNumOperands() +
3469 OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) {
3471 for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) {
3472 uint64_t Coeff1 = Choose(x, 2*x - y, Overflow);
3473 for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1),
3474 ze = std::min(x+1, (int)OtherAddRec->getNumOperands());
3475 z < ze && !Overflow; ++z) {
3476 uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow);
3477 uint64_t Coeff;
3478 if (LargerThan64Bits)
3479 Coeff = umul_ov(Coeff1, Coeff2, Overflow);
3480 else
3481 Coeff = Coeff1*Coeff2;
3482 const SCEV *CoeffTerm = getConstant(Ty, Coeff);
3483 const SCEV *Term1 = AddRec->getOperand(y-z);
3484 const SCEV *Term2 = OtherAddRec->getOperand(z);
3485 SumOps.push_back(getMulExpr(CoeffTerm, Term1, Term2,
3486 SCEV::FlagAnyWrap, Depth + 1));
3487 }
3488 }
3489 if (SumOps.empty())
3490 SumOps.push_back(getZero(Ty));
3491 AddRecOps.push_back(getAddExpr(SumOps, SCEV::FlagAnyWrap, Depth + 1));
3492 }
3493 if (!Overflow) {
3494 const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRec->getLoop(),
3496 if (Ops.size() == 2) return NewAddRec;
3497 Ops[Idx] = NewAddRec;
3498 Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
3499 OpsModified = true;
3500 AddRec = dyn_cast<SCEVAddRecExpr>(NewAddRec);
3501 if (!AddRec)
3502 break;
3503 }
3504 }
3505 if (OpsModified)
3506 return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3507
3508 // Otherwise couldn't fold anything into this recurrence. Move onto the
3509 // next one.
3510 }
3511
3512 // Okay, it looks like we really DO need an mul expr. Check to see if we
3513 // already have one, otherwise create a new one.
3514 return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3515}
3516
3517/// Represents an unsigned remainder expression based on unsigned division.
3519 assert(getEffectiveSCEVType(LHS->getType()) ==
3520 getEffectiveSCEVType(RHS->getType()) &&
3521 "SCEVURemExpr operand types don't match!");
3522
3523 // Short-circuit easy cases
3524 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3525 // If constant is one, the result is trivial
3526 if (RHSC->getValue()->isOne())
3527 return getZero(LHS->getType()); // X urem 1 --> 0
3528
3529 // If constant is a power of two, fold into a zext(trunc(LHS)).
3530 if (RHSC->getAPInt().isPowerOf2()) {
3531 Type *FullTy = LHS->getType();
3532 Type *TruncTy =
3533 IntegerType::get(getContext(), RHSC->getAPInt().logBase2());
3534 return getZeroExtendExpr(getTruncateExpr(LHS, TruncTy), FullTy);
3535 }
3536 }
3537
3538 // Fallback to %a == %x urem %y == %x -<nuw> ((%x udiv %y) *<nuw> %y)
3539 const SCEV *UDiv = getUDivExpr(LHS, RHS);
3540 const SCEV *Mult = getMulExpr(UDiv, RHS, SCEV::FlagNUW);
3541 return getMinusSCEV(LHS, Mult, SCEV::FlagNUW);
3542}
3543
3544/// Get a canonical unsigned division expression, or something simpler if
3545/// possible.
3547 assert(!LHS->getType()->isPointerTy() &&
3548 "SCEVUDivExpr operand can't be pointer!");
3549 assert(LHS->getType() == RHS->getType() &&
3550 "SCEVUDivExpr operand types don't match!");
3551
3553 ID.AddInteger(scUDivExpr);
3554 ID.AddPointer(LHS);
3555 ID.AddPointer(RHS);
3556 void *IP = nullptr;
3557 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3558 return S;
3559
3560 // 0 udiv Y == 0
3561 if (match(LHS, m_scev_Zero()))
3562 return LHS;
3563
3564 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3565 if (RHSC->getValue()->isOne())
3566 return LHS; // X udiv 1 --> x
3567 // If the denominator is zero, the result of the udiv is undefined. Don't
3568 // try to analyze it, because the resolution chosen here may differ from
3569 // the resolution chosen in other parts of the compiler.
3570 if (!RHSC->getValue()->isZero()) {
3571 // Determine if the division can be folded into the operands of
3572 // its operands.
3573 // TODO: Generalize this to non-constants by using known-bits information.
3574 Type *Ty = LHS->getType();
3575 unsigned LZ = RHSC->getAPInt().countl_zero();
3576 unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1;
3577 // For non-power-of-two values, effectively round the value up to the
3578 // nearest power of two.
3579 if (!RHSC->getAPInt().isPowerOf2())
3580 ++MaxShiftAmt;
3581 IntegerType *ExtTy =
3582 IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt);
3583 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
3584 if (const SCEVConstant *Step =
3585 dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this))) {
3586 // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded.
3587 const APInt &StepInt = Step->getAPInt();
3588 const APInt &DivInt = RHSC->getAPInt();
3589 if (!StepInt.urem(DivInt) &&
3590 getZeroExtendExpr(AR, ExtTy) ==
3591 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3592 getZeroExtendExpr(Step, ExtTy),
3593 AR->getLoop(), SCEV::FlagAnyWrap)) {
3594 SmallVector<SCEVUse, 4> Operands;
3595 for (const SCEV *Op : AR->operands())
3596 Operands.push_back(getUDivExpr(Op, RHS));
3597 return getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagNW);
3598 }
3599 /// Get a canonical UDivExpr for a recurrence.
3600 /// {X,+,N}/C => {Y,+,N}/C where Y=X-(X%N). Safe when C%N=0.
3601 const APInt *StartRem;
3602 if (!DivInt.urem(StepInt) && match(getURemExpr(AR->getStart(), Step),
3603 m_scev_APInt(StartRem))) {
3604 bool NoWrap =
3605 getZeroExtendExpr(AR, ExtTy) ==
3606 getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3607 getZeroExtendExpr(Step, ExtTy), AR->getLoop(),
3609
3610 // With N <= C and both N, C as powers-of-2, the transformation
3611 // {X,+,N}/C => {(X - X%N),+,N}/C preserves division results even
3612 // if wrapping occurs, as the division results remain equivalent for
3613 // all offsets in [[(X - X%N), X).
3614 bool CanFoldWithWrap = StepInt.ule(DivInt) && // N <= C
3615 StepInt.isPowerOf2() && DivInt.isPowerOf2();
3616 // Only fold if the subtraction can be folded in the start
3617 // expression.
3618 const SCEV *NewStart =
3619 getMinusSCEV(AR->getStart(), getConstant(*StartRem));
3620 if (*StartRem != 0 && (NoWrap || CanFoldWithWrap) &&
3621 !isa<SCEVAddExpr>(NewStart)) {
3622 const SCEV *NewLHS =
3623 getAddRecExpr(NewStart, Step, AR->getLoop(),
3624 NoWrap ? SCEV::FlagNW : SCEV::FlagAnyWrap);
3625 if (LHS != NewLHS) {
3626 LHS = NewLHS;
3627
3628 // Reset the ID to include the new LHS, and check if it is
3629 // already cached.
3630 ID.clear();
3631 ID.AddInteger(scUDivExpr);
3632 ID.AddPointer(LHS);
3633 ID.AddPointer(RHS);
3634 IP = nullptr;
3635 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3636 return S;
3637 }
3638 }
3639 }
3640 }
3641 // (A*B)/C --> A*(B/C) if safe and B/C can be folded.
3642 if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) {
3643 SmallVector<SCEVUse, 4> Operands;
3644 for (const SCEV *Op : M->operands())
3645 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3646 if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands)) {
3647 // Find an operand that's safely divisible.
3648 for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) {
3649 const SCEV *Op = M->getOperand(i);
3650 const SCEV *Div = getUDivExpr(Op, RHSC);
3651 if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) {
3652 Operands = SmallVector<SCEVUse, 4>(M->operands());
3653 Operands[i] = Div;
3654 return getMulExpr(Operands);
3655 }
3656 }
3657
3658 // Even if it's not divisible, try to remove a common factor.
3659 if (const auto *LHSC = dyn_cast<SCEVConstant>(M->getOperand(0))) {
3660 APInt Factor = APIntOps::GreatestCommonDivisor(LHSC->getAPInt(),
3661 RHSC->getAPInt());
3662 if (!Factor.isIntN(1)) {
3663 SmallVector<SCEVUse, 2> NewOperands;
3664 NewOperands.push_back(getConstant(LHSC->getAPInt().udiv(Factor)));
3665 append_range(NewOperands, M->operands().drop_front());
3666 const SCEV *NewMul = getMulExpr(NewOperands);
3667 return getUDivExpr(NewMul,
3668 getConstant(RHSC->getAPInt().udiv(Factor)));
3669 }
3670 }
3671 }
3672 }
3673
3674 // (A/B)/C --> A/(B*C) if safe and B*C can be folded.
3675 if (const SCEVUDivExpr *OtherDiv = dyn_cast<SCEVUDivExpr>(LHS)) {
3676 if (auto *DivisorConstant =
3677 dyn_cast<SCEVConstant>(OtherDiv->getRHS())) {
3678 bool Overflow = false;
3679 APInt NewRHS =
3680 DivisorConstant->getAPInt().umul_ov(RHSC->getAPInt(), Overflow);
3681 if (Overflow) {
3682 return getConstant(RHSC->getType(), 0, false);
3683 }
3684 return getUDivExpr(OtherDiv->getLHS(), getConstant(NewRHS));
3685 }
3686 }
3687
3688 // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded.
3689 if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(LHS)) {
3690 SmallVector<SCEVUse, 4> Operands;
3691 for (const SCEV *Op : A->operands())
3692 Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3693 if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) {
3694 Operands.clear();
3695 for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) {
3696 const SCEV *Op = getUDivExpr(A->getOperand(i), RHS);
3697 if (isa<SCEVUDivExpr>(Op) ||
3698 getMulExpr(Op, RHS) != A->getOperand(i))
3699 break;
3700 Operands.push_back(Op);
3701 }
3702 if (Operands.size() == A->getNumOperands())
3703 return getAddExpr(Operands);
3704 }
3705 }
3706
3707 // Fold if both operands are constant.
3708 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS))
3709 return getConstant(LHSC->getAPInt().udiv(RHSC->getAPInt()));
3710 }
3711 }
3712
3713 // ((-C + (C smax %x)) /u %x) evaluates to zero, for any positive constant C.
3714 const APInt *NegC, *C;
3715 if (match(LHS,
3718 NegC->isNegative() && !NegC->isMinSignedValue() && *C == -*NegC)
3719 return getZero(LHS->getType());
3720
3721 // (%a * %b)<nuw> / %b -> %a
3722 const auto *Mul = dyn_cast<SCEVMulExpr>(LHS);
3723 if (Mul && Mul->hasNoUnsignedWrap()) {
3724 for (int i = 0, e = Mul->getNumOperands(); i != e; ++i) {
3725 if (Mul->getOperand(i) == RHS) {
3726 SmallVector<SCEVUse, 2> Operands;
3727 append_range(Operands, Mul->operands().take_front(i));
3728 append_range(Operands, Mul->operands().drop_front(i + 1));
3729 return getMulExpr(Operands);
3730 }
3731 }
3732 }
3733
3734 // TODO: Generalize to handle any common factors.
3735 // udiv (mul nuw a, vscale), (mul nuw b, vscale) --> udiv a, b
3736 const SCEV *NewLHS, *NewRHS;
3737 if (match(LHS, m_scev_c_NUWMul(m_SCEV(NewLHS), m_SCEVVScale())) &&
3738 match(RHS, m_scev_c_NUWMul(m_SCEV(NewRHS), m_SCEVVScale())))
3739 return getUDivExpr(NewLHS, NewRHS);
3740
3741 // The Insertion Point (IP) might be invalid by now (due to UniqueSCEVs
3742 // changes). Make sure we get a new one.
3743 IP = nullptr;
3744 if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
3745 SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator),
3746 LHS, RHS);
3747 UniqueSCEVs.InsertNode(S, IP);
3748 S->computeAndSetCanonical(*this);
3749 registerUser(S, ArrayRef<SCEVUse>({LHS, RHS}));
3750 return S;
3751}
3752
3753APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) {
3754 APInt A = C1->getAPInt().abs();
3755 APInt B = C2->getAPInt().abs();
3756 uint32_t ABW = A.getBitWidth();
3757 uint32_t BBW = B.getBitWidth();
3758
3759 if (ABW > BBW)
3760 B = B.zext(ABW);
3761 else if (ABW < BBW)
3762 A = A.zext(BBW);
3763
3764 return APIntOps::GreatestCommonDivisor(std::move(A), std::move(B));
3765}
3766
3767/// Get a canonical unsigned division expression, or something simpler if
3768/// possible. There is no representation for an exact udiv in SCEV IR, but we
3769/// can attempt to optimize it prior to construction.
3771 // Currently there is no exact specific logic.
3772
3773 return getUDivExpr(LHS, RHS);
3774}
3775
3776/// Get an add recurrence expression for the specified loop. Simplify the
3777/// expression as much as possible.
3779 const Loop *L,
3780 SCEV::NoWrapFlags Flags) {
3781 SmallVector<SCEVUse, 4> Operands;
3782 Operands.push_back(Start);
3783 if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step))
3784 if (StepChrec->getLoop() == L) {
3785 append_range(Operands, StepChrec->operands());
3786 return getAddRecExpr(Operands, L, maskFlags(Flags, SCEV::FlagNW));
3787 }
3788
3789 Operands.push_back(Step);
3790 return getAddRecExpr(Operands, L, Flags);
3791}
3792
3793/// Get an add recurrence expression for the specified loop. Simplify the
3794/// expression as much as possible.
3796 const Loop *L,
3797 SCEV::NoWrapFlags Flags) {
3798 if (Operands.size() == 1) return Operands[0];
3799#ifndef NDEBUG
3800 Type *ETy = getEffectiveSCEVType(Operands[0]->getType());
3801 for (const SCEV *Op : llvm::drop_begin(Operands)) {
3802 assert(getEffectiveSCEVType(Op->getType()) == ETy &&
3803 "SCEVAddRecExpr operand types don't match!");
3804 assert(!Op->getType()->isPointerTy() && "Step must be integer");
3805 }
3806 for (const SCEV *Op : Operands)
3808 "SCEVAddRecExpr operand is not available at loop entry!");
3809#endif
3810
3811 if (Operands.back()->isZero()) {
3812 Operands.pop_back();
3813 return getAddRecExpr(Operands, L, SCEV::FlagAnyWrap); // {X,+,0} --> X
3814 }
3815
3816 // It's tempting to want to call getConstantMaxBackedgeTakenCount count here and
3817 // use that information to infer NUW and NSW flags. However, computing a
3818 // BE count requires calling getAddRecExpr, so we may not yet have a
3819 // meaningful BE count at this point (and if we don't, we'd be stuck
3820 // with a SCEVCouldNotCompute as the cached BE count).
3821
3822 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
3823
3824 // Canonicalize nested AddRecs in by nesting them in order of loop depth.
3825 if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
3826 const Loop *NestedLoop = NestedAR->getLoop();
3827 if (L->contains(NestedLoop)
3828 ? (L->getLoopDepth() < NestedLoop->getLoopDepth())
3829 : (!NestedLoop->contains(L) &&
3830 DT.dominates(L->getHeader(), NestedLoop->getHeader()))) {
3831 SmallVector<SCEVUse, 4> NestedOperands(NestedAR->operands());
3832 Operands[0] = NestedAR->getStart();
3833 // AddRecs require their operands be loop-invariant with respect to their
3834 // loops. Don't perform this transformation if it would break this
3835 // requirement.
3836 bool AllInvariant = all_of(
3837 Operands, [&](const SCEV *Op) { return isLoopInvariant(Op, L); });
3838
3839 if (AllInvariant) {
3840 // Create a recurrence for the outer loop with the same step size.
3841 //
3842 // The outer recurrence keeps its NW flag but only keeps NUW/NSW if the
3843 // inner recurrence has the same property.
3844 SCEV::NoWrapFlags OuterFlags =
3845 maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags());
3846
3847 NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags);
3848 AllInvariant = all_of(NestedOperands, [&](const SCEV *Op) {
3849 return isLoopInvariant(Op, NestedLoop);
3850 });
3851
3852 if (AllInvariant) {
3853 // Ok, both add recurrences are valid after the transformation.
3854 //
3855 // The inner recurrence keeps its NW flag but only keeps NUW/NSW if
3856 // the outer recurrence has the same property.
3857 SCEV::NoWrapFlags InnerFlags =
3858 maskFlags(NestedAR->getNoWrapFlags(), SCEV::FlagNW | Flags);
3859 return getAddRecExpr(NestedOperands, NestedLoop, InnerFlags);
3860 }
3861 }
3862 // Reset Operands to its original state.
3863 Operands[0] = NestedAR;
3864 }
3865 }
3866
3867 // Okay, it looks like we really DO need an addrec expr. Check to see if we
3868 // already have one, otherwise create a new one.
3869 return getOrCreateAddRecExpr(Operands, L, Flags);
3870}
3871
3873 ArrayRef<SCEVUse> IndexExprs) {
3874 const SCEV *BaseExpr = getSCEV(GEP->getPointerOperand());
3875 // getSCEV(Base)->getType() has the same address space as Base->getType()
3876 // because SCEV::getType() preserves the address space.
3877 GEPNoWrapFlags NW = GEP->getNoWrapFlags();
3878 if (NW != GEPNoWrapFlags::none()) {
3879 // We'd like to propagate flags from the IR to the corresponding SCEV nodes,
3880 // but to do that, we have to ensure that said flag is valid in the entire
3881 // defined scope of the SCEV.
3882 // TODO: non-instructions have global scope. We might be able to prove
3883 // some global scope cases
3884 auto *GEPI = dyn_cast<Instruction>(GEP);
3885 if (!GEPI || !isSCEVExprNeverPoison(GEPI))
3886 NW = GEPNoWrapFlags::none();
3887 }
3888
3889 return getGEPExpr(BaseExpr, IndexExprs, GEP->getSourceElementType(), NW);
3890}
3891
3893 ArrayRef<SCEVUse> IndexExprs,
3894 Type *SrcElementTy, GEPNoWrapFlags NW) {
3896 if (NW.hasNoUnsignedSignedWrap())
3897 OffsetWrap = setFlags(OffsetWrap, SCEV::FlagNSW);
3898 if (NW.hasNoUnsignedWrap())
3899 OffsetWrap = setFlags(OffsetWrap, SCEV::FlagNUW);
3900
3901 Type *CurTy = BaseExpr->getType();
3902 Type *IntIdxTy = getEffectiveSCEVType(BaseExpr->getType());
3903 bool FirstIter = true;
3905 for (SCEVUse IndexExpr : IndexExprs) {
3906 // Compute the (potentially symbolic) offset in bytes for this index.
3907 if (StructType *STy = dyn_cast<StructType>(CurTy)) {
3908 // For a struct, add the member offset.
3909 ConstantInt *Index = cast<SCEVConstant>(IndexExpr)->getValue();
3910 unsigned FieldNo = Index->getZExtValue();
3911 const SCEV *FieldOffset = getOffsetOfExpr(IntIdxTy, STy, FieldNo);
3912 Offsets.push_back(FieldOffset);
3913
3914 // Update CurTy to the type of the field at Index.
3915 CurTy = STy->getTypeAtIndex(Index);
3916 } else {
3917 // Update CurTy to its element type.
3918 if (FirstIter) {
3919 assert(isa<PointerType>(CurTy) &&
3920 "The first index of a GEP indexes a pointer");
3921 CurTy = SrcElementTy;
3922 FirstIter = false;
3923 } else {
3925 }
3926 // For an array, add the element offset, explicitly scaled.
3927 const SCEV *ElementSize = getSizeOfExpr(IntIdxTy, CurTy);
3928 // Getelementptr indices are signed.
3929 IndexExpr = getTruncateOrSignExtend(IndexExpr, IntIdxTy);
3930
3931 // Multiply the index by the element size to compute the element offset.
3932 const SCEV *LocalOffset = getMulExpr(IndexExpr, ElementSize, OffsetWrap);
3933 Offsets.push_back(LocalOffset);
3934 }
3935 }
3936
3937 // Handle degenerate case of GEP without offsets.
3938 if (Offsets.empty())
3939 return BaseExpr;
3940
3941 // Add the offsets together, assuming nsw if inbounds.
3942 const SCEV *Offset = getAddExpr(Offsets, OffsetWrap);
3943 // Add the base address and the offset. We cannot use the nsw flag, as the
3944 // base address is unsigned. However, if we know that the offset is
3945 // non-negative, we can use nuw.
3946 bool NUW = NW.hasNoUnsignedWrap() ||
3949 auto *GEPExpr = getAddExpr(BaseExpr, Offset, BaseWrap);
3950 assert(BaseExpr->getType() == GEPExpr->getType() &&
3951 "GEP should not change type mid-flight.");
3952 return GEPExpr;
3953}
3954
3955SCEV *ScalarEvolution::findExistingSCEVInCache(SCEVTypes SCEVType,
3958 ID.AddInteger(SCEVType);
3959 for (const SCEV *Op : Ops)
3960 ID.AddPointer(Op);
3961 void *IP = nullptr;
3962 return UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3963}
3964
3965SCEV *ScalarEvolution::findExistingSCEVInCache(SCEVTypes SCEVType,
3968 ID.AddInteger(SCEVType);
3969 for (const SCEV *Op : Ops)
3970 ID.AddPointer(Op);
3971 void *IP = nullptr;
3972 return UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3973}
3974
3975const SCEV *ScalarEvolution::getAbsExpr(const SCEV *Op, bool IsNSW) {
3977 return getSMaxExpr(Op, getNegativeSCEV(Op, Flags));
3978}
3979
3982 assert(SCEVMinMaxExpr::isMinMaxType(Kind) && "Not a SCEVMinMaxExpr!");
3983 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
3984 if (Ops.size() == 1) return Ops[0];
3985#ifndef NDEBUG
3986 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
3987 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
3988 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
3989 "Operand types don't match!");
3990 assert(Ops[0]->getType()->isPointerTy() ==
3991 Ops[i]->getType()->isPointerTy() &&
3992 "min/max should be consistently pointerish");
3993 }
3994#endif
3995
3996 bool IsSigned = Kind == scSMaxExpr || Kind == scSMinExpr;
3997 bool IsMax = Kind == scSMaxExpr || Kind == scUMaxExpr;
3998
3999 const SCEV *Folded = constantFoldAndGroupOps(
4000 *this, LI, DT, Ops,
4001 [&](const APInt &C1, const APInt &C2) {
4002 switch (Kind) {
4003 case scSMaxExpr:
4004 return APIntOps::smax(C1, C2);
4005 case scSMinExpr:
4006 return APIntOps::smin(C1, C2);
4007 case scUMaxExpr:
4008 return APIntOps::umax(C1, C2);
4009 case scUMinExpr:
4010 return APIntOps::umin(C1, C2);
4011 default:
4012 llvm_unreachable("Unknown SCEV min/max opcode");
4013 }
4014 },
4015 [&](const APInt &C) {
4016 // identity
4017 if (IsMax)
4018 return IsSigned ? C.isMinSignedValue() : C.isMinValue();
4019 else
4020 return IsSigned ? C.isMaxSignedValue() : C.isMaxValue();
4021 },
4022 [&](const APInt &C) {
4023 // absorber
4024 if (IsMax)
4025 return IsSigned ? C.isMaxSignedValue() : C.isMaxValue();
4026 else
4027 return IsSigned ? C.isMinSignedValue() : C.isMinValue();
4028 });
4029 if (Folded)
4030 return Folded;
4031
4032 // Check if we have created the same expression before.
4033 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops)) {
4034 return S;
4035 }
4036
4037 // Find the first operation of the same kind
4038 unsigned Idx = 0;
4039 while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < Kind)
4040 ++Idx;
4041
4042 // Check to see if one of the operands is of the same kind. If so, expand its
4043 // operands onto our operand list, and recurse to simplify.
4044 if (Idx < Ops.size()) {
4045 bool DeletedAny = false;
4046 while (Ops[Idx]->getSCEVType() == Kind) {
4047 const SCEVMinMaxExpr *SMME = cast<SCEVMinMaxExpr>(Ops[Idx]);
4048 Ops.erase(Ops.begin()+Idx);
4049 append_range(Ops, SMME->operands());
4050 DeletedAny = true;
4051 }
4052
4053 if (DeletedAny)
4054 return getMinMaxExpr(Kind, Ops);
4055 }
4056
4057 // Okay, check to see if the same value occurs in the operand list twice. If
4058 // so, delete one. Since we sorted the list, these values are required to
4059 // be adjacent.
4064 llvm::CmpInst::Predicate FirstPred = IsMax ? GEPred : LEPred;
4065 llvm::CmpInst::Predicate SecondPred = IsMax ? LEPred : GEPred;
4066 for (unsigned i = 0, e = Ops.size() - 1; i != e; ++i) {
4067 if (Ops[i] == Ops[i + 1] ||
4068 isKnownViaNonRecursiveReasoning(FirstPred, Ops[i], Ops[i + 1])) {
4069 // X op Y op Y --> X op Y
4070 // X op Y --> X, if we know X, Y are ordered appropriately
4071 Ops.erase(Ops.begin() + i + 1, Ops.begin() + i + 2);
4072 --i;
4073 --e;
4074 } else if (isKnownViaNonRecursiveReasoning(SecondPred, Ops[i],
4075 Ops[i + 1])) {
4076 // X op Y --> Y, if we know X, Y are ordered appropriately
4077 Ops.erase(Ops.begin() + i, Ops.begin() + i + 1);
4078 --i;
4079 --e;
4080 }
4081 }
4082
4083 if (Ops.size() == 1) return Ops[0];
4084
4085 assert(!Ops.empty() && "Reduced smax down to nothing!");
4086
4087 // Okay, it looks like we really DO need an expr. Check to see if we
4088 // already have one, otherwise create a new one.
4090 ID.AddInteger(Kind);
4091 for (const SCEV *Op : Ops)
4092 ID.AddPointer(Op);
4093 void *IP = nullptr;
4094 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
4095 if (ExistingSCEV)
4096 return ExistingSCEV;
4097 SCEVUse *O = SCEVAllocator.Allocate<SCEVUse>(Ops.size());
4099 SCEV *S = new (SCEVAllocator)
4100 SCEVMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
4101
4102 UniqueSCEVs.InsertNode(S, IP);
4103 S->computeAndSetCanonical(*this);
4104 registerUser(S, Ops);
4105 return S;
4106}
4107
4108namespace {
4109
4110class SCEVSequentialMinMaxDeduplicatingVisitor final
4111 : public SCEVVisitor<SCEVSequentialMinMaxDeduplicatingVisitor,
4112 std::optional<const SCEV *>> {
4113 using RetVal = std::optional<const SCEV *>;
4115
4116 ScalarEvolution &SE;
4117 const SCEVTypes RootKind; // Must be a sequential min/max expression.
4118 const SCEVTypes NonSequentialRootKind; // Non-sequential variant of RootKind.
4120
4121 bool canRecurseInto(SCEVTypes Kind) const {
4122 // We can only recurse into the SCEV expression of the same effective type
4123 // as the type of our root SCEV expression.
4124 return RootKind == Kind || NonSequentialRootKind == Kind;
4125 };
4126
4127 RetVal visitAnyMinMaxExpr(const SCEV *S) {
4129 "Only for min/max expressions.");
4130 SCEVTypes Kind = S->getSCEVType();
4131
4132 if (!canRecurseInto(Kind))
4133 return S;
4134
4135 auto *NAry = cast<SCEVNAryExpr>(S);
4136 SmallVector<SCEVUse> NewOps;
4137 bool Changed = visit(Kind, NAry->operands(), NewOps);
4138
4139 if (!Changed)
4140 return S;
4141 if (NewOps.empty())
4142 return std::nullopt;
4143
4145 ? SE.getSequentialMinMaxExpr(Kind, NewOps)
4146 : SE.getMinMaxExpr(Kind, NewOps);
4147 }
4148
4149 RetVal visit(const SCEV *S) {
4150 // Has the whole operand been seen already?
4151 if (!SeenOps.insert(S).second)
4152 return std::nullopt;
4153 return Base::visit(S);
4154 }
4155
4156public:
4157 SCEVSequentialMinMaxDeduplicatingVisitor(ScalarEvolution &SE,
4158 SCEVTypes RootKind)
4159 : SE(SE), RootKind(RootKind),
4160 NonSequentialRootKind(
4161 SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType(
4162 RootKind)) {}
4163
4164 bool /*Changed*/ visit(SCEVTypes Kind, ArrayRef<SCEVUse> OrigOps,
4165 SmallVectorImpl<SCEVUse> &NewOps) {
4166 bool Changed = false;
4168 Ops.reserve(OrigOps.size());
4169
4170 for (const SCEV *Op : OrigOps) {
4171 RetVal NewOp = visit(Op);
4172 if (NewOp != Op)
4173 Changed = true;
4174 if (NewOp)
4175 Ops.emplace_back(*NewOp);
4176 }
4177
4178 if (Changed)
4179 NewOps = std::move(Ops);
4180 return Changed;
4181 }
4182
4183 RetVal visitConstant(const SCEVConstant *Constant) { return Constant; }
4184
4185 RetVal visitVScale(const SCEVVScale *VScale) { return VScale; }
4186
4187 RetVal visitPtrToAddrExpr(const SCEVPtrToAddrExpr *Expr) { return Expr; }
4188
4189 RetVal visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { return Expr; }
4190
4191 RetVal visitTruncateExpr(const SCEVTruncateExpr *Expr) { return Expr; }
4192
4193 RetVal visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { return Expr; }
4194
4195 RetVal visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { return Expr; }
4196
4197 RetVal visitAddExpr(const SCEVAddExpr *Expr) { return Expr; }
4198
4199 RetVal visitMulExpr(const SCEVMulExpr *Expr) { return Expr; }
4200
4201 RetVal visitUDivExpr(const SCEVUDivExpr *Expr) { return Expr; }
4202
4203 RetVal visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
4204
4205 RetVal visitSMaxExpr(const SCEVSMaxExpr *Expr) {
4206 return visitAnyMinMaxExpr(Expr);
4207 }
4208
4209 RetVal visitUMaxExpr(const SCEVUMaxExpr *Expr) {
4210 return visitAnyMinMaxExpr(Expr);
4211 }
4212
4213 RetVal visitSMinExpr(const SCEVSMinExpr *Expr) {
4214 return visitAnyMinMaxExpr(Expr);
4215 }
4216
4217 RetVal visitUMinExpr(const SCEVUMinExpr *Expr) {
4218 return visitAnyMinMaxExpr(Expr);
4219 }
4220
4221 RetVal visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
4222 return visitAnyMinMaxExpr(Expr);
4223 }
4224
4225 RetVal visitUnknown(const SCEVUnknown *Expr) { return Expr; }
4226
4227 RetVal visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { return Expr; }
4228};
4229
4230} // namespace
4231
4233 switch (Kind) {
4234 case scConstant:
4235 case scVScale:
4236 case scTruncate:
4237 case scZeroExtend:
4238 case scSignExtend:
4239 case scPtrToAddr:
4240 case scPtrToInt:
4241 case scAddExpr:
4242 case scMulExpr:
4243 case scUDivExpr:
4244 case scAddRecExpr:
4245 case scUMaxExpr:
4246 case scSMaxExpr:
4247 case scUMinExpr:
4248 case scSMinExpr:
4249 case scUnknown:
4250 // If any operand is poison, the whole expression is poison.
4251 return true;
4253 // FIXME: if the *first* operand is poison, the whole expression is poison.
4254 return false; // Pessimistically, say that it does not propagate poison.
4255 case scCouldNotCompute:
4256 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
4257 }
4258 llvm_unreachable("Unknown SCEV kind!");
4259}
4260
4261namespace {
4262// The only way poison may be introduced in a SCEV expression is from a
4263// poison SCEVUnknown (ConstantExprs are also represented as SCEVUnknown,
4264// not SCEVConstant). Notably, nowrap flags in SCEV nodes can *not*
4265// introduce poison -- they encode guaranteed, non-speculated knowledge.
4266//
4267// Additionally, all SCEV nodes propagate poison from inputs to outputs,
4268// with the notable exception of umin_seq, where only poison from the first
4269// operand is (unconditionally) propagated.
4270struct SCEVPoisonCollector {
4271 bool LookThroughMaybePoisonBlocking;
4272 SmallPtrSet<const SCEVUnknown *, 4> MaybePoison;
4273 SCEVPoisonCollector(bool LookThroughMaybePoisonBlocking)
4274 : LookThroughMaybePoisonBlocking(LookThroughMaybePoisonBlocking) {}
4275
4276 bool follow(const SCEV *S) {
4277 if (!LookThroughMaybePoisonBlocking &&
4279 return false;
4280
4281 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
4282 if (!isGuaranteedNotToBePoison(SU->getValue()))
4283 MaybePoison.insert(SU);
4284 }
4285 return true;
4286 }
4287 bool isDone() const { return false; }
4288};
4289} // namespace
4290
4291/// Return true if V is poison given that AssumedPoison is already poison.
4292static bool impliesPoison(const SCEV *AssumedPoison, const SCEV *S) {
4293 // First collect all SCEVs that might result in AssumedPoison to be poison.
4294 // We need to look through potentially poison-blocking operations here,
4295 // because we want to find all SCEVs that *might* result in poison, not only
4296 // those that are *required* to.
4297 SCEVPoisonCollector PC1(/* LookThroughMaybePoisonBlocking */ true);
4298 visitAll(AssumedPoison, PC1);
4299
4300 // AssumedPoison is never poison. As the assumption is false, the implication
4301 // is true. Don't bother walking the other SCEV in this case.
4302 if (PC1.MaybePoison.empty())
4303 return true;
4304
4305 // Collect all SCEVs in S that, if poison, *will* result in S being poison
4306 // as well. We cannot look through potentially poison-blocking operations
4307 // here, as their arguments only *may* make the result poison.
4308 SCEVPoisonCollector PC2(/* LookThroughMaybePoisonBlocking */ false);
4309 visitAll(S, PC2);
4310
4311 // Make sure that no matter which SCEV in PC1.MaybePoison is actually poison,
4312 // it will also make S poison by being part of PC2.MaybePoison.
4313 return llvm::set_is_subset(PC1.MaybePoison, PC2.MaybePoison);
4314}
4315
4317 SmallPtrSetImpl<const Value *> &Result, const SCEV *S) {
4318 SCEVPoisonCollector PC(/* LookThroughMaybePoisonBlocking */ false);
4319 visitAll(S, PC);
4320 for (const SCEVUnknown *SU : PC.MaybePoison)
4321 Result.insert(SU->getValue());
4322}
4323
4325 const SCEV *S, Instruction *I,
4326 SmallVectorImpl<Instruction *> &DropPoisonGeneratingInsts) {
4327 // If the instruction cannot be poison, it's always safe to reuse.
4329 return true;
4330
4331 // Otherwise, it is possible that I is more poisonous that S. Collect the
4332 // poison-contributors of S, and then check whether I has any additional
4333 // poison-contributors. Poison that is contributed through poison-generating
4334 // flags is handled by dropping those flags instead.
4336 getPoisonGeneratingValues(PoisonVals, S);
4337
4338 SmallVector<Value *> Worklist;
4340 Worklist.push_back(I);
4341 while (!Worklist.empty()) {
4342 Value *V = Worklist.pop_back_val();
4343 if (!Visited.insert(V).second)
4344 continue;
4345
4346 // Avoid walking large instruction graphs.
4347 if (Visited.size() > 16)
4348 return false;
4349
4350 // Either the value can't be poison, or the S would also be poison if it
4351 // is.
4352 if (PoisonVals.contains(V) || ::isGuaranteedNotToBePoison(V))
4353 continue;
4354
4355 auto *I = dyn_cast<Instruction>(V);
4356 if (!I)
4357 return false;
4358
4359 // Disjoint or instructions are interpreted as adds by SCEV. However, we
4360 // can't replace an arbitrary add with disjoint or, even if we drop the
4361 // flag. We would need to convert the or into an add.
4362 if (auto *PDI = dyn_cast<PossiblyDisjointInst>(I))
4363 if (PDI->isDisjoint())
4364 return false;
4365
4366 // FIXME: Ignore vscale, even though it technically could be poison. Do this
4367 // because SCEV currently assumes it can't be poison. Remove this special
4368 // case once we proper model when vscale can be poison.
4369 if (auto *II = dyn_cast<IntrinsicInst>(I);
4370 II && II->getIntrinsicID() == Intrinsic::vscale)
4371 continue;
4372
4373 if (canCreatePoison(cast<Operator>(I), /*ConsiderFlagsAndMetadata*/ false))
4374 return false;
4375
4376 // If the instruction can't create poison, we can recurse to its operands.
4377 if (I->hasPoisonGeneratingAnnotations())
4378 DropPoisonGeneratingInsts.push_back(I);
4379
4380 llvm::append_range(Worklist, I->operands());
4381 }
4382 return true;
4383}
4384
4385const SCEV *
4388 assert(SCEVSequentialMinMaxExpr::isSequentialMinMaxType(Kind) &&
4389 "Not a SCEVSequentialMinMaxExpr!");
4390 assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
4391 if (Ops.size() == 1)
4392 return Ops[0];
4393#ifndef NDEBUG
4394 Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
4395 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4396 assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
4397 "Operand types don't match!");
4398 assert(Ops[0]->getType()->isPointerTy() ==
4399 Ops[i]->getType()->isPointerTy() &&
4400 "min/max should be consistently pointerish");
4401 }
4402#endif
4403
4404 // Note that SCEVSequentialMinMaxExpr is *NOT* commutative,
4405 // so we can *NOT* do any kind of sorting of the expressions!
4406
4407 // Check if we have created the same expression before.
4408 if (const SCEV *S = findExistingSCEVInCache(Kind, Ops))
4409 return S;
4410
4411 // FIXME: there are *some* simplifications that we can do here.
4412
4413 // Keep only the first instance of an operand.
4414 {
4415 SCEVSequentialMinMaxDeduplicatingVisitor Deduplicator(*this, Kind);
4416 bool Changed = Deduplicator.visit(Kind, Ops, Ops);
4417 if (Changed)
4418 return getSequentialMinMaxExpr(Kind, Ops);
4419 }
4420
4421 // Check to see if one of the operands is of the same kind. If so, expand its
4422 // operands onto our operand list, and recurse to simplify.
4423 {
4424 unsigned Idx = 0;
4425 bool DeletedAny = false;
4426 while (Idx < Ops.size()) {
4427 if (Ops[Idx]->getSCEVType() != Kind) {
4428 ++Idx;
4429 continue;
4430 }
4431 const auto *SMME = cast<SCEVSequentialMinMaxExpr>(Ops[Idx]);
4432 Ops.erase(Ops.begin() + Idx);
4433 Ops.insert(Ops.begin() + Idx, SMME->operands().begin(),
4434 SMME->operands().end());
4435 DeletedAny = true;
4436 }
4437
4438 if (DeletedAny)
4439 return getSequentialMinMaxExpr(Kind, Ops);
4440 }
4441
4442 const SCEV *SaturationPoint;
4444 switch (Kind) {
4446 SaturationPoint = getZero(Ops[0]->getType());
4447 Pred = ICmpInst::ICMP_ULE;
4448 break;
4449 default:
4450 llvm_unreachable("Not a sequential min/max type.");
4451 }
4452
4453 for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4454 if (!isGuaranteedNotToCauseUB(Ops[i]))
4455 continue;
4456 // We can replace %x umin_seq %y with %x umin %y if either:
4457 // * %y being poison implies %x is also poison.
4458 // * %x cannot be the saturating value (e.g. zero for umin).
4459 if (::impliesPoison(Ops[i], Ops[i - 1]) ||
4460 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_NE, Ops[i - 1],
4461 SaturationPoint)) {
4462 SmallVector<SCEVUse, 2> SeqOps = {Ops[i - 1], Ops[i]};
4463 Ops[i - 1] = getMinMaxExpr(
4465 SeqOps);
4466 Ops.erase(Ops.begin() + i);
4467 return getSequentialMinMaxExpr(Kind, Ops);
4468 }
4469 // Fold %x umin_seq %y to %x if %x ule %y.
4470 // TODO: We might be able to prove the predicate for a later operand.
4471 if (isKnownViaNonRecursiveReasoning(Pred, Ops[i - 1], Ops[i])) {
4472 Ops.erase(Ops.begin() + i);
4473 return getSequentialMinMaxExpr(Kind, Ops);
4474 }
4475 }
4476
4477 // Okay, it looks like we really DO need an expr. Check to see if we
4478 // already have one, otherwise create a new one.
4480 ID.AddInteger(Kind);
4481 for (const SCEV *Op : Ops)
4482 ID.AddPointer(Op);
4483 void *IP = nullptr;
4484 const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
4485 if (ExistingSCEV)
4486 return ExistingSCEV;
4487
4488 SCEVUse *O = SCEVAllocator.Allocate<SCEVUse>(Ops.size());
4490 SCEV *S = new (SCEVAllocator)
4491 SCEVSequentialMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
4492
4493 UniqueSCEVs.InsertNode(S, IP);
4494 S->computeAndSetCanonical(*this);
4495 registerUser(S, Ops);
4496 return S;
4497}
4498
4503
4507
4512
4516
4521
4525
4527 bool Sequential) {
4528 SmallVector<SCEVUse, 2> Ops = {LHS, RHS};
4529 return getUMinExpr(Ops, Sequential);
4530}
4531
4537
4538const SCEV *
4540 const SCEV *Res = getConstant(IntTy, Size.getKnownMinValue());
4541 if (Size.isScalable())
4542 Res = getMulExpr(Res, getVScale(IntTy));
4543 return Res;
4544}
4545
4547 return getSizeOfExpr(IntTy, getDataLayout().getTypeAllocSize(AllocTy));
4548}
4549
4551 return getSizeOfExpr(IntTy, getDataLayout().getTypeStoreSize(StoreTy));
4552}
4553
4555 StructType *STy,
4556 unsigned FieldNo) {
4557 // We can bypass creating a target-independent constant expression and then
4558 // folding it back into a ConstantInt. This is just a compile-time
4559 // optimization.
4560 const StructLayout *SL = getDataLayout().getStructLayout(STy);
4561 assert(!SL->getSizeInBits().isScalable() &&
4562 "Cannot get offset for structure containing scalable vector types");
4563 return getConstant(IntTy, SL->getElementOffset(FieldNo));
4564}
4565
4567 // Don't attempt to do anything other than create a SCEVUnknown object
4568 // here. createSCEV only calls getUnknown after checking for all other
4569 // interesting possibilities, and any other code that calls getUnknown
4570 // is doing so in order to hide a value from SCEV canonicalization.
4571
4573 ID.AddInteger(scUnknown);
4574 ID.AddPointer(V);
4575 void *IP = nullptr;
4576 if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) {
4577 assert(cast<SCEVUnknown>(S)->getValue() == V &&
4578 "Stale SCEVUnknown in uniquing map!");
4579 return S;
4580 }
4581 SCEV *S = new (SCEVAllocator) SCEVUnknown(ID.Intern(SCEVAllocator), V, this,
4582 FirstUnknown);
4583 FirstUnknown = cast<SCEVUnknown>(S);
4584 UniqueSCEVs.InsertNode(S, IP);
4585 S->computeAndSetCanonical(*this);
4586 return S;
4587}
4588
4589//===----------------------------------------------------------------------===//
4590// Basic SCEV Analysis and PHI Idiom Recognition Code
4591//
4592
4593/// Test if values of the given type are analyzable within the SCEV
4594/// framework. This primarily includes integer types, and it can optionally
4595/// include pointer types if the ScalarEvolution class has access to
4596/// target-specific information.
4598 // Integers and pointers are always SCEVable.
4599 return Ty->isIntOrPtrTy();
4600}
4601
4602/// Return the size in bits of the specified type, for which isSCEVable must
4603/// return true.
4605 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4606 if (Ty->isPointerTy())
4608 return getDataLayout().getTypeSizeInBits(Ty);
4609}
4610
4611/// Return a type with the same bitwidth as the given type and which represents
4612/// how SCEV will treat the given type, for which isSCEVable must return
4613/// true. For pointer types, this is the pointer index sized integer type.
4615 assert(isSCEVable(Ty) && "Type is not SCEVable!");
4616
4617 if (Ty->isIntegerTy())
4618 return Ty;
4619
4620 // The only other support type is pointer.
4621 assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!");
4622 return getDataLayout().getIndexType(Ty);
4623}
4624
4626 return getTypeSizeInBits(T1) >= getTypeSizeInBits(T2) ? T1 : T2;
4627}
4628
4630 const SCEV *B) {
4631 /// For a valid use point to exist, the defining scope of one operand
4632 /// must dominate the other.
4633 bool PreciseA, PreciseB;
4634 auto *ScopeA = getDefiningScopeBound({A}, PreciseA);
4635 auto *ScopeB = getDefiningScopeBound({B}, PreciseB);
4636 if (!PreciseA || !PreciseB)
4637 // Can't tell.
4638 return false;
4639 return (ScopeA == ScopeB) || DT.dominates(ScopeA, ScopeB) ||
4640 DT.dominates(ScopeB, ScopeA);
4641}
4642
4644 return CouldNotCompute.get();
4645}
4646
4647bool ScalarEvolution::checkValidity(const SCEV *S) const {
4648 bool ContainsNulls = SCEVExprContains(S, [](const SCEV *S) {
4649 auto *SU = dyn_cast<SCEVUnknown>(S);
4650 return SU && SU->getValue() == nullptr;
4651 });
4652
4653 return !ContainsNulls;
4654}
4655
4657 HasRecMapType::iterator I = HasRecMap.find(S);
4658 if (I != HasRecMap.end())
4659 return I->second;
4660
4661 bool FoundAddRec =
4662 SCEVExprContains(S, [](const SCEV *S) { return isa<SCEVAddRecExpr>(S); });
4663 HasRecMap.insert({S, FoundAddRec});
4664 return FoundAddRec;
4665}
4666
4667/// Return the ValueOffsetPair set for \p S. \p S can be represented
4668/// by the value and offset from any ValueOffsetPair in the set.
4669ArrayRef<Value *> ScalarEvolution::getSCEVValues(const SCEV *S) {
4670 ExprValueMapType::iterator SI = ExprValueMap.find_as(S);
4671 if (SI == ExprValueMap.end())
4672 return {};
4673 return SI->second.getArrayRef();
4674}
4675
4676/// Erase Value from ValueExprMap and ExprValueMap. ValueExprMap.erase(V)
4677/// cannot be used separately. eraseValueFromMap should be used to remove
4678/// V from ValueExprMap and ExprValueMap at the same time.
4679void ScalarEvolution::eraseValueFromMap(Value *V) {
4680 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4681 if (I != ValueExprMap.end()) {
4682 auto EVIt = ExprValueMap.find(I->second);
4683 bool Removed = EVIt->second.remove(V);
4684 (void) Removed;
4685 assert(Removed && "Value not in ExprValueMap?");
4686 ValueExprMap.erase(I);
4687 }
4688}
4689
4690void ScalarEvolution::insertValueToMap(Value *V, const SCEV *S) {
4691 // A recursive query may have already computed the SCEV. It should be
4692 // equivalent, but may not necessarily be exactly the same, e.g. due to lazily
4693 // inferred nowrap flags.
4694 auto It = ValueExprMap.find_as(V);
4695 if (It == ValueExprMap.end()) {
4696 ValueExprMap.insert({SCEVCallbackVH(V, this), S});
4697 ExprValueMap[S].insert(V);
4698 }
4699}
4700
4701/// Return an existing SCEV if it exists, otherwise analyze the expression and
4702/// create a new one.
4704 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4705
4706 if (const SCEV *S = getExistingSCEV(V))
4707 return S;
4708 return createSCEVIter(V);
4709}
4710
4712 assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4713
4714 ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4715 if (I != ValueExprMap.end()) {
4716 const SCEV *S = I->second;
4717 assert(checkValidity(S) &&
4718 "existing SCEV has not been properly invalidated");
4719 return S;
4720 }
4721 return nullptr;
4722}
4723
4724/// Return a SCEV corresponding to -V = -1*V
4726 SCEV::NoWrapFlags Flags) {
4727 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4728 return getConstant(
4729 cast<ConstantInt>(ConstantExpr::getNeg(VC->getValue())));
4730
4731 Type *Ty = V->getType();
4732 Ty = getEffectiveSCEVType(Ty);
4733 return getMulExpr(V, getMinusOne(Ty), Flags);
4734}
4735
4736/// If Expr computes ~A, return A else return nullptr
4737static const SCEV *MatchNotExpr(const SCEV *Expr) {
4738 const SCEV *MulOp;
4739 if (match(Expr, m_scev_Add(m_scev_AllOnes(),
4740 m_scev_Mul(m_scev_AllOnes(), m_SCEV(MulOp)))))
4741 return MulOp;
4742 return nullptr;
4743}
4744
4745/// Return a SCEV corresponding to ~V = -1-V
4747 assert(!V->getType()->isPointerTy() && "Can't negate pointer");
4748
4749 if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4750 return getConstant(
4751 cast<ConstantInt>(ConstantExpr::getNot(VC->getValue())));
4752
4753 // Fold ~(u|s)(min|max)(~x, ~y) to (u|s)(max|min)(x, y)
4754 if (const SCEVMinMaxExpr *MME = dyn_cast<SCEVMinMaxExpr>(V)) {
4755 auto MatchMinMaxNegation = [&](const SCEVMinMaxExpr *MME) {
4756 SmallVector<SCEVUse, 2> MatchedOperands;
4757 for (const SCEV *Operand : MME->operands()) {
4758 const SCEV *Matched = MatchNotExpr(Operand);
4759 if (!Matched)
4760 return (const SCEV *)nullptr;
4761 MatchedOperands.push_back(Matched);
4762 }
4763 return getMinMaxExpr(SCEVMinMaxExpr::negate(MME->getSCEVType()),
4764 MatchedOperands);
4765 };
4766 if (const SCEV *Replaced = MatchMinMaxNegation(MME))
4767 return Replaced;
4768 }
4769
4770 Type *Ty = V->getType();
4771 Ty = getEffectiveSCEVType(Ty);
4772 return getMinusSCEV(getMinusOne(Ty), V);
4773}
4774
4776 assert(P->getType()->isPointerTy());
4777
4778 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(P)) {
4779 // The base of an AddRec is the first operand.
4780 SmallVector<SCEVUse> Ops{AddRec->operands()};
4781 Ops[0] = removePointerBase(Ops[0]);
4782 // Don't try to transfer nowrap flags for now. We could in some cases
4783 // (for example, if pointer operand of the AddRec is a SCEVUnknown).
4784 return getAddRecExpr(Ops, AddRec->getLoop(), SCEV::FlagAnyWrap);
4785 }
4786 if (auto *Add = dyn_cast<SCEVAddExpr>(P)) {
4787 // The base of an Add is the pointer operand.
4788 SmallVector<SCEVUse> Ops{Add->operands()};
4789 SCEVUse *PtrOp = nullptr;
4790 for (SCEVUse &AddOp : Ops) {
4791 if (AddOp->getType()->isPointerTy()) {
4792 assert(!PtrOp && "Cannot have multiple pointer ops");
4793 PtrOp = &AddOp;
4794 }
4795 }
4796 *PtrOp = removePointerBase(*PtrOp);
4797 // Don't try to transfer nowrap flags for now. We could in some cases
4798 // (for example, if the pointer operand of the Add is a SCEVUnknown).
4799 return getAddExpr(Ops);
4800 }
4801 // Any other expression must be a pointer base.
4802 return getZero(P->getType());
4803}
4804
4806 SCEV::NoWrapFlags Flags,
4807 unsigned Depth) {
4808 // Fast path: X - X --> 0.
4809 if (LHS == RHS)
4810 return getZero(LHS->getType());
4811
4812 // If we subtract two pointers with different pointer bases, bail.
4813 // Eventually, we're going to add an assertion to getMulExpr that we
4814 // can't multiply by a pointer.
4815 if (RHS->getType()->isPointerTy()) {
4816 if (!LHS->getType()->isPointerTy() ||
4817 getPointerBase(LHS) != getPointerBase(RHS))
4818 return getCouldNotCompute();
4819 LHS = removePointerBase(LHS);
4820 RHS = removePointerBase(RHS);
4821 }
4822
4823 // We represent LHS - RHS as LHS + (-1)*RHS. This transformation
4824 // makes it so that we cannot make much use of NUW.
4825 auto AddFlags = SCEV::FlagAnyWrap;
4826 const bool RHSIsNotMinSigned =
4828 if (hasFlags(Flags, SCEV::FlagNSW)) {
4829 // Let M be the minimum representable signed value. Then (-1)*RHS
4830 // signed-wraps if and only if RHS is M. That can happen even for
4831 // a NSW subtraction because e.g. (-1)*M signed-wraps even though
4832 // -1 - M does not. So to transfer NSW from LHS - RHS to LHS +
4833 // (-1)*RHS, we need to prove that RHS != M.
4834 //
4835 // If LHS is non-negative and we know that LHS - RHS does not
4836 // signed-wrap, then RHS cannot be M. So we can rule out signed-wrap
4837 // either by proving that RHS > M or that LHS >= 0.
4838 if (RHSIsNotMinSigned || isKnownNonNegative(LHS)) {
4839 AddFlags = SCEV::FlagNSW;
4840 }
4841 }
4842
4843 // FIXME: Find a correct way to transfer NSW to (-1)*M when LHS -
4844 // RHS is NSW and LHS >= 0.
4845 //
4846 // The difficulty here is that the NSW flag may have been proven
4847 // relative to a loop that is to be found in a recurrence in LHS and
4848 // not in RHS. Applying NSW to (-1)*M may then let the NSW have a
4849 // larger scope than intended.
4850 auto NegFlags = RHSIsNotMinSigned ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
4851
4852 return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags, Depth);
4853}
4854
4856 unsigned Depth) {
4857 Type *SrcTy = V->getType();
4858 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4859 "Cannot truncate or zero extend with non-integer arguments!");
4860 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4861 return V; // No conversion
4862 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4863 return getTruncateExpr(V, Ty, Depth);
4864 return getZeroExtendExpr(V, Ty, Depth);
4865}
4866
4868 unsigned Depth) {
4869 Type *SrcTy = V->getType();
4870 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4871 "Cannot truncate or zero extend with non-integer arguments!");
4872 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4873 return V; // No conversion
4874 if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4875 return getTruncateExpr(V, Ty, Depth);
4876 return getSignExtendExpr(V, Ty, Depth);
4877}
4878
4879const SCEV *
4881 Type *SrcTy = V->getType();
4882 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4883 "Cannot noop or zero extend with non-integer arguments!");
4885 "getNoopOrZeroExtend cannot truncate!");
4886 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4887 return V; // No conversion
4888 return getZeroExtendExpr(V, Ty);
4889}
4890
4891const SCEV *
4893 Type *SrcTy = V->getType();
4894 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4895 "Cannot noop or sign extend with non-integer arguments!");
4897 "getNoopOrSignExtend cannot truncate!");
4898 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4899 return V; // No conversion
4900 return getSignExtendExpr(V, Ty);
4901}
4902
4903const SCEV *
4905 Type *SrcTy = V->getType();
4906 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4907 "Cannot noop or any extend with non-integer arguments!");
4909 "getNoopOrAnyExtend cannot truncate!");
4910 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4911 return V; // No conversion
4912 return getAnyExtendExpr(V, Ty);
4913}
4914
4915const SCEV *
4917 Type *SrcTy = V->getType();
4918 assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4919 "Cannot truncate or noop with non-integer arguments!");
4921 "getTruncateOrNoop cannot extend!");
4922 if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4923 return V; // No conversion
4924 return getTruncateExpr(V, Ty);
4925}
4926
4928 const SCEV *RHS) {
4929 const SCEV *PromotedLHS = LHS;
4930 const SCEV *PromotedRHS = RHS;
4931
4932 if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType()))
4933 PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
4934 else
4935 PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
4936
4937 return getUMaxExpr(PromotedLHS, PromotedRHS);
4938}
4939
4941 const SCEV *RHS,
4942 bool Sequential) {
4943 SmallVector<SCEVUse, 2> Ops = {LHS, RHS};
4944 return getUMinFromMismatchedTypes(Ops, Sequential);
4945}
4946
4947const SCEV *
4949 bool Sequential) {
4950 assert(!Ops.empty() && "At least one operand must be!");
4951 // Trivial case.
4952 if (Ops.size() == 1)
4953 return Ops[0];
4954
4955 // Find the max type first.
4956 Type *MaxType = nullptr;
4957 for (SCEVUse S : Ops)
4958 if (MaxType)
4959 MaxType = getWiderType(MaxType, S->getType());
4960 else
4961 MaxType = S->getType();
4962 assert(MaxType && "Failed to find maximum type!");
4963
4964 // Extend all ops to max type.
4965 SmallVector<SCEVUse, 2> PromotedOps;
4966 for (SCEVUse S : Ops)
4967 PromotedOps.push_back(getNoopOrZeroExtend(S, MaxType));
4968
4969 // Generate umin.
4970 return getUMinExpr(PromotedOps, Sequential);
4971}
4972
4974 // A pointer operand may evaluate to a nonpointer expression, such as null.
4975 if (!V->getType()->isPointerTy())
4976 return V;
4977
4978 while (true) {
4979 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(V)) {
4980 V = AddRec->getStart();
4981 } else if (auto *Add = dyn_cast<SCEVAddExpr>(V)) {
4982 const SCEV *PtrOp = nullptr;
4983 for (const SCEV *AddOp : Add->operands()) {
4984 if (AddOp->getType()->isPointerTy()) {
4985 assert(!PtrOp && "Cannot have multiple pointer ops");
4986 PtrOp = AddOp;
4987 }
4988 }
4989 assert(PtrOp && "Must have pointer op");
4990 V = PtrOp;
4991 } else // Not something we can look further into.
4992 return V;
4993 }
4994}
4995
4996/// Push users of the given Instruction onto the given Worklist.
5000 // Push the def-use children onto the Worklist stack.
5001 for (User *U : I->users()) {
5002 auto *UserInsn = cast<Instruction>(U);
5003 if (Visited.insert(UserInsn).second)
5004 Worklist.push_back(UserInsn);
5005 }
5006}
5007
5008namespace {
5009
5010/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its start
5011/// expression in case its Loop is L. If it is not L then
5012/// if IgnoreOtherLoops is true then use AddRec itself
5013/// otherwise rewrite cannot be done.
5014/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
5015class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> {
5016public:
5017 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
5018 bool IgnoreOtherLoops = true) {
5019 SCEVInitRewriter Rewriter(L, SE);
5020 const SCEV *Result = Rewriter.visit(S);
5021 if (Rewriter.hasSeenLoopVariantSCEVUnknown())
5022 return SE.getCouldNotCompute();
5023 return Rewriter.hasSeenOtherLoops() && !IgnoreOtherLoops
5024 ? SE.getCouldNotCompute()
5025 : Result;
5026 }
5027
5028 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5029 if (!SE.isLoopInvariant(Expr, L))
5030 SeenLoopVariantSCEVUnknown = true;
5031 return Expr;
5032 }
5033
5034 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
5035 // Only re-write AddRecExprs for this loop.
5036 if (Expr->getLoop() == L)
5037 return Expr->getStart();
5038 SeenOtherLoops = true;
5039 return Expr;
5040 }
5041
5042 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
5043
5044 bool hasSeenOtherLoops() { return SeenOtherLoops; }
5045
5046private:
5047 explicit SCEVInitRewriter(const Loop *L, ScalarEvolution &SE)
5048 : SCEVRewriteVisitor(SE), L(L) {}
5049
5050 const Loop *L;
5051 bool SeenLoopVariantSCEVUnknown = false;
5052 bool SeenOtherLoops = false;
5053};
5054
5055/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its post
5056/// increment expression in case its Loop is L. If it is not L then
5057/// use AddRec itself.
5058/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
5059class SCEVPostIncRewriter : public SCEVRewriteVisitor<SCEVPostIncRewriter> {
5060public:
5061 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE) {
5062 SCEVPostIncRewriter Rewriter(L, SE);
5063 const SCEV *Result = Rewriter.visit(S);
5064 return Rewriter.hasSeenLoopVariantSCEVUnknown()
5065 ? SE.getCouldNotCompute()
5066 : Result;
5067 }
5068
5069 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5070 if (!SE.isLoopInvariant(Expr, L))
5071 SeenLoopVariantSCEVUnknown = true;
5072 return Expr;
5073 }
5074
5075 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
5076 // Only re-write AddRecExprs for this loop.
5077 if (Expr->getLoop() == L)
5078 return Expr->getPostIncExpr(SE);
5079 SeenOtherLoops = true;
5080 return Expr;
5081 }
5082
5083 bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
5084
5085 bool hasSeenOtherLoops() { return SeenOtherLoops; }
5086
5087private:
5088 explicit SCEVPostIncRewriter(const Loop *L, ScalarEvolution &SE)
5089 : SCEVRewriteVisitor(SE), L(L) {}
5090
5091 const Loop *L;
5092 bool SeenLoopVariantSCEVUnknown = false;
5093 bool SeenOtherLoops = false;
5094};
5095
5096/// This class evaluates the compare condition by matching it against the
5097/// condition of loop latch. If there is a match we assume a true value
5098/// for the condition while building SCEV nodes.
5099class SCEVBackedgeConditionFolder
5100 : public SCEVRewriteVisitor<SCEVBackedgeConditionFolder> {
5101public:
5102 static const SCEV *rewrite(const SCEV *S, const Loop *L,
5103 ScalarEvolution &SE) {
5104 bool IsPosBECond = false;
5105 Value *BECond = nullptr;
5106 if (BasicBlock *Latch = L->getLoopLatch()) {
5107 if (CondBrInst *BI = dyn_cast<CondBrInst>(Latch->getTerminator())) {
5108 assert(BI->getSuccessor(0) != BI->getSuccessor(1) &&
5109 "Both outgoing branches should not target same header!");
5110 BECond = BI->getCondition();
5111 IsPosBECond = BI->getSuccessor(0) == L->getHeader();
5112 } else {
5113 return S;
5114 }
5115 }
5116 SCEVBackedgeConditionFolder Rewriter(L, BECond, IsPosBECond, SE);
5117 return Rewriter.visit(S);
5118 }
5119
5120 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5121 const SCEV *Result = Expr;
5122 bool InvariantF = SE.isLoopInvariant(Expr, L);
5123
5124 if (!InvariantF) {
5126 switch (I->getOpcode()) {
5127 case Instruction::Select: {
5128 SelectInst *SI = cast<SelectInst>(I);
5129 std::optional<const SCEV *> Res =
5130 compareWithBackedgeCondition(SI->getCondition());
5131 if (Res) {
5132 bool IsOne = cast<SCEVConstant>(*Res)->getValue()->isOne();
5133 Result = SE.getSCEV(IsOne ? SI->getTrueValue() : SI->getFalseValue());
5134 }
5135 break;
5136 }
5137 default: {
5138 std::optional<const SCEV *> Res = compareWithBackedgeCondition(I);
5139 if (Res)
5140 Result = *Res;
5141 break;
5142 }
5143 }
5144 }
5145 return Result;
5146 }
5147
5148private:
5149 explicit SCEVBackedgeConditionFolder(const Loop *L, Value *BECond,
5150 bool IsPosBECond, ScalarEvolution &SE)
5151 : SCEVRewriteVisitor(SE), L(L), BackedgeCond(BECond),
5152 IsPositiveBECond(IsPosBECond) {}
5153
5154 std::optional<const SCEV *> compareWithBackedgeCondition(Value *IC);
5155
5156 const Loop *L;
5157 /// Loop back condition.
5158 Value *BackedgeCond = nullptr;
5159 /// Set to true if loop back is on positive branch condition.
5160 bool IsPositiveBECond;
5161};
5162
5163std::optional<const SCEV *>
5164SCEVBackedgeConditionFolder::compareWithBackedgeCondition(Value *IC) {
5165
5166 // If value matches the backedge condition for loop latch,
5167 // then return a constant evolution node based on loopback
5168 // branch taken.
5169 if (BackedgeCond == IC)
5170 return IsPositiveBECond ? SE.getOne(Type::getInt1Ty(SE.getContext()))
5172 return std::nullopt;
5173}
5174
5175class SCEVShiftRewriter : public SCEVRewriteVisitor<SCEVShiftRewriter> {
5176public:
5177 static const SCEV *rewrite(const SCEV *S, const Loop *L,
5178 ScalarEvolution &SE) {
5179 SCEVShiftRewriter Rewriter(L, SE);
5180 const SCEV *Result = Rewriter.visit(S);
5181 return Rewriter.isValid() ? Result : SE.getCouldNotCompute();
5182 }
5183
5184 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
5185 // Only allow AddRecExprs for this loop.
5186 if (!SE.isLoopInvariant(Expr, L))
5187 Valid = false;
5188 return Expr;
5189 }
5190
5191 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
5192 if (Expr->getLoop() == L && Expr->isAffine())
5193 return SE.getMinusSCEV(Expr, Expr->getStepRecurrence(SE));
5194 Valid = false;
5195 return Expr;
5196 }
5197
5198 bool isValid() { return Valid; }
5199
5200private:
5201 explicit SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE)
5202 : SCEVRewriteVisitor(SE), L(L) {}
5203
5204 const Loop *L;
5205 bool Valid = true;
5206};
5207
5208} // end anonymous namespace
5209
5211ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) {
5212 if (!AR->isAffine())
5213 return SCEV::FlagAnyWrap;
5214
5215 using OBO = OverflowingBinaryOperator;
5216
5218
5219 if (!AR->hasNoSelfWrap()) {
5220 const SCEV *BECount = getConstantMaxBackedgeTakenCount(AR->getLoop());
5221 if (const SCEVConstant *BECountMax = dyn_cast<SCEVConstant>(BECount)) {
5222 ConstantRange StepCR = getSignedRange(AR->getStepRecurrence(*this));
5223 const APInt &BECountAP = BECountMax->getAPInt();
5224 unsigned NoOverflowBitWidth =
5225 BECountAP.getActiveBits() + StepCR.getMinSignedBits();
5226 if (NoOverflowBitWidth <= getTypeSizeInBits(AR->getType()))
5228 }
5229 }
5230
5231 if (!AR->hasNoSignedWrap()) {
5232 ConstantRange AddRecRange = getSignedRange(AR);
5233 ConstantRange IncRange = getSignedRange(AR->getStepRecurrence(*this));
5234
5236 Instruction::Add, IncRange, OBO::NoSignedWrap);
5237 if (NSWRegion.contains(AddRecRange))
5239 }
5240
5241 if (!AR->hasNoUnsignedWrap()) {
5242 ConstantRange AddRecRange = getUnsignedRange(AR);
5243 ConstantRange IncRange = getUnsignedRange(AR->getStepRecurrence(*this));
5244
5246 Instruction::Add, IncRange, OBO::NoUnsignedWrap);
5247 if (NUWRegion.contains(AddRecRange))
5249 }
5250
5251 return Result;
5252}
5253
5255ScalarEvolution::proveNoSignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5257
5258 if (AR->hasNoSignedWrap())
5259 return Result;
5260
5261 if (!AR->isAffine())
5262 return Result;
5263
5264 // This function can be expensive, only try to prove NSW once per AddRec.
5265 if (!SignedWrapViaInductionTried.insert(AR).second)
5266 return Result;
5267
5268 const SCEV *Step = AR->getStepRecurrence(*this);
5269 const Loop *L = AR->getLoop();
5270
5271 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5272 // Note that this serves two purposes: It filters out loops that are
5273 // simply not analyzable, and it covers the case where this code is
5274 // being called from within backedge-taken count analysis, such that
5275 // attempting to ask for the backedge-taken count would likely result
5276 // in infinite recursion. In the later case, the analysis code will
5277 // cope with a conservative value, and it will take care to purge
5278 // that value once it has finished.
5279 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5280
5281 // Normally, in the cases we can prove no-overflow via a
5282 // backedge guarding condition, we can also compute a backedge
5283 // taken count for the loop. The exceptions are assumptions and
5284 // guards present in the loop -- SCEV is not great at exploiting
5285 // these to compute max backedge taken counts, but can still use
5286 // these to prove lack of overflow. Use this fact to avoid
5287 // doing extra work that may not pay off.
5288
5289 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5290 AC.assumptions().empty())
5291 return Result;
5292
5293 // If the backedge is guarded by a comparison with the pre-inc value the
5294 // addrec is safe. Also, if the entry is guarded by a comparison with the
5295 // start value and the backedge is guarded by a comparison with the post-inc
5296 // value, the addrec is safe.
5298 const SCEV *OverflowLimit =
5299 getSignedOverflowLimitForStep(Step, &Pred, this);
5300 if (OverflowLimit &&
5301 (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) ||
5302 isKnownOnEveryIteration(Pred, AR, OverflowLimit))) {
5303 Result = setFlags(Result, SCEV::FlagNSW);
5304 }
5305 return Result;
5306}
5308ScalarEvolution::proveNoUnsignedWrapViaInduction(const SCEVAddRecExpr *AR) {
5310
5311 if (AR->hasNoUnsignedWrap())
5312 return Result;
5313
5314 if (!AR->isAffine())
5315 return Result;
5316
5317 // This function can be expensive, only try to prove NUW once per AddRec.
5318 if (!UnsignedWrapViaInductionTried.insert(AR).second)
5319 return Result;
5320
5321 const SCEV *Step = AR->getStepRecurrence(*this);
5322 unsigned BitWidth = getTypeSizeInBits(AR->getType());
5323 const Loop *L = AR->getLoop();
5324
5325 // Check whether the backedge-taken count is SCEVCouldNotCompute.
5326 // Note that this serves two purposes: It filters out loops that are
5327 // simply not analyzable, and it covers the case where this code is
5328 // being called from within backedge-taken count analysis, such that
5329 // attempting to ask for the backedge-taken count would likely result
5330 // in infinite recursion. In the later case, the analysis code will
5331 // cope with a conservative value, and it will take care to purge
5332 // that value once it has finished.
5333 const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
5334
5335 // Normally, in the cases we can prove no-overflow via a
5336 // backedge guarding condition, we can also compute a backedge
5337 // taken count for the loop. The exceptions are assumptions and
5338 // guards present in the loop -- SCEV is not great at exploiting
5339 // these to compute max backedge taken counts, but can still use
5340 // these to prove lack of overflow. Use this fact to avoid
5341 // doing extra work that may not pay off.
5342
5343 if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
5344 AC.assumptions().empty())
5345 return Result;
5346
5347 // If the backedge is guarded by a comparison with the pre-inc value the
5348 // addrec is safe. Also, if the entry is guarded by a comparison with the
5349 // start value and the backedge is guarded by a comparison with the post-inc
5350 // value, the addrec is safe.
5351 if (isKnownPositive(Step)) {
5352 const SCEV *N = getConstant(APInt::getMinValue(BitWidth) -
5353 getUnsignedRangeMax(Step));
5356 Result = setFlags(Result, SCEV::FlagNUW);
5357 }
5358 }
5359
5360 return Result;
5361}
5362
5363namespace {
5364
5365/// Represents an abstract binary operation. This may exist as a
5366/// normal instruction or constant expression, or may have been
5367/// derived from an expression tree.
5368struct BinaryOp {
5369 unsigned Opcode;
5370 Value *LHS;
5371 Value *RHS;
5372 bool IsNSW = false;
5373 bool IsNUW = false;
5374
5375 /// Op is set if this BinaryOp corresponds to a concrete LLVM instruction or
5376 /// constant expression.
5377 Operator *Op = nullptr;
5378
5379 explicit BinaryOp(Operator *Op)
5380 : Opcode(Op->getOpcode()), LHS(Op->getOperand(0)), RHS(Op->getOperand(1)),
5381 Op(Op) {
5382 if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Op)) {
5383 IsNSW = OBO->hasNoSignedWrap();
5384 IsNUW = OBO->hasNoUnsignedWrap();
5385 }
5386 }
5387
5388 explicit BinaryOp(unsigned Opcode, Value *LHS, Value *RHS, bool IsNSW = false,
5389 bool IsNUW = false)
5390 : Opcode(Opcode), LHS(LHS), RHS(RHS), IsNSW(IsNSW), IsNUW(IsNUW) {}
5391};
5392
5393} // end anonymous namespace
5394
5395/// Try to map \p V into a BinaryOp, and return \c std::nullopt on failure.
5396static std::optional<BinaryOp> MatchBinaryOp(Value *V, const DataLayout &DL,
5397 AssumptionCache &AC,
5398 const DominatorTree &DT,
5399 const Instruction *CxtI) {
5400 auto *Op = dyn_cast<Operator>(V);
5401 if (!Op)
5402 return std::nullopt;
5403
5404 // Implementation detail: all the cleverness here should happen without
5405 // creating new SCEV expressions -- our caller knowns tricks to avoid creating
5406 // SCEV expressions when possible, and we should not break that.
5407
5408 switch (Op->getOpcode()) {
5409 case Instruction::Add:
5410 case Instruction::Sub:
5411 case Instruction::Mul:
5412 case Instruction::UDiv:
5413 case Instruction::URem:
5414 case Instruction::And:
5415 case Instruction::AShr:
5416 case Instruction::Shl:
5417 return BinaryOp(Op);
5418
5419 case Instruction::Or: {
5420 // Convert or disjoint into add nuw nsw.
5421 if (cast<PossiblyDisjointInst>(Op)->isDisjoint())
5422 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1),
5423 /*IsNSW=*/true, /*IsNUW=*/true);
5424 return BinaryOp(Op);
5425 }
5426
5427 case Instruction::Xor:
5428 if (auto *RHSC = dyn_cast<ConstantInt>(Op->getOperand(1)))
5429 // If the RHS of the xor is a signmask, then this is just an add.
5430 // Instcombine turns add of signmask into xor as a strength reduction step.
5431 if (RHSC->getValue().isSignMask())
5432 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5433 // Binary `xor` is a bit-wise `add`.
5434 if (V->getType()->isIntegerTy(1))
5435 return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
5436 return BinaryOp(Op);
5437
5438 case Instruction::LShr:
5439 // Turn logical shift right of a constant into a unsigned divide.
5440 if (ConstantInt *SA = dyn_cast<ConstantInt>(Op->getOperand(1))) {
5441 uint32_t BitWidth = cast<IntegerType>(Op->getType())->getBitWidth();
5442
5443 // If the shift count is not less than the bitwidth, the result of
5444 // the shift is undefined. Don't try to analyze it, because the
5445 // resolution chosen here may differ from the resolution chosen in
5446 // other parts of the compiler.
5447 if (SA->getValue().ult(BitWidth)) {
5448 Constant *X =
5449 ConstantInt::get(SA->getContext(),
5450 APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
5451 return BinaryOp(Instruction::UDiv, Op->getOperand(0), X);
5452 }
5453 }
5454 return BinaryOp(Op);
5455
5456 case Instruction::ExtractValue: {
5457 auto *EVI = cast<ExtractValueInst>(Op);
5458 if (EVI->getNumIndices() != 1 || EVI->getIndices()[0] != 0)
5459 break;
5460
5461 auto *WO = dyn_cast<WithOverflowInst>(EVI->getAggregateOperand());
5462 if (!WO)
5463 break;
5464
5465 Instruction::BinaryOps BinOp = WO->getBinaryOp();
5466 bool Signed = WO->isSigned();
5467 // TODO: Should add nuw/nsw flags for mul as well.
5468 if (BinOp == Instruction::Mul || !isOverflowIntrinsicNoWrap(WO, DT))
5469 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS());
5470
5471 // Now that we know that all uses of the arithmetic-result component of
5472 // CI are guarded by the overflow check, we can go ahead and pretend
5473 // that the arithmetic is non-overflowing.
5474 return BinaryOp(BinOp, WO->getLHS(), WO->getRHS(),
5475 /* IsNSW = */ Signed, /* IsNUW = */ !Signed);
5476 }
5477
5478 default:
5479 break;
5480 }
5481
5482 // Recognise intrinsic loop.decrement.reg, and as this has exactly the same
5483 // semantics as a Sub, return a binary sub expression.
5484 if (auto *II = dyn_cast<IntrinsicInst>(V))
5485 if (II->getIntrinsicID() == Intrinsic::loop_decrement_reg)
5486 return BinaryOp(Instruction::Sub, II->getOperand(0), II->getOperand(1));
5487
5488 return std::nullopt;
5489}
5490
5491/// Helper function to createAddRecFromPHIWithCasts. We have a phi
5492/// node whose symbolic (unknown) SCEV is \p SymbolicPHI, which is updated via
5493/// the loop backedge by a SCEVAddExpr, possibly also with a few casts on the
5494/// way. This function checks if \p Op, an operand of this SCEVAddExpr,
5495/// follows one of the following patterns:
5496/// Op == (SExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5497/// Op == (ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
5498/// If the SCEV expression of \p Op conforms with one of the expected patterns
5499/// we return the type of the truncation operation, and indicate whether the
5500/// truncated type should be treated as signed/unsigned by setting
5501/// \p Signed to true/false, respectively.
5502static Type *isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI,
5503 bool &Signed, ScalarEvolution &SE) {
5504 // The case where Op == SymbolicPHI (that is, with no type conversions on
5505 // the way) is handled by the regular add recurrence creating logic and
5506 // would have already been triggered in createAddRecForPHI. Reaching it here
5507 // means that createAddRecFromPHI had failed for this PHI before (e.g.,
5508 // because one of the other operands of the SCEVAddExpr updating this PHI is
5509 // not invariant).
5510 //
5511 // Here we look for the case where Op = (ext(trunc(SymbolicPHI))), and in
5512 // this case predicates that allow us to prove that Op == SymbolicPHI will
5513 // be added.
5514 if (Op == SymbolicPHI)
5515 return nullptr;
5516
5517 unsigned SourceBits = SE.getTypeSizeInBits(SymbolicPHI->getType());
5518 unsigned NewBits = SE.getTypeSizeInBits(Op->getType());
5519 if (SourceBits != NewBits)
5520 return nullptr;
5521
5522 if (match(Op, m_scev_SExt(m_scev_Trunc(m_scev_Specific(SymbolicPHI))))) {
5523 Signed = true;
5524 return cast<SCEVCastExpr>(Op)->getOperand()->getType();
5525 }
5526 if (match(Op, m_scev_ZExt(m_scev_Trunc(m_scev_Specific(SymbolicPHI))))) {
5527 Signed = false;
5528 return cast<SCEVCastExpr>(Op)->getOperand()->getType();
5529 }
5530 return nullptr;
5531}
5532
5533static const Loop *isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI) {
5534 if (!PN->getType()->isIntegerTy())
5535 return nullptr;
5536 const Loop *L = LI.getLoopFor(PN->getParent());
5537 if (!L || L->getHeader() != PN->getParent())
5538 return nullptr;
5539 return L;
5540}
5541
5542// Analyze \p SymbolicPHI, a SCEV expression of a phi node, and check if the
5543// computation that updates the phi follows the following pattern:
5544// (SExt/ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy) + InvariantAccum
5545// which correspond to a phi->trunc->sext/zext->add->phi update chain.
5546// If so, try to see if it can be rewritten as an AddRecExpr under some
5547// Predicates. If successful, return them as a pair. Also cache the results
5548// of the analysis.
5549//
5550// Example usage scenario:
5551// Say the Rewriter is called for the following SCEV:
5552// 8 * ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5553// where:
5554// %X = phi i64 (%Start, %BEValue)
5555// It will visitMul->visitAdd->visitSExt->visitTrunc->visitUnknown(%X),
5556// and call this function with %SymbolicPHI = %X.
5557//
5558// The analysis will find that the value coming around the backedge has
5559// the following SCEV:
5560// BEValue = ((sext i32 (trunc i64 %X to i32) to i64) + %Step)
5561// Upon concluding that this matches the desired pattern, the function
5562// will return the pair {NewAddRec, SmallPredsVec} where:
5563// NewAddRec = {%Start,+,%Step}
5564// SmallPredsVec = {P1, P2, P3} as follows:
5565// P1(WrapPred): AR: {trunc(%Start),+,(trunc %Step)}<nsw> Flags: <nssw>
5566// P2(EqualPred): %Start == (sext i32 (trunc i64 %Start to i32) to i64)
5567// P3(EqualPred): %Step == (sext i32 (trunc i64 %Step to i32) to i64)
5568// The returned pair means that SymbolicPHI can be rewritten into NewAddRec
5569// under the predicates {P1,P2,P3}.
5570// This predicated rewrite will be cached in PredicatedSCEVRewrites:
5571// PredicatedSCEVRewrites[{%X,L}] = {NewAddRec, {P1,P2,P3)}
5572//
5573// TODO's:
5574//
5575// 1) Extend the Induction descriptor to also support inductions that involve
5576// casts: When needed (namely, when we are called in the context of the
5577// vectorizer induction analysis), a Set of cast instructions will be
5578// populated by this method, and provided back to isInductionPHI. This is
5579// needed to allow the vectorizer to properly record them to be ignored by
5580// the cost model and to avoid vectorizing them (otherwise these casts,
5581// which are redundant under the runtime overflow checks, will be
5582// vectorized, which can be costly).
5583//
5584// 2) Support additional induction/PHISCEV patterns: We also want to support
5585// inductions where the sext-trunc / zext-trunc operations (partly) occur
5586// after the induction update operation (the induction increment):
5587//
5588// (Trunc iy (SExt/ZExt ix (%SymbolicPHI + InvariantAccum) to iy) to ix)
5589// which correspond to a phi->add->trunc->sext/zext->phi update chain.
5590//
5591// (Trunc iy ((SExt/ZExt ix (%SymbolicPhi) to iy) + InvariantAccum) to ix)
5592// which correspond to a phi->trunc->add->sext/zext->phi update chain.
5593//
5594// 3) Outline common code with createAddRecFromPHI to avoid duplication.
5595std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5596ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI) {
5598
5599 // *** Part1: Analyze if we have a phi-with-cast pattern for which we can
5600 // return an AddRec expression under some predicate.
5601
5602 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5603 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5604 assert(L && "Expecting an integer loop header phi");
5605
5606 // The loop may have multiple entrances or multiple exits; we can analyze
5607 // this phi as an addrec if it has a unique entry value and a unique
5608 // backedge value.
5609 Value *BEValueV = nullptr, *StartValueV = nullptr;
5610 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5611 Value *V = PN->getIncomingValue(i);
5612 if (L->contains(PN->getIncomingBlock(i))) {
5613 if (!BEValueV) {
5614 BEValueV = V;
5615 } else if (BEValueV != V) {
5616 BEValueV = nullptr;
5617 break;
5618 }
5619 } else if (!StartValueV) {
5620 StartValueV = V;
5621 } else if (StartValueV != V) {
5622 StartValueV = nullptr;
5623 break;
5624 }
5625 }
5626 if (!BEValueV || !StartValueV)
5627 return std::nullopt;
5628
5629 const SCEV *BEValue = getSCEV(BEValueV);
5630
5631 // If the value coming around the backedge is an add with the symbolic
5632 // value we just inserted, possibly with casts that we can ignore under
5633 // an appropriate runtime guard, then we found a simple induction variable!
5634 const auto *Add = dyn_cast<SCEVAddExpr>(BEValue);
5635 if (!Add)
5636 return std::nullopt;
5637
5638 // If there is a single occurrence of the symbolic value, possibly
5639 // casted, replace it with a recurrence.
5640 unsigned FoundIndex = Add->getNumOperands();
5641 Type *TruncTy = nullptr;
5642 bool Signed;
5643 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5644 if ((TruncTy =
5645 isSimpleCastedPHI(Add->getOperand(i), SymbolicPHI, Signed, *this)))
5646 if (FoundIndex == e) {
5647 FoundIndex = i;
5648 break;
5649 }
5650
5651 if (FoundIndex == Add->getNumOperands())
5652 return std::nullopt;
5653
5654 // Create an add with everything but the specified operand.
5656 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5657 if (i != FoundIndex)
5658 Ops.push_back(Add->getOperand(i));
5659 const SCEV *Accum = getAddExpr(Ops);
5660
5661 // The runtime checks will not be valid if the step amount is
5662 // varying inside the loop.
5663 if (!isLoopInvariant(Accum, L))
5664 return std::nullopt;
5665
5666 // *** Part2: Create the predicates
5667
5668 // Analysis was successful: we have a phi-with-cast pattern for which we
5669 // can return an AddRec expression under the following predicates:
5670 //
5671 // P1: A Wrap predicate that guarantees that Trunc(Start) + i*Trunc(Accum)
5672 // fits within the truncated type (does not overflow) for i = 0 to n-1.
5673 // P2: An Equal predicate that guarantees that
5674 // Start = (Ext ix (Trunc iy (Start) to ix) to iy)
5675 // P3: An Equal predicate that guarantees that
5676 // Accum = (Ext ix (Trunc iy (Accum) to ix) to iy)
5677 //
5678 // As we next prove, the above predicates guarantee that:
5679 // Start + i*Accum = (Ext ix (Trunc iy ( Start + i*Accum ) to ix) to iy)
5680 //
5681 //
5682 // More formally, we want to prove that:
5683 // Expr(i+1) = Start + (i+1) * Accum
5684 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5685 //
5686 // Given that:
5687 // 1) Expr(0) = Start
5688 // 2) Expr(1) = Start + Accum
5689 // = (Ext ix (Trunc iy (Start) to ix) to iy) + Accum :: from P2
5690 // 3) Induction hypothesis (step i):
5691 // Expr(i) = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum
5692 //
5693 // Proof:
5694 // Expr(i+1) =
5695 // = Start + (i+1)*Accum
5696 // = (Start + i*Accum) + Accum
5697 // = Expr(i) + Accum
5698 // = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum + Accum
5699 // :: from step i
5700 //
5701 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy) + Accum + Accum
5702 //
5703 // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy)
5704 // + (Ext ix (Trunc iy (Accum) to ix) to iy)
5705 // + Accum :: from P3
5706 //
5707 // = (Ext ix (Trunc iy ((Start + (i-1)*Accum) + Accum) to ix) to iy)
5708 // + Accum :: from P1: Ext(x)+Ext(y)=>Ext(x+y)
5709 //
5710 // = (Ext ix (Trunc iy (Start + i*Accum) to ix) to iy) + Accum
5711 // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum
5712 //
5713 // By induction, the same applies to all iterations 1<=i<n:
5714 //
5715
5716 // Create a truncated addrec for which we will add a no overflow check (P1).
5717 const SCEV *StartVal = getSCEV(StartValueV);
5718 const SCEV *PHISCEV =
5719 getAddRecExpr(getTruncateExpr(StartVal, TruncTy),
5720 getTruncateExpr(Accum, TruncTy), L, SCEV::FlagAnyWrap);
5721
5722 // PHISCEV can be either a SCEVConstant or a SCEVAddRecExpr.
5723 // ex: If truncated Accum is 0 and StartVal is a constant, then PHISCEV
5724 // will be constant.
5725 //
5726 // If PHISCEV is a constant, then P1 degenerates into P2 or P3, so we don't
5727 // add P1.
5728 if (const auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5732 const SCEVPredicate *AddRecPred = getWrapPredicate(AR, AddedFlags);
5733 Predicates.push_back(AddRecPred);
5734 }
5735
5736 // Create the Equal Predicates P2,P3:
5737
5738 // It is possible that the predicates P2 and/or P3 are computable at
5739 // compile time due to StartVal and/or Accum being constants.
5740 // If either one is, then we can check that now and escape if either P2
5741 // or P3 is false.
5742
5743 // Construct the extended SCEV: (Ext ix (Trunc iy (Expr) to ix) to iy)
5744 // for each of StartVal and Accum
5745 auto getExtendedExpr = [&](const SCEV *Expr,
5746 bool CreateSignExtend) -> const SCEV * {
5747 assert(isLoopInvariant(Expr, L) && "Expr is expected to be invariant");
5748 const SCEV *TruncatedExpr = getTruncateExpr(Expr, TruncTy);
5749 const SCEV *ExtendedExpr =
5750 CreateSignExtend ? getSignExtendExpr(TruncatedExpr, Expr->getType())
5751 : getZeroExtendExpr(TruncatedExpr, Expr->getType());
5752 return ExtendedExpr;
5753 };
5754
5755 // Given:
5756 // ExtendedExpr = (Ext ix (Trunc iy (Expr) to ix) to iy
5757 // = getExtendedExpr(Expr)
5758 // Determine whether the predicate P: Expr == ExtendedExpr
5759 // is known to be false at compile time
5760 auto PredIsKnownFalse = [&](const SCEV *Expr,
5761 const SCEV *ExtendedExpr) -> bool {
5762 return Expr != ExtendedExpr &&
5763 isKnownPredicate(ICmpInst::ICMP_NE, Expr, ExtendedExpr);
5764 };
5765
5766 const SCEV *StartExtended = getExtendedExpr(StartVal, Signed);
5767 if (PredIsKnownFalse(StartVal, StartExtended)) {
5768 LLVM_DEBUG(dbgs() << "P2 is compile-time false\n";);
5769 return std::nullopt;
5770 }
5771
5772 // The Step is always Signed (because the overflow checks are either
5773 // NSSW or NUSW)
5774 const SCEV *AccumExtended = getExtendedExpr(Accum, /*CreateSignExtend=*/true);
5775 if (PredIsKnownFalse(Accum, AccumExtended)) {
5776 LLVM_DEBUG(dbgs() << "P3 is compile-time false\n";);
5777 return std::nullopt;
5778 }
5779
5780 auto AppendPredicate = [&](const SCEV *Expr,
5781 const SCEV *ExtendedExpr) -> void {
5782 if (Expr != ExtendedExpr &&
5783 !isKnownPredicate(ICmpInst::ICMP_EQ, Expr, ExtendedExpr)) {
5784 const SCEVPredicate *Pred = getEqualPredicate(Expr, ExtendedExpr);
5785 LLVM_DEBUG(dbgs() << "Added Predicate: " << *Pred);
5786 Predicates.push_back(Pred);
5787 }
5788 };
5789
5790 AppendPredicate(StartVal, StartExtended);
5791 AppendPredicate(Accum, AccumExtended);
5792
5793 // *** Part3: Predicates are ready. Now go ahead and create the new addrec in
5794 // which the casts had been folded away. The caller can rewrite SymbolicPHI
5795 // into NewAR if it will also add the runtime overflow checks specified in
5796 // Predicates.
5797 auto *NewAR = getAddRecExpr(StartVal, Accum, L, SCEV::FlagAnyWrap);
5798
5799 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> PredRewrite =
5800 std::make_pair(NewAR, Predicates);
5801 // Remember the result of the analysis for this SCEV at this locayyytion.
5802 PredicatedSCEVRewrites[{SymbolicPHI, L}] = PredRewrite;
5803 return PredRewrite;
5804}
5805
5806std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5808 auto *PN = cast<PHINode>(SymbolicPHI->getValue());
5809 const Loop *L = isIntegerLoopHeaderPHI(PN, LI);
5810 if (!L)
5811 return std::nullopt;
5812
5813 // Check to see if we already analyzed this PHI.
5814 auto I = PredicatedSCEVRewrites.find({SymbolicPHI, L});
5815 if (I != PredicatedSCEVRewrites.end()) {
5816 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> Rewrite =
5817 I->second;
5818 // Analysis was done before and failed to create an AddRec:
5819 if (Rewrite.first == SymbolicPHI)
5820 return std::nullopt;
5821 // Analysis was done before and succeeded to create an AddRec under
5822 // a predicate:
5823 assert(isa<SCEVAddRecExpr>(Rewrite.first) && "Expected an AddRec");
5824 assert(!(Rewrite.second).empty() && "Expected to find Predicates");
5825 return Rewrite;
5826 }
5827
5828 std::optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
5829 Rewrite = createAddRecFromPHIWithCastsImpl(SymbolicPHI);
5830
5831 // Record in the cache that the analysis failed
5832 if (!Rewrite) {
5834 PredicatedSCEVRewrites[{SymbolicPHI, L}] = {SymbolicPHI, Predicates};
5835 return std::nullopt;
5836 }
5837
5838 return Rewrite;
5839}
5840
5841// FIXME: This utility is currently required because the Rewriter currently
5842// does not rewrite this expression:
5843// {0, +, (sext ix (trunc iy to ix) to iy)}
5844// into {0, +, %step},
5845// even when the following Equal predicate exists:
5846// "%step == (sext ix (trunc iy to ix) to iy)".
5848 const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const {
5849 if (AR1 == AR2)
5850 return true;
5851
5852 auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool {
5853 if (Expr1 != Expr2 &&
5854 !Preds->implies(SE.getEqualPredicate(Expr1, Expr2), SE) &&
5855 !Preds->implies(SE.getEqualPredicate(Expr2, Expr1), SE))
5856 return false;
5857 return true;
5858 };
5859
5860 if (!areExprsEqual(AR1->getStart(), AR2->getStart()) ||
5861 !areExprsEqual(AR1->getStepRecurrence(SE), AR2->getStepRecurrence(SE)))
5862 return false;
5863 return true;
5864}
5865
5866/// A helper function for createAddRecFromPHI to handle simple cases.
5867///
5868/// This function tries to find an AddRec expression for the simplest (yet most
5869/// common) cases: PN = PHI(Start, OP(Self, LoopInvariant)).
5870/// If it fails, createAddRecFromPHI will use a more general, but slow,
5871/// technique for finding the AddRec expression.
5872const SCEV *ScalarEvolution::createSimpleAffineAddRec(PHINode *PN,
5873 Value *BEValueV,
5874 Value *StartValueV) {
5875 const Loop *L = LI.getLoopFor(PN->getParent());
5876 assert(L && L->getHeader() == PN->getParent());
5877 assert(BEValueV && StartValueV);
5878
5879 auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN);
5880 if (!BO)
5881 return nullptr;
5882
5883 if (BO->Opcode != Instruction::Add)
5884 return nullptr;
5885
5886 const SCEV *Accum = nullptr;
5887 if (BO->LHS == PN && L->isLoopInvariant(BO->RHS))
5888 Accum = getSCEV(BO->RHS);
5889 else if (BO->RHS == PN && L->isLoopInvariant(BO->LHS))
5890 Accum = getSCEV(BO->LHS);
5891
5892 if (!Accum)
5893 return nullptr;
5894
5896 if (BO->IsNUW)
5897 Flags = setFlags(Flags, SCEV::FlagNUW);
5898 if (BO->IsNSW)
5899 Flags = setFlags(Flags, SCEV::FlagNSW);
5900
5901 const SCEV *StartVal = getSCEV(StartValueV);
5902 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
5903 insertValueToMap(PN, PHISCEV);
5904
5905 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
5906 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR),
5907 (AR->getNoWrapFlags() | proveNoWrapViaConstantRanges(AR)));
5908 }
5909
5910 // We can add Flags to the post-inc expression only if we
5911 // know that it is *undefined behavior* for BEValueV to
5912 // overflow.
5913 if (auto *BEInst = dyn_cast<Instruction>(BEValueV)) {
5914 assert(isLoopInvariant(Accum, L) &&
5915 "Accum is defined outside L, but is not invariant?");
5916 if (isAddRecNeverPoison(BEInst, L))
5917 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
5918 }
5919
5920 return PHISCEV;
5921}
5922
5923const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
5924 const Loop *L = LI.getLoopFor(PN->getParent());
5925 if (!L || L->getHeader() != PN->getParent())
5926 return nullptr;
5927
5928 // The loop may have multiple entrances or multiple exits; we can analyze
5929 // this phi as an addrec if it has a unique entry value and a unique
5930 // backedge value.
5931 Value *BEValueV = nullptr, *StartValueV = nullptr;
5932 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
5933 Value *V = PN->getIncomingValue(i);
5934 if (L->contains(PN->getIncomingBlock(i))) {
5935 if (!BEValueV) {
5936 BEValueV = V;
5937 } else if (BEValueV != V) {
5938 BEValueV = nullptr;
5939 break;
5940 }
5941 } else if (!StartValueV) {
5942 StartValueV = V;
5943 } else if (StartValueV != V) {
5944 StartValueV = nullptr;
5945 break;
5946 }
5947 }
5948 if (!BEValueV || !StartValueV)
5949 return nullptr;
5950
5951 assert(ValueExprMap.find_as(PN) == ValueExprMap.end() &&
5952 "PHI node already processed?");
5953
5954 // First, try to find AddRec expression without creating a fictituos symbolic
5955 // value for PN.
5956 if (auto *S = createSimpleAffineAddRec(PN, BEValueV, StartValueV))
5957 return S;
5958
5959 // Handle PHI node value symbolically.
5960 const SCEV *SymbolicName = getUnknown(PN);
5961 insertValueToMap(PN, SymbolicName);
5962
5963 // Using this symbolic name for the PHI, analyze the value coming around
5964 // the back-edge.
5965 const SCEV *BEValue = getSCEV(BEValueV);
5966
5967 // NOTE: If BEValue is loop invariant, we know that the PHI node just
5968 // has a special value for the first iteration of the loop.
5969
5970 // If the value coming around the backedge is an add with the symbolic
5971 // value we just inserted, then we found a simple induction variable!
5972 if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) {
5973 // If there is a single occurrence of the symbolic value, replace it
5974 // with a recurrence.
5975 unsigned FoundIndex = Add->getNumOperands();
5976 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5977 if (Add->getOperand(i) == SymbolicName)
5978 if (FoundIndex == e) {
5979 FoundIndex = i;
5980 break;
5981 }
5982
5983 if (FoundIndex != Add->getNumOperands()) {
5984 // Create an add with everything but the specified operand.
5986 for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
5987 if (i != FoundIndex)
5988 Ops.push_back(SCEVBackedgeConditionFolder::rewrite(Add->getOperand(i),
5989 L, *this));
5990 const SCEV *Accum = getAddExpr(Ops);
5991
5992 // This is not a valid addrec if the step amount is varying each
5993 // loop iteration, but is not itself an addrec in this loop.
5994 if (isLoopInvariant(Accum, L) ||
5995 (isa<SCEVAddRecExpr>(Accum) &&
5996 cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) {
5998
5999 if (auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN)) {
6000 if (BO->Opcode == Instruction::Add && BO->LHS == PN) {
6001 if (BO->IsNUW)
6002 Flags = setFlags(Flags, SCEV::FlagNUW);
6003 if (BO->IsNSW)
6004 Flags = setFlags(Flags, SCEV::FlagNSW);
6005 }
6006 } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(BEValueV)) {
6007 if (GEP->getOperand(0) == PN) {
6008 GEPNoWrapFlags NW = GEP->getNoWrapFlags();
6009 // If the increment has any nowrap flags, then we know the address
6010 // space cannot be wrapped around.
6011 if (NW != GEPNoWrapFlags::none())
6012 Flags = setFlags(Flags, SCEV::FlagNW);
6013 // If the GEP is nuw or nusw with non-negative offset, we know that
6014 // no unsigned wrap occurs. We cannot set the nsw flag as only the
6015 // offset is treated as signed, while the base is unsigned.
6016 if (NW.hasNoUnsignedWrap() ||
6018 Flags = setFlags(Flags, SCEV::FlagNUW);
6019 }
6020
6021 // We cannot transfer nuw and nsw flags from subtraction
6022 // operations -- sub nuw X, Y is not the same as add nuw X, -Y
6023 // for instance.
6024 }
6025
6026 const SCEV *StartVal = getSCEV(StartValueV);
6027 const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
6028
6029 // Okay, for the entire analysis of this edge we assumed the PHI
6030 // to be symbolic. We now need to go back and purge all of the
6031 // entries for the scalars that use the symbolic expression.
6032 forgetMemoizedResults({SymbolicName});
6033 insertValueToMap(PN, PHISCEV);
6034
6035 if (auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) {
6037 const_cast<SCEVAddRecExpr *>(AR),
6038 (AR->getNoWrapFlags() | proveNoWrapViaConstantRanges(AR)));
6039 }
6040
6041 // We can add Flags to the post-inc expression only if we
6042 // know that it is *undefined behavior* for BEValueV to
6043 // overflow.
6044 if (auto *BEInst = dyn_cast<Instruction>(BEValueV))
6045 if (isLoopInvariant(Accum, L) && isAddRecNeverPoison(BEInst, L))
6046 (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags);
6047
6048 return PHISCEV;
6049 }
6050 }
6051 } else {
6052 // Otherwise, this could be a loop like this:
6053 // i = 0; for (j = 1; ..; ++j) { .... i = j; }
6054 // In this case, j = {1,+,1} and BEValue is j.
6055 // Because the other in-value of i (0) fits the evolution of BEValue
6056 // i really is an addrec evolution.
6057 //
6058 // We can generalize this saying that i is the shifted value of BEValue
6059 // by one iteration:
6060 // PHI(f(0), f({1,+,1})) --> f({0,+,1})
6061
6062 // Do not allow refinement in rewriting of BEValue.
6063 const SCEV *Shifted = SCEVShiftRewriter::rewrite(BEValue, L, *this);
6064 const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this, false);
6065 if (Shifted != getCouldNotCompute() && Start != getCouldNotCompute() &&
6066 isGuaranteedNotToCauseUB(Shifted) && ::impliesPoison(Shifted, Start)) {
6067 const SCEV *StartVal = getSCEV(StartValueV);
6068 if (Start == StartVal) {
6069 // Okay, for the entire analysis of this edge we assumed the PHI
6070 // to be symbolic. We now need to go back and purge all of the
6071 // entries for the scalars that use the symbolic expression.
6072 forgetMemoizedResults({SymbolicName});
6073 insertValueToMap(PN, Shifted);
6074 return Shifted;
6075 }
6076 }
6077 }
6078
6079 // Remove the temporary PHI node SCEV that has been inserted while intending
6080 // to create an AddRecExpr for this PHI node. We can not keep this temporary
6081 // as it will prevent later (possibly simpler) SCEV expressions to be added
6082 // to the ValueExprMap.
6083 eraseValueFromMap(PN);
6084
6085 return nullptr;
6086}
6087
6088// Try to match a control flow sequence that branches out at BI and merges back
6089// at Merge into a "C ? LHS : RHS" select pattern. Return true on a successful
6090// match.
6092 Value *&C, Value *&LHS, Value *&RHS) {
6093 C = BI->getCondition();
6094
6095 BasicBlockEdge LeftEdge(BI->getParent(), BI->getSuccessor(0));
6096 BasicBlockEdge RightEdge(BI->getParent(), BI->getSuccessor(1));
6097
6098 Use &LeftUse = Merge->getOperandUse(0);
6099 Use &RightUse = Merge->getOperandUse(1);
6100
6101 if (DT.dominates(LeftEdge, LeftUse) && DT.dominates(RightEdge, RightUse)) {
6102 LHS = LeftUse;
6103 RHS = RightUse;
6104 return true;
6105 }
6106
6107 if (DT.dominates(LeftEdge, RightUse) && DT.dominates(RightEdge, LeftUse)) {
6108 LHS = RightUse;
6109 RHS = LeftUse;
6110 return true;
6111 }
6112
6113 return false;
6114}
6115
6117 Value *&Cond, Value *&LHS,
6118 Value *&RHS) {
6119 auto IsReachable =
6120 [&](BasicBlock *BB) { return DT.isReachableFromEntry(BB); };
6121 if (PN->getNumIncomingValues() == 2 && all_of(PN->blocks(), IsReachable)) {
6122 // Try to match
6123 //
6124 // br %cond, label %left, label %right
6125 // left:
6126 // br label %merge
6127 // right:
6128 // br label %merge
6129 // merge:
6130 // V = phi [ %x, %left ], [ %y, %right ]
6131 //
6132 // as "select %cond, %x, %y"
6133
6134 BasicBlock *IDom = DT[PN->getParent()]->getIDom()->getBlock();
6135 assert(IDom && "At least the entry block should dominate PN");
6136
6137 auto *BI = dyn_cast<CondBrInst>(IDom->getTerminator());
6138 return BI && BrPHIToSelect(DT, BI, PN, Cond, LHS, RHS);
6139 }
6140 return false;
6141}
6142
6143const SCEV *ScalarEvolution::createNodeFromSelectLikePHI(PHINode *PN) {
6144 Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr;
6145 if (getOperandsForSelectLikePHI(DT, PN, Cond, LHS, RHS) &&
6148 return createNodeForSelectOrPHI(PN, Cond, LHS, RHS);
6149
6150 return nullptr;
6151}
6152
6154 BinaryOperator *CommonInst = nullptr;
6155 // Check if instructions are identical.
6156 for (Value *Incoming : PN->incoming_values()) {
6157 auto *IncomingInst = dyn_cast<BinaryOperator>(Incoming);
6158 if (!IncomingInst)
6159 return nullptr;
6160 if (CommonInst) {
6161 if (!CommonInst->isIdenticalToWhenDefined(IncomingInst))
6162 return nullptr; // Not identical, give up
6163 } else {
6164 // Remember binary operator
6165 CommonInst = IncomingInst;
6166 }
6167 }
6168 return CommonInst;
6169}
6170
6171/// Returns SCEV for the first operand of a phi if all phi operands have
6172/// identical opcodes and operands
6173/// eg.
6174/// a: %add = %a + %b
6175/// br %c
6176/// b: %add1 = %a + %b
6177/// br %c
6178/// c: %phi = phi [%add, a], [%add1, b]
6179/// scev(%phi) => scev(%add)
6180const SCEV *
6181ScalarEvolution::createNodeForPHIWithIdenticalOperands(PHINode *PN) {
6182 BinaryOperator *CommonInst = getCommonInstForPHI(PN);
6183 if (!CommonInst)
6184 return nullptr;
6185
6186 // Check if SCEV exprs for instructions are identical.
6187 const SCEV *CommonSCEV = getSCEV(CommonInst);
6188 bool SCEVExprsIdentical =
6190 [this, CommonSCEV](Value *V) { return CommonSCEV == getSCEV(V); });
6191 return SCEVExprsIdentical ? CommonSCEV : nullptr;
6192}
6193
6194const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
6195 if (const SCEV *S = createAddRecFromPHI(PN))
6196 return S;
6197
6198 // We do not allow simplifying phi (undef, X) to X here, to avoid reusing the
6199 // phi node for X.
6200 if (Value *V = simplifyInstruction(
6201 PN, {getDataLayout(), &TLI, &DT, &AC, /*CtxI=*/nullptr,
6202 /*UseInstrInfo=*/true, /*CanUseUndef=*/false}))
6203 return getSCEV(V);
6204
6205 if (const SCEV *S = createNodeForPHIWithIdenticalOperands(PN))
6206 return S;
6207
6208 if (const SCEV *S = createNodeFromSelectLikePHI(PN))
6209 return S;
6210
6211 // If it's not a loop phi, we can't handle it yet.
6212 return getUnknown(PN);
6213}
6214
6215bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind,
6216 SCEVTypes RootKind) {
6217 struct FindClosure {
6218 const SCEV *OperandToFind;
6219 const SCEVTypes RootKind; // Must be a sequential min/max expression.
6220 const SCEVTypes NonSequentialRootKind; // Non-seq variant of RootKind.
6221
6222 bool Found = false;
6223
6224 bool canRecurseInto(SCEVTypes Kind) const {
6225 // We can only recurse into the SCEV expression of the same effective type
6226 // as the type of our root SCEV expression, and into zero-extensions.
6227 return RootKind == Kind || NonSequentialRootKind == Kind ||
6228 scZeroExtend == Kind;
6229 };
6230
6231 FindClosure(const SCEV *OperandToFind, SCEVTypes RootKind)
6232 : OperandToFind(OperandToFind), RootKind(RootKind),
6233 NonSequentialRootKind(
6235 RootKind)) {}
6236
6237 bool follow(const SCEV *S) {
6238 Found = S == OperandToFind;
6239
6240 return !isDone() && canRecurseInto(S->getSCEVType());
6241 }
6242
6243 bool isDone() const { return Found; }
6244 };
6245
6246 FindClosure FC(OperandToFind, RootKind);
6247 visitAll(Root, FC);
6248 return FC.Found;
6249}
6250
6251std::optional<const SCEV *>
6252ScalarEvolution::createNodeForSelectOrPHIInstWithICmpInstCond(Type *Ty,
6253 ICmpInst *Cond,
6254 Value *TrueVal,
6255 Value *FalseVal) {
6256 // Try to match some simple smax or umax patterns.
6257 auto *ICI = Cond;
6258
6259 Value *LHS = ICI->getOperand(0);
6260 Value *RHS = ICI->getOperand(1);
6261
6262 switch (ICI->getPredicate()) {
6263 case ICmpInst::ICMP_SLT:
6264 case ICmpInst::ICMP_SLE:
6265 case ICmpInst::ICMP_ULT:
6266 case ICmpInst::ICMP_ULE:
6267 std::swap(LHS, RHS);
6268 [[fallthrough]];
6269 case ICmpInst::ICMP_SGT:
6270 case ICmpInst::ICMP_SGE:
6271 case ICmpInst::ICMP_UGT:
6272 case ICmpInst::ICMP_UGE:
6273 // a > b ? a+x : b+x -> max(a, b)+x
6274 // a > b ? b+x : a+x -> min(a, b)+x
6276 bool Signed = ICI->isSigned();
6277 const SCEV *LA = getSCEV(TrueVal);
6278 const SCEV *RA = getSCEV(FalseVal);
6279 const SCEV *LS = getSCEV(LHS);
6280 const SCEV *RS = getSCEV(RHS);
6281 if (LA->getType()->isPointerTy()) {
6282 // FIXME: Handle cases where LS/RS are pointers not equal to LA/RA.
6283 // Need to make sure we can't produce weird expressions involving
6284 // negated pointers.
6285 if (LA == LS && RA == RS)
6286 return Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS);
6287 if (LA == RS && RA == LS)
6288 return Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS);
6289 }
6290 auto CoerceOperand = [&](const SCEV *Op) -> const SCEV * {
6291 if (Op->getType()->isPointerTy()) {
6294 return Op;
6295 }
6296 if (Signed)
6297 Op = getNoopOrSignExtend(Op, Ty);
6298 else
6299 Op = getNoopOrZeroExtend(Op, Ty);
6300 return Op;
6301 };
6302 LS = CoerceOperand(LS);
6303 RS = CoerceOperand(RS);
6305 break;
6306 const SCEV *LDiff = getMinusSCEV(LA, LS);
6307 const SCEV *RDiff = getMinusSCEV(RA, RS);
6308 if (LDiff == RDiff)
6309 return getAddExpr(Signed ? getSMaxExpr(LS, RS) : getUMaxExpr(LS, RS),
6310 LDiff);
6311 LDiff = getMinusSCEV(LA, RS);
6312 RDiff = getMinusSCEV(RA, LS);
6313 if (LDiff == RDiff)
6314 return getAddExpr(Signed ? getSMinExpr(LS, RS) : getUMinExpr(LS, RS),
6315 LDiff);
6316 }
6317 break;
6318 case ICmpInst::ICMP_NE:
6319 // x != 0 ? x+y : C+y -> x == 0 ? C+y : x+y
6320 std::swap(TrueVal, FalseVal);
6321 [[fallthrough]];
6322 case ICmpInst::ICMP_EQ:
6323 // x == 0 ? C+y : x+y -> umax(x, C)+y iff C u<= 1
6326 const SCEV *X = getNoopOrZeroExtend(getSCEV(LHS), Ty);
6327 const SCEV *TrueValExpr = getSCEV(TrueVal); // C+y
6328 const SCEV *FalseValExpr = getSCEV(FalseVal); // x+y
6329 const SCEV *Y = getMinusSCEV(FalseValExpr, X); // y = (x+y)-x
6330 const SCEV *C = getMinusSCEV(TrueValExpr, Y); // C = (C+y)-y
6331 if (isa<SCEVConstant>(C) && cast<SCEVConstant>(C)->getAPInt().ule(1))
6332 return getAddExpr(getUMaxExpr(X, C), Y);
6333 }
6334 // x == 0 ? 0 : umin (..., x, ...) -> umin_seq(x, umin (...))
6335 // x == 0 ? 0 : umin_seq(..., x, ...) -> umin_seq(x, umin_seq(...))
6336 // x == 0 ? 0 : umin (..., umin_seq(..., x, ...), ...)
6337 // -> umin_seq(x, umin (..., umin_seq(...), ...))
6339 isa<ConstantInt>(TrueVal) && cast<ConstantInt>(TrueVal)->isZero()) {
6340 const SCEV *X = getSCEV(LHS);
6341 while (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(X))
6342 X = ZExt->getOperand();
6343 if (getTypeSizeInBits(X->getType()) <= getTypeSizeInBits(Ty)) {
6344 const SCEV *FalseValExpr = getSCEV(FalseVal);
6345 if (SCEVMinMaxExprContains(FalseValExpr, X, scSequentialUMinExpr))
6346 return getUMinExpr(getNoopOrZeroExtend(X, Ty), FalseValExpr,
6347 /*Sequential=*/true);
6348 }
6349 }
6350 break;
6351 default:
6352 break;
6353 }
6354
6355 return std::nullopt;
6356}
6357
6358static std::optional<const SCEV *>
6360 const SCEV *TrueExpr, const SCEV *FalseExpr) {
6361 assert(CondExpr->getType()->isIntegerTy(1) &&
6362 TrueExpr->getType() == FalseExpr->getType() &&
6363 TrueExpr->getType()->isIntegerTy(1) &&
6364 "Unexpected operands of a select.");
6365
6366 // i1 cond ? i1 x : i1 C --> C + (i1 cond ? (i1 x - i1 C) : i1 0)
6367 // --> C + (umin_seq cond, x - C)
6368 //
6369 // i1 cond ? i1 C : i1 x --> C + (i1 cond ? i1 0 : (i1 x - i1 C))
6370 // --> C + (i1 ~cond ? (i1 x - i1 C) : i1 0)
6371 // --> C + (umin_seq ~cond, x - C)
6372
6373 // FIXME: while we can't legally model the case where both of the hands
6374 // are fully variable, we only require that the *difference* is constant.
6375 if (!isa<SCEVConstant>(TrueExpr) && !isa<SCEVConstant>(FalseExpr))
6376 return std::nullopt;
6377
6378 const SCEV *X, *C;
6379 if (isa<SCEVConstant>(TrueExpr)) {
6380 CondExpr = SE->getNotSCEV(CondExpr);
6381 X = FalseExpr;
6382 C = TrueExpr;
6383 } else {
6384 X = TrueExpr;
6385 C = FalseExpr;
6386 }
6387 return SE->getAddExpr(C, SE->getUMinExpr(CondExpr, SE->getMinusSCEV(X, C),
6388 /*Sequential=*/true));
6389}
6390
6391static std::optional<const SCEV *>
6393 Value *FalseVal) {
6394 if (!isa<ConstantInt>(TrueVal) && !isa<ConstantInt>(FalseVal))
6395 return std::nullopt;
6396
6397 const auto *SECond = SE->getSCEV(Cond);
6398 const auto *SETrue = SE->getSCEV(TrueVal);
6399 const auto *SEFalse = SE->getSCEV(FalseVal);
6400 return createNodeForSelectViaUMinSeq(SE, SECond, SETrue, SEFalse);
6401}
6402
6403const SCEV *ScalarEvolution::createNodeForSelectOrPHIViaUMinSeq(
6404 Value *V, Value *Cond, Value *TrueVal, Value *FalseVal) {
6405 assert(Cond->getType()->isIntegerTy(1) && "Select condition is not an i1?");
6406 assert(TrueVal->getType() == FalseVal->getType() &&
6407 V->getType() == TrueVal->getType() &&
6408 "Types of select hands and of the result must match.");
6409
6410 // For now, only deal with i1-typed `select`s.
6411 if (!V->getType()->isIntegerTy(1))
6412 return getUnknown(V);
6413
6414 if (std::optional<const SCEV *> S =
6415 createNodeForSelectViaUMinSeq(this, Cond, TrueVal, FalseVal))
6416 return *S;
6417
6418 return getUnknown(V);
6419}
6420
6421const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Value *V, Value *Cond,
6422 Value *TrueVal,
6423 Value *FalseVal) {
6424 // Handle "constant" branch or select. This can occur for instance when a
6425 // loop pass transforms an inner loop and moves on to process the outer loop.
6426 if (auto *CI = dyn_cast<ConstantInt>(Cond))
6427 return getSCEV(CI->isOne() ? TrueVal : FalseVal);
6428
6429 if (auto *I = dyn_cast<Instruction>(V)) {
6430 if (auto *ICI = dyn_cast<ICmpInst>(Cond)) {
6431 if (std::optional<const SCEV *> S =
6432 createNodeForSelectOrPHIInstWithICmpInstCond(I->getType(), ICI,
6433 TrueVal, FalseVal))
6434 return *S;
6435 }
6436 }
6437
6438 return createNodeForSelectOrPHIViaUMinSeq(V, Cond, TrueVal, FalseVal);
6439}
6440
6441/// Expand GEP instructions into add and multiply operations. This allows them
6442/// to be analyzed by regular SCEV code.
6443const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
6444 assert(GEP->getSourceElementType()->isSized() &&
6445 "GEP source element type must be sized");
6446
6447 SmallVector<SCEVUse, 4> IndexExprs;
6448 for (Value *Index : GEP->indices())
6449 IndexExprs.push_back(getSCEV(Index));
6450 return getGEPExpr(GEP, IndexExprs);
6451}
6452
6453APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S,
6454 const Instruction *CtxI) {
6455 uint64_t BitWidth = getTypeSizeInBits(S->getType());
6456 auto GetShiftedByZeros = [BitWidth](uint32_t TrailingZeros) {
6457 return TrailingZeros >= BitWidth
6459 : APInt::getOneBitSet(BitWidth, TrailingZeros);
6460 };
6461 auto GetGCDMultiple = [this, CtxI](const SCEVNAryExpr *N) {
6462 // The result is GCD of all operands results.
6463 APInt Res = getConstantMultiple(N->getOperand(0), CtxI);
6464 for (unsigned I = 1, E = N->getNumOperands(); I < E && Res != 1; ++I)
6466 Res, getConstantMultiple(N->getOperand(I), CtxI));
6467 return Res;
6468 };
6469
6470 switch (S->getSCEVType()) {
6471 case scConstant:
6472 return cast<SCEVConstant>(S)->getAPInt();
6473 case scPtrToAddr:
6474 case scPtrToInt:
6475 return getConstantMultiple(cast<SCEVCastExpr>(S)->getOperand());
6476 case scUDivExpr:
6477 case scVScale:
6478 return APInt(BitWidth, 1);
6479 case scTruncate: {
6480 // Only multiples that are a power of 2 will hold after truncation.
6481 const SCEVTruncateExpr *T = cast<SCEVTruncateExpr>(S);
6482 uint32_t TZ = getMinTrailingZeros(T->getOperand(), CtxI);
6483 return GetShiftedByZeros(TZ);
6484 }
6485 case scZeroExtend: {
6486 const SCEVZeroExtendExpr *Z = cast<SCEVZeroExtendExpr>(S);
6487 return getConstantMultiple(Z->getOperand(), CtxI).zext(BitWidth);
6488 }
6489 case scSignExtend: {
6490 // Only multiples that are a power of 2 will hold after sext.
6491 const SCEVSignExtendExpr *E = cast<SCEVSignExtendExpr>(S);
6492 uint32_t TZ = getMinTrailingZeros(E->getOperand(), CtxI);
6493 return GetShiftedByZeros(TZ);
6494 }
6495 case scMulExpr: {
6496 const SCEVMulExpr *M = cast<SCEVMulExpr>(S);
6497 if (M->hasNoUnsignedWrap()) {
6498 // The result is the product of all operand results.
6499 APInt Res = getConstantMultiple(M->getOperand(0), CtxI);
6500 for (const SCEV *Operand : M->operands().drop_front())
6501 Res = Res * getConstantMultiple(Operand, CtxI);
6502 return Res;
6503 }
6504
6505 // If there are no wrap guarentees, find the trailing zeros, which is the
6506 // sum of trailing zeros for all its operands.
6507 uint32_t TZ = 0;
6508 for (const SCEV *Operand : M->operands())
6509 TZ += getMinTrailingZeros(Operand, CtxI);
6510 return GetShiftedByZeros(TZ);
6511 }
6512 case scAddExpr:
6513 case scAddRecExpr: {
6514 const SCEVNAryExpr *N = cast<SCEVNAryExpr>(S);
6515 if (N->hasNoUnsignedWrap())
6516 return GetGCDMultiple(N);
6517 // Find the trailing bits, which is the minimum of its operands.
6518 uint32_t TZ = getMinTrailingZeros(N->getOperand(0), CtxI);
6519 for (const SCEV *Operand : N->operands().drop_front())
6520 TZ = std::min(TZ, getMinTrailingZeros(Operand, CtxI));
6521 return GetShiftedByZeros(TZ);
6522 }
6523 case scUMaxExpr:
6524 case scSMaxExpr:
6525 case scUMinExpr:
6526 case scSMinExpr:
6528 return GetGCDMultiple(cast<SCEVNAryExpr>(S));
6529 case scUnknown: {
6530 // Ask ValueTracking for known bits. SCEVUnknown only become available at
6531 // the point their underlying IR instruction has been defined. If CtxI was
6532 // not provided, use:
6533 // * the first instruction in the entry block if it is an argument
6534 // * the instruction itself otherwise.
6535 const SCEVUnknown *U = cast<SCEVUnknown>(S);
6536 if (!CtxI) {
6537 if (isa<Argument>(U->getValue()))
6538 CtxI = &*F.getEntryBlock().begin();
6539 else if (auto *I = dyn_cast<Instruction>(U->getValue()))
6540 CtxI = I;
6541 }
6542 unsigned Known =
6543 computeKnownBits(U->getValue(),
6544 SimplifyQuery(getDataLayout(), &DT, &AC, CtxI)
6545 .allowEphemerals(true))
6546 .countMinTrailingZeros();
6547 return GetShiftedByZeros(Known);
6548 }
6549 case scCouldNotCompute:
6550 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6551 }
6552 llvm_unreachable("Unknown SCEV kind!");
6553}
6554
6556 const Instruction *CtxI) {
6557 // Skip looking up and updating the cache if there is a context instruction,
6558 // as the result will only be valid in the specified context.
6559 if (CtxI)
6560 return getConstantMultipleImpl(S, CtxI);
6561
6562 auto I = ConstantMultipleCache.find(S);
6563 if (I != ConstantMultipleCache.end())
6564 return I->second;
6565
6566 APInt Result = getConstantMultipleImpl(S, CtxI);
6567 auto InsertPair = ConstantMultipleCache.insert({S, Result});
6568 assert(InsertPair.second && "Should insert a new key");
6569 return InsertPair.first->second;
6570}
6571
6573 APInt Multiple = getConstantMultiple(S);
6574 return Multiple == 0 ? APInt(Multiple.getBitWidth(), 1) : Multiple;
6575}
6576
6578 const Instruction *CtxI) {
6579 return std::min(getConstantMultiple(S, CtxI).countTrailingZeros(),
6580 (unsigned)getTypeSizeInBits(S->getType()));
6581}
6582
6583/// Helper method to assign a range to V from metadata present in the IR.
6584static std::optional<ConstantRange> GetRangeFromMetadata(Value *V) {
6586 if (MDNode *MD = I->getMetadata(LLVMContext::MD_range))
6587 return getConstantRangeFromMetadata(*MD);
6588 if (const auto *CB = dyn_cast<CallBase>(V))
6589 if (std::optional<ConstantRange> Range = CB->getRange())
6590 return Range;
6591 }
6592 if (auto *A = dyn_cast<Argument>(V))
6593 if (std::optional<ConstantRange> Range = A->getRange())
6594 return Range;
6595
6596 return std::nullopt;
6597}
6598
6600 SCEV::NoWrapFlags Flags) {
6601 if (AddRec->getNoWrapFlags(Flags) != Flags) {
6602 AddRec->setNoWrapFlags(Flags);
6603 UnsignedRanges.erase(AddRec);
6604 SignedRanges.erase(AddRec);
6605 ConstantMultipleCache.erase(AddRec);
6606 }
6607}
6608
6609ConstantRange ScalarEvolution::
6610getRangeForUnknownRecurrence(const SCEVUnknown *U) {
6611 const DataLayout &DL = getDataLayout();
6612
6613 unsigned BitWidth = getTypeSizeInBits(U->getType());
6614 const ConstantRange FullSet(BitWidth, /*isFullSet=*/true);
6615
6616 // Match a simple recurrence of the form: <start, ShiftOp, Step>, and then
6617 // use information about the trip count to improve our available range. Note
6618 // that the trip count independent cases are already handled by known bits.
6619 // WARNING: The definition of recurrence used here is subtly different than
6620 // the one used by AddRec (and thus most of this file). Step is allowed to
6621 // be arbitrarily loop varying here, where AddRec allows only loop invariant
6622 // and other addrecs in the same loop (for non-affine addrecs). The code
6623 // below intentionally handles the case where step is not loop invariant.
6624 auto *P = dyn_cast<PHINode>(U->getValue());
6625 if (!P)
6626 return FullSet;
6627
6628 // Make sure that no Phi input comes from an unreachable block. Otherwise,
6629 // even the values that are not available in these blocks may come from them,
6630 // and this leads to false-positive recurrence test.
6631 for (auto *Pred : predecessors(P->getParent()))
6632 if (!DT.isReachableFromEntry(Pred))
6633 return FullSet;
6634
6635 BinaryOperator *BO;
6636 Value *Start, *Step;
6637 if (!matchSimpleRecurrence(P, BO, Start, Step))
6638 return FullSet;
6639
6640 // If we found a recurrence in reachable code, we must be in a loop. Note
6641 // that BO might be in some subloop of L, and that's completely okay.
6642 auto *L = LI.getLoopFor(P->getParent());
6643 assert(L && L->getHeader() == P->getParent());
6644 if (!L->contains(BO->getParent()))
6645 // NOTE: This bailout should be an assert instead. However, asserting
6646 // the condition here exposes a case where LoopFusion is querying SCEV
6647 // with malformed loop information during the midst of the transform.
6648 // There doesn't appear to be an obvious fix, so for the moment bailout
6649 // until the caller issue can be fixed. PR49566 tracks the bug.
6650 return FullSet;
6651
6652 // TODO: Extend to other opcodes such as mul, and div
6653 switch (BO->getOpcode()) {
6654 default:
6655 return FullSet;
6656 case Instruction::AShr:
6657 case Instruction::LShr:
6658 case Instruction::Shl:
6659 break;
6660 };
6661
6662 if (BO->getOperand(0) != P)
6663 // TODO: Handle the power function forms some day.
6664 return FullSet;
6665
6666 unsigned TC = getSmallConstantMaxTripCount(L);
6667 if (!TC || TC >= BitWidth)
6668 return FullSet;
6669
6670 auto KnownStart = computeKnownBits(Start, DL, &AC, nullptr, &DT);
6671 auto KnownStep = computeKnownBits(Step, DL, &AC, nullptr, &DT);
6672 assert(KnownStart.getBitWidth() == BitWidth &&
6673 KnownStep.getBitWidth() == BitWidth);
6674
6675 // Compute total shift amount, being careful of overflow and bitwidths.
6676 auto MaxShiftAmt = KnownStep.getMaxValue();
6677 APInt TCAP(BitWidth, TC-1);
6678 bool Overflow = false;
6679 auto TotalShift = MaxShiftAmt.umul_ov(TCAP, Overflow);
6680 if (Overflow)
6681 return FullSet;
6682
6683 switch (BO->getOpcode()) {
6684 default:
6685 llvm_unreachable("filtered out above");
6686 case Instruction::AShr: {
6687 // For each ashr, three cases:
6688 // shift = 0 => unchanged value
6689 // saturation => 0 or -1
6690 // other => a value closer to zero (of the same sign)
6691 // Thus, the end value is closer to zero than the start.
6692 auto KnownEnd = KnownBits::ashr(KnownStart,
6693 KnownBits::makeConstant(TotalShift));
6694 if (KnownStart.isNonNegative())
6695 // Analogous to lshr (simply not yet canonicalized)
6696 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6697 KnownStart.getMaxValue() + 1);
6698 if (KnownStart.isNegative())
6699 // End >=u Start && End <=s Start
6700 return ConstantRange::getNonEmpty(KnownStart.getMinValue(),
6701 KnownEnd.getMaxValue() + 1);
6702 break;
6703 }
6704 case Instruction::LShr: {
6705 // For each lshr, three cases:
6706 // shift = 0 => unchanged value
6707 // saturation => 0
6708 // other => a smaller positive number
6709 // Thus, the low end of the unsigned range is the last value produced.
6710 auto KnownEnd = KnownBits::lshr(KnownStart,
6711 KnownBits::makeConstant(TotalShift));
6712 return ConstantRange::getNonEmpty(KnownEnd.getMinValue(),
6713 KnownStart.getMaxValue() + 1);
6714 }
6715 case Instruction::Shl: {
6716 // Iff no bits are shifted out, value increases on every shift.
6717 auto KnownEnd = KnownBits::shl(KnownStart,
6718 KnownBits::makeConstant(TotalShift));
6719 if (TotalShift.ult(KnownStart.countMinLeadingZeros()))
6720 return ConstantRange(KnownStart.getMinValue(),
6721 KnownEnd.getMaxValue() + 1);
6722 break;
6723 }
6724 };
6725 return FullSet;
6726}
6727
6728// The goal of this function is to check if recursively visiting the operands
6729// of this PHI might lead to an infinite loop. If we do see such a loop,
6730// there's no good way to break it, so we avoid analyzing such cases.
6731//
6732// getRangeRef previously used a visited set to avoid infinite loops, but this
6733// caused other issues: the result was dependent on the order of getRangeRef
6734// calls, and the interaction with createSCEVIter could cause a stack overflow
6735// in some cases (see issue #148253).
6736//
6737// FIXME: The way this is implemented is overly conservative; this checks
6738// for a few obviously safe patterns, but anything that doesn't lead to
6739// recursion is fine.
6741 Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr;
6743 return true;
6744
6745 if (all_of(PHI->operands(),
6746 [&](Value *Operand) { return DT.dominates(Operand, PHI); }))
6747 return true;
6748
6749 return false;
6750}
6751
6752const ConstantRange &
6753ScalarEvolution::getRangeRefIter(const SCEV *S,
6754 ScalarEvolution::RangeSignHint SignHint) {
6755 DenseMap<const SCEV *, ConstantRange> &Cache =
6756 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6757 : SignedRanges;
6758 SmallVector<SCEVUse> WorkList;
6759 SmallPtrSet<const SCEV *, 8> Seen;
6760
6761 // Add Expr to the worklist, if Expr is either an N-ary expression or a
6762 // SCEVUnknown PHI node.
6763 auto AddToWorklist = [&WorkList, &Seen, &Cache](const SCEV *Expr) {
6764 if (!Seen.insert(Expr).second)
6765 return;
6766 if (Cache.contains(Expr))
6767 return;
6768 switch (Expr->getSCEVType()) {
6769 case scUnknown:
6770 if (!isa<PHINode>(cast<SCEVUnknown>(Expr)->getValue()))
6771 break;
6772 [[fallthrough]];
6773 case scConstant:
6774 case scVScale:
6775 case scTruncate:
6776 case scZeroExtend:
6777 case scSignExtend:
6778 case scPtrToAddr:
6779 case scPtrToInt:
6780 case scAddExpr:
6781 case scMulExpr:
6782 case scUDivExpr:
6783 case scAddRecExpr:
6784 case scUMaxExpr:
6785 case scSMaxExpr:
6786 case scUMinExpr:
6787 case scSMinExpr:
6789 WorkList.push_back(Expr);
6790 break;
6791 case scCouldNotCompute:
6792 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6793 }
6794 };
6795 AddToWorklist(S);
6796
6797 // Build worklist by queuing operands of N-ary expressions and phi nodes.
6798 for (unsigned I = 0; I != WorkList.size(); ++I) {
6799 const SCEV *P = WorkList[I];
6800 auto *UnknownS = dyn_cast<SCEVUnknown>(P);
6801 // If it is not a `SCEVUnknown`, just recurse into operands.
6802 if (!UnknownS) {
6803 for (const SCEV *Op : P->operands())
6804 AddToWorklist(Op);
6805 continue;
6806 }
6807 // `SCEVUnknown`'s require special treatment.
6808 if (PHINode *P = dyn_cast<PHINode>(UnknownS->getValue())) {
6809 if (!RangeRefPHIAllowedOperands(DT, P))
6810 continue;
6811 for (auto &Op : reverse(P->operands()))
6812 AddToWorklist(getSCEV(Op));
6813 }
6814 }
6815
6816 if (!WorkList.empty()) {
6817 // Use getRangeRef to compute ranges for items in the worklist in reverse
6818 // order. This will force ranges for earlier operands to be computed before
6819 // their users in most cases.
6820 for (const SCEV *P : reverse(drop_begin(WorkList))) {
6821 getRangeRef(P, SignHint);
6822 }
6823 }
6824
6825 return getRangeRef(S, SignHint, 0);
6826}
6827
6828/// Determine the range for a particular SCEV. If SignHint is
6829/// HINT_RANGE_UNSIGNED (resp. HINT_RANGE_SIGNED) then getRange prefers ranges
6830/// with a "cleaner" unsigned (resp. signed) representation.
6831const ConstantRange &ScalarEvolution::getRangeRef(
6832 const SCEV *S, ScalarEvolution::RangeSignHint SignHint, unsigned Depth) {
6833 DenseMap<const SCEV *, ConstantRange> &Cache =
6834 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges
6835 : SignedRanges;
6837 SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? ConstantRange::Unsigned
6839
6840 // See if we've computed this range already.
6841 auto I = Cache.find(S);
6842 if (I != Cache.end())
6843 return I->second;
6844
6845 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
6846 return setRange(C, SignHint, ConstantRange(C->getAPInt()));
6847
6848 // Switch to iteratively computing the range for S, if it is part of a deeply
6849 // nested expression.
6851 return getRangeRefIter(S, SignHint);
6852
6853 unsigned BitWidth = getTypeSizeInBits(S->getType());
6854 ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true);
6855 using OBO = OverflowingBinaryOperator;
6856
6857 // If the value has known zeros, the maximum value will have those known zeros
6858 // as well.
6859 if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) {
6860 APInt Multiple = getNonZeroConstantMultiple(S);
6861 APInt Remainder = APInt::getMaxValue(BitWidth).urem(Multiple);
6862 if (!Remainder.isZero())
6863 ConservativeResult =
6864 ConstantRange(APInt::getMinValue(BitWidth),
6865 APInt::getMaxValue(BitWidth) - Remainder + 1);
6866 }
6867 else {
6868 uint32_t TZ = getMinTrailingZeros(S);
6869 if (TZ != 0) {
6870 ConservativeResult = ConstantRange(
6872 APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1);
6873 }
6874 }
6875
6876 switch (S->getSCEVType()) {
6877 case scConstant:
6878 llvm_unreachable("Already handled above.");
6879 case scVScale:
6880 return setRange(S, SignHint, getVScaleRange(&F, BitWidth));
6881 case scTruncate: {
6882 const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(S);
6883 ConstantRange X = getRangeRef(Trunc->getOperand(), SignHint, Depth + 1);
6884 return setRange(
6885 Trunc, SignHint,
6886 ConservativeResult.intersectWith(X.truncate(BitWidth), RangeType));
6887 }
6888 case scZeroExtend: {
6889 const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(S);
6890 ConstantRange X = getRangeRef(ZExt->getOperand(), SignHint, Depth + 1);
6891 return setRange(
6892 ZExt, SignHint,
6893 ConservativeResult.intersectWith(X.zeroExtend(BitWidth), RangeType));
6894 }
6895 case scSignExtend: {
6896 const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(S);
6897 ConstantRange X = getRangeRef(SExt->getOperand(), SignHint, Depth + 1);
6898 return setRange(
6899 SExt, SignHint,
6900 ConservativeResult.intersectWith(X.signExtend(BitWidth), RangeType));
6901 }
6902 case scPtrToAddr:
6903 case scPtrToInt: {
6904 const SCEVCastExpr *Cast = cast<SCEVCastExpr>(S);
6905 ConstantRange X = getRangeRef(Cast->getOperand(), SignHint, Depth + 1);
6906 return setRange(Cast, SignHint, X);
6907 }
6908 case scAddExpr: {
6909 const SCEVAddExpr *Add = cast<SCEVAddExpr>(S);
6910 // Check if this is a URem pattern: A - (A / B) * B, which is always < B.
6911 const SCEV *URemLHS = nullptr, *URemRHS = nullptr;
6912 if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED &&
6913 match(S, m_scev_URem(m_SCEV(URemLHS), m_SCEV(URemRHS), *this))) {
6914 ConstantRange LHSRange = getRangeRef(URemLHS, SignHint, Depth + 1);
6915 ConstantRange RHSRange = getRangeRef(URemRHS, SignHint, Depth + 1);
6916 ConservativeResult =
6917 ConservativeResult.intersectWith(LHSRange.urem(RHSRange), RangeType);
6918 }
6919 ConstantRange X = getRangeRef(Add->getOperand(0), SignHint, Depth + 1);
6920 unsigned WrapType = OBO::AnyWrap;
6921 if (Add->hasNoSignedWrap())
6922 WrapType |= OBO::NoSignedWrap;
6923 if (Add->hasNoUnsignedWrap())
6924 WrapType |= OBO::NoUnsignedWrap;
6925 for (const SCEV *Op : drop_begin(Add->operands()))
6926 X = X.addWithNoWrap(getRangeRef(Op, SignHint, Depth + 1), WrapType,
6927 RangeType);
6928 return setRange(Add, SignHint,
6929 ConservativeResult.intersectWith(X, RangeType));
6930 }
6931 case scMulExpr: {
6932 const SCEVMulExpr *Mul = cast<SCEVMulExpr>(S);
6933 ConstantRange X = getRangeRef(Mul->getOperand(0), SignHint, Depth + 1);
6934 for (const SCEV *Op : drop_begin(Mul->operands()))
6935 X = X.multiply(getRangeRef(Op, SignHint, Depth + 1));
6936 return setRange(Mul, SignHint,
6937 ConservativeResult.intersectWith(X, RangeType));
6938 }
6939 case scUDivExpr: {
6940 const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
6941 ConstantRange X = getRangeRef(UDiv->getLHS(), SignHint, Depth + 1);
6942 ConstantRange Y = getRangeRef(UDiv->getRHS(), SignHint, Depth + 1);
6943 return setRange(UDiv, SignHint,
6944 ConservativeResult.intersectWith(X.udiv(Y), RangeType));
6945 }
6946 case scAddRecExpr: {
6947 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(S);
6948 // If there's no unsigned wrap, the value will never be less than its
6949 // initial value.
6950 if (AddRec->hasNoUnsignedWrap()) {
6951 APInt UnsignedMinValue = getUnsignedRangeMin(AddRec->getStart());
6952 if (!UnsignedMinValue.isZero())
6953 ConservativeResult = ConservativeResult.intersectWith(
6954 ConstantRange(UnsignedMinValue, APInt(BitWidth, 0)), RangeType);
6955 }
6956
6957 // If there's no signed wrap, and all the operands except initial value have
6958 // the same sign or zero, the value won't ever be:
6959 // 1: smaller than initial value if operands are non negative,
6960 // 2: bigger than initial value if operands are non positive.
6961 // For both cases, value can not cross signed min/max boundary.
6962 if (AddRec->hasNoSignedWrap()) {
6963 bool AllNonNeg = true;
6964 bool AllNonPos = true;
6965 for (unsigned i = 1, e = AddRec->getNumOperands(); i != e; ++i) {
6966 if (!isKnownNonNegative(AddRec->getOperand(i)))
6967 AllNonNeg = false;
6968 if (!isKnownNonPositive(AddRec->getOperand(i)))
6969 AllNonPos = false;
6970 }
6971 if (AllNonNeg)
6972 ConservativeResult = ConservativeResult.intersectWith(
6975 RangeType);
6976 else if (AllNonPos)
6977 ConservativeResult = ConservativeResult.intersectWith(
6979 getSignedRangeMax(AddRec->getStart()) +
6980 1),
6981 RangeType);
6982 }
6983
6984 // TODO: non-affine addrec
6985 if (AddRec->isAffine()) {
6986 const SCEV *MaxBEScev =
6988 if (!isa<SCEVCouldNotCompute>(MaxBEScev)) {
6989 APInt MaxBECount = cast<SCEVConstant>(MaxBEScev)->getAPInt();
6990
6991 // Adjust MaxBECount to the same bitwidth as AddRec. We can truncate if
6992 // MaxBECount's active bits are all <= AddRec's bit width.
6993 if (MaxBECount.getBitWidth() > BitWidth &&
6994 MaxBECount.getActiveBits() <= BitWidth)
6995 MaxBECount = MaxBECount.trunc(BitWidth);
6996 else if (MaxBECount.getBitWidth() < BitWidth)
6997 MaxBECount = MaxBECount.zext(BitWidth);
6998
6999 if (MaxBECount.getBitWidth() == BitWidth) {
7000 auto RangeFromAffine = getRangeForAffineAR(
7001 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
7002 ConservativeResult =
7003 ConservativeResult.intersectWith(RangeFromAffine, RangeType);
7004
7005 auto RangeFromFactoring = getRangeViaFactoring(
7006 AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount);
7007 ConservativeResult =
7008 ConservativeResult.intersectWith(RangeFromFactoring, RangeType);
7009 }
7010 }
7011
7012 // Now try symbolic BE count and more powerful methods.
7014 const SCEV *SymbolicMaxBECount =
7016 if (!isa<SCEVCouldNotCompute>(SymbolicMaxBECount) &&
7017 getTypeSizeInBits(MaxBEScev->getType()) <= BitWidth &&
7018 AddRec->hasNoSelfWrap()) {
7019 auto RangeFromAffineNew = getRangeForAffineNoSelfWrappingAR(
7020 AddRec, SymbolicMaxBECount, BitWidth, SignHint);
7021 ConservativeResult =
7022 ConservativeResult.intersectWith(RangeFromAffineNew, RangeType);
7023 }
7024 }
7025 }
7026
7027 return setRange(AddRec, SignHint, std::move(ConservativeResult));
7028 }
7029 case scUMaxExpr:
7030 case scSMaxExpr:
7031 case scUMinExpr:
7032 case scSMinExpr:
7033 case scSequentialUMinExpr: {
7035 switch (S->getSCEVType()) {
7036 case scUMaxExpr:
7037 ID = Intrinsic::umax;
7038 break;
7039 case scSMaxExpr:
7040 ID = Intrinsic::smax;
7041 break;
7042 case scUMinExpr:
7044 ID = Intrinsic::umin;
7045 break;
7046 case scSMinExpr:
7047 ID = Intrinsic::smin;
7048 break;
7049 default:
7050 llvm_unreachable("Unknown SCEVMinMaxExpr/SCEVSequentialMinMaxExpr.");
7051 }
7052
7053 const auto *NAry = cast<SCEVNAryExpr>(S);
7054 ConstantRange X = getRangeRef(NAry->getOperand(0), SignHint, Depth + 1);
7055 for (unsigned i = 1, e = NAry->getNumOperands(); i != e; ++i)
7056 X = X.intrinsic(
7057 ID, {X, getRangeRef(NAry->getOperand(i), SignHint, Depth + 1)});
7058 return setRange(S, SignHint,
7059 ConservativeResult.intersectWith(X, RangeType));
7060 }
7061 case scUnknown: {
7062 const SCEVUnknown *U = cast<SCEVUnknown>(S);
7063 Value *V = U->getValue();
7064
7065 // Check if the IR explicitly contains !range metadata.
7066 std::optional<ConstantRange> MDRange = GetRangeFromMetadata(V);
7067 if (MDRange)
7068 ConservativeResult =
7069 ConservativeResult.intersectWith(*MDRange, RangeType);
7070
7071 // Use facts about recurrences in the underlying IR. Note that add
7072 // recurrences are AddRecExprs and thus don't hit this path. This
7073 // primarily handles shift recurrences.
7074 auto CR = getRangeForUnknownRecurrence(U);
7075 ConservativeResult = ConservativeResult.intersectWith(CR);
7076
7077 // See if ValueTracking can give us a useful range.
7078 const DataLayout &DL = getDataLayout();
7079 KnownBits Known = computeKnownBits(V, DL, &AC, nullptr, &DT);
7080 if (Known.getBitWidth() != BitWidth)
7081 Known = Known.zextOrTrunc(BitWidth);
7082
7083 // ValueTracking may be able to compute a tighter result for the number of
7084 // sign bits than for the value of those sign bits.
7085 unsigned NS = ComputeNumSignBits(V, DL, &AC, nullptr, &DT);
7086 if (U->getType()->isPointerTy()) {
7087 // If the pointer size is larger than the index size type, this can cause
7088 // NS to be larger than BitWidth. So compensate for this.
7089 unsigned ptrSize = DL.getPointerTypeSizeInBits(U->getType());
7090 int ptrIdxDiff = ptrSize - BitWidth;
7091 if (ptrIdxDiff > 0 && ptrSize > BitWidth && NS > (unsigned)ptrIdxDiff)
7092 NS -= ptrIdxDiff;
7093 }
7094
7095 if (NS > 1) {
7096 // If we know any of the sign bits, we know all of the sign bits.
7097 if (!Known.Zero.getHiBits(NS).isZero())
7098 Known.Zero.setHighBits(NS);
7099 if (!Known.One.getHiBits(NS).isZero())
7100 Known.One.setHighBits(NS);
7101 }
7102
7103 if (Known.getMinValue() != Known.getMaxValue() + 1)
7104 ConservativeResult = ConservativeResult.intersectWith(
7105 ConstantRange(Known.getMinValue(), Known.getMaxValue() + 1),
7106 RangeType);
7107 if (NS > 1)
7108 ConservativeResult = ConservativeResult.intersectWith(
7109 ConstantRange(APInt::getSignedMinValue(BitWidth).ashr(NS - 1),
7110 APInt::getSignedMaxValue(BitWidth).ashr(NS - 1) + 1),
7111 RangeType);
7112
7113 if (U->getType()->isPointerTy() && SignHint == HINT_RANGE_UNSIGNED) {
7114 // Strengthen the range if the underlying IR value is a
7115 // global/alloca/heap allocation using the size of the object.
7116 bool CanBeNull, CanBeFreed;
7117 uint64_t DerefBytes =
7118 V->getPointerDereferenceableBytes(DL, CanBeNull, CanBeFreed);
7119 if (DerefBytes > 1 && isUIntN(BitWidth, DerefBytes)) {
7120 // The highest address the object can start is DerefBytes bytes before
7121 // the end (unsigned max value). If this value is not a multiple of the
7122 // alignment, the last possible start value is the next lowest multiple
7123 // of the alignment. Note: The computations below cannot overflow,
7124 // because if they would there's no possible start address for the
7125 // object.
7126 APInt MaxVal =
7127 APInt::getMaxValue(BitWidth) - APInt(BitWidth, DerefBytes);
7128 uint64_t Align = U->getValue()->getPointerAlignment(DL).value();
7129 uint64_t Rem = MaxVal.urem(Align);
7130 MaxVal -= APInt(BitWidth, Rem);
7131 APInt MinVal = APInt::getZero(BitWidth);
7132 if (llvm::isKnownNonZero(V, DL))
7133 MinVal = Align;
7134 ConservativeResult = ConservativeResult.intersectWith(
7135 ConstantRange::getNonEmpty(MinVal, MaxVal + 1), RangeType);
7136 }
7137 }
7138
7139 // A range of Phi is a subset of union of all ranges of its input.
7140 if (PHINode *Phi = dyn_cast<PHINode>(V)) {
7141 // SCEVExpander sometimes creates SCEVUnknowns that are secretly
7142 // AddRecs; return the range for the corresponding AddRec.
7143 if (auto *AR = dyn_cast<SCEVAddRecExpr>(getSCEV(V)))
7144 return getRangeRef(AR, SignHint, Depth + 1);
7145
7146 // Make sure that we do not run over cycled Phis.
7147 if (RangeRefPHIAllowedOperands(DT, Phi)) {
7148 ConstantRange RangeFromOps(BitWidth, /*isFullSet=*/false);
7149
7150 for (const auto &Op : Phi->operands()) {
7151 auto OpRange = getRangeRef(getSCEV(Op), SignHint, Depth + 1);
7152 RangeFromOps = RangeFromOps.unionWith(OpRange);
7153 // No point to continue if we already have a full set.
7154 if (RangeFromOps.isFullSet())
7155 break;
7156 }
7157 ConservativeResult =
7158 ConservativeResult.intersectWith(RangeFromOps, RangeType);
7159 }
7160 }
7161
7162 // vscale can't be equal to zero
7163 if (const auto *II = dyn_cast<IntrinsicInst>(V))
7164 if (II->getIntrinsicID() == Intrinsic::vscale) {
7165 ConstantRange Disallowed = APInt::getZero(BitWidth);
7166 ConservativeResult = ConservativeResult.difference(Disallowed);
7167 }
7168
7169 return setRange(U, SignHint, std::move(ConservativeResult));
7170 }
7171 case scCouldNotCompute:
7172 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
7173 }
7174
7175 return setRange(S, SignHint, std::move(ConservativeResult));
7176}
7177
7178// Given a StartRange, Step and MaxBECount for an expression compute a range of
7179// values that the expression can take. Initially, the expression has a value
7180// from StartRange and then is changed by Step up to MaxBECount times. Signed
7181// argument defines if we treat Step as signed or unsigned.
7183 const ConstantRange &StartRange,
7184 const APInt &MaxBECount,
7185 bool Signed) {
7186 unsigned BitWidth = Step.getBitWidth();
7187 assert(BitWidth == StartRange.getBitWidth() &&
7188 BitWidth == MaxBECount.getBitWidth() && "mismatched bit widths");
7189 // If either Step or MaxBECount is 0, then the expression won't change, and we
7190 // just need to return the initial range.
7191 if (Step == 0 || MaxBECount == 0)
7192 return StartRange;
7193
7194 // If we don't know anything about the initial value (i.e. StartRange is
7195 // FullRange), then we don't know anything about the final range either.
7196 // Return FullRange.
7197 if (StartRange.isFullSet())
7198 return ConstantRange::getFull(BitWidth);
7199
7200 // If Step is signed and negative, then we use its absolute value, but we also
7201 // note that we're moving in the opposite direction.
7202 bool Descending = Signed && Step.isNegative();
7203
7204 if (Signed)
7205 // This is correct even for INT_SMIN. Let's look at i8 to illustrate this:
7206 // abs(INT_SMIN) = abs(-128) = abs(0x80) = -0x80 = 0x80 = 128.
7207 // This equations hold true due to the well-defined wrap-around behavior of
7208 // APInt.
7209 Step = Step.abs();
7210
7211 // Check if Offset is more than full span of BitWidth. If it is, the
7212 // expression is guaranteed to overflow.
7213 if (APInt::getMaxValue(StartRange.getBitWidth()).udiv(Step).ult(MaxBECount))
7214 return ConstantRange::getFull(BitWidth);
7215
7216 // Offset is by how much the expression can change. Checks above guarantee no
7217 // overflow here.
7218 APInt Offset = Step * MaxBECount;
7219
7220 // Minimum value of the final range will match the minimal value of StartRange
7221 // if the expression is increasing and will be decreased by Offset otherwise.
7222 // Maximum value of the final range will match the maximal value of StartRange
7223 // if the expression is decreasing and will be increased by Offset otherwise.
7224 APInt StartLower = StartRange.getLower();
7225 APInt StartUpper = StartRange.getUpper() - 1;
7226 APInt MovedBoundary = Descending ? (StartLower - std::move(Offset))
7227 : (StartUpper + std::move(Offset));
7228
7229 // It's possible that the new minimum/maximum value will fall into the initial
7230 // range (due to wrap around). This means that the expression can take any
7231 // value in this bitwidth, and we have to return full range.
7232 if (StartRange.contains(MovedBoundary))
7233 return ConstantRange::getFull(BitWidth);
7234
7235 APInt NewLower =
7236 Descending ? std::move(MovedBoundary) : std::move(StartLower);
7237 APInt NewUpper =
7238 Descending ? std::move(StartUpper) : std::move(MovedBoundary);
7239 NewUpper += 1;
7240
7241 // No overflow detected, return [StartLower, StartUpper + Offset + 1) range.
7242 return ConstantRange::getNonEmpty(std::move(NewLower), std::move(NewUpper));
7243}
7244
7245ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start,
7246 const SCEV *Step,
7247 const APInt &MaxBECount) {
7248 assert(getTypeSizeInBits(Start->getType()) ==
7249 getTypeSizeInBits(Step->getType()) &&
7250 getTypeSizeInBits(Start->getType()) == MaxBECount.getBitWidth() &&
7251 "mismatched bit widths");
7252
7253 // First, consider step signed.
7254 ConstantRange StartSRange = getSignedRange(Start);
7255 ConstantRange StepSRange = getSignedRange(Step);
7256
7257 // If Step can be both positive and negative, we need to find ranges for the
7258 // maximum absolute step values in both directions and union them.
7259 ConstantRange SR = getRangeForAffineARHelper(
7260 StepSRange.getSignedMin(), StartSRange, MaxBECount, /* Signed = */ true);
7262 StartSRange, MaxBECount,
7263 /* Signed = */ true));
7264
7265 // Next, consider step unsigned.
7266 ConstantRange UR = getRangeForAffineARHelper(
7267 getUnsignedRangeMax(Step), getUnsignedRange(Start), MaxBECount,
7268 /* Signed = */ false);
7269
7270 // Finally, intersect signed and unsigned ranges.
7272}
7273
7274ConstantRange ScalarEvolution::getRangeForAffineNoSelfWrappingAR(
7275 const SCEVAddRecExpr *AddRec, const SCEV *MaxBECount, unsigned BitWidth,
7276 ScalarEvolution::RangeSignHint SignHint) {
7277 assert(AddRec->isAffine() && "Non-affine AddRecs are not suppored!\n");
7278 assert(AddRec->hasNoSelfWrap() &&
7279 "This only works for non-self-wrapping AddRecs!");
7280 const bool IsSigned = SignHint == HINT_RANGE_SIGNED;
7281 const SCEV *Step = AddRec->getStepRecurrence(*this);
7282 // Only deal with constant step to save compile time.
7283 if (!isa<SCEVConstant>(Step))
7284 return ConstantRange::getFull(BitWidth);
7285 // Let's make sure that we can prove that we do not self-wrap during
7286 // MaxBECount iterations. We need this because MaxBECount is a maximum
7287 // iteration count estimate, and we might infer nw from some exit for which we
7288 // do not know max exit count (or any other side reasoning).
7289 // TODO: Turn into assert at some point.
7290 if (getTypeSizeInBits(MaxBECount->getType()) >
7291 getTypeSizeInBits(AddRec->getType()))
7292 return ConstantRange::getFull(BitWidth);
7293 MaxBECount = getNoopOrZeroExtend(MaxBECount, AddRec->getType());
7294 const SCEV *RangeWidth = getMinusOne(AddRec->getType());
7295 const SCEV *StepAbs = getUMinExpr(Step, getNegativeSCEV(Step));
7296 const SCEV *MaxItersWithoutWrap = getUDivExpr(RangeWidth, StepAbs);
7297 if (!isKnownPredicateViaConstantRanges(ICmpInst::ICMP_ULE, MaxBECount,
7298 MaxItersWithoutWrap))
7299 return ConstantRange::getFull(BitWidth);
7300
7301 ICmpInst::Predicate LEPred =
7303 ICmpInst::Predicate GEPred =
7305 const SCEV *End = AddRec->evaluateAtIteration(MaxBECount, *this);
7306
7307 // We know that there is no self-wrap. Let's take Start and End values and
7308 // look at all intermediate values V1, V2, ..., Vn that IndVar takes during
7309 // the iteration. They either lie inside the range [Min(Start, End),
7310 // Max(Start, End)] or outside it:
7311 //
7312 // Case 1: RangeMin ... Start V1 ... VN End ... RangeMax;
7313 // Case 2: RangeMin Vk ... V1 Start ... End Vn ... Vk + 1 RangeMax;
7314 //
7315 // No self wrap flag guarantees that the intermediate values cannot be BOTH
7316 // outside and inside the range [Min(Start, End), Max(Start, End)]. Using that
7317 // knowledge, let's try to prove that we are dealing with Case 1. It is so if
7318 // Start <= End and step is positive, or Start >= End and step is negative.
7319 const SCEV *Start = applyLoopGuards(AddRec->getStart(), AddRec->getLoop());
7320 ConstantRange StartRange = getRangeRef(Start, SignHint);
7321 ConstantRange EndRange = getRangeRef(End, SignHint);
7322 ConstantRange RangeBetween = StartRange.unionWith(EndRange);
7323 // If they already cover full iteration space, we will know nothing useful
7324 // even if we prove what we want to prove.
7325 if (RangeBetween.isFullSet())
7326 return RangeBetween;
7327 // Only deal with ranges that do not wrap (i.e. RangeMin < RangeMax).
7328 bool IsWrappedSet = IsSigned ? RangeBetween.isSignWrappedSet()
7329 : RangeBetween.isWrappedSet();
7330 if (IsWrappedSet)
7331 return ConstantRange::getFull(BitWidth);
7332
7333 if (isKnownPositive(Step) &&
7334 isKnownPredicateViaConstantRanges(LEPred, Start, End))
7335 return RangeBetween;
7336 if (isKnownNegative(Step) &&
7337 isKnownPredicateViaConstantRanges(GEPred, Start, End))
7338 return RangeBetween;
7339 return ConstantRange::getFull(BitWidth);
7340}
7341
7342ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start,
7343 const SCEV *Step,
7344 const APInt &MaxBECount) {
7345 // RangeOf({C?A:B,+,C?P:Q}) == RangeOf(C?{A,+,P}:{B,+,Q})
7346 // == RangeOf({A,+,P}) union RangeOf({B,+,Q})
7347
7348 unsigned BitWidth = MaxBECount.getBitWidth();
7349 assert(getTypeSizeInBits(Start->getType()) == BitWidth &&
7350 getTypeSizeInBits(Step->getType()) == BitWidth &&
7351 "mismatched bit widths");
7352
7353 struct SelectPattern {
7354 Value *Condition = nullptr;
7355 APInt TrueValue;
7356 APInt FalseValue;
7357
7358 explicit SelectPattern(ScalarEvolution &SE, unsigned BitWidth,
7359 const SCEV *S) {
7360 std::optional<unsigned> CastOp;
7361 APInt Offset(BitWidth, 0);
7362
7364 "Should be!");
7365
7366 // Peel off a constant offset. In the future we could consider being
7367 // smarter here and handle {Start+Step,+,Step} too.
7368 const APInt *Off;
7369 if (match(S, m_scev_Add(m_scev_APInt(Off), m_SCEV(S))))
7370 Offset = *Off;
7371
7372 // Peel off a cast operation
7373 if (auto *SCast = dyn_cast<SCEVIntegralCastExpr>(S)) {
7374 CastOp = SCast->getSCEVType();
7375 S = SCast->getOperand();
7376 }
7377
7378 using namespace llvm::PatternMatch;
7379
7380 auto *SU = dyn_cast<SCEVUnknown>(S);
7381 const APInt *TrueVal, *FalseVal;
7382 if (!SU ||
7383 !match(SU->getValue(), m_Select(m_Value(Condition), m_APInt(TrueVal),
7384 m_APInt(FalseVal)))) {
7385 Condition = nullptr;
7386 return;
7387 }
7388
7389 TrueValue = *TrueVal;
7390 FalseValue = *FalseVal;
7391
7392 // Re-apply the cast we peeled off earlier
7393 if (CastOp)
7394 switch (*CastOp) {
7395 default:
7396 llvm_unreachable("Unknown SCEV cast type!");
7397
7398 case scTruncate:
7399 TrueValue = TrueValue.trunc(BitWidth);
7400 FalseValue = FalseValue.trunc(BitWidth);
7401 break;
7402 case scZeroExtend:
7403 TrueValue = TrueValue.zext(BitWidth);
7404 FalseValue = FalseValue.zext(BitWidth);
7405 break;
7406 case scSignExtend:
7407 TrueValue = TrueValue.sext(BitWidth);
7408 FalseValue = FalseValue.sext(BitWidth);
7409 break;
7410 }
7411
7412 // Re-apply the constant offset we peeled off earlier
7413 TrueValue += Offset;
7414 FalseValue += Offset;
7415 }
7416
7417 bool isRecognized() { return Condition != nullptr; }
7418 };
7419
7420 SelectPattern StartPattern(*this, BitWidth, Start);
7421 if (!StartPattern.isRecognized())
7422 return ConstantRange::getFull(BitWidth);
7423
7424 SelectPattern StepPattern(*this, BitWidth, Step);
7425 if (!StepPattern.isRecognized())
7426 return ConstantRange::getFull(BitWidth);
7427
7428 if (StartPattern.Condition != StepPattern.Condition) {
7429 // We don't handle this case today; but we could, by considering four
7430 // possibilities below instead of two. I'm not sure if there are cases where
7431 // that will help over what getRange already does, though.
7432 return ConstantRange::getFull(BitWidth);
7433 }
7434
7435 // NB! Calling ScalarEvolution::getConstant is fine, but we should not try to
7436 // construct arbitrary general SCEV expressions here. This function is called
7437 // from deep in the call stack, and calling getSCEV (on a sext instruction,
7438 // say) can end up caching a suboptimal value.
7439
7440 // FIXME: without the explicit `this` receiver below, MSVC errors out with
7441 // C2352 and C2512 (otherwise it isn't needed).
7442
7443 const SCEV *TrueStart = this->getConstant(StartPattern.TrueValue);
7444 const SCEV *TrueStep = this->getConstant(StepPattern.TrueValue);
7445 const SCEV *FalseStart = this->getConstant(StartPattern.FalseValue);
7446 const SCEV *FalseStep = this->getConstant(StepPattern.FalseValue);
7447
7448 ConstantRange TrueRange =
7449 this->getRangeForAffineAR(TrueStart, TrueStep, MaxBECount);
7450 ConstantRange FalseRange =
7451 this->getRangeForAffineAR(FalseStart, FalseStep, MaxBECount);
7452
7453 return TrueRange.unionWith(FalseRange);
7454}
7455
7456SCEV::NoWrapFlags ScalarEvolution::getNoWrapFlagsFromUB(const Value *V) {
7457 if (isa<ConstantExpr>(V)) return SCEV::FlagAnyWrap;
7458 const BinaryOperator *BinOp = cast<BinaryOperator>(V);
7459
7460 // Return early if there are no flags to propagate to the SCEV.
7462 if (BinOp->hasNoUnsignedWrap())
7464 if (BinOp->hasNoSignedWrap())
7466 if (Flags == SCEV::FlagAnyWrap)
7467 return SCEV::FlagAnyWrap;
7468
7469 return isSCEVExprNeverPoison(BinOp) ? Flags : SCEV::FlagAnyWrap;
7470}
7471
7472const Instruction *
7473ScalarEvolution::getNonTrivialDefiningScopeBound(const SCEV *S) {
7474 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S))
7475 return &*AddRec->getLoop()->getHeader()->begin();
7476 if (auto *U = dyn_cast<SCEVUnknown>(S))
7477 if (auto *I = dyn_cast<Instruction>(U->getValue()))
7478 return I;
7479 return nullptr;
7480}
7481
7482const Instruction *ScalarEvolution::getDefiningScopeBound(ArrayRef<SCEVUse> Ops,
7483 bool &Precise) {
7484 Precise = true;
7485 // Do a bounded search of the def relation of the requested SCEVs.
7486 SmallPtrSet<const SCEV *, 16> Visited;
7487 SmallVector<SCEVUse> Worklist;
7488 auto pushOp = [&](const SCEV *S) {
7489 if (!Visited.insert(S).second)
7490 return;
7491 // Threshold of 30 here is arbitrary.
7492 if (Visited.size() > 30) {
7493 Precise = false;
7494 return;
7495 }
7496 Worklist.push_back(S);
7497 };
7498
7499 for (SCEVUse S : Ops)
7500 pushOp(S);
7501
7502 const Instruction *Bound = nullptr;
7503 while (!Worklist.empty()) {
7504 SCEVUse S = Worklist.pop_back_val();
7505 if (auto *DefI = getNonTrivialDefiningScopeBound(S)) {
7506 if (!Bound || DT.dominates(Bound, DefI))
7507 Bound = DefI;
7508 } else {
7509 for (SCEVUse Op : S->operands())
7510 pushOp(Op);
7511 }
7512 }
7513 return Bound ? Bound : &*F.getEntryBlock().begin();
7514}
7515
7516const Instruction *
7517ScalarEvolution::getDefiningScopeBound(ArrayRef<SCEVUse> Ops) {
7518 bool Discard;
7519 return getDefiningScopeBound(Ops, Discard);
7520}
7521
7522bool ScalarEvolution::isGuaranteedToTransferExecutionTo(const Instruction *A,
7523 const Instruction *B) {
7524 if (A->getParent() == B->getParent() &&
7526 B->getIterator()))
7527 return true;
7528
7529 auto *BLoop = LI.getLoopFor(B->getParent());
7530 if (BLoop && BLoop->getHeader() == B->getParent() &&
7531 BLoop->getLoopPreheader() == A->getParent() &&
7533 A->getParent()->end()) &&
7534 isGuaranteedToTransferExecutionToSuccessor(B->getParent()->begin(),
7535 B->getIterator()))
7536 return true;
7537 return false;
7538}
7539
7540bool ScalarEvolution::isGuaranteedNotToBePoison(const SCEV *Op) {
7541 SCEVPoisonCollector PC(/* LookThroughMaybePoisonBlocking */ true);
7542 visitAll(Op, PC);
7543 return PC.MaybePoison.empty();
7544}
7545
7546bool ScalarEvolution::isGuaranteedNotToCauseUB(const SCEV *Op) {
7547 return !SCEVExprContains(Op, [this](const SCEV *S) {
7548 const SCEV *Op1;
7549 bool M = match(S, m_scev_UDiv(m_SCEV(), m_SCEV(Op1)));
7550 // The UDiv may be UB if the divisor is poison or zero. Unless the divisor
7551 // is a non-zero constant, we have to assume the UDiv may be UB.
7552 return M && (!isKnownNonZero(Op1) || !isGuaranteedNotToBePoison(Op1));
7553 });
7554}
7555
7556bool ScalarEvolution::isSCEVExprNeverPoison(const Instruction *I) {
7557 // Only proceed if we can prove that I does not yield poison.
7559 return false;
7560
7561 // At this point we know that if I is executed, then it does not wrap
7562 // according to at least one of NSW or NUW. If I is not executed, then we do
7563 // not know if the calculation that I represents would wrap. Multiple
7564 // instructions can map to the same SCEV. If we apply NSW or NUW from I to
7565 // the SCEV, we must guarantee no wrapping for that SCEV also when it is
7566 // derived from other instructions that map to the same SCEV. We cannot make
7567 // that guarantee for cases where I is not executed. So we need to find a
7568 // upper bound on the defining scope for the SCEV, and prove that I is
7569 // executed every time we enter that scope. When the bounding scope is a
7570 // loop (the common case), this is equivalent to proving I executes on every
7571 // iteration of that loop.
7572 SmallVector<SCEVUse> SCEVOps;
7573 for (const Use &Op : I->operands()) {
7574 // I could be an extractvalue from a call to an overflow intrinsic.
7575 // TODO: We can do better here in some cases.
7576 if (isSCEVable(Op->getType()))
7577 SCEVOps.push_back(getSCEV(Op));
7578 }
7579 auto *DefI = getDefiningScopeBound(SCEVOps);
7580 return isGuaranteedToTransferExecutionTo(DefI, I);
7581}
7582
7583bool ScalarEvolution::isAddRecNeverPoison(const Instruction *I, const Loop *L) {
7584 // If we know that \c I can never be poison period, then that's enough.
7585 if (isSCEVExprNeverPoison(I))
7586 return true;
7587
7588 // If the loop only has one exit, then we know that, if the loop is entered,
7589 // any instruction dominating that exit will be executed. If any such
7590 // instruction would result in UB, the addrec cannot be poison.
7591 //
7592 // This is basically the same reasoning as in isSCEVExprNeverPoison(), but
7593 // also handles uses outside the loop header (they just need to dominate the
7594 // single exit).
7595
7596 auto *ExitingBB = L->getExitingBlock();
7597 if (!ExitingBB || !loopHasNoAbnormalExits(L))
7598 return false;
7599
7600 SmallPtrSet<const Value *, 16> KnownPoison;
7602
7603 // We start by assuming \c I, the post-inc add recurrence, is poison. Only
7604 // things that are known to be poison under that assumption go on the
7605 // Worklist.
7606 KnownPoison.insert(I);
7607 Worklist.push_back(I);
7608
7609 while (!Worklist.empty()) {
7610 const Instruction *Poison = Worklist.pop_back_val();
7611
7612 for (const Use &U : Poison->uses()) {
7613 const Instruction *PoisonUser = cast<Instruction>(U.getUser());
7614 if (mustTriggerUB(PoisonUser, KnownPoison) &&
7615 DT.dominates(PoisonUser->getParent(), ExitingBB))
7616 return true;
7617
7618 if (propagatesPoison(U) && L->contains(PoisonUser))
7619 if (KnownPoison.insert(PoisonUser).second)
7620 Worklist.push_back(PoisonUser);
7621 }
7622 }
7623
7624 return false;
7625}
7626
7627ScalarEvolution::LoopProperties
7628ScalarEvolution::getLoopProperties(const Loop *L) {
7629 using LoopProperties = ScalarEvolution::LoopProperties;
7630
7631 auto Itr = LoopPropertiesCache.find(L);
7632 if (Itr == LoopPropertiesCache.end()) {
7633 auto HasSideEffects = [](Instruction *I) {
7634 if (auto *SI = dyn_cast<StoreInst>(I))
7635 return !SI->isSimple();
7636
7637 if (I->mayThrow())
7638 return true;
7639
7640 // Non-volatile memset / memcpy do not count as side-effect for forward
7641 // progress.
7642 if (isa<MemIntrinsic>(I) && !I->isVolatile())
7643 return false;
7644
7645 return I->mayWriteToMemory();
7646 };
7647
7648 LoopProperties LP = {/* HasNoAbnormalExits */ true,
7649 /*HasNoSideEffects*/ true};
7650
7651 for (auto *BB : L->getBlocks())
7652 for (auto &I : *BB) {
7654 LP.HasNoAbnormalExits = false;
7655 if (HasSideEffects(&I))
7656 LP.HasNoSideEffects = false;
7657 if (!LP.HasNoAbnormalExits && !LP.HasNoSideEffects)
7658 break; // We're already as pessimistic as we can get.
7659 }
7660
7661 auto InsertPair = LoopPropertiesCache.insert({L, LP});
7662 assert(InsertPair.second && "We just checked!");
7663 Itr = InsertPair.first;
7664 }
7665
7666 return Itr->second;
7667}
7668
7670 // A mustprogress loop without side effects must be finite.
7671 // TODO: The check used here is very conservative. It's only *specific*
7672 // side effects which are well defined in infinite loops.
7673 return isFinite(L) || (isMustProgress(L) && loopHasNoSideEffects(L));
7674}
7675
7676const SCEV *ScalarEvolution::createSCEVIter(Value *V) {
7677 // Worklist item with a Value and a bool indicating whether all operands have
7678 // been visited already.
7681
7682 Stack.emplace_back(V, true);
7683 Stack.emplace_back(V, false);
7684 while (!Stack.empty()) {
7685 auto E = Stack.pop_back_val();
7686 Value *CurV = E.getPointer();
7687
7688 if (getExistingSCEV(CurV))
7689 continue;
7690
7692 const SCEV *CreatedSCEV = nullptr;
7693 // If all operands have been visited already, create the SCEV.
7694 if (E.getInt()) {
7695 CreatedSCEV = createSCEV(CurV);
7696 } else {
7697 // Otherwise get the operands we need to create SCEV's for before creating
7698 // the SCEV for CurV. If the SCEV for CurV can be constructed trivially,
7699 // just use it.
7700 CreatedSCEV = getOperandsToCreate(CurV, Ops);
7701 }
7702
7703 if (CreatedSCEV) {
7704 insertValueToMap(CurV, CreatedSCEV);
7705 } else {
7706 // Queue CurV for SCEV creation, followed by its's operands which need to
7707 // be constructed first.
7708 Stack.emplace_back(CurV, true);
7709 for (Value *Op : Ops)
7710 Stack.emplace_back(Op, false);
7711 }
7712 }
7713
7714 return getExistingSCEV(V);
7715}
7716
7717const SCEV *
7718ScalarEvolution::getOperandsToCreate(Value *V, SmallVectorImpl<Value *> &Ops) {
7719 if (!isSCEVable(V->getType()))
7720 return getUnknown(V);
7721
7722 if (Instruction *I = dyn_cast<Instruction>(V)) {
7723 // Don't attempt to analyze instructions in blocks that aren't
7724 // reachable. Such instructions don't matter, and they aren't required
7725 // to obey basic rules for definitions dominating uses which this
7726 // analysis depends on.
7727 if (!DT.isReachableFromEntry(I->getParent()))
7728 return getUnknown(PoisonValue::get(V->getType()));
7729 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7730 return getConstant(CI);
7731 else if (isa<GlobalAlias>(V))
7732 return getUnknown(V);
7733 else if (!isa<ConstantExpr>(V))
7734 return getUnknown(V);
7735
7737 if (auto BO =
7739 bool IsConstArg = isa<ConstantInt>(BO->RHS);
7740 switch (BO->Opcode) {
7741 case Instruction::Add:
7742 case Instruction::Mul: {
7743 // For additions and multiplications, traverse add/mul chains for which we
7744 // can potentially create a single SCEV, to reduce the number of
7745 // get{Add,Mul}Expr calls.
7746 do {
7747 if (BO->Op) {
7748 if (BO->Op != V && getExistingSCEV(BO->Op)) {
7749 Ops.push_back(BO->Op);
7750 break;
7751 }
7752 }
7753 Ops.push_back(BO->RHS);
7754 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
7756 if (!NewBO ||
7757 (BO->Opcode == Instruction::Add &&
7758 (NewBO->Opcode != Instruction::Add &&
7759 NewBO->Opcode != Instruction::Sub)) ||
7760 (BO->Opcode == Instruction::Mul &&
7761 NewBO->Opcode != Instruction::Mul)) {
7762 Ops.push_back(BO->LHS);
7763 break;
7764 }
7765 // CreateSCEV calls getNoWrapFlagsFromUB, which under certain conditions
7766 // requires a SCEV for the LHS.
7767 if (BO->Op && (BO->IsNSW || BO->IsNUW)) {
7768 auto *I = dyn_cast<Instruction>(BO->Op);
7769 if (I && programUndefinedIfPoison(I)) {
7770 Ops.push_back(BO->LHS);
7771 break;
7772 }
7773 }
7774 BO = NewBO;
7775 } while (true);
7776 return nullptr;
7777 }
7778 case Instruction::Sub:
7779 case Instruction::UDiv:
7780 case Instruction::URem:
7781 break;
7782 case Instruction::AShr:
7783 case Instruction::Shl:
7784 case Instruction::Xor:
7785 if (!IsConstArg)
7786 return nullptr;
7787 break;
7788 case Instruction::And:
7789 case Instruction::Or:
7790 if (!IsConstArg && !BO->LHS->getType()->isIntegerTy(1))
7791 return nullptr;
7792 break;
7793 case Instruction::LShr:
7794 return getUnknown(V);
7795 default:
7796 llvm_unreachable("Unhandled binop");
7797 break;
7798 }
7799
7800 Ops.push_back(BO->LHS);
7801 Ops.push_back(BO->RHS);
7802 return nullptr;
7803 }
7804
7805 switch (U->getOpcode()) {
7806 case Instruction::Trunc:
7807 case Instruction::ZExt:
7808 case Instruction::SExt:
7809 case Instruction::PtrToAddr:
7810 case Instruction::PtrToInt:
7811 Ops.push_back(U->getOperand(0));
7812 return nullptr;
7813
7814 case Instruction::BitCast:
7815 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType())) {
7816 Ops.push_back(U->getOperand(0));
7817 return nullptr;
7818 }
7819 return getUnknown(V);
7820
7821 case Instruction::SDiv:
7822 case Instruction::SRem:
7823 Ops.push_back(U->getOperand(0));
7824 Ops.push_back(U->getOperand(1));
7825 return nullptr;
7826
7827 case Instruction::GetElementPtr:
7828 assert(cast<GEPOperator>(U)->getSourceElementType()->isSized() &&
7829 "GEP source element type must be sized");
7830 llvm::append_range(Ops, U->operands());
7831 return nullptr;
7832
7833 case Instruction::IntToPtr:
7834 return getUnknown(V);
7835
7836 case Instruction::PHI:
7837 // getNodeForPHI has four ways to turn a PHI into a SCEV; retrieve the
7838 // relevant nodes for each of them.
7839 //
7840 // The first is just to call simplifyInstruction, and get something back
7841 // that isn't a PHI.
7842 if (Value *V = simplifyInstruction(
7843 cast<PHINode>(U),
7844 {getDataLayout(), &TLI, &DT, &AC, /*CtxI=*/nullptr,
7845 /*UseInstrInfo=*/true, /*CanUseUndef=*/false})) {
7846 assert(V);
7847 Ops.push_back(V);
7848 return nullptr;
7849 }
7850 // The second is createNodeForPHIWithIdenticalOperands: this looks for
7851 // operands which all perform the same operation, but haven't been
7852 // CSE'ed for whatever reason.
7853 if (BinaryOperator *BO = getCommonInstForPHI(cast<PHINode>(U))) {
7854 assert(BO);
7855 Ops.push_back(BO);
7856 return nullptr;
7857 }
7858 // The third is createNodeFromSelectLikePHI; this takes a PHI which
7859 // is equivalent to a select, and analyzes it like a select.
7860 {
7861 Value *Cond = nullptr, *LHS = nullptr, *RHS = nullptr;
7863 assert(Cond);
7864 assert(LHS);
7865 assert(RHS);
7866 if (auto *CondICmp = dyn_cast<ICmpInst>(Cond)) {
7867 Ops.push_back(CondICmp->getOperand(0));
7868 Ops.push_back(CondICmp->getOperand(1));
7869 }
7870 Ops.push_back(Cond);
7871 Ops.push_back(LHS);
7872 Ops.push_back(RHS);
7873 return nullptr;
7874 }
7875 }
7876 // The fourth way is createAddRecFromPHI. It's complicated to handle here,
7877 // so just construct it recursively.
7878 //
7879 // In addition to getNodeForPHI, also construct nodes which might be needed
7880 // by getRangeRef.
7882 for (Value *V : cast<PHINode>(U)->operands())
7883 Ops.push_back(V);
7884 return nullptr;
7885 }
7886 return nullptr;
7887
7888 case Instruction::Select: {
7889 // Check if U is a select that can be simplified to a SCEVUnknown.
7890 auto CanSimplifyToUnknown = [this, U]() {
7891 if (U->getType()->isIntegerTy(1) || isa<ConstantInt>(U->getOperand(0)))
7892 return false;
7893
7894 auto *ICI = dyn_cast<ICmpInst>(U->getOperand(0));
7895 if (!ICI)
7896 return false;
7897 Value *LHS = ICI->getOperand(0);
7898 Value *RHS = ICI->getOperand(1);
7899 if (ICI->getPredicate() == CmpInst::ICMP_EQ ||
7900 ICI->getPredicate() == CmpInst::ICMP_NE) {
7902 return true;
7903 } else if (getTypeSizeInBits(LHS->getType()) >
7904 getTypeSizeInBits(U->getType()))
7905 return true;
7906 return false;
7907 };
7908 if (CanSimplifyToUnknown())
7909 return getUnknown(U);
7910
7911 llvm::append_range(Ops, U->operands());
7912 return nullptr;
7913 break;
7914 }
7915 case Instruction::Call:
7916 case Instruction::Invoke:
7917 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand()) {
7918 Ops.push_back(RV);
7919 return nullptr;
7920 }
7921
7922 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
7923 switch (II->getIntrinsicID()) {
7924 case Intrinsic::abs:
7925 Ops.push_back(II->getArgOperand(0));
7926 return nullptr;
7927 case Intrinsic::umax:
7928 case Intrinsic::umin:
7929 case Intrinsic::smax:
7930 case Intrinsic::smin:
7931 case Intrinsic::usub_sat:
7932 case Intrinsic::uadd_sat:
7933 Ops.push_back(II->getArgOperand(0));
7934 Ops.push_back(II->getArgOperand(1));
7935 return nullptr;
7936 case Intrinsic::start_loop_iterations:
7937 case Intrinsic::annotation:
7938 case Intrinsic::ptr_annotation:
7939 Ops.push_back(II->getArgOperand(0));
7940 return nullptr;
7941 default:
7942 break;
7943 }
7944 }
7945 break;
7946 }
7947
7948 return nullptr;
7949}
7950
7951const SCEV *ScalarEvolution::createSCEV(Value *V) {
7952 if (!isSCEVable(V->getType()))
7953 return getUnknown(V);
7954
7955 if (Instruction *I = dyn_cast<Instruction>(V)) {
7956 // Don't attempt to analyze instructions in blocks that aren't
7957 // reachable. Such instructions don't matter, and they aren't required
7958 // to obey basic rules for definitions dominating uses which this
7959 // analysis depends on.
7960 if (!DT.isReachableFromEntry(I->getParent()))
7961 return getUnknown(PoisonValue::get(V->getType()));
7962 } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
7963 return getConstant(CI);
7964 else if (isa<GlobalAlias>(V))
7965 return getUnknown(V);
7966 else if (!isa<ConstantExpr>(V))
7967 return getUnknown(V);
7968
7969 const SCEV *LHS;
7970 const SCEV *RHS;
7971
7973 if (auto BO =
7975 switch (BO->Opcode) {
7976 case Instruction::Add: {
7977 // The simple thing to do would be to just call getSCEV on both operands
7978 // and call getAddExpr with the result. However if we're looking at a
7979 // bunch of things all added together, this can be quite inefficient,
7980 // because it leads to N-1 getAddExpr calls for N ultimate operands.
7981 // Instead, gather up all the operands and make a single getAddExpr call.
7982 // LLVM IR canonical form means we need only traverse the left operands.
7984 do {
7985 if (BO->Op) {
7986 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
7987 AddOps.push_back(OpSCEV);
7988 break;
7989 }
7990
7991 // If a NUW or NSW flag can be applied to the SCEV for this
7992 // addition, then compute the SCEV for this addition by itself
7993 // with a separate call to getAddExpr. We need to do that
7994 // instead of pushing the operands of the addition onto AddOps,
7995 // since the flags are only known to apply to this particular
7996 // addition - they may not apply to other additions that can be
7997 // formed with operands from AddOps.
7998 const SCEV *RHS = getSCEV(BO->RHS);
7999 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
8000 if (Flags != SCEV::FlagAnyWrap) {
8001 const SCEV *LHS = getSCEV(BO->LHS);
8002 if (BO->Opcode == Instruction::Sub)
8003 AddOps.push_back(getMinusSCEV(LHS, RHS, Flags));
8004 else
8005 AddOps.push_back(getAddExpr(LHS, RHS, Flags));
8006 break;
8007 }
8008 }
8009
8010 if (BO->Opcode == Instruction::Sub)
8011 AddOps.push_back(getNegativeSCEV(getSCEV(BO->RHS)));
8012 else
8013 AddOps.push_back(getSCEV(BO->RHS));
8014
8015 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
8017 if (!NewBO || (NewBO->Opcode != Instruction::Add &&
8018 NewBO->Opcode != Instruction::Sub)) {
8019 AddOps.push_back(getSCEV(BO->LHS));
8020 break;
8021 }
8022 BO = NewBO;
8023 } while (true);
8024
8025 return getAddExpr(AddOps);
8026 }
8027
8028 case Instruction::Mul: {
8030 do {
8031 if (BO->Op) {
8032 if (auto *OpSCEV = getExistingSCEV(BO->Op)) {
8033 MulOps.push_back(OpSCEV);
8034 break;
8035 }
8036
8037 SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op);
8038 if (Flags != SCEV::FlagAnyWrap) {
8039 LHS = getSCEV(BO->LHS);
8040 RHS = getSCEV(BO->RHS);
8041 MulOps.push_back(getMulExpr(LHS, RHS, Flags));
8042 break;
8043 }
8044 }
8045
8046 MulOps.push_back(getSCEV(BO->RHS));
8047 auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
8049 if (!NewBO || NewBO->Opcode != Instruction::Mul) {
8050 MulOps.push_back(getSCEV(BO->LHS));
8051 break;
8052 }
8053 BO = NewBO;
8054 } while (true);
8055
8056 return getMulExpr(MulOps);
8057 }
8058 case Instruction::UDiv:
8059 LHS = getSCEV(BO->LHS);
8060 RHS = getSCEV(BO->RHS);
8061 return getUDivExpr(LHS, RHS);
8062 case Instruction::URem:
8063 LHS = getSCEV(BO->LHS);
8064 RHS = getSCEV(BO->RHS);
8065 return getURemExpr(LHS, RHS);
8066 case Instruction::Sub: {
8068 if (BO->Op)
8069 Flags = getNoWrapFlagsFromUB(BO->Op);
8070 LHS = getSCEV(BO->LHS);
8071 RHS = getSCEV(BO->RHS);
8072 return getMinusSCEV(LHS, RHS, Flags);
8073 }
8074 case Instruction::And:
8075 // For an expression like x&255 that merely masks off the high bits,
8076 // use zext(trunc(x)) as the SCEV expression.
8077 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
8078 if (CI->isZero())
8079 return getSCEV(BO->RHS);
8080 if (CI->isMinusOne())
8081 return getSCEV(BO->LHS);
8082 const APInt &A = CI->getValue();
8083
8084 // Instcombine's ShrinkDemandedConstant may strip bits out of
8085 // constants, obscuring what would otherwise be a low-bits mask.
8086 // Use computeKnownBits to compute what ShrinkDemandedConstant
8087 // knew about to reconstruct a low-bits mask value.
8088 unsigned LZ = A.countl_zero();
8089 unsigned TZ = A.countr_zero();
8090 unsigned BitWidth = A.getBitWidth();
8091 KnownBits Known(BitWidth);
8092 computeKnownBits(BO->LHS, Known, getDataLayout(), &AC, nullptr, &DT);
8093
8094 APInt EffectiveMask =
8095 APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ);
8096 if ((LZ != 0 || TZ != 0) && !((~A & ~Known.Zero) & EffectiveMask)) {
8097 const SCEV *MulCount = getConstant(APInt::getOneBitSet(BitWidth, TZ));
8098 const SCEV *LHS = getSCEV(BO->LHS);
8099 const SCEV *ShiftedLHS = nullptr;
8100 if (auto *LHSMul = dyn_cast<SCEVMulExpr>(LHS)) {
8101 if (auto *OpC = dyn_cast<SCEVConstant>(LHSMul->getOperand(0))) {
8102 // For an expression like (x * 8) & 8, simplify the multiply.
8103 unsigned MulZeros = OpC->getAPInt().countr_zero();
8104 unsigned GCD = std::min(MulZeros, TZ);
8105 APInt DivAmt = APInt::getOneBitSet(BitWidth, TZ - GCD);
8107 MulOps.push_back(getConstant(OpC->getAPInt().ashr(GCD)));
8108 append_range(MulOps, LHSMul->operands().drop_front());
8109 auto *NewMul = getMulExpr(MulOps, LHSMul->getNoWrapFlags());
8110 ShiftedLHS = getUDivExpr(NewMul, getConstant(DivAmt));
8111 }
8112 }
8113 if (!ShiftedLHS)
8114 ShiftedLHS = getUDivExpr(LHS, MulCount);
8115 return getMulExpr(
8117 getTruncateExpr(ShiftedLHS,
8118 IntegerType::get(getContext(), BitWidth - LZ - TZ)),
8119 BO->LHS->getType()),
8120 MulCount);
8121 }
8122 }
8123 // Binary `and` is a bit-wise `umin`.
8124 if (BO->LHS->getType()->isIntegerTy(1)) {
8125 LHS = getSCEV(BO->LHS);
8126 RHS = getSCEV(BO->RHS);
8127 return getUMinExpr(LHS, RHS);
8128 }
8129 break;
8130
8131 case Instruction::Or:
8132 // Binary `or` is a bit-wise `umax`.
8133 if (BO->LHS->getType()->isIntegerTy(1)) {
8134 LHS = getSCEV(BO->LHS);
8135 RHS = getSCEV(BO->RHS);
8136 return getUMaxExpr(LHS, RHS);
8137 }
8138 break;
8139
8140 case Instruction::Xor:
8141 if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
8142 // If the RHS of xor is -1, then this is a not operation.
8143 if (CI->isMinusOne())
8144 return getNotSCEV(getSCEV(BO->LHS));
8145
8146 // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask.
8147 // This is a variant of the check for xor with -1, and it handles
8148 // the case where instcombine has trimmed non-demanded bits out
8149 // of an xor with -1.
8150 if (auto *LBO = dyn_cast<BinaryOperator>(BO->LHS))
8151 if (ConstantInt *LCI = dyn_cast<ConstantInt>(LBO->getOperand(1)))
8152 if (LBO->getOpcode() == Instruction::And &&
8153 LCI->getValue() == CI->getValue())
8154 if (const SCEVZeroExtendExpr *Z =
8156 Type *UTy = BO->LHS->getType();
8157 const SCEV *Z0 = Z->getOperand();
8158 Type *Z0Ty = Z0->getType();
8159 unsigned Z0TySize = getTypeSizeInBits(Z0Ty);
8160
8161 // If C is a low-bits mask, the zero extend is serving to
8162 // mask off the high bits. Complement the operand and
8163 // re-apply the zext.
8164 if (CI->getValue().isMask(Z0TySize))
8165 return getZeroExtendExpr(getNotSCEV(Z0), UTy);
8166
8167 // If C is a single bit, it may be in the sign-bit position
8168 // before the zero-extend. In this case, represent the xor
8169 // using an add, which is equivalent, and re-apply the zext.
8170 APInt Trunc = CI->getValue().trunc(Z0TySize);
8171 if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() &&
8172 Trunc.isSignMask())
8173 return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)),
8174 UTy);
8175 }
8176 }
8177 break;
8178
8179 case Instruction::Shl:
8180 // Turn shift left of a constant amount into a multiply.
8181 if (ConstantInt *SA = dyn_cast<ConstantInt>(BO->RHS)) {
8182 uint32_t BitWidth = cast<IntegerType>(SA->getType())->getBitWidth();
8183
8184 // If the shift count is not less than the bitwidth, the result of
8185 // the shift is undefined. Don't try to analyze it, because the
8186 // resolution chosen here may differ from the resolution chosen in
8187 // other parts of the compiler.
8188 if (SA->getValue().uge(BitWidth))
8189 break;
8190
8191 // We can safely preserve the nuw flag in all cases. It's also safe to
8192 // turn a nuw nsw shl into a nuw nsw mul. However, nsw in isolation
8193 // requires special handling. It can be preserved as long as we're not
8194 // left shifting by bitwidth - 1.
8195 auto Flags = SCEV::FlagAnyWrap;
8196 if (BO->Op) {
8197 auto MulFlags = getNoWrapFlagsFromUB(BO->Op);
8198 if (any(MulFlags & SCEV::FlagNSW) &&
8199 (any(MulFlags & SCEV::FlagNUW) ||
8200 SA->getValue().ult(BitWidth - 1)))
8202 if (any(MulFlags & SCEV::FlagNUW))
8204 }
8205
8206 ConstantInt *X = ConstantInt::get(
8207 getContext(), APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
8208 return getMulExpr(getSCEV(BO->LHS), getConstant(X), Flags);
8209 }
8210 break;
8211
8212 case Instruction::AShr:
8213 // AShr X, C, where C is a constant.
8214 ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS);
8215 if (!CI)
8216 break;
8217
8218 Type *OuterTy = BO->LHS->getType();
8219 uint64_t BitWidth = getTypeSizeInBits(OuterTy);
8220 // If the shift count is not less than the bitwidth, the result of
8221 // the shift is undefined. Don't try to analyze it, because the
8222 // resolution chosen here may differ from the resolution chosen in
8223 // other parts of the compiler.
8224 if (CI->getValue().uge(BitWidth))
8225 break;
8226
8227 if (CI->isZero())
8228 return getSCEV(BO->LHS); // shift by zero --> noop
8229
8230 uint64_t AShrAmt = CI->getZExtValue();
8231 Type *TruncTy = IntegerType::get(getContext(), BitWidth - AShrAmt);
8232
8233 Operator *L = dyn_cast<Operator>(BO->LHS);
8234 const SCEV *AddTruncateExpr = nullptr;
8235 ConstantInt *ShlAmtCI = nullptr;
8236 const SCEV *AddConstant = nullptr;
8237
8238 if (L && L->getOpcode() == Instruction::Add) {
8239 // X = Shl A, n
8240 // Y = Add X, c
8241 // Z = AShr Y, m
8242 // n, c and m are constants.
8243
8244 Operator *LShift = dyn_cast<Operator>(L->getOperand(0));
8245 ConstantInt *AddOperandCI = dyn_cast<ConstantInt>(L->getOperand(1));
8246 if (LShift && LShift->getOpcode() == Instruction::Shl) {
8247 if (AddOperandCI) {
8248 const SCEV *ShlOp0SCEV = getSCEV(LShift->getOperand(0));
8249 ShlAmtCI = dyn_cast<ConstantInt>(LShift->getOperand(1));
8250 // since we truncate to TruncTy, the AddConstant should be of the
8251 // same type, so create a new Constant with type same as TruncTy.
8252 // Also, the Add constant should be shifted right by AShr amount.
8253 APInt AddOperand = AddOperandCI->getValue().ashr(AShrAmt);
8254 AddConstant = getConstant(AddOperand.trunc(BitWidth - AShrAmt));
8255 // we model the expression as sext(add(trunc(A), c << n)), since the
8256 // sext(trunc) part is already handled below, we create a
8257 // AddExpr(TruncExp) which will be used later.
8258 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
8259 }
8260 }
8261 } else if (L && L->getOpcode() == Instruction::Shl) {
8262 // X = Shl A, n
8263 // Y = AShr X, m
8264 // Both n and m are constant.
8265
8266 const SCEV *ShlOp0SCEV = getSCEV(L->getOperand(0));
8267 ShlAmtCI = dyn_cast<ConstantInt>(L->getOperand(1));
8268 AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
8269 }
8270
8271 if (AddTruncateExpr && ShlAmtCI) {
8272 // We can merge the two given cases into a single SCEV statement,
8273 // incase n = m, the mul expression will be 2^0, so it gets resolved to
8274 // a simpler case. The following code handles the two cases:
8275 //
8276 // 1) For a two-shift sext-inreg, i.e. n = m,
8277 // use sext(trunc(x)) as the SCEV expression.
8278 //
8279 // 2) When n > m, use sext(mul(trunc(x), 2^(n-m)))) as the SCEV
8280 // expression. We already checked that ShlAmt < BitWidth, so
8281 // the multiplier, 1 << (ShlAmt - AShrAmt), fits into TruncTy as
8282 // ShlAmt - AShrAmt < Amt.
8283 const APInt &ShlAmt = ShlAmtCI->getValue();
8284 if (ShlAmt.ult(BitWidth) && ShlAmt.uge(AShrAmt)) {
8285 APInt Mul = APInt::getOneBitSet(BitWidth - AShrAmt,
8286 ShlAmtCI->getZExtValue() - AShrAmt);
8287 const SCEV *CompositeExpr =
8288 getMulExpr(AddTruncateExpr, getConstant(Mul));
8289 if (L->getOpcode() != Instruction::Shl)
8290 CompositeExpr = getAddExpr(CompositeExpr, AddConstant);
8291
8292 return getSignExtendExpr(CompositeExpr, OuterTy);
8293 }
8294 }
8295 break;
8296 }
8297 }
8298
8299 switch (U->getOpcode()) {
8300 case Instruction::Trunc:
8301 return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType());
8302
8303 case Instruction::ZExt:
8304 return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType());
8305
8306 case Instruction::SExt:
8307 if (auto BO = MatchBinaryOp(U->getOperand(0), getDataLayout(), AC, DT,
8309 // The NSW flag of a subtract does not always survive the conversion to
8310 // A + (-1)*B. By pushing sign extension onto its operands we are much
8311 // more likely to preserve NSW and allow later AddRec optimisations.
8312 //
8313 // NOTE: This is effectively duplicating this logic from getSignExtend:
8314 // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
8315 // but by that point the NSW information has potentially been lost.
8316 if (BO->Opcode == Instruction::Sub && BO->IsNSW) {
8317 Type *Ty = U->getType();
8318 auto *V1 = getSignExtendExpr(getSCEV(BO->LHS), Ty);
8319 auto *V2 = getSignExtendExpr(getSCEV(BO->RHS), Ty);
8320 return getMinusSCEV(V1, V2, SCEV::FlagNSW);
8321 }
8322 }
8323 return getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType());
8324
8325 case Instruction::BitCast:
8326 // BitCasts are no-op casts so we just eliminate the cast.
8327 if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType()))
8328 return getSCEV(U->getOperand(0));
8329 break;
8330
8331 case Instruction::PtrToAddr: {
8332 const SCEV *IntOp = getPtrToAddrExpr(getSCEV(U->getOperand(0)));
8333 if (isa<SCEVCouldNotCompute>(IntOp))
8334 return getUnknown(V);
8335 return IntOp;
8336 }
8337
8338 case Instruction::PtrToInt: {
8339 // Pointer to integer cast is straight-forward, so do model it.
8340 const SCEV *Op = getSCEV(U->getOperand(0));
8341 Type *DstIntTy = U->getType();
8342 // But only if effective SCEV (integer) type is wide enough to represent
8343 // all possible pointer values.
8344 const SCEV *IntOp = getPtrToIntExpr(Op, DstIntTy);
8345 if (isa<SCEVCouldNotCompute>(IntOp))
8346 return getUnknown(V);
8347 return IntOp;
8348 }
8349 case Instruction::IntToPtr:
8350 // Just don't deal with inttoptr casts.
8351 return getUnknown(V);
8352
8353 case Instruction::SDiv:
8354 // If both operands are non-negative, this is just an udiv.
8355 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8356 isKnownNonNegative(getSCEV(U->getOperand(1))))
8357 return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8358 break;
8359
8360 case Instruction::SRem:
8361 // If both operands are non-negative, this is just an urem.
8362 if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
8363 isKnownNonNegative(getSCEV(U->getOperand(1))))
8364 return getURemExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
8365 break;
8366
8367 case Instruction::GetElementPtr:
8368 return createNodeForGEP(cast<GEPOperator>(U));
8369
8370 case Instruction::PHI:
8371 return createNodeForPHI(cast<PHINode>(U));
8372
8373 case Instruction::Select:
8374 return createNodeForSelectOrPHI(U, U->getOperand(0), U->getOperand(1),
8375 U->getOperand(2));
8376
8377 case Instruction::Call:
8378 case Instruction::Invoke:
8379 if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand())
8380 return getSCEV(RV);
8381
8382 if (auto *II = dyn_cast<IntrinsicInst>(U)) {
8383 switch (II->getIntrinsicID()) {
8384 case Intrinsic::abs:
8385 return getAbsExpr(
8386 getSCEV(II->getArgOperand(0)),
8387 /*IsNSW=*/cast<ConstantInt>(II->getArgOperand(1))->isOne());
8388 case Intrinsic::umax:
8389 LHS = getSCEV(II->getArgOperand(0));
8390 RHS = getSCEV(II->getArgOperand(1));
8391 return getUMaxExpr(LHS, RHS);
8392 case Intrinsic::umin:
8393 LHS = getSCEV(II->getArgOperand(0));
8394 RHS = getSCEV(II->getArgOperand(1));
8395 return getUMinExpr(LHS, RHS);
8396 case Intrinsic::smax:
8397 LHS = getSCEV(II->getArgOperand(0));
8398 RHS = getSCEV(II->getArgOperand(1));
8399 return getSMaxExpr(LHS, RHS);
8400 case Intrinsic::smin:
8401 LHS = getSCEV(II->getArgOperand(0));
8402 RHS = getSCEV(II->getArgOperand(1));
8403 return getSMinExpr(LHS, RHS);
8404 case Intrinsic::usub_sat: {
8405 const SCEV *X = getSCEV(II->getArgOperand(0));
8406 const SCEV *Y = getSCEV(II->getArgOperand(1));
8407 const SCEV *ClampedY = getUMinExpr(X, Y);
8408 return getMinusSCEV(X, ClampedY, SCEV::FlagNUW);
8409 }
8410 case Intrinsic::uadd_sat: {
8411 const SCEV *X = getSCEV(II->getArgOperand(0));
8412 const SCEV *Y = getSCEV(II->getArgOperand(1));
8413 const SCEV *ClampedX = getUMinExpr(X, getNotSCEV(Y));
8414 return getAddExpr(ClampedX, Y, SCEV::FlagNUW);
8415 }
8416 case Intrinsic::start_loop_iterations:
8417 case Intrinsic::annotation:
8418 case Intrinsic::ptr_annotation:
8419 // A start_loop_iterations or llvm.annotation or llvm.prt.annotation is
8420 // just eqivalent to the first operand for SCEV purposes.
8421 return getSCEV(II->getArgOperand(0));
8422 case Intrinsic::vscale:
8423 return getVScale(II->getType());
8424 default:
8425 break;
8426 }
8427 }
8428 break;
8429 }
8430
8431 return getUnknown(V);
8432}
8433
8434//===----------------------------------------------------------------------===//
8435// Iteration Count Computation Code
8436//
8437
8439 if (isa<SCEVCouldNotCompute>(ExitCount))
8440 return getCouldNotCompute();
8441
8442 auto *ExitCountType = ExitCount->getType();
8443 assert(ExitCountType->isIntegerTy());
8444 auto *EvalTy = Type::getIntNTy(ExitCountType->getContext(),
8445 1 + ExitCountType->getScalarSizeInBits());
8446 return getTripCountFromExitCount(ExitCount, EvalTy, nullptr);
8447}
8448
8450 Type *EvalTy,
8451 const Loop *L) {
8452 if (isa<SCEVCouldNotCompute>(ExitCount))
8453 return getCouldNotCompute();
8454
8455 unsigned ExitCountSize = getTypeSizeInBits(ExitCount->getType());
8456 unsigned EvalSize = EvalTy->getPrimitiveSizeInBits();
8457
8458 auto CanAddOneWithoutOverflow = [&]() {
8459 ConstantRange ExitCountRange =
8460 getRangeRef(ExitCount, RangeSignHint::HINT_RANGE_UNSIGNED);
8461 if (!ExitCountRange.contains(APInt::getMaxValue(ExitCountSize)))
8462 return true;
8463
8464 return L && isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, ExitCount,
8465 getMinusOne(ExitCount->getType()));
8466 };
8467
8468 // If we need to zero extend the backedge count, check if we can add one to
8469 // it prior to zero extending without overflow. Provided this is safe, it
8470 // allows better simplification of the +1.
8471 if (EvalSize > ExitCountSize && CanAddOneWithoutOverflow())
8472 return getZeroExtendExpr(
8473 getAddExpr(ExitCount, getOne(ExitCount->getType())), EvalTy);
8474
8475 // Get the total trip count from the count by adding 1. This may wrap.
8476 return getAddExpr(getTruncateOrZeroExtend(ExitCount, EvalTy), getOne(EvalTy));
8477}
8478
8479static unsigned getConstantTripCount(const SCEVConstant *ExitCount) {
8480 if (!ExitCount)
8481 return 0;
8482
8483 ConstantInt *ExitConst = ExitCount->getValue();
8484
8485 // Guard against huge trip counts.
8486 if (ExitConst->getValue().getActiveBits() > 32)
8487 return 0;
8488
8489 // In case of integer overflow, this returns 0, which is correct.
8490 return ((unsigned)ExitConst->getZExtValue()) + 1;
8491}
8492
8494 auto *ExitCount = dyn_cast<SCEVConstant>(getBackedgeTakenCount(L, Exact));
8495 return getConstantTripCount(ExitCount);
8496}
8497
8498unsigned
8500 const BasicBlock *ExitingBlock) {
8501 assert(ExitingBlock && "Must pass a non-null exiting block!");
8502 assert(L->isLoopExiting(ExitingBlock) &&
8503 "Exiting block must actually branch out of the loop!");
8504 const SCEVConstant *ExitCount =
8505 dyn_cast<SCEVConstant>(getExitCount(L, ExitingBlock));
8506 return getConstantTripCount(ExitCount);
8507}
8508
8510 const Loop *L, SmallVectorImpl<const SCEVPredicate *> *Predicates) {
8511
8512 const auto *MaxExitCount =
8513 Predicates ? getPredicatedConstantMaxBackedgeTakenCount(L, *Predicates)
8515 return getConstantTripCount(dyn_cast<SCEVConstant>(MaxExitCount));
8516}
8517
8519 SmallVector<BasicBlock *, 8> ExitingBlocks;
8520 L->getExitingBlocks(ExitingBlocks);
8521
8522 std::optional<unsigned> Res;
8523 for (auto *ExitingBB : ExitingBlocks) {
8524 unsigned Multiple = getSmallConstantTripMultiple(L, ExitingBB);
8525 if (!Res)
8526 Res = Multiple;
8527 Res = std::gcd(*Res, Multiple);
8528 }
8529 return Res.value_or(1);
8530}
8531
8533 const SCEV *ExitCount) {
8534 if (isa<SCEVCouldNotCompute>(ExitCount))
8535 return 1;
8536
8537 // Get the trip count
8538 const SCEV *TCExpr = getTripCountFromExitCount(applyLoopGuards(ExitCount, L));
8539
8540 APInt Multiple = getNonZeroConstantMultiple(TCExpr);
8541 // If a trip multiple is huge (>=2^32), the trip count is still divisible by
8542 // the greatest power of 2 divisor less than 2^32.
8543 return Multiple.getActiveBits() > 32
8544 ? 1U << std::min(31U, Multiple.countTrailingZeros())
8545 : (unsigned)Multiple.getZExtValue();
8546}
8547
8548/// Returns the largest constant divisor of the trip count of this loop as a
8549/// normal unsigned value, if possible. This means that the actual trip count is
8550/// always a multiple of the returned value (don't forget the trip count could
8551/// very well be zero as well!).
8552///
8553/// Returns 1 if the trip count is unknown or not guaranteed to be the
8554/// multiple of a constant (which is also the case if the trip count is simply
8555/// constant, use getSmallConstantTripCount for that case), Will also return 1
8556/// if the trip count is very large (>= 2^32).
8557///
8558/// As explained in the comments for getSmallConstantTripCount, this assumes
8559/// that control exits the loop via ExitingBlock.
8560unsigned
8562 const BasicBlock *ExitingBlock) {
8563 assert(ExitingBlock && "Must pass a non-null exiting block!");
8564 assert(L->isLoopExiting(ExitingBlock) &&
8565 "Exiting block must actually branch out of the loop!");
8566 const SCEV *ExitCount = getExitCount(L, ExitingBlock);
8567 return getSmallConstantTripMultiple(L, ExitCount);
8568}
8569
8571 const BasicBlock *ExitingBlock,
8572 ExitCountKind Kind) {
8573 switch (Kind) {
8574 case Exact:
8575 return getBackedgeTakenInfo(L).getExact(ExitingBlock, this);
8576 case SymbolicMaximum:
8577 return getBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this);
8578 case ConstantMaximum:
8579 return getBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this);
8580 };
8581 llvm_unreachable("Invalid ExitCountKind!");
8582}
8583
8585 const Loop *L, const BasicBlock *ExitingBlock,
8587 switch (Kind) {
8588 case Exact:
8589 return getPredicatedBackedgeTakenInfo(L).getExact(ExitingBlock, this,
8590 Predicates);
8591 case SymbolicMaximum:
8592 return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(ExitingBlock, this,
8593 Predicates);
8594 case ConstantMaximum:
8595 return getPredicatedBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this,
8596 Predicates);
8597 };
8598 llvm_unreachable("Invalid ExitCountKind!");
8599}
8600
8603 return getPredicatedBackedgeTakenInfo(L).getExact(L, this, &Preds);
8604}
8605
8607 ExitCountKind Kind) {
8608 switch (Kind) {
8609 case Exact:
8610 return getBackedgeTakenInfo(L).getExact(L, this);
8611 case ConstantMaximum:
8612 return getBackedgeTakenInfo(L).getConstantMax(this);
8613 case SymbolicMaximum:
8614 return getBackedgeTakenInfo(L).getSymbolicMax(L, this);
8615 };
8616 llvm_unreachable("Invalid ExitCountKind!");
8617}
8618
8621 return getPredicatedBackedgeTakenInfo(L).getSymbolicMax(L, this, &Preds);
8622}
8623
8626 return getPredicatedBackedgeTakenInfo(L).getConstantMax(this, &Preds);
8627}
8628
8630 return getBackedgeTakenInfo(L).isConstantMaxOrZero(this);
8631}
8632
8633/// Push PHI nodes in the header of the given loop onto the given Worklist.
8634static void PushLoopPHIs(const Loop *L,
8637 BasicBlock *Header = L->getHeader();
8638
8639 // Push all Loop-header PHIs onto the Worklist stack.
8640 for (PHINode &PN : Header->phis())
8641 if (Visited.insert(&PN).second)
8642 Worklist.push_back(&PN);
8643}
8644
8645ScalarEvolution::BackedgeTakenInfo &
8646ScalarEvolution::getPredicatedBackedgeTakenInfo(const Loop *L) {
8647 auto &BTI = getBackedgeTakenInfo(L);
8648 if (BTI.hasFullInfo())
8649 return BTI;
8650
8651 auto Pair = PredicatedBackedgeTakenCounts.try_emplace(L);
8652
8653 if (!Pair.second)
8654 return Pair.first->second;
8655
8656 BackedgeTakenInfo Result =
8657 computeBackedgeTakenCount(L, /*AllowPredicates=*/true);
8658
8659 return PredicatedBackedgeTakenCounts.find(L)->second = std::move(Result);
8660}
8661
8662ScalarEvolution::BackedgeTakenInfo &
8663ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
8664 // Initially insert an invalid entry for this loop. If the insertion
8665 // succeeds, proceed to actually compute a backedge-taken count and
8666 // update the value. The temporary CouldNotCompute value tells SCEV
8667 // code elsewhere that it shouldn't attempt to request a new
8668 // backedge-taken count, which could result in infinite recursion.
8669 std::pair<DenseMap<const Loop *, BackedgeTakenInfo>::iterator, bool> Pair =
8670 BackedgeTakenCounts.try_emplace(L);
8671 if (!Pair.second)
8672 return Pair.first->second;
8673
8674 // computeBackedgeTakenCount may allocate memory for its result. Inserting it
8675 // into the BackedgeTakenCounts map transfers ownership. Otherwise, the result
8676 // must be cleared in this scope.
8677 BackedgeTakenInfo Result = computeBackedgeTakenCount(L);
8678
8679 // Now that we know more about the trip count for this loop, forget any
8680 // existing SCEV values for PHI nodes in this loop since they are only
8681 // conservative estimates made without the benefit of trip count
8682 // information. This invalidation is not necessary for correctness, and is
8683 // only done to produce more precise results.
8684 if (Result.hasAnyInfo()) {
8685 // Invalidate any expression using an addrec in this loop.
8686 SmallVector<SCEVUse, 8> ToForget;
8687 auto LoopUsersIt = LoopUsers.find(L);
8688 if (LoopUsersIt != LoopUsers.end())
8689 append_range(ToForget, LoopUsersIt->second);
8690 forgetMemoizedResults(ToForget);
8691
8692 // Invalidate constant-evolved loop header phis.
8693 for (PHINode &PN : L->getHeader()->phis())
8694 ConstantEvolutionLoopExitValue.erase(&PN);
8695 }
8696
8697 // Re-lookup the insert position, since the call to
8698 // computeBackedgeTakenCount above could result in a
8699 // recusive call to getBackedgeTakenInfo (on a different
8700 // loop), which would invalidate the iterator computed
8701 // earlier.
8702 return BackedgeTakenCounts.find(L)->second = std::move(Result);
8703}
8704
8706 // This method is intended to forget all info about loops. It should
8707 // invalidate caches as if the following happened:
8708 // - The trip counts of all loops have changed arbitrarily
8709 // - Every llvm::Value has been updated in place to produce a different
8710 // result.
8711 BackedgeTakenCounts.clear();
8712 PredicatedBackedgeTakenCounts.clear();
8713 BECountUsers.clear();
8714 LoopPropertiesCache.clear();
8715 ConstantEvolutionLoopExitValue.clear();
8716 ValueExprMap.clear();
8717 ValuesAtScopes.clear();
8718 ValuesAtScopesUsers.clear();
8719 LoopDispositions.clear();
8720 BlockDispositions.clear();
8721 UnsignedRanges.clear();
8722 SignedRanges.clear();
8723 ExprValueMap.clear();
8724 HasRecMap.clear();
8725 ConstantMultipleCache.clear();
8726 PredicatedSCEVRewrites.clear();
8727 FoldCache.clear();
8728 FoldCacheUser.clear();
8729}
8730void ScalarEvolution::visitAndClearUsers(
8733 SmallVectorImpl<SCEVUse> &ToForget) {
8734 while (!Worklist.empty()) {
8735 Instruction *I = Worklist.pop_back_val();
8736 if (!isSCEVable(I->getType()) && !isa<WithOverflowInst>(I))
8737 continue;
8738
8740 ValueExprMap.find_as(static_cast<Value *>(I));
8741 if (It != ValueExprMap.end()) {
8742 eraseValueFromMap(It->first);
8743 ToForget.push_back(It->second);
8744 if (PHINode *PN = dyn_cast<PHINode>(I))
8745 ConstantEvolutionLoopExitValue.erase(PN);
8746 }
8747
8748 PushDefUseChildren(I, Worklist, Visited);
8749 }
8750}
8751
8753 SmallVector<const Loop *, 16> LoopWorklist(1, L);
8756 SmallVector<SCEVUse, 16> ToForget;
8757
8758 // Iterate over all the loops and sub-loops to drop SCEV information.
8759 while (!LoopWorklist.empty()) {
8760 auto *CurrL = LoopWorklist.pop_back_val();
8761
8762 // Drop any stored trip count value.
8763 forgetBackedgeTakenCounts(CurrL, /* Predicated */ false);
8764 forgetBackedgeTakenCounts(CurrL, /* Predicated */ true);
8765
8766 // Drop information about predicated SCEV rewrites for this loop.
8767 for (auto I = PredicatedSCEVRewrites.begin();
8768 I != PredicatedSCEVRewrites.end();) {
8769 std::pair<const SCEV *, const Loop *> Entry = I->first;
8770 if (Entry.second == CurrL)
8771 PredicatedSCEVRewrites.erase(I++);
8772 else
8773 ++I;
8774 }
8775
8776 auto LoopUsersItr = LoopUsers.find(CurrL);
8777 if (LoopUsersItr != LoopUsers.end())
8778 llvm::append_range(ToForget, LoopUsersItr->second);
8779
8780 // Drop information about expressions based on loop-header PHIs.
8781 PushLoopPHIs(CurrL, Worklist, Visited);
8782 visitAndClearUsers(Worklist, Visited, ToForget);
8783
8784 LoopPropertiesCache.erase(CurrL);
8785 // Forget all contained loops too, to avoid dangling entries in the
8786 // ValuesAtScopes map.
8787 LoopWorklist.append(CurrL->begin(), CurrL->end());
8788 }
8789 forgetMemoizedResults(ToForget);
8790}
8791
8793 forgetLoop(L->getOutermostLoop());
8794}
8795
8798 if (!I) return;
8799
8800 // Drop information about expressions based on loop-header PHIs.
8803 SmallVector<SCEVUse, 8> ToForget;
8804 Worklist.push_back(I);
8805 Visited.insert(I);
8806 visitAndClearUsers(Worklist, Visited, ToForget);
8807
8808 forgetMemoizedResults(ToForget);
8809}
8810
8812 if (!isSCEVable(V->getType()))
8813 return;
8814
8815 // If SCEV looked through a trivial LCSSA phi node, we might have SCEV's
8816 // directly using a SCEVUnknown/SCEVAddRec defined in the loop. After an
8817 // extra predecessor is added, this is no longer valid. Find all Unknowns and
8818 // AddRecs defined in the loop and invalidate any SCEV's making use of them.
8819 if (const SCEV *S = getExistingSCEV(V)) {
8820 struct InvalidationRootCollector {
8821 Loop *L;
8823
8824 InvalidationRootCollector(Loop *L) : L(L) {}
8825
8826 bool follow(const SCEV *S) {
8827 if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
8828 if (auto *I = dyn_cast<Instruction>(SU->getValue()))
8829 if (L->contains(I))
8830 Roots.push_back(S);
8831 } else if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
8832 if (L->contains(AddRec->getLoop()))
8833 Roots.push_back(S);
8834 }
8835 return true;
8836 }
8837 bool isDone() const { return false; }
8838 };
8839
8840 InvalidationRootCollector C(L);
8841 visitAll(S, C);
8842 forgetMemoizedResults(C.Roots);
8843 }
8844
8845 // Also perform the normal invalidation.
8846 forgetValue(V);
8847}
8848
8849void ScalarEvolution::forgetLoopDispositions() { LoopDispositions.clear(); }
8850
8852 // Unless a specific value is passed to invalidation, completely clear both
8853 // caches.
8854 if (!V) {
8855 BlockDispositions.clear();
8856 LoopDispositions.clear();
8857 return;
8858 }
8859
8860 if (!isSCEVable(V->getType()))
8861 return;
8862
8863 const SCEV *S = getExistingSCEV(V);
8864 if (!S)
8865 return;
8866
8867 // Invalidate the block and loop dispositions cached for S. Dispositions of
8868 // S's users may change if S's disposition changes (i.e. a user may change to
8869 // loop-invariant, if S changes to loop invariant), so also invalidate
8870 // dispositions of S's users recursively.
8871 SmallVector<SCEVUse, 8> Worklist = {S};
8873 while (!Worklist.empty()) {
8874 const SCEV *Curr = Worklist.pop_back_val();
8875 bool LoopDispoRemoved = LoopDispositions.erase(Curr);
8876 bool BlockDispoRemoved = BlockDispositions.erase(Curr);
8877 if (!LoopDispoRemoved && !BlockDispoRemoved)
8878 continue;
8879 auto Users = SCEVUsers.find(Curr);
8880 if (Users != SCEVUsers.end())
8881 for (const auto *User : Users->second)
8882 if (Seen.insert(User).second)
8883 Worklist.push_back(User);
8884 }
8885}
8886
8887/// Get the exact loop backedge taken count considering all loop exits. A
8888/// computable result can only be returned for loops with all exiting blocks
8889/// dominating the latch. howFarToZero assumes that the limit of each loop test
8890/// is never skipped. This is a valid assumption as long as the loop exits via
8891/// that test. For precise results, it is the caller's responsibility to specify
8892/// the relevant loop exiting block using getExact(ExitingBlock, SE).
8893const SCEV *ScalarEvolution::BackedgeTakenInfo::getExact(
8894 const Loop *L, ScalarEvolution *SE,
8896 // If any exits were not computable, the loop is not computable.
8897 if (!isComplete() || ExitNotTaken.empty())
8898 return SE->getCouldNotCompute();
8899
8900 const BasicBlock *Latch = L->getLoopLatch();
8901 // All exiting blocks we have collected must dominate the only backedge.
8902 if (!Latch)
8903 return SE->getCouldNotCompute();
8904
8905 // All exiting blocks we have gathered dominate loop's latch, so exact trip
8906 // count is simply a minimum out of all these calculated exit counts.
8908 for (const auto &ENT : ExitNotTaken) {
8909 const SCEV *BECount = ENT.ExactNotTaken;
8910 assert(BECount != SE->getCouldNotCompute() && "Bad exit SCEV!");
8911 assert(SE->DT.dominates(ENT.ExitingBlock, Latch) &&
8912 "We should only have known counts for exiting blocks that dominate "
8913 "latch!");
8914
8915 Ops.push_back(BECount);
8916
8917 if (Preds)
8918 append_range(*Preds, ENT.Predicates);
8919
8920 assert((Preds || ENT.hasAlwaysTruePredicate()) &&
8921 "Predicate should be always true!");
8922 }
8923
8924 // If an earlier exit exits on the first iteration (exit count zero), then
8925 // a later poison exit count should not propagate into the result. This are
8926 // exactly the semantics provided by umin_seq.
8927 return SE->getUMinFromMismatchedTypes(Ops, /* Sequential */ true);
8928}
8929
8930const ScalarEvolution::ExitNotTakenInfo *
8931ScalarEvolution::BackedgeTakenInfo::getExitNotTaken(
8932 const BasicBlock *ExitingBlock,
8933 SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
8934 for (const auto &ENT : ExitNotTaken)
8935 if (ENT.ExitingBlock == ExitingBlock) {
8936 if (ENT.hasAlwaysTruePredicate())
8937 return &ENT;
8938 else if (Predicates) {
8939 append_range(*Predicates, ENT.Predicates);
8940 return &ENT;
8941 }
8942 }
8943
8944 return nullptr;
8945}
8946
8947/// getConstantMax - Get the constant max backedge taken count for the loop.
8948const SCEV *ScalarEvolution::BackedgeTakenInfo::getConstantMax(
8949 ScalarEvolution *SE,
8950 SmallVectorImpl<const SCEVPredicate *> *Predicates) const {
8951 if (!getConstantMax())
8952 return SE->getCouldNotCompute();
8953
8954 for (const auto &ENT : ExitNotTaken)
8955 if (!ENT.hasAlwaysTruePredicate()) {
8956 if (!Predicates)
8957 return SE->getCouldNotCompute();
8958 append_range(*Predicates, ENT.Predicates);
8959 }
8960
8961 assert((isa<SCEVCouldNotCompute>(getConstantMax()) ||
8962 isa<SCEVConstant>(getConstantMax())) &&
8963 "No point in having a non-constant max backedge taken count!");
8964 return getConstantMax();
8965}
8966
8967const SCEV *ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(
8968 const Loop *L, ScalarEvolution *SE,
8969 SmallVectorImpl<const SCEVPredicate *> *Predicates) {
8970 if (!SymbolicMax) {
8971 // Form an expression for the maximum exit count possible for this loop. We
8972 // merge the max and exact information to approximate a version of
8973 // getConstantMaxBackedgeTakenCount which isn't restricted to just
8974 // constants.
8975 SmallVector<SCEVUse, 4> ExitCounts;
8976
8977 for (const auto &ENT : ExitNotTaken) {
8978 const SCEV *ExitCount = ENT.SymbolicMaxNotTaken;
8979 if (!isa<SCEVCouldNotCompute>(ExitCount)) {
8980 assert(SE->DT.dominates(ENT.ExitingBlock, L->getLoopLatch()) &&
8981 "We should only have known counts for exiting blocks that "
8982 "dominate latch!");
8983 ExitCounts.push_back(ExitCount);
8984 if (Predicates)
8985 append_range(*Predicates, ENT.Predicates);
8986
8987 assert((Predicates || ENT.hasAlwaysTruePredicate()) &&
8988 "Predicate should be always true!");
8989 }
8990 }
8991 if (ExitCounts.empty())
8992 SymbolicMax = SE->getCouldNotCompute();
8993 else
8994 SymbolicMax =
8995 SE->getUMinFromMismatchedTypes(ExitCounts, /*Sequential*/ true);
8996 }
8997 return SymbolicMax;
8998}
8999
9000bool ScalarEvolution::BackedgeTakenInfo::isConstantMaxOrZero(
9001 ScalarEvolution *SE) const {
9002 auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) {
9003 return !ENT.hasAlwaysTruePredicate();
9004 };
9005 return MaxOrZero && !any_of(ExitNotTaken, PredicateNotAlwaysTrue);
9006}
9007
9010
9012 const SCEV *E, const SCEV *ConstantMaxNotTaken,
9013 const SCEV *SymbolicMaxNotTaken, bool MaxOrZero,
9017 // If we prove the max count is zero, so is the symbolic bound. This happens
9018 // in practice due to differences in a) how context sensitive we've chosen
9019 // to be and b) how we reason about bounds implied by UB.
9020 if (ConstantMaxNotTaken->isZero()) {
9021 this->ExactNotTaken = E = ConstantMaxNotTaken;
9022 this->SymbolicMaxNotTaken = SymbolicMaxNotTaken = ConstantMaxNotTaken;
9023 }
9024
9027 "Exact is not allowed to be less precise than Constant Max");
9030 "Exact is not allowed to be less precise than Symbolic Max");
9033 "Symbolic Max is not allowed to be less precise than Constant Max");
9036 "No point in having a non-constant max backedge taken count!");
9038 for (const auto PredList : PredLists)
9039 for (const auto *P : PredList) {
9040 if (SeenPreds.contains(P))
9041 continue;
9042 assert(!isa<SCEVUnionPredicate>(P) && "Only add leaf predicates here!");
9043 SeenPreds.insert(P);
9044 Predicates.push_back(P);
9045 }
9046 assert((isa<SCEVCouldNotCompute>(E) || !E->getType()->isPointerTy()) &&
9047 "Backedge count should be int");
9049 !ConstantMaxNotTaken->getType()->isPointerTy()) &&
9050 "Max backedge count should be int");
9051}
9052
9060
9061/// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each
9062/// computable exit into a persistent ExitNotTakenInfo array.
9063ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(
9065 bool IsComplete, const SCEV *ConstantMax, bool MaxOrZero)
9066 : ConstantMax(ConstantMax), IsComplete(IsComplete), MaxOrZero(MaxOrZero) {
9067 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
9068
9069 ExitNotTaken.reserve(ExitCounts.size());
9070 std::transform(ExitCounts.begin(), ExitCounts.end(),
9071 std::back_inserter(ExitNotTaken),
9072 [&](const EdgeExitInfo &EEI) {
9073 BasicBlock *ExitBB = EEI.first;
9074 const ExitLimit &EL = EEI.second;
9075 return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken,
9076 EL.ConstantMaxNotTaken, EL.SymbolicMaxNotTaken,
9077 EL.Predicates);
9078 });
9079 assert((isa<SCEVCouldNotCompute>(ConstantMax) ||
9080 isa<SCEVConstant>(ConstantMax)) &&
9081 "No point in having a non-constant max backedge taken count!");
9082}
9083
9084/// Compute the number of times the backedge of the specified loop will execute.
9085ScalarEvolution::BackedgeTakenInfo
9086ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
9087 bool AllowPredicates) {
9088 SmallVector<BasicBlock *, 8> ExitingBlocks;
9089 L->getExitingBlocks(ExitingBlocks);
9090
9091 using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
9092
9094 bool CouldComputeBECount = true;
9095 BasicBlock *Latch = L->getLoopLatch(); // may be NULL.
9096 const SCEV *MustExitMaxBECount = nullptr;
9097 const SCEV *MayExitMaxBECount = nullptr;
9098 bool MustExitMaxOrZero = false;
9099 bool IsOnlyExit = ExitingBlocks.size() == 1;
9100
9101 // Compute the ExitLimit for each loop exit. Use this to populate ExitCounts
9102 // and compute maxBECount.
9103 // Do a union of all the predicates here.
9104 for (BasicBlock *ExitBB : ExitingBlocks) {
9105 // We canonicalize untaken exits to br (constant), ignore them so that
9106 // proving an exit untaken doesn't negatively impact our ability to reason
9107 // about the loop as whole.
9108 if (auto *BI = dyn_cast<CondBrInst>(ExitBB->getTerminator()))
9109 if (auto *CI = dyn_cast<ConstantInt>(BI->getCondition())) {
9110 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
9111 if (ExitIfTrue == CI->isZero())
9112 continue;
9113 }
9114
9115 ExitLimit EL = computeExitLimit(L, ExitBB, IsOnlyExit, AllowPredicates);
9116
9117 assert((AllowPredicates || EL.Predicates.empty()) &&
9118 "Predicated exit limit when predicates are not allowed!");
9119
9120 // 1. For each exit that can be computed, add an entry to ExitCounts.
9121 // CouldComputeBECount is true only if all exits can be computed.
9122 if (EL.ExactNotTaken != getCouldNotCompute())
9123 ++NumExitCountsComputed;
9124 else
9125 // We couldn't compute an exact value for this exit, so
9126 // we won't be able to compute an exact value for the loop.
9127 CouldComputeBECount = false;
9128 // Remember exit count if either exact or symbolic is known. Because
9129 // Exact always implies symbolic, only check symbolic.
9130 if (EL.SymbolicMaxNotTaken != getCouldNotCompute())
9131 ExitCounts.emplace_back(ExitBB, EL);
9132 else {
9133 assert(EL.ExactNotTaken == getCouldNotCompute() &&
9134 "Exact is known but symbolic isn't?");
9135 ++NumExitCountsNotComputed;
9136 }
9137
9138 // 2. Derive the loop's MaxBECount from each exit's max number of
9139 // non-exiting iterations. Partition the loop exits into two kinds:
9140 // LoopMustExits and LoopMayExits.
9141 //
9142 // If the exit dominates the loop latch, it is a LoopMustExit otherwise it
9143 // is a LoopMayExit. If any computable LoopMustExit is found, then
9144 // MaxBECount is the minimum EL.ConstantMaxNotTaken of computable
9145 // LoopMustExits. Otherwise, MaxBECount is conservatively the maximum
9146 // EL.ConstantMaxNotTaken, where CouldNotCompute is considered greater than
9147 // any
9148 // computable EL.ConstantMaxNotTaken.
9149 if (EL.ConstantMaxNotTaken != getCouldNotCompute() && Latch &&
9150 DT.dominates(ExitBB, Latch)) {
9151 if (!MustExitMaxBECount) {
9152 MustExitMaxBECount = EL.ConstantMaxNotTaken;
9153 MustExitMaxOrZero = EL.MaxOrZero;
9154 } else {
9155 MustExitMaxBECount = getUMinFromMismatchedTypes(MustExitMaxBECount,
9156 EL.ConstantMaxNotTaken);
9157 }
9158 } else if (MayExitMaxBECount != getCouldNotCompute()) {
9159 if (!MayExitMaxBECount || EL.ConstantMaxNotTaken == getCouldNotCompute())
9160 MayExitMaxBECount = EL.ConstantMaxNotTaken;
9161 else {
9162 MayExitMaxBECount = getUMaxFromMismatchedTypes(MayExitMaxBECount,
9163 EL.ConstantMaxNotTaken);
9164 }
9165 }
9166 }
9167 const SCEV *MaxBECount = MustExitMaxBECount ? MustExitMaxBECount :
9168 (MayExitMaxBECount ? MayExitMaxBECount : getCouldNotCompute());
9169 // The loop backedge will be taken the maximum or zero times if there's
9170 // a single exit that must be taken the maximum or zero times.
9171 bool MaxOrZero = (MustExitMaxOrZero && ExitingBlocks.size() == 1);
9172
9173 // Remember which SCEVs are used in exit limits for invalidation purposes.
9174 // We only care about non-constant SCEVs here, so we can ignore
9175 // EL.ConstantMaxNotTaken
9176 // and MaxBECount, which must be SCEVConstant.
9177 for (const auto &Pair : ExitCounts) {
9178 if (!isa<SCEVConstant>(Pair.second.ExactNotTaken))
9179 BECountUsers[Pair.second.ExactNotTaken].insert({L, AllowPredicates});
9180 if (!isa<SCEVConstant>(Pair.second.SymbolicMaxNotTaken))
9181 BECountUsers[Pair.second.SymbolicMaxNotTaken].insert(
9182 {L, AllowPredicates});
9183 }
9184 return BackedgeTakenInfo(std::move(ExitCounts), CouldComputeBECount,
9185 MaxBECount, MaxOrZero);
9186}
9187
9188ScalarEvolution::ExitLimit
9189ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
9190 bool IsOnlyExit, bool AllowPredicates) {
9191 assert(L->contains(ExitingBlock) && "Exit count for non-loop block?");
9192 // If our exiting block does not dominate the latch, then its connection with
9193 // loop's exit limit may be far from trivial.
9194 const BasicBlock *Latch = L->getLoopLatch();
9195 if (!Latch || !DT.dominates(ExitingBlock, Latch))
9196 return getCouldNotCompute();
9197
9198 Instruction *Term = ExitingBlock->getTerminator();
9199 if (CondBrInst *BI = dyn_cast<CondBrInst>(Term)) {
9200 bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
9201 assert(ExitIfTrue == L->contains(BI->getSuccessor(1)) &&
9202 "It should have one successor in loop and one exit block!");
9203 // Proceed to the next level to examine the exit condition expression.
9204 return computeExitLimitFromCond(L, BI->getCondition(), ExitIfTrue,
9205 /*ControlsOnlyExit=*/IsOnlyExit,
9206 AllowPredicates);
9207 }
9208
9209 if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) {
9210 // For switch, make sure that there is a single exit from the loop.
9211 BasicBlock *Exit = nullptr;
9212 for (auto *SBB : successors(ExitingBlock))
9213 if (!L->contains(SBB)) {
9214 if (Exit) // Multiple exit successors.
9215 return getCouldNotCompute();
9216 Exit = SBB;
9217 }
9218 assert(Exit && "Exiting block must have at least one exit");
9219 return computeExitLimitFromSingleExitSwitch(
9220 L, SI, Exit, /*ControlsOnlyExit=*/IsOnlyExit);
9221 }
9222
9223 return getCouldNotCompute();
9224}
9225
9227 const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
9228 bool AllowPredicates) {
9229 ScalarEvolution::ExitLimitCacheTy Cache(L, ExitIfTrue, AllowPredicates);
9230 return computeExitLimitFromCondCached(Cache, L, ExitCond, ExitIfTrue,
9231 ControlsOnlyExit, AllowPredicates);
9232}
9233
9234std::optional<ScalarEvolution::ExitLimit>
9235ScalarEvolution::ExitLimitCache::find(const Loop *L, Value *ExitCond,
9236 bool ExitIfTrue, bool ControlsOnlyExit,
9237 bool AllowPredicates) {
9238 (void)this->L;
9239 (void)this->ExitIfTrue;
9240 (void)this->AllowPredicates;
9241
9242 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
9243 this->AllowPredicates == AllowPredicates &&
9244 "Variance in assumed invariant key components!");
9245 auto Itr = TripCountMap.find({ExitCond, ControlsOnlyExit});
9246 if (Itr == TripCountMap.end())
9247 return std::nullopt;
9248 return Itr->second;
9249}
9250
9251void ScalarEvolution::ExitLimitCache::insert(const Loop *L, Value *ExitCond,
9252 bool ExitIfTrue,
9253 bool ControlsOnlyExit,
9254 bool AllowPredicates,
9255 const ExitLimit &EL) {
9256 assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
9257 this->AllowPredicates == AllowPredicates &&
9258 "Variance in assumed invariant key components!");
9259
9260 auto InsertResult = TripCountMap.insert({{ExitCond, ControlsOnlyExit}, EL});
9261 assert(InsertResult.second && "Expected successful insertion!");
9262 (void)InsertResult;
9263 (void)ExitIfTrue;
9264}
9265
9266ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached(
9267 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9268 bool ControlsOnlyExit, bool AllowPredicates) {
9269
9270 if (auto MaybeEL = Cache.find(L, ExitCond, ExitIfTrue, ControlsOnlyExit,
9271 AllowPredicates))
9272 return *MaybeEL;
9273
9274 ExitLimit EL = computeExitLimitFromCondImpl(
9275 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates);
9276 Cache.insert(L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates, EL);
9277 return EL;
9278}
9279
9280ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
9281 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9282 bool ControlsOnlyExit, bool AllowPredicates) {
9283 // Handle BinOp conditions (And, Or).
9284 if (auto LimitFromBinOp = computeExitLimitFromCondFromBinOp(
9285 Cache, L, ExitCond, ExitIfTrue, ControlsOnlyExit, AllowPredicates))
9286 return *LimitFromBinOp;
9287
9288 // With an icmp, it may be feasible to compute an exact backedge-taken count.
9289 // Proceed to the next level to examine the icmp.
9290 if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond)) {
9291 ExitLimit EL =
9292 computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsOnlyExit);
9293 if (EL.hasFullInfo() || !AllowPredicates)
9294 return EL;
9295
9296 // Try again, but use SCEV predicates this time.
9297 return computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue,
9298 ControlsOnlyExit,
9299 /*AllowPredicates=*/true);
9300 }
9301
9302 // Check for a constant condition. These are normally stripped out by
9303 // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to
9304 // preserve the CFG and is temporarily leaving constant conditions
9305 // in place.
9306 if (ConstantInt *CI = dyn_cast<ConstantInt>(ExitCond)) {
9307 if (ExitIfTrue == !CI->getZExtValue())
9308 // The backedge is always taken.
9309 return getCouldNotCompute();
9310 // The backedge is never taken.
9311 return getZero(CI->getType());
9312 }
9313
9314 // If we're exiting based on the overflow flag of an x.with.overflow intrinsic
9315 // with a constant step, we can form an equivalent icmp predicate and figure
9316 // out how many iterations will be taken before we exit.
9317 const WithOverflowInst *WO;
9318 const APInt *C;
9319 if (match(ExitCond, m_ExtractValue<1>(m_WithOverflowInst(WO))) &&
9320 match(WO->getRHS(), m_APInt(C))) {
9321 ConstantRange NWR =
9323 WO->getNoWrapKind());
9324 CmpInst::Predicate Pred;
9325 APInt NewRHSC, Offset;
9326 NWR.getEquivalentICmp(Pred, NewRHSC, Offset);
9327 if (!ExitIfTrue)
9328 Pred = ICmpInst::getInversePredicate(Pred);
9329 auto *LHS = getSCEV(WO->getLHS());
9330 if (Offset != 0)
9332 auto EL = computeExitLimitFromICmp(L, Pred, LHS, getConstant(NewRHSC),
9333 ControlsOnlyExit, AllowPredicates);
9334 if (EL.hasAnyInfo())
9335 return EL;
9336 }
9337
9338 // If it's not an integer or pointer comparison then compute it the hard way.
9339 return computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
9340}
9341
9342std::optional<ScalarEvolution::ExitLimit>
9343ScalarEvolution::computeExitLimitFromCondFromBinOp(
9344 ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
9345 bool ControlsOnlyExit, bool AllowPredicates) {
9346 // Check if the controlling expression for this loop is an And or Or.
9347 Value *Op0, *Op1;
9348 bool IsAnd;
9349 if (match(ExitCond, m_LogicalAnd(m_Value(Op0), m_Value(Op1))))
9350 IsAnd = true;
9351 else if (match(ExitCond, m_LogicalOr(m_Value(Op0), m_Value(Op1))))
9352 IsAnd = false;
9353 else
9354 return std::nullopt;
9355
9356 // Be robust against unsimplified IR for the form "op i1 X, NeutralElement".
9357 const Constant *NeutralElement = ConstantInt::get(ExitCond->getType(), IsAnd);
9358 if (Op0 == NeutralElement)
9359 std::swap(Op0, Op1);
9360 if (Op1 == NeutralElement)
9361 return computeExitLimitFromCondCached(Cache, L, Op0, ExitIfTrue,
9362 ControlsOnlyExit, AllowPredicates);
9363
9364 // A sub-condition of a non-trivial binop never solely controls the exit,
9365 // whether we exit always depends on both conditions.
9366 ExitLimit EL0 = computeExitLimitFromCondCached(
9367 Cache, L, Op0, ExitIfTrue, /*ControlsOnlyExit=*/false, AllowPredicates);
9368 ExitLimit EL1 = computeExitLimitFromCondCached(
9369 Cache, L, Op1, ExitIfTrue, /*ControlsOnlyExit=*/false, AllowPredicates);
9370
9371 // EitherMayExit is true in these two cases:
9372 // br (and Op0 Op1), loop, exit
9373 // br (or Op0 Op1), exit, loop
9374 bool EitherMayExit = IsAnd ^ ExitIfTrue;
9375
9376 const SCEV *BECount = getCouldNotCompute();
9377 const SCEV *ConstantMaxBECount = getCouldNotCompute();
9378 const SCEV *SymbolicMaxBECount = getCouldNotCompute();
9379 if (EitherMayExit) {
9380 bool UseSequentialUMin = !isa<BinaryOperator>(ExitCond);
9381 // Both conditions must be same for the loop to continue executing.
9382 // Choose the less conservative count.
9383 if (EL0.ExactNotTaken != getCouldNotCompute() &&
9384 EL1.ExactNotTaken != getCouldNotCompute()) {
9385 BECount = getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken,
9386 UseSequentialUMin);
9387 }
9388 if (EL0.ConstantMaxNotTaken == getCouldNotCompute())
9389 ConstantMaxBECount = EL1.ConstantMaxNotTaken;
9390 else if (EL1.ConstantMaxNotTaken == getCouldNotCompute())
9391 ConstantMaxBECount = EL0.ConstantMaxNotTaken;
9392 else
9393 ConstantMaxBECount = getUMinFromMismatchedTypes(EL0.ConstantMaxNotTaken,
9394 EL1.ConstantMaxNotTaken);
9395 if (EL0.SymbolicMaxNotTaken == getCouldNotCompute())
9396 SymbolicMaxBECount = EL1.SymbolicMaxNotTaken;
9397 else if (EL1.SymbolicMaxNotTaken == getCouldNotCompute())
9398 SymbolicMaxBECount = EL0.SymbolicMaxNotTaken;
9399 else
9400 SymbolicMaxBECount = getUMinFromMismatchedTypes(
9401 EL0.SymbolicMaxNotTaken, EL1.SymbolicMaxNotTaken, UseSequentialUMin);
9402 } else {
9403 // Both conditions must be same at the same time for the loop to exit.
9404 // For now, be conservative.
9405 if (EL0.ExactNotTaken == EL1.ExactNotTaken)
9406 BECount = EL0.ExactNotTaken;
9407 }
9408
9409 // There are cases (e.g. PR26207) where computeExitLimitFromCond is able
9410 // to be more aggressive when computing BECount than when computing
9411 // ConstantMaxBECount. In these cases it is possible for EL0.ExactNotTaken
9412 // and
9413 // EL1.ExactNotTaken to match, but for EL0.ConstantMaxNotTaken and
9414 // EL1.ConstantMaxNotTaken to not.
9415 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
9416 !isa<SCEVCouldNotCompute>(BECount))
9417 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
9418 if (isa<SCEVCouldNotCompute>(SymbolicMaxBECount))
9419 SymbolicMaxBECount =
9420 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
9421 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
9422 {ArrayRef(EL0.Predicates), ArrayRef(EL1.Predicates)});
9423}
9424
9425ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9426 const Loop *L, ICmpInst *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit,
9427 bool AllowPredicates) {
9428 // If the condition was exit on true, convert the condition to exit on false
9429 CmpPredicate Pred;
9430 if (!ExitIfTrue)
9431 Pred = ExitCond->getCmpPredicate();
9432 else
9433 Pred = ExitCond->getInverseCmpPredicate();
9434 const ICmpInst::Predicate OriginalPred = Pred;
9435
9436 const SCEV *LHS = getSCEV(ExitCond->getOperand(0));
9437 const SCEV *RHS = getSCEV(ExitCond->getOperand(1));
9438
9439 ExitLimit EL = computeExitLimitFromICmp(L, Pred, LHS, RHS, ControlsOnlyExit,
9440 AllowPredicates);
9441 if (EL.hasAnyInfo())
9442 return EL;
9443
9444 auto *ExhaustiveCount =
9445 computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
9446
9447 if (!isa<SCEVCouldNotCompute>(ExhaustiveCount))
9448 return ExhaustiveCount;
9449
9450 return computeShiftCompareExitLimit(ExitCond->getOperand(0),
9451 ExitCond->getOperand(1), L, OriginalPred);
9452}
9453ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(
9454 const Loop *L, CmpPredicate Pred, SCEVUse LHS, SCEVUse RHS,
9455 bool ControlsOnlyExit, bool AllowPredicates) {
9456
9457 // Try to evaluate any dependencies out of the loop.
9458 LHS = getSCEVAtScope(LHS, L);
9459 RHS = getSCEVAtScope(RHS, L);
9460
9461 // At this point, we would like to compute how many iterations of the
9462 // loop the predicate will return true for these inputs.
9463 if (isLoopInvariant(LHS, L) && !isLoopInvariant(RHS, L)) {
9464 // If there is a loop-invariant, force it into the RHS.
9465 std::swap(LHS, RHS);
9467 }
9468
9469 bool ControllingFiniteLoop = ControlsOnlyExit && loopHasNoAbnormalExits(L) &&
9471 // Simplify the operands before analyzing them.
9472 (void)SimplifyICmpOperands(Pred, LHS, RHS, /*Depth=*/0);
9473
9474 // If we have a comparison of a chrec against a constant, try to use value
9475 // ranges to answer this query.
9476 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS))
9477 if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS))
9478 if (AddRec->getLoop() == L) {
9479 // Form the constant range.
9480 ConstantRange CompRange =
9481 ConstantRange::makeExactICmpRegion(Pred, RHSC->getAPInt());
9482
9483 const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this);
9484 if (!isa<SCEVCouldNotCompute>(Ret)) return Ret;
9485 }
9486
9487 // If this loop must exit based on this condition (or execute undefined
9488 // behaviour), see if we can improve wrap flags. This is essentially
9489 // a must execute style proof.
9490 if (ControllingFiniteLoop && isLoopInvariant(RHS, L)) {
9491 // If we can prove the test sequence produced must repeat the same values
9492 // on self-wrap of the IV, then we can infer that IV doesn't self wrap
9493 // because if it did, we'd have an infinite (undefined) loop.
9494 // TODO: We can peel off any functions which are invertible *in L*. Loop
9495 // invariant terms are effectively constants for our purposes here.
9496 SCEVUse InnerLHS = LHS;
9497 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS))
9498 InnerLHS = ZExt->getOperand();
9499 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(InnerLHS);
9500 AR && !AR->hasNoSelfWrap() && AR->getLoop() == L && AR->isAffine() &&
9501 isKnownToBeAPowerOfTwo(AR->getStepRecurrence(*this), /*OrZero=*/true,
9502 /*OrNegative=*/true)) {
9503 auto Flags = AR->getNoWrapFlags();
9504 Flags = setFlags(Flags, SCEV::FlagNW);
9505 SmallVector<SCEVUse> Operands{AR->operands()};
9506 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
9507 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
9508 }
9509
9510 // For a slt/ult condition with a positive step, can we prove nsw/nuw?
9511 // From no-self-wrap, this follows trivially from the fact that every
9512 // (un)signed-wrapped, but not self-wrapped value must be LT than the
9513 // last value before (un)signed wrap. Since we know that last value
9514 // didn't exit, nor will any smaller one.
9515 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT) {
9516 auto WrapType = Pred == ICmpInst::ICMP_SLT ? SCEV::FlagNSW : SCEV::FlagNUW;
9517 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS);
9518 AR && AR->getLoop() == L && AR->isAffine() &&
9519 !AR->getNoWrapFlags(WrapType) && AR->hasNoSelfWrap() &&
9520 isKnownPositive(AR->getStepRecurrence(*this))) {
9521 auto Flags = AR->getNoWrapFlags();
9522 Flags = setFlags(Flags, WrapType);
9523 SmallVector<SCEVUse> Operands{AR->operands()};
9524 Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
9525 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
9526 }
9527 }
9528 }
9529
9530 switch (Pred) {
9531 case ICmpInst::ICMP_NE: { // while (X != Y)
9532 // Convert to: while (X-Y != 0)
9533 if (LHS->getType()->isPointerTy()) {
9536 return LHS;
9537 }
9538 if (RHS->getType()->isPointerTy()) {
9541 return RHS;
9542 }
9543 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit,
9544 AllowPredicates);
9545 if (EL.hasAnyInfo())
9546 return EL;
9547 break;
9548 }
9549 case ICmpInst::ICMP_EQ: { // while (X == Y)
9550 // Convert to: while (X-Y == 0)
9551 if (LHS->getType()->isPointerTy()) {
9554 return LHS;
9555 }
9556 if (RHS->getType()->isPointerTy()) {
9559 return RHS;
9560 }
9561 ExitLimit EL = howFarToNonZero(getMinusSCEV(LHS, RHS), L);
9562 if (EL.hasAnyInfo()) return EL;
9563 break;
9564 }
9565 case ICmpInst::ICMP_SLE:
9566 case ICmpInst::ICMP_ULE:
9567 // Since the loop is finite, an invariant RHS cannot include the boundary
9568 // value, otherwise it would loop forever.
9569 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9570 !isLoopInvariant(RHS, L)) {
9571 // Otherwise, perform the addition in a wider type, to avoid overflow.
9572 // If the LHS is an addrec with the appropriate nowrap flag, the
9573 // extension will be sunk into it and the exit count can be analyzed.
9574 auto *OldType = dyn_cast<IntegerType>(LHS->getType());
9575 if (!OldType)
9576 break;
9577 // Prefer doubling the bitwidth over adding a single bit to make it more
9578 // likely that we use a legal type.
9579 auto *NewType =
9580 Type::getIntNTy(OldType->getContext(), OldType->getBitWidth() * 2);
9581 if (ICmpInst::isSigned(Pred)) {
9582 LHS = getSignExtendExpr(LHS, NewType);
9583 RHS = getSignExtendExpr(RHS, NewType);
9584 } else {
9585 LHS = getZeroExtendExpr(LHS, NewType);
9586 RHS = getZeroExtendExpr(RHS, NewType);
9587 }
9588 }
9590 [[fallthrough]];
9591 case ICmpInst::ICMP_SLT:
9592 case ICmpInst::ICMP_ULT: { // while (X < Y)
9593 bool IsSigned = ICmpInst::isSigned(Pred);
9594 ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9595 AllowPredicates);
9596 if (EL.hasAnyInfo())
9597 return EL;
9598 break;
9599 }
9600 case ICmpInst::ICMP_SGE:
9601 case ICmpInst::ICMP_UGE:
9602 // Since the loop is finite, an invariant RHS cannot include the boundary
9603 // value, otherwise it would loop forever.
9604 if (!EnableFiniteLoopControl || !ControllingFiniteLoop ||
9605 !isLoopInvariant(RHS, L))
9606 break;
9608 [[fallthrough]];
9609 case ICmpInst::ICMP_SGT:
9610 case ICmpInst::ICMP_UGT: { // while (X > Y)
9611 bool IsSigned = ICmpInst::isSigned(Pred);
9612 ExitLimit EL = howManyGreaterThans(LHS, RHS, L, IsSigned, ControlsOnlyExit,
9613 AllowPredicates);
9614 if (EL.hasAnyInfo())
9615 return EL;
9616 break;
9617 }
9618 default:
9619 break;
9620 }
9621
9622 return getCouldNotCompute();
9623}
9624
9625ScalarEvolution::ExitLimit
9626ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L,
9627 SwitchInst *Switch,
9628 BasicBlock *ExitingBlock,
9629 bool ControlsOnlyExit) {
9630 assert(!L->contains(ExitingBlock) && "Not an exiting block!");
9631
9632 // Give up if the exit is the default dest of a switch.
9633 if (Switch->getDefaultDest() == ExitingBlock)
9634 return getCouldNotCompute();
9635
9636 assert(L->contains(Switch->getDefaultDest()) &&
9637 "Default case must not exit the loop!");
9638 const SCEV *LHS = getSCEVAtScope(Switch->getCondition(), L);
9639 const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock));
9640
9641 // while (X != Y) --> while (X-Y != 0)
9642 ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsOnlyExit);
9643 if (EL.hasAnyInfo())
9644 return EL;
9645
9646 return getCouldNotCompute();
9647}
9648
9649static ConstantInt *
9651 ScalarEvolution &SE) {
9652 const SCEV *InVal = SE.getConstant(C);
9653 const SCEV *Val = AddRec->evaluateAtIteration(InVal, SE);
9655 "Evaluation of SCEV at constant didn't fold correctly?");
9656 return cast<SCEVConstant>(Val)->getValue();
9657}
9658
9659ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit(
9660 Value *LHS, Value *RHSV, const Loop *L, ICmpInst::Predicate Pred) {
9661 ConstantInt *RHS = dyn_cast<ConstantInt>(RHSV);
9662 if (!RHS)
9663 return getCouldNotCompute();
9664
9665 const BasicBlock *Latch = L->getLoopLatch();
9666 if (!Latch)
9667 return getCouldNotCompute();
9668
9669 const BasicBlock *Predecessor = L->getLoopPredecessor();
9670 if (!Predecessor)
9671 return getCouldNotCompute();
9672
9673 // Return true if V is of the form "LHS `shift_op` <positive constant>".
9674 // Return LHS in OutLHS and shift_opt in OutOpCode.
9675 auto MatchPositiveShift =
9676 [](Value *V, Value *&OutLHS, Instruction::BinaryOps &OutOpCode) {
9677
9678 using namespace PatternMatch;
9679
9680 ConstantInt *ShiftAmt;
9681 if (match(V, m_LShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9682 OutOpCode = Instruction::LShr;
9683 else if (match(V, m_AShr(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9684 OutOpCode = Instruction::AShr;
9685 else if (match(V, m_Shl(m_Value(OutLHS), m_ConstantInt(ShiftAmt))))
9686 OutOpCode = Instruction::Shl;
9687 else
9688 return false;
9689
9690 return ShiftAmt->getValue().isStrictlyPositive();
9691 };
9692
9693 // Recognize a "shift recurrence" either of the form %iv or of %iv.shifted in
9694 //
9695 // loop:
9696 // %iv = phi i32 [ %iv.shifted, %loop ], [ %val, %preheader ]
9697 // %iv.shifted = lshr i32 %iv, <positive constant>
9698 //
9699 // Return true on a successful match. Return the corresponding PHI node (%iv
9700 // above) in PNOut and the opcode of the shift operation in OpCodeOut.
9701 auto MatchShiftRecurrence =
9702 [&](Value *V, PHINode *&PNOut, Instruction::BinaryOps &OpCodeOut) {
9703 std::optional<Instruction::BinaryOps> PostShiftOpCode;
9704
9705 {
9707 Value *V;
9708
9709 // If we encounter a shift instruction, "peel off" the shift operation,
9710 // and remember that we did so. Later when we inspect %iv's backedge
9711 // value, we will make sure that the backedge value uses the same
9712 // operation.
9713 //
9714 // Note: the peeled shift operation does not have to be the same
9715 // instruction as the one feeding into the PHI's backedge value. We only
9716 // really care about it being the same *kind* of shift instruction --
9717 // that's all that is required for our later inferences to hold.
9718 if (MatchPositiveShift(LHS, V, OpC)) {
9719 PostShiftOpCode = OpC;
9720 LHS = V;
9721 }
9722 }
9723
9724 PNOut = dyn_cast<PHINode>(LHS);
9725 if (!PNOut || PNOut->getParent() != L->getHeader())
9726 return false;
9727
9728 Value *BEValue = PNOut->getIncomingValueForBlock(Latch);
9729 Value *OpLHS;
9730
9731 return
9732 // The backedge value for the PHI node must be a shift by a positive
9733 // amount
9734 MatchPositiveShift(BEValue, OpLHS, OpCodeOut) &&
9735
9736 // of the PHI node itself
9737 OpLHS == PNOut &&
9738
9739 // and the kind of shift should be match the kind of shift we peeled
9740 // off, if any.
9741 (!PostShiftOpCode || *PostShiftOpCode == OpCodeOut);
9742 };
9743
9744 PHINode *PN;
9746 if (!MatchShiftRecurrence(LHS, PN, OpCode))
9747 return getCouldNotCompute();
9748
9749 const DataLayout &DL = getDataLayout();
9750
9751 // The key rationale for this optimization is that for some kinds of shift
9752 // recurrences, the value of the recurrence "stabilizes" to either 0 or -1
9753 // within a finite number of iterations. If the condition guarding the
9754 // backedge (in the sense that the backedge is taken if the condition is true)
9755 // is false for the value the shift recurrence stabilizes to, then we know
9756 // that the backedge is taken only a finite number of times.
9757
9758 ConstantInt *StableValue = nullptr;
9759 switch (OpCode) {
9760 default:
9761 llvm_unreachable("Impossible case!");
9762
9763 case Instruction::AShr: {
9764 // {K,ashr,<positive-constant>} stabilizes to signum(K) in at most
9765 // bitwidth(K) iterations.
9766 Value *FirstValue = PN->getIncomingValueForBlock(Predecessor);
9767 KnownBits Known = computeKnownBits(FirstValue, DL, &AC,
9768 Predecessor->getTerminator(), &DT);
9769 auto *Ty = cast<IntegerType>(RHS->getType());
9770 if (Known.isNonNegative())
9771 StableValue = ConstantInt::get(Ty, 0);
9772 else if (Known.isNegative())
9773 StableValue = ConstantInt::get(Ty, -1, true);
9774 else
9775 return getCouldNotCompute();
9776
9777 break;
9778 }
9779 case Instruction::LShr:
9780 case Instruction::Shl:
9781 // Both {K,lshr,<positive-constant>} and {K,shl,<positive-constant>}
9782 // stabilize to 0 in at most bitwidth(K) iterations.
9783 StableValue = ConstantInt::get(cast<IntegerType>(RHS->getType()), 0);
9784 break;
9785 }
9786
9787 auto *Result =
9788 ConstantFoldCompareInstOperands(Pred, StableValue, RHS, DL, &TLI);
9789 assert(Result->getType()->isIntegerTy(1) &&
9790 "Otherwise cannot be an operand to a branch instruction");
9791
9792 if (Result->isNullValue()) {
9793 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
9794 const SCEV *UpperBound =
9796 return ExitLimit(getCouldNotCompute(), UpperBound, UpperBound, false);
9797 }
9798
9799 return getCouldNotCompute();
9800}
9801
9802/// Return true if we can constant fold an instruction of the specified type,
9803/// assuming that all operands were constants.
9804static bool CanConstantFold(const Instruction *I) {
9808 return true;
9809
9810 if (const CallInst *CI = dyn_cast<CallInst>(I))
9811 if (const Function *F = CI->getCalledFunction())
9812 return canConstantFoldCallTo(CI, F);
9813 return false;
9814}
9815
9816/// Determine whether this instruction can constant evolve within this loop
9817/// assuming its operands can all constant evolve.
9818static bool canConstantEvolve(Instruction *I, const Loop *L) {
9819 // An instruction outside of the loop can't be derived from a loop PHI.
9820 if (!L->contains(I)) return false;
9821
9822 if (isa<PHINode>(I)) {
9823 // We don't currently keep track of the control flow needed to evaluate
9824 // PHIs, so we cannot handle PHIs inside of loops.
9825 return L->getHeader() == I->getParent();
9826 }
9827
9828 // If we won't be able to constant fold this expression even if the operands
9829 // are constants, bail early.
9830 return CanConstantFold(I);
9831}
9832
9833/// getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by
9834/// recursing through each instruction operand until reaching a loop header phi.
9835static PHINode *
9838 unsigned Depth) {
9840 return nullptr;
9841
9842 // Otherwise, we can evaluate this instruction if all of its operands are
9843 // constant or derived from a PHI node themselves.
9844 PHINode *PHI = nullptr;
9845 for (Value *Op : UseInst->operands()) {
9846 if (isa<Constant>(Op)) continue;
9847
9849 if (!OpInst || !canConstantEvolve(OpInst, L)) return nullptr;
9850
9851 PHINode *P = dyn_cast<PHINode>(OpInst);
9852 if (!P)
9853 // If this operand is already visited, reuse the prior result.
9854 // We may have P != PHI if this is the deepest point at which the
9855 // inconsistent paths meet.
9856 P = PHIMap.lookup(OpInst);
9857 if (!P) {
9858 // Recurse and memoize the results, whether a phi is found or not.
9859 // This recursive call invalidates pointers into PHIMap.
9860 P = getConstantEvolvingPHIOperands(OpInst, L, PHIMap, Depth + 1);
9861 PHIMap[OpInst] = P;
9862 }
9863 if (!P)
9864 return nullptr; // Not evolving from PHI
9865 if (PHI && PHI != P)
9866 return nullptr; // Evolving from multiple different PHIs.
9867 PHI = P;
9868 }
9869 // This is a expression evolving from a constant PHI!
9870 return PHI;
9871}
9872
9873/// getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node
9874/// in the loop that V is derived from. We allow arbitrary operations along the
9875/// way, but the operands of an operation must either be constants or a value
9876/// derived from a constant PHI. If this expression does not fit with these
9877/// constraints, return null.
9880 if (!I || !canConstantEvolve(I, L)) return nullptr;
9881
9882 if (PHINode *PN = dyn_cast<PHINode>(I))
9883 return PN;
9884
9885 // Record non-constant instructions contained by the loop.
9887 return getConstantEvolvingPHIOperands(I, L, PHIMap, 0);
9888}
9889
9890/// EvaluateExpression - Given an expression that passes the
9891/// getConstantEvolvingPHI predicate, evaluate its value assuming the PHI node
9892/// in the loop has the value PHIVal. If we can't fold this expression for some
9893/// reason, return null.
9896 const DataLayout &DL,
9897 const TargetLibraryInfo *TLI) {
9898 // Convenient constant check, but redundant for recursive calls.
9899 if (Constant *C = dyn_cast<Constant>(V)) return C;
9901 if (!I) return nullptr;
9902
9903 if (Constant *C = Vals.lookup(I)) return C;
9904
9905 // An instruction inside the loop depends on a value outside the loop that we
9906 // weren't given a mapping for, or a value such as a call inside the loop.
9907 if (!canConstantEvolve(I, L)) return nullptr;
9908
9909 // An unmapped PHI can be due to a branch or another loop inside this loop,
9910 // or due to this not being the initial iteration through a loop where we
9911 // couldn't compute the evolution of this particular PHI last time.
9912 if (isa<PHINode>(I)) return nullptr;
9913
9914 std::vector<Constant*> Operands(I->getNumOperands());
9915
9916 for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
9917 Instruction *Operand = dyn_cast<Instruction>(I->getOperand(i));
9918 if (!Operand) {
9919 Operands[i] = dyn_cast<Constant>(I->getOperand(i));
9920 if (!Operands[i]) return nullptr;
9921 continue;
9922 }
9923 Constant *C = EvaluateExpression(Operand, L, Vals, DL, TLI);
9924 Vals[Operand] = C;
9925 if (!C) return nullptr;
9926 Operands[i] = C;
9927 }
9928
9929 return ConstantFoldInstOperands(I, Operands, DL, TLI,
9930 /*AllowNonDeterministic=*/false);
9931}
9932
9933
9934// If every incoming value to PN except the one for BB is a specific Constant,
9935// return that, else return nullptr.
9937 Constant *IncomingVal = nullptr;
9938
9939 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
9940 if (PN->getIncomingBlock(i) == BB)
9941 continue;
9942
9943 auto *CurrentVal = dyn_cast<Constant>(PN->getIncomingValue(i));
9944 if (!CurrentVal)
9945 return nullptr;
9946
9947 if (IncomingVal != CurrentVal) {
9948 if (IncomingVal)
9949 return nullptr;
9950 IncomingVal = CurrentVal;
9951 }
9952 }
9953
9954 return IncomingVal;
9955}
9956
9957/// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
9958/// in the header of its containing loop, we know the loop executes a
9959/// constant number of times, and the PHI node is just a recurrence
9960/// involving constants, fold it.
9961Constant *
9962ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN,
9963 const APInt &BEs,
9964 const Loop *L) {
9965 auto [I, Inserted] = ConstantEvolutionLoopExitValue.try_emplace(PN);
9966 if (!Inserted)
9967 return I->second;
9968
9970 return nullptr; // Not going to evaluate it.
9971
9972 Constant *&RetVal = I->second;
9973
9974 DenseMap<Instruction *, Constant *> CurrentIterVals;
9975 BasicBlock *Header = L->getHeader();
9976 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
9977
9978 BasicBlock *Latch = L->getLoopLatch();
9979 if (!Latch)
9980 return nullptr;
9981
9982 for (PHINode &PHI : Header->phis()) {
9983 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
9984 CurrentIterVals[&PHI] = StartCST;
9985 }
9986 if (!CurrentIterVals.count(PN))
9987 return RetVal = nullptr;
9988
9989 Value *BEValue = PN->getIncomingValueForBlock(Latch);
9990
9991 // Execute the loop symbolically to determine the exit value.
9992 assert(BEs.getActiveBits() < CHAR_BIT * sizeof(unsigned) &&
9993 "BEs is <= MaxBruteForceIterations which is an 'unsigned'!");
9994
9995 unsigned NumIterations = BEs.getZExtValue(); // must be in range
9996 unsigned IterationNum = 0;
9997 const DataLayout &DL = getDataLayout();
9998 for (; ; ++IterationNum) {
9999 if (IterationNum == NumIterations)
10000 return RetVal = CurrentIterVals[PN]; // Got exit value!
10001
10002 // Compute the value of the PHIs for the next iteration.
10003 // EvaluateExpression adds non-phi values to the CurrentIterVals map.
10004 DenseMap<Instruction *, Constant *> NextIterVals;
10005 Constant *NextPHI =
10006 EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
10007 if (!NextPHI)
10008 return nullptr; // Couldn't evaluate!
10009 NextIterVals[PN] = NextPHI;
10010
10011 bool StoppedEvolving = NextPHI == CurrentIterVals[PN];
10012
10013 // Also evaluate the other PHI nodes. However, we don't get to stop if we
10014 // cease to be able to evaluate one of them or if they stop evolving,
10015 // because that doesn't necessarily prevent us from computing PN.
10017 for (const auto &I : CurrentIterVals) {
10018 PHINode *PHI = dyn_cast<PHINode>(I.first);
10019 if (!PHI || PHI == PN || PHI->getParent() != Header) continue;
10020 PHIsToCompute.emplace_back(PHI, I.second);
10021 }
10022 // We use two distinct loops because EvaluateExpression may invalidate any
10023 // iterators into CurrentIterVals.
10024 for (const auto &I : PHIsToCompute) {
10025 PHINode *PHI = I.first;
10026 Constant *&NextPHI = NextIterVals[PHI];
10027 if (!NextPHI) { // Not already computed.
10028 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
10029 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
10030 }
10031 if (NextPHI != I.second)
10032 StoppedEvolving = false;
10033 }
10034
10035 // If all entries in CurrentIterVals == NextIterVals then we can stop
10036 // iterating, the loop can't continue to change.
10037 if (StoppedEvolving)
10038 return RetVal = CurrentIterVals[PN];
10039
10040 CurrentIterVals.swap(NextIterVals);
10041 }
10042}
10043
10044const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L,
10045 Value *Cond,
10046 bool ExitWhen) {
10047 PHINode *PN = getConstantEvolvingPHI(Cond, L);
10048 if (!PN) return getCouldNotCompute();
10049
10050 // If the loop is canonicalized, the PHI will have exactly two entries.
10051 // That's the only form we support here.
10052 if (PN->getNumIncomingValues() != 2) return getCouldNotCompute();
10053
10054 DenseMap<Instruction *, Constant *> CurrentIterVals;
10055 BasicBlock *Header = L->getHeader();
10056 assert(PN->getParent() == Header && "Can't evaluate PHI not in loop header!");
10057
10058 BasicBlock *Latch = L->getLoopLatch();
10059 assert(Latch && "Should follow from NumIncomingValues == 2!");
10060
10061 for (PHINode &PHI : Header->phis()) {
10062 if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
10063 CurrentIterVals[&PHI] = StartCST;
10064 }
10065 if (!CurrentIterVals.count(PN))
10066 return getCouldNotCompute();
10067
10068 // Okay, we find a PHI node that defines the trip count of this loop. Execute
10069 // the loop symbolically to determine when the condition gets a value of
10070 // "ExitWhen".
10071 unsigned MaxIterations = MaxBruteForceIterations; // Limit analysis.
10072 const DataLayout &DL = getDataLayout();
10073 for (unsigned IterationNum = 0; IterationNum != MaxIterations;++IterationNum){
10074 auto *CondVal = dyn_cast_or_null<ConstantInt>(
10075 EvaluateExpression(Cond, L, CurrentIterVals, DL, &TLI));
10076
10077 // Couldn't symbolically evaluate.
10078 if (!CondVal) return getCouldNotCompute();
10079
10080 if (CondVal->getValue() == uint64_t(ExitWhen)) {
10081 ++NumBruteForceTripCountsComputed;
10082 return getConstant(Type::getInt32Ty(getContext()), IterationNum);
10083 }
10084
10085 // Update all the PHI nodes for the next iteration.
10086 DenseMap<Instruction *, Constant *> NextIterVals;
10087
10088 // Create a list of which PHIs we need to compute. We want to do this before
10089 // calling EvaluateExpression on them because that may invalidate iterators
10090 // into CurrentIterVals.
10091 SmallVector<PHINode *, 8> PHIsToCompute;
10092 for (const auto &I : CurrentIterVals) {
10093 PHINode *PHI = dyn_cast<PHINode>(I.first);
10094 if (!PHI || PHI->getParent() != Header) continue;
10095 PHIsToCompute.push_back(PHI);
10096 }
10097 for (PHINode *PHI : PHIsToCompute) {
10098 Constant *&NextPHI = NextIterVals[PHI];
10099 if (NextPHI) continue; // Already computed!
10100
10101 Value *BEValue = PHI->getIncomingValueForBlock(Latch);
10102 NextPHI = EvaluateExpression(BEValue, L, CurrentIterVals, DL, &TLI);
10103 }
10104 CurrentIterVals.swap(NextIterVals);
10105 }
10106
10107 // Too many iterations were needed to evaluate.
10108 return getCouldNotCompute();
10109}
10110
10111const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
10113 ValuesAtScopes[V];
10114 // Check to see if we've folded this expression at this loop before.
10115 for (auto &LS : Values)
10116 if (LS.first == L)
10117 return LS.second ? LS.second : V;
10118
10119 Values.emplace_back(L, nullptr);
10120
10121 // Otherwise compute it.
10122 const SCEV *C = computeSCEVAtScope(V, L);
10123 for (auto &LS : reverse(ValuesAtScopes[V]))
10124 if (LS.first == L) {
10125 LS.second = C;
10126 if (!isa<SCEVConstant>(C))
10127 ValuesAtScopesUsers[C].push_back({L, V});
10128 break;
10129 }
10130 return C;
10131}
10132
10133/// This builds up a Constant using the ConstantExpr interface. That way, we
10134/// will return Constants for objects which aren't represented by a
10135/// SCEVConstant, because SCEVConstant is restricted to ConstantInt.
10136/// Returns NULL if the SCEV isn't representable as a Constant.
10138 switch (V->getSCEVType()) {
10139 case scCouldNotCompute:
10140 case scAddRecExpr:
10141 case scVScale:
10142 return nullptr;
10143 case scConstant:
10144 return cast<SCEVConstant>(V)->getValue();
10145 case scUnknown:
10146 return dyn_cast<Constant>(cast<SCEVUnknown>(V)->getValue());
10147 case scPtrToAddr: {
10149 if (Constant *CastOp = BuildConstantFromSCEV(P2I->getOperand()))
10150 return ConstantExpr::getPtrToAddr(CastOp, P2I->getType());
10151
10152 return nullptr;
10153 }
10154 case scPtrToInt: {
10156 if (Constant *CastOp = BuildConstantFromSCEV(P2I->getOperand()))
10157 return ConstantExpr::getPtrToInt(CastOp, P2I->getType());
10158
10159 return nullptr;
10160 }
10161 case scTruncate: {
10163 if (Constant *CastOp = BuildConstantFromSCEV(ST->getOperand()))
10164 return ConstantExpr::getTrunc(CastOp, ST->getType());
10165 return nullptr;
10166 }
10167 case scAddExpr: {
10168 const SCEVAddExpr *SA = cast<SCEVAddExpr>(V);
10169 Constant *C = nullptr;
10170 for (const SCEV *Op : SA->operands()) {
10172 if (!OpC)
10173 return nullptr;
10174 if (!C) {
10175 C = OpC;
10176 continue;
10177 }
10178 assert(!C->getType()->isPointerTy() &&
10179 "Can only have one pointer, and it must be last");
10180 if (OpC->getType()->isPointerTy()) {
10181 // The offsets have been converted to bytes. We can add bytes using
10182 // an i8 GEP.
10183 C = ConstantExpr::getPtrAdd(OpC, C);
10184 } else {
10185 C = ConstantExpr::getAdd(C, OpC);
10186 }
10187 }
10188 return C;
10189 }
10190 case scMulExpr:
10191 case scSignExtend:
10192 case scZeroExtend:
10193 case scUDivExpr:
10194 case scSMaxExpr:
10195 case scUMaxExpr:
10196 case scSMinExpr:
10197 case scUMinExpr:
10199 return nullptr;
10200 }
10201 llvm_unreachable("Unknown SCEV kind!");
10202}
10203
10204const SCEV *ScalarEvolution::getWithOperands(const SCEV *S,
10205 SmallVectorImpl<SCEVUse> &NewOps) {
10206 switch (S->getSCEVType()) {
10207 case scTruncate:
10208 case scZeroExtend:
10209 case scSignExtend:
10210 case scPtrToAddr:
10211 case scPtrToInt:
10212 return getCastExpr(S->getSCEVType(), NewOps[0], S->getType());
10213 case scAddRecExpr: {
10214 auto *AddRec = cast<SCEVAddRecExpr>(S);
10215 return getAddRecExpr(NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags());
10216 }
10217 case scAddExpr:
10218 return getAddExpr(NewOps, cast<SCEVAddExpr>(S)->getNoWrapFlags());
10219 case scMulExpr:
10220 return getMulExpr(NewOps, cast<SCEVMulExpr>(S)->getNoWrapFlags());
10221 case scUDivExpr:
10222 return getUDivExpr(NewOps[0], NewOps[1]);
10223 case scUMaxExpr:
10224 case scSMaxExpr:
10225 case scUMinExpr:
10226 case scSMinExpr:
10227 return getMinMaxExpr(S->getSCEVType(), NewOps);
10229 return getSequentialMinMaxExpr(S->getSCEVType(), NewOps);
10230 case scConstant:
10231 case scVScale:
10232 case scUnknown:
10233 return S;
10234 case scCouldNotCompute:
10235 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
10236 }
10237 llvm_unreachable("Unknown SCEV kind!");
10238}
10239
10240const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
10241 switch (V->getSCEVType()) {
10242 case scConstant:
10243 case scVScale:
10244 return V;
10245 case scAddRecExpr: {
10246 // If this is a loop recurrence for a loop that does not contain L, then we
10247 // are dealing with the final value computed by the loop.
10248 const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(V);
10249 // First, attempt to evaluate each operand.
10250 // Avoid performing the look-up in the common case where the specified
10251 // expression has no loop-variant portions.
10252 for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
10253 const SCEV *OpAtScope = getSCEVAtScope(AddRec->getOperand(i), L);
10254 if (OpAtScope == AddRec->getOperand(i))
10255 continue;
10256
10257 // Okay, at least one of these operands is loop variant but might be
10258 // foldable. Build a new instance of the folded commutative expression.
10260 NewOps.reserve(AddRec->getNumOperands());
10261 append_range(NewOps, AddRec->operands().take_front(i));
10262 NewOps.push_back(OpAtScope);
10263 for (++i; i != e; ++i)
10264 NewOps.push_back(getSCEVAtScope(AddRec->getOperand(i), L));
10265
10266 const SCEV *FoldedRec = getAddRecExpr(
10267 NewOps, AddRec->getLoop(), AddRec->getNoWrapFlags(SCEV::FlagNW));
10268 AddRec = dyn_cast<SCEVAddRecExpr>(FoldedRec);
10269 // The addrec may be folded to a nonrecurrence, for example, if the
10270 // induction variable is multiplied by zero after constant folding. Go
10271 // ahead and return the folded value.
10272 if (!AddRec)
10273 return FoldedRec;
10274 break;
10275 }
10276
10277 // If the scope is outside the addrec's loop, evaluate it by using the
10278 // loop exit value of the addrec.
10279 if (!AddRec->getLoop()->contains(L)) {
10280 // To evaluate this recurrence, we need to know how many times the AddRec
10281 // loop iterates. Compute this now.
10282 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop());
10283 if (BackedgeTakenCount == getCouldNotCompute())
10284 return AddRec;
10285
10286 // Then, evaluate the AddRec.
10287 return AddRec->evaluateAtIteration(BackedgeTakenCount, *this);
10288 }
10289
10290 return AddRec;
10291 }
10292 case scTruncate:
10293 case scZeroExtend:
10294 case scSignExtend:
10295 case scPtrToAddr:
10296 case scPtrToInt:
10297 case scAddExpr:
10298 case scMulExpr:
10299 case scUDivExpr:
10300 case scUMaxExpr:
10301 case scSMaxExpr:
10302 case scUMinExpr:
10303 case scSMinExpr:
10304 case scSequentialUMinExpr: {
10305 ArrayRef<SCEVUse> Ops = V->operands();
10306 // Avoid performing the look-up in the common case where the specified
10307 // expression has no loop-variant portions.
10308 for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
10309 const SCEV *OpAtScope = getSCEVAtScope(Ops[i].getPointer(), L);
10310 if (OpAtScope != Ops[i].getPointer()) {
10311 // Okay, at least one of these operands is loop variant but might be
10312 // foldable. Build a new instance of the folded commutative expression.
10314 NewOps.reserve(Ops.size());
10315 append_range(NewOps, Ops.take_front(i));
10316 NewOps.push_back(OpAtScope);
10317
10318 for (++i; i != e; ++i) {
10319 OpAtScope = getSCEVAtScope(Ops[i].getPointer(), L);
10320 NewOps.push_back(OpAtScope);
10321 }
10322
10323 return getWithOperands(V, NewOps);
10324 }
10325 }
10326 // If we got here, all operands are loop invariant.
10327 return V;
10328 }
10329 case scUnknown: {
10330 // If this instruction is evolved from a constant-evolving PHI, compute the
10331 // exit value from the loop without using SCEVs.
10332 const SCEVUnknown *SU = cast<SCEVUnknown>(V);
10334 if (!I)
10335 return V; // This is some other type of SCEVUnknown, just return it.
10336
10337 if (PHINode *PN = dyn_cast<PHINode>(I)) {
10338 const Loop *CurrLoop = this->LI[I->getParent()];
10339 // Looking for loop exit value.
10340 if (CurrLoop && CurrLoop->getParentLoop() == L &&
10341 PN->getParent() == CurrLoop->getHeader()) {
10342 // Okay, there is no closed form solution for the PHI node. Check
10343 // to see if the loop that contains it has a known backedge-taken
10344 // count. If so, we may be able to force computation of the exit
10345 // value.
10346 const SCEV *BackedgeTakenCount = getBackedgeTakenCount(CurrLoop);
10347 // This trivial case can show up in some degenerate cases where
10348 // the incoming IR has not yet been fully simplified.
10349 if (BackedgeTakenCount->isZero()) {
10350 Value *InitValue = nullptr;
10351 bool MultipleInitValues = false;
10352 for (unsigned i = 0; i < PN->getNumIncomingValues(); i++) {
10353 if (!CurrLoop->contains(PN->getIncomingBlock(i))) {
10354 if (!InitValue)
10355 InitValue = PN->getIncomingValue(i);
10356 else if (InitValue != PN->getIncomingValue(i)) {
10357 MultipleInitValues = true;
10358 break;
10359 }
10360 }
10361 }
10362 if (!MultipleInitValues && InitValue)
10363 return getSCEV(InitValue);
10364 }
10365 // Do we have a loop invariant value flowing around the backedge
10366 // for a loop which must execute the backedge?
10367 if (!isa<SCEVCouldNotCompute>(BackedgeTakenCount) &&
10368 isKnownNonZero(BackedgeTakenCount) &&
10369 PN->getNumIncomingValues() == 2) {
10370
10371 unsigned InLoopPred =
10372 CurrLoop->contains(PN->getIncomingBlock(0)) ? 0 : 1;
10373 Value *BackedgeVal = PN->getIncomingValue(InLoopPred);
10374 if (CurrLoop->isLoopInvariant(BackedgeVal))
10375 return getSCEV(BackedgeVal);
10376 }
10377 if (auto *BTCC = dyn_cast<SCEVConstant>(BackedgeTakenCount)) {
10378 // Okay, we know how many times the containing loop executes. If
10379 // this is a constant evolving PHI node, get the final value at
10380 // the specified iteration number.
10381 Constant *RV =
10382 getConstantEvolutionLoopExitValue(PN, BTCC->getAPInt(), CurrLoop);
10383 if (RV)
10384 return getSCEV(RV);
10385 }
10386 }
10387 }
10388
10389 // Okay, this is an expression that we cannot symbolically evaluate
10390 // into a SCEV. Check to see if it's possible to symbolically evaluate
10391 // the arguments into constants, and if so, try to constant propagate the
10392 // result. This is particularly useful for computing loop exit values.
10393 if (!CanConstantFold(I))
10394 return V; // This is some other type of SCEVUnknown, just return it.
10395
10396 SmallVector<Constant *, 4> Operands;
10397 Operands.reserve(I->getNumOperands());
10398 bool MadeImprovement = false;
10399 for (Value *Op : I->operands()) {
10400 if (Constant *C = dyn_cast<Constant>(Op)) {
10401 Operands.push_back(C);
10402 continue;
10403 }
10404
10405 // If any of the operands is non-constant and if they are
10406 // non-integer and non-pointer, don't even try to analyze them
10407 // with scev techniques.
10408 if (!isSCEVable(Op->getType()))
10409 return V;
10410
10411 const SCEV *OrigV = getSCEV(Op);
10412 const SCEV *OpV = getSCEVAtScope(OrigV, L);
10413 MadeImprovement |= OrigV != OpV;
10414
10416 if (!C)
10417 return V;
10418 assert(C->getType() == Op->getType() && "Type mismatch");
10419 Operands.push_back(C);
10420 }
10421
10422 // Check to see if getSCEVAtScope actually made an improvement.
10423 if (!MadeImprovement)
10424 return V; // This is some other type of SCEVUnknown, just return it.
10425
10426 Constant *C = nullptr;
10427 const DataLayout &DL = getDataLayout();
10428 C = ConstantFoldInstOperands(I, Operands, DL, &TLI,
10429 /*AllowNonDeterministic=*/false);
10430 if (!C)
10431 return V;
10432 return getSCEV(C);
10433 }
10434 case scCouldNotCompute:
10435 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
10436 }
10437 llvm_unreachable("Unknown SCEV type!");
10438}
10439
10441 return getSCEVAtScope(getSCEV(V), L);
10442}
10443
10444const SCEV *ScalarEvolution::stripInjectiveFunctions(const SCEV *S) const {
10446 return stripInjectiveFunctions(ZExt->getOperand());
10448 return stripInjectiveFunctions(SExt->getOperand());
10449 return S;
10450}
10451
10452/// Finds the minimum unsigned root of the following equation:
10453///
10454/// A * X = B (mod N)
10455///
10456/// where N = 2^BW and BW is the common bit width of A and B. The signedness of
10457/// A and B isn't important.
10458///
10459/// If the equation does not have a solution, SCEVCouldNotCompute is returned.
10460static const SCEV *
10463 ScalarEvolution &SE, const Loop *L) {
10464 uint32_t BW = A.getBitWidth();
10465 assert(BW == SE.getTypeSizeInBits(B->getType()));
10466 assert(A != 0 && "A must be non-zero.");
10467
10468 // 1. D = gcd(A, N)
10469 //
10470 // The gcd of A and N may have only one prime factor: 2. The number of
10471 // trailing zeros in A is its multiplicity
10472 uint32_t Mult2 = A.countr_zero();
10473 // D = 2^Mult2
10474
10475 // 2. Check if B is divisible by D.
10476 //
10477 // B is divisible by D if and only if the multiplicity of prime factor 2 for B
10478 // is not less than multiplicity of this prime factor for D.
10479 unsigned MinTZ = SE.getMinTrailingZeros(B);
10480 // Try again with the terminator of the loop predecessor for context-specific
10481 // result, if MinTZ s too small.
10482 if (MinTZ < Mult2 && L->getLoopPredecessor())
10483 MinTZ = SE.getMinTrailingZeros(B, L->getLoopPredecessor()->getTerminator());
10484 if (MinTZ < Mult2) {
10485 // Check if we can prove there's no remainder using URem.
10486 const SCEV *URem =
10487 SE.getURemExpr(B, SE.getConstant(APInt::getOneBitSet(BW, Mult2)));
10488 const SCEV *Zero = SE.getZero(B->getType());
10489 if (!SE.isKnownPredicate(CmpInst::ICMP_EQ, URem, Zero)) {
10490 // Try to add a predicate ensuring B is a multiple of 1 << Mult2.
10491 if (!Predicates)
10492 return SE.getCouldNotCompute();
10493
10494 // Avoid adding a predicate that is known to be false.
10495 if (SE.isKnownPredicate(CmpInst::ICMP_NE, URem, Zero))
10496 return SE.getCouldNotCompute();
10497 Predicates->push_back(SE.getEqualPredicate(URem, Zero));
10498 }
10499 }
10500
10501 // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic
10502 // modulo (N / D).
10503 //
10504 // If D == 1, (N / D) == N == 2^BW, so we need one extra bit to represent
10505 // (N / D) in general. The inverse itself always fits into BW bits, though,
10506 // so we immediately truncate it.
10507 APInt AD = A.lshr(Mult2).trunc(BW - Mult2); // AD = A / D
10508 APInt I = AD.multiplicativeInverse().zext(BW);
10509
10510 // 4. Compute the minimum unsigned root of the equation:
10511 // I * (B / D) mod (N / D)
10512 // To simplify the computation, we factor out the divide by D:
10513 // (I * B mod N) / D
10514 const SCEV *D = SE.getConstant(APInt::getOneBitSet(BW, Mult2));
10515 return SE.getUDivExactExpr(SE.getMulExpr(B, SE.getConstant(I)), D);
10516}
10517
10518/// For a given quadratic addrec, generate coefficients of the corresponding
10519/// quadratic equation, multiplied by a common value to ensure that they are
10520/// integers.
10521/// The returned value is a tuple { A, B, C, M, BitWidth }, where
10522/// Ax^2 + Bx + C is the quadratic function, M is the value that A, B and C
10523/// were multiplied by, and BitWidth is the bit width of the original addrec
10524/// coefficients.
10525/// This function returns std::nullopt if the addrec coefficients are not
10526/// compile- time constants.
10527static std::optional<std::tuple<APInt, APInt, APInt, APInt, unsigned>>
10529 assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!");
10530 const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0));
10531 const SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1));
10532 const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2));
10533 LLVM_DEBUG(dbgs() << __func__ << ": analyzing quadratic addrec: "
10534 << *AddRec << '\n');
10535
10536 // We currently can only solve this if the coefficients are constants.
10537 if (!LC || !MC || !NC) {
10538 LLVM_DEBUG(dbgs() << __func__ << ": coefficients are not constant\n");
10539 return std::nullopt;
10540 }
10541
10542 APInt L = LC->getAPInt();
10543 APInt M = MC->getAPInt();
10544 APInt N = NC->getAPInt();
10545 assert(!N.isZero() && "This is not a quadratic addrec");
10546
10547 unsigned BitWidth = LC->getAPInt().getBitWidth();
10548 unsigned NewWidth = BitWidth + 1;
10549 LLVM_DEBUG(dbgs() << __func__ << ": addrec coeff bw: "
10550 << BitWidth << '\n');
10551 // The sign-extension (as opposed to a zero-extension) here matches the
10552 // extension used in SolveQuadraticEquationWrap (with the same motivation).
10553 N = N.sext(NewWidth);
10554 M = M.sext(NewWidth);
10555 L = L.sext(NewWidth);
10556
10557 // The increments are M, M+N, M+2N, ..., so the accumulated values are
10558 // L+M, (L+M)+(M+N), (L+M)+(M+N)+(M+2N), ..., that is,
10559 // L+M, L+2M+N, L+3M+3N, ...
10560 // After n iterations the accumulated value Acc is L + nM + n(n-1)/2 N.
10561 //
10562 // The equation Acc = 0 is then
10563 // L + nM + n(n-1)/2 N = 0, or 2L + 2M n + n(n-1) N = 0.
10564 // In a quadratic form it becomes:
10565 // N n^2 + (2M-N) n + 2L = 0.
10566
10567 APInt A = N;
10568 APInt B = 2 * M - A;
10569 APInt C = 2 * L;
10570 APInt T = APInt(NewWidth, 2);
10571 LLVM_DEBUG(dbgs() << __func__ << ": equation " << A << "x^2 + " << B
10572 << "x + " << C << ", coeff bw: " << NewWidth
10573 << ", multiplied by " << T << '\n');
10574 return std::make_tuple(A, B, C, T, BitWidth);
10575}
10576
10577/// Helper function to compare optional APInts:
10578/// (a) if X and Y both exist, return min(X, Y),
10579/// (b) if neither X nor Y exist, return std::nullopt,
10580/// (c) if exactly one of X and Y exists, return that value.
10581static std::optional<APInt> MinOptional(std::optional<APInt> X,
10582 std::optional<APInt> Y) {
10583 if (X && Y) {
10584 unsigned W = std::max(X->getBitWidth(), Y->getBitWidth());
10585 APInt XW = X->sext(W);
10586 APInt YW = Y->sext(W);
10587 return XW.slt(YW) ? *X : *Y;
10588 }
10589 if (!X && !Y)
10590 return std::nullopt;
10591 return X ? *X : *Y;
10592}
10593
10594/// Helper function to truncate an optional APInt to a given BitWidth.
10595/// When solving addrec-related equations, it is preferable to return a value
10596/// that has the same bit width as the original addrec's coefficients. If the
10597/// solution fits in the original bit width, truncate it (except for i1).
10598/// Returning a value of a different bit width may inhibit some optimizations.
10599///
10600/// In general, a solution to a quadratic equation generated from an addrec
10601/// may require BW+1 bits, where BW is the bit width of the addrec's
10602/// coefficients. The reason is that the coefficients of the quadratic
10603/// equation are BW+1 bits wide (to avoid truncation when converting from
10604/// the addrec to the equation).
10605static std::optional<APInt> TruncIfPossible(std::optional<APInt> X,
10606 unsigned BitWidth) {
10607 if (!X)
10608 return std::nullopt;
10609 unsigned W = X->getBitWidth();
10611 return X->trunc(BitWidth);
10612 return X;
10613}
10614
10615/// Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n
10616/// iterations. The values L, M, N are assumed to be signed, and they
10617/// should all have the same bit widths.
10618/// Find the least n >= 0 such that c(n) = 0 in the arithmetic modulo 2^BW,
10619/// where BW is the bit width of the addrec's coefficients.
10620/// If the calculated value is a BW-bit integer (for BW > 1), it will be
10621/// returned as such, otherwise the bit width of the returned value may
10622/// be greater than BW.
10623///
10624/// This function returns std::nullopt if
10625/// (a) the addrec coefficients are not constant, or
10626/// (b) SolveQuadraticEquationWrap was unable to find a solution. For cases
10627/// like x^2 = 5, no integer solutions exist, in other cases an integer
10628/// solution may exist, but SolveQuadraticEquationWrap may fail to find it.
10629static std::optional<APInt>
10631 APInt A, B, C, M;
10632 unsigned BitWidth;
10633 auto T = GetQuadraticEquation(AddRec);
10634 if (!T)
10635 return std::nullopt;
10636
10637 std::tie(A, B, C, M, BitWidth) = *T;
10638 LLVM_DEBUG(dbgs() << __func__ << ": solving for unsigned overflow\n");
10639 std::optional<APInt> X =
10641 if (!X)
10642 return std::nullopt;
10643
10644 ConstantInt *CX = ConstantInt::get(SE.getContext(), *X);
10645 ConstantInt *V = EvaluateConstantChrecAtConstant(AddRec, CX, SE);
10646 if (!V->isZero())
10647 return std::nullopt;
10648
10649 return TruncIfPossible(X, BitWidth);
10650}
10651
10652/// Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n
10653/// iterations. The values M, N are assumed to be signed, and they
10654/// should all have the same bit widths.
10655/// Find the least n such that c(n) does not belong to the given range,
10656/// while c(n-1) does.
10657///
10658/// This function returns std::nullopt if
10659/// (a) the addrec coefficients are not constant, or
10660/// (b) SolveQuadraticEquationWrap was unable to find a solution for the
10661/// bounds of the range.
10662static std::optional<APInt>
10664 const ConstantRange &Range, ScalarEvolution &SE) {
10665 assert(AddRec->getOperand(0)->isZero() &&
10666 "Starting value of addrec should be 0");
10667 LLVM_DEBUG(dbgs() << __func__ << ": solving boundary crossing for range "
10668 << Range << ", addrec " << *AddRec << '\n');
10669 // This case is handled in getNumIterationsInRange. Here we can assume that
10670 // we start in the range.
10671 assert(Range.contains(APInt(SE.getTypeSizeInBits(AddRec->getType()), 0)) &&
10672 "Addrec's initial value should be in range");
10673
10674 APInt A, B, C, M;
10675 unsigned BitWidth;
10676 auto T = GetQuadraticEquation(AddRec);
10677 if (!T)
10678 return std::nullopt;
10679
10680 // Be careful about the return value: there can be two reasons for not
10681 // returning an actual number. First, if no solutions to the equations
10682 // were found, and second, if the solutions don't leave the given range.
10683 // The first case means that the actual solution is "unknown", the second
10684 // means that it's known, but not valid. If the solution is unknown, we
10685 // cannot make any conclusions.
10686 // Return a pair: the optional solution and a flag indicating if the
10687 // solution was found.
10688 auto SolveForBoundary =
10689 [&](APInt Bound) -> std::pair<std::optional<APInt>, bool> {
10690 // Solve for signed overflow and unsigned overflow, pick the lower
10691 // solution.
10692 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: checking boundary "
10693 << Bound << " (before multiplying by " << M << ")\n");
10694 Bound *= M; // The quadratic equation multiplier.
10695
10696 std::optional<APInt> SO;
10697 if (BitWidth > 1) {
10698 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10699 "signed overflow\n");
10701 }
10702 LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
10703 "unsigned overflow\n");
10704 std::optional<APInt> UO =
10706
10707 auto LeavesRange = [&] (const APInt &X) {
10708 ConstantInt *C0 = ConstantInt::get(SE.getContext(), X);
10709 ConstantInt *V0 = EvaluateConstantChrecAtConstant(AddRec, C0, SE);
10710 if (Range.contains(V0->getValue()))
10711 return false;
10712 // X should be at least 1, so X-1 is non-negative.
10713 ConstantInt *C1 = ConstantInt::get(SE.getContext(), X-1);
10714 ConstantInt *V1 = EvaluateConstantChrecAtConstant(AddRec, C1, SE);
10715 if (Range.contains(V1->getValue()))
10716 return true;
10717 return false;
10718 };
10719
10720 // If SolveQuadraticEquationWrap returns std::nullopt, it means that there
10721 // can be a solution, but the function failed to find it. We cannot treat it
10722 // as "no solution".
10723 if (!SO || !UO)
10724 return {std::nullopt, false};
10725
10726 // Check the smaller value first to see if it leaves the range.
10727 // At this point, both SO and UO must have values.
10728 std::optional<APInt> Min = MinOptional(SO, UO);
10729 if (LeavesRange(*Min))
10730 return { Min, true };
10731 std::optional<APInt> Max = Min == SO ? UO : SO;
10732 if (LeavesRange(*Max))
10733 return { Max, true };
10734
10735 // Solutions were found, but were eliminated, hence the "true".
10736 return {std::nullopt, true};
10737 };
10738
10739 std::tie(A, B, C, M, BitWidth) = *T;
10740 // Lower bound is inclusive, subtract 1 to represent the exiting value.
10741 APInt Lower = Range.getLower().sext(A.getBitWidth()) - 1;
10742 APInt Upper = Range.getUpper().sext(A.getBitWidth());
10743 auto SL = SolveForBoundary(Lower);
10744 auto SU = SolveForBoundary(Upper);
10745 // If any of the solutions was unknown, no meaninigful conclusions can
10746 // be made.
10747 if (!SL.second || !SU.second)
10748 return std::nullopt;
10749
10750 // Claim: The correct solution is not some value between Min and Max.
10751 //
10752 // Justification: Assuming that Min and Max are different values, one of
10753 // them is when the first signed overflow happens, the other is when the
10754 // first unsigned overflow happens. Crossing the range boundary is only
10755 // possible via an overflow (treating 0 as a special case of it, modeling
10756 // an overflow as crossing k*2^W for some k).
10757 //
10758 // The interesting case here is when Min was eliminated as an invalid
10759 // solution, but Max was not. The argument is that if there was another
10760 // overflow between Min and Max, it would also have been eliminated if
10761 // it was considered.
10762 //
10763 // For a given boundary, it is possible to have two overflows of the same
10764 // type (signed/unsigned) without having the other type in between: this
10765 // can happen when the vertex of the parabola is between the iterations
10766 // corresponding to the overflows. This is only possible when the two
10767 // overflows cross k*2^W for the same k. In such case, if the second one
10768 // left the range (and was the first one to do so), the first overflow
10769 // would have to enter the range, which would mean that either we had left
10770 // the range before or that we started outside of it. Both of these cases
10771 // are contradictions.
10772 //
10773 // Claim: In the case where SolveForBoundary returns std::nullopt, the correct
10774 // solution is not some value between the Max for this boundary and the
10775 // Min of the other boundary.
10776 //
10777 // Justification: Assume that we had such Max_A and Min_B corresponding
10778 // to range boundaries A and B and such that Max_A < Min_B. If there was
10779 // a solution between Max_A and Min_B, it would have to be caused by an
10780 // overflow corresponding to either A or B. It cannot correspond to B,
10781 // since Min_B is the first occurrence of such an overflow. If it
10782 // corresponded to A, it would have to be either a signed or an unsigned
10783 // overflow that is larger than both eliminated overflows for A. But
10784 // between the eliminated overflows and this overflow, the values would
10785 // cover the entire value space, thus crossing the other boundary, which
10786 // is a contradiction.
10787
10788 return TruncIfPossible(MinOptional(SL.first, SU.first), BitWidth);
10789}
10790
10791ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
10792 const Loop *L,
10793 bool ControlsOnlyExit,
10794 bool AllowPredicates) {
10795
10796 // This is only used for loops with a "x != y" exit test. The exit condition
10797 // is now expressed as a single expression, V = x-y. So the exit test is
10798 // effectively V != 0. We know and take advantage of the fact that this
10799 // expression only being used in a comparison by zero context.
10800
10802 // If the value is a constant
10803 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10804 // If the value is already zero, the branch will execute zero times.
10805 if (C->getValue()->isZero()) return C;
10806 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10807 }
10808
10809 const SCEVAddRecExpr *AddRec =
10810 dyn_cast<SCEVAddRecExpr>(stripInjectiveFunctions(V));
10811
10812 if (!AddRec && AllowPredicates)
10813 // Try to make this an AddRec using runtime tests, in the first X
10814 // iterations of this loop, where X is the SCEV expression found by the
10815 // algorithm below.
10816 AddRec = convertSCEVToAddRecWithPredicates(V, L, Predicates);
10817
10818 if (!AddRec || AddRec->getLoop() != L)
10819 return getCouldNotCompute();
10820
10821 // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of
10822 // the quadratic equation to solve it.
10823 if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) {
10824 // We can only use this value if the chrec ends up with an exact zero
10825 // value at this index. When solving for "X*X != 5", for example, we
10826 // should not accept a root of 2.
10827 if (auto S = SolveQuadraticAddRecExact(AddRec, *this)) {
10828 const auto *R = cast<SCEVConstant>(getConstant(*S));
10829 return ExitLimit(R, R, R, false, Predicates);
10830 }
10831 return getCouldNotCompute();
10832 }
10833
10834 // Otherwise we can only handle this if it is affine.
10835 if (!AddRec->isAffine())
10836 return getCouldNotCompute();
10837
10838 // If this is an affine expression, the execution count of this branch is
10839 // the minimum unsigned root of the following equation:
10840 //
10841 // Start + Step*N = 0 (mod 2^BW)
10842 //
10843 // equivalent to:
10844 //
10845 // Step*N = -Start (mod 2^BW)
10846 //
10847 // where BW is the common bit width of Start and Step.
10848
10849 // Get the initial value for the loop.
10850 const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop());
10851 const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop());
10852
10853 if (!isLoopInvariant(Step, L))
10854 return getCouldNotCompute();
10855
10856 LoopGuards Guards = LoopGuards::collect(L, *this);
10857 // Specialize step for this loop so we get context sensitive facts below.
10858 const SCEV *StepWLG = applyLoopGuards(Step, Guards);
10859
10860 // For positive steps (counting up until unsigned overflow):
10861 // N = -Start/Step (as unsigned)
10862 // For negative steps (counting down to zero):
10863 // N = Start/-Step
10864 // First compute the unsigned distance from zero in the direction of Step.
10865 bool CountDown = isKnownNegative(StepWLG);
10866 if (!CountDown && !isKnownNonNegative(StepWLG))
10867 return getCouldNotCompute();
10868
10869 const SCEV *Distance = CountDown ? Start : getNegativeSCEV(Start);
10870 // Handle unitary steps, which cannot wraparound.
10871 // 1*N = -Start; -1*N = Start (mod 2^BW), so:
10872 // N = Distance (as unsigned)
10873
10874 if (match(Step, m_CombineOr(m_scev_One(), m_scev_AllOnes()))) {
10875 APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, Guards));
10876 MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance));
10877
10878 // When a loop like "for (int i = 0; i != n; ++i) { /* body */ }" is rotated,
10879 // we end up with a loop whose backedge-taken count is n - 1. Detect this
10880 // case, and see if we can improve the bound.
10881 //
10882 // Explicitly handling this here is necessary because getUnsignedRange
10883 // isn't context-sensitive; it doesn't know that we only care about the
10884 // range inside the loop.
10885 const SCEV *Zero = getZero(Distance->getType());
10886 const SCEV *One = getOne(Distance->getType());
10887 const SCEV *DistancePlusOne = getAddExpr(Distance, One);
10888 if (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, DistancePlusOne, Zero)) {
10889 // If Distance + 1 doesn't overflow, we can compute the maximum distance
10890 // as "unsigned_max(Distance + 1) - 1".
10891 ConstantRange CR = getUnsignedRange(DistancePlusOne);
10892 MaxBECount = APIntOps::umin(MaxBECount, CR.getUnsignedMax() - 1);
10893 }
10894 return ExitLimit(Distance, getConstant(MaxBECount), Distance, false,
10895 Predicates);
10896 }
10897
10898 // If the condition controls loop exit (the loop exits only if the expression
10899 // is true) and the addition is no-wrap we can use unsigned divide to
10900 // compute the backedge count. In this case, the step may not divide the
10901 // distance, but we don't care because if the condition is "missed" the loop
10902 // will have undefined behavior due to wrapping.
10903 if (ControlsOnlyExit && AddRec->hasNoSelfWrap() &&
10904 loopHasNoAbnormalExits(AddRec->getLoop())) {
10905
10906 // If the stride is zero and the start is non-zero, the loop must be
10907 // infinite. In C++, most loops are finite by assumption, in which case the
10908 // step being zero implies UB must execute if the loop is entered.
10909 if (!(loopIsFiniteByAssumption(L) && isKnownNonZero(Start)) &&
10910 !isKnownNonZero(StepWLG))
10911 return getCouldNotCompute();
10912
10913 const SCEV *Exact =
10914 getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
10915 const SCEV *ConstantMax = getCouldNotCompute();
10916 if (Exact != getCouldNotCompute()) {
10917 APInt MaxInt = getUnsignedRangeMax(applyLoopGuards(Exact, Guards));
10918 ConstantMax =
10920 }
10921 const SCEV *SymbolicMax =
10922 isa<SCEVCouldNotCompute>(Exact) ? ConstantMax : Exact;
10923 return ExitLimit(Exact, ConstantMax, SymbolicMax, false, Predicates);
10924 }
10925
10926 // Solve the general equation.
10927 const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
10928 if (!StepC || StepC->getValue()->isZero())
10929 return getCouldNotCompute();
10930 const SCEV *E = SolveLinEquationWithOverflow(
10931 StepC->getAPInt(), getNegativeSCEV(Start),
10932 AllowPredicates ? &Predicates : nullptr, *this, L);
10933
10934 const SCEV *M = E;
10935 if (E != getCouldNotCompute()) {
10936 APInt MaxWithGuards = getUnsignedRangeMax(applyLoopGuards(E, Guards));
10937 M = getConstant(APIntOps::umin(MaxWithGuards, getUnsignedRangeMax(E)));
10938 }
10939 auto *S = isa<SCEVCouldNotCompute>(E) ? M : E;
10940 return ExitLimit(E, M, S, false, Predicates);
10941}
10942
10943ScalarEvolution::ExitLimit
10944ScalarEvolution::howFarToNonZero(const SCEV *V, const Loop *L) {
10945 // Loops that look like: while (X == 0) are very strange indeed. We don't
10946 // handle them yet except for the trivial case. This could be expanded in the
10947 // future as needed.
10948
10949 // If the value is a constant, check to see if it is known to be non-zero
10950 // already. If so, the backedge will execute zero times.
10951 if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
10952 if (!C->getValue()->isZero())
10953 return getZero(C->getType());
10954 return getCouldNotCompute(); // Otherwise it will loop infinitely.
10955 }
10956
10957 // We could implement others, but I really doubt anyone writes loops like
10958 // this, and if they did, they would already be constant folded.
10959 return getCouldNotCompute();
10960}
10961
10962std::pair<const BasicBlock *, const BasicBlock *>
10963ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(const BasicBlock *BB)
10964 const {
10965 // If the block has a unique predecessor, then there is no path from the
10966 // predecessor to the block that does not go through the direct edge
10967 // from the predecessor to the block.
10968 if (const BasicBlock *Pred = BB->getSinglePredecessor())
10969 return {Pred, BB};
10970
10971 // A loop's header is defined to be a block that dominates the loop.
10972 // If the header has a unique predecessor outside the loop, it must be
10973 // a block that has exactly one successor that can reach the loop.
10974 if (const Loop *L = LI.getLoopFor(BB))
10975 return {L->getLoopPredecessor(), L->getHeader()};
10976
10977 return {nullptr, BB};
10978}
10979
10980/// SCEV structural equivalence is usually sufficient for testing whether two
10981/// expressions are equal, however for the purposes of looking for a condition
10982/// guarding a loop, it can be useful to be a little more general, since a
10983/// front-end may have replicated the controlling expression.
10984static bool HasSameValue(const SCEV *A, const SCEV *B) {
10985 // Quick check to see if they are the same SCEV.
10986 if (A == B) return true;
10987
10988 auto ComputesEqualValues = [](const Instruction *A, const Instruction *B) {
10989 // Not all instructions that are "identical" compute the same value. For
10990 // instance, two distinct alloca instructions allocating the same type are
10991 // identical and do not read memory; but compute distinct values.
10992 return A->isIdenticalTo(B) && (isa<BinaryOperator>(A) || isa<GetElementPtrInst>(A));
10993 };
10994
10995 // Otherwise, if they're both SCEVUnknown, it's possible that they hold
10996 // two different instructions with the same value. Check for this case.
10997 if (const SCEVUnknown *AU = dyn_cast<SCEVUnknown>(A))
10998 if (const SCEVUnknown *BU = dyn_cast<SCEVUnknown>(B))
10999 if (const Instruction *AI = dyn_cast<Instruction>(AU->getValue()))
11000 if (const Instruction *BI = dyn_cast<Instruction>(BU->getValue()))
11001 if (ComputesEqualValues(AI, BI))
11002 return true;
11003
11004 // Otherwise assume they may have a different value.
11005 return false;
11006}
11007
11008static bool MatchBinarySub(const SCEV *S, SCEVUse &LHS, SCEVUse &RHS) {
11009 const SCEV *Op0, *Op1;
11010 if (!match(S, m_scev_Add(m_SCEV(Op0), m_SCEV(Op1))))
11011 return false;
11012 if (match(Op0, m_scev_Mul(m_scev_AllOnes(), m_SCEV(RHS)))) {
11013 LHS = Op1;
11014 return true;
11015 }
11016 if (match(Op1, m_scev_Mul(m_scev_AllOnes(), m_SCEV(RHS)))) {
11017 LHS = Op0;
11018 return true;
11019 }
11020 return false;
11021}
11022
11024 SCEVUse &RHS, unsigned Depth) {
11025 bool Changed = false;
11026 // Simplifies ICMP to trivial true or false by turning it into '0 == 0' or
11027 // '0 != 0'.
11028 auto TrivialCase = [&](bool TriviallyTrue) {
11030 Pred = TriviallyTrue ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE;
11031 return true;
11032 };
11033 // If we hit the max recursion limit bail out.
11034 if (Depth >= 3)
11035 return false;
11036
11037 const SCEV *NewLHS, *NewRHS;
11038 if (match(LHS, m_scev_c_Mul(m_SCEV(NewLHS), m_SCEVVScale())) &&
11039 match(RHS, m_scev_c_Mul(m_SCEV(NewRHS), m_SCEVVScale()))) {
11040 const SCEVMulExpr *LMul = cast<SCEVMulExpr>(LHS);
11041 const SCEVMulExpr *RMul = cast<SCEVMulExpr>(RHS);
11042
11043 // (X * vscale) pred (Y * vscale) ==> X pred Y
11044 // when both multiples are NSW.
11045 // (X * vscale) uicmp/eq/ne (Y * vscale) ==> X uicmp/eq/ne Y
11046 // when both multiples are NUW.
11047 if ((LMul->hasNoSignedWrap() && RMul->hasNoSignedWrap()) ||
11048 (LMul->hasNoUnsignedWrap() && RMul->hasNoUnsignedWrap() &&
11049 !ICmpInst::isSigned(Pred))) {
11050 LHS = NewLHS;
11051 RHS = NewRHS;
11052 Changed = true;
11053 }
11054 }
11055
11056 // Canonicalize a constant to the right side.
11057 if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
11058 // Check for both operands constant.
11059 if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
11060 if (!ICmpInst::compare(LHSC->getAPInt(), RHSC->getAPInt(), Pred))
11061 return TrivialCase(false);
11062 return TrivialCase(true);
11063 }
11064 // Otherwise swap the operands to put the constant on the right.
11065 std::swap(LHS, RHS);
11067 Changed = true;
11068 }
11069
11070 // If we're comparing an addrec with a value which is loop-invariant in the
11071 // addrec's loop, put the addrec on the left. Also make a dominance check,
11072 // as both operands could be addrecs loop-invariant in each other's loop.
11073 if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(RHS)) {
11074 const Loop *L = AR->getLoop();
11075 if (isLoopInvariant(LHS, L) && properlyDominates(LHS, L->getHeader())) {
11076 std::swap(LHS, RHS);
11078 Changed = true;
11079 }
11080 }
11081
11082 // If there's a constant operand, canonicalize comparisons with boundary
11083 // cases, and canonicalize *-or-equal comparisons to regular comparisons.
11084 if (const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS)) {
11085 const APInt &RA = RC->getAPInt();
11086
11087 bool SimplifiedByConstantRange = false;
11088
11089 if (!ICmpInst::isEquality(Pred)) {
11091 if (ExactCR.isFullSet())
11092 return TrivialCase(true);
11093 if (ExactCR.isEmptySet())
11094 return TrivialCase(false);
11095
11096 APInt NewRHS;
11097 CmpInst::Predicate NewPred;
11098 if (ExactCR.getEquivalentICmp(NewPred, NewRHS) &&
11099 ICmpInst::isEquality(NewPred)) {
11100 // We were able to convert an inequality to an equality.
11101 Pred = NewPred;
11102 RHS = getConstant(NewRHS);
11103 Changed = SimplifiedByConstantRange = true;
11104 }
11105 }
11106
11107 if (!SimplifiedByConstantRange) {
11108 switch (Pred) {
11109 default:
11110 break;
11111 case ICmpInst::ICMP_EQ:
11112 case ICmpInst::ICMP_NE:
11113 // Fold ((-1) * %a) + %b == 0 (equivalent to %b-%a == 0) into %a == %b.
11114 if (RA.isZero() && MatchBinarySub(LHS, LHS, RHS))
11115 Changed = true;
11116 break;
11117
11118 // The "Should have been caught earlier!" messages refer to the fact
11119 // that the ExactCR.isFullSet() or ExactCR.isEmptySet() check above
11120 // should have fired on the corresponding cases, and canonicalized the
11121 // check to trivial case.
11122
11123 case ICmpInst::ICMP_UGE:
11124 assert(!RA.isMinValue() && "Should have been caught earlier!");
11125 Pred = ICmpInst::ICMP_UGT;
11126 RHS = getConstant(RA - 1);
11127 Changed = true;
11128 break;
11129 case ICmpInst::ICMP_ULE:
11130 assert(!RA.isMaxValue() && "Should have been caught earlier!");
11131 Pred = ICmpInst::ICMP_ULT;
11132 RHS = getConstant(RA + 1);
11133 Changed = true;
11134 break;
11135 case ICmpInst::ICMP_SGE:
11136 assert(!RA.isMinSignedValue() && "Should have been caught earlier!");
11137 Pred = ICmpInst::ICMP_SGT;
11138 RHS = getConstant(RA - 1);
11139 Changed = true;
11140 break;
11141 case ICmpInst::ICMP_SLE:
11142 assert(!RA.isMaxSignedValue() && "Should have been caught earlier!");
11143 Pred = ICmpInst::ICMP_SLT;
11144 RHS = getConstant(RA + 1);
11145 Changed = true;
11146 break;
11147 }
11148 }
11149 }
11150
11151 // Check for obvious equality.
11152 if (HasSameValue(LHS, RHS)) {
11153 if (ICmpInst::isTrueWhenEqual(Pred))
11154 return TrivialCase(true);
11156 return TrivialCase(false);
11157 }
11158
11159 // If possible, canonicalize GE/LE comparisons to GT/LT comparisons, by
11160 // adding or subtracting 1 from one of the operands.
11161 switch (Pred) {
11162 case ICmpInst::ICMP_SLE:
11163 if (!getSignedRangeMax(RHS).isMaxSignedValue()) {
11164 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
11166 Pred = ICmpInst::ICMP_SLT;
11167 Changed = true;
11168 } else if (!getSignedRangeMin(LHS).isMinSignedValue()) {
11169 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS,
11171 Pred = ICmpInst::ICMP_SLT;
11172 Changed = true;
11173 }
11174 break;
11175 case ICmpInst::ICMP_SGE:
11176 if (!getSignedRangeMin(RHS).isMinSignedValue()) {
11177 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS,
11179 Pred = ICmpInst::ICMP_SGT;
11180 Changed = true;
11181 } else if (!getSignedRangeMax(LHS).isMaxSignedValue()) {
11182 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
11184 Pred = ICmpInst::ICMP_SGT;
11185 Changed = true;
11186 }
11187 break;
11188 case ICmpInst::ICMP_ULE:
11189 if (!getUnsignedRangeMax(RHS).isMaxValue()) {
11190 RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
11192 Pred = ICmpInst::ICMP_ULT;
11193 Changed = true;
11194 } else if (!getUnsignedRangeMin(LHS).isMinValue()) {
11195 LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS);
11196 Pred = ICmpInst::ICMP_ULT;
11197 Changed = true;
11198 }
11199 break;
11200 case ICmpInst::ICMP_UGE:
11201 // If RHS is an op we can fold the -1, try that first.
11202 // Otherwise prefer LHS to preserve the nuw flag.
11203 if ((isa<SCEVConstant>(RHS) ||
11205 isa<SCEVConstant>(cast<SCEVNAryExpr>(RHS)->getOperand(0)))) &&
11206 !getUnsignedRangeMin(RHS).isMinValue()) {
11207 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
11208 Pred = ICmpInst::ICMP_UGT;
11209 Changed = true;
11210 } else if (!getUnsignedRangeMax(LHS).isMaxValue()) {
11211 LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
11213 Pred = ICmpInst::ICMP_UGT;
11214 Changed = true;
11215 } else if (!getUnsignedRangeMin(RHS).isMinValue()) {
11216 RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS);
11217 Pred = ICmpInst::ICMP_UGT;
11218 Changed = true;
11219 }
11220 break;
11221 default:
11222 break;
11223 }
11224
11225 // TODO: More simplifications are possible here.
11226
11227 // Recursively simplify until we either hit a recursion limit or nothing
11228 // changes.
11229 if (Changed)
11230 (void)SimplifyICmpOperands(Pred, LHS, RHS, Depth + 1);
11231
11232 return Changed;
11233}
11234
11236 return getSignedRangeMax(S).isNegative();
11237}
11238
11242
11244 return !getSignedRangeMin(S).isNegative();
11245}
11246
11250
11252 // Query push down for cases where the unsigned range is
11253 // less than sufficient.
11254 if (const auto *SExt = dyn_cast<SCEVSignExtendExpr>(S))
11255 return isKnownNonZero(SExt->getOperand(0));
11256 return getUnsignedRangeMin(S) != 0;
11257}
11258
11260 bool OrNegative) {
11261 auto NonRecursive = [OrNegative](const SCEV *S) {
11262 if (auto *C = dyn_cast<SCEVConstant>(S))
11263 return C->getAPInt().isPowerOf2() ||
11264 (OrNegative && C->getAPInt().isNegatedPowerOf2());
11265
11266 // vscale is a power-of-two.
11267 return isa<SCEVVScale>(S);
11268 };
11269
11270 if (NonRecursive(S))
11271 return true;
11272
11273 auto *Mul = dyn_cast<SCEVMulExpr>(S);
11274 if (!Mul)
11275 return false;
11276 return all_of(Mul->operands(), NonRecursive) && (OrZero || isKnownNonZero(S));
11277}
11278
11280 const SCEV *S, uint64_t M,
11282 if (M == 0)
11283 return false;
11284 if (M == 1)
11285 return true;
11286
11287 // Recursively check AddRec operands. An AddRecExpr S is a multiple of M if S
11288 // starts with a multiple of M and at every iteration step S only adds
11289 // multiples of M.
11290 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(S))
11291 return isKnownMultipleOf(AddRec->getStart(), M, Assumptions) &&
11292 isKnownMultipleOf(AddRec->getStepRecurrence(*this), M, Assumptions);
11293
11294 // For a constant, check that "S % M == 0".
11295 if (auto *Cst = dyn_cast<SCEVConstant>(S)) {
11296 APInt C = Cst->getAPInt();
11297 return C.urem(M) == 0;
11298 }
11299
11300 // TODO: Also check other SCEV expressions, i.e., SCEVAddRecExpr, etc.
11301
11302 // Basic tests have failed.
11303 // Check "S % M == 0" at compile time and record runtime Assumptions.
11304 auto *STy = dyn_cast<IntegerType>(S->getType());
11305 const SCEV *SmodM =
11306 getURemExpr(S, getConstant(ConstantInt::get(STy, M, false)));
11307 const SCEV *Zero = getZero(STy);
11308
11309 // Check whether "S % M == 0" is known at compile time.
11310 if (isKnownPredicate(ICmpInst::ICMP_EQ, SmodM, Zero))
11311 return true;
11312
11313 // Check whether "S % M != 0" is known at compile time.
11314 if (isKnownPredicate(ICmpInst::ICMP_NE, SmodM, Zero))
11315 return false;
11316
11318
11319 // Detect redundant predicates.
11320 for (auto *A : Assumptions)
11321 if (A->implies(P, *this))
11322 return true;
11323
11324 // Only record non-redundant predicates.
11325 Assumptions.push_back(P);
11326 return true;
11327}
11328
11330 return ((isKnownNonNegative(S1) && isKnownNonNegative(S2)) ||
11332}
11333
11334std::pair<const SCEV *, const SCEV *>
11336 // Compute SCEV on entry of loop L.
11337 const SCEV *Start = SCEVInitRewriter::rewrite(S, L, *this);
11338 if (Start == getCouldNotCompute())
11339 return { Start, Start };
11340 // Compute post increment SCEV for loop L.
11341 const SCEV *PostInc = SCEVPostIncRewriter::rewrite(S, L, *this);
11342 assert(PostInc != getCouldNotCompute() && "Unexpected could not compute");
11343 return { Start, PostInc };
11344}
11345
11347 SCEVUse RHS) {
11348 // First collect all loops.
11350 getUsedLoops(LHS, LoopsUsed);
11351 getUsedLoops(RHS, LoopsUsed);
11352
11353 if (LoopsUsed.empty())
11354 return false;
11355
11356 // Domination relationship must be a linear order on collected loops.
11357#ifndef NDEBUG
11358 for (const auto *L1 : LoopsUsed)
11359 for (const auto *L2 : LoopsUsed)
11360 assert((DT.dominates(L1->getHeader(), L2->getHeader()) ||
11361 DT.dominates(L2->getHeader(), L1->getHeader())) &&
11362 "Domination relationship is not a linear order");
11363#endif
11364
11365 const Loop *MDL =
11366 *llvm::max_element(LoopsUsed, [&](const Loop *L1, const Loop *L2) {
11367 return DT.properlyDominates(L1->getHeader(), L2->getHeader());
11368 });
11369
11370 // Get init and post increment value for LHS.
11371 auto SplitLHS = SplitIntoInitAndPostInc(MDL, LHS);
11372 // if LHS contains unknown non-invariant SCEV then bail out.
11373 if (SplitLHS.first == getCouldNotCompute())
11374 return false;
11375 assert (SplitLHS.second != getCouldNotCompute() && "Unexpected CNC");
11376 // Get init and post increment value for RHS.
11377 auto SplitRHS = SplitIntoInitAndPostInc(MDL, RHS);
11378 // if RHS contains unknown non-invariant SCEV then bail out.
11379 if (SplitRHS.first == getCouldNotCompute())
11380 return false;
11381 assert (SplitRHS.second != getCouldNotCompute() && "Unexpected CNC");
11382 // It is possible that init SCEV contains an invariant load but it does
11383 // not dominate MDL and is not available at MDL loop entry, so we should
11384 // check it here.
11385 if (!isAvailableAtLoopEntry(SplitLHS.first, MDL) ||
11386 !isAvailableAtLoopEntry(SplitRHS.first, MDL))
11387 return false;
11388
11389 // It seems backedge guard check is faster than entry one so in some cases
11390 // it can speed up whole estimation by short circuit
11391 return isLoopBackedgeGuardedByCond(MDL, Pred, SplitLHS.second,
11392 SplitRHS.second) &&
11393 isLoopEntryGuardedByCond(MDL, Pred, SplitLHS.first, SplitRHS.first);
11394}
11395
11397 SCEVUse RHS) {
11398 // Canonicalize the inputs first.
11399 (void)SimplifyICmpOperands(Pred, LHS, RHS);
11400
11401 if (isKnownViaInduction(Pred, LHS, RHS))
11402 return true;
11403
11404 if (isKnownPredicateViaSplitting(Pred, LHS, RHS))
11405 return true;
11406
11407 // Otherwise see what can be done with some simple reasoning.
11408 return isKnownViaNonRecursiveReasoning(Pred, LHS, RHS);
11409}
11410
11412 const SCEV *LHS,
11413 const SCEV *RHS) {
11414 if (isKnownPredicate(Pred, LHS, RHS))
11415 return true;
11417 return false;
11418 return std::nullopt;
11419}
11420
11422 const SCEV *RHS,
11423 const Instruction *CtxI) {
11424 // TODO: Analyze guards and assumes from Context's block.
11425 return isKnownPredicate(Pred, LHS, RHS) ||
11426 isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS);
11427}
11428
11429std::optional<bool>
11431 const SCEV *RHS, const Instruction *CtxI) {
11432 std::optional<bool> KnownWithoutContext = evaluatePredicate(Pred, LHS, RHS);
11433 if (KnownWithoutContext)
11434 return KnownWithoutContext;
11435
11436 if (isBasicBlockEntryGuardedByCond(CtxI->getParent(), Pred, LHS, RHS))
11437 return true;
11439 CtxI->getParent(), ICmpInst::getInverseCmpPredicate(Pred), LHS, RHS))
11440 return false;
11441 return std::nullopt;
11442}
11443
11445 const SCEVAddRecExpr *LHS,
11446 const SCEV *RHS) {
11447 const Loop *L = LHS->getLoop();
11448 return isLoopEntryGuardedByCond(L, Pred, LHS->getStart(), RHS) &&
11449 isLoopBackedgeGuardedByCond(L, Pred, LHS->getPostIncExpr(*this), RHS);
11450}
11451
11452std::optional<ScalarEvolution::MonotonicPredicateType>
11454 ICmpInst::Predicate Pred) {
11455 auto Result = getMonotonicPredicateTypeImpl(LHS, Pred);
11456
11457#ifndef NDEBUG
11458 // Verify an invariant: inverting the predicate should turn a monotonically
11459 // increasing change to a monotonically decreasing one, and vice versa.
11460 if (Result) {
11461 auto ResultSwapped =
11462 getMonotonicPredicateTypeImpl(LHS, ICmpInst::getSwappedPredicate(Pred));
11463
11464 assert(*ResultSwapped != *Result &&
11465 "monotonicity should flip as we flip the predicate");
11466 }
11467#endif
11468
11469 return Result;
11470}
11471
11472std::optional<ScalarEvolution::MonotonicPredicateType>
11473ScalarEvolution::getMonotonicPredicateTypeImpl(const SCEVAddRecExpr *LHS,
11474 ICmpInst::Predicate Pred) {
11475 // A zero step value for LHS means the induction variable is essentially a
11476 // loop invariant value. We don't really depend on the predicate actually
11477 // flipping from false to true (for increasing predicates, and the other way
11478 // around for decreasing predicates), all we care about is that *if* the
11479 // predicate changes then it only changes from false to true.
11480 //
11481 // A zero step value in itself is not very useful, but there may be places
11482 // where SCEV can prove X >= 0 but not prove X > 0, so it is helpful to be
11483 // as general as possible.
11484
11485 // Only handle LE/LT/GE/GT predicates.
11486 if (!ICmpInst::isRelational(Pred))
11487 return std::nullopt;
11488
11489 bool IsGreater = ICmpInst::isGE(Pred) || ICmpInst::isGT(Pred);
11490 assert((IsGreater || ICmpInst::isLE(Pred) || ICmpInst::isLT(Pred)) &&
11491 "Should be greater or less!");
11492
11493 // Check that AR does not wrap.
11494 if (ICmpInst::isUnsigned(Pred)) {
11495 if (!LHS->hasNoUnsignedWrap())
11496 return std::nullopt;
11498 }
11499 assert(ICmpInst::isSigned(Pred) &&
11500 "Relational predicate is either signed or unsigned!");
11501 if (!LHS->hasNoSignedWrap())
11502 return std::nullopt;
11503
11504 const SCEV *Step = LHS->getStepRecurrence(*this);
11505
11506 if (isKnownNonNegative(Step))
11508
11509 if (isKnownNonPositive(Step))
11511
11512 return std::nullopt;
11513}
11514
11515std::optional<ScalarEvolution::LoopInvariantPredicate>
11517 const SCEV *RHS, const Loop *L,
11518 const Instruction *CtxI) {
11519 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
11520 if (!isLoopInvariant(RHS, L)) {
11521 if (!isLoopInvariant(LHS, L))
11522 return std::nullopt;
11523
11524 std::swap(LHS, RHS);
11526 }
11527
11528 const SCEVAddRecExpr *ArLHS = dyn_cast<SCEVAddRecExpr>(LHS);
11529 if (!ArLHS || ArLHS->getLoop() != L)
11530 return std::nullopt;
11531
11532 auto MonotonicType = getMonotonicPredicateType(ArLHS, Pred);
11533 if (!MonotonicType)
11534 return std::nullopt;
11535 // If the predicate "ArLHS `Pred` RHS" monotonically increases from false to
11536 // true as the loop iterates, and the backedge is control dependent on
11537 // "ArLHS `Pred` RHS" == true then we can reason as follows:
11538 //
11539 // * if the predicate was false in the first iteration then the predicate
11540 // is never evaluated again, since the loop exits without taking the
11541 // backedge.
11542 // * if the predicate was true in the first iteration then it will
11543 // continue to be true for all future iterations since it is
11544 // monotonically increasing.
11545 //
11546 // For both the above possibilities, we can replace the loop varying
11547 // predicate with its value on the first iteration of the loop (which is
11548 // loop invariant).
11549 //
11550 // A similar reasoning applies for a monotonically decreasing predicate, by
11551 // replacing true with false and false with true in the above two bullets.
11553 auto P = Increasing ? Pred : ICmpInst::getInverseCmpPredicate(Pred);
11554
11555 if (isLoopBackedgeGuardedByCond(L, P, LHS, RHS))
11557 RHS);
11558
11559 if (!CtxI)
11560 return std::nullopt;
11561 // Try to prove via context.
11562 // TODO: Support other cases.
11563 switch (Pred) {
11564 default:
11565 break;
11566 case ICmpInst::ICMP_ULE:
11567 case ICmpInst::ICMP_ULT: {
11568 assert(ArLHS->hasNoUnsignedWrap() && "Is a requirement of monotonicity!");
11569 // Given preconditions
11570 // (1) ArLHS does not cross the border of positive and negative parts of
11571 // range because of:
11572 // - Positive step; (TODO: lift this limitation)
11573 // - nuw - does not cross zero boundary;
11574 // - nsw - does not cross SINT_MAX boundary;
11575 // (2) ArLHS <s RHS
11576 // (3) RHS >=s 0
11577 // we can replace the loop variant ArLHS <u RHS condition with loop
11578 // invariant Start(ArLHS) <u RHS.
11579 //
11580 // Because of (1) there are two options:
11581 // - ArLHS is always negative. It means that ArLHS <u RHS is always false;
11582 // - ArLHS is always non-negative. Because of (3) RHS is also non-negative.
11583 // It means that ArLHS <s RHS <=> ArLHS <u RHS.
11584 // Because of (2) ArLHS <u RHS is trivially true.
11585 // All together it means that ArLHS <u RHS <=> Start(ArLHS) >=s 0.
11586 // We can strengthen this to Start(ArLHS) <u RHS.
11587 auto SignFlippedPred = ICmpInst::getFlippedSignednessPredicate(Pred);
11588 if (ArLHS->hasNoSignedWrap() && ArLHS->isAffine() &&
11589 isKnownPositive(ArLHS->getStepRecurrence(*this)) &&
11590 isKnownNonNegative(RHS) &&
11591 isKnownPredicateAt(SignFlippedPred, ArLHS, RHS, CtxI))
11593 RHS);
11594 }
11595 }
11596
11597 return std::nullopt;
11598}
11599
11600std::optional<ScalarEvolution::LoopInvariantPredicate>
11602 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11603 const Instruction *CtxI, const SCEV *MaxIter) {
11605 Pred, LHS, RHS, L, CtxI, MaxIter))
11606 return LIP;
11607 if (auto *UMin = dyn_cast<SCEVUMinExpr>(MaxIter))
11608 // Number of iterations expressed as UMIN isn't always great for expressing
11609 // the value on the last iteration. If the straightforward approach didn't
11610 // work, try the following trick: if the a predicate is invariant for X, it
11611 // is also invariant for umin(X, ...). So try to find something that works
11612 // among subexpressions of MaxIter expressed as umin.
11613 for (SCEVUse Op : UMin->operands())
11615 Pred, LHS, RHS, L, CtxI, Op))
11616 return LIP;
11617 return std::nullopt;
11618}
11619
11620std::optional<ScalarEvolution::LoopInvariantPredicate>
11622 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
11623 const Instruction *CtxI, const SCEV *MaxIter) {
11624 // Try to prove the following set of facts:
11625 // - The predicate is monotonic in the iteration space.
11626 // - If the check does not fail on the 1st iteration:
11627 // - No overflow will happen during first MaxIter iterations;
11628 // - It will not fail on the MaxIter'th iteration.
11629 // If the check does fail on the 1st iteration, we leave the loop and no
11630 // other checks matter.
11631
11632 // If there is a loop-invariant, force it into the RHS, otherwise bail out.
11633 if (!isLoopInvariant(RHS, L)) {
11634 if (!isLoopInvariant(LHS, L))
11635 return std::nullopt;
11636
11637 std::swap(LHS, RHS);
11639 }
11640
11641 auto *AR = dyn_cast<SCEVAddRecExpr>(LHS);
11642 if (!AR || AR->getLoop() != L)
11643 return std::nullopt;
11644
11645 // Even if both are valid, we need to consistently chose the unsigned or the
11646 // signed predicate below, not mixtures of both. For now, prefer the unsigned
11647 // predicate.
11648 Pred = Pred.dropSameSign();
11649
11650 // The predicate must be relational (i.e. <, <=, >=, >).
11651 if (!ICmpInst::isRelational(Pred))
11652 return std::nullopt;
11653
11654 // TODO: Support steps other than +/- 1.
11655 const SCEV *Step = AR->getStepRecurrence(*this);
11656 auto *One = getOne(Step->getType());
11657 auto *MinusOne = getNegativeSCEV(One);
11658 if (Step != One && Step != MinusOne)
11659 return std::nullopt;
11660
11661 // Type mismatch here means that MaxIter is potentially larger than max
11662 // unsigned value in start type, which mean we cannot prove no wrap for the
11663 // indvar.
11664 if (AR->getType() != MaxIter->getType())
11665 return std::nullopt;
11666
11667 // Value of IV on suggested last iteration.
11668 const SCEV *Last = AR->evaluateAtIteration(MaxIter, *this);
11669 // Does it still meet the requirement?
11670 if (!isLoopBackedgeGuardedByCond(L, Pred, Last, RHS))
11671 return std::nullopt;
11672 // Because step is +/- 1 and MaxIter has same type as Start (i.e. it does
11673 // not exceed max unsigned value of this type), this effectively proves
11674 // that there is no wrap during the iteration. To prove that there is no
11675 // signed/unsigned wrap, we need to check that
11676 // Start <= Last for step = 1 or Start >= Last for step = -1.
11677 ICmpInst::Predicate NoOverflowPred =
11679 if (Step == MinusOne)
11680 NoOverflowPred = ICmpInst::getSwappedPredicate(NoOverflowPred);
11681 const SCEV *Start = AR->getStart();
11682 if (!isKnownPredicateAt(NoOverflowPred, Start, Last, CtxI))
11683 return std::nullopt;
11684
11685 // Everything is fine.
11686 return ScalarEvolution::LoopInvariantPredicate(Pred, Start, RHS);
11687}
11688
11689bool ScalarEvolution::isKnownPredicateViaConstantRanges(CmpPredicate Pred,
11690 SCEVUse LHS,
11691 SCEVUse RHS) {
11692 if (HasSameValue(LHS, RHS))
11693 return ICmpInst::isTrueWhenEqual(Pred);
11694
11695 auto CheckRange = [&](bool IsSigned) {
11696 auto RangeLHS = IsSigned ? getSignedRange(LHS) : getUnsignedRange(LHS);
11697 auto RangeRHS = IsSigned ? getSignedRange(RHS) : getUnsignedRange(RHS);
11698 return RangeLHS.icmp(Pred, RangeRHS);
11699 };
11700
11701 // The check at the top of the function catches the case where the values are
11702 // known to be equal.
11703 if (Pred == CmpInst::ICMP_EQ)
11704 return false;
11705
11706 if (Pred == CmpInst::ICMP_NE) {
11707 if (CheckRange(true) || CheckRange(false))
11708 return true;
11709 auto *Diff = getMinusSCEV(LHS, RHS);
11710 return !isa<SCEVCouldNotCompute>(Diff) && isKnownNonZero(Diff);
11711 }
11712
11713 return CheckRange(CmpInst::isSigned(Pred));
11714}
11715
11716bool ScalarEvolution::isKnownPredicateViaNoOverflow(CmpPredicate Pred,
11718 // Match X to (A + C1)<ExpectedFlags> and Y to (A + C2)<ExpectedFlags>, where
11719 // C1 and C2 are constant integers. If either X or Y are not add expressions,
11720 // consider them as X + 0 and Y + 0 respectively. C1 and C2 are returned via
11721 // OutC1 and OutC2.
11722 auto MatchBinaryAddToConst = [this](SCEVUse X, SCEVUse Y, APInt &OutC1,
11723 APInt &OutC2,
11724 SCEV::NoWrapFlags ExpectedFlags) {
11725 SCEVUse XNonConstOp, XConstOp;
11726 SCEVUse YNonConstOp, YConstOp;
11727 SCEV::NoWrapFlags XFlagsPresent;
11728 SCEV::NoWrapFlags YFlagsPresent;
11729
11730 if (!splitBinaryAdd(X, XConstOp, XNonConstOp, XFlagsPresent)) {
11731 XConstOp = getZero(X->getType());
11732 XNonConstOp = X;
11733 XFlagsPresent = ExpectedFlags;
11734 }
11735 if (!isa<SCEVConstant>(XConstOp))
11736 return false;
11737
11738 if (!splitBinaryAdd(Y, YConstOp, YNonConstOp, YFlagsPresent)) {
11739 YConstOp = getZero(Y->getType());
11740 YNonConstOp = Y;
11741 YFlagsPresent = ExpectedFlags;
11742 }
11743
11744 if (YNonConstOp != XNonConstOp)
11745 return false;
11746
11747 if (!isa<SCEVConstant>(YConstOp))
11748 return false;
11749
11750 // When matching ADDs with NUW flags (and unsigned predicates), only the
11751 // second ADD (with the larger constant) requires NUW.
11752 if ((YFlagsPresent & ExpectedFlags) != ExpectedFlags)
11753 return false;
11754 if (ExpectedFlags != SCEV::FlagNUW &&
11755 (XFlagsPresent & ExpectedFlags) != ExpectedFlags) {
11756 return false;
11757 }
11758
11759 OutC1 = cast<SCEVConstant>(XConstOp)->getAPInt();
11760 OutC2 = cast<SCEVConstant>(YConstOp)->getAPInt();
11761
11762 return true;
11763 };
11764
11765 APInt C1;
11766 APInt C2;
11767
11768 switch (Pred) {
11769 default:
11770 break;
11771
11772 case ICmpInst::ICMP_SGE:
11773 std::swap(LHS, RHS);
11774 [[fallthrough]];
11775 case ICmpInst::ICMP_SLE:
11776 // (X + C1)<nsw> s<= (X + C2)<nsw> if C1 s<= C2.
11777 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.sle(C2))
11778 return true;
11779
11780 break;
11781
11782 case ICmpInst::ICMP_SGT:
11783 std::swap(LHS, RHS);
11784 [[fallthrough]];
11785 case ICmpInst::ICMP_SLT:
11786 // (X + C1)<nsw> s< (X + C2)<nsw> if C1 s< C2.
11787 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNSW) && C1.slt(C2))
11788 return true;
11789
11790 break;
11791
11792 case ICmpInst::ICMP_UGE:
11793 std::swap(LHS, RHS);
11794 [[fallthrough]];
11795 case ICmpInst::ICMP_ULE:
11796 // (X + C1) u<= (X + C2)<nuw> for C1 u<= C2.
11797 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNUW) && C1.ule(C2))
11798 return true;
11799
11800 break;
11801
11802 case ICmpInst::ICMP_UGT:
11803 std::swap(LHS, RHS);
11804 [[fallthrough]];
11805 case ICmpInst::ICMP_ULT:
11806 // (X + C1) u< (X + C2)<nuw> if C1 u< C2.
11807 if (MatchBinaryAddToConst(LHS, RHS, C1, C2, SCEV::FlagNUW) && C1.ult(C2))
11808 return true;
11809 break;
11810 }
11811
11812 return false;
11813}
11814
11815bool ScalarEvolution::isKnownPredicateViaSplitting(CmpPredicate Pred,
11817 if (Pred != ICmpInst::ICMP_ULT || ProvingSplitPredicate)
11818 return false;
11819
11820 // Allowing arbitrary number of activations of isKnownPredicateViaSplitting on
11821 // the stack can result in exponential time complexity.
11822 SaveAndRestore Restore(ProvingSplitPredicate, true);
11823
11824 // If L >= 0 then I `ult` L <=> I >= 0 && I `slt` L
11825 //
11826 // To prove L >= 0 we use isKnownNonNegative whereas to prove I >= 0 we use
11827 // isKnownPredicate. isKnownPredicate is more powerful, but also more
11828 // expensive; and using isKnownNonNegative(RHS) is sufficient for most of the
11829 // interesting cases seen in practice. We can consider "upgrading" L >= 0 to
11830 // use isKnownPredicate later if needed.
11831 return isKnownNonNegative(RHS) &&
11834}
11835
11836bool ScalarEvolution::isImpliedViaGuard(const BasicBlock *BB, CmpPredicate Pred,
11837 const SCEV *LHS, const SCEV *RHS) {
11838 // No need to even try if we know the module has no guards.
11839 if (!HasGuards)
11840 return false;
11841
11842 return any_of(*BB, [&](const Instruction &I) {
11843 using namespace llvm::PatternMatch;
11844
11845 Value *Condition;
11847 m_Value(Condition))) &&
11848 isImpliedCond(Pred, LHS, RHS, Condition, false);
11849 });
11850}
11851
11852/// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is
11853/// protected by a conditional between LHS and RHS. This is used to
11854/// to eliminate casts.
11856 CmpPredicate Pred,
11857 const SCEV *LHS,
11858 const SCEV *RHS) {
11859 // Interpret a null as meaning no loop, where there is obviously no guard
11860 // (interprocedural conditions notwithstanding). Do not bother about
11861 // unreachable loops.
11862 if (!L || !DT.isReachableFromEntry(L->getHeader()))
11863 return true;
11864
11865 if (VerifyIR)
11866 assert(!verifyFunction(*L->getHeader()->getParent(), &dbgs()) &&
11867 "This cannot be done on broken IR!");
11868
11869
11870 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
11871 return true;
11872
11873 BasicBlock *Latch = L->getLoopLatch();
11874 if (!Latch)
11875 return false;
11876
11877 CondBrInst *LoopContinuePredicate =
11879 if (LoopContinuePredicate &&
11880 isImpliedCond(Pred, LHS, RHS, LoopContinuePredicate->getCondition(),
11881 LoopContinuePredicate->getSuccessor(0) != L->getHeader()))
11882 return true;
11883
11884 // We don't want more than one activation of the following loops on the stack
11885 // -- that can lead to O(n!) time complexity.
11886 if (WalkingBEDominatingConds)
11887 return false;
11888
11889 SaveAndRestore ClearOnExit(WalkingBEDominatingConds, true);
11890
11891 // See if we can exploit a trip count to prove the predicate.
11892 const auto &BETakenInfo = getBackedgeTakenInfo(L);
11893 const SCEV *LatchBECount = BETakenInfo.getExact(Latch, this);
11894 if (LatchBECount != getCouldNotCompute()) {
11895 // We know that Latch branches back to the loop header exactly
11896 // LatchBECount times. This means the backdege condition at Latch is
11897 // equivalent to "{0,+,1} u< LatchBECount".
11898 Type *Ty = LatchBECount->getType();
11899 auto NoWrapFlags = SCEV::NoWrapFlags(SCEV::FlagNUW | SCEV::FlagNW);
11900 const SCEV *LoopCounter =
11901 getAddRecExpr(getZero(Ty), getOne(Ty), L, NoWrapFlags);
11902 if (isImpliedCond(Pred, LHS, RHS, ICmpInst::ICMP_ULT, LoopCounter,
11903 LatchBECount))
11904 return true;
11905 }
11906
11907 // Check conditions due to any @llvm.assume intrinsics.
11908 for (auto &AssumeVH : AC.assumptions()) {
11909 if (!AssumeVH)
11910 continue;
11911 auto *CI = cast<CallInst>(AssumeVH);
11912 if (!DT.dominates(CI, Latch->getTerminator()))
11913 continue;
11914
11915 if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false))
11916 return true;
11917 }
11918
11919 if (isImpliedViaGuard(Latch, Pred, LHS, RHS))
11920 return true;
11921
11922 for (DomTreeNode *DTN = DT[Latch], *HeaderDTN = DT[L->getHeader()];
11923 DTN != HeaderDTN; DTN = DTN->getIDom()) {
11924 assert(DTN && "should reach the loop header before reaching the root!");
11925
11926 BasicBlock *BB = DTN->getBlock();
11927 if (isImpliedViaGuard(BB, Pred, LHS, RHS))
11928 return true;
11929
11930 BasicBlock *PBB = BB->getSinglePredecessor();
11931 if (!PBB)
11932 continue;
11933
11935 if (!ContBr || ContBr->getSuccessor(0) == ContBr->getSuccessor(1))
11936 continue;
11937
11938 // If we have an edge `E` within the loop body that dominates the only
11939 // latch, the condition guarding `E` also guards the backedge. This
11940 // reasoning works only for loops with a single latch.
11941 // We're constructively (and conservatively) enumerating edges within the
11942 // loop body that dominate the latch. The dominator tree better agree
11943 // with us on this:
11944 assert(DT.dominates(BasicBlockEdge(PBB, BB), Latch) && "should be!");
11945 if (isImpliedCond(Pred, LHS, RHS, ContBr->getCondition(),
11946 BB != ContBr->getSuccessor(0)))
11947 return true;
11948 }
11949
11950 return false;
11951}
11952
11954 CmpPredicate Pred,
11955 const SCEV *LHS,
11956 const SCEV *RHS) {
11957 // Do not bother proving facts for unreachable code.
11958 if (!DT.isReachableFromEntry(BB))
11959 return true;
11960 if (VerifyIR)
11961 assert(!verifyFunction(*BB->getParent(), &dbgs()) &&
11962 "This cannot be done on broken IR!");
11963
11964 // If we cannot prove strict comparison (e.g. a > b), maybe we can prove
11965 // the facts (a >= b && a != b) separately. A typical situation is when the
11966 // non-strict comparison is known from ranges and non-equality is known from
11967 // dominating predicates. If we are proving strict comparison, we always try
11968 // to prove non-equality and non-strict comparison separately.
11969 CmpPredicate NonStrictPredicate = ICmpInst::getNonStrictCmpPredicate(Pred);
11970 const bool ProvingStrictComparison =
11971 Pred != NonStrictPredicate.dropSameSign();
11972 bool ProvedNonStrictComparison = false;
11973 bool ProvedNonEquality = false;
11974
11975 auto SplitAndProve = [&](std::function<bool(CmpPredicate)> Fn) -> bool {
11976 if (!ProvedNonStrictComparison)
11977 ProvedNonStrictComparison = Fn(NonStrictPredicate);
11978 if (!ProvedNonEquality)
11979 ProvedNonEquality = Fn(ICmpInst::ICMP_NE);
11980 if (ProvedNonStrictComparison && ProvedNonEquality)
11981 return true;
11982 return false;
11983 };
11984
11985 if (ProvingStrictComparison) {
11986 auto ProofFn = [&](CmpPredicate P) {
11987 return isKnownViaNonRecursiveReasoning(P, LHS, RHS);
11988 };
11989 if (SplitAndProve(ProofFn))
11990 return true;
11991 }
11992
11993 // Try to prove (Pred, LHS, RHS) using isImpliedCond.
11994 auto ProveViaCond = [&](const Value *Condition, bool Inverse) {
11995 const Instruction *CtxI = &BB->front();
11996 if (isImpliedCond(Pred, LHS, RHS, Condition, Inverse, CtxI))
11997 return true;
11998 if (ProvingStrictComparison) {
11999 auto ProofFn = [&](CmpPredicate P) {
12000 return isImpliedCond(P, LHS, RHS, Condition, Inverse, CtxI);
12001 };
12002 if (SplitAndProve(ProofFn))
12003 return true;
12004 }
12005 return false;
12006 };
12007
12008 // Starting at the block's predecessor, climb up the predecessor chain, as long
12009 // as there are predecessors that can be found that have unique successors
12010 // leading to the original block.
12011 const Loop *ContainingLoop = LI.getLoopFor(BB);
12012 const BasicBlock *PredBB;
12013 if (ContainingLoop && ContainingLoop->getHeader() == BB)
12014 PredBB = ContainingLoop->getLoopPredecessor();
12015 else
12016 PredBB = BB->getSinglePredecessor();
12017 for (std::pair<const BasicBlock *, const BasicBlock *> Pair(PredBB, BB);
12018 Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
12019 const CondBrInst *BlockEntryPredicate =
12020 dyn_cast<CondBrInst>(Pair.first->getTerminator());
12021 if (!BlockEntryPredicate)
12022 continue;
12023
12024 if (ProveViaCond(BlockEntryPredicate->getCondition(),
12025 BlockEntryPredicate->getSuccessor(0) != Pair.second))
12026 return true;
12027 }
12028
12029 // Check conditions due to any @llvm.assume intrinsics.
12030 for (auto &AssumeVH : AC.assumptions()) {
12031 if (!AssumeVH)
12032 continue;
12033 auto *CI = cast<CallInst>(AssumeVH);
12034 if (!DT.dominates(CI, BB))
12035 continue;
12036
12037 if (ProveViaCond(CI->getArgOperand(0), false))
12038 return true;
12039 }
12040
12041 // Check conditions due to any @llvm.experimental.guard intrinsics.
12042 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
12043 F.getParent(), Intrinsic::experimental_guard);
12044 if (GuardDecl)
12045 for (const auto *GU : GuardDecl->users())
12046 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
12047 if (Guard->getFunction() == BB->getParent() && DT.dominates(Guard, BB))
12048 if (ProveViaCond(Guard->getArgOperand(0), false))
12049 return true;
12050 return false;
12051}
12052
12054 const SCEV *LHS,
12055 const SCEV *RHS) {
12056 // Interpret a null as meaning no loop, where there is obviously no guard
12057 // (interprocedural conditions notwithstanding).
12058 if (!L)
12059 return false;
12060
12061 // Both LHS and RHS must be available at loop entry.
12063 "LHS is not available at Loop Entry");
12065 "RHS is not available at Loop Entry");
12066
12067 if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
12068 return true;
12069
12070 return isBasicBlockEntryGuardedByCond(L->getHeader(), Pred, LHS, RHS);
12071}
12072
12073bool ScalarEvolution::isImpliedCond(CmpPredicate Pred, const SCEV *LHS,
12074 const SCEV *RHS,
12075 const Value *FoundCondValue, bool Inverse,
12076 const Instruction *CtxI) {
12077 // False conditions implies anything. Do not bother analyzing it further.
12078 if (FoundCondValue ==
12079 ConstantInt::getBool(FoundCondValue->getContext(), Inverse))
12080 return true;
12081
12082 if (!PendingLoopPredicates.insert(FoundCondValue).second)
12083 return false;
12084
12085 llvm::scope_exit ClearOnExit(
12086 [&]() { PendingLoopPredicates.erase(FoundCondValue); });
12087
12088 // Recursively handle And and Or conditions.
12089 const Value *Op0, *Op1;
12090 if (match(FoundCondValue, m_LogicalAnd(m_Value(Op0), m_Value(Op1)))) {
12091 if (!Inverse)
12092 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
12093 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
12094 } else if (match(FoundCondValue, m_LogicalOr(m_Value(Op0), m_Value(Op1)))) {
12095 if (Inverse)
12096 return isImpliedCond(Pred, LHS, RHS, Op0, Inverse, CtxI) ||
12097 isImpliedCond(Pred, LHS, RHS, Op1, Inverse, CtxI);
12098 }
12099
12100 const ICmpInst *ICI = dyn_cast<ICmpInst>(FoundCondValue);
12101 if (!ICI) return false;
12102
12103 // Now that we found a conditional branch that dominates the loop or controls
12104 // the loop latch. Check to see if it is the comparison we are looking for.
12105 CmpPredicate FoundPred;
12106 if (Inverse)
12107 FoundPred = ICI->getInverseCmpPredicate();
12108 else
12109 FoundPred = ICI->getCmpPredicate();
12110
12111 const SCEV *FoundLHS = getSCEV(ICI->getOperand(0));
12112 const SCEV *FoundRHS = getSCEV(ICI->getOperand(1));
12113
12114 return isImpliedCond(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS, CtxI);
12115}
12116
12117bool ScalarEvolution::isImpliedCond(CmpPredicate Pred, const SCEV *LHS,
12118 const SCEV *RHS, CmpPredicate FoundPred,
12119 const SCEV *FoundLHS, const SCEV *FoundRHS,
12120 const Instruction *CtxI) {
12121 // Balance the types.
12122 if (getTypeSizeInBits(LHS->getType()) <
12123 getTypeSizeInBits(FoundLHS->getType())) {
12124 // For unsigned and equality predicates, try to prove that both found
12125 // operands fit into narrow unsigned range. If so, try to prove facts in
12126 // narrow types.
12127 if (!CmpInst::isSigned(FoundPred) && !FoundLHS->getType()->isPointerTy() &&
12128 !FoundRHS->getType()->isPointerTy()) {
12129 auto *NarrowType = LHS->getType();
12130 auto *WideType = FoundLHS->getType();
12131 auto BitWidth = getTypeSizeInBits(NarrowType);
12132 const SCEV *MaxValue = getZeroExtendExpr(
12134 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundLHS,
12135 MaxValue) &&
12136 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, FoundRHS,
12137 MaxValue)) {
12138 const SCEV *TruncFoundLHS = getTruncateExpr(FoundLHS, NarrowType);
12139 const SCEV *TruncFoundRHS = getTruncateExpr(FoundRHS, NarrowType);
12140 // We cannot preserve samesign after truncation.
12141 if (isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred.dropSameSign(),
12142 TruncFoundLHS, TruncFoundRHS, CtxI))
12143 return true;
12144 }
12145 }
12146
12147 if (LHS->getType()->isPointerTy() || RHS->getType()->isPointerTy())
12148 return false;
12149 if (CmpInst::isSigned(Pred)) {
12150 LHS = getSignExtendExpr(LHS, FoundLHS->getType());
12151 RHS = getSignExtendExpr(RHS, FoundLHS->getType());
12152 } else {
12153 LHS = getZeroExtendExpr(LHS, FoundLHS->getType());
12154 RHS = getZeroExtendExpr(RHS, FoundLHS->getType());
12155 }
12156 } else if (getTypeSizeInBits(LHS->getType()) >
12157 getTypeSizeInBits(FoundLHS->getType())) {
12158 if (FoundLHS->getType()->isPointerTy() || FoundRHS->getType()->isPointerTy())
12159 return false;
12160 if (CmpInst::isSigned(FoundPred)) {
12161 FoundLHS = getSignExtendExpr(FoundLHS, LHS->getType());
12162 FoundRHS = getSignExtendExpr(FoundRHS, LHS->getType());
12163 } else {
12164 FoundLHS = getZeroExtendExpr(FoundLHS, LHS->getType());
12165 FoundRHS = getZeroExtendExpr(FoundRHS, LHS->getType());
12166 }
12167 }
12168 return isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, FoundLHS,
12169 FoundRHS, CtxI);
12170}
12171
12172bool ScalarEvolution::isImpliedCondBalancedTypes(
12173 CmpPredicate Pred, SCEVUse LHS, SCEVUse RHS, CmpPredicate FoundPred,
12174 SCEVUse FoundLHS, SCEVUse FoundRHS, const Instruction *CtxI) {
12176 getTypeSizeInBits(FoundLHS->getType()) &&
12177 "Types should be balanced!");
12178 // Canonicalize the query to match the way instcombine will have
12179 // canonicalized the comparison.
12180 if (SimplifyICmpOperands(Pred, LHS, RHS))
12181 if (LHS == RHS)
12182 return CmpInst::isTrueWhenEqual(Pred);
12183 if (SimplifyICmpOperands(FoundPred, FoundLHS, FoundRHS))
12184 if (FoundLHS == FoundRHS)
12185 return CmpInst::isFalseWhenEqual(FoundPred);
12186
12187 // Check to see if we can make the LHS or RHS match.
12188 if (LHS == FoundRHS || RHS == FoundLHS) {
12189 if (isa<SCEVConstant>(RHS)) {
12190 std::swap(FoundLHS, FoundRHS);
12191 FoundPred = ICmpInst::getSwappedCmpPredicate(FoundPred);
12192 } else {
12193 std::swap(LHS, RHS);
12195 }
12196 }
12197
12198 // Check whether the found predicate is the same as the desired predicate.
12199 if (auto P = CmpPredicate::getMatching(FoundPred, Pred))
12200 return isImpliedCondOperands(*P, LHS, RHS, FoundLHS, FoundRHS, CtxI);
12201
12202 // Check whether swapping the found predicate makes it the same as the
12203 // desired predicate.
12204 if (auto P = CmpPredicate::getMatching(
12205 ICmpInst::getSwappedCmpPredicate(FoundPred), Pred)) {
12206 // We can write the implication
12207 // 0. LHS Pred RHS <- FoundLHS SwapPred FoundRHS
12208 // using one of the following ways:
12209 // 1. LHS Pred RHS <- FoundRHS Pred FoundLHS
12210 // 2. RHS SwapPred LHS <- FoundLHS SwapPred FoundRHS
12211 // 3. LHS Pred RHS <- ~FoundLHS Pred ~FoundRHS
12212 // 4. ~LHS SwapPred ~RHS <- FoundLHS SwapPred FoundRHS
12213 // Forms 1. and 2. require swapping the operands of one condition. Don't
12214 // do this if it would break canonical constant/addrec ordering.
12216 return isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(*P), RHS,
12217 LHS, FoundLHS, FoundRHS, CtxI);
12218 if (!isa<SCEVConstant>(FoundRHS) && !isa<SCEVAddRecExpr>(FoundLHS))
12219 return isImpliedCondOperands(*P, LHS, RHS, FoundRHS, FoundLHS, CtxI);
12220
12221 // There's no clear preference between forms 3. and 4., try both. Avoid
12222 // forming getNotSCEV of pointer values as the resulting subtract is
12223 // not legal.
12224 if (!LHS->getType()->isPointerTy() && !RHS->getType()->isPointerTy() &&
12225 isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(*P),
12226 getNotSCEV(LHS), getNotSCEV(RHS), FoundLHS,
12227 FoundRHS, CtxI))
12228 return true;
12229
12230 if (!FoundLHS->getType()->isPointerTy() &&
12231 !FoundRHS->getType()->isPointerTy() &&
12232 isImpliedCondOperands(*P, LHS, RHS, getNotSCEV(FoundLHS),
12233 getNotSCEV(FoundRHS), CtxI))
12234 return true;
12235
12236 return false;
12237 }
12238
12239 auto IsSignFlippedPredicate = [](CmpInst::Predicate P1,
12241 assert(P1 != P2 && "Handled earlier!");
12242 return CmpInst::isRelational(P2) &&
12244 };
12245 if (IsSignFlippedPredicate(Pred, FoundPred)) {
12246 // Unsigned comparison is the same as signed comparison when both the
12247 // operands are non-negative or negative.
12248 if (haveSameSign(FoundLHS, FoundRHS))
12249 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI);
12250 // Create local copies that we can freely swap and canonicalize our
12251 // conditions to "le/lt".
12252 CmpPredicate CanonicalPred = Pred, CanonicalFoundPred = FoundPred;
12253 const SCEV *CanonicalLHS = LHS, *CanonicalRHS = RHS,
12254 *CanonicalFoundLHS = FoundLHS, *CanonicalFoundRHS = FoundRHS;
12255 if (ICmpInst::isGT(CanonicalPred) || ICmpInst::isGE(CanonicalPred)) {
12256 CanonicalPred = ICmpInst::getSwappedCmpPredicate(CanonicalPred);
12257 CanonicalFoundPred = ICmpInst::getSwappedCmpPredicate(CanonicalFoundPred);
12258 std::swap(CanonicalLHS, CanonicalRHS);
12259 std::swap(CanonicalFoundLHS, CanonicalFoundRHS);
12260 }
12261 assert((ICmpInst::isLT(CanonicalPred) || ICmpInst::isLE(CanonicalPred)) &&
12262 "Must be!");
12263 assert((ICmpInst::isLT(CanonicalFoundPred) ||
12264 ICmpInst::isLE(CanonicalFoundPred)) &&
12265 "Must be!");
12266 if (ICmpInst::isSigned(CanonicalPred) && isKnownNonNegative(CanonicalRHS))
12267 // Use implication:
12268 // x <u y && y >=s 0 --> x <s y.
12269 // If we can prove the left part, the right part is also proven.
12270 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
12271 CanonicalRHS, CanonicalFoundLHS,
12272 CanonicalFoundRHS);
12273 if (ICmpInst::isUnsigned(CanonicalPred) && isKnownNegative(CanonicalRHS))
12274 // Use implication:
12275 // x <s y && y <s 0 --> x <u y.
12276 // If we can prove the left part, the right part is also proven.
12277 return isImpliedCondOperands(CanonicalFoundPred, CanonicalLHS,
12278 CanonicalRHS, CanonicalFoundLHS,
12279 CanonicalFoundRHS);
12280 }
12281
12282 // Check if we can make progress by sharpening ranges.
12283 if (FoundPred == ICmpInst::ICMP_NE &&
12284 (isa<SCEVConstant>(FoundLHS) || isa<SCEVConstant>(FoundRHS))) {
12285
12286 const SCEVConstant *C = nullptr;
12287 const SCEV *V = nullptr;
12288
12289 if (isa<SCEVConstant>(FoundLHS)) {
12290 C = cast<SCEVConstant>(FoundLHS);
12291 V = FoundRHS;
12292 } else {
12293 C = cast<SCEVConstant>(FoundRHS);
12294 V = FoundLHS;
12295 }
12296
12297 // The guarding predicate tells us that C != V. If the known range
12298 // of V is [C, t), we can sharpen the range to [C + 1, t). The
12299 // range we consider has to correspond to same signedness as the
12300 // predicate we're interested in folding.
12301
12302 APInt Min = ICmpInst::isSigned(Pred) ?
12304
12305 if (Min == C->getAPInt()) {
12306 // Given (V >= Min && V != Min) we conclude V >= (Min + 1).
12307 // This is true even if (Min + 1) wraps around -- in case of
12308 // wraparound, (Min + 1) < Min, so (V >= Min => V >= (Min + 1)).
12309
12310 APInt SharperMin = Min + 1;
12311
12312 switch (Pred) {
12313 case ICmpInst::ICMP_SGE:
12314 case ICmpInst::ICMP_UGE:
12315 // We know V `Pred` SharperMin. If this implies LHS `Pred`
12316 // RHS, we're done.
12317 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(SharperMin),
12318 CtxI))
12319 return true;
12320 [[fallthrough]];
12321
12322 case ICmpInst::ICMP_SGT:
12323 case ICmpInst::ICMP_UGT:
12324 // We know from the range information that (V `Pred` Min ||
12325 // V == Min). We know from the guarding condition that !(V
12326 // == Min). This gives us
12327 //
12328 // V `Pred` Min || V == Min && !(V == Min)
12329 // => V `Pred` Min
12330 //
12331 // If V `Pred` Min implies LHS `Pred` RHS, we're done.
12332
12333 if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min), CtxI))
12334 return true;
12335 break;
12336
12337 // `LHS < RHS` and `LHS <= RHS` are handled in the same way as `RHS > LHS` and `RHS >= LHS` respectively.
12338 case ICmpInst::ICMP_SLE:
12339 case ICmpInst::ICMP_ULE:
12340 if (isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(Pred), RHS,
12341 LHS, V, getConstant(SharperMin), CtxI))
12342 return true;
12343 [[fallthrough]];
12344
12345 case ICmpInst::ICMP_SLT:
12346 case ICmpInst::ICMP_ULT:
12347 if (isImpliedCondOperands(ICmpInst::getSwappedCmpPredicate(Pred), RHS,
12348 LHS, V, getConstant(Min), CtxI))
12349 return true;
12350 break;
12351
12352 default:
12353 // No change
12354 break;
12355 }
12356 }
12357 }
12358
12359 // Check whether the actual condition is beyond sufficient.
12360 if (FoundPred == ICmpInst::ICMP_EQ)
12361 if (ICmpInst::isTrueWhenEqual(Pred))
12362 if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
12363 return true;
12364 if (Pred == ICmpInst::ICMP_NE)
12365 if (!ICmpInst::isTrueWhenEqual(FoundPred))
12366 if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS, CtxI))
12367 return true;
12368
12369 if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS))
12370 return true;
12371
12372 // Otherwise assume the worst.
12373 return false;
12374}
12375
12376bool ScalarEvolution::splitBinaryAdd(SCEVUse Expr, SCEVUse &L, SCEVUse &R,
12377 SCEV::NoWrapFlags &Flags) {
12378 if (!match(Expr, m_scev_Add(m_SCEV(L), m_SCEV(R))))
12379 return false;
12380
12381 Flags = cast<SCEVAddExpr>(Expr)->getNoWrapFlags();
12382 return true;
12383}
12384
12385std::optional<APInt>
12387 // We avoid subtracting expressions here because this function is usually
12388 // fairly deep in the call stack (i.e. is called many times).
12389
12390 unsigned BW = getTypeSizeInBits(More->getType());
12391 APInt Diff(BW, 0);
12392 APInt DiffMul(BW, 1);
12393 // Try various simplifications to reduce the difference to a constant. Limit
12394 // the number of allowed simplifications to keep compile-time low.
12395 for (unsigned I = 0; I < 8; ++I) {
12396 if (More == Less)
12397 return Diff;
12398
12399 // Reduce addrecs with identical steps to their start value.
12401 const auto *LAR = cast<SCEVAddRecExpr>(Less);
12402 const auto *MAR = cast<SCEVAddRecExpr>(More);
12403
12404 if (LAR->getLoop() != MAR->getLoop())
12405 return std::nullopt;
12406
12407 // We look at affine expressions only; not for correctness but to keep
12408 // getStepRecurrence cheap.
12409 if (!LAR->isAffine() || !MAR->isAffine())
12410 return std::nullopt;
12411
12412 if (LAR->getStepRecurrence(*this) != MAR->getStepRecurrence(*this))
12413 return std::nullopt;
12414
12415 Less = LAR->getStart();
12416 More = MAR->getStart();
12417 continue;
12418 }
12419
12420 // Try to match a common constant multiply.
12421 auto MatchConstMul =
12422 [](const SCEV *S) -> std::optional<std::pair<const SCEV *, APInt>> {
12423 const APInt *C;
12424 const SCEV *Op;
12425 if (match(S, m_scev_Mul(m_scev_APInt(C), m_SCEV(Op))))
12426 return {{Op, *C}};
12427 return std::nullopt;
12428 };
12429 if (auto MatchedMore = MatchConstMul(More)) {
12430 if (auto MatchedLess = MatchConstMul(Less)) {
12431 if (MatchedMore->second == MatchedLess->second) {
12432 More = MatchedMore->first;
12433 Less = MatchedLess->first;
12434 DiffMul *= MatchedMore->second;
12435 continue;
12436 }
12437 }
12438 }
12439
12440 // Try to cancel out common factors in two add expressions.
12442 auto Add = [&](const SCEV *S, int Mul) {
12443 if (auto *C = dyn_cast<SCEVConstant>(S)) {
12444 if (Mul == 1) {
12445 Diff += C->getAPInt() * DiffMul;
12446 } else {
12447 assert(Mul == -1);
12448 Diff -= C->getAPInt() * DiffMul;
12449 }
12450 } else
12451 Multiplicity[S] += Mul;
12452 };
12453 auto Decompose = [&](const SCEV *S, int Mul) {
12454 if (isa<SCEVAddExpr>(S)) {
12455 for (const SCEV *Op : S->operands())
12456 Add(Op, Mul);
12457 } else
12458 Add(S, Mul);
12459 };
12460 Decompose(More, 1);
12461 Decompose(Less, -1);
12462
12463 // Check whether all the non-constants cancel out, or reduce to new
12464 // More/Less values.
12465 const SCEV *NewMore = nullptr, *NewLess = nullptr;
12466 for (const auto &[S, Mul] : Multiplicity) {
12467 if (Mul == 0)
12468 continue;
12469 if (Mul == 1) {
12470 if (NewMore)
12471 return std::nullopt;
12472 NewMore = S;
12473 } else if (Mul == -1) {
12474 if (NewLess)
12475 return std::nullopt;
12476 NewLess = S;
12477 } else
12478 return std::nullopt;
12479 }
12480
12481 // Values stayed the same, no point in trying further.
12482 if (NewMore == More || NewLess == Less)
12483 return std::nullopt;
12484
12485 More = NewMore;
12486 Less = NewLess;
12487
12488 // Reduced to constant.
12489 if (!More && !Less)
12490 return Diff;
12491
12492 // Left with variable on only one side, bail out.
12493 if (!More || !Less)
12494 return std::nullopt;
12495 }
12496
12497 // Did not reduce to constant.
12498 return std::nullopt;
12499}
12500
12501bool ScalarEvolution::isImpliedCondOperandsViaAddRecStart(
12502 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS,
12503 const SCEV *FoundRHS, const Instruction *CtxI) {
12504 // Try to recognize the following pattern:
12505 //
12506 // FoundRHS = ...
12507 // ...
12508 // loop:
12509 // FoundLHS = {Start,+,W}
12510 // context_bb: // Basic block from the same loop
12511 // known(Pred, FoundLHS, FoundRHS)
12512 //
12513 // If some predicate is known in the context of a loop, it is also known on
12514 // each iteration of this loop, including the first iteration. Therefore, in
12515 // this case, `FoundLHS Pred FoundRHS` implies `Start Pred FoundRHS`. Try to
12516 // prove the original pred using this fact.
12517 if (!CtxI)
12518 return false;
12519 const BasicBlock *ContextBB = CtxI->getParent();
12520 // Make sure AR varies in the context block.
12521 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundLHS)) {
12522 const Loop *L = AR->getLoop();
12523 const auto *Latch = L->getLoopLatch();
12524 // Make sure that context belongs to the loop and executes on 1st iteration
12525 // (if it ever executes at all).
12526 if (!L->contains(ContextBB) || !Latch || !DT.dominates(ContextBB, Latch))
12527 return false;
12528 if (!isAvailableAtLoopEntry(FoundRHS, AR->getLoop()))
12529 return false;
12530 return isImpliedCondOperands(Pred, LHS, RHS, AR->getStart(), FoundRHS);
12531 }
12532
12533 if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundRHS)) {
12534 const Loop *L = AR->getLoop();
12535 const auto *Latch = L->getLoopLatch();
12536 // Make sure that context belongs to the loop and executes on 1st iteration
12537 // (if it ever executes at all).
12538 if (!L->contains(ContextBB) || !Latch || !DT.dominates(ContextBB, Latch))
12539 return false;
12540 if (!isAvailableAtLoopEntry(FoundLHS, AR->getLoop()))
12541 return false;
12542 return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, AR->getStart());
12543 }
12544
12545 return false;
12546}
12547
12548bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow(CmpPredicate Pred,
12549 const SCEV *LHS,
12550 const SCEV *RHS,
12551 const SCEV *FoundLHS,
12552 const SCEV *FoundRHS) {
12553 if (Pred != CmpInst::ICMP_SLT && Pred != CmpInst::ICMP_ULT)
12554 return false;
12555
12556 const auto *AddRecLHS = dyn_cast<SCEVAddRecExpr>(LHS);
12557 if (!AddRecLHS)
12558 return false;
12559
12560 const auto *AddRecFoundLHS = dyn_cast<SCEVAddRecExpr>(FoundLHS);
12561 if (!AddRecFoundLHS)
12562 return false;
12563
12564 // We'd like to let SCEV reason about control dependencies, so we constrain
12565 // both the inequalities to be about add recurrences on the same loop. This
12566 // way we can use isLoopEntryGuardedByCond later.
12567
12568 const Loop *L = AddRecFoundLHS->getLoop();
12569 if (L != AddRecLHS->getLoop())
12570 return false;
12571
12572 // FoundLHS u< FoundRHS u< -C => (FoundLHS + C) u< (FoundRHS + C) ... (1)
12573 //
12574 // FoundLHS s< FoundRHS s< INT_MIN - C => (FoundLHS + C) s< (FoundRHS + C)
12575 // ... (2)
12576 //
12577 // Informal proof for (2), assuming (1) [*]:
12578 //
12579 // We'll also assume (A s< B) <=> ((A + INT_MIN) u< (B + INT_MIN)) ... (3)[**]
12580 //
12581 // Then
12582 //
12583 // FoundLHS s< FoundRHS s< INT_MIN - C
12584 // <=> (FoundLHS + INT_MIN) u< (FoundRHS + INT_MIN) u< -C [ using (3) ]
12585 // <=> (FoundLHS + INT_MIN + C) u< (FoundRHS + INT_MIN + C) [ using (1) ]
12586 // <=> (FoundLHS + INT_MIN + C + INT_MIN) s<
12587 // (FoundRHS + INT_MIN + C + INT_MIN) [ using (3) ]
12588 // <=> FoundLHS + C s< FoundRHS + C
12589 //
12590 // [*]: (1) can be proved by ruling out overflow.
12591 //
12592 // [**]: This can be proved by analyzing all the four possibilities:
12593 // (A s< 0, B s< 0), (A s< 0, B s>= 0), (A s>= 0, B s< 0) and
12594 // (A s>= 0, B s>= 0).
12595 //
12596 // Note:
12597 // Despite (2), "FoundRHS s< INT_MIN - C" does not mean that "FoundRHS + C"
12598 // will not sign underflow. For instance, say FoundLHS = (i8 -128), FoundRHS
12599 // = (i8 -127) and C = (i8 -100). Then INT_MIN - C = (i8 -28), and FoundRHS
12600 // s< (INT_MIN - C). Lack of sign overflow / underflow in "FoundRHS + C" is
12601 // neither necessary nor sufficient to prove "(FoundLHS + C) s< (FoundRHS +
12602 // C)".
12603
12604 std::optional<APInt> LDiff = computeConstantDifference(LHS, FoundLHS);
12605 if (!LDiff)
12606 return false;
12607 std::optional<APInt> RDiff = computeConstantDifference(RHS, FoundRHS);
12608 if (!RDiff || *LDiff != *RDiff)
12609 return false;
12610
12611 if (LDiff->isMinValue())
12612 return true;
12613
12614 APInt FoundRHSLimit;
12615
12616 if (Pred == CmpInst::ICMP_ULT) {
12617 FoundRHSLimit = -(*RDiff);
12618 } else {
12619 assert(Pred == CmpInst::ICMP_SLT && "Checked above!");
12620 FoundRHSLimit = APInt::getSignedMinValue(getTypeSizeInBits(RHS->getType())) - *RDiff;
12621 }
12622
12623 // Try to prove (1) or (2), as needed.
12624 return isAvailableAtLoopEntry(FoundRHS, L) &&
12625 isLoopEntryGuardedByCond(L, Pred, FoundRHS,
12626 getConstant(FoundRHSLimit));
12627}
12628
12629bool ScalarEvolution::isImpliedViaMerge(CmpPredicate Pred, const SCEV *LHS,
12630 const SCEV *RHS, const SCEV *FoundLHS,
12631 const SCEV *FoundRHS, unsigned Depth) {
12632 const PHINode *LPhi = nullptr, *RPhi = nullptr;
12633
12634 llvm::scope_exit ClearOnExit([&]() {
12635 if (LPhi) {
12636 bool Erased = PendingMerges.erase(LPhi);
12637 assert(Erased && "Failed to erase LPhi!");
12638 (void)Erased;
12639 }
12640 if (RPhi) {
12641 bool Erased = PendingMerges.erase(RPhi);
12642 assert(Erased && "Failed to erase RPhi!");
12643 (void)Erased;
12644 }
12645 });
12646
12647 // Find respective Phis and check that they are not being pending.
12648 if (const SCEVUnknown *LU = dyn_cast<SCEVUnknown>(LHS))
12649 if (auto *Phi = dyn_cast<PHINode>(LU->getValue())) {
12650 if (!PendingMerges.insert(Phi).second)
12651 return false;
12652 LPhi = Phi;
12653 }
12654 if (const SCEVUnknown *RU = dyn_cast<SCEVUnknown>(RHS))
12655 if (auto *Phi = dyn_cast<PHINode>(RU->getValue())) {
12656 // If we detect a loop of Phi nodes being processed by this method, for
12657 // example:
12658 //
12659 // %a = phi i32 [ %some1, %preheader ], [ %b, %latch ]
12660 // %b = phi i32 [ %some2, %preheader ], [ %a, %latch ]
12661 //
12662 // we don't want to deal with a case that complex, so return conservative
12663 // answer false.
12664 if (!PendingMerges.insert(Phi).second)
12665 return false;
12666 RPhi = Phi;
12667 }
12668
12669 // If none of LHS, RHS is a Phi, nothing to do here.
12670 if (!LPhi && !RPhi)
12671 return false;
12672
12673 // If there is a SCEVUnknown Phi we are interested in, make it left.
12674 if (!LPhi) {
12675 std::swap(LHS, RHS);
12676 std::swap(FoundLHS, FoundRHS);
12677 std::swap(LPhi, RPhi);
12679 }
12680
12681 assert(LPhi && "LPhi should definitely be a SCEVUnknown Phi!");
12682 const BasicBlock *LBB = LPhi->getParent();
12683 const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
12684
12685 auto ProvedEasily = [&](const SCEV *S1, const SCEV *S2) {
12686 return isKnownViaNonRecursiveReasoning(Pred, S1, S2) ||
12687 isImpliedCondOperandsViaRanges(Pred, S1, S2, Pred, FoundLHS, FoundRHS) ||
12688 isImpliedViaOperations(Pred, S1, S2, FoundLHS, FoundRHS, Depth);
12689 };
12690
12691 if (RPhi && RPhi->getParent() == LBB) {
12692 // Case one: RHS is also a SCEVUnknown Phi from the same basic block.
12693 // If we compare two Phis from the same block, and for each entry block
12694 // the predicate is true for incoming values from this block, then the
12695 // predicate is also true for the Phis.
12696 for (const BasicBlock *IncBB : predecessors(LBB)) {
12697 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12698 const SCEV *R = getSCEV(RPhi->getIncomingValueForBlock(IncBB));
12699 if (!ProvedEasily(L, R))
12700 return false;
12701 }
12702 } else if (RAR && RAR->getLoop()->getHeader() == LBB) {
12703 // Case two: RHS is also a Phi from the same basic block, and it is an
12704 // AddRec. It means that there is a loop which has both AddRec and Unknown
12705 // PHIs, for it we can compare incoming values of AddRec from above the loop
12706 // and latch with their respective incoming values of LPhi.
12707 // TODO: Generalize to handle loops with many inputs in a header.
12708 if (LPhi->getNumIncomingValues() != 2) return false;
12709
12710 auto *RLoop = RAR->getLoop();
12711 auto *Predecessor = RLoop->getLoopPredecessor();
12712 assert(Predecessor && "Loop with AddRec with no predecessor?");
12713 const SCEV *L1 = getSCEV(LPhi->getIncomingValueForBlock(Predecessor));
12714 if (!ProvedEasily(L1, RAR->getStart()))
12715 return false;
12716 auto *Latch = RLoop->getLoopLatch();
12717 assert(Latch && "Loop with AddRec with no latch?");
12718 const SCEV *L2 = getSCEV(LPhi->getIncomingValueForBlock(Latch));
12719 if (!ProvedEasily(L2, RAR->getPostIncExpr(*this)))
12720 return false;
12721 } else {
12722 // In all other cases go over inputs of LHS and compare each of them to RHS,
12723 // the predicate is true for (LHS, RHS) if it is true for all such pairs.
12724 // At this point RHS is either a non-Phi, or it is a Phi from some block
12725 // different from LBB.
12726 for (const BasicBlock *IncBB : predecessors(LBB)) {
12727 // Check that RHS is available in this block.
12728 if (!dominates(RHS, IncBB))
12729 return false;
12730 const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
12731 // Make sure L does not refer to a value from a potentially previous
12732 // iteration of a loop.
12733 if (!properlyDominates(L, LBB))
12734 return false;
12735 // Addrecs are considered to properly dominate their loop, so are missed
12736 // by the previous check. Discard any values that have computable
12737 // evolution in this loop.
12738 if (auto *Loop = LI.getLoopFor(LBB))
12739 if (hasComputableLoopEvolution(L, Loop))
12740 return false;
12741 if (!ProvedEasily(L, RHS))
12742 return false;
12743 }
12744 }
12745 return true;
12746}
12747
12748bool ScalarEvolution::isImpliedCondOperandsViaShift(CmpPredicate Pred,
12749 const SCEV *LHS,
12750 const SCEV *RHS,
12751 const SCEV *FoundLHS,
12752 const SCEV *FoundRHS) {
12753 // We want to imply LHS < RHS from LHS < (RHS >> shiftvalue). First, make
12754 // sure that we are dealing with same LHS.
12755 if (RHS == FoundRHS) {
12756 std::swap(LHS, RHS);
12757 std::swap(FoundLHS, FoundRHS);
12759 }
12760 if (LHS != FoundLHS)
12761 return false;
12762
12763 auto *SUFoundRHS = dyn_cast<SCEVUnknown>(FoundRHS);
12764 if (!SUFoundRHS)
12765 return false;
12766
12767 Value *Shiftee, *ShiftValue;
12768
12769 using namespace PatternMatch;
12770 if (match(SUFoundRHS->getValue(),
12771 m_LShr(m_Value(Shiftee), m_Value(ShiftValue)))) {
12772 auto *ShifteeS = getSCEV(Shiftee);
12773 // Prove one of the following:
12774 // LHS <u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <u RHS
12775 // LHS <=u (shiftee >> shiftvalue) && shiftee <=u RHS ---> LHS <=u RHS
12776 // LHS <s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12777 // ---> LHS <s RHS
12778 // LHS <=s (shiftee >> shiftvalue) && shiftee <=s RHS && shiftee >=s 0
12779 // ---> LHS <=s RHS
12780 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE)
12781 return isKnownPredicate(ICmpInst::ICMP_ULE, ShifteeS, RHS);
12782 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE)
12783 if (isKnownNonNegative(ShifteeS))
12784 return isKnownPredicate(ICmpInst::ICMP_SLE, ShifteeS, RHS);
12785 }
12786
12787 return false;
12788}
12789
12790bool ScalarEvolution::isImpliedCondOperands(CmpPredicate Pred, const SCEV *LHS,
12791 const SCEV *RHS,
12792 const SCEV *FoundLHS,
12793 const SCEV *FoundRHS,
12794 const Instruction *CtxI) {
12795 return isImpliedCondOperandsViaRanges(Pred, LHS, RHS, Pred, FoundLHS,
12796 FoundRHS) ||
12797 isImpliedCondOperandsViaNoOverflow(Pred, LHS, RHS, FoundLHS,
12798 FoundRHS) ||
12799 isImpliedCondOperandsViaShift(Pred, LHS, RHS, FoundLHS, FoundRHS) ||
12800 isImpliedCondOperandsViaAddRecStart(Pred, LHS, RHS, FoundLHS, FoundRHS,
12801 CtxI) ||
12802 isImpliedCondOperandsHelper(Pred, LHS, RHS, FoundLHS, FoundRHS);
12803}
12804
12805/// Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
12806template <typename MinMaxExprType>
12807static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr,
12808 const SCEV *Candidate) {
12809 const MinMaxExprType *MinMaxExpr = dyn_cast<MinMaxExprType>(MaybeMinMaxExpr);
12810 if (!MinMaxExpr)
12811 return false;
12812
12813 return is_contained(MinMaxExpr->operands(), Candidate);
12814}
12815
12817 CmpPredicate Pred, const SCEV *LHS,
12818 const SCEV *RHS) {
12819 // If both sides are affine addrecs for the same loop, with equal
12820 // steps, and we know the recurrences don't wrap, then we only
12821 // need to check the predicate on the starting values.
12822
12823 if (!ICmpInst::isRelational(Pred))
12824 return false;
12825
12826 const SCEV *LStart, *RStart, *Step;
12827 const Loop *L;
12828 if (!match(LHS,
12829 m_scev_AffineAddRec(m_SCEV(LStart), m_SCEV(Step), m_Loop(L))) ||
12831 m_SpecificLoop(L))))
12832 return false;
12837 if (!LAR->getNoWrapFlags(NW) || !RAR->getNoWrapFlags(NW))
12838 return false;
12839
12840 return SE.isKnownPredicate(Pred, LStart, RStart);
12841}
12842
12843/// Is LHS `Pred` RHS true on the virtue of LHS or RHS being a Min or Max
12844/// expression?
12846 const SCEV *LHS, const SCEV *RHS) {
12847 switch (Pred) {
12848 default:
12849 return false;
12850
12851 case ICmpInst::ICMP_SGE:
12852 std::swap(LHS, RHS);
12853 [[fallthrough]];
12854 case ICmpInst::ICMP_SLE:
12855 return
12856 // min(A, ...) <= A
12858 // A <= max(A, ...)
12860
12861 case ICmpInst::ICMP_UGE:
12862 std::swap(LHS, RHS);
12863 [[fallthrough]];
12864 case ICmpInst::ICMP_ULE:
12865 return
12866 // min(A, ...) <= A
12867 // FIXME: what about umin_seq?
12869 // A <= max(A, ...)
12871 }
12872
12873 llvm_unreachable("covered switch fell through?!");
12874}
12875
12876bool ScalarEvolution::isImpliedViaOperations(CmpPredicate Pred, const SCEV *LHS,
12877 const SCEV *RHS,
12878 const SCEV *FoundLHS,
12879 const SCEV *FoundRHS,
12880 unsigned Depth) {
12883 "LHS and RHS have different sizes?");
12884 assert(getTypeSizeInBits(FoundLHS->getType()) ==
12885 getTypeSizeInBits(FoundRHS->getType()) &&
12886 "FoundLHS and FoundRHS have different sizes?");
12887 // We want to avoid hurting the compile time with analysis of too big trees.
12889 return false;
12890
12891 // We only want to work with GT comparison so far.
12892 if (ICmpInst::isLT(Pred)) {
12894 std::swap(LHS, RHS);
12895 std::swap(FoundLHS, FoundRHS);
12896 }
12897
12899
12900 // For unsigned, try to reduce it to corresponding signed comparison.
12901 if (P == ICmpInst::ICMP_UGT)
12902 // We can replace unsigned predicate with its signed counterpart if all
12903 // involved values are non-negative.
12904 // TODO: We could have better support for unsigned.
12905 if (isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) {
12906 // Knowing that both FoundLHS and FoundRHS are non-negative, and knowing
12907 // FoundLHS >u FoundRHS, we also know that FoundLHS >s FoundRHS. Let us
12908 // use this fact to prove that LHS and RHS are non-negative.
12909 const SCEV *MinusOne = getMinusOne(LHS->getType());
12910 if (isImpliedCondOperands(ICmpInst::ICMP_SGT, LHS, MinusOne, FoundLHS,
12911 FoundRHS) &&
12912 isImpliedCondOperands(ICmpInst::ICMP_SGT, RHS, MinusOne, FoundLHS,
12913 FoundRHS))
12915 }
12916
12917 if (P != ICmpInst::ICMP_SGT)
12918 return false;
12919
12920 auto GetOpFromSExt = [&](const SCEV *S) -> const SCEV * {
12921 if (auto *Ext = dyn_cast<SCEVSignExtendExpr>(S))
12922 return Ext->getOperand();
12923 // TODO: If S is a SCEVConstant then you can cheaply "strip" the sext off
12924 // the constant in some cases.
12925 return S;
12926 };
12927
12928 // Acquire values from extensions.
12929 auto *OrigLHS = LHS;
12930 auto *OrigFoundLHS = FoundLHS;
12931 LHS = GetOpFromSExt(LHS);
12932 FoundLHS = GetOpFromSExt(FoundLHS);
12933
12934 // Is the SGT predicate can be proved trivially or using the found context.
12935 auto IsSGTViaContext = [&](const SCEV *S1, const SCEV *S2) {
12936 return isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGT, S1, S2) ||
12937 isImpliedViaOperations(ICmpInst::ICMP_SGT, S1, S2, OrigFoundLHS,
12938 FoundRHS, Depth + 1);
12939 };
12940
12941 if (auto *LHSAddExpr = dyn_cast<SCEVAddExpr>(LHS)) {
12942 // We want to avoid creation of any new non-constant SCEV. Since we are
12943 // going to compare the operands to RHS, we should be certain that we don't
12944 // need any size extensions for this. So let's decline all cases when the
12945 // sizes of types of LHS and RHS do not match.
12946 // TODO: Maybe try to get RHS from sext to catch more cases?
12948 return false;
12949
12950 // Should not overflow.
12951 if (!LHSAddExpr->hasNoSignedWrap())
12952 return false;
12953
12954 SCEVUse LL = LHSAddExpr->getOperand(0);
12955 SCEVUse LR = LHSAddExpr->getOperand(1);
12956 auto *MinusOne = getMinusOne(RHS->getType());
12957
12958 // Checks that S1 >= 0 && S2 > RHS, trivially or using the found context.
12959 auto IsSumGreaterThanRHS = [&](const SCEV *S1, const SCEV *S2) {
12960 return IsSGTViaContext(S1, MinusOne) && IsSGTViaContext(S2, RHS);
12961 };
12962 // Try to prove the following rule:
12963 // (LHS = LL + LR) && (LL >= 0) && (LR > RHS) => (LHS > RHS).
12964 // (LHS = LL + LR) && (LR >= 0) && (LL > RHS) => (LHS > RHS).
12965 if (IsSumGreaterThanRHS(LL, LR) || IsSumGreaterThanRHS(LR, LL))
12966 return true;
12967 } else if (auto *LHSUnknownExpr = dyn_cast<SCEVUnknown>(LHS)) {
12968 Value *LL, *LR;
12969 // FIXME: Once we have SDiv implemented, we can get rid of this matching.
12970
12971 using namespace llvm::PatternMatch;
12972
12973 if (match(LHSUnknownExpr->getValue(), m_SDiv(m_Value(LL), m_Value(LR)))) {
12974 // Rules for division.
12975 // We are going to perform some comparisons with Denominator and its
12976 // derivative expressions. In general case, creating a SCEV for it may
12977 // lead to a complex analysis of the entire graph, and in particular it
12978 // can request trip count recalculation for the same loop. This would
12979 // cache as SCEVCouldNotCompute to avoid the infinite recursion. To avoid
12980 // this, we only want to create SCEVs that are constants in this section.
12981 // So we bail if Denominator is not a constant.
12982 if (!isa<ConstantInt>(LR))
12983 return false;
12984
12985 auto *Denominator = cast<SCEVConstant>(getSCEV(LR));
12986
12987 // We want to make sure that LHS = FoundLHS / Denominator. If it is so,
12988 // then a SCEV for the numerator already exists and matches with FoundLHS.
12989 auto *Numerator = getExistingSCEV(LL);
12990 if (!Numerator || Numerator->getType() != FoundLHS->getType())
12991 return false;
12992
12993 // Make sure that the numerator matches with FoundLHS and the denominator
12994 // is positive.
12995 if (!HasSameValue(Numerator, FoundLHS) || !isKnownPositive(Denominator))
12996 return false;
12997
12998 auto *DTy = Denominator->getType();
12999 auto *FRHSTy = FoundRHS->getType();
13000 if (DTy->isPointerTy() != FRHSTy->isPointerTy())
13001 // One of types is a pointer and another one is not. We cannot extend
13002 // them properly to a wider type, so let us just reject this case.
13003 // TODO: Usage of getEffectiveSCEVType for DTy, FRHSTy etc should help
13004 // to avoid this check.
13005 return false;
13006
13007 // Given that:
13008 // FoundLHS > FoundRHS, LHS = FoundLHS / Denominator, Denominator > 0.
13009 auto *WTy = getWiderType(DTy, FRHSTy);
13010 auto *DenominatorExt = getNoopOrSignExtend(Denominator, WTy);
13011 auto *FoundRHSExt = getNoopOrSignExtend(FoundRHS, WTy);
13012
13013 // Try to prove the following rule:
13014 // (FoundRHS > Denominator - 2) && (RHS <= 0) => (LHS > RHS).
13015 // For example, given that FoundLHS > 2. It means that FoundLHS is at
13016 // least 3. If we divide it by Denominator < 4, we will have at least 1.
13017 auto *DenomMinusTwo = getMinusSCEV(DenominatorExt, getConstant(WTy, 2));
13018 if (isKnownNonPositive(RHS) &&
13019 IsSGTViaContext(FoundRHSExt, DenomMinusTwo))
13020 return true;
13021
13022 // Try to prove the following rule:
13023 // (FoundRHS > -1 - Denominator) && (RHS < 0) => (LHS > RHS).
13024 // For example, given that FoundLHS > -3. Then FoundLHS is at least -2.
13025 // If we divide it by Denominator > 2, then:
13026 // 1. If FoundLHS is negative, then the result is 0.
13027 // 2. If FoundLHS is non-negative, then the result is non-negative.
13028 // Anyways, the result is non-negative.
13029 auto *MinusOne = getMinusOne(WTy);
13030 auto *NegDenomMinusOne = getMinusSCEV(MinusOne, DenominatorExt);
13031 if (isKnownNegative(RHS) &&
13032 IsSGTViaContext(FoundRHSExt, NegDenomMinusOne))
13033 return true;
13034 }
13035 }
13036
13037 // If our expression contained SCEVUnknown Phis, and we split it down and now
13038 // need to prove something for them, try to prove the predicate for every
13039 // possible incoming values of those Phis.
13040 if (isImpliedViaMerge(Pred, OrigLHS, RHS, OrigFoundLHS, FoundRHS, Depth + 1))
13041 return true;
13042
13043 return false;
13044}
13045
13047 const SCEV *RHS) {
13048 // zext x u<= sext x, sext x s<= zext x
13049 const SCEV *Op;
13050 switch (Pred) {
13051 case ICmpInst::ICMP_SGE:
13052 std::swap(LHS, RHS);
13053 [[fallthrough]];
13054 case ICmpInst::ICMP_SLE: {
13055 // If operand >=s 0 then ZExt == SExt. If operand <s 0 then SExt <s ZExt.
13056 return match(LHS, m_scev_SExt(m_SCEV(Op))) &&
13058 }
13059 case ICmpInst::ICMP_UGE:
13060 std::swap(LHS, RHS);
13061 [[fallthrough]];
13062 case ICmpInst::ICMP_ULE: {
13063 // If operand >=u 0 then ZExt == SExt. If operand <u 0 then ZExt <u SExt.
13064 return match(LHS, m_scev_ZExt(m_SCEV(Op))) &&
13066 }
13067 default:
13068 return false;
13069 };
13070 llvm_unreachable("unhandled case");
13071}
13072
13073bool ScalarEvolution::isKnownViaNonRecursiveReasoning(CmpPredicate Pred,
13074 SCEVUse LHS,
13075 SCEVUse RHS) {
13076 return isKnownPredicateExtendIdiom(Pred, LHS, RHS) ||
13077 isKnownPredicateViaConstantRanges(Pred, LHS, RHS) ||
13078 IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) ||
13079 IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) ||
13080 isKnownPredicateViaNoOverflow(Pred, LHS, RHS);
13081}
13082
13083bool ScalarEvolution::isImpliedCondOperandsHelper(CmpPredicate Pred,
13084 const SCEV *LHS,
13085 const SCEV *RHS,
13086 const SCEV *FoundLHS,
13087 const SCEV *FoundRHS) {
13088 switch (Pred) {
13089 default:
13090 llvm_unreachable("Unexpected CmpPredicate value!");
13091 case ICmpInst::ICMP_EQ:
13092 case ICmpInst::ICMP_NE:
13093 if (HasSameValue(LHS, FoundLHS) && HasSameValue(RHS, FoundRHS))
13094 return true;
13095 break;
13096 case ICmpInst::ICMP_SLT:
13097 case ICmpInst::ICMP_SLE:
13098 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, LHS, FoundLHS) &&
13099 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, RHS, FoundRHS))
13100 return true;
13101 break;
13102 case ICmpInst::ICMP_SGT:
13103 case ICmpInst::ICMP_SGE:
13104 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, LHS, FoundLHS) &&
13105 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, RHS, FoundRHS))
13106 return true;
13107 break;
13108 case ICmpInst::ICMP_ULT:
13109 case ICmpInst::ICMP_ULE:
13110 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, LHS, FoundLHS) &&
13111 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, RHS, FoundRHS))
13112 return true;
13113 break;
13114 case ICmpInst::ICMP_UGT:
13115 case ICmpInst::ICMP_UGE:
13116 if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, LHS, FoundLHS) &&
13117 isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, RHS, FoundRHS))
13118 return true;
13119 break;
13120 }
13121
13122 // Maybe it can be proved via operations?
13123 if (isImpliedViaOperations(Pred, LHS, RHS, FoundLHS, FoundRHS))
13124 return true;
13125
13126 return false;
13127}
13128
13129bool ScalarEvolution::isImpliedCondOperandsViaRanges(
13130 CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, CmpPredicate FoundPred,
13131 const SCEV *FoundLHS, const SCEV *FoundRHS) {
13132 if (!isa<SCEVConstant>(RHS) || !isa<SCEVConstant>(FoundRHS))
13133 // The restriction on `FoundRHS` be lifted easily -- it exists only to
13134 // reduce the compile time impact of this optimization.
13135 return false;
13136
13137 std::optional<APInt> Addend = computeConstantDifference(LHS, FoundLHS);
13138 if (!Addend)
13139 return false;
13140
13141 const APInt &ConstFoundRHS = cast<SCEVConstant>(FoundRHS)->getAPInt();
13142
13143 // `FoundLHSRange` is the range we know `FoundLHS` to be in by virtue of the
13144 // antecedent "`FoundLHS` `FoundPred` `FoundRHS`".
13145 ConstantRange FoundLHSRange =
13146 ConstantRange::makeExactICmpRegion(FoundPred, ConstFoundRHS);
13147
13148 // Since `LHS` is `FoundLHS` + `Addend`, we can compute a range for `LHS`:
13149 ConstantRange LHSRange = FoundLHSRange.add(ConstantRange(*Addend));
13150
13151 // We can also compute the range of values for `LHS` that satisfy the
13152 // consequent, "`LHS` `Pred` `RHS`":
13153 const APInt &ConstRHS = cast<SCEVConstant>(RHS)->getAPInt();
13154 // The antecedent implies the consequent if every value of `LHS` that
13155 // satisfies the antecedent also satisfies the consequent.
13156 return LHSRange.icmp(Pred, ConstRHS);
13157}
13158
13159bool ScalarEvolution::canIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride,
13160 bool IsSigned) {
13161 assert(isKnownPositive(Stride) && "Positive stride expected!");
13162
13163 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
13164 const SCEV *One = getOne(Stride->getType());
13165
13166 if (IsSigned) {
13167 APInt MaxRHS = getSignedRangeMax(RHS);
13168 APInt MaxValue = APInt::getSignedMaxValue(BitWidth);
13169 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
13170
13171 // SMaxRHS + SMaxStrideMinusOne > SMaxValue => overflow!
13172 return (std::move(MaxValue) - MaxStrideMinusOne).slt(MaxRHS);
13173 }
13174
13175 APInt MaxRHS = getUnsignedRangeMax(RHS);
13176 APInt MaxValue = APInt::getMaxValue(BitWidth);
13177 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
13178
13179 // UMaxRHS + UMaxStrideMinusOne > UMaxValue => overflow!
13180 return (std::move(MaxValue) - MaxStrideMinusOne).ult(MaxRHS);
13181}
13182
13183bool ScalarEvolution::canIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride,
13184 bool IsSigned) {
13185
13186 unsigned BitWidth = getTypeSizeInBits(RHS->getType());
13187 const SCEV *One = getOne(Stride->getType());
13188
13189 if (IsSigned) {
13190 APInt MinRHS = getSignedRangeMin(RHS);
13191 APInt MinValue = APInt::getSignedMinValue(BitWidth);
13192 APInt MaxStrideMinusOne = getSignedRangeMax(getMinusSCEV(Stride, One));
13193
13194 // SMinRHS - SMaxStrideMinusOne < SMinValue => overflow!
13195 return (std::move(MinValue) + MaxStrideMinusOne).sgt(MinRHS);
13196 }
13197
13198 APInt MinRHS = getUnsignedRangeMin(RHS);
13199 APInt MinValue = APInt::getMinValue(BitWidth);
13200 APInt MaxStrideMinusOne = getUnsignedRangeMax(getMinusSCEV(Stride, One));
13201
13202 // UMinRHS - UMaxStrideMinusOne < UMinValue => overflow!
13203 return (std::move(MinValue) + MaxStrideMinusOne).ugt(MinRHS);
13204}
13205
13207 // umin(N, 1) + floor((N - umin(N, 1)) / D)
13208 // This is equivalent to "1 + floor((N - 1) / D)" for N != 0. The umin
13209 // expression fixes the case of N=0.
13210 const SCEV *MinNOne = getUMinExpr(N, getOne(N->getType()));
13211 const SCEV *NMinusOne = getMinusSCEV(N, MinNOne);
13212 return getAddExpr(MinNOne, getUDivExpr(NMinusOne, D));
13213}
13214
13215const SCEV *ScalarEvolution::computeMaxBECountForLT(const SCEV *Start,
13216 const SCEV *Stride,
13217 const SCEV *End,
13218 unsigned BitWidth,
13219 bool IsSigned) {
13220 // The logic in this function assumes we can represent a positive stride.
13221 // If we can't, the backedge-taken count must be zero.
13222 if (IsSigned && BitWidth == 1)
13223 return getZero(Stride->getType());
13224
13225 // This code below only been closely audited for negative strides in the
13226 // unsigned comparison case, it may be correct for signed comparison, but
13227 // that needs to be established.
13228 if (IsSigned && isKnownNegative(Stride))
13229 return getCouldNotCompute();
13230
13231 // Calculate the maximum backedge count based on the range of values
13232 // permitted by Start, End, and Stride.
13233 APInt MinStart =
13234 IsSigned ? getSignedRangeMin(Start) : getUnsignedRangeMin(Start);
13235
13236 APInt MinStride =
13237 IsSigned ? getSignedRangeMin(Stride) : getUnsignedRangeMin(Stride);
13238
13239 // We assume either the stride is positive, or the backedge-taken count
13240 // is zero. So force StrideForMaxBECount to be at least one.
13241 APInt One(BitWidth, 1);
13242 APInt StrideForMaxBECount = IsSigned ? APIntOps::smax(One, MinStride)
13243 : APIntOps::umax(One, MinStride);
13244
13245 APInt MaxValue = IsSigned ? APInt::getSignedMaxValue(BitWidth)
13246 : APInt::getMaxValue(BitWidth);
13247 APInt Limit = MaxValue - (StrideForMaxBECount - 1);
13248
13249 // Although End can be a MAX expression we estimate MaxEnd considering only
13250 // the case End = RHS of the loop termination condition. This is safe because
13251 // in the other case (End - Start) is zero, leading to a zero maximum backedge
13252 // taken count.
13253 APInt MaxEnd = IsSigned ? APIntOps::smin(getSignedRangeMax(End), Limit)
13254 : APIntOps::umin(getUnsignedRangeMax(End), Limit);
13255
13256 // MaxBECount = ceil((max(MaxEnd, MinStart) - MinStart) / Stride)
13257 MaxEnd = IsSigned ? APIntOps::smax(MaxEnd, MinStart)
13258 : APIntOps::umax(MaxEnd, MinStart);
13259
13260 return getUDivCeilSCEV(getConstant(MaxEnd - MinStart) /* Delta */,
13261 getConstant(StrideForMaxBECount) /* Step */);
13262}
13263
13265ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
13266 const Loop *L, bool IsSigned,
13267 bool ControlsOnlyExit, bool AllowPredicates) {
13269
13271 bool PredicatedIV = false;
13272 if (!IV) {
13273 if (auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(LHS)) {
13274 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(ZExt->getOperand());
13275 if (AR && AR->getLoop() == L && AR->isAffine()) {
13276 auto canProveNUW = [&]() {
13277 // We can use the comparison to infer no-wrap flags only if it fully
13278 // controls the loop exit.
13279 if (!ControlsOnlyExit)
13280 return false;
13281
13282 if (!isLoopInvariant(RHS, L))
13283 return false;
13284
13285 if (!isKnownNonZero(AR->getStepRecurrence(*this)))
13286 // We need the sequence defined by AR to strictly increase in the
13287 // unsigned integer domain for the logic below to hold.
13288 return false;
13289
13290 const unsigned InnerBitWidth = getTypeSizeInBits(AR->getType());
13291 const unsigned OuterBitWidth = getTypeSizeInBits(RHS->getType());
13292 // If RHS <=u Limit, then there must exist a value V in the sequence
13293 // defined by AR (e.g. {Start,+,Step}) such that V >u RHS, and
13294 // V <=u UINT_MAX. Thus, we must exit the loop before unsigned
13295 // overflow occurs. This limit also implies that a signed comparison
13296 // (in the wide bitwidth) is equivalent to an unsigned comparison as
13297 // the high bits on both sides must be zero.
13298 APInt StrideMax = getUnsignedRangeMax(AR->getStepRecurrence(*this));
13299 APInt Limit = APInt::getMaxValue(InnerBitWidth) - (StrideMax - 1);
13300 Limit = Limit.zext(OuterBitWidth);
13301 return getUnsignedRangeMax(applyLoopGuards(RHS, L)).ule(Limit);
13302 };
13303 auto Flags = AR->getNoWrapFlags();
13304 if (!hasFlags(Flags, SCEV::FlagNUW) && canProveNUW())
13305 Flags = setFlags(Flags, SCEV::FlagNUW);
13306
13307 setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), Flags);
13308 if (AR->hasNoUnsignedWrap()) {
13309 // Emulate what getZeroExtendExpr would have done during construction
13310 // if we'd been able to infer the fact just above at that time.
13311 const SCEV *Step = AR->getStepRecurrence(*this);
13312 Type *Ty = ZExt->getType();
13313 auto *S = getAddRecExpr(
13315 getZeroExtendExpr(Step, Ty, 0), L, AR->getNoWrapFlags());
13317 }
13318 }
13319 }
13320 }
13321
13322
13323 if (!IV && AllowPredicates) {
13324 // Try to make this an AddRec using runtime tests, in the first X
13325 // iterations of this loop, where X is the SCEV expression found by the
13326 // algorithm below.
13327 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
13328 PredicatedIV = true;
13329 }
13330
13331 // Avoid weird loops
13332 if (!IV || IV->getLoop() != L || !IV->isAffine())
13333 return getCouldNotCompute();
13334
13335 // A precondition of this method is that the condition being analyzed
13336 // reaches an exiting branch which dominates the latch. Given that, we can
13337 // assume that an increment which violates the nowrap specification and
13338 // produces poison must cause undefined behavior when the resulting poison
13339 // value is branched upon and thus we can conclude that the backedge is
13340 // taken no more often than would be required to produce that poison value.
13341 // Note that a well defined loop can exit on the iteration which violates
13342 // the nowrap specification if there is another exit (either explicit or
13343 // implicit/exceptional) which causes the loop to execute before the
13344 // exiting instruction we're analyzing would trigger UB.
13345 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
13346 bool NoWrap = ControlsOnlyExit && any(IV->getNoWrapFlags(WrapType));
13348
13349 const SCEV *Stride = IV->getStepRecurrence(*this);
13350
13351 bool PositiveStride = isKnownPositive(Stride);
13352
13353 // Avoid negative or zero stride values.
13354 if (!PositiveStride) {
13355 // We can compute the correct backedge taken count for loops with unknown
13356 // strides if we can prove that the loop is not an infinite loop with side
13357 // effects. Here's the loop structure we are trying to handle -
13358 //
13359 // i = start
13360 // do {
13361 // A[i] = i;
13362 // i += s;
13363 // } while (i < end);
13364 //
13365 // The backedge taken count for such loops is evaluated as -
13366 // (max(end, start + stride) - start - 1) /u stride
13367 //
13368 // The additional preconditions that we need to check to prove correctness
13369 // of the above formula is as follows -
13370 //
13371 // a) IV is either nuw or nsw depending upon signedness (indicated by the
13372 // NoWrap flag).
13373 // b) the loop is guaranteed to be finite (e.g. is mustprogress and has
13374 // no side effects within the loop)
13375 // c) loop has a single static exit (with no abnormal exits)
13376 //
13377 // Precondition a) implies that if the stride is negative, this is a single
13378 // trip loop. The backedge taken count formula reduces to zero in this case.
13379 //
13380 // Precondition b) and c) combine to imply that if rhs is invariant in L,
13381 // then a zero stride means the backedge can't be taken without executing
13382 // undefined behavior.
13383 //
13384 // The positive stride case is the same as isKnownPositive(Stride) returning
13385 // true (original behavior of the function).
13386 //
13387 if (PredicatedIV || !NoWrap || !loopIsFiniteByAssumption(L) ||
13389 return getCouldNotCompute();
13390
13391 if (!isKnownNonZero(Stride)) {
13392 // If we have a step of zero, and RHS isn't invariant in L, we don't know
13393 // if it might eventually be greater than start and if so, on which
13394 // iteration. We can't even produce a useful upper bound.
13395 if (!isLoopInvariant(RHS, L))
13396 return getCouldNotCompute();
13397
13398 // We allow a potentially zero stride, but we need to divide by stride
13399 // below. Since the loop can't be infinite and this check must control
13400 // the sole exit, we can infer the exit must be taken on the first
13401 // iteration (e.g. backedge count = 0) if the stride is zero. Given that,
13402 // we know the numerator in the divides below must be zero, so we can
13403 // pick an arbitrary non-zero value for the denominator (e.g. stride)
13404 // and produce the right result.
13405 // FIXME: Handle the case where Stride is poison?
13406 auto wouldZeroStrideBeUB = [&]() {
13407 // Proof by contradiction. Suppose the stride were zero. If we can
13408 // prove that the backedge *is* taken on the first iteration, then since
13409 // we know this condition controls the sole exit, we must have an
13410 // infinite loop. We can't have a (well defined) infinite loop per
13411 // check just above.
13412 // Note: The (Start - Stride) term is used to get the start' term from
13413 // (start' + stride,+,stride). Remember that we only care about the
13414 // result of this expression when stride == 0 at runtime.
13415 auto *StartIfZero = getMinusSCEV(IV->getStart(), Stride);
13416 return isLoopEntryGuardedByCond(L, Cond, StartIfZero, RHS);
13417 };
13418 if (!wouldZeroStrideBeUB()) {
13419 Stride = getUMaxExpr(Stride, getOne(Stride->getType()));
13420 }
13421 }
13422 } else if (!NoWrap) {
13423 // Avoid proven overflow cases: this will ensure that the backedge taken
13424 // count will not generate any unsigned overflow.
13425 if (canIVOverflowOnLT(RHS, Stride, IsSigned))
13426 return getCouldNotCompute();
13427 }
13428
13429 // On all paths just preceeding, we established the following invariant:
13430 // IV can be assumed not to overflow up to and including the exiting
13431 // iteration. We proved this in one of two ways:
13432 // 1) We can show overflow doesn't occur before the exiting iteration
13433 // 1a) canIVOverflowOnLT, and b) step of one
13434 // 2) We can show that if overflow occurs, the loop must execute UB
13435 // before any possible exit.
13436 // Note that we have not yet proved RHS invariant (in general).
13437
13438 const SCEV *Start = IV->getStart();
13439
13440 // Preserve pointer-typed Start/RHS to pass to isLoopEntryGuardedByCond.
13441 // If we convert to integers, isLoopEntryGuardedByCond will miss some cases.
13442 // Use integer-typed versions for actual computation; we can't subtract
13443 // pointers in general.
13444 const SCEV *OrigStart = Start;
13445 const SCEV *OrigRHS = RHS;
13446 if (Start->getType()->isPointerTy()) {
13448 if (isa<SCEVCouldNotCompute>(Start))
13449 return Start;
13450 }
13451 if (RHS->getType()->isPointerTy()) {
13454 return RHS;
13455 }
13456
13457 const SCEV *End = nullptr, *BECount = nullptr,
13458 *BECountIfBackedgeTaken = nullptr;
13459 if (!isLoopInvariant(RHS, L)) {
13460 const auto *RHSAddRec = dyn_cast<SCEVAddRecExpr>(RHS);
13461 if (PositiveStride && RHSAddRec != nullptr && RHSAddRec->getLoop() == L &&
13462 any(RHSAddRec->getNoWrapFlags())) {
13463 // The structure of loop we are trying to calculate backedge count of:
13464 //
13465 // left = left_start
13466 // right = right_start
13467 //
13468 // while(left < right){
13469 // ... do something here ...
13470 // left += s1; // stride of left is s1 (s1 > 0)
13471 // right += s2; // stride of right is s2 (s2 < 0)
13472 // }
13473 //
13474
13475 const SCEV *RHSStart = RHSAddRec->getStart();
13476 const SCEV *RHSStride = RHSAddRec->getStepRecurrence(*this);
13477
13478 // If Stride - RHSStride is positive and does not overflow, we can write
13479 // backedge count as ->
13480 // ceil((End - Start) /u (Stride - RHSStride))
13481 // Where, End = max(RHSStart, Start)
13482
13483 // Check if RHSStride < 0 and Stride - RHSStride will not overflow.
13484 if (isKnownNegative(RHSStride) &&
13485 willNotOverflow(Instruction::Sub, /*Signed=*/true, Stride,
13486 RHSStride)) {
13487
13488 const SCEV *Denominator = getMinusSCEV(Stride, RHSStride);
13489 if (isKnownPositive(Denominator)) {
13490 End = IsSigned ? getSMaxExpr(RHSStart, Start)
13491 : getUMaxExpr(RHSStart, Start);
13492
13493 // We can do this because End >= Start, as End = max(RHSStart, Start)
13494 const SCEV *Delta = getMinusSCEV(End, Start);
13495
13496 BECount = getUDivCeilSCEV(Delta, Denominator);
13497 BECountIfBackedgeTaken =
13498 getUDivCeilSCEV(getMinusSCEV(RHSStart, Start), Denominator);
13499 }
13500 }
13501 }
13502 if (BECount == nullptr) {
13503 // If we cannot calculate ExactBECount, we can calculate the MaxBECount,
13504 // given the start, stride and max value for the end bound of the
13505 // loop (RHS), and the fact that IV does not overflow (which is
13506 // checked above).
13507 const SCEV *MaxBECount = computeMaxBECountForLT(
13508 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13509 return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount,
13510 MaxBECount, false /*MaxOrZero*/, Predicates);
13511 }
13512 } else {
13513 // We use the expression (max(End,Start)-Start)/Stride to describe the
13514 // backedge count, as if the backedge is taken at least once
13515 // max(End,Start) is End and so the result is as above, and if not
13516 // max(End,Start) is Start so we get a backedge count of zero.
13517 auto *OrigStartMinusStride = getMinusSCEV(OrigStart, Stride);
13518 assert(isAvailableAtLoopEntry(OrigStartMinusStride, L) && "Must be!");
13519 assert(isAvailableAtLoopEntry(OrigStart, L) && "Must be!");
13520 assert(isAvailableAtLoopEntry(OrigRHS, L) && "Must be!");
13521 // Can we prove (max(RHS,Start) > Start - Stride?
13522 if (isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigStart) &&
13523 isLoopEntryGuardedByCond(L, Cond, OrigStartMinusStride, OrigRHS)) {
13524 // In this case, we can use a refined formula for computing backedge
13525 // taken count. The general formula remains:
13526 // "End-Start /uceiling Stride" where "End = max(RHS,Start)"
13527 // We want to use the alternate formula:
13528 // "((End - 1) - (Start - Stride)) /u Stride"
13529 // Let's do a quick case analysis to show these are equivalent under
13530 // our precondition that max(RHS,Start) > Start - Stride.
13531 // * For RHS <= Start, the backedge-taken count must be zero.
13532 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
13533 // "((Start - 1) - (Start - Stride)) /u Stride" which simplies to
13534 // "Stride - 1 /u Stride" which is indeed zero for all non-zero values
13535 // of Stride. For 0 stride, we've use umin(1,Stride) above,
13536 // reducing this to the stride of 1 case.
13537 // * For RHS >= Start, the backedge count must be "RHS-Start /uceil
13538 // Stride".
13539 // "((End - 1) - (Start - Stride)) /u Stride" reduces to
13540 // "((RHS - 1) - (Start - Stride)) /u Stride" reassociates to
13541 // "((RHS - (Start - Stride) - 1) /u Stride".
13542 // Our preconditions trivially imply no overflow in that form.
13543 const SCEV *MinusOne = getMinusOne(Stride->getType());
13544 const SCEV *Numerator =
13545 getMinusSCEV(getAddExpr(RHS, MinusOne), getMinusSCEV(Start, Stride));
13546 BECount = getUDivExpr(Numerator, Stride);
13547 }
13548
13549 if (!BECount) {
13550 auto canProveRHSGreaterThanEqualStart = [&]() {
13551 auto CondGE = IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
13552 const SCEV *GuardedRHS = applyLoopGuards(OrigRHS, L);
13553 const SCEV *GuardedStart = applyLoopGuards(OrigStart, L);
13554
13555 if (isLoopEntryGuardedByCond(L, CondGE, OrigRHS, OrigStart) ||
13556 isKnownPredicate(CondGE, GuardedRHS, GuardedStart))
13557 return true;
13558
13559 // (RHS > Start - 1) implies RHS >= Start.
13560 // * "RHS >= Start" is trivially equivalent to "RHS > Start - 1" if
13561 // "Start - 1" doesn't overflow.
13562 // * For signed comparison, if Start - 1 does overflow, it's equal
13563 // to INT_MAX, and "RHS >s INT_MAX" is trivially false.
13564 // * For unsigned comparison, if Start - 1 does overflow, it's equal
13565 // to UINT_MAX, and "RHS >u UINT_MAX" is trivially false.
13566 //
13567 // FIXME: Should isLoopEntryGuardedByCond do this for us?
13568 auto CondGT = IsSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT;
13569 auto *StartMinusOne =
13570 getAddExpr(OrigStart, getMinusOne(OrigStart->getType()));
13571 return isLoopEntryGuardedByCond(L, CondGT, OrigRHS, StartMinusOne);
13572 };
13573
13574 // If we know that RHS >= Start in the context of loop, then we know
13575 // that max(RHS, Start) = RHS at this point.
13576 if (canProveRHSGreaterThanEqualStart()) {
13577 End = RHS;
13578 } else {
13579 // If RHS < Start, the backedge will be taken zero times. So in
13580 // general, we can write the backedge-taken count as:
13581 //
13582 // RHS >= Start ? ceil(RHS - Start) / Stride : 0
13583 //
13584 // We convert it to the following to make it more convenient for SCEV:
13585 //
13586 // ceil(max(RHS, Start) - Start) / Stride
13587 End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start);
13588
13589 // See what would happen if we assume the backedge is taken. This is
13590 // used to compute MaxBECount.
13591 BECountIfBackedgeTaken =
13592 getUDivCeilSCEV(getMinusSCEV(RHS, Start), Stride);
13593 }
13594
13595 // At this point, we know:
13596 //
13597 // 1. If IsSigned, Start <=s End; otherwise, Start <=u End
13598 // 2. The index variable doesn't overflow.
13599 //
13600 // Therefore, we know N exists such that
13601 // (Start + Stride * N) >= End, and computing "(Start + Stride * N)"
13602 // doesn't overflow.
13603 //
13604 // Using this information, try to prove whether the addition in
13605 // "(Start - End) + (Stride - 1)" has unsigned overflow.
13606 const SCEV *One = getOne(Stride->getType());
13607 bool MayAddOverflow = [&] {
13608 if (isKnownToBeAPowerOfTwo(Stride)) {
13609 // Suppose Stride is a power of two, and Start/End are unsigned
13610 // integers. Let UMAX be the largest representable unsigned
13611 // integer.
13612 //
13613 // By the preconditions of this function, we know
13614 // "(Start + Stride * N) >= End", and this doesn't overflow.
13615 // As a formula:
13616 //
13617 // End <= (Start + Stride * N) <= UMAX
13618 //
13619 // Subtracting Start from all the terms:
13620 //
13621 // End - Start <= Stride * N <= UMAX - Start
13622 //
13623 // Since Start is unsigned, UMAX - Start <= UMAX. Therefore:
13624 //
13625 // End - Start <= Stride * N <= UMAX
13626 //
13627 // Stride * N is a multiple of Stride. Therefore,
13628 //
13629 // End - Start <= Stride * N <= UMAX - (UMAX mod Stride)
13630 //
13631 // Since Stride is a power of two, UMAX + 1 is divisible by
13632 // Stride. Therefore, UMAX mod Stride == Stride - 1. So we can
13633 // write:
13634 //
13635 // End - Start <= Stride * N <= UMAX - Stride - 1
13636 //
13637 // Dropping the middle term:
13638 //
13639 // End - Start <= UMAX - Stride - 1
13640 //
13641 // Adding Stride - 1 to both sides:
13642 //
13643 // (End - Start) + (Stride - 1) <= UMAX
13644 //
13645 // In other words, the addition doesn't have unsigned overflow.
13646 //
13647 // A similar proof works if we treat Start/End as signed values.
13648 // Just rewrite steps before "End - Start <= Stride * N <= UMAX"
13649 // to use signed max instead of unsigned max. Note that we're
13650 // trying to prove a lack of unsigned overflow in either case.
13651 return false;
13652 }
13653 if (Start == Stride || Start == getMinusSCEV(Stride, One)) {
13654 // If Start is equal to Stride, (End - Start) + (Stride - 1) == End
13655 // - 1. If !IsSigned, 0 <u Stride == Start <=u End; so 0 <u End - 1
13656 // <u End. If IsSigned, 0 <s Stride == Start <=s End; so 0 <s End -
13657 // 1 <s End.
13658 //
13659 // If Start is equal to Stride - 1, (End - Start) + Stride - 1 ==
13660 // End.
13661 return false;
13662 }
13663 return true;
13664 }();
13665
13666 const SCEV *Delta = getMinusSCEV(End, Start);
13667 if (!MayAddOverflow) {
13668 // floor((D + (S - 1)) / S)
13669 // We prefer this formulation if it's legal because it's fewer
13670 // operations.
13671 BECount =
13672 getUDivExpr(getAddExpr(Delta, getMinusSCEV(Stride, One)), Stride);
13673 } else {
13674 BECount = getUDivCeilSCEV(Delta, Stride);
13675 }
13676 }
13677 }
13678
13679 const SCEV *ConstantMaxBECount;
13680 bool MaxOrZero = false;
13681 if (isa<SCEVConstant>(BECount)) {
13682 ConstantMaxBECount = BECount;
13683 } else if (BECountIfBackedgeTaken &&
13684 isa<SCEVConstant>(BECountIfBackedgeTaken)) {
13685 // If we know exactly how many times the backedge will be taken if it's
13686 // taken at least once, then the backedge count will either be that or
13687 // zero.
13688 ConstantMaxBECount = BECountIfBackedgeTaken;
13689 MaxOrZero = true;
13690 } else {
13691 ConstantMaxBECount = computeMaxBECountForLT(
13692 Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);
13693 }
13694
13695 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount) &&
13696 !isa<SCEVCouldNotCompute>(BECount))
13697 ConstantMaxBECount = getConstant(getUnsignedRangeMax(BECount));
13698
13699 const SCEV *SymbolicMaxBECount =
13700 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13701 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, MaxOrZero,
13702 Predicates);
13703}
13704
13705ScalarEvolution::ExitLimit ScalarEvolution::howManyGreaterThans(
13706 const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned,
13707 bool ControlsOnlyExit, bool AllowPredicates) {
13709 // We handle only IV > Invariant
13710 if (!isLoopInvariant(RHS, L))
13711 return getCouldNotCompute();
13712
13713 const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);
13714 if (!IV && AllowPredicates)
13715 // Try to make this an AddRec using runtime tests, in the first X
13716 // iterations of this loop, where X is the SCEV expression found by the
13717 // algorithm below.
13718 IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);
13719
13720 // Avoid weird loops
13721 if (!IV || IV->getLoop() != L || !IV->isAffine())
13722 return getCouldNotCompute();
13723
13724 auto WrapType = IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW;
13725 bool NoWrap = ControlsOnlyExit && any(IV->getNoWrapFlags(WrapType));
13727
13728 const SCEV *Stride = getNegativeSCEV(IV->getStepRecurrence(*this));
13729
13730 // Avoid negative or zero stride values
13731 if (!isKnownPositive(Stride))
13732 return getCouldNotCompute();
13733
13734 // Avoid proven overflow cases: this will ensure that the backedge taken count
13735 // will not generate any unsigned overflow. Relaxed no-overflow conditions
13736 // exploit NoWrapFlags, allowing to optimize in presence of undefined
13737 // behaviors like the case of C language.
13738 if (!Stride->isOne() && !NoWrap)
13739 if (canIVOverflowOnGT(RHS, Stride, IsSigned))
13740 return getCouldNotCompute();
13741
13742 const SCEV *Start = IV->getStart();
13743 const SCEV *End = RHS;
13744 if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) {
13745 // If we know that Start >= RHS in the context of loop, then we know that
13746 // min(RHS, Start) = RHS at this point.
13748 L, IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE, Start, RHS))
13749 End = RHS;
13750 else
13751 End = IsSigned ? getSMinExpr(RHS, Start) : getUMinExpr(RHS, Start);
13752 }
13753
13754 if (Start->getType()->isPointerTy()) {
13756 if (isa<SCEVCouldNotCompute>(Start))
13757 return Start;
13758 }
13759 if (End->getType()->isPointerTy()) {
13760 End = getLosslessPtrToIntExpr(End);
13761 if (isa<SCEVCouldNotCompute>(End))
13762 return End;
13763 }
13764
13765 // Compute ((Start - End) + (Stride - 1)) / Stride.
13766 // FIXME: This can overflow. Holding off on fixing this for now;
13767 // howManyGreaterThans will hopefully be gone soon.
13768 const SCEV *One = getOne(Stride->getType());
13769 const SCEV *BECount = getUDivExpr(
13770 getAddExpr(getMinusSCEV(Start, End), getMinusSCEV(Stride, One)), Stride);
13771
13772 APInt MaxStart = IsSigned ? getSignedRangeMax(Start)
13774
13775 APInt MinStride = IsSigned ? getSignedRangeMin(Stride)
13776 : getUnsignedRangeMin(Stride);
13777
13778 unsigned BitWidth = getTypeSizeInBits(LHS->getType());
13779 APInt Limit = IsSigned ? APInt::getSignedMinValue(BitWidth) + (MinStride - 1)
13780 : APInt::getMinValue(BitWidth) + (MinStride - 1);
13781
13782 // Although End can be a MIN expression we estimate MinEnd considering only
13783 // the case End = RHS. This is safe because in the other case (Start - End)
13784 // is zero, leading to a zero maximum backedge taken count.
13785 APInt MinEnd =
13786 IsSigned ? APIntOps::smax(getSignedRangeMin(RHS), Limit)
13787 : APIntOps::umax(getUnsignedRangeMin(RHS), Limit);
13788
13789 const SCEV *ConstantMaxBECount =
13790 isa<SCEVConstant>(BECount)
13791 ? BECount
13792 : getUDivCeilSCEV(getConstant(MaxStart - MinEnd),
13793 getConstant(MinStride));
13794
13795 if (isa<SCEVCouldNotCompute>(ConstantMaxBECount))
13796 ConstantMaxBECount = BECount;
13797 const SCEV *SymbolicMaxBECount =
13798 isa<SCEVCouldNotCompute>(BECount) ? ConstantMaxBECount : BECount;
13799
13800 return ExitLimit(BECount, ConstantMaxBECount, SymbolicMaxBECount, false,
13801 Predicates);
13802}
13803
13805 ScalarEvolution &SE) const {
13806 if (Range.isFullSet()) // Infinite loop.
13807 return SE.getCouldNotCompute();
13808
13809 // If the start is a non-zero constant, shift the range to simplify things.
13810 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart()))
13811 if (!SC->getValue()->isZero()) {
13813 Operands[0] = SE.getZero(SC->getType());
13814 const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop(),
13816 if (const auto *ShiftedAddRec = dyn_cast<SCEVAddRecExpr>(Shifted))
13817 return ShiftedAddRec->getNumIterationsInRange(
13818 Range.subtract(SC->getAPInt()), SE);
13819 // This is strange and shouldn't happen.
13820 return SE.getCouldNotCompute();
13821 }
13822
13823 // The only time we can solve this is when we have all constant indices.
13824 // Otherwise, we cannot determine the overflow conditions.
13825 if (any_of(operands(), [](const SCEV *Op) { return !isa<SCEVConstant>(Op); }))
13826 return SE.getCouldNotCompute();
13827
13828 // Okay at this point we know that all elements of the chrec are constants and
13829 // that the start element is zero.
13830
13831 // First check to see if the range contains zero. If not, the first
13832 // iteration exits.
13833 unsigned BitWidth = SE.getTypeSizeInBits(getType());
13834 if (!Range.contains(APInt(BitWidth, 0)))
13835 return SE.getZero(getType());
13836
13837 if (isAffine()) {
13838 // If this is an affine expression then we have this situation:
13839 // Solve {0,+,A} in Range === Ax in Range
13840
13841 // We know that zero is in the range. If A is positive then we know that
13842 // the upper value of the range must be the first possible exit value.
13843 // If A is negative then the lower of the range is the last possible loop
13844 // value. Also note that we already checked for a full range.
13845 APInt A = cast<SCEVConstant>(getOperand(1))->getAPInt();
13846 APInt End = A.sge(1) ? (Range.getUpper() - 1) : Range.getLower();
13847
13848 // The exit value should be (End+A)/A.
13849 APInt ExitVal = (End + A).udiv(A);
13850 ConstantInt *ExitValue = ConstantInt::get(SE.getContext(), ExitVal);
13851
13852 // Evaluate at the exit value. If we really did fall out of the valid
13853 // range, then we computed our trip count, otherwise wrap around or other
13854 // things must have happened.
13855 ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue, SE);
13856 if (Range.contains(Val->getValue()))
13857 return SE.getCouldNotCompute(); // Something strange happened
13858
13859 // Ensure that the previous value is in the range.
13860 assert(Range.contains(
13862 ConstantInt::get(SE.getContext(), ExitVal - 1), SE)->getValue()) &&
13863 "Linear scev computation is off in a bad way!");
13864 return SE.getConstant(ExitValue);
13865 }
13866
13867 if (isQuadratic()) {
13868 if (auto S = SolveQuadraticAddRecRange(this, Range, SE))
13869 return SE.getConstant(*S);
13870 }
13871
13872 return SE.getCouldNotCompute();
13873}
13874
13875const SCEVAddRecExpr *
13877 assert(getNumOperands() > 1 && "AddRec with zero step?");
13878 // There is a temptation to just call getAddExpr(this, getStepRecurrence(SE)),
13879 // but in this case we cannot guarantee that the value returned will be an
13880 // AddRec because SCEV does not have a fixed point where it stops
13881 // simplification: it is legal to return ({rec1} + {rec2}). For example, it
13882 // may happen if we reach arithmetic depth limit while simplifying. So we
13883 // construct the returned value explicitly.
13885 // If this is {A,+,B,+,C,...,+,N}, then its step is {B,+,C,+,...,+,N}, and
13886 // (this + Step) is {A+B,+,B+C,+...,+,N}.
13887 for (unsigned i = 0, e = getNumOperands() - 1; i < e; ++i)
13888 Ops.push_back(SE.getAddExpr(getOperand(i), getOperand(i + 1)));
13889 // We know that the last operand is not a constant zero (otherwise it would
13890 // have been popped out earlier). This guarantees us that if the result has
13891 // the same last operand, then it will also not be popped out, meaning that
13892 // the returned value will be an AddRec.
13893 const SCEV *Last = getOperand(getNumOperands() - 1);
13894 assert(!Last->isZero() && "Recurrency with zero step?");
13895 Ops.push_back(Last);
13898}
13899
13900// Return true when S contains at least an undef value.
13902 return SCEVExprContains(
13903 S, [](const SCEV *S) { return match(S, m_scev_UndefOrPoison()); });
13904}
13905
13906// Return true when S contains a value that is a nullptr.
13908 return SCEVExprContains(S, [](const SCEV *S) {
13909 if (const auto *SU = dyn_cast<SCEVUnknown>(S))
13910 return SU->getValue() == nullptr;
13911 return false;
13912 });
13913}
13914
13915/// Return the size of an element read or written by Inst.
13917 Type *Ty;
13918 if (StoreInst *Store = dyn_cast<StoreInst>(Inst))
13919 Ty = Store->getValueOperand()->getType();
13920 else if (LoadInst *Load = dyn_cast<LoadInst>(Inst))
13921 Ty = Load->getType();
13922 else
13923 return nullptr;
13924
13926 return getSizeOfExpr(ETy, Ty);
13927}
13928
13929//===----------------------------------------------------------------------===//
13930// SCEVCallbackVH Class Implementation
13931//===----------------------------------------------------------------------===//
13932
13934 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13935 if (PHINode *PN = dyn_cast<PHINode>(getValPtr()))
13936 SE->ConstantEvolutionLoopExitValue.erase(PN);
13937 SE->eraseValueFromMap(getValPtr());
13938 // this now dangles!
13939}
13940
13941void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) {
13942 assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
13943
13944 // Forget all the expressions associated with users of the old value,
13945 // so that future queries will recompute the expressions using the new
13946 // value.
13947 SE->forgetValue(getValPtr());
13948 // this now dangles!
13949}
13950
13951ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se)
13952 : CallbackVH(V), SE(se) {}
13953
13954//===----------------------------------------------------------------------===//
13955// ScalarEvolution Class Implementation
13956//===----------------------------------------------------------------------===//
13957
13960 LoopInfo &LI)
13961 : F(F), DL(F.getDataLayout()), TLI(TLI), AC(AC), DT(DT), LI(LI),
13962 CouldNotCompute(new SCEVCouldNotCompute()), ValuesAtScopes(64),
13963 LoopDispositions(64), BlockDispositions(64) {
13964 // To use guards for proving predicates, we need to scan every instruction in
13965 // relevant basic blocks, and not just terminators. Doing this is a waste of
13966 // time if the IR does not actually contain any calls to
13967 // @llvm.experimental.guard, so do a quick check and remember this beforehand.
13968 //
13969 // This pessimizes the case where a pass that preserves ScalarEvolution wants
13970 // to _add_ guards to the module when there weren't any before, and wants
13971 // ScalarEvolution to optimize based on those guards. For now we prefer to be
13972 // efficient in lieu of being smart in that rather obscure case.
13973
13974 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
13975 F.getParent(), Intrinsic::experimental_guard);
13976 HasGuards = GuardDecl && !GuardDecl->use_empty();
13977}
13978
13980 : F(Arg.F), DL(Arg.DL), HasGuards(Arg.HasGuards), TLI(Arg.TLI), AC(Arg.AC),
13981 DT(Arg.DT), LI(Arg.LI), CouldNotCompute(std::move(Arg.CouldNotCompute)),
13982 ValueExprMap(std::move(Arg.ValueExprMap)),
13983 PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)),
13984 PendingMerges(std::move(Arg.PendingMerges)),
13985 ConstantMultipleCache(std::move(Arg.ConstantMultipleCache)),
13986 BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)),
13987 PredicatedBackedgeTakenCounts(
13988 std::move(Arg.PredicatedBackedgeTakenCounts)),
13989 BECountUsers(std::move(Arg.BECountUsers)),
13990 ConstantEvolutionLoopExitValue(
13991 std::move(Arg.ConstantEvolutionLoopExitValue)),
13992 ValuesAtScopes(std::move(Arg.ValuesAtScopes)),
13993 ValuesAtScopesUsers(std::move(Arg.ValuesAtScopesUsers)),
13994 LoopDispositions(std::move(Arg.LoopDispositions)),
13995 LoopPropertiesCache(std::move(Arg.LoopPropertiesCache)),
13996 BlockDispositions(std::move(Arg.BlockDispositions)),
13997 SCEVUsers(std::move(Arg.SCEVUsers)),
13998 UnsignedRanges(std::move(Arg.UnsignedRanges)),
13999 SignedRanges(std::move(Arg.SignedRanges)),
14000 UniqueSCEVs(std::move(Arg.UniqueSCEVs)),
14001 UniquePreds(std::move(Arg.UniquePreds)),
14002 SCEVAllocator(std::move(Arg.SCEVAllocator)),
14003 LoopUsers(std::move(Arg.LoopUsers)),
14004 PredicatedSCEVRewrites(std::move(Arg.PredicatedSCEVRewrites)),
14005 FirstUnknown(Arg.FirstUnknown) {
14006 Arg.FirstUnknown = nullptr;
14007}
14008
14010 // Iterate through all the SCEVUnknown instances and call their
14011 // destructors, so that they release their references to their values.
14012 for (SCEVUnknown *U = FirstUnknown; U;) {
14013 SCEVUnknown *Tmp = U;
14014 U = U->Next;
14015 Tmp->~SCEVUnknown();
14016 }
14017 FirstUnknown = nullptr;
14018
14019 ExprValueMap.clear();
14020 ValueExprMap.clear();
14021 HasRecMap.clear();
14022 BackedgeTakenCounts.clear();
14023 PredicatedBackedgeTakenCounts.clear();
14024
14025 assert(PendingLoopPredicates.empty() && "isImpliedCond garbage");
14026 assert(PendingMerges.empty() && "isImpliedViaMerge garbage");
14027 assert(!WalkingBEDominatingConds && "isLoopBackedgeGuardedByCond garbage!");
14028 assert(!ProvingSplitPredicate && "ProvingSplitPredicate garbage!");
14029}
14030
14034
14035/// When printing a top-level SCEV for trip counts, it's helpful to include
14036/// a type for constants which are otherwise hard to disambiguate.
14037static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV* S) {
14038 if (isa<SCEVConstant>(S))
14039 OS << *S->getType() << " ";
14040 OS << *S;
14041}
14042
14044 const Loop *L) {
14045 // Print all inner loops first
14046 for (Loop *I : *L)
14047 PrintLoopInfo(OS, SE, I);
14048
14049 OS << "Loop ";
14050 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14051 OS << ": ";
14052
14053 SmallVector<BasicBlock *, 8> ExitingBlocks;
14054 L->getExitingBlocks(ExitingBlocks);
14055 if (ExitingBlocks.size() != 1)
14056 OS << "<multiple exits> ";
14057
14058 auto *BTC = SE->getBackedgeTakenCount(L);
14059 if (!isa<SCEVCouldNotCompute>(BTC)) {
14060 OS << "backedge-taken count is ";
14061 PrintSCEVWithTypeHint(OS, BTC);
14062 } else
14063 OS << "Unpredictable backedge-taken count.";
14064 OS << "\n";
14065
14066 if (ExitingBlocks.size() > 1)
14067 for (BasicBlock *ExitingBlock : ExitingBlocks) {
14068 OS << " exit count for " << ExitingBlock->getName() << ": ";
14069 const SCEV *EC = SE->getExitCount(L, ExitingBlock);
14070 PrintSCEVWithTypeHint(OS, EC);
14071 if (isa<SCEVCouldNotCompute>(EC)) {
14072 // Retry with predicates.
14074 EC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates);
14075 if (!isa<SCEVCouldNotCompute>(EC)) {
14076 OS << "\n predicated exit count for " << ExitingBlock->getName()
14077 << ": ";
14078 PrintSCEVWithTypeHint(OS, EC);
14079 OS << "\n Predicates:\n";
14080 for (const auto *P : Predicates)
14081 P->print(OS, 4);
14082 }
14083 }
14084 OS << "\n";
14085 }
14086
14087 OS << "Loop ";
14088 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14089 OS << ": ";
14090
14091 auto *ConstantBTC = SE->getConstantMaxBackedgeTakenCount(L);
14092 if (!isa<SCEVCouldNotCompute>(ConstantBTC)) {
14093 OS << "constant max backedge-taken count is ";
14094 PrintSCEVWithTypeHint(OS, ConstantBTC);
14096 OS << ", actual taken count either this or zero.";
14097 } else {
14098 OS << "Unpredictable constant max backedge-taken count. ";
14099 }
14100
14101 OS << "\n"
14102 "Loop ";
14103 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14104 OS << ": ";
14105
14106 auto *SymbolicBTC = SE->getSymbolicMaxBackedgeTakenCount(L);
14107 if (!isa<SCEVCouldNotCompute>(SymbolicBTC)) {
14108 OS << "symbolic max backedge-taken count is ";
14109 PrintSCEVWithTypeHint(OS, SymbolicBTC);
14111 OS << ", actual taken count either this or zero.";
14112 } else {
14113 OS << "Unpredictable symbolic max backedge-taken count. ";
14114 }
14115 OS << "\n";
14116
14117 if (ExitingBlocks.size() > 1)
14118 for (BasicBlock *ExitingBlock : ExitingBlocks) {
14119 OS << " symbolic max exit count for " << ExitingBlock->getName() << ": ";
14120 auto *ExitBTC = SE->getExitCount(L, ExitingBlock,
14122 PrintSCEVWithTypeHint(OS, ExitBTC);
14123 if (isa<SCEVCouldNotCompute>(ExitBTC)) {
14124 // Retry with predicates.
14126 ExitBTC = SE->getPredicatedExitCount(L, ExitingBlock, &Predicates,
14128 if (!isa<SCEVCouldNotCompute>(ExitBTC)) {
14129 OS << "\n predicated symbolic max exit count for "
14130 << ExitingBlock->getName() << ": ";
14131 PrintSCEVWithTypeHint(OS, ExitBTC);
14132 OS << "\n Predicates:\n";
14133 for (const auto *P : Predicates)
14134 P->print(OS, 4);
14135 }
14136 }
14137 OS << "\n";
14138 }
14139
14141 auto *PBT = SE->getPredicatedBackedgeTakenCount(L, Preds);
14142 if (PBT != BTC) {
14143 assert(!Preds.empty() && "Different predicated BTC, but no predicates");
14144 OS << "Loop ";
14145 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14146 OS << ": ";
14147 if (!isa<SCEVCouldNotCompute>(PBT)) {
14148 OS << "Predicated backedge-taken count is ";
14149 PrintSCEVWithTypeHint(OS, PBT);
14150 } else
14151 OS << "Unpredictable predicated backedge-taken count.";
14152 OS << "\n";
14153 OS << " Predicates:\n";
14154 for (const auto *P : Preds)
14155 P->print(OS, 4);
14156 }
14157 Preds.clear();
14158
14159 auto *PredConstantMax =
14161 if (PredConstantMax != ConstantBTC) {
14162 assert(!Preds.empty() &&
14163 "different predicated constant max BTC but no predicates");
14164 OS << "Loop ";
14165 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14166 OS << ": ";
14167 if (!isa<SCEVCouldNotCompute>(PredConstantMax)) {
14168 OS << "Predicated constant max backedge-taken count is ";
14169 PrintSCEVWithTypeHint(OS, PredConstantMax);
14170 } else
14171 OS << "Unpredictable predicated constant max backedge-taken count.";
14172 OS << "\n";
14173 OS << " Predicates:\n";
14174 for (const auto *P : Preds)
14175 P->print(OS, 4);
14176 }
14177 Preds.clear();
14178
14179 auto *PredSymbolicMax =
14181 if (SymbolicBTC != PredSymbolicMax) {
14182 assert(!Preds.empty() &&
14183 "Different predicated symbolic max BTC, but no predicates");
14184 OS << "Loop ";
14185 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14186 OS << ": ";
14187 if (!isa<SCEVCouldNotCompute>(PredSymbolicMax)) {
14188 OS << "Predicated symbolic max backedge-taken count is ";
14189 PrintSCEVWithTypeHint(OS, PredSymbolicMax);
14190 } else
14191 OS << "Unpredictable predicated symbolic max backedge-taken count.";
14192 OS << "\n";
14193 OS << " Predicates:\n";
14194 for (const auto *P : Preds)
14195 P->print(OS, 4);
14196 }
14197
14199 OS << "Loop ";
14200 L->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14201 OS << ": ";
14202 OS << "Trip multiple is " << SE->getSmallConstantTripMultiple(L) << "\n";
14203 }
14204}
14205
14206namespace llvm {
14207// Note: these overloaded operators need to be in the llvm namespace for them
14208// to be resolved correctly. If we put them outside the llvm namespace, the
14209//
14210// OS << ": " << SE.getLoopDisposition(SV, InnerL);
14211//
14212// code below "breaks" and start printing raw enum values as opposed to the
14213// string values.
14216 switch (LD) {
14218 OS << "Variant";
14219 break;
14221 OS << "Invariant";
14222 break;
14224 OS << "Computable";
14225 break;
14226 }
14227 return OS;
14228}
14229
14232 switch (BD) {
14234 OS << "DoesNotDominate";
14235 break;
14237 OS << "Dominates";
14238 break;
14240 OS << "ProperlyDominates";
14241 break;
14242 }
14243 return OS;
14244}
14245} // namespace llvm
14246
14248 // ScalarEvolution's implementation of the print method is to print
14249 // out SCEV values of all instructions that are interesting. Doing
14250 // this potentially causes it to create new SCEV objects though,
14251 // which technically conflicts with the const qualifier. This isn't
14252 // observable from outside the class though, so casting away the
14253 // const isn't dangerous.
14254 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
14255
14256 if (ClassifyExpressions) {
14257 OS << "Classifying expressions for: ";
14258 F.printAsOperand(OS, /*PrintType=*/false);
14259 OS << "\n";
14260 for (Instruction &I : instructions(F))
14261 if (isSCEVable(I.getType()) && !isa<CmpInst>(I)) {
14262 OS << I << '\n';
14263 OS << " --> ";
14264 const SCEV *SV = SE.getSCEV(&I);
14265 SV->print(OS);
14266 if (!isa<SCEVCouldNotCompute>(SV)) {
14267 OS << " U: ";
14268 SE.getUnsignedRange(SV).print(OS);
14269 OS << " S: ";
14270 SE.getSignedRange(SV).print(OS);
14271 }
14272
14273 const Loop *L = LI.getLoopFor(I.getParent());
14274
14275 const SCEV *AtUse = SE.getSCEVAtScope(SV, L);
14276 if (AtUse != SV) {
14277 OS << " --> ";
14278 AtUse->print(OS);
14279 if (!isa<SCEVCouldNotCompute>(AtUse)) {
14280 OS << " U: ";
14281 SE.getUnsignedRange(AtUse).print(OS);
14282 OS << " S: ";
14283 SE.getSignedRange(AtUse).print(OS);
14284 }
14285 }
14286
14287 if (L) {
14288 OS << "\t\t" "Exits: ";
14289 const SCEV *ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop());
14290 if (!SE.isLoopInvariant(ExitValue, L)) {
14291 OS << "<<Unknown>>";
14292 } else {
14293 OS << *ExitValue;
14294 }
14295
14296 ListSeparator LS(", ", "\t\tLoopDispositions: { ");
14297 for (const auto *Iter = L; Iter; Iter = Iter->getParentLoop()) {
14298 OS << LS;
14299 Iter->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14300 OS << ": " << SE.getLoopDisposition(SV, Iter);
14301 }
14302
14303 for (const auto *InnerL : depth_first(L)) {
14304 if (InnerL == L)
14305 continue;
14306 OS << LS;
14307 InnerL->getHeader()->printAsOperand(OS, /*PrintType=*/false);
14308 OS << ": " << SE.getLoopDisposition(SV, InnerL);
14309 }
14310
14311 OS << " }";
14312 }
14313
14314 OS << "\n";
14315 }
14316 }
14317
14318 OS << "Determining loop execution counts for: ";
14319 F.printAsOperand(OS, /*PrintType=*/false);
14320 OS << "\n";
14321 for (Loop *I : LI)
14322 PrintLoopInfo(OS, &SE, I);
14323}
14324
14327 auto &Values = LoopDispositions[S];
14328 for (auto &V : Values) {
14329 if (V.getPointer() == L)
14330 return V.getInt();
14331 }
14332 Values.emplace_back(L, LoopVariant);
14333 LoopDisposition D = computeLoopDisposition(S, L);
14334 auto &Values2 = LoopDispositions[S];
14335 for (auto &V : llvm::reverse(Values2)) {
14336 if (V.getPointer() == L) {
14337 V.setInt(D);
14338 break;
14339 }
14340 }
14341 return D;
14342}
14343
14345ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) {
14346 switch (S->getSCEVType()) {
14347 case scConstant:
14348 case scVScale:
14349 return LoopInvariant;
14350 case scAddRecExpr: {
14351 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
14352
14353 // If L is the addrec's loop, it's computable.
14354 if (AR->getLoop() == L)
14355 return LoopComputable;
14356
14357 // Add recurrences are never invariant in the function-body (null loop).
14358 if (!L)
14359 return LoopVariant;
14360
14361 // Everything that is not defined at loop entry is variant.
14362 if (DT.dominates(L->getHeader(), AR->getLoop()->getHeader()))
14363 return LoopVariant;
14364 assert(!L->contains(AR->getLoop()) && "Containing loop's header does not"
14365 " dominate the contained loop's header?");
14366
14367 // This recurrence is invariant w.r.t. L if AR's loop contains L.
14368 if (AR->getLoop()->contains(L))
14369 return LoopInvariant;
14370
14371 // This recurrence is variant w.r.t. L if any of its operands
14372 // are variant.
14373 for (SCEVUse Op : AR->operands())
14374 if (!isLoopInvariant(Op, L))
14375 return LoopVariant;
14376
14377 // Otherwise it's loop-invariant.
14378 return LoopInvariant;
14379 }
14380 case scTruncate:
14381 case scZeroExtend:
14382 case scSignExtend:
14383 case scPtrToAddr:
14384 case scPtrToInt:
14385 case scAddExpr:
14386 case scMulExpr:
14387 case scUDivExpr:
14388 case scUMaxExpr:
14389 case scSMaxExpr:
14390 case scUMinExpr:
14391 case scSMinExpr:
14392 case scSequentialUMinExpr: {
14393 bool HasVarying = false;
14394 for (SCEVUse Op : S->operands()) {
14396 if (D == LoopVariant)
14397 return LoopVariant;
14398 if (D == LoopComputable)
14399 HasVarying = true;
14400 }
14401 return HasVarying ? LoopComputable : LoopInvariant;
14402 }
14403 case scUnknown:
14404 // All non-instruction values are loop invariant. All instructions are loop
14405 // invariant if they are not contained in the specified loop.
14406 // Instructions are never considered invariant in the function body
14407 // (null loop) because they are defined within the "loop".
14408 if (auto *I = dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue()))
14409 return (L && !L->contains(I)) ? LoopInvariant : LoopVariant;
14410 return LoopInvariant;
14411 case scCouldNotCompute:
14412 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
14413 }
14414 llvm_unreachable("Unknown SCEV kind!");
14415}
14416
14418 return getLoopDisposition(S, L) == LoopInvariant;
14419}
14420
14422 return getLoopDisposition(S, L) == LoopComputable;
14423}
14424
14427 auto &Values = BlockDispositions[S];
14428 for (auto &V : Values) {
14429 if (V.getPointer() == BB)
14430 return V.getInt();
14431 }
14432 Values.emplace_back(BB, DoesNotDominateBlock);
14433 BlockDisposition D = computeBlockDisposition(S, BB);
14434 auto &Values2 = BlockDispositions[S];
14435 for (auto &V : llvm::reverse(Values2)) {
14436 if (V.getPointer() == BB) {
14437 V.setInt(D);
14438 break;
14439 }
14440 }
14441 return D;
14442}
14443
14445ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) {
14446 switch (S->getSCEVType()) {
14447 case scConstant:
14448 case scVScale:
14450 case scAddRecExpr: {
14451 // This uses a "dominates" query instead of "properly dominates" query
14452 // to test for proper dominance too, because the instruction which
14453 // produces the addrec's value is a PHI, and a PHI effectively properly
14454 // dominates its entire containing block.
14455 const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
14456 if (!DT.dominates(AR->getLoop()->getHeader(), BB))
14457 return DoesNotDominateBlock;
14458
14459 // Fall through into SCEVNAryExpr handling.
14460 [[fallthrough]];
14461 }
14462 case scTruncate:
14463 case scZeroExtend:
14464 case scSignExtend:
14465 case scPtrToAddr:
14466 case scPtrToInt:
14467 case scAddExpr:
14468 case scMulExpr:
14469 case scUDivExpr:
14470 case scUMaxExpr:
14471 case scSMaxExpr:
14472 case scUMinExpr:
14473 case scSMinExpr:
14474 case scSequentialUMinExpr: {
14475 bool Proper = true;
14476 for (const SCEV *NAryOp : S->operands()) {
14478 if (D == DoesNotDominateBlock)
14479 return DoesNotDominateBlock;
14480 if (D == DominatesBlock)
14481 Proper = false;
14482 }
14483 return Proper ? ProperlyDominatesBlock : DominatesBlock;
14484 }
14485 case scUnknown:
14486 if (Instruction *I =
14487 dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue())) {
14488 if (I->getParent() == BB)
14489 return DominatesBlock;
14490 if (DT.properlyDominates(I->getParent(), BB))
14492 return DoesNotDominateBlock;
14493 }
14495 case scCouldNotCompute:
14496 llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
14497 }
14498 llvm_unreachable("Unknown SCEV kind!");
14499}
14500
14501bool ScalarEvolution::dominates(const SCEV *S, const BasicBlock *BB) {
14502 return getBlockDisposition(S, BB) >= DominatesBlock;
14503}
14504
14507}
14508
14509bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const {
14510 return SCEVExprContains(S, [&](const SCEV *Expr) { return Expr == Op; });
14511}
14512
14513void ScalarEvolution::forgetBackedgeTakenCounts(const Loop *L,
14514 bool Predicated) {
14515 auto &BECounts =
14516 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
14517 auto It = BECounts.find(L);
14518 if (It != BECounts.end()) {
14519 for (const ExitNotTakenInfo &ENT : It->second.ExitNotTaken) {
14520 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
14521 if (!isa<SCEVConstant>(S)) {
14522 auto UserIt = BECountUsers.find(S);
14523 assert(UserIt != BECountUsers.end());
14524 UserIt->second.erase({L, Predicated});
14525 }
14526 }
14527 }
14528 BECounts.erase(It);
14529 }
14530}
14531
14532void ScalarEvolution::forgetMemoizedResults(ArrayRef<SCEVUse> SCEVs) {
14533 SmallPtrSet<const SCEV *, 8> ToForget(llvm::from_range, SCEVs);
14534 SmallVector<SCEVUse, 8> Worklist(ToForget.begin(), ToForget.end());
14535
14536 while (!Worklist.empty()) {
14537 const SCEV *Curr = Worklist.pop_back_val();
14538 auto Users = SCEVUsers.find(Curr);
14539 if (Users != SCEVUsers.end())
14540 for (const auto *User : Users->second)
14541 if (ToForget.insert(User).second)
14542 Worklist.push_back(User);
14543 }
14544
14545 for (const auto *S : ToForget)
14546 forgetMemoizedResultsImpl(S);
14547
14548 for (auto I = PredicatedSCEVRewrites.begin();
14549 I != PredicatedSCEVRewrites.end();) {
14550 std::pair<const SCEV *, const Loop *> Entry = I->first;
14551 if (ToForget.count(Entry.first))
14552 PredicatedSCEVRewrites.erase(I++);
14553 else
14554 ++I;
14555 }
14556}
14557
14558void ScalarEvolution::forgetMemoizedResultsImpl(const SCEV *S) {
14559 LoopDispositions.erase(S);
14560 BlockDispositions.erase(S);
14561 UnsignedRanges.erase(S);
14562 SignedRanges.erase(S);
14563 HasRecMap.erase(S);
14564 ConstantMultipleCache.erase(S);
14565
14566 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S)) {
14567 UnsignedWrapViaInductionTried.erase(AR);
14568 SignedWrapViaInductionTried.erase(AR);
14569 }
14570
14571 auto ExprIt = ExprValueMap.find(S);
14572 if (ExprIt != ExprValueMap.end()) {
14573 for (Value *V : ExprIt->second) {
14574 auto ValueIt = ValueExprMap.find_as(V);
14575 if (ValueIt != ValueExprMap.end())
14576 ValueExprMap.erase(ValueIt);
14577 }
14578 ExprValueMap.erase(ExprIt);
14579 }
14580
14581 auto ScopeIt = ValuesAtScopes.find(S);
14582 if (ScopeIt != ValuesAtScopes.end()) {
14583 for (const auto &Pair : ScopeIt->second)
14584 if (!isa_and_nonnull<SCEVConstant>(Pair.second))
14585 llvm::erase(ValuesAtScopesUsers[Pair.second],
14586 std::make_pair(Pair.first, S));
14587 ValuesAtScopes.erase(ScopeIt);
14588 }
14589
14590 auto ScopeUserIt = ValuesAtScopesUsers.find(S);
14591 if (ScopeUserIt != ValuesAtScopesUsers.end()) {
14592 for (const auto &Pair : ScopeUserIt->second)
14593 llvm::erase(ValuesAtScopes[Pair.second], std::make_pair(Pair.first, S));
14594 ValuesAtScopesUsers.erase(ScopeUserIt);
14595 }
14596
14597 auto BEUsersIt = BECountUsers.find(S);
14598 if (BEUsersIt != BECountUsers.end()) {
14599 // Work on a copy, as forgetBackedgeTakenCounts() will modify the original.
14600 auto Copy = BEUsersIt->second;
14601 for (const auto &Pair : Copy)
14602 forgetBackedgeTakenCounts(Pair.getPointer(), Pair.getInt());
14603 BECountUsers.erase(BEUsersIt);
14604 }
14605
14606 auto FoldUser = FoldCacheUser.find(S);
14607 if (FoldUser != FoldCacheUser.end())
14608 for (auto &KV : FoldUser->second)
14609 FoldCache.erase(KV);
14610 FoldCacheUser.erase(S);
14611}
14612
14613void
14614ScalarEvolution::getUsedLoops(const SCEV *S,
14615 SmallPtrSetImpl<const Loop *> &LoopsUsed) {
14616 struct FindUsedLoops {
14617 FindUsedLoops(SmallPtrSetImpl<const Loop *> &LoopsUsed)
14618 : LoopsUsed(LoopsUsed) {}
14619 SmallPtrSetImpl<const Loop *> &LoopsUsed;
14620 bool follow(const SCEV *S) {
14621 if (auto *AR = dyn_cast<SCEVAddRecExpr>(S))
14622 LoopsUsed.insert(AR->getLoop());
14623 return true;
14624 }
14625
14626 bool isDone() const { return false; }
14627 };
14628
14629 FindUsedLoops F(LoopsUsed);
14630 SCEVTraversal<FindUsedLoops>(F).visitAll(S);
14631}
14632
14633void ScalarEvolution::getReachableBlocks(
14636 Worklist.push_back(&F.getEntryBlock());
14637 while (!Worklist.empty()) {
14638 BasicBlock *BB = Worklist.pop_back_val();
14639 if (!Reachable.insert(BB).second)
14640 continue;
14641
14642 Value *Cond;
14643 BasicBlock *TrueBB, *FalseBB;
14644 if (match(BB->getTerminator(), m_Br(m_Value(Cond), m_BasicBlock(TrueBB),
14645 m_BasicBlock(FalseBB)))) {
14646 if (auto *C = dyn_cast<ConstantInt>(Cond)) {
14647 Worklist.push_back(C->isOne() ? TrueBB : FalseBB);
14648 continue;
14649 }
14650
14651 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
14652 const SCEV *L = getSCEV(Cmp->getOperand(0));
14653 const SCEV *R = getSCEV(Cmp->getOperand(1));
14654 if (isKnownPredicateViaConstantRanges(Cmp->getCmpPredicate(), L, R)) {
14655 Worklist.push_back(TrueBB);
14656 continue;
14657 }
14658 if (isKnownPredicateViaConstantRanges(Cmp->getInverseCmpPredicate(), L,
14659 R)) {
14660 Worklist.push_back(FalseBB);
14661 continue;
14662 }
14663 }
14664 }
14665
14666 append_range(Worklist, successors(BB));
14667 }
14668}
14669
14671 ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
14672 ScalarEvolution SE2(F, TLI, AC, DT, LI);
14673
14674 SmallVector<Loop *, 8> LoopStack(LI.begin(), LI.end());
14675
14676 // Map's SCEV expressions from one ScalarEvolution "universe" to another.
14677 struct SCEVMapper : public SCEVRewriteVisitor<SCEVMapper> {
14678 SCEVMapper(ScalarEvolution &SE) : SCEVRewriteVisitor<SCEVMapper>(SE) {}
14679
14680 const SCEV *visitConstant(const SCEVConstant *Constant) {
14681 return SE.getConstant(Constant->getAPInt());
14682 }
14683
14684 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
14685 return SE.getUnknown(Expr->getValue());
14686 }
14687
14688 const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
14689 return SE.getCouldNotCompute();
14690 }
14691 };
14692
14693 SCEVMapper SCM(SE2);
14694 SmallPtrSet<BasicBlock *, 16> ReachableBlocks;
14695 SE2.getReachableBlocks(ReachableBlocks, F);
14696
14697 auto GetDelta = [&](const SCEV *Old, const SCEV *New) -> const SCEV * {
14698 if (containsUndefs(Old) || containsUndefs(New)) {
14699 // SCEV treats "undef" as an unknown but consistent value (i.e. it does
14700 // not propagate undef aggressively). This means we can (and do) fail
14701 // verification in cases where a transform makes a value go from "undef"
14702 // to "undef+1" (say). The transform is fine, since in both cases the
14703 // result is "undef", but SCEV thinks the value increased by 1.
14704 return nullptr;
14705 }
14706
14707 // Unless VerifySCEVStrict is set, we only compare constant deltas.
14708 const SCEV *Delta = SE2.getMinusSCEV(Old, New);
14709 if (!VerifySCEVStrict && !isa<SCEVConstant>(Delta))
14710 return nullptr;
14711
14712 return Delta;
14713 };
14714
14715 while (!LoopStack.empty()) {
14716 auto *L = LoopStack.pop_back_val();
14717 llvm::append_range(LoopStack, *L);
14718
14719 // Only verify BECounts in reachable loops. For an unreachable loop,
14720 // any BECount is legal.
14721 if (!ReachableBlocks.contains(L->getHeader()))
14722 continue;
14723
14724 // Only verify cached BECounts. Computing new BECounts may change the
14725 // results of subsequent SCEV uses.
14726 auto It = BackedgeTakenCounts.find(L);
14727 if (It == BackedgeTakenCounts.end())
14728 continue;
14729
14730 auto *CurBECount =
14731 SCM.visit(It->second.getExact(L, const_cast<ScalarEvolution *>(this)));
14732 auto *NewBECount = SE2.getBackedgeTakenCount(L);
14733
14734 if (CurBECount == SE2.getCouldNotCompute() ||
14735 NewBECount == SE2.getCouldNotCompute()) {
14736 // NB! This situation is legal, but is very suspicious -- whatever pass
14737 // change the loop to make a trip count go from could not compute to
14738 // computable or vice-versa *should have* invalidated SCEV. However, we
14739 // choose not to assert here (for now) since we don't want false
14740 // positives.
14741 continue;
14742 }
14743
14744 if (SE.getTypeSizeInBits(CurBECount->getType()) >
14745 SE.getTypeSizeInBits(NewBECount->getType()))
14746 NewBECount = SE2.getZeroExtendExpr(NewBECount, CurBECount->getType());
14747 else if (SE.getTypeSizeInBits(CurBECount->getType()) <
14748 SE.getTypeSizeInBits(NewBECount->getType()))
14749 CurBECount = SE2.getZeroExtendExpr(CurBECount, NewBECount->getType());
14750
14751 const SCEV *Delta = GetDelta(CurBECount, NewBECount);
14752 if (Delta && !Delta->isZero()) {
14753 dbgs() << "Trip Count for " << *L << " Changed!\n";
14754 dbgs() << "Old: " << *CurBECount << "\n";
14755 dbgs() << "New: " << *NewBECount << "\n";
14756 dbgs() << "Delta: " << *Delta << "\n";
14757 std::abort();
14758 }
14759 }
14760
14761 // Collect all valid loops currently in LoopInfo.
14762 SmallPtrSet<Loop *, 32> ValidLoops;
14763 SmallVector<Loop *, 32> Worklist(LI.begin(), LI.end());
14764 while (!Worklist.empty()) {
14765 Loop *L = Worklist.pop_back_val();
14766 if (ValidLoops.insert(L).second)
14767 Worklist.append(L->begin(), L->end());
14768 }
14769 for (const auto &KV : ValueExprMap) {
14770#ifndef NDEBUG
14771 // Check for SCEV expressions referencing invalid/deleted loops.
14772 if (auto *AR = dyn_cast<SCEVAddRecExpr>(KV.second)) {
14773 assert(ValidLoops.contains(AR->getLoop()) &&
14774 "AddRec references invalid loop");
14775 }
14776#endif
14777
14778 // Check that the value is also part of the reverse map.
14779 auto It = ExprValueMap.find(KV.second);
14780 if (It == ExprValueMap.end() || !It->second.contains(KV.first)) {
14781 dbgs() << "Value " << *KV.first
14782 << " is in ValueExprMap but not in ExprValueMap\n";
14783 std::abort();
14784 }
14785
14786 if (auto *I = dyn_cast<Instruction>(&*KV.first)) {
14787 if (!ReachableBlocks.contains(I->getParent()))
14788 continue;
14789 const SCEV *OldSCEV = SCM.visit(KV.second);
14790 const SCEV *NewSCEV = SE2.getSCEV(I);
14791 const SCEV *Delta = GetDelta(OldSCEV, NewSCEV);
14792 if (Delta && !Delta->isZero()) {
14793 dbgs() << "SCEV for value " << *I << " changed!\n"
14794 << "Old: " << *OldSCEV << "\n"
14795 << "New: " << *NewSCEV << "\n"
14796 << "Delta: " << *Delta << "\n";
14797 std::abort();
14798 }
14799 }
14800 }
14801
14802 for (const auto &KV : ExprValueMap) {
14803 for (Value *V : KV.second) {
14804 const SCEV *S = ValueExprMap.lookup(V);
14805 if (!S) {
14806 dbgs() << "Value " << *V
14807 << " is in ExprValueMap but not in ValueExprMap\n";
14808 std::abort();
14809 }
14810 if (S != KV.first) {
14811 dbgs() << "Value " << *V << " mapped to " << *S << " rather than "
14812 << *KV.first << "\n";
14813 std::abort();
14814 }
14815 }
14816 }
14817
14818 // Verify integrity of SCEV users.
14819 for (const auto &S : UniqueSCEVs) {
14820 for (SCEVUse Op : S.operands()) {
14821 // We do not store dependencies of constants.
14822 if (isa<SCEVConstant>(Op))
14823 continue;
14824 auto It = SCEVUsers.find(Op);
14825 if (It != SCEVUsers.end() && It->second.count(&S))
14826 continue;
14827 dbgs() << "Use of operand " << *Op << " by user " << S
14828 << " is not being tracked!\n";
14829 std::abort();
14830 }
14831 }
14832
14833 // Verify integrity of ValuesAtScopes users.
14834 for (const auto &ValueAndVec : ValuesAtScopes) {
14835 const SCEV *Value = ValueAndVec.first;
14836 for (const auto &LoopAndValueAtScope : ValueAndVec.second) {
14837 const Loop *L = LoopAndValueAtScope.first;
14838 const SCEV *ValueAtScope = LoopAndValueAtScope.second;
14839 if (!isa<SCEVConstant>(ValueAtScope)) {
14840 auto It = ValuesAtScopesUsers.find(ValueAtScope);
14841 if (It != ValuesAtScopesUsers.end() &&
14842 is_contained(It->second, std::make_pair(L, Value)))
14843 continue;
14844 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14845 << *ValueAtScope << " missing in ValuesAtScopesUsers\n";
14846 std::abort();
14847 }
14848 }
14849 }
14850
14851 for (const auto &ValueAtScopeAndVec : ValuesAtScopesUsers) {
14852 const SCEV *ValueAtScope = ValueAtScopeAndVec.first;
14853 for (const auto &LoopAndValue : ValueAtScopeAndVec.second) {
14854 const Loop *L = LoopAndValue.first;
14855 const SCEV *Value = LoopAndValue.second;
14857 auto It = ValuesAtScopes.find(Value);
14858 if (It != ValuesAtScopes.end() &&
14859 is_contained(It->second, std::make_pair(L, ValueAtScope)))
14860 continue;
14861 dbgs() << "Value: " << *Value << ", Loop: " << *L << ", ValueAtScope: "
14862 << *ValueAtScope << " missing in ValuesAtScopes\n";
14863 std::abort();
14864 }
14865 }
14866
14867 // Verify integrity of BECountUsers.
14868 auto VerifyBECountUsers = [&](bool Predicated) {
14869 auto &BECounts =
14870 Predicated ? PredicatedBackedgeTakenCounts : BackedgeTakenCounts;
14871 for (const auto &LoopAndBEInfo : BECounts) {
14872 for (const ExitNotTakenInfo &ENT : LoopAndBEInfo.second.ExitNotTaken) {
14873 for (const SCEV *S : {ENT.ExactNotTaken, ENT.SymbolicMaxNotTaken}) {
14874 if (!isa<SCEVConstant>(S)) {
14875 auto UserIt = BECountUsers.find(S);
14876 if (UserIt != BECountUsers.end() &&
14877 UserIt->second.contains({ LoopAndBEInfo.first, Predicated }))
14878 continue;
14879 dbgs() << "Value " << *S << " for loop " << *LoopAndBEInfo.first
14880 << " missing from BECountUsers\n";
14881 std::abort();
14882 }
14883 }
14884 }
14885 }
14886 };
14887 VerifyBECountUsers(/* Predicated */ false);
14888 VerifyBECountUsers(/* Predicated */ true);
14889
14890 // Verify intergity of loop disposition cache.
14891 for (auto &[S, Values] : LoopDispositions) {
14892 for (auto [Loop, CachedDisposition] : Values) {
14893 const auto RecomputedDisposition = SE2.getLoopDisposition(S, Loop);
14894 if (CachedDisposition != RecomputedDisposition) {
14895 dbgs() << "Cached disposition of " << *S << " for loop " << *Loop
14896 << " is incorrect: cached " << CachedDisposition << ", actual "
14897 << RecomputedDisposition << "\n";
14898 std::abort();
14899 }
14900 }
14901 }
14902
14903 // Verify integrity of the block disposition cache.
14904 for (auto &[S, Values] : BlockDispositions) {
14905 for (auto [BB, CachedDisposition] : Values) {
14906 const auto RecomputedDisposition = SE2.getBlockDisposition(S, BB);
14907 if (CachedDisposition != RecomputedDisposition) {
14908 dbgs() << "Cached disposition of " << *S << " for block %"
14909 << BB->getName() << " is incorrect: cached " << CachedDisposition
14910 << ", actual " << RecomputedDisposition << "\n";
14911 std::abort();
14912 }
14913 }
14914 }
14915
14916 // Verify FoldCache/FoldCacheUser caches.
14917 for (auto [FoldID, Expr] : FoldCache) {
14918 auto I = FoldCacheUser.find(Expr);
14919 if (I == FoldCacheUser.end()) {
14920 dbgs() << "Missing entry in FoldCacheUser for cached expression " << *Expr
14921 << "!\n";
14922 std::abort();
14923 }
14924 if (!is_contained(I->second, FoldID)) {
14925 dbgs() << "Missing FoldID in cached users of " << *Expr << "!\n";
14926 std::abort();
14927 }
14928 }
14929 for (auto [Expr, IDs] : FoldCacheUser) {
14930 for (auto &FoldID : IDs) {
14931 const SCEV *S = FoldCache.lookup(FoldID);
14932 if (!S) {
14933 dbgs() << "Missing entry in FoldCache for expression " << *Expr
14934 << "!\n";
14935 std::abort();
14936 }
14937 if (S != Expr) {
14938 dbgs() << "Entry in FoldCache doesn't match FoldCacheUser: " << *S
14939 << " != " << *Expr << "!\n";
14940 std::abort();
14941 }
14942 }
14943 }
14944
14945 // Verify that ConstantMultipleCache computations are correct. We check that
14946 // cached multiples and recomputed multiples are multiples of each other to
14947 // verify correctness. It is possible that a recomputed multiple is different
14948 // from the cached multiple due to strengthened no wrap flags or changes in
14949 // KnownBits computations.
14950 for (auto [S, Multiple] : ConstantMultipleCache) {
14951 APInt RecomputedMultiple = SE2.getConstantMultiple(S);
14952 if ((Multiple != 0 && RecomputedMultiple != 0 &&
14953 Multiple.urem(RecomputedMultiple) != 0 &&
14954 RecomputedMultiple.urem(Multiple) != 0)) {
14955 dbgs() << "Incorrect cached computation in ConstantMultipleCache for "
14956 << *S << " : Computed " << RecomputedMultiple
14957 << " but cache contains " << Multiple << "!\n";
14958 std::abort();
14959 }
14960 }
14961}
14962
14964 Function &F, const PreservedAnalyses &PA,
14965 FunctionAnalysisManager::Invalidator &Inv) {
14966 // Invalidate the ScalarEvolution object whenever it isn't preserved or one
14967 // of its dependencies is invalidated.
14968 auto PAC = PA.getChecker<ScalarEvolutionAnalysis>();
14969 return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Function>>()) ||
14970 Inv.invalidate<AssumptionAnalysis>(F, PA) ||
14971 Inv.invalidate<DominatorTreeAnalysis>(F, PA) ||
14972 Inv.invalidate<LoopAnalysis>(F, PA);
14973}
14974
14975AnalysisKey ScalarEvolutionAnalysis::Key;
14976
14979 auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
14980 auto &AC = AM.getResult<AssumptionAnalysis>(F);
14981 auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
14982 auto &LI = AM.getResult<LoopAnalysis>(F);
14983 return ScalarEvolution(F, TLI, AC, DT, LI);
14984}
14985
14991
14994 // For compatibility with opt's -analyze feature under legacy pass manager
14995 // which was not ported to NPM. This keeps tests using
14996 // update_analyze_test_checks.py working.
14997 OS << "Printing analysis 'Scalar Evolution Analysis' for function '"
14998 << F.getName() << "':\n";
15000 return PreservedAnalyses::all();
15001}
15002
15004 "Scalar Evolution Analysis", false, true)
15010 "Scalar Evolution Analysis", false, true)
15011
15013
15015
15017 SE.reset(new ScalarEvolution(
15019 getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F),
15021 getAnalysis<LoopInfoWrapperPass>().getLoopInfo()));
15022 return false;
15023}
15024
15026
15028 SE->print(OS);
15029}
15030
15032 if (!VerifySCEV)
15033 return;
15034
15035 SE->verify();
15036}
15037
15045
15047 const SCEV *RHS) {
15048 return getComparePredicate(ICmpInst::ICMP_EQ, LHS, RHS);
15049}
15050
15051const SCEVPredicate *
15053 const SCEV *LHS, const SCEV *RHS) {
15055 assert(LHS->getType() == RHS->getType() &&
15056 "Type mismatch between LHS and RHS");
15057 // Unique this node based on the arguments
15058 ID.AddInteger(SCEVPredicate::P_Compare);
15059 ID.AddInteger(Pred);
15060 ID.AddPointer(LHS);
15061 ID.AddPointer(RHS);
15062 void *IP = nullptr;
15063 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
15064 return S;
15065 SCEVComparePredicate *Eq = new (SCEVAllocator)
15066 SCEVComparePredicate(ID.Intern(SCEVAllocator), Pred, LHS, RHS);
15067 UniquePreds.InsertNode(Eq, IP);
15068 return Eq;
15069}
15070
15072 const SCEVAddRecExpr *AR,
15075 // Unique this node based on the arguments
15076 ID.AddInteger(SCEVPredicate::P_Wrap);
15077 ID.AddPointer(AR);
15078 ID.AddInteger(AddedFlags);
15079 void *IP = nullptr;
15080 if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP))
15081 return S;
15082 auto *OF = new (SCEVAllocator)
15083 SCEVWrapPredicate(ID.Intern(SCEVAllocator), AR, AddedFlags);
15084 UniquePreds.InsertNode(OF, IP);
15085 return OF;
15086}
15087
15088namespace {
15089
15090class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {
15091public:
15092
15093 /// Rewrites \p S in the context of a loop L and the SCEV predication
15094 /// infrastructure.
15095 ///
15096 /// If \p Pred is non-null, the SCEV expression is rewritten to respect the
15097 /// equivalences present in \p Pred.
15098 ///
15099 /// If \p NewPreds is non-null, rewrite is free to add further predicates to
15100 /// \p NewPreds such that the result will be an AddRecExpr.
15101 static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
15103 const SCEVPredicate *Pred) {
15104 SCEVPredicateRewriter Rewriter(L, SE, NewPreds, Pred);
15105 return Rewriter.visit(S);
15106 }
15107
15108 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
15109 if (Pred) {
15110 if (auto *U = dyn_cast<SCEVUnionPredicate>(Pred)) {
15111 for (const auto *Pred : U->getPredicates())
15112 if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred))
15113 if (IPred->getLHS() == Expr &&
15114 IPred->getPredicate() == ICmpInst::ICMP_EQ)
15115 return IPred->getRHS();
15116 } else if (const auto *IPred = dyn_cast<SCEVComparePredicate>(Pred)) {
15117 if (IPred->getLHS() == Expr &&
15118 IPred->getPredicate() == ICmpInst::ICMP_EQ)
15119 return IPred->getRHS();
15120 }
15121 }
15122 return convertToAddRecWithPreds(Expr);
15123 }
15124
15125 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
15126 const SCEV *Operand = visit(Expr->getOperand());
15127 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
15128 if (AR && AR->getLoop() == L && AR->isAffine()) {
15129 // This couldn't be folded because the operand didn't have the nuw
15130 // flag. Add the nusw flag as an assumption that we could make.
15131 const SCEV *Step = AR->getStepRecurrence(SE);
15132 Type *Ty = Expr->getType();
15133 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNUSW))
15134 return SE.getAddRecExpr(SE.getZeroExtendExpr(AR->getStart(), Ty),
15135 SE.getSignExtendExpr(Step, Ty), L,
15136 AR->getNoWrapFlags());
15137 }
15138 return SE.getZeroExtendExpr(Operand, Expr->getType());
15139 }
15140
15141 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
15142 const SCEV *Operand = visit(Expr->getOperand());
15143 const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand);
15144 if (AR && AR->getLoop() == L && AR->isAffine()) {
15145 // This couldn't be folded because the operand didn't have the nsw
15146 // flag. Add the nssw flag as an assumption that we could make.
15147 const SCEV *Step = AR->getStepRecurrence(SE);
15148 Type *Ty = Expr->getType();
15149 if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNSSW))
15150 return SE.getAddRecExpr(SE.getSignExtendExpr(AR->getStart(), Ty),
15151 SE.getSignExtendExpr(Step, Ty), L,
15152 AR->getNoWrapFlags());
15153 }
15154 return SE.getSignExtendExpr(Operand, Expr->getType());
15155 }
15156
15157private:
15158 explicit SCEVPredicateRewriter(
15159 const Loop *L, ScalarEvolution &SE,
15160 SmallVectorImpl<const SCEVPredicate *> *NewPreds,
15161 const SCEVPredicate *Pred)
15162 : SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {}
15163
15164 bool addOverflowAssumption(const SCEVPredicate *P) {
15165 if (!NewPreds) {
15166 // Check if we've already made this assumption.
15167 return Pred && Pred->implies(P, SE);
15168 }
15169 NewPreds->push_back(P);
15170 return true;
15171 }
15172
15173 bool addOverflowAssumption(const SCEVAddRecExpr *AR,
15175 auto *A = SE.getWrapPredicate(AR, AddedFlags);
15176 return addOverflowAssumption(A);
15177 }
15178
15179 // If \p Expr represents a PHINode, we try to see if it can be represented
15180 // as an AddRec, possibly under a predicate (PHISCEVPred). If it is possible
15181 // to add this predicate as a runtime overflow check, we return the AddRec.
15182 // If \p Expr does not meet these conditions (is not a PHI node, or we
15183 // couldn't create an AddRec for it, or couldn't add the predicate), we just
15184 // return \p Expr.
15185 const SCEV *convertToAddRecWithPreds(const SCEVUnknown *Expr) {
15186 if (!isa<PHINode>(Expr->getValue()))
15187 return Expr;
15188 std::optional<
15189 std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>
15190 PredicatedRewrite = SE.createAddRecFromPHIWithCasts(Expr);
15191 if (!PredicatedRewrite)
15192 return Expr;
15193 for (const auto *P : PredicatedRewrite->second){
15194 // Wrap predicates from outer loops are not supported.
15195 if (auto *WP = dyn_cast<const SCEVWrapPredicate>(P)) {
15196 if (L != WP->getExpr()->getLoop())
15197 return Expr;
15198 }
15199 if (!addOverflowAssumption(P))
15200 return Expr;
15201 }
15202 return PredicatedRewrite->first;
15203 }
15204
15205 SmallVectorImpl<const SCEVPredicate *> *NewPreds;
15206 const SCEVPredicate *Pred;
15207 const Loop *L;
15208};
15209
15210} // end anonymous namespace
15211
15212const SCEV *
15214 const SCEVPredicate &Preds) {
15215 return SCEVPredicateRewriter::rewrite(S, L, *this, nullptr, &Preds);
15216}
15217
15219 const SCEV *S, const Loop *L,
15222 S = SCEVPredicateRewriter::rewrite(S, L, *this, &TransformPreds, nullptr);
15223 auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
15224
15225 if (!AddRec)
15226 return nullptr;
15227
15228 // Check if any of the transformed predicates is known to be false. In that
15229 // case, it doesn't make sense to convert to a predicated AddRec, as the
15230 // versioned loop will never execute.
15231 for (const SCEVPredicate *Pred : TransformPreds) {
15232 auto *WrapPred = dyn_cast<SCEVWrapPredicate>(Pred);
15233 if (!WrapPred || WrapPred->getFlags() != SCEVWrapPredicate::IncrementNSSW)
15234 continue;
15235
15236 const SCEVAddRecExpr *AddRecToCheck = WrapPred->getExpr();
15237 const SCEV *ExitCount = getBackedgeTakenCount(AddRecToCheck->getLoop());
15238 if (isa<SCEVCouldNotCompute>(ExitCount))
15239 continue;
15240
15241 const SCEV *Step = AddRecToCheck->getStepRecurrence(*this);
15242 if (!Step->isOne())
15243 continue;
15244
15245 ExitCount = getTruncateOrSignExtend(ExitCount, Step->getType());
15246 const SCEV *Add = getAddExpr(AddRecToCheck->getStart(), ExitCount);
15247 if (isKnownPredicate(CmpInst::ICMP_SLT, Add, AddRecToCheck->getStart()))
15248 return nullptr;
15249 }
15250
15251 // Since the transformation was successful, we can now transfer the SCEV
15252 // predicates.
15253 Preds.append(TransformPreds.begin(), TransformPreds.end());
15254
15255 return AddRec;
15256}
15257
15258/// SCEV predicates
15262
15264 const ICmpInst::Predicate Pred,
15265 const SCEV *LHS, const SCEV *RHS)
15266 : SCEVPredicate(ID, P_Compare), Pred(Pred), LHS(LHS), RHS(RHS) {
15267 assert(LHS->getType() == RHS->getType() && "LHS and RHS types don't match");
15268 assert(LHS != RHS && "LHS and RHS are the same SCEV");
15269}
15270
15272 ScalarEvolution &SE) const {
15273 const auto *Op = dyn_cast<SCEVComparePredicate>(N);
15274
15275 if (!Op)
15276 return false;
15277
15278 if (Pred != ICmpInst::ICMP_EQ)
15279 return false;
15280
15281 return Op->LHS == LHS && Op->RHS == RHS;
15282}
15283
15284bool SCEVComparePredicate::isAlwaysTrue() const { return false; }
15285
15287 if (Pred == ICmpInst::ICMP_EQ)
15288 OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n";
15289 else
15290 OS.indent(Depth) << "Compare predicate: " << *LHS << " " << Pred << ") "
15291 << *RHS << "\n";
15292
15293}
15294
15296 const SCEVAddRecExpr *AR,
15297 IncrementWrapFlags Flags)
15298 : SCEVPredicate(ID, P_Wrap), AR(AR), Flags(Flags) {}
15299
15300const SCEVAddRecExpr *SCEVWrapPredicate::getExpr() const { return AR; }
15301
15303 ScalarEvolution &SE) const {
15304 const auto *Op = dyn_cast<SCEVWrapPredicate>(N);
15305 if (!Op || setFlags(Flags, Op->Flags) != Flags)
15306 return false;
15307
15308 if (Op->AR == AR)
15309 return true;
15310
15311 if (Flags != SCEVWrapPredicate::IncrementNSSW &&
15313 return false;
15314
15315 const SCEV *Start = AR->getStart();
15316 const SCEV *OpStart = Op->AR->getStart();
15317 if (Start->getType()->isPointerTy() != OpStart->getType()->isPointerTy())
15318 return false;
15319
15320 // Reject pointers to different address spaces.
15321 if (Start->getType()->isPointerTy() && Start->getType() != OpStart->getType())
15322 return false;
15323
15324 // NUSW/NSSW on a wider-type AddRec does not imply the same on a
15325 // narrower-type AddRec.
15326 if (SE.getTypeSizeInBits(AR->getType()) >
15327 SE.getTypeSizeInBits(Op->AR->getType()))
15328 return false;
15329
15330 const SCEV *Step = AR->getStepRecurrence(SE);
15331 const SCEV *OpStep = Op->AR->getStepRecurrence(SE);
15332 if (!SE.isKnownPositive(Step) || !SE.isKnownPositive(OpStep))
15333 return false;
15334
15335 // If both steps are positive, this implies N, if N's start and step are
15336 // ULE/SLE (for NSUW/NSSW) than this'.
15337 Type *WiderTy = SE.getWiderType(Step->getType(), OpStep->getType());
15338 Step = SE.getNoopOrZeroExtend(Step, WiderTy);
15339 OpStep = SE.getNoopOrZeroExtend(OpStep, WiderTy);
15340
15341 bool IsNUW = Flags == SCEVWrapPredicate::IncrementNUSW;
15342 OpStart = IsNUW ? SE.getNoopOrZeroExtend(OpStart, WiderTy)
15343 : SE.getNoopOrSignExtend(OpStart, WiderTy);
15344 Start = IsNUW ? SE.getNoopOrZeroExtend(Start, WiderTy)
15345 : SE.getNoopOrSignExtend(Start, WiderTy);
15347 return SE.isKnownPredicate(Pred, OpStep, Step) &&
15348 SE.isKnownPredicate(Pred, OpStart, Start);
15349}
15350
15352 SCEV::NoWrapFlags ScevFlags = AR->getNoWrapFlags();
15353 IncrementWrapFlags IFlags = Flags;
15354
15355 if (ScalarEvolution::setFlags(ScevFlags, SCEV::FlagNSW) == ScevFlags)
15356 IFlags = clearFlags(IFlags, IncrementNSSW);
15357
15358 return IFlags == IncrementAnyWrap;
15359}
15360
15361void SCEVWrapPredicate::print(raw_ostream &OS, unsigned Depth) const {
15362 OS.indent(Depth) << *getExpr() << " Added Flags: ";
15364 OS << "<nusw>";
15366 OS << "<nssw>";
15367 OS << "\n";
15368}
15369
15372 ScalarEvolution &SE) {
15373 IncrementWrapFlags ImpliedFlags = IncrementAnyWrap;
15374 SCEV::NoWrapFlags StaticFlags = AR->getNoWrapFlags();
15375
15376 // We can safely transfer the NSW flag as NSSW.
15377 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNSW) == StaticFlags)
15378 ImpliedFlags = IncrementNSSW;
15379
15380 if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNUW) == StaticFlags) {
15381 // If the increment is positive, the SCEV NUW flag will also imply the
15382 // WrapPredicate NUSW flag.
15383 if (const auto *Step = dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE)))
15384 if (Step->getValue()->getValue().isNonNegative())
15385 ImpliedFlags = setFlags(ImpliedFlags, IncrementNUSW);
15386 }
15387
15388 return ImpliedFlags;
15389}
15390
15391/// Union predicates don't get cached so create a dummy set ID for it.
15393 ScalarEvolution &SE)
15394 : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {
15395 for (const auto *P : Preds)
15396 add(P, SE);
15397}
15398
15400 return all_of(Preds,
15401 [](const SCEVPredicate *I) { return I->isAlwaysTrue(); });
15402}
15403
15405 ScalarEvolution &SE) const {
15406 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N))
15407 return all_of(Set->Preds, [this, &SE](const SCEVPredicate *I) {
15408 return this->implies(I, SE);
15409 });
15410
15411 return any_of(Preds,
15412 [N, &SE](const SCEVPredicate *I) { return I->implies(N, SE); });
15413}
15414
15416 for (const auto *Pred : Preds)
15417 Pred->print(OS, Depth);
15418}
15419
15420void SCEVUnionPredicate::add(const SCEVPredicate *N, ScalarEvolution &SE) {
15421 if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) {
15422 for (const auto *Pred : Set->Preds)
15423 add(Pred, SE);
15424 return;
15425 }
15426
15427 // Implication checks are quadratic in the number of predicates. Stop doing
15428 // them if there are many predicates, as they should be too expensive to use
15429 // anyway at that point.
15430 bool CheckImplies = Preds.size() < 16;
15431
15432 // Only add predicate if it is not already implied by this union predicate.
15433 if (CheckImplies && implies(N, SE))
15434 return;
15435
15436 // Build a new vector containing the current predicates, except the ones that
15437 // are implied by the new predicate N.
15439 for (auto *P : Preds) {
15440 if (CheckImplies && N->implies(P, SE))
15441 continue;
15442 PrunedPreds.push_back(P);
15443 }
15444 Preds = std::move(PrunedPreds);
15445 Preds.push_back(N);
15446}
15447
15449 Loop &L)
15450 : SE(SE), L(L) {
15452 Preds = std::make_unique<SCEVUnionPredicate>(Empty, SE);
15453}
15454
15457 for (const auto *Op : Ops)
15458 // We do not expect that forgetting cached data for SCEVConstants will ever
15459 // open any prospects for sharpening or introduce any correctness issues,
15460 // so we don't bother storing their dependencies.
15461 if (!isa<SCEVConstant>(Op))
15462 SCEVUsers[Op].insert(User);
15463}
15464
15466 for (const SCEV *Op : Ops)
15467 // We do not expect that forgetting cached data for SCEVConstants will ever
15468 // open any prospects for sharpening or introduce any correctness issues,
15469 // so we don't bother storing their dependencies.
15470 if (!isa<SCEVConstant>(Op))
15471 SCEVUsers[Op].insert(User);
15472}
15473
15475 const SCEV *Expr = SE.getSCEV(V);
15476 return getPredicatedSCEV(Expr);
15477}
15478
15480 RewriteEntry &Entry = RewriteMap[Expr];
15481
15482 // If we already have an entry and the version matches, return it.
15483 if (Entry.second && Generation == Entry.first)
15484 return Entry.second;
15485
15486 // We found an entry but it's stale. Rewrite the stale entry
15487 // according to the current predicate.
15488 if (Entry.second)
15489 Expr = Entry.second;
15490
15491 const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, &L, *Preds);
15492 Entry = {Generation, NewSCEV};
15493
15494 return NewSCEV;
15495}
15496
15498 if (!BackedgeCount) {
15500 BackedgeCount = SE.getPredicatedBackedgeTakenCount(&L, Preds);
15501 for (const auto *P : Preds)
15502 addPredicate(*P);
15503 }
15504 return BackedgeCount;
15505}
15506
15508 if (!SymbolicMaxBackedgeCount) {
15510 SymbolicMaxBackedgeCount =
15511 SE.getPredicatedSymbolicMaxBackedgeTakenCount(&L, Preds);
15512 for (const auto *P : Preds)
15513 addPredicate(*P);
15514 }
15515 return SymbolicMaxBackedgeCount;
15516}
15517
15519 if (!SmallConstantMaxTripCount) {
15521 SmallConstantMaxTripCount = SE.getSmallConstantMaxTripCount(&L, &Preds);
15522 for (const auto *P : Preds)
15523 addPredicate(*P);
15524 }
15525 return *SmallConstantMaxTripCount;
15526}
15527
15529 if (Preds->implies(&Pred, SE))
15530 return;
15531
15532 SmallVector<const SCEVPredicate *, 4> NewPreds(Preds->getPredicates());
15533 NewPreds.push_back(&Pred);
15534 Preds = std::make_unique<SCEVUnionPredicate>(NewPreds, SE);
15535 updateGeneration();
15536}
15537
15539 return *Preds;
15540}
15541
15542void PredicatedScalarEvolution::updateGeneration() {
15543 // If the generation number wrapped recompute everything.
15544 if (++Generation == 0) {
15545 for (auto &II : RewriteMap) {
15546 const SCEV *Rewritten = II.second.second;
15547 II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, &L, *Preds)};
15548 }
15549 }
15550}
15551
15554 const SCEV *Expr = getSCEV(V);
15555 const auto *AR = cast<SCEVAddRecExpr>(Expr);
15556
15557 auto ImpliedFlags = SCEVWrapPredicate::getImpliedFlags(AR, SE);
15558
15559 // Clear the statically implied flags.
15560 Flags = SCEVWrapPredicate::clearFlags(Flags, ImpliedFlags);
15561 addPredicate(*SE.getWrapPredicate(AR, Flags));
15562
15563 auto II = FlagsMap.insert({V, Flags});
15564 if (!II.second)
15565 II.first->second = SCEVWrapPredicate::setFlags(Flags, II.first->second);
15566}
15567
15570 const SCEV *Expr = getSCEV(V);
15571 const auto *AR = cast<SCEVAddRecExpr>(Expr);
15572
15574 Flags, SCEVWrapPredicate::getImpliedFlags(AR, SE));
15575
15576 auto II = FlagsMap.find(V);
15577
15578 if (II != FlagsMap.end())
15579 Flags = SCEVWrapPredicate::clearFlags(Flags, II->second);
15580
15582}
15583
15585 const SCEV *Expr = this->getSCEV(V);
15587 auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, NewPreds);
15588
15589 if (!New)
15590 return nullptr;
15591
15592 for (const auto *P : NewPreds)
15593 addPredicate(*P);
15594
15595 RewriteMap[SE.getSCEV(V)] = {Generation, New};
15596 return New;
15597}
15598
15601 : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L),
15602 Preds(std::make_unique<SCEVUnionPredicate>(Init.Preds->getPredicates(),
15603 SE)),
15604 Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) {
15605 for (auto I : Init.FlagsMap)
15606 FlagsMap.insert(I);
15607}
15608
15610 // For each block.
15611 for (auto *BB : L.getBlocks())
15612 for (auto &I : *BB) {
15613 if (!SE.isSCEVable(I.getType()))
15614 continue;
15615
15616 auto *Expr = SE.getSCEV(&I);
15617 auto II = RewriteMap.find(Expr);
15618
15619 if (II == RewriteMap.end())
15620 continue;
15621
15622 // Don't print things that are not interesting.
15623 if (II->second.second == Expr)
15624 continue;
15625
15626 OS.indent(Depth) << "[PSE]" << I << ":\n";
15627 OS.indent(Depth + 2) << *Expr << "\n";
15628 OS.indent(Depth + 2) << "--> " << *II->second.second << "\n";
15629 }
15630}
15631
15634 BasicBlock *Header = L->getHeader();
15635 BasicBlock *Pred = L->getLoopPredecessor();
15636 LoopGuards Guards(SE);
15637 if (!Pred)
15638 return Guards;
15640 collectFromBlock(SE, Guards, Header, Pred, VisitedBlocks);
15641 return Guards;
15642}
15643
15644void ScalarEvolution::LoopGuards::collectFromPHI(
15648 unsigned Depth) {
15649 if (!SE.isSCEVable(Phi.getType()))
15650 return;
15651
15652 using MinMaxPattern = std::pair<const SCEVConstant *, SCEVTypes>;
15653 auto GetMinMaxConst = [&](unsigned IncomingIdx) -> MinMaxPattern {
15654 const BasicBlock *InBlock = Phi.getIncomingBlock(IncomingIdx);
15655 if (!VisitedBlocks.insert(InBlock).second)
15656 return {nullptr, scCouldNotCompute};
15657
15658 // Avoid analyzing unreachable blocks so that we don't get trapped
15659 // traversing cycles with ill-formed dominance or infinite cycles
15660 if (!SE.DT.isReachableFromEntry(InBlock))
15661 return {nullptr, scCouldNotCompute};
15662
15663 auto [G, Inserted] = IncomingGuards.try_emplace(InBlock, LoopGuards(SE));
15664 if (Inserted)
15665 collectFromBlock(SE, G->second, Phi.getParent(), InBlock, VisitedBlocks,
15666 Depth + 1);
15667 auto &RewriteMap = G->second.RewriteMap;
15668 if (RewriteMap.empty())
15669 return {nullptr, scCouldNotCompute};
15670 auto S = RewriteMap.find(SE.getSCEV(Phi.getIncomingValue(IncomingIdx)));
15671 if (S == RewriteMap.end())
15672 return {nullptr, scCouldNotCompute};
15673 auto *SM = dyn_cast_if_present<SCEVMinMaxExpr>(S->second);
15674 if (!SM)
15675 return {nullptr, scCouldNotCompute};
15676 if (const SCEVConstant *C0 = dyn_cast<SCEVConstant>(SM->getOperand(0)))
15677 return {C0, SM->getSCEVType()};
15678 return {nullptr, scCouldNotCompute};
15679 };
15680 auto MergeMinMaxConst = [](MinMaxPattern P1,
15681 MinMaxPattern P2) -> MinMaxPattern {
15682 auto [C1, T1] = P1;
15683 auto [C2, T2] = P2;
15684 if (!C1 || !C2 || T1 != T2)
15685 return {nullptr, scCouldNotCompute};
15686 switch (T1) {
15687 case scUMaxExpr:
15688 return {C1->getAPInt().ult(C2->getAPInt()) ? C1 : C2, T1};
15689 case scSMaxExpr:
15690 return {C1->getAPInt().slt(C2->getAPInt()) ? C1 : C2, T1};
15691 case scUMinExpr:
15692 return {C1->getAPInt().ugt(C2->getAPInt()) ? C1 : C2, T1};
15693 case scSMinExpr:
15694 return {C1->getAPInt().sgt(C2->getAPInt()) ? C1 : C2, T1};
15695 default:
15696 llvm_unreachable("Trying to merge non-MinMaxExpr SCEVs.");
15697 }
15698 };
15699 auto P = GetMinMaxConst(0);
15700 for (unsigned int In = 1; In < Phi.getNumIncomingValues(); In++) {
15701 if (!P.first)
15702 break;
15703 P = MergeMinMaxConst(P, GetMinMaxConst(In));
15704 }
15705 if (P.first) {
15706 const SCEV *LHS = SE.getSCEV(const_cast<PHINode *>(&Phi));
15707 SmallVector<SCEVUse, 2> Ops({P.first, LHS});
15708 const SCEV *RHS = SE.getMinMaxExpr(P.second, Ops);
15709 Guards.RewriteMap.insert({LHS, RHS});
15710 }
15711}
15712
15713// Return a new SCEV that modifies \p Expr to the closest number divides by
15714// \p Divisor and less or equal than Expr. For now, only handle constant
15715// Expr.
15717 const APInt &DivisorVal,
15718 ScalarEvolution &SE) {
15719 const APInt *ExprVal;
15720 if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() ||
15721 DivisorVal.isNonPositive())
15722 return Expr;
15723 APInt Rem = ExprVal->urem(DivisorVal);
15724 // return the SCEV: Expr - Expr % Divisor
15725 return SE.getConstant(*ExprVal - Rem);
15726}
15727
15728// Return a new SCEV that modifies \p Expr to the closest number divides by
15729// \p Divisor and greater or equal than Expr. For now, only handle constant
15730// Expr.
15731static const SCEV *getNextSCEVDivisibleByDivisor(const SCEV *Expr,
15732 const APInt &DivisorVal,
15733 ScalarEvolution &SE) {
15734 const APInt *ExprVal;
15735 if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() ||
15736 DivisorVal.isNonPositive())
15737 return Expr;
15738 APInt Rem = ExprVal->urem(DivisorVal);
15739 if (Rem.isZero())
15740 return Expr;
15741 // return the SCEV: Expr + Divisor - Expr % Divisor
15742 return SE.getConstant(*ExprVal + DivisorVal - Rem);
15743}
15744
15746 ICmpInst::Predicate Predicate, const SCEV *LHS, const SCEV *RHS,
15749 // If we have LHS == 0, check if LHS is computing a property of some unknown
15750 // SCEV %v which we can rewrite %v to express explicitly.
15752 return false;
15753 // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
15754 // explicitly express that.
15755 const SCEVUnknown *URemLHS = nullptr;
15756 const SCEV *URemRHS = nullptr;
15757 if (!match(LHS, m_scev_URem(m_SCEVUnknown(URemLHS), m_SCEV(URemRHS), SE)))
15758 return false;
15759
15760 const SCEV *Multiple =
15761 SE.getMulExpr(SE.getUDivExpr(URemLHS, URemRHS), URemRHS);
15762 DivInfo[URemLHS] = Multiple;
15763 if (auto *C = dyn_cast<SCEVConstant>(URemRHS))
15764 Multiples[URemLHS] = C->getAPInt();
15765 return true;
15766}
15767
15768// Check if the condition is a divisibility guard (A % B == 0).
15769static bool isDivisibilityGuard(const SCEV *LHS, const SCEV *RHS,
15770 ScalarEvolution &SE) {
15771 const SCEV *X, *Y;
15772 return match(LHS, m_scev_URem(m_SCEV(X), m_SCEV(Y), SE)) && RHS->isZero();
15773}
15774
15775// Apply divisibility by \p Divisor on MinMaxExpr with constant values,
15776// recursively. This is done by aligning up/down the constant value to the
15777// Divisor.
15778static const SCEV *applyDivisibilityOnMinMaxExpr(const SCEV *MinMaxExpr,
15779 APInt Divisor,
15780 ScalarEvolution &SE) {
15781 // Return true if \p Expr is a MinMax SCEV expression with a non-negative
15782 // constant operand. If so, return in \p SCTy the SCEV type and in \p RHS
15783 // the non-constant operand and in \p LHS the constant operand.
15784 auto IsMinMaxSCEVWithNonNegativeConstant =
15785 [&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS,
15786 const SCEV *&RHS) {
15787 if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
15788 if (MinMax->getNumOperands() != 2)
15789 return false;
15790 if (auto *C = dyn_cast<SCEVConstant>(MinMax->getOperand(0))) {
15791 if (C->getAPInt().isNegative())
15792 return false;
15793 SCTy = MinMax->getSCEVType();
15794 LHS = MinMax->getOperand(0);
15795 RHS = MinMax->getOperand(1);
15796 return true;
15797 }
15798 }
15799 return false;
15800 };
15801
15802 const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr;
15803 SCEVTypes SCTy;
15804 if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS,
15805 MinMaxRHS))
15806 return MinMaxExpr;
15807 auto IsMin = isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr);
15808 assert(SE.isKnownNonNegative(MinMaxLHS) && "Expected non-negative operand!");
15809 auto *DivisibleExpr =
15810 IsMin ? getPreviousSCEVDivisibleByDivisor(MinMaxLHS, Divisor, SE)
15811 : getNextSCEVDivisibleByDivisor(MinMaxLHS, Divisor, SE);
15813 applyDivisibilityOnMinMaxExpr(MinMaxRHS, Divisor, SE), DivisibleExpr};
15814 return SE.getMinMaxExpr(SCTy, Ops);
15815}
15816
15817void ScalarEvolution::LoopGuards::collectFromBlock(
15818 ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
15819 const BasicBlock *Block, const BasicBlock *Pred,
15820 SmallPtrSetImpl<const BasicBlock *> &VisitedBlocks, unsigned Depth) {
15821
15823
15824 SmallVector<SCEVUse> ExprsToRewrite;
15825 auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
15826 const SCEV *RHS,
15827 DenseMap<const SCEV *, const SCEV *> &RewriteMap,
15828 const LoopGuards &DivGuards) {
15829 // WARNING: It is generally unsound to apply any wrap flags to the proposed
15830 // replacement SCEV which isn't directly implied by the structure of that
15831 // SCEV. In particular, using contextual facts to imply flags is *NOT*
15832 // legal. See the scoping rules for flags in the header to understand why.
15833
15834 // Check for a condition of the form (-C1 + X < C2). InstCombine will
15835 // create this form when combining two checks of the form (X u< C2 + C1) and
15836 // (X >=u C1).
15837 auto MatchRangeCheckIdiom = [&SE, Predicate, LHS, RHS, &RewriteMap,
15838 &ExprsToRewrite]() {
15839 const SCEVConstant *C1;
15840 const SCEVUnknown *LHSUnknown;
15841 auto *C2 = dyn_cast<SCEVConstant>(RHS);
15842 if (!match(LHS,
15843 m_scev_Add(m_SCEVConstant(C1), m_SCEVUnknown(LHSUnknown))) ||
15844 !C2)
15845 return false;
15846
15847 auto ExactRegion =
15848 ConstantRange::makeExactICmpRegion(Predicate, C2->getAPInt())
15849 .sub(C1->getAPInt());
15850
15851 // Bail out, unless we have a non-wrapping, monotonic range.
15852 if (ExactRegion.isWrappedSet() || ExactRegion.isFullSet())
15853 return false;
15854 auto [I, Inserted] = RewriteMap.try_emplace(LHSUnknown);
15855 const SCEV *RewrittenLHS = Inserted ? LHSUnknown : I->second;
15856 I->second = SE.getUMaxExpr(
15857 SE.getConstant(ExactRegion.getUnsignedMin()),
15858 SE.getUMinExpr(RewrittenLHS,
15859 SE.getConstant(ExactRegion.getUnsignedMax())));
15860 ExprsToRewrite.push_back(LHSUnknown);
15861 return true;
15862 };
15863 if (MatchRangeCheckIdiom())
15864 return;
15865
15866 // Do not apply information for constants or if RHS contains an AddRec.
15868 return;
15869
15870 // If RHS is SCEVUnknown, make sure the information is applied to it.
15872 std::swap(LHS, RHS);
15874 }
15875
15876 // Puts rewrite rule \p From -> \p To into the rewrite map. Also if \p From
15877 // and \p FromRewritten are the same (i.e. there has been no rewrite
15878 // registered for \p From), then puts this value in the list of rewritten
15879 // expressions.
15880 auto AddRewrite = [&](const SCEV *From, const SCEV *FromRewritten,
15881 const SCEV *To) {
15882 if (From == FromRewritten)
15883 ExprsToRewrite.push_back(From);
15884 RewriteMap[From] = To;
15885 };
15886
15887 // Checks whether \p S has already been rewritten. In that case returns the
15888 // existing rewrite because we want to chain further rewrites onto the
15889 // already rewritten value. Otherwise returns \p S.
15890 auto GetMaybeRewritten = [&](const SCEV *S) {
15891 return RewriteMap.lookup_or(S, S);
15892 };
15893
15894 const SCEV *RewrittenLHS = GetMaybeRewritten(LHS);
15895 // Apply divisibility information when computing the constant multiple.
15896 const APInt &DividesBy =
15897 SE.getConstantMultiple(DivGuards.rewrite(RewrittenLHS));
15898
15899 // Collect rewrites for LHS and its transitive operands based on the
15900 // condition.
15901 // For min/max expressions, also apply the guard to its operands:
15902 // 'min(a, b) >= c' -> '(a >= c) and (b >= c)',
15903 // 'min(a, b) > c' -> '(a > c) and (b > c)',
15904 // 'max(a, b) <= c' -> '(a <= c) and (b <= c)',
15905 // 'max(a, b) < c' -> '(a < c) and (b < c)'.
15906
15907 // We cannot express strict predicates in SCEV, so instead we replace them
15908 // with non-strict ones against plus or minus one of RHS depending on the
15909 // predicate.
15910 const SCEV *One = SE.getOne(RHS->getType());
15911 switch (Predicate) {
15912 case CmpInst::ICMP_ULT:
15913 if (RHS->getType()->isPointerTy())
15914 return;
15915 RHS = SE.getUMaxExpr(RHS, One);
15916 [[fallthrough]];
15917 case CmpInst::ICMP_SLT: {
15918 RHS = SE.getMinusSCEV(RHS, One);
15919 RHS = getPreviousSCEVDivisibleByDivisor(RHS, DividesBy, SE);
15920 break;
15921 }
15922 case CmpInst::ICMP_UGT:
15923 case CmpInst::ICMP_SGT:
15924 RHS = SE.getAddExpr(RHS, One);
15925 RHS = getNextSCEVDivisibleByDivisor(RHS, DividesBy, SE);
15926 break;
15927 case CmpInst::ICMP_ULE:
15928 case CmpInst::ICMP_SLE:
15929 RHS = getPreviousSCEVDivisibleByDivisor(RHS, DividesBy, SE);
15930 break;
15931 case CmpInst::ICMP_UGE:
15932 case CmpInst::ICMP_SGE:
15933 RHS = getNextSCEVDivisibleByDivisor(RHS, DividesBy, SE);
15934 break;
15935 default:
15936 break;
15937 }
15938
15939 SmallVector<SCEVUse, 16> Worklist(1, LHS);
15940 SmallPtrSet<const SCEV *, 16> Visited;
15941
15942 auto EnqueueOperands = [&Worklist](const SCEVNAryExpr *S) {
15943 append_range(Worklist, S->operands());
15944 };
15945
15946 while (!Worklist.empty()) {
15947 const SCEV *From = Worklist.pop_back_val();
15948 if (isa<SCEVConstant>(From))
15949 continue;
15950 if (!Visited.insert(From).second)
15951 continue;
15952 const SCEV *FromRewritten = GetMaybeRewritten(From);
15953 const SCEV *To = nullptr;
15954
15955 switch (Predicate) {
15956 case CmpInst::ICMP_ULT:
15957 case CmpInst::ICMP_ULE:
15958 To = SE.getUMinExpr(FromRewritten, RHS);
15959 if (auto *UMax = dyn_cast<SCEVUMaxExpr>(FromRewritten))
15960 EnqueueOperands(UMax);
15961 break;
15962 case CmpInst::ICMP_SLT:
15963 case CmpInst::ICMP_SLE:
15964 To = SE.getSMinExpr(FromRewritten, RHS);
15965 if (auto *SMax = dyn_cast<SCEVSMaxExpr>(FromRewritten))
15966 EnqueueOperands(SMax);
15967 break;
15968 case CmpInst::ICMP_UGT:
15969 case CmpInst::ICMP_UGE:
15970 To = SE.getUMaxExpr(FromRewritten, RHS);
15971 if (auto *UMin = dyn_cast<SCEVUMinExpr>(FromRewritten))
15972 EnqueueOperands(UMin);
15973 break;
15974 case CmpInst::ICMP_SGT:
15975 case CmpInst::ICMP_SGE:
15976 To = SE.getSMaxExpr(FromRewritten, RHS);
15977 if (auto *SMin = dyn_cast<SCEVSMinExpr>(FromRewritten))
15978 EnqueueOperands(SMin);
15979 break;
15980 case CmpInst::ICMP_EQ:
15982 To = RHS;
15983 break;
15984 case CmpInst::ICMP_NE:
15985 if (match(RHS, m_scev_Zero())) {
15986 const SCEV *OneAlignedUp =
15987 getNextSCEVDivisibleByDivisor(One, DividesBy, SE);
15988 To = SE.getUMaxExpr(FromRewritten, OneAlignedUp);
15989 } else {
15990 // LHS != RHS can be rewritten as (LHS - RHS) = UMax(1, LHS - RHS),
15991 // but creating the subtraction eagerly is expensive. Track the
15992 // inequalities in a separate map, and materialize the rewrite lazily
15993 // when encountering a suitable subtraction while re-writing.
15994 if (LHS->getType()->isPointerTy()) {
15998 break;
15999 }
16000 const SCEVConstant *C;
16001 const SCEV *A, *B;
16004 RHS = A;
16005 LHS = B;
16006 }
16007 if (LHS > RHS)
16008 std::swap(LHS, RHS);
16009 Guards.NotEqual.insert({LHS, RHS});
16010 continue;
16011 }
16012 break;
16013 default:
16014 break;
16015 }
16016
16017 if (To)
16018 AddRewrite(From, FromRewritten, To);
16019 }
16020 };
16021
16023 // First, collect information from assumptions dominating the loop.
16024 for (auto &AssumeVH : SE.AC.assumptions()) {
16025 if (!AssumeVH)
16026 continue;
16027 auto *AssumeI = cast<CallInst>(AssumeVH);
16028 if (!SE.DT.dominates(AssumeI, Block))
16029 continue;
16030 Terms.emplace_back(AssumeI->getOperand(0), true);
16031 }
16032
16033 // Second, collect information from llvm.experimental.guards dominating the loop.
16034 auto *GuardDecl = Intrinsic::getDeclarationIfExists(
16035 SE.F.getParent(), Intrinsic::experimental_guard);
16036 if (GuardDecl)
16037 for (const auto *GU : GuardDecl->users())
16038 if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
16039 if (Guard->getFunction() == Block->getParent() &&
16040 SE.DT.dominates(Guard, Block))
16041 Terms.emplace_back(Guard->getArgOperand(0), true);
16042
16043 // Third, collect conditions from dominating branches. Starting at the loop
16044 // predecessor, climb up the predecessor chain, as long as there are
16045 // predecessors that can be found that have unique successors leading to the
16046 // original header.
16047 // TODO: share this logic with isLoopEntryGuardedByCond.
16048 unsigned NumCollectedConditions = 0;
16050 std::pair<const BasicBlock *, const BasicBlock *> Pair(Pred, Block);
16051 for (; Pair.first;
16052 Pair = SE.getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
16053 VisitedBlocks.insert(Pair.second);
16054 const CondBrInst *LoopEntryPredicate =
16055 dyn_cast<CondBrInst>(Pair.first->getTerminator());
16056 if (!LoopEntryPredicate)
16057 continue;
16058
16059 Terms.emplace_back(LoopEntryPredicate->getCondition(),
16060 LoopEntryPredicate->getSuccessor(0) == Pair.second);
16061 NumCollectedConditions++;
16062
16063 // If we are recursively collecting guards stop after 2
16064 // conditions to limit compile-time impact for now.
16065 if (Depth > 0 && NumCollectedConditions == 2)
16066 break;
16067 }
16068 // Finally, if we stopped climbing the predecessor chain because
16069 // there wasn't a unique one to continue, try to collect conditions
16070 // for PHINodes by recursively following all of their incoming
16071 // blocks and try to merge the found conditions to build a new one
16072 // for the Phi.
16073 if (Pair.second->hasNPredecessorsOrMore(2) &&
16075 SmallDenseMap<const BasicBlock *, LoopGuards> IncomingGuards;
16076 for (auto &Phi : Pair.second->phis())
16077 collectFromPHI(SE, Guards, Phi, VisitedBlocks, IncomingGuards, Depth);
16078 }
16079
16080 // Now apply the information from the collected conditions to
16081 // Guards.RewriteMap. Conditions are processed in reverse order, so the
16082 // earliest conditions is processed first, except guards with divisibility
16083 // information, which are moved to the back. This ensures the SCEVs with the
16084 // shortest dependency chains are constructed first.
16086 GuardsToProcess;
16087 for (auto [Term, EnterIfTrue] : reverse(Terms)) {
16088 SmallVector<Value *, 8> Worklist;
16089 SmallPtrSet<Value *, 8> Visited;
16090 Worklist.push_back(Term);
16091 while (!Worklist.empty()) {
16092 Value *Cond = Worklist.pop_back_val();
16093 if (!Visited.insert(Cond).second)
16094 continue;
16095
16096 if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
16097 auto Predicate =
16098 EnterIfTrue ? Cmp->getPredicate() : Cmp->getInversePredicate();
16099 const auto *LHS = SE.getSCEV(Cmp->getOperand(0));
16100 const auto *RHS = SE.getSCEV(Cmp->getOperand(1));
16101 // If LHS is a constant, apply information to the other expression.
16102 // TODO: If LHS is not a constant, check if using CompareSCEVComplexity
16103 // can improve results.
16104 if (isa<SCEVConstant>(LHS)) {
16105 std::swap(LHS, RHS);
16107 }
16108 GuardsToProcess.emplace_back(Predicate, LHS, RHS);
16109 continue;
16110 }
16111
16112 Value *L, *R;
16113 if (EnterIfTrue ? match(Cond, m_LogicalAnd(m_Value(L), m_Value(R)))
16114 : match(Cond, m_LogicalOr(m_Value(L), m_Value(R)))) {
16115 Worklist.push_back(L);
16116 Worklist.push_back(R);
16117 }
16118 }
16119 }
16120
16121 // Process divisibility guards in reverse order to populate DivGuards early.
16122 DenseMap<const SCEV *, APInt> Multiples;
16123 LoopGuards DivGuards(SE);
16124 for (const auto &[Predicate, LHS, RHS] : GuardsToProcess) {
16125 if (!isDivisibilityGuard(LHS, RHS, SE))
16126 continue;
16127 collectDivisibilityInformation(Predicate, LHS, RHS, DivGuards.RewriteMap,
16128 Multiples, SE);
16129 }
16130
16131 for (const auto &[Predicate, LHS, RHS] : GuardsToProcess)
16132 CollectCondition(Predicate, LHS, RHS, Guards.RewriteMap, DivGuards);
16133
16134 // Apply divisibility information last. This ensures it is applied to the
16135 // outermost expression after other rewrites for the given value.
16136 for (const auto &[K, Divisor] : Multiples) {
16137 const SCEV *DivisorSCEV = SE.getConstant(Divisor);
16138 Guards.RewriteMap[K] =
16140 Guards.rewrite(K), Divisor, SE),
16141 DivisorSCEV),
16142 DivisorSCEV);
16143 ExprsToRewrite.push_back(K);
16144 }
16145
16146 // Let the rewriter preserve NUW/NSW flags if the unsigned/signed ranges of
16147 // the replacement expressions are contained in the ranges of the replaced
16148 // expressions.
16149 Guards.PreserveNUW = true;
16150 Guards.PreserveNSW = true;
16151 for (const SCEV *Expr : ExprsToRewrite) {
16152 const SCEV *RewriteTo = Guards.RewriteMap[Expr];
16153 Guards.PreserveNUW &=
16154 SE.getUnsignedRange(Expr).contains(SE.getUnsignedRange(RewriteTo));
16155 Guards.PreserveNSW &=
16156 SE.getSignedRange(Expr).contains(SE.getSignedRange(RewriteTo));
16157 }
16158
16159 // Now that all rewrite information is collect, rewrite the collected
16160 // expressions with the information in the map. This applies information to
16161 // sub-expressions.
16162 if (ExprsToRewrite.size() > 1) {
16163 for (const SCEV *Expr : ExprsToRewrite) {
16164 const SCEV *RewriteTo = Guards.RewriteMap[Expr];
16165 Guards.RewriteMap.erase(Expr);
16166 Guards.RewriteMap.insert({Expr, Guards.rewrite(RewriteTo)});
16167 }
16168 }
16169}
16170
16172 /// A rewriter to replace SCEV expressions in Map with the corresponding entry
16173 /// in the map. It skips AddRecExpr because we cannot guarantee that the
16174 /// replacement is loop invariant in the loop of the AddRec.
16175 class SCEVLoopGuardRewriter
16176 : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
16179
16181
16182 public:
16183 SCEVLoopGuardRewriter(ScalarEvolution &SE,
16184 const ScalarEvolution::LoopGuards &Guards)
16185 : SCEVRewriteVisitor(SE), Map(Guards.RewriteMap),
16186 NotEqual(Guards.NotEqual) {
16187 if (Guards.PreserveNUW)
16188 FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNUW);
16189 if (Guards.PreserveNSW)
16190 FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNSW);
16191 }
16192
16193 const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
16194
16195 const SCEV *visitUnknown(const SCEVUnknown *Expr) {
16196 return Map.lookup_or(Expr, Expr);
16197 }
16198
16199 const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
16200 if (const SCEV *S = Map.lookup(Expr))
16201 return S;
16202
16203 // If we didn't find the extact ZExt expr in the map, check if there's
16204 // an entry for a smaller ZExt we can use instead.
16205 Type *Ty = Expr->getType();
16206 const SCEV *Op = Expr->getOperand(0);
16207 unsigned Bitwidth = Ty->getScalarSizeInBits() / 2;
16208 while (Bitwidth % 8 == 0 && Bitwidth >= 8 &&
16209 Bitwidth > Op->getType()->getScalarSizeInBits()) {
16210 Type *NarrowTy = IntegerType::get(SE.getContext(), Bitwidth);
16211 auto *NarrowExt = SE.getZeroExtendExpr(Op, NarrowTy);
16212 if (const SCEV *S = Map.lookup(NarrowExt))
16213 return SE.getZeroExtendExpr(S, Ty);
16214 Bitwidth = Bitwidth / 2;
16215 }
16216
16218 Expr);
16219 }
16220
16221 const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
16222 if (const SCEV *S = Map.lookup(Expr))
16223 return S;
16225 Expr);
16226 }
16227
16228 const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) {
16229 if (const SCEV *S = Map.lookup(Expr))
16230 return S;
16232 }
16233
16234 const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) {
16235 if (const SCEV *S = Map.lookup(Expr))
16236 return S;
16238 }
16239
16240 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
16241 // Helper to check if S is a subtraction (A - B) where A != B, and if so,
16242 // return UMax(S, 1).
16243 auto RewriteSubtraction = [&](const SCEV *S) -> const SCEV * {
16244 SCEVUse LHS, RHS;
16245 if (MatchBinarySub(S, LHS, RHS)) {
16246 if (LHS > RHS)
16247 std::swap(LHS, RHS);
16248 if (NotEqual.contains({LHS, RHS})) {
16249 const SCEV *OneAlignedUp = getNextSCEVDivisibleByDivisor(
16250 SE.getOne(S->getType()), SE.getConstantMultiple(S), SE);
16251 return SE.getUMaxExpr(OneAlignedUp, S);
16252 }
16253 }
16254 return nullptr;
16255 };
16256
16257 // Check if Expr itself is a subtraction pattern with guard info.
16258 if (const SCEV *Rewritten = RewriteSubtraction(Expr))
16259 return Rewritten;
16260
16261 // Trip count expressions sometimes consist of adding 3 operands, i.e.
16262 // (Const + A + B). There may be guard info for A + B, and if so, apply
16263 // it.
16264 // TODO: Could more generally apply guards to Add sub-expressions.
16265 if (isa<SCEVConstant>(Expr->getOperand(0)) &&
16266 Expr->getNumOperands() == 3) {
16267 const SCEV *Add =
16268 SE.getAddExpr(Expr->getOperand(1), Expr->getOperand(2));
16269 if (const SCEV *Rewritten = RewriteSubtraction(Add))
16270 return SE.getAddExpr(
16271 Expr->getOperand(0), Rewritten,
16272 ScalarEvolution::maskFlags(Expr->getNoWrapFlags(), FlagMask));
16273 if (const SCEV *S = Map.lookup(Add))
16274 return SE.getAddExpr(Expr->getOperand(0), S);
16275 }
16276 SmallVector<SCEVUse, 2> Operands;
16277 bool Changed = false;
16278 for (SCEVUse Op : Expr->operands()) {
16279 Operands.push_back(
16281 Changed |= Op != Operands.back();
16282 }
16283 // We are only replacing operands with equivalent values, so transfer the
16284 // flags from the original expression.
16285 return !Changed ? Expr
16286 : SE.getAddExpr(Operands,
16288 Expr->getNoWrapFlags(), FlagMask));
16289 }
16290
16291 const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
16292 SmallVector<SCEVUse, 2> Operands;
16293 bool Changed = false;
16294 for (SCEVUse Op : Expr->operands()) {
16295 Operands.push_back(
16297 Changed |= Op != Operands.back();
16298 }
16299 // We are only replacing operands with equivalent values, so transfer the
16300 // flags from the original expression.
16301 return !Changed ? Expr
16302 : SE.getMulExpr(Operands,
16304 Expr->getNoWrapFlags(), FlagMask));
16305 }
16306 };
16307
16308 if (RewriteMap.empty() && NotEqual.empty())
16309 return Expr;
16310
16311 SCEVLoopGuardRewriter Rewriter(SE, *this);
16312 return Rewriter.visit(Expr);
16313}
16314
16315const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
16316 return applyLoopGuards(Expr, LoopGuards::collect(L, *this));
16317}
16318
16320 const LoopGuards &Guards) {
16321 return Guards.rewrite(Expr);
16322}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
constexpr LLT S1
Rewrite undef for PHI
This file implements a class to represent arbitrary precision integral constant values and operations...
@ PostInc
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Expand Atomic instructions
#define X(NUM, ENUM, NAME)
Definition ELF.h:853
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< StatepointGC > D("statepoint-example", "an example strategy for statepoint")
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
#define LLVM_DUMP_METHOD
Mark debug helper function definitions like dump() that should not be stripped from debug builds.
Definition Compiler.h:661
This file contains the declarations for the subclasses of Constant, which represent the different fla...
SmallPtrSet< const BasicBlock *, 8 > VisitedBlocks
This file defines the DenseMap class.
This file builds on the ADT/GraphTraits.h file to build generic depth first graph iterator.
static bool isSigned(unsigned Opcode)
This file defines a hash set that can be used to remove duplication of nodes in a graph.
#define op(i)
Hexagon Common GEP
Value * getPointer(Value *Ptr)
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
This defines the Use class.
iv Induction Variable Users
Definition IVUsers.cpp:48
const AbstractManglingParser< Derived, Alloc >::OperatorInfo AbstractManglingParser< Derived, Alloc >::Ops[]
static bool isZero(Value *V, const DataLayout &DL, DominatorTree *DT, AssumptionCache *AC)
Definition Lint.cpp:539
#define F(x, y, z)
Definition MD5.cpp:54
#define I(x, y, z)
Definition MD5.cpp:57
#define G(x, y, z)
Definition MD5.cpp:55
#define T
#define T1
static constexpr unsigned SM(unsigned Version)
ConstantRange Range(APInt(BitWidth, Low), APInt(BitWidth, High))
uint64_t IntrinsicInst * II
#define P(N)
ppc ctr loops verify
PowerPC Reduce CR logical Operation
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition PassSupport.h:42
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition PassSupport.h:44
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition PassSupport.h:39
R600 Clause Merge
const SmallVectorImpl< MachineOperand > & Cond
static DominatorTree getDomTree(Function &F)
static bool isValid(const char C)
Returns true if C is a valid mangled character: <0-9a-zA-Z_>.
SI optimize exec mask operations pre RA
static void visit(BasicBlock &Start, std::function< bool(BasicBlock *)> op)
This file contains some templates that are useful if you are working with the STL at all.
This file provides utility classes that use RAII to save and restore values.
bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind, SCEVTypes RootKind)
static cl::opt< unsigned > MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden, cl::desc("Max coefficients in AddRec during evolving"), cl::init(8))
static cl::opt< unsigned > RangeIterThreshold("scev-range-iter-threshold", cl::Hidden, cl::desc("Threshold for switching to iteratively computing SCEV ranges"), cl::init(32))
static const Loop * isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI)
static unsigned getConstantTripCount(const SCEVConstant *ExitCount)
static int CompareValueComplexity(const LoopInfo *const LI, Value *LV, Value *RV, unsigned Depth)
Compare the two values LV and RV in terms of their "complexity" where "complexity" is a partial (and ...
static const SCEV * getNextSCEVDivisibleByDivisor(const SCEV *Expr, const APInt &DivisorVal, ScalarEvolution &SE)
static void PushLoopPHIs(const Loop *L, SmallVectorImpl< Instruction * > &Worklist, SmallPtrSetImpl< Instruction * > &Visited)
Push PHI nodes in the header of the given loop onto the given Worklist.
static void insertFoldCacheEntry(const ScalarEvolution::FoldID &ID, const SCEV *S, DenseMap< ScalarEvolution::FoldID, const SCEV * > &FoldCache, DenseMap< const SCEV *, SmallVector< ScalarEvolution::FoldID, 2 > > &FoldCacheUser)
static cl::opt< bool > ClassifyExpressions("scalar-evolution-classify-expressions", cl::Hidden, cl::init(true), cl::desc("When printing analysis, include information on every instruction"))
static bool hasHugeExpression(ArrayRef< SCEVUse > Ops)
Returns true if Ops contains a huge SCEV (the subtree of S contains at least HugeExprThreshold nodes)...
static bool CanConstantFold(const Instruction *I)
Return true if we can constant fold an instruction of the specified type, assuming that all operands ...
static cl::opt< unsigned > AddOpsInlineThreshold("scev-addops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining addition operands into a SCEV"), cl::init(500))
static cl::opt< unsigned > MaxLoopGuardCollectionDepth("scalar-evolution-max-loop-guard-collection-depth", cl::Hidden, cl::desc("Maximum depth for recursive loop guard collection"), cl::init(1))
static cl::opt< bool > VerifyIR("scev-verify-ir", cl::Hidden, cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"), cl::init(false))
static bool RangeRefPHIAllowedOperands(DominatorTree &DT, PHINode *PHI)
static const SCEV * getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static std::optional< APInt > MinOptional(std::optional< APInt > X, std::optional< APInt > Y)
Helper function to compare optional APInts: (a) if X and Y both exist, return min(X,...
static cl::opt< unsigned > MulOpsInlineThreshold("scev-mulops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining multiplication operands into a SCEV"), cl::init(32))
static BinaryOperator * getCommonInstForPHI(PHINode *PN)
static bool isDivisibilityGuard(const SCEV *LHS, const SCEV *RHS, ScalarEvolution &SE)
static std::optional< const SCEV * > createNodeForSelectViaUMinSeq(ScalarEvolution *SE, const SCEV *CondExpr, const SCEV *TrueExpr, const SCEV *FalseExpr)
static Constant * BuildConstantFromSCEV(const SCEV *V)
This builds up a Constant using the ConstantExpr interface.
static ConstantInt * EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C, ScalarEvolution &SE)
static const SCEV * BinomialCoefficient(const SCEV *It, unsigned K, ScalarEvolution &SE, Type *ResultTy)
Compute BC(It, K). The result has width W. Assume, K > 0.
static cl::opt< unsigned > MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden, cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"), cl::init(8))
static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr, const SCEV *Candidate)
Is MaybeMinMaxExpr an (U|S)(Min|Max) of Candidate and some other values?
static PHINode * getConstantEvolvingPHI(Value *V, const Loop *L)
getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node in the loop that V is deri...
static const SCEV * SolveLinEquationWithOverflow(const APInt &A, const SCEV *B, SmallVectorImpl< const SCEVPredicate * > *Predicates, ScalarEvolution &SE, const Loop *L)
Finds the minimum unsigned root of the following equation:
static cl::opt< unsigned > MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden, cl::desc("Maximum number of iterations SCEV will " "symbolically execute a constant " "derived loop"), cl::init(100))
static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow)
static void PrintSCEVWithTypeHint(raw_ostream &OS, const SCEV *S)
When printing a top-level SCEV for trip counts, it's helpful to include a type for constants which ar...
static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, const Loop *L)
static SCEV::NoWrapFlags StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, ArrayRef< SCEVUse > Ops, SCEV::NoWrapFlags Flags)
static bool containsConstantInAddMulChain(const SCEV *StartExpr)
Determine if any of the operands in this SCEV are a constant or if any of the add or multiply express...
static const SCEV * getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty, ScalarEvolution *SE, unsigned Depth)
static bool CollectAddOperandsWithScales(SmallDenseMap< SCEVUse, APInt, 16 > &M, SmallVectorImpl< SCEVUse > &NewOps, APInt &AccumulatedConstant, ArrayRef< SCEVUse > Ops, const APInt &Scale, ScalarEvolution &SE)
Process the given Ops list, which is a list of operands to be added under the given scale,...
static const SCEV * constantFoldAndGroupOps(ScalarEvolution &SE, LoopInfo &LI, DominatorTree &DT, SmallVectorImpl< SCEVUse > &Ops, FoldT Fold, IsIdentityT IsIdentity, IsAbsorberT IsAbsorber)
Performs a number of common optimizations on the passed Ops.
static cl::opt< unsigned > MaxPhiSCCAnalysisSize("scalar-evolution-max-scc-analysis-depth", cl::Hidden, cl::desc("Maximum amount of nodes to process while searching SCEVUnknown " "Phi strongly connected components"), cl::init(8))
static bool IsKnownPredicateViaAddRecStart(ScalarEvolution &SE, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
static void GroupByComplexity(SmallVectorImpl< SCEVUse > &Ops, LoopInfo *LI, DominatorTree &DT)
Given a list of SCEV objects, order them by their complexity, and group objects of the same complexit...
static bool collectDivisibilityInformation(ICmpInst::Predicate Predicate, const SCEV *LHS, const SCEV *RHS, DenseMap< const SCEV *, const SCEV * > &DivInfo, DenseMap< const SCEV *, APInt > &Multiples, ScalarEvolution &SE)
static cl::opt< unsigned > MaxSCEVOperationsImplicationDepth("scalar-evolution-max-scev-operations-implication-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV operations implication analysis"), cl::init(2))
static void PushDefUseChildren(Instruction *I, SmallVectorImpl< Instruction * > &Worklist, SmallPtrSetImpl< Instruction * > &Visited)
Push users of the given Instruction onto the given Worklist.
static std::optional< APInt > SolveQuadraticAddRecRange(const SCEVAddRecExpr *AddRec, const ConstantRange &Range, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n iterations.
static cl::opt< bool > UseContextForNoWrapFlagInference("scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden, cl::desc("Infer nuw/nsw flags using context where suitable"), cl::init(true))
static cl::opt< bool > EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden, cl::desc("Handle <= and >= in finite loops"), cl::init(true))
static bool getOperandsForSelectLikePHI(DominatorTree &DT, PHINode *PN, Value *&Cond, Value *&LHS, Value *&RHS)
static std::optional< std::tuple< APInt, APInt, APInt, APInt, unsigned > > GetQuadraticEquation(const SCEVAddRecExpr *AddRec)
For a given quadratic addrec, generate coefficients of the corresponding quadratic equation,...
static bool isKnownPredicateExtendIdiom(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
static std::optional< BinaryOp > MatchBinaryOp(Value *V, const DataLayout &DL, AssumptionCache &AC, const DominatorTree &DT, const Instruction *CxtI)
Try to map V into a BinaryOp, and return std::nullopt on failure.
static std::optional< APInt > SolveQuadraticAddRecExact(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE)
Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n iterations.
static std::optional< APInt > TruncIfPossible(std::optional< APInt > X, unsigned BitWidth)
Helper function to truncate an optional APInt to a given BitWidth.
static cl::opt< unsigned > MaxSCEVCompareDepth("scalar-evolution-max-scev-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive SCEV complexity comparisons"), cl::init(32))
static APInt extractConstantWithoutWrapping(ScalarEvolution &SE, const SCEVConstant *ConstantTerm, const SCEVAddExpr *WholeAddExpr)
static cl::opt< unsigned > MaxConstantEvolvingDepth("scalar-evolution-max-constant-evolving-depth", cl::Hidden, cl::desc("Maximum depth of recursive constant evolving"), cl::init(32))
static ConstantRange getRangeForAffineARHelper(APInt Step, const ConstantRange &StartRange, const APInt &MaxBECount, bool Signed)
static bool MatchBinarySub(const SCEV *S, SCEVUse &LHS, SCEVUse &RHS)
static std::optional< ConstantRange > GetRangeFromMetadata(Value *V)
Helper method to assign a range to V from metadata present in the IR.
static cl::opt< unsigned > HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden, cl::desc("Size of the expression which is considered huge"), cl::init(4096))
static Type * isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI, bool &Signed, ScalarEvolution &SE)
Helper function to createAddRecFromPHIWithCasts.
static Constant * EvaluateExpression(Value *V, const Loop *L, DenseMap< Instruction *, Constant * > &Vals, const DataLayout &DL, const TargetLibraryInfo *TLI)
EvaluateExpression - Given an expression that passes the getConstantEvolvingPHI predicate,...
static const SCEV * getPreviousSCEVDivisibleByDivisor(const SCEV *Expr, const APInt &DivisorVal, ScalarEvolution &SE)
static const SCEV * MatchNotExpr(const SCEV *Expr)
If Expr computes ~A, return A else return nullptr.
static cl::opt< unsigned > MaxValueCompareDepth("scalar-evolution-max-value-compare-depth", cl::Hidden, cl::desc("Maximum depth of recursive value complexity comparisons"), cl::init(2))
static const SCEV * applyDivisibilityOnMinMaxExpr(const SCEV *MinMaxExpr, APInt Divisor, ScalarEvolution &SE)
static cl::opt< bool, true > VerifySCEVOpt("verify-scev", cl::Hidden, cl::location(VerifySCEV), cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"))
static const SCEV * getSignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
static cl::opt< unsigned > MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden, cl::desc("Maximum depth of recursive arithmetics"), cl::init(32))
static bool HasSameValue(const SCEV *A, const SCEV *B)
SCEV structural equivalence is usually sufficient for testing whether two expressions are equal,...
static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow)
Compute the result of "n choose k", the binomial coefficient.
static std::optional< int > CompareSCEVComplexity(const LoopInfo *const LI, const SCEV *LHS, const SCEV *RHS, DominatorTree &DT, unsigned Depth=0)
static bool canConstantEvolve(Instruction *I, const Loop *L)
Determine whether this instruction can constant evolve within this loop assuming its operands can all...
static PHINode * getConstantEvolvingPHIOperands(Instruction *UseInst, const Loop *L, DenseMap< Instruction *, PHINode * > &PHIMap, unsigned Depth)
getConstantEvolvingPHIOperands - Implement getConstantEvolvingPHI by recursing through each instructi...
static bool scevUnconditionallyPropagatesPoisonFromOperands(SCEVTypes Kind)
static cl::opt< bool > VerifySCEVStrict("verify-scev-strict", cl::Hidden, cl::desc("Enable stricter verification with -verify-scev is passed"))
static Constant * getOtherIncomingValue(PHINode *PN, BasicBlock *BB)
static cl::opt< bool > UseExpensiveRangeSharpening("scalar-evolution-use-expensive-range-sharpening", cl::Hidden, cl::init(false), cl::desc("Use more powerful methods of sharpening expression ranges. May " "be costly in terms of compile time"))
static const SCEV * getUnsignedOverflowLimitForStep(const SCEV *Step, ICmpInst::Predicate *Pred, ScalarEvolution *SE)
static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Is LHS Pred RHS true on the virtue of LHS or RHS being a Min or Max expression?
static bool BrPHIToSelect(DominatorTree &DT, CondBrInst *BI, PHINode *Merge, Value *&C, Value *&LHS, Value *&RHS)
This file defines the make_scope_exit function, which executes user-defined cleanup logic at scope ex...
static bool InBlock(const Value *V, const BasicBlock *BB)
Provides some synthesis utilities to produce sequences of values.
This file defines the SmallPtrSet class.
This file defines the SmallVector class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition Statistic.h:171
This file contains some functions that are useful when dealing with strings.
#define LLVM_DEBUG(...)
Definition Debug.h:119
static TableGen::Emitter::Opt Y("gen-skeleton-entry", EmitSkeleton, "Generate example skeleton entry")
static SymbolRef::Type getType(const Symbol *Sym)
Definition TapiFile.cpp:39
LocallyHashedType DenseMapInfo< LocallyHashedType >::Empty
static std::optional< bool > isImpliedCondOperands(CmpInst::Predicate Pred, const Value *ALHS, const Value *ARHS, const Value *BLHS, const Value *BRHS)
Return true if "icmp Pred BLHS BRHS" is true whenever "icmp PredALHS ARHS" is true.
Virtual Register Rewriter
Value * RHS
Value * LHS
BinaryOperator * Mul
static const uint32_t IV[8]
Definition blake3_impl.h:83
SCEVCastSinkingRewriter(ScalarEvolution &SE, Type *TargetTy, ConversionFn CreatePtrCast)
static const SCEV * rewrite(const SCEV *Scev, ScalarEvolution &SE, Type *TargetTy, ConversionFn CreatePtrCast)
const SCEV * visitUnknown(const SCEVUnknown *Expr)
const SCEV * visitMulExpr(const SCEVMulExpr *Expr)
const SCEV * visitAddExpr(const SCEVAddExpr *Expr)
const SCEV * visit(const SCEV *S)
Class for arbitrary precision integers.
Definition APInt.h:78
LLVM_ABI APInt umul_ov(const APInt &RHS, bool &Overflow) const
Definition APInt.cpp:2023
LLVM_ABI APInt zext(unsigned width) const
Zero extend to a new width.
Definition APInt.cpp:1055
bool isMinSignedValue() const
Determine if this is the smallest signed value.
Definition APInt.h:424
uint64_t getZExtValue() const
Get zero extended value.
Definition APInt.h:1563
void setHighBits(unsigned hiBits)
Set the top hiBits bits.
Definition APInt.h:1414
LLVM_ABI APInt getHiBits(unsigned numBits) const
Compute an APInt containing numBits highbits from this APInt.
Definition APInt.cpp:640
unsigned getActiveBits() const
Compute the number of active bits in the value.
Definition APInt.h:1535
LLVM_ABI APInt trunc(unsigned width) const
Truncate to new width.
Definition APInt.cpp:968
static APInt getMaxValue(unsigned numBits)
Gets maximum unsigned value of APInt for specific bit width.
Definition APInt.h:207
APInt abs() const
Get the absolute value.
Definition APInt.h:1818
bool sgt(const APInt &RHS) const
Signed greater than comparison.
Definition APInt.h:1208
bool isAllOnes() const
Determine if all bits are set. This is true for zero-width values.
Definition APInt.h:372
bool ugt(const APInt &RHS) const
Unsigned greater than comparison.
Definition APInt.h:1189
bool isZero() const
Determine if this value is zero, i.e. all bits are clear.
Definition APInt.h:381
bool isSignMask() const
Check if the APInt's value is returned by getSignMask.
Definition APInt.h:467
LLVM_ABI APInt urem(const APInt &RHS) const
Unsigned remainder operation.
Definition APInt.cpp:1709
unsigned getBitWidth() const
Return the number of bits in the APInt.
Definition APInt.h:1511
bool ult(const APInt &RHS) const
Unsigned less than comparison.
Definition APInt.h:1118
static APInt getSignedMaxValue(unsigned numBits)
Gets maximum signed value of APInt for a specific bit width.
Definition APInt.h:210
static APInt getMinValue(unsigned numBits)
Gets minimum unsigned value of APInt for a specific bit width.
Definition APInt.h:217
bool isNegative() const
Determine sign of this APInt.
Definition APInt.h:330
bool sle(const APInt &RHS) const
Signed less or equal comparison.
Definition APInt.h:1173
static APInt getSignedMinValue(unsigned numBits)
Gets minimum signed value of APInt for a specific bit width.
Definition APInt.h:220
bool isNonPositive() const
Determine if this APInt Value is non-positive (<= 0).
Definition APInt.h:362
unsigned countTrailingZeros() const
Definition APInt.h:1670
bool isStrictlyPositive() const
Determine if this APInt Value is positive.
Definition APInt.h:357
unsigned logBase2() const
Definition APInt.h:1784
APInt ashr(unsigned ShiftAmt) const
Arithmetic right-shift function.
Definition APInt.h:834
LLVM_ABI APInt multiplicativeInverse() const
Definition APInt.cpp:1317
bool ule(const APInt &RHS) const
Unsigned less or equal comparison.
Definition APInt.h:1157
LLVM_ABI APInt sext(unsigned width) const
Sign extend to a new width.
Definition APInt.cpp:1028
APInt shl(unsigned shiftAmt) const
Left-shift function.
Definition APInt.h:880
bool isPowerOf2() const
Check if this APInt's value is a power of two greater than zero.
Definition APInt.h:441
static APInt getLowBitsSet(unsigned numBits, unsigned loBitsSet)
Constructs an APInt value that has the bottom loBitsSet bits set.
Definition APInt.h:307
bool isSignBitSet() const
Determine if sign bit of this APInt is set.
Definition APInt.h:342
bool slt(const APInt &RHS) const
Signed less than comparison.
Definition APInt.h:1137
static APInt getZero(unsigned numBits)
Get the '0' value for the specified bit-width.
Definition APInt.h:201
bool isIntN(unsigned N) const
Check if this APInt has an N-bits unsigned integer value.
Definition APInt.h:433
static APInt getOneBitSet(unsigned numBits, unsigned BitNo)
Return an APInt with exactly one bit set in the result.
Definition APInt.h:240
bool uge(const APInt &RHS) const
Unsigned greater or equal comparison.
Definition APInt.h:1228
This templated class represents "all analyses that operate over <aparticular IR unit>" (e....
Definition Analysis.h:50
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Represent the analysis usage information of a pass.
void setPreservesAll()
Set by analyses that do not transform their input at all.
AnalysisUsage & addRequiredTransitive()
Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition ArrayRef.h:40
iterator end() const
Definition ArrayRef.h:130
size_t size() const
Get the array size.
Definition ArrayRef.h:141
iterator begin() const
Definition ArrayRef.h:129
A function analysis which provides an AssumptionCache.
An immutable pass that tracks lazily created AssumptionCache objects.
A cache of @llvm.assume calls within a function.
MutableArrayRef< WeakVH > assumptions()
Access the list of assumption handles currently tracked for this function.
LLVM Basic Block Representation.
Definition BasicBlock.h:62
iterator begin()
Instruction iterator methods.
Definition BasicBlock.h:461
const Function * getParent() const
Return the enclosing method, or null if none.
Definition BasicBlock.h:213
LLVM_ABI const BasicBlock * getSinglePredecessor() const
Return the predecessor of this block if it has a single predecessor block.
const Instruction & front() const
Definition BasicBlock.h:484
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction; assumes that the block is well-formed.
Definition BasicBlock.h:237
LLVM_ABI unsigned getNoWrapKind() const
Returns one of OBO::NoSignedWrap or OBO::NoUnsignedWrap.
LLVM_ABI Instruction::BinaryOps getBinaryOp() const
Returns the binary operation underlying the intrinsic.
BinaryOps getOpcode() const
Definition InstrTypes.h:374
This class represents a function call, abstracting a target machine's calling convention.
virtual void deleted()
Callback for Value destruction.
void setValPtr(Value *P)
bool isFalseWhenEqual() const
This is just a convenience.
Definition InstrTypes.h:948
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Definition InstrTypes.h:676
@ ICMP_SLT
signed less than
Definition InstrTypes.h:705
@ ICMP_SLE
signed less or equal
Definition InstrTypes.h:706
@ ICMP_UGE
unsigned greater or equal
Definition InstrTypes.h:700
@ ICMP_UGT
unsigned greater than
Definition InstrTypes.h:699
@ ICMP_SGT
signed greater than
Definition InstrTypes.h:703
@ ICMP_ULT
unsigned less than
Definition InstrTypes.h:701
@ ICMP_NE
not equal
Definition InstrTypes.h:698
@ ICMP_SGE
signed greater or equal
Definition InstrTypes.h:704
@ ICMP_ULE
unsigned less or equal
Definition InstrTypes.h:702
bool isSigned() const
Definition InstrTypes.h:930
Predicate getSwappedPredicate() const
For example, EQ->EQ, SLE->SGE, ULT->UGT, OEQ->OEQ, ULE->UGE, OLT->OGT, etc.
Definition InstrTypes.h:827
bool isTrueWhenEqual() const
This is just a convenience.
Definition InstrTypes.h:942
Predicate getInversePredicate() const
For example, EQ -> NE, UGT -> ULE, SLT -> SGE, OEQ -> UNE, UGT -> OLE, OLT -> UGE,...
Definition InstrTypes.h:789
bool isUnsigned() const
Definition InstrTypes.h:936
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
Definition InstrTypes.h:926
An abstraction over a floating-point predicate, and a pack of an integer predicate with samesign info...
static LLVM_ABI std::optional< CmpPredicate > getMatching(CmpPredicate A, CmpPredicate B)
Compares two CmpPredicates taking samesign into account and returns the canonicalized CmpPredicate if...
LLVM_ABI CmpInst::Predicate getPreferredSignedPredicate() const
Attempts to return a signed CmpInst::Predicate from the CmpPredicate.
CmpInst::Predicate dropSameSign() const
Drops samesign information.
Conditional Branch instruction.
Value * getCondition() const
BasicBlock * getSuccessor(unsigned i) const
static LLVM_ABI Constant * getNot(Constant *C)
static Constant * getPtrAdd(Constant *Ptr, Constant *Offset, GEPNoWrapFlags NW=GEPNoWrapFlags::none(), std::optional< ConstantRange > InRange=std::nullopt, Type *OnlyIfReduced=nullptr)
Create a getelementptr i8, ptr, offset constant expression.
Definition Constants.h:1478
static LLVM_ABI Constant * getPtrToInt(Constant *C, Type *Ty, bool OnlyIfReduced=false)
static LLVM_ABI Constant * getPtrToAddr(Constant *C, Type *Ty, bool OnlyIfReduced=false)
static LLVM_ABI Constant * getAdd(Constant *C1, Constant *C2, bool HasNUW=false, bool HasNSW=false)
static LLVM_ABI Constant * getNeg(Constant *C, bool HasNSW=false)
static LLVM_ABI Constant * getTrunc(Constant *C, Type *Ty, bool OnlyIfReduced=false)
This is the shared class of boolean and integer constants.
Definition Constants.h:87
bool isZero() const
This is just a convenience method to make client code smaller for a common code.
Definition Constants.h:219
static LLVM_ABI ConstantInt * getFalse(LLVMContext &Context)
uint64_t getZExtValue() const
Return the constant as a 64-bit unsigned integer value after it has been zero extended as appropriate...
Definition Constants.h:168
const APInt & getValue() const
Return the constant as an APInt value reference.
Definition Constants.h:159
static LLVM_ABI ConstantInt * getBool(LLVMContext &Context, bool V)
This class represents a range of values.
LLVM_ABI ConstantRange add(const ConstantRange &Other) const
Return a new range representing the possible values resulting from an addition of a value in this ran...
LLVM_ABI ConstantRange zextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
PreferredRangeType
If represented precisely, the result of some range operations may consist of multiple disjoint ranges...
LLVM_ABI bool getEquivalentICmp(CmpInst::Predicate &Pred, APInt &RHS) const
Set up Pred and RHS such that ConstantRange::makeExactICmpRegion(Pred, RHS) == *this.
const APInt & getLower() const
Return the lower value for this range.
LLVM_ABI ConstantRange urem(const ConstantRange &Other) const
Return a new range representing the possible values resulting from an unsigned remainder operation of...
LLVM_ABI bool isFullSet() const
Return true if this set contains all of the elements possible for this data-type.
LLVM_ABI bool icmp(CmpInst::Predicate Pred, const ConstantRange &Other) const
Does the predicate Pred hold between ranges this and Other?
LLVM_ABI bool isEmptySet() const
Return true if this set contains no members.
LLVM_ABI ConstantRange zeroExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
LLVM_ABI bool isSignWrappedSet() const
Return true if this set wraps around the signed domain.
LLVM_ABI APInt getSignedMin() const
Return the smallest signed value contained in the ConstantRange.
LLVM_ABI bool isWrappedSet() const
Return true if this set wraps around the unsigned domain.
LLVM_ABI void print(raw_ostream &OS) const
Print out the bounds to a stream.
LLVM_ABI ConstantRange truncate(uint32_t BitWidth, unsigned NoWrapKind=0) const
Return a new range in the specified integer type, which must be strictly smaller than the current typ...
LLVM_ABI ConstantRange signExtend(uint32_t BitWidth) const
Return a new range in the specified integer type, which must be strictly larger than the current type...
const APInt & getUpper() const
Return the upper value for this range.
LLVM_ABI ConstantRange unionWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the union of this range with another range.
static LLVM_ABI ConstantRange makeExactICmpRegion(CmpInst::Predicate Pred, const APInt &Other)
Produce the exact range such that all values in the returned range satisfy the given predicate with a...
LLVM_ABI bool contains(const APInt &Val) const
Return true if the specified value is in the set.
LLVM_ABI APInt getUnsignedMax() const
Return the largest unsigned value contained in the ConstantRange.
LLVM_ABI ConstantRange intersectWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the intersection of this range with another range.
LLVM_ABI APInt getSignedMax() const
Return the largest signed value contained in the ConstantRange.
static ConstantRange getNonEmpty(APInt Lower, APInt Upper)
Create non-empty constant range with the given bounds.
static LLVM_ABI ConstantRange makeGuaranteedNoWrapRegion(Instruction::BinaryOps BinOp, const ConstantRange &Other, unsigned NoWrapKind)
Produce the largest range containing all X such that "X BinOp Y" is guaranteed not to wrap (overflow)...
LLVM_ABI unsigned getMinSignedBits() const
Compute the maximal number of bits needed to represent every value in this signed range.
uint32_t getBitWidth() const
Get the bit width of this ConstantRange.
LLVM_ABI ConstantRange sub(const ConstantRange &Other) const
Return a new range representing the possible values resulting from a subtraction of a value in this r...
LLVM_ABI ConstantRange sextOrTrunc(uint32_t BitWidth) const
Make this range have the bit width given by BitWidth.
static LLVM_ABI ConstantRange makeExactNoWrapRegion(Instruction::BinaryOps BinOp, const APInt &Other, unsigned NoWrapKind)
Produce the range that contains X if and only if "X BinOp Other" does not wrap.
This is an important base class in LLVM.
Definition Constant.h:43
A parsed version of the target data layout string in and methods for querying it.
Definition DataLayout.h:64
LLVM_ABI const StructLayout * getStructLayout(StructType *Ty) const
Returns a StructLayout object, indicating the alignment of the struct, its size, and the offsets of i...
LLVM_ABI IntegerType * getIntPtrType(LLVMContext &C, unsigned AddressSpace=0) const
Returns an integer type with size at least as big as that of a pointer in the given address space.
LLVM_ABI unsigned getIndexTypeSizeInBits(Type *Ty) const
The size in bits of the index used in GEP calculation for this type.
LLVM_ABI IntegerType * getIndexType(LLVMContext &C, unsigned AddressSpace) const
Returns the type of a GEP index in AddressSpace.
TypeSize getTypeSizeInBits(Type *Ty) const
Size examples:
Definition DataLayout.h:791
ValueT lookup(const_arg_type_t< KeyT > Val) const
Return the entry for the specified key, or a default constructed value if no such entry exists.
Definition DenseMap.h:205
iterator find(const_arg_type_t< KeyT > Val)
Definition DenseMap.h:178
std::pair< iterator, bool > try_emplace(KeyT &&Key, Ts &&...Args)
Definition DenseMap.h:254
DenseMapIterator< KeyT, ValueT, KeyInfoT, BucketT > iterator
Definition DenseMap.h:74
iterator find_as(const LookupKeyT &Val)
Alternate version of find() which allows a different, and possibly less expensive,...
Definition DenseMap.h:191
size_type count(const_arg_type_t< KeyT > Val) const
Return 1 if the specified key is in the map, 0 otherwise.
Definition DenseMap.h:174
iterator end()
Definition DenseMap.h:81
bool contains(const_arg_type_t< KeyT > Val) const
Return true if the specified key is in the map, false otherwise.
Definition DenseMap.h:169
void swap(DerivedT &RHS)
Definition DenseMap.h:368
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition DenseMap.h:239
Analysis pass which computes a DominatorTree.
Definition Dominators.h:278
Legacy analysis pass which computes a DominatorTree.
Definition Dominators.h:314
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition Dominators.h:159
LLVM_ABI bool isReachableFromEntry(const Use &U) const
Provide an overload for a Use.
LLVM_ABI bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
This class describes a reference to an interned FoldingSetNodeID, which can be a useful to store node...
Definition FoldingSet.h:171
This class is used to gather all the unique data bits of a node.
Definition FoldingSet.h:208
FunctionPass(char &pid)
Definition Pass.h:316
Represents flags for the getelementptr instruction/expression.
bool hasNoUnsignedSignedWrap() const
bool hasNoUnsignedWrap() const
static GEPNoWrapFlags none()
static LLVM_ABI Type * getTypeAtIndex(Type *Ty, Value *Idx)
Return the type of the element at the given index of an indexable type.
Module * getParent()
Get the module that this global value is contained inside of...
static bool isPrivateLinkage(LinkageTypes Linkage)
static bool isInternalLinkage(LinkageTypes Linkage)
This instruction compares its operands according to the predicate given to the constructor.
CmpPredicate getCmpPredicate() const
static bool isGE(Predicate P)
Return true if the predicate is SGE or UGE.
CmpPredicate getSwappedCmpPredicate() const
static LLVM_ABI bool compare(const APInt &LHS, const APInt &RHS, ICmpInst::Predicate Pred)
Return result of LHS Pred RHS comparison.
static bool isLT(Predicate P)
Return true if the predicate is SLT or ULT.
CmpPredicate getInverseCmpPredicate() const
Predicate getNonStrictCmpPredicate() const
For example, SGT -> SGE, SLT -> SLE, ULT -> ULE, UGT -> UGE.
static bool isGT(Predicate P)
Return true if the predicate is SGT or UGT.
Predicate getFlippedSignednessPredicate() const
For example, SLT->ULT, ULT->SLT, SLE->ULE, ULE->SLE, EQ->EQ.
static CmpPredicate getInverseCmpPredicate(CmpPredicate Pred)
static bool isEquality(Predicate P)
Return true if this predicate is either EQ or NE.
bool isRelational() const
Return true if the predicate is relational (not EQ or NE).
static bool isLE(Predicate P)
Return true if the predicate is SLE or ULE.
LLVM_ABI bool hasNoUnsignedWrap() const LLVM_READONLY
Determine whether the no unsigned wrap flag is set.
LLVM_ABI bool hasNoSignedWrap() const LLVM_READONLY
Determine whether the no signed wrap flag is set.
LLVM_ABI bool isIdenticalToWhenDefined(const Instruction *I, bool IntersectAttrs=false) const LLVM_READONLY
This is like isIdenticalTo, except that it ignores the SubclassOptionalData flags,...
Class to represent integer types.
static LLVM_ABI IntegerType * get(LLVMContext &C, unsigned NumBits)
This static method is the primary way of constructing an IntegerType.
Definition Type.cpp:354
A helper class to return the specified delimiter string after the first invocation of operator String...
An instruction for reading from memory.
Analysis pass that exposes the LoopInfo for a function.
Definition LoopInfo.h:587
bool contains(const LoopT *L) const
Return true if the specified loop is contained within in this loop.
BlockT * getHeader() const
unsigned getLoopDepth() const
Return the nesting level of this loop.
BlockT * getLoopPredecessor() const
If the given loop's header has exactly one unique predecessor outside the loop, return it.
LoopT * getParentLoop() const
Return the parent loop if it exists or nullptr for top level loops.
unsigned getLoopDepth(const BlockT *BB) const
Return the loop nesting level of the specified block.
LoopT * getLoopFor(const BlockT *BB) const
Return the inner most loop that BB lives in.
The legacy pass manager's analysis pass to compute loop information.
Definition LoopInfo.h:612
Represents a single loop in the control flow graph.
Definition LoopInfo.h:40
bool isLoopInvariant(const Value *V) const
Return true if the specified value is loop invariant.
Definition LoopInfo.cpp:67
Metadata node.
Definition Metadata.h:1080
A Module instance is used to store all the information related to an LLVM module.
Definition Module.h:67
unsigned getOpcode() const
Return the opcode for this Instruction or ConstantExpr.
Definition Operator.h:43
Utility class for integer operators which may exhibit overflow - Add, Sub, Mul, and Shl.
Definition Operator.h:78
bool hasNoSignedWrap() const
Test whether this operation is known to never undergo signed overflow, aka the nsw property.
Definition Operator.h:111
bool hasNoUnsignedWrap() const
Test whether this operation is known to never undergo unsigned overflow, aka the nuw property.
Definition Operator.h:105
iterator_range< const_block_iterator > blocks() const
op_range incoming_values()
Value * getIncomingValueForBlock(const BasicBlock *BB) const
BasicBlock * getIncomingBlock(unsigned i) const
Return incoming basic block number i.
Value * getIncomingValue(unsigned i) const
Return incoming value number x.
unsigned getNumIncomingValues() const
Return the number of incoming edges.
AnalysisType & getAnalysis() const
getAnalysis<AnalysisType>() - This function is used by subclasses to get to the analysis information ...
PointerIntPair - This class implements a pair of a pointer and small integer.
static PointerType * getUnqual(Type *ElementType)
This constructs a pointer to an object of the specified type in the default address space (address sp...
static LLVM_ABI PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
LLVM_ABI void addPredicate(const SCEVPredicate &Pred)
Adds a new predicate.
LLVM_ABI const SCEVPredicate & getPredicate() const
LLVM_ABI const SCEV * getPredicatedSCEV(const SCEV *Expr)
Returns the rewritten SCEV for Expr in the context of the current SCEV predicate.
LLVM_ABI bool hasNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags)
Returns true if we've proved that V doesn't wrap by means of a SCEV predicate.
LLVM_ABI void setNoOverflow(Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags)
Proves that V doesn't overflow by adding SCEV predicate.
LLVM_ABI void print(raw_ostream &OS, unsigned Depth) const
Print the SCEV mappings done by the Predicated Scalar Evolution.
LLVM_ABI bool areAddRecsEqualWithPreds(const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const
Check if AR1 and AR2 are equal, while taking into account Equal predicates in Preds.
LLVM_ABI PredicatedScalarEvolution(ScalarEvolution &SE, Loop &L)
LLVM_ABI const SCEVAddRecExpr * getAsAddRec(Value *V)
Attempts to produce an AddRecExpr for V by adding additional SCEV predicates.
LLVM_ABI unsigned getSmallConstantMaxTripCount()
Returns the upper bound of the loop trip count as a normal unsigned value, or 0 if the trip count is ...
LLVM_ABI const SCEV * getBackedgeTakenCount()
Get the (predicated) backedge count for the analyzed loop.
LLVM_ABI const SCEV * getSymbolicMaxBackedgeTakenCount()
Get the (predicated) symbolic max backedge count for the analyzed loop.
LLVM_ABI const SCEV * getSCEV(Value *V)
Returns the SCEV expression of V, in the context of the current SCEV predicate.
A set of analyses that are preserved following a run of a transformation pass.
Definition Analysis.h:112
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition Analysis.h:118
PreservedAnalysisChecker getChecker() const
Build a checker for this PreservedAnalyses and the specified analysis type.
Definition Analysis.h:275
constexpr bool isValid() const
Definition Register.h:112
This node represents an addition of some number of SCEVs.
This node represents a polynomial recurrence on the trip count of the specified loop.
LLVM_ABI const SCEV * evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const
Return the value of this chain of recurrences at the specified iteration number.
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a recurrence without clearing any previously set flags.
bool isAffine() const
Return true if this represents an expression A + B*x where A and B are loop invariant values.
bool isQuadratic() const
Return true if this represents an expression A + B*x + C*x^2 where A, B and C are loop invariant valu...
LLVM_ABI const SCEV * getNumIterationsInRange(const ConstantRange &Range, ScalarEvolution &SE) const
Return the number of iterations of this loop that produce values in the specified constant range.
LLVM_ABI const SCEVAddRecExpr * getPostIncExpr(ScalarEvolution &SE) const
Return an expression representing the value of this expression one iteration of the loop ahead.
SCEVUse getStepRecurrence(ScalarEvolution &SE) const
Constructs and returns the recurrence indicating how much this expression steps by.
This is the base class for unary cast operator classes.
LLVM_ABI SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, SCEVUse op, Type *ty)
void setNoWrapFlags(NoWrapFlags Flags)
Set flags for a non-recurrence without clearing previously set flags.
This class represents an assumption that the expression LHS Pred RHS evaluates to true,...
SCEVComparePredicate(const FoldingSetNodeIDRef ID, const ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Implementation of the SCEVPredicate interface.
This class represents a constant integer value.
ConstantInt * getValue() const
const APInt & getAPInt() const
This is the base class for unary integral cast operator classes.
LLVM_ABI SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, SCEVUse op, Type *ty)
This node is the base class min/max selections.
static enum SCEVTypes negate(enum SCEVTypes T)
This node represents multiplication of some number of SCEVs.
This node is a base class providing common functionality for n'ary operators.
ArrayRef< SCEVUse > operands() const
NoWrapFlags getNoWrapFlags(NoWrapFlags Mask=NoWrapMask) const
SCEVUse getOperand(unsigned i) const
This class represents an assumption made using SCEV expressions which can be checked at run-time.
SCEVPredicate(const SCEVPredicate &)=default
virtual bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const =0
Returns true if this predicate implies N.
SCEVPredicateKind Kind
This class represents a cast from a pointer to a pointer-sized integer value, without capturing the p...
This class represents a cast from a pointer to a pointer-sized integer value.
This visitor recursively visits a SCEV expression and re-writes it.
const SCEV * visitSignExtendExpr(const SCEVSignExtendExpr *Expr)
const SCEV * visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr)
const SCEV * visitSMinExpr(const SCEVSMinExpr *Expr)
const SCEV * visitUMinExpr(const SCEVUMinExpr *Expr)
This class represents a signed minimum selection.
This node is the base class for sequential/in-order min/max selections.
static SCEVTypes getEquivalentNonSequentialSCEVType(SCEVTypes Ty)
This class represents a sign extension of a small integer value to a larger integer value.
Visit all nodes in the expression tree using worklist traversal.
This class represents a truncation of an integer value to a smaller integer value.
This class represents a binary unsigned division operation.
This class represents an unsigned minimum selection.
This class represents a composition of other SCEV predicates, and is the class that most clients will...
void print(raw_ostream &OS, unsigned Depth) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Returns true if this predicate implies N.
SCEVUnionPredicate(ArrayRef< const SCEVPredicate * > Preds, ScalarEvolution &SE)
Union predicates don't get cached so create a dummy set ID for it.
bool isAlwaysTrue() const override
Implementation of the SCEVPredicate interface.
This means that we are dealing with an entirely unknown SCEV value, and only represent it as its LLVM...
This class represents the value of vscale, as used when defining the length of a scalable vector or r...
This class represents an assumption made on an AddRec expression.
IncrementWrapFlags
Similar to SCEV::NoWrapFlags, but with slightly different semantics for FlagNUSW.
SCEVWrapPredicate(const FoldingSetNodeIDRef ID, const SCEVAddRecExpr *AR, IncrementWrapFlags Flags)
bool implies(const SCEVPredicate *N, ScalarEvolution &SE) const override
Returns true if this predicate implies N.
static SCEVWrapPredicate::IncrementWrapFlags setFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OnFlags)
void print(raw_ostream &OS, unsigned Depth=0) const override
Prints a textual representation of this predicate with an indentation of Depth.
bool isAlwaysTrue() const override
Returns true if the predicate is always true.
const SCEVAddRecExpr * getExpr() const
Implementation of the SCEVPredicate interface.
static SCEVWrapPredicate::IncrementWrapFlags clearFlags(SCEVWrapPredicate::IncrementWrapFlags Flags, SCEVWrapPredicate::IncrementWrapFlags OffFlags)
Convenient IncrementWrapFlags manipulation methods.
static SCEVWrapPredicate::IncrementWrapFlags getImpliedFlags(const SCEVAddRecExpr *AR, ScalarEvolution &SE)
Returns the set of SCEVWrapPredicate no wrap flags implied by a SCEVAddRecExpr.
IncrementWrapFlags getFlags() const
Returns the set assumed no overflow flags.
This class represents a zero extension of a small integer value to a larger integer value.
This class represents an analyzed expression in the program.
unsigned short getExpressionSize() const
SCEVNoWrapFlags NoWrapFlags
LLVM_ABI bool isOne() const
Return true if the expression is a constant one.
static constexpr auto FlagNUW
LLVM_ABI void computeAndSetCanonical(ScalarEvolution &SE)
Compute and set the canonical SCEV, by constructing a SCEV with the same operands,...
LLVM_ABI bool isZero() const
Return true if the expression is a constant zero.
const SCEV * CanonicalSCEV
Pointer to the canonical version of the SCEV, i.e.
static constexpr auto FlagAnyWrap
LLVM_ABI void dump() const
This method is used for debugging.
LLVM_ABI bool isAllOnesValue() const
Return true if the expression is a constant all-ones value.
LLVM_ABI bool isNonConstantNegative() const
Return true if the specified scev is negated, but not a constant.
static constexpr auto FlagNSW
LLVM_ABI ArrayRef< SCEVUse > operands() const
Return operands of this SCEV expression.
LLVM_ABI void print(raw_ostream &OS) const
Print out the internal representation of this scalar to the specified stream.
SCEV(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, unsigned short ExpressionSize)
SCEVTypes getSCEVType() const
static constexpr auto FlagNW
LLVM_ABI Type * getType() const
Return the LLVM type of this SCEV expression.
Analysis pass that exposes the ScalarEvolution for a function.
LLVM_ABI ScalarEvolution run(Function &F, FunctionAnalysisManager &AM)
LLVM_ABI PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
LLVM_ABI PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
void print(raw_ostream &OS, const Module *=nullptr) const override
print - Print out the internal state of the pass.
bool runOnFunction(Function &F) override
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
void releaseMemory() override
releaseMemory() - This member can be implemented by a pass if it wants to be able to release its memo...
void verifyAnalysis() const override
verifyAnalysis() - This member can be implemented by a analysis pass to check state of analysis infor...
static LLVM_ABI LoopGuards collect(const Loop *L, ScalarEvolution &SE)
Collect rewrite map for loop guards for loop L, together with flags indicating if NUW and NSW can be ...
LLVM_ABI const SCEV * rewrite(const SCEV *Expr) const
Try to apply the collected loop guards to Expr.
The main scalar evolution driver.
LLVM_ABI const SCEV * getUDivExpr(SCEVUse LHS, SCEVUse RHS)
Get a canonical unsigned division expression, or something simpler if possible.
const SCEV * getConstantMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEVConstant that is greater than or equal to (i.e.
static bool hasFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags TestFlags)
const DataLayout & getDataLayout() const
Return the DataLayout associated with the module this SCEV instance is operating on.
LLVM_ABI bool isKnownNonNegative(const SCEV *S)
Test if the given expression is known to be non-negative.
LLVM_ABI bool isKnownOnEveryIteration(CmpPredicate Pred, const SCEVAddRecExpr *LHS, const SCEV *RHS)
Test if the condition described by Pred, LHS, RHS is known to be true on every iteration of the loop ...
LLVM_ABI const SCEV * getNegativeSCEV(const SCEV *V, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap)
Return the SCEV object corresponding to -V.
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterationsImpl(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
LLVM_ABI const SCEV * getUDivCeilSCEV(const SCEV *N, const SCEV *D)
Compute ceil(N / D).
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantExitCondDuringFirstIterations(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI, const SCEV *MaxIter)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L at given Context duri...
LLVM_ABI Type * getWiderType(Type *Ty1, Type *Ty2) const
LLVM_ABI const SCEV * getAbsExpr(const SCEV *Op, bool IsNSW)
LLVM_ABI bool isKnownNonPositive(const SCEV *S)
Test if the given expression is known to be non-positive.
LLVM_ABI bool isKnownNegative(const SCEV *S)
Test if the given expression is known to be negative.
LLVM_ABI const SCEV * getPredicatedConstantMaxBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getConstantMaxBackedgeTakenCount, except it will add a set of SCEV predicates to Predicate...
LLVM_ABI const SCEV * removePointerBase(const SCEV *S)
Compute an expression equivalent to S - getPointerBase(S).
LLVM_ABI bool isLoopEntryGuardedByCond(const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the loop is protected by a conditional between LHS and RHS.
LLVM_ABI bool isKnownNonZero(const SCEV *S)
Test if the given expression is known to be non-zero.
LLVM_ABI const SCEV * getURemExpr(SCEVUse LHS, SCEVUse RHS)
Represents an unsigned remainder expression based on unsigned division.
LLVM_ABI const SCEV * getSCEVAtScope(const SCEV *S, const Loop *L)
Return a SCEV expression for the specified value at the specified scope in the program.
LLVM_ABI const SCEV * getBackedgeTakenCount(const Loop *L, ExitCountKind Kind=Exact)
If the specified loop has a predictable backedge-taken count, return it, otherwise return a SCEVCould...
LLVM_ABI const SCEV * getSMinExpr(SCEVUse LHS, SCEVUse RHS)
LLVM_ABI void setNoWrapFlags(SCEVAddRecExpr *AddRec, SCEV::NoWrapFlags Flags)
Update no-wrap flags of an AddRec.
LLVM_ABI const SCEV * getUMaxFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS)
Promote the operands to the wider of the types using zero-extension, and then perform a umax operatio...
const SCEV * getZero(Type *Ty)
Return a SCEV for the constant 0 of a specific type.
LLVM_ABI bool willNotOverflow(Instruction::BinaryOps BinOp, bool Signed, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI=nullptr)
Is operation BinOp between LHS and RHS provably does not have a signed/unsigned overflow (Signed)?
LLVM_ABI ExitLimit computeExitLimitFromCond(const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsOnlyExit, bool AllowPredicates=false)
Compute the number of times the backedge of the specified loop will execute if its exit condition wer...
LLVM_ABI const SCEV * getZeroExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI const SCEV * getMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< SCEVUse > &Operands)
LLVM_ABI const SCEVPredicate * getEqualPredicate(const SCEV *LHS, const SCEV *RHS)
LLVM_ABI unsigned getSmallConstantTripMultiple(const Loop *L, const SCEV *ExitCount)
Returns the largest constant divisor of the trip count as a normal unsigned value,...
LLVM_ABI uint64_t getTypeSizeInBits(Type *Ty) const
Return the size in bits of the specified type, for which isSCEVable must return true.
LLVM_ABI const SCEV * getConstant(ConstantInt *V)
LLVM_ABI const SCEV * getPredicatedBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getBackedgeTakenCount, except it will add a set of SCEV predicates to Predicates that are ...
LLVM_ABI const SCEV * getSCEV(Value *V)
Return a SCEV expression for the full generality of the specified expression.
LLVM_ABI const SCEV * getMinusSCEV(SCEVUse LHS, SCEVUse RHS, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Return LHS-RHS.
ConstantRange getSignedRange(const SCEV *S)
Determine the signed range for a particular SCEV.
LLVM_ABI const SCEV * getAddRecExpr(SCEVUse Start, SCEVUse Step, const Loop *L, SCEV::NoWrapFlags Flags)
Get an add recurrence expression for the specified loop.
LLVM_ABI const SCEV * getNoopOrSignExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
bool loopHasNoAbnormalExits(const Loop *L)
Return true if the loop has no abnormal exits.
LLVM_ABI const SCEV * getTripCountFromExitCount(const SCEV *ExitCount)
A version of getTripCountFromExitCount below which always picks an evaluation type which can not resu...
LLVM_ABI ScalarEvolution(Function &F, TargetLibraryInfo &TLI, AssumptionCache &AC, DominatorTree &DT, LoopInfo &LI)
const SCEV * getOne(Type *Ty)
Return a SCEV for the constant 1 of a specific type.
LLVM_ABI const SCEV * getTruncateOrNoop(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI const SCEV * getLosslessPtrToIntExpr(const SCEV *Op)
LLVM_ABI const SCEV * getCastExpr(SCEVTypes Kind, const SCEV *Op, Type *Ty)
LLVM_ABI const SCEV * getSequentialMinMaxExpr(SCEVTypes Kind, SmallVectorImpl< SCEVUse > &Operands)
LLVM_ABI std::optional< bool > evaluatePredicateAt(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Check whether the condition described by Pred, LHS, and RHS is true or false in the given Context.
LLVM_ABI unsigned getSmallConstantMaxTripCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > *Predicates=nullptr)
Returns the upper bound of the loop trip count as a normal unsigned value.
LLVM_ABI const SCEV * getPtrToIntExpr(const SCEV *Op, Type *Ty)
LLVM_ABI bool isBackedgeTakenCountMaxOrZero(const Loop *L)
Return true if the backedge taken count is either the value returned by getConstantMaxBackedgeTakenCo...
LLVM_ABI void forgetLoop(const Loop *L)
This method should be called by the client when it has changed a loop in a way that may effect Scalar...
LLVM_ABI bool isLoopInvariant(const SCEV *S, const Loop *L)
Return true if the value of the given SCEV is unchanging in the specified loop.
LLVM_ABI bool isKnownPositive(const SCEV *S)
Test if the given expression is known to be positive.
LLVM_ABI bool SimplifyICmpOperands(CmpPredicate &Pred, SCEVUse &LHS, SCEVUse &RHS, unsigned Depth=0)
Simplify LHS and RHS in a comparison with predicate Pred.
APInt getUnsignedRangeMin(const SCEV *S)
Determine the min of the unsigned range for a particular SCEV.
LLVM_ABI const SCEV * getOffsetOfExpr(Type *IntTy, StructType *STy, unsigned FieldNo)
Return an expression for offsetof on the given field with type IntTy.
LLVM_ABI LoopDisposition getLoopDisposition(const SCEV *S, const Loop *L)
Return the "disposition" of the given SCEV with respect to the given loop.
LLVM_ABI bool containsAddRecurrence(const SCEV *S)
Return true if the SCEV is a scAddRecExpr or it contains scAddRecExpr.
LLVM_ABI const SCEV * getSignExtendExprImpl(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI bool hasOperand(const SCEV *S, const SCEV *Op) const
Test whether the given SCEV has Op as a direct or indirect operand.
LLVM_ABI const SCEV * getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI bool isSCEVable(Type *Ty) const
Test if values of the given type are analyzable within the SCEV framework.
LLVM_ABI Type * getEffectiveSCEVType(Type *Ty) const
Return a type with the same bitwidth as the given type and which represents how SCEV will treat the g...
LLVM_ABI const SCEVPredicate * getComparePredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS)
LLVM_ABI bool haveSameSign(const SCEV *S1, const SCEV *S2)
Return true if we know that S1 and S2 must have the same sign.
LLVM_ABI const SCEV * getNotSCEV(const SCEV *V)
Return the SCEV object corresponding to ~V.
LLVM_ABI const SCEV * getElementCount(Type *Ty, ElementCount EC, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap)
LLVM_ABI bool instructionCouldExistWithOperands(const SCEV *A, const SCEV *B)
Return true if there exists a point in the program at which both A and B could be operands to the sam...
ConstantRange getUnsignedRange(const SCEV *S)
Determine the unsigned range for a particular SCEV.
LLVM_ABI void print(raw_ostream &OS) const
LLVM_ABI const SCEV * getPredicatedExitCount(const Loop *L, const BasicBlock *ExitingBlock, SmallVectorImpl< const SCEVPredicate * > *Predicates, ExitCountKind Kind=Exact)
Same as above except this uses the predicated backedge taken info and may require predicates.
static SCEV::NoWrapFlags clearFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OffFlags)
LLVM_ABI void forgetTopmostLoop(const Loop *L)
LLVM_ABI void forgetValue(Value *V)
This method should be called by the client when it has changed a value in a way that may effect its v...
APInt getSignedRangeMin(const SCEV *S)
Determine the min of the signed range for a particular SCEV.
LLVM_ABI const SCEV * getNoopOrAnyExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI void forgetBlockAndLoopDispositions(Value *V=nullptr)
Called when the client has changed the disposition of values in a loop or block.
LLVM_ABI const SCEV * getTruncateExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI const SCEV * getUMaxExpr(SCEVUse LHS, SCEVUse RHS)
static SCEV::NoWrapFlags maskFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags Mask)
Convenient NoWrapFlags manipulation.
LLVM_ABI std::optional< LoopInvariantPredicate > getLoopInvariantPredicate(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, const Instruction *CtxI=nullptr)
If the result of the predicate LHS Pred RHS is loop invariant with respect to L, return a LoopInvaria...
LLVM_ABI const SCEV * getStoreSizeOfExpr(Type *IntTy, Type *StoreTy)
Return an expression for the store size of StoreTy that is type IntTy.
LLVM_ABI const SCEVPredicate * getWrapPredicate(const SCEVAddRecExpr *AR, SCEVWrapPredicate::IncrementWrapFlags AddedFlags)
LLVM_ABI bool isLoopBackedgeGuardedByCond(const Loop *L, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether the backedge of the loop is protected by a conditional between LHS and RHS.
LLVM_ABI APInt getNonZeroConstantMultiple(const SCEV *S)
const SCEV * getMinusOne(Type *Ty)
Return a SCEV for the constant -1 of a specific type.
static SCEV::NoWrapFlags setFlags(SCEV::NoWrapFlags Flags, SCEV::NoWrapFlags OnFlags)
LLVM_ABI bool hasLoopInvariantBackedgeTakenCount(const Loop *L)
Return true if the specified loop has an analyzable loop-invariant backedge-taken count.
LLVM_ABI BlockDisposition getBlockDisposition(const SCEV *S, const BasicBlock *BB)
Return the "disposition" of the given SCEV with respect to the given block.
LLVM_ABI const SCEV * getNoopOrZeroExtend(const SCEV *V, Type *Ty)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI bool invalidate(Function &F, const PreservedAnalyses &PA, FunctionAnalysisManager::Invalidator &Inv)
LLVM_ABI const SCEV * getUMinFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS, bool Sequential=false)
Promote the operands to the wider of the types using zero-extension, and then perform a umin operatio...
LLVM_ABI bool loopIsFiniteByAssumption(const Loop *L)
Return true if this loop is finite by assumption.
LLVM_ABI const SCEV * getExistingSCEV(Value *V)
Return an existing SCEV for V if there is one, otherwise return nullptr.
LLVM_ABI APInt getConstantMultiple(const SCEV *S, const Instruction *CtxI=nullptr)
Returns the max constant multiple of S.
LoopDisposition
An enum describing the relationship between a SCEV and a loop.
@ LoopComputable
The SCEV varies predictably with the loop.
@ LoopVariant
The SCEV is loop-variant (unknown).
@ LoopInvariant
The SCEV is loop-invariant.
LLVM_ABI bool isKnownMultipleOf(const SCEV *S, uint64_t M, SmallVectorImpl< const SCEVPredicate * > &Assumptions)
Check that S is a multiple of M.
LLVM_ABI const SCEV * getAnyExtendExpr(const SCEV *Op, Type *Ty)
getAnyExtendExpr - Return a SCEV for the given operand extended with unspecified bits out to the give...
LLVM_ABI bool isKnownToBeAPowerOfTwo(const SCEV *S, bool OrZero=false, bool OrNegative=false)
Test if the given expression is known to be a power of 2.
LLVM_ABI std::optional< SCEV::NoWrapFlags > getStrengthenedNoWrapFlagsFromBinOp(const OverflowingBinaryOperator *OBO)
Parse NSW/NUW flags from add/sub/mul IR binary operation Op into SCEV no-wrap flags,...
LLVM_ABI void forgetLcssaPhiWithNewPredecessor(Loop *L, PHINode *V)
Forget LCSSA phi node V of loop L to which a new predecessor was added, such that it may no longer be...
LLVM_ABI bool containsUndefs(const SCEV *S) const
Return true if the SCEV expression contains an undef value.
LLVM_ABI std::optional< MonotonicPredicateType > getMonotonicPredicateType(const SCEVAddRecExpr *LHS, ICmpInst::Predicate Pred)
If, for all loop invariant X, the predicate "LHS `Pred` X" is monotonically increasing or decreasing,...
LLVM_ABI const SCEV * getCouldNotCompute()
LLVM_ABI const SCEV * getMulExpr(SmallVectorImpl< SCEVUse > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical multiply expression, or something simpler if possible.
LLVM_ABI bool isAvailableAtLoopEntry(const SCEV *S, const Loop *L)
Determine if the SCEV can be evaluated at loop's entry.
LLVM_ABI uint32_t getMinTrailingZeros(const SCEV *S, const Instruction *CtxI=nullptr)
Determine the minimum number of zero bits that S is guaranteed to end in (at every loop iteration).
BlockDisposition
An enum describing the relationship between a SCEV and a basic block.
@ DominatesBlock
The SCEV dominates the block.
@ ProperlyDominatesBlock
The SCEV properly dominates the block.
@ DoesNotDominateBlock
The SCEV does not dominate the block.
LLVM_ABI const SCEV * getExitCount(const Loop *L, const BasicBlock *ExitingBlock, ExitCountKind Kind=Exact)
Return the number of times the backedge executes before the given exit would be taken; if not exactly...
LLVM_ABI const SCEV * getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
LLVM_ABI void getPoisonGeneratingValues(SmallPtrSetImpl< const Value * > &Result, const SCEV *S)
Return the set of Values that, if poison, will definitively result in S being poison as well.
LLVM_ABI void forgetLoopDispositions()
Called when the client has changed the disposition of values in this loop.
LLVM_ABI const SCEV * getVScale(Type *Ty)
LLVM_ABI unsigned getSmallConstantTripCount(const Loop *L)
Returns the exact trip count of the loop if we can compute it, and the result is a small constant.
LLVM_ABI bool hasComputableLoopEvolution(const SCEV *S, const Loop *L)
Return true if the given SCEV changes value in a known way in the specified loop.
LLVM_ABI const SCEV * getPointerBase(const SCEV *V)
Transitively follow the chain of pointer-type operands until reaching a SCEV that does not have a sin...
LLVM_ABI void forgetAllLoops()
LLVM_ABI bool dominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV dominate the specified basic block.
APInt getUnsignedRangeMax(const SCEV *S)
Determine the max of the unsigned range for a particular SCEV.
LLVM_ABI const SCEV * getAddExpr(SmallVectorImpl< SCEVUse > &Ops, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Get a canonical add expression, or something simpler if possible.
ExitCountKind
The terms "backedge taken count" and "exit count" are used interchangeably to refer to the number of ...
@ SymbolicMaximum
An expression which provides an upper bound on the exact trip count.
@ ConstantMaximum
A constant which provides an upper bound on the exact trip count.
@ Exact
An expression exactly describing the number of times the backedge has executed when a loop is exited.
LLVM_ABI bool isKnownPredicate(CmpPredicate Pred, SCEVUse LHS, SCEVUse RHS)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
LLVM_ABI const SCEV * applyLoopGuards(const SCEV *Expr, const Loop *L)
Try to apply information from loop guards for L to Expr.
LLVM_ABI const SCEV * getPtrToAddrExpr(const SCEV *Op)
LLVM_ABI const SCEVAddRecExpr * convertSCEVToAddRecWithPredicates(const SCEV *S, const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Preds)
Tries to convert the S expression to an AddRec expression, adding additional predicates to Preds as r...
LLVM_ABI const SCEV * getSMaxExpr(SCEVUse LHS, SCEVUse RHS)
LLVM_ABI const SCEV * getElementSize(Instruction *Inst)
Return the size of an element read or written by Inst.
LLVM_ABI const SCEV * getSizeOfExpr(Type *IntTy, TypeSize Size)
Return an expression for a TypeSize.
LLVM_ABI std::optional< bool > evaluatePredicate(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Check whether the condition described by Pred, LHS, and RHS is true or false.
LLVM_ABI const SCEV * getUnknown(Value *V)
LLVM_ABI std::optional< std::pair< const SCEV *, SmallVector< const SCEVPredicate *, 3 > > > createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI)
Checks if SymbolicPHI can be rewritten as an AddRecExpr under some Predicates.
LLVM_ABI const SCEV * getTruncateOrZeroExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI bool isKnownViaInduction(CmpPredicate Pred, SCEVUse LHS, SCEVUse RHS)
We'd like to check the predicate on every iteration of the most dominated loop between loops used in ...
LLVM_ABI std::optional< APInt > computeConstantDifference(const SCEV *LHS, const SCEV *RHS)
Compute LHS - RHS and returns the result as an APInt if it is a constant, and std::nullopt if it isn'...
LLVM_ABI bool properlyDominates(const SCEV *S, const BasicBlock *BB)
Return true if elements that makes up the given SCEV properly dominate the specified basic block.
LLVM_ABI const SCEV * getUDivExactExpr(SCEVUse LHS, SCEVUse RHS)
Get a canonical unsigned division expression, or something simpler if possible.
LLVM_ABI const SCEV * rewriteUsingPredicate(const SCEV *S, const Loop *L, const SCEVPredicate &A)
Re-writes the SCEV according to the Predicates in A.
LLVM_ABI std::pair< const SCEV *, const SCEV * > SplitIntoInitAndPostInc(const Loop *L, const SCEV *S)
Splits SCEV expression S into two SCEVs.
LLVM_ABI bool canReuseInstruction(const SCEV *S, Instruction *I, SmallVectorImpl< Instruction * > &DropPoisonGeneratingInsts)
Check whether it is poison-safe to represent the expression S using the instruction I.
LLVM_ABI bool isKnownPredicateAt(CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS, const Instruction *CtxI)
Test if the given expression is known to satisfy the condition described by Pred, LHS,...
LLVM_ABI const SCEV * getPredicatedSymbolicMaxBackedgeTakenCount(const Loop *L, SmallVectorImpl< const SCEVPredicate * > &Predicates)
Similar to getSymbolicMaxBackedgeTakenCount, except it will add a set of SCEV predicates to Predicate...
LLVM_ABI const SCEV * getGEPExpr(GEPOperator *GEP, ArrayRef< SCEVUse > IndexExprs)
Returns an expression for a GEP.
LLVM_ABI const SCEV * getUMinExpr(SCEVUse LHS, SCEVUse RHS, bool Sequential=false)
LLVM_ABI void registerUser(const SCEV *User, ArrayRef< const SCEV * > Ops)
Notify this ScalarEvolution that User directly uses SCEVs in Ops.
LLVM_ABI bool isBasicBlockEntryGuardedByCond(const BasicBlock *BB, CmpPredicate Pred, const SCEV *LHS, const SCEV *RHS)
Test whether entry to the basic block is protected by a conditional between LHS and RHS.
LLVM_ABI const SCEV * getTruncateOrSignExtend(const SCEV *V, Type *Ty, unsigned Depth=0)
Return a SCEV corresponding to a conversion of the input value to the specified type.
LLVM_ABI bool containsErasedValue(const SCEV *S) const
Return true if the SCEV expression contains a Value that has been optimised out and is now a nullptr.
const SCEV * getSymbolicMaxBackedgeTakenCount(const Loop *L)
When successful, this returns a SCEV that is greater than or equal to (i.e.
APInt getSignedRangeMax(const SCEV *S)
Determine the max of the signed range for a particular SCEV.
LLVM_ABI void verify() const
LLVMContext & getContext() const
Implements a dense probed hash-table based set with some number of buckets stored inline.
Definition DenseSet.h:291
size_type size() const
Definition SmallPtrSet.h:99
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
bool contains(ConstPtrType Ptr) const
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
reference emplace_back(ArgTypes &&... Args)
void reserve(size_type N)
iterator erase(const_iterator CI)
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
iterator insert(iterator I, T &&Elt)
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
An instruction for storing to memory.
Represent a constant reference to a string, i.e.
Definition StringRef.h:56
Used to lazily calculate structure layout information for a target machine, based on the DataLayout s...
Definition DataLayout.h:743
TypeSize getElementOffset(unsigned Idx) const
Definition DataLayout.h:774
TypeSize getSizeInBits() const
Definition DataLayout.h:754
Class to represent struct types.
Analysis pass providing the TargetLibraryInfo.
Provides information about what library functions are available for the current target.
The instances of the Type class are immutable: once they are created, they are never changed.
Definition Type.h:46
static LLVM_ABI IntegerType * getInt32Ty(LLVMContext &C)
Definition Type.cpp:313
bool isPointerTy() const
True if this is an instance of PointerType.
Definition Type.h:284
LLVM_ABI TypeSize getPrimitiveSizeInBits() const LLVM_READONLY
Return the basic size of this type if it is a primitive type.
Definition Type.cpp:201
static LLVM_ABI IntegerType * getInt1Ty(LLVMContext &C)
Definition Type.cpp:310
bool isIntOrPtrTy() const
Return true if this is an integer type or a pointer type.
Definition Type.h:272
bool isIntegerTy() const
True if this is an instance of IntegerType.
Definition Type.h:257
static LLVM_ABI IntegerType * getIntNTy(LLVMContext &C, unsigned N)
Definition Type.cpp:317
A Use represents the edge between a Value definition and its users.
Definition Use.h:35
op_range operands()
Definition User.h:267
Use & Op()
Definition User.h:171
Value * getOperand(unsigned i) const
Definition User.h:207
LLVM Value Representation.
Definition Value.h:75
Type * getType() const
All values are typed, get the type of this value.
Definition Value.h:255
LLVMContext & getContext() const
All values hold a context through their type.
Definition Value.h:258
unsigned getValueID() const
Return an ID for the concrete type of this object.
Definition Value.h:543
LLVM_ABI void printAsOperand(raw_ostream &O, bool PrintType=true, const Module *M=nullptr) const
Print the name of this Value out to the specified raw_ostream.
LLVM_ABI StringRef getName() const
Return a constant reference to the value's name.
Definition Value.cpp:318
constexpr bool isScalable() const
Returns whether the quantity is scaled by a runtime quantity (vscale).
Definition TypeSize.h:168
An efficient, type-erasing, non-owning reference to a callable.
const ParentTy * getParent() const
Definition ilist_node.h:34
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition raw_ostream.h:53
raw_ostream & indent(unsigned NumSpaces)
indent - Insert 'NumSpaces' spaces.
Changed
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
constexpr char Align[]
Key for Kernel::Arg::Metadata::mAlign.
const APInt & smin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be signed.
Definition APInt.h:2277
const APInt & smax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be signed.
Definition APInt.h:2282
const APInt & umin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be unsigned.
Definition APInt.h:2287
LLVM_ABI std::optional< APInt > SolveQuadraticEquationWrap(APInt A, APInt B, APInt C, unsigned RangeWidth)
Let q(n) = An^2 + Bn + C, and BW = bit width of the value range (e.g.
Definition APInt.cpp:2864
const APInt & umax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be unsigned.
Definition APInt.h:2292
LLVM_ABI APInt GreatestCommonDivisor(APInt A, APInt B)
Compute GCD of two unsigned APInt values.
Definition APInt.cpp:830
constexpr bool any(E Val)
@ Entry
Definition COFF.h:862
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition CallingConv.h:24
@ C
The default llvm calling convention, compatible with C.
Definition CallingConv.h:34
int getMinValue(MCInstrInfo const &MCII, MCInst const &MCI)
Return the minimum value of an extendable operand.
@ BasicBlock
Various leaf nodes.
Definition ISDOpcodes.h:81
LLVM_ABI Function * getDeclarationIfExists(const Module *M, ID id)
Look up the Function declaration of the intrinsic id in the Module M and return it if it exists.
Predicate
Predicate - These are "(BI << 5) | BO" for various predicates.
match_combine_or< Ty... > m_CombineOr(const Ty &...Ps)
Combine pattern matchers matching any of Ps patterns.
BinaryOp_match< LHS, RHS, Instruction::AShr > m_AShr(const LHS &L, const RHS &R)
ap_match< APInt > m_APInt(const APInt *&Res)
Match a ConstantInt or splatted ConstantVector, binding the specified pointer to the contained APInt.
bool match(Val *V, const Pattern &P)
IntrinsicID_match m_Intrinsic()
Match intrinsic calls like this: m_Intrinsic<Intrinsic::fabs>(m_Value(X))
ThreeOps_match< Cond, LHS, RHS, Instruction::Select > m_Select(const Cond &C, const LHS &L, const RHS &R)
Matches SelectInst.
auto m_BasicBlock()
Match an arbitrary basic block value and ignore it.
ExtractValue_match< Ind, Val_t > m_ExtractValue(const Val_t &V)
Match a single index ExtractValue instruction.
auto m_Value()
Match an arbitrary value and ignore it.
auto m_LogicalOr()
Matches L || R where L and R are arbitrary values.
match_bind< WithOverflowInst > m_WithOverflowInst(WithOverflowInst *&I)
Match a with overflow intrinsic, capturing it if we match.
BinaryOp_match< LHS, RHS, Instruction::SDiv > m_SDiv(const LHS &L, const RHS &R)
BinaryOp_match< LHS, RHS, Instruction::LShr > m_LShr(const LHS &L, const RHS &R)
BinaryOp_match< LHS, RHS, Instruction::Shl > m_Shl(const LHS &L, const RHS &R)
auto m_LogicalAnd()
Matches L && R where L and R are arbitrary values.
brc_match< Cond_t, match_bind< BasicBlock >, match_bind< BasicBlock > > m_Br(const Cond_t &C, BasicBlock *&T, BasicBlock *&F)
auto m_ConstantInt()
Match an arbitrary ConstantInt and ignore it.
bind_cst_ty m_scev_APInt(const APInt *&C)
Match an SCEV constant and bind it to an APInt.
cst_pred_ty< is_all_ones > m_scev_AllOnes()
Match an integer with all bits set.
SCEVUnaryExpr_match< SCEVZeroExtendExpr, Op0_t > m_scev_ZExt(const Op0_t &Op0)
is_undef_or_poison m_scev_UndefOrPoison()
Match an SCEVUnknown wrapping undef or poison.
cst_pred_ty< is_one > m_scev_One()
Match an integer 1.
specificloop_ty m_SpecificLoop(const Loop *L)
SCEVUnaryExpr_match< SCEVSignExtendExpr, Op0_t > m_scev_SExt(const Op0_t &Op0)
match_bind< const SCEVMulExpr > m_scev_Mul(const SCEVMulExpr *&V)
cst_pred_ty< is_zero > m_scev_Zero()
Match an integer 0.
SCEVUnaryExpr_match< SCEVTruncateExpr, Op0_t > m_scev_Trunc(const Op0_t &Op0)
bool match(const SCEV *S, const Pattern &P)
SCEVBinaryExpr_match< SCEVUDivExpr, Op0_t, Op1_t > m_scev_UDiv(const Op0_t &Op0, const Op1_t &Op1)
specificscev_ty m_scev_Specific(const SCEV *S)
Match if we have a specific specified SCEV.
SCEVAffineAddRec_match< Op0_t, Op1_t, match_isa< const Loop > > m_scev_AffineAddRec(const Op0_t &Op0, const Op1_t &Op1)
match_bind< const SCEVUnknown > m_SCEVUnknown(const SCEVUnknown *&V)
SCEVBinaryExpr_match< SCEVMulExpr, Op0_t, Op1_t, SCEV::FlagNUW, true > m_scev_c_NUWMul(const Op0_t &Op0, const Op1_t &Op1)
match_bind< const SCEVAddExpr > m_scev_Add(const SCEVAddExpr *&V)
SCEVBinaryExpr_match< SCEVMulExpr, Op0_t, Op1_t, SCEV::FlagAnyWrap, true > m_scev_c_Mul(const Op0_t &Op0, const Op1_t &Op1)
SCEVBinaryExpr_match< SCEVSMaxExpr, Op0_t, Op1_t > m_scev_SMax(const Op0_t &Op0, const Op1_t &Op1)
SCEVURem_match< Op0_t, Op1_t > m_scev_URem(Op0_t LHS, Op1_t RHS, ScalarEvolution &SE)
Match the mathematical pattern A - (A / B) * B, where A and B can be arbitrary expressions.
@ Valid
The data is already valid.
initializer< Ty > init(const Ty &Val)
LocationClass< Ty > location(Ty &L)
@ Switch
The "resume-switch" lowering, where there are separate resume and destroy functions that are shared b...
Definition CoroShape.h:31
constexpr double e
NodeAddr< PhiNode * > Phi
Definition RDFGraph.h:390
friend class Instruction
Iterator for Instructions in a `BasicBlock.
Definition BasicBlock.h:73
This is an optimization pass for GlobalISel generic memory operations.
void visitAll(const SCEV *Root, SV &Visitor)
Use SCEVTraversal to visit all nodes in the given expression tree.
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
Definition STLExtras.h:315
@ Offset
Definition DWP.cpp:557
FunctionAddr VTableAddr Value
Definition InstrProf.h:137
LLVM_ATTRIBUTE_ALWAYS_INLINE DynamicAPInt gcd(const DynamicAPInt &A, const DynamicAPInt &B)
void stable_sort(R &&Range)
Definition STLExtras.h:2115
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1738
SaveAndRestore(T &) -> SaveAndRestore< T >
Printable print(const GCNRegPressure &RP, const GCNSubtarget *ST=nullptr, unsigned DynamicVGPRBlockSize=0)
LLVM_ABI bool canCreatePoison(const Operator *Op, bool ConsiderFlagsAndMetadata=true)
LLVM_ABI bool mustTriggerUB(const Instruction *I, const SmallPtrSetImpl< const Value * > &KnownPoison)
Return true if the given instruction must trigger undefined behavior when I is executed with any oper...
LLVM_ABI bool canConstantFoldCallTo(const CallBase *Call, const Function *F)
canConstantFoldCallTo - Return true if its even possible to fold a call to the specified function.
InterleavedRange< Range > interleaved(const Range &R, StringRef Separator=", ", StringRef Prefix="", StringRef Suffix="")
Output range R as a sequence of interleaved elements.
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:643
LLVM_ABI bool verifyFunction(const Function &F, raw_ostream *OS=nullptr)
Check a function for errors, useful for use when debugging a pass.
auto successors(const MachineBasicBlock *BB)
scope_exit(Callable) -> scope_exit< Callable >
constexpr from_range_t from_range
auto dyn_cast_if_present(const Y &Val)
dyn_cast_if_present<X> - Functionally identical to dyn_cast, except that a null (or none in the case ...
Definition Casting.h:732
bool set_is_subset(const S1Ty &S1, const S2Ty &S2)
set_is_subset(A, B) - Return true iff A in B
void append_range(Container &C, Range &&R)
Wrapper function to append range R to container C.
Definition STLExtras.h:2207
constexpr bool isUIntN(unsigned N, uint64_t x)
Checks if an unsigned integer fits into the given (dynamic) bit width.
Definition MathExtras.h:243
LLVM_ABI Constant * ConstantFoldCompareInstOperands(unsigned Predicate, Constant *LHS, Constant *RHS, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr, const Instruction *I=nullptr)
Attempt to constant fold a compare instruction (icmp/fcmp) with the specified operands.
void * PointerTy
LLVM_ABI bool VerifySCEV
auto uninitialized_copy(R &&Src, IterTy Dst)
Definition STLExtras.h:2110
bool isa_and_nonnull(const Y &Val)
Definition Casting.h:676
LLVM_ABI ConstantRange getConstantRangeFromMetadata(const MDNode &RangeMD)
Parse out a conservative ConstantRange from !range metadata.
int countr_zero(T Val)
Count number of 0's from the least significant bit to the most stopping at the first 1.
Definition bit.h:204
LLVM_ABI Value * simplifyInstruction(Instruction *I, const SimplifyQuery &Q)
See if we can compute a simplified version of this instruction.
LLVM_ABI bool isOverflowIntrinsicNoWrap(const WithOverflowInst *WO, const DominatorTree &DT)
Returns true if the arithmetic part of the WO 's result is used only along the paths control dependen...
DomTreeNodeBase< BasicBlock > DomTreeNode
Definition Dominators.h:94
LLVM_ABI bool matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO, Value *&Start, Value *&Step)
Attempt to match a simple first order recurrence cycle of the form: iv = phi Ty [Start,...
auto dyn_cast_or_null(const Y &Val)
Definition Casting.h:753
void erase(Container &C, ValueType V)
Wrapper function to remove a value from a container:
Definition STLExtras.h:2199
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1745
iterator_range< pointee_iterator< WrappedIteratorT > > make_pointee_range(RangeT &&Range)
Definition iterator.h:341
auto reverse(ContainerTy &&C)
Definition STLExtras.h:407
LLVM_ABI bool isMustProgress(const Loop *L)
Return true if this loop can be assumed to make progress.
LLVM_ABI bool impliesPoison(const Value *ValAssumedPoison, const Value *V)
Return true if V is poison given that ValAssumedPoison is already poison.
LLVM_ABI bool isFinite(const Loop *L)
Return true if this loop can be assumed to run for a finite number of iterations.
LLVM_ABI void computeKnownBits(const Value *V, KnownBits &Known, const DataLayout &DL, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true, unsigned Depth=0)
Determine which bits of V are known to be either zero or one and return them in the KnownZero/KnownOn...
unsigned short computeExpressionSize(ArrayRef< SCEVUse > Args)
LLVM_ABI bool programUndefinedIfPoison(const Instruction *Inst)
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:209
bool isPointerTy(const Type *T)
Definition SPIRVUtils.h:370
FunctionAddr VTableAddr Count
Definition InstrProf.h:139
LLVM_ABI ConstantRange getVScaleRange(const Function *F, unsigned BitWidth)
Determine the possible constant range of vscale with the given bit width, based on the vscale_range f...
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
Definition Casting.h:547
LLVM_ATTRIBUTE_VISIBILITY_DEFAULT AnalysisKey InnerAnalysisManagerProxy< AnalysisManagerT, IRUnitT, ExtraArgTs... >::Key
LLVM_ABI bool isKnownNonZero(const Value *V, const SimplifyQuery &Q, unsigned Depth=0)
Return true if the given value is known to be non-zero when defined.
LLVM_ABI bool propagatesPoison(const Use &PoisonOp)
Return true if PoisonOp's user yields poison or raises UB if its operand PoisonOp is poison.
@ UMin
Unsigned integer min implemented in terms of select(cmp()).
@ Mul
Product of integers.
@ SMax
Signed integer max implemented in terms of select(cmp()).
@ SMin
Signed integer min implemented in terms of select(cmp()).
@ Add
Sum of integers.
@ UMax
Unsigned integer max implemented in terms of select(cmp()).
auto count(R &&Range, const E &Element)
Wrapper function around std::count to count the number of times an element Element occurs in the give...
Definition STLExtras.h:2011
DWARFExpression::Operation Op
auto max_element(R &&Range)
Provide wrappers to std::max_element which take ranges instead of having to pass begin/end explicitly...
Definition STLExtras.h:2087
raw_ostream & operator<<(raw_ostream &OS, const APFixedPoint &FX)
ArrayRef(const T &OneElt) -> ArrayRef< T >
LLVM_ABI unsigned ComputeNumSignBits(const Value *Op, const DataLayout &DL, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true, unsigned Depth=0)
Return the number of times the sign bit of the register is replicated into the other bits.
constexpr unsigned BitWidth
OutputIt move(R &&Range, OutputIt Out)
Provide wrappers to std::move which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1916
LLVM_ABI bool isGuaranteedToTransferExecutionToSuccessor(const Instruction *I)
Return true if this function can prove that the instruction I will always transfer execution to one o...
auto count_if(R &&Range, UnaryPredicate P)
Wrapper function around std::count_if to count the number of times an element satisfying a given pred...
Definition STLExtras.h:2018
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:559
constexpr bool isIntN(unsigned N, int64_t x)
Checks if an signed integer fits into the given (dynamic) bit width.
Definition MathExtras.h:248
auto predecessors(const MachineBasicBlock *BB)
bool is_contained(R &&Range, const E &Element)
Returns true if Element is found in Range.
Definition STLExtras.h:1946
iterator_range< df_iterator< T > > depth_first(const T &G)
auto seq(T Begin, T End)
Iterate over an integral type from Begin up to - but not including - End.
Definition Sequence.h:305
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.
LLVM_ABI bool isGuaranteedNotToBePoison(const Value *V, AssumptionCache *AC=nullptr, const Instruction *CtxI=nullptr, const DominatorTree *DT=nullptr, unsigned Depth=0)
Returns true if V cannot be poison, but may be undef.
LLVM_ABI Constant * ConstantFoldInstOperands(const Instruction *I, ArrayRef< Constant * > Ops, const DataLayout &DL, const TargetLibraryInfo *TLI=nullptr, bool AllowNonDeterministic=true)
ConstantFoldInstOperands - Attempt to constant fold an instruction with the specified operands.
SCEVUseT< const SCEV * > SCEVUse
bool SCEVExprContains(const SCEV *Root, PredTy Pred)
Return true if any node in Root satisfies the predicate Pred.
Implement std::hash so that hash_code can be used in STL containers.
Definition BitVector.h:874
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
Definition BitVector.h:876
#define N
#define NC
Definition regutils.h:42
A special type used by analysis passes to provide an address that identifies that particular analysis...
Definition Analysis.h:29
static KnownBits makeConstant(const APInt &C)
Create known bits from a known constant.
Definition KnownBits.h:315
bool isNonNegative() const
Returns true if this value is known to be non-negative.
Definition KnownBits.h:106
static LLVM_ABI KnownBits ashr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for ashr(LHS, RHS).
unsigned getBitWidth() const
Get the bit width of this value.
Definition KnownBits.h:44
static LLVM_ABI KnownBits lshr(const KnownBits &LHS, const KnownBits &RHS, bool ShAmtNonZero=false, bool Exact=false)
Compute known bits for lshr(LHS, RHS).
KnownBits zextOrTrunc(unsigned BitWidth) const
Return known bits for a zero extension or truncation of the value we're tracking.
Definition KnownBits.h:200
APInt getMaxValue() const
Return the maximal unsigned value possible given these KnownBits.
Definition KnownBits.h:146
APInt getMinValue() const
Return the minimal unsigned value possible given these KnownBits.
Definition KnownBits.h:130
bool isNegative() const
Returns true if this value is known to be negative.
Definition KnownBits.h:103
static LLVM_ABI KnownBits shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW=false, bool NSW=false, bool ShAmtNonZero=false)
Compute known bits for shl(LHS, RHS).
An object of this class is returned by queries that could not be answered.
static LLVM_ABI bool classof(const SCEV *S)
Methods for support type inquiry through isa, cast, and dyn_cast:
This class defines a simple visitor class that may be used for various SCEV analysis purposes.
A utility class that uses RAII to save and restore the value of a variable.
Information about the number of loop iterations for which a loop exit's branch condition evaluates to...
LLVM_ABI ExitLimit(const SCEV *E)
Construct either an exact exit limit from a constant, or an unknown one from a SCEVCouldNotCompute.
SmallVector< const SCEVPredicate *, 4 > Predicates
A vector of predicate guards for this ExitLimit.