LLVM  16.0.0git
ScalarEvolution.cpp
Go to the documentation of this file.
1 //===- ScalarEvolution.cpp - Scalar Evolution Analysis --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the implementation of the scalar evolution analysis
10 // engine, which is used primarily to analyze expressions involving induction
11 // variables in loops.
12 //
13 // There are several aspects to this library. First is the representation of
14 // scalar expressions, which are represented as subclasses of the SCEV class.
15 // These classes are used to represent certain types of subexpressions that we
16 // can handle. We only create one SCEV of a particular shape, so
17 // pointer-comparisons for equality are legal.
18 //
19 // One important aspect of the SCEV objects is that they are never cyclic, even
20 // if there is a cycle in the dataflow for an expression (ie, a PHI node). If
21 // the PHI node is one of the idioms that we can represent (e.g., a polynomial
22 // recurrence) then we represent it directly as a recurrence node, otherwise we
23 // represent it as a SCEVUnknown node.
24 //
25 // In addition to being able to represent expressions of various types, we also
26 // have folders that are used to build the *canonical* representation for a
27 // particular expression. These folders are capable of using a variety of
28 // rewrite rules to simplify the expressions.
29 //
30 // Once the folders are defined, we can implement the more interesting
31 // higher-level code, such as the code that recognizes PHI nodes of various
32 // types, computes the execution count of a loop, etc.
33 //
34 // TODO: We should use these routines and value representations to implement
35 // dependence analysis!
36 //
37 //===----------------------------------------------------------------------===//
38 //
39 // There are several good references for the techniques used in this analysis.
40 //
41 // Chains of recurrences -- a method to expedite the evaluation
42 // of closed-form functions
43 // Olaf Bachmann, Paul S. Wang, Eugene V. Zima
44 //
45 // On computational properties of chains of recurrences
46 // Eugene V. Zima
47 //
48 // Symbolic Evaluation of Chains of Recurrences for Loop Optimization
49 // Robert A. van Engelen
50 //
51 // Efficient Symbolic Analysis for Optimizing Compilers
52 // Robert A. van Engelen
53 //
54 // Using the chains of recurrences algebra for data dependence testing and
55 // induction variable substitution
56 // MS Thesis, Johnie Birch
57 //
58 //===----------------------------------------------------------------------===//
59 
61 #include "llvm/ADT/APInt.h"
62 #include "llvm/ADT/ArrayRef.h"
63 #include "llvm/ADT/DenseMap.h"
66 #include "llvm/ADT/FoldingSet.h"
67 #include "llvm/ADT/None.h"
68 #include "llvm/ADT/Optional.h"
69 #include "llvm/ADT/STLExtras.h"
70 #include "llvm/ADT/ScopeExit.h"
71 #include "llvm/ADT/Sequence.h"
72 #include "llvm/ADT/SmallPtrSet.h"
73 #include "llvm/ADT/SmallSet.h"
74 #include "llvm/ADT/SmallVector.h"
75 #include "llvm/ADT/Statistic.h"
76 #include "llvm/ADT/StringRef.h"
80 #include "llvm/Analysis/LoopInfo.h"
84 #include "llvm/Config/llvm-config.h"
85 #include "llvm/IR/Argument.h"
86 #include "llvm/IR/BasicBlock.h"
87 #include "llvm/IR/CFG.h"
88 #include "llvm/IR/Constant.h"
89 #include "llvm/IR/ConstantRange.h"
90 #include "llvm/IR/Constants.h"
91 #include "llvm/IR/DataLayout.h"
92 #include "llvm/IR/DerivedTypes.h"
93 #include "llvm/IR/Dominators.h"
94 #include "llvm/IR/Function.h"
95 #include "llvm/IR/GlobalAlias.h"
96 #include "llvm/IR/GlobalValue.h"
97 #include "llvm/IR/InstIterator.h"
98 #include "llvm/IR/InstrTypes.h"
99 #include "llvm/IR/Instruction.h"
100 #include "llvm/IR/Instructions.h"
101 #include "llvm/IR/IntrinsicInst.h"
102 #include "llvm/IR/Intrinsics.h"
103 #include "llvm/IR/LLVMContext.h"
104 #include "llvm/IR/Operator.h"
105 #include "llvm/IR/PatternMatch.h"
106 #include "llvm/IR/Type.h"
107 #include "llvm/IR/Use.h"
108 #include "llvm/IR/User.h"
109 #include "llvm/IR/Value.h"
110 #include "llvm/IR/Verifier.h"
111 #include "llvm/InitializePasses.h"
112 #include "llvm/Pass.h"
113 #include "llvm/Support/Casting.h"
115 #include "llvm/Support/Compiler.h"
116 #include "llvm/Support/Debug.h"
118 #include "llvm/Support/KnownBits.h"
121 #include <algorithm>
122 #include <cassert>
123 #include <climits>
124 #include <cstdint>
125 #include <cstdlib>
126 #include <map>
127 #include <memory>
128 #include <numeric>
129 #include <optional>
130 #include <tuple>
131 #include <utility>
132 #include <vector>
133 
134 using namespace llvm;
135 using namespace PatternMatch;
136 
137 #define DEBUG_TYPE "scalar-evolution"
138 
139 STATISTIC(NumTripCountsComputed,
140  "Number of loops with predictable loop counts");
141 STATISTIC(NumTripCountsNotComputed,
142  "Number of loops without predictable loop counts");
143 STATISTIC(NumBruteForceTripCountsComputed,
144  "Number of loops with trip counts computed by force");
145 
146 #ifdef EXPENSIVE_CHECKS
147 bool llvm::VerifySCEV = true;
148 #else
149 bool llvm::VerifySCEV = false;
150 #endif
151 
152 static cl::opt<unsigned>
153  MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
154  cl::desc("Maximum number of iterations SCEV will "
155  "symbolically execute a constant "
156  "derived loop"),
157  cl::init(100));
158 
160  "verify-scev", cl::Hidden, cl::location(VerifySCEV),
161  cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"));
163  "verify-scev-strict", cl::Hidden,
164  cl::desc("Enable stricter verification with -verify-scev is passed"));
165 static cl::opt<bool>
166  VerifySCEVMap("verify-scev-maps", cl::Hidden,
167  cl::desc("Verify no dangling value in ScalarEvolution's "
168  "ExprValueMap (slow)"));
169 
170 static cl::opt<bool> VerifyIR(
171  "scev-verify-ir", cl::Hidden,
172  cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"),
173  cl::init(false));
174 
176  "scev-mulops-inline-threshold", cl::Hidden,
177  cl::desc("Threshold for inlining multiplication operands into a SCEV"),
178  cl::init(32));
179 
181  "scev-addops-inline-threshold", cl::Hidden,
182  cl::desc("Threshold for inlining addition operands into a SCEV"),
183  cl::init(500));
184 
186  "scalar-evolution-max-scev-compare-depth", cl::Hidden,
187  cl::desc("Maximum depth of recursive SCEV complexity comparisons"),
188  cl::init(32));
189 
191  "scalar-evolution-max-scev-operations-implication-depth", cl::Hidden,
192  cl::desc("Maximum depth of recursive SCEV operations implication analysis"),
193  cl::init(2));
194 
196  "scalar-evolution-max-value-compare-depth", cl::Hidden,
197  cl::desc("Maximum depth of recursive value complexity comparisons"),
198  cl::init(2));
199 
200 static cl::opt<unsigned>
201  MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden,
202  cl::desc("Maximum depth of recursive arithmetics"),
203  cl::init(32));
204 
206  "scalar-evolution-max-constant-evolving-depth", cl::Hidden,
207  cl::desc("Maximum depth of recursive constant evolving"), cl::init(32));
208 
209 static cl::opt<unsigned>
210  MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden,
211  cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"),
212  cl::init(8));
213 
214 static cl::opt<unsigned>
215  MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden,
216  cl::desc("Max coefficients in AddRec during evolving"),
217  cl::init(8));
218 
219 static cl::opt<unsigned>
220  HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden,
221  cl::desc("Size of the expression which is considered huge"),
222  cl::init(4096));
223 
225  "scev-range-iter-threshold", cl::Hidden,
226  cl::desc("Threshold for switching to iteratively computing SCEV ranges"),
227  cl::init(32));
228 
229 static cl::opt<bool>
230 ClassifyExpressions("scalar-evolution-classify-expressions",
231  cl::Hidden, cl::init(true),
232  cl::desc("When printing analysis, include information on every instruction"));
233 
235  "scalar-evolution-use-expensive-range-sharpening", cl::Hidden,
236  cl::init(false),
237  cl::desc("Use more powerful methods of sharpening expression ranges. May "
238  "be costly in terms of compile time"));
239 
241  "scalar-evolution-max-scc-analysis-depth", cl::Hidden,
242  cl::desc("Maximum amount of nodes to process while searching SCEVUnknown "
243  "Phi strongly connected components"),
244  cl::init(8));
245 
246 static cl::opt<bool>
247  EnableFiniteLoopControl("scalar-evolution-finite-loop", cl::Hidden,
248  cl::desc("Handle <= and >= in finite loops"),
249  cl::init(true));
250 
252  "scalar-evolution-use-context-for-no-wrap-flag-strenghening", cl::Hidden,
253  cl::desc("Infer nuw/nsw flags using context where suitable"),
254  cl::init(true));
255 
256 //===----------------------------------------------------------------------===//
257 // SCEV class definitions
258 //===----------------------------------------------------------------------===//
259 
260 //===----------------------------------------------------------------------===//
261 // Implementation of the SCEV class.
262 //
263 
264 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
266  print(dbgs());
267  dbgs() << '\n';
268 }
269 #endif
270 
271 void SCEV::print(raw_ostream &OS) const {
272  switch (getSCEVType()) {
273  case scConstant:
274  cast<SCEVConstant>(this)->getValue()->printAsOperand(OS, false);
275  return;
276  case scPtrToInt: {
277  const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(this);
278  const SCEV *Op = PtrToInt->getOperand();
279  OS << "(ptrtoint " << *Op->getType() << " " << *Op << " to "
280  << *PtrToInt->getType() << ")";
281  return;
282  }
283  case scTruncate: {
284  const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(this);
285  const SCEV *Op = Trunc->getOperand();
286  OS << "(trunc " << *Op->getType() << " " << *Op << " to "
287  << *Trunc->getType() << ")";
288  return;
289  }
290  case scZeroExtend: {
291  const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(this);
292  const SCEV *Op = ZExt->getOperand();
293  OS << "(zext " << *Op->getType() << " " << *Op << " to "
294  << *ZExt->getType() << ")";
295  return;
296  }
297  case scSignExtend: {
298  const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(this);
299  const SCEV *Op = SExt->getOperand();
300  OS << "(sext " << *Op->getType() << " " << *Op << " to "
301  << *SExt->getType() << ")";
302  return;
303  }
304  case scAddRecExpr: {
305  const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(this);
306  OS << "{" << *AR->getOperand(0);
307  for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i)
308  OS << ",+," << *AR->getOperand(i);
309  OS << "}<";
310  if (AR->hasNoUnsignedWrap())
311  OS << "nuw><";
312  if (AR->hasNoSignedWrap())
313  OS << "nsw><";
314  if (AR->hasNoSelfWrap() &&
315  !AR->getNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW)))
316  OS << "nw><";
317  AR->getLoop()->getHeader()->printAsOperand(OS, /*PrintType=*/false);
318  OS << ">";
319  return;
320  }
321  case scAddExpr:
322  case scMulExpr:
323  case scUMaxExpr:
324  case scSMaxExpr:
325  case scUMinExpr:
326  case scSMinExpr:
327  case scSequentialUMinExpr: {
328  const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(this);
329  const char *OpStr = nullptr;
330  switch (NAry->getSCEVType()) {
331  case scAddExpr: OpStr = " + "; break;
332  case scMulExpr: OpStr = " * "; break;
333  case scUMaxExpr: OpStr = " umax "; break;
334  case scSMaxExpr: OpStr = " smax "; break;
335  case scUMinExpr:
336  OpStr = " umin ";
337  break;
338  case scSMinExpr:
339  OpStr = " smin ";
340  break;
342  OpStr = " umin_seq ";
343  break;
344  default:
345  llvm_unreachable("There are no other nary expression types.");
346  }
347  OS << "(";
348  ListSeparator LS(OpStr);
349  for (const SCEV *Op : NAry->operands())
350  OS << LS << *Op;
351  OS << ")";
352  switch (NAry->getSCEVType()) {
353  case scAddExpr:
354  case scMulExpr:
355  if (NAry->hasNoUnsignedWrap())
356  OS << "<nuw>";
357  if (NAry->hasNoSignedWrap())
358  OS << "<nsw>";
359  break;
360  default:
361  // Nothing to print for other nary expressions.
362  break;
363  }
364  return;
365  }
366  case scUDivExpr: {
367  const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(this);
368  OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")";
369  return;
370  }
371  case scUnknown: {
372  const SCEVUnknown *U = cast<SCEVUnknown>(this);
373  Type *AllocTy;
374  if (U->isSizeOf(AllocTy)) {
375  OS << "sizeof(" << *AllocTy << ")";
376  return;
377  }
378  if (U->isAlignOf(AllocTy)) {
379  OS << "alignof(" << *AllocTy << ")";
380  return;
381  }
382 
383  Type *CTy;
384  Constant *FieldNo;
385  if (U->isOffsetOf(CTy, FieldNo)) {
386  OS << "offsetof(" << *CTy << ", ";
387  FieldNo->printAsOperand(OS, false);
388  OS << ")";
389  return;
390  }
391 
392  // Otherwise just print it normally.
393  U->getValue()->printAsOperand(OS, false);
394  return;
395  }
396  case scCouldNotCompute:
397  OS << "***COULDNOTCOMPUTE***";
398  return;
399  }
400  llvm_unreachable("Unknown SCEV kind!");
401 }
402 
403 Type *SCEV::getType() const {
404  switch (getSCEVType()) {
405  case scConstant:
406  return cast<SCEVConstant>(this)->getType();
407  case scPtrToInt:
408  case scTruncate:
409  case scZeroExtend:
410  case scSignExtend:
411  return cast<SCEVCastExpr>(this)->getType();
412  case scAddRecExpr:
413  return cast<SCEVAddRecExpr>(this)->getType();
414  case scMulExpr:
415  return cast<SCEVMulExpr>(this)->getType();
416  case scUMaxExpr:
417  case scSMaxExpr:
418  case scUMinExpr:
419  case scSMinExpr:
420  return cast<SCEVMinMaxExpr>(this)->getType();
422  return cast<SCEVSequentialMinMaxExpr>(this)->getType();
423  case scAddExpr:
424  return cast<SCEVAddExpr>(this)->getType();
425  case scUDivExpr:
426  return cast<SCEVUDivExpr>(this)->getType();
427  case scUnknown:
428  return cast<SCEVUnknown>(this)->getType();
429  case scCouldNotCompute:
430  llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
431  }
432  llvm_unreachable("Unknown SCEV kind!");
433 }
434 
435 bool SCEV::isZero() const {
436  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
437  return SC->getValue()->isZero();
438  return false;
439 }
440 
441 bool SCEV::isOne() const {
442  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
443  return SC->getValue()->isOne();
444  return false;
445 }
446 
447 bool SCEV::isAllOnesValue() const {
448  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
449  return SC->getValue()->isMinusOne();
450  return false;
451 }
452 
454  const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(this);
455  if (!Mul) return false;
456 
457  // If there is a constant factor, it will be first.
458  const SCEVConstant *SC = dyn_cast<SCEVConstant>(Mul->getOperand(0));
459  if (!SC) return false;
460 
461  // Return true if the value is negative, this matches things like (-42 * V).
462  return SC->getAPInt().isNegative();
463 }
464 
467 
469  return S->getSCEVType() == scCouldNotCompute;
470 }
471 
475  ID.AddPointer(V);
476  void *IP = nullptr;
477  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
478  SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V);
479  UniqueSCEVs.InsertNode(S, IP);
480  return S;
481 }
482 
484  return getConstant(ConstantInt::get(getContext(), Val));
485 }
486 
487 const SCEV *
489  IntegerType *ITy = cast<IntegerType>(getEffectiveSCEVType(Ty));
490  return getConstant(ConstantInt::get(ITy, V, isSigned));
491 }
492 
494  const SCEV *op, Type *ty)
495  : SCEV(ID, SCEVTy, computeExpressionSize(op)), Ty(ty) {
496  Operands[0] = op;
497 }
498 
499 SCEVPtrToIntExpr::SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op,
500  Type *ITy)
501  : SCEVCastExpr(ID, scPtrToInt, Op, ITy) {
502  assert(getOperand()->getType()->isPointerTy() && Ty->isIntegerTy() &&
503  "Must be a non-bit-width-changing pointer-to-integer cast!");
504 }
505 
507  SCEVTypes SCEVTy, const SCEV *op,
508  Type *ty)
509  : SCEVCastExpr(ID, SCEVTy, op, ty) {}
510 
511 SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op,
512  Type *ty)
514  assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
515  "Cannot truncate non-integer value!");
516 }
517 
518 SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID,
519  const SCEV *op, Type *ty)
521  assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
522  "Cannot zero extend non-integer value!");
523 }
524 
525 SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID,
526  const SCEV *op, Type *ty)
528  assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
529  "Cannot sign extend non-integer value!");
530 }
531 
532 void SCEVUnknown::deleted() {
533  // Clear this SCEVUnknown from various maps.
534  SE->forgetMemoizedResults(this);
535 
536  // Remove this SCEVUnknown from the uniquing map.
537  SE->UniqueSCEVs.RemoveNode(this);
538 
539  // Release the value.
540  setValPtr(nullptr);
541 }
542 
543 void SCEVUnknown::allUsesReplacedWith(Value *New) {
544  // Clear this SCEVUnknown from various maps.
545  SE->forgetMemoizedResults(this);
546 
547  // Remove this SCEVUnknown from the uniquing map.
548  SE->UniqueSCEVs.RemoveNode(this);
549 
550  // Replace the value pointer in case someone is still using this SCEVUnknown.
551  setValPtr(New);
552 }
553 
554 bool SCEVUnknown::isSizeOf(Type *&AllocTy) const {
555  if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue()))
556  if (VCE->getOpcode() == Instruction::PtrToInt)
557  if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
558  if (CE->getOpcode() == Instruction::GetElementPtr &&
559  CE->getOperand(0)->isNullValue() &&
560  CE->getNumOperands() == 2)
561  if (ConstantInt *CI = dyn_cast<ConstantInt>(CE->getOperand(1)))
562  if (CI->isOne()) {
563  AllocTy = cast<GEPOperator>(CE)->getSourceElementType();
564  return true;
565  }
566 
567  return false;
568 }
569 
570 bool SCEVUnknown::isAlignOf(Type *&AllocTy) const {
571  if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue()))
572  if (VCE->getOpcode() == Instruction::PtrToInt)
573  if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
574  if (CE->getOpcode() == Instruction::GetElementPtr &&
575  CE->getOperand(0)->isNullValue()) {
576  Type *Ty = cast<GEPOperator>(CE)->getSourceElementType();
577  if (StructType *STy = dyn_cast<StructType>(Ty))
578  if (!STy->isPacked() &&
579  CE->getNumOperands() == 3 &&
580  CE->getOperand(1)->isNullValue()) {
581  if (ConstantInt *CI = dyn_cast<ConstantInt>(CE->getOperand(2)))
582  if (CI->isOne() &&
583  STy->getNumElements() == 2 &&
584  STy->getElementType(0)->isIntegerTy(1)) {
585  AllocTy = STy->getElementType(1);
586  return true;
587  }
588  }
589  }
590 
591  return false;
592 }
593 
594 bool SCEVUnknown::isOffsetOf(Type *&CTy, Constant *&FieldNo) const {
595  if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue()))
596  if (VCE->getOpcode() == Instruction::PtrToInt)
597  if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
598  if (CE->getOpcode() == Instruction::GetElementPtr &&
599  CE->getNumOperands() == 3 &&
600  CE->getOperand(0)->isNullValue() &&
601  CE->getOperand(1)->isNullValue()) {
602  Type *Ty = cast<GEPOperator>(CE)->getSourceElementType();
603  // Ignore vector types here so that ScalarEvolutionExpander doesn't
604  // emit getelementptrs that index into vectors.
605  if (Ty->isStructTy() || Ty->isArrayTy()) {
606  CTy = Ty;
607  FieldNo = CE->getOperand(2);
608  return true;
609  }
610  }
611 
612  return false;
613 }
614 
615 //===----------------------------------------------------------------------===//
616 // SCEV Utilities
617 //===----------------------------------------------------------------------===//
618 
619 /// Compare the two values \p LV and \p RV in terms of their "complexity" where
620 /// "complexity" is a partial (and somewhat ad-hoc) relation used to order
621 /// operands in SCEV expressions. \p EqCache is a set of pairs of values that
622 /// have been previously deemed to be "equally complex" by this routine. It is
623 /// intended to avoid exponential time complexity in cases like:
624 ///
625 /// %a = f(%x, %y)
626 /// %b = f(%a, %a)
627 /// %c = f(%b, %b)
628 ///
629 /// %d = f(%x, %y)
630 /// %e = f(%d, %d)
631 /// %f = f(%e, %e)
632 ///
633 /// CompareValueComplexity(%f, %c)
634 ///
635 /// Since we do not continue running this routine on expression trees once we
636 /// have seen unequal values, there is no need to track them in the cache.
637 static int
639  const LoopInfo *const LI, Value *LV, Value *RV,
640  unsigned Depth) {
641  if (Depth > MaxValueCompareDepth || EqCacheValue.isEquivalent(LV, RV))
642  return 0;
643 
644  // Order pointer values after integer values. This helps SCEVExpander form
645  // GEPs.
646  bool LIsPointer = LV->getType()->isPointerTy(),
647  RIsPointer = RV->getType()->isPointerTy();
648  if (LIsPointer != RIsPointer)
649  return (int)LIsPointer - (int)RIsPointer;
650 
651  // Compare getValueID values.
652  unsigned LID = LV->getValueID(), RID = RV->getValueID();
653  if (LID != RID)
654  return (int)LID - (int)RID;
655 
656  // Sort arguments by their position.
657  if (const auto *LA = dyn_cast<Argument>(LV)) {
658  const auto *RA = cast<Argument>(RV);
659  unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo();
660  return (int)LArgNo - (int)RArgNo;
661  }
662 
663  if (const auto *LGV = dyn_cast<GlobalValue>(LV)) {
664  const auto *RGV = cast<GlobalValue>(RV);
665 
666  const auto IsGVNameSemantic = [&](const GlobalValue *GV) {
667  auto LT = GV->getLinkage();
668  return !(GlobalValue::isPrivateLinkage(LT) ||
670  };
671 
672  // Use the names to distinguish the two values, but only if the
673  // names are semantically important.
674  if (IsGVNameSemantic(LGV) && IsGVNameSemantic(RGV))
675  return LGV->getName().compare(RGV->getName());
676  }
677 
678  // For instructions, compare their loop depth, and their operand count. This
679  // is pretty loose.
680  if (const auto *LInst = dyn_cast<Instruction>(LV)) {
681  const auto *RInst = cast<Instruction>(RV);
682 
683  // Compare loop depths.
684  const BasicBlock *LParent = LInst->getParent(),
685  *RParent = RInst->getParent();
686  if (LParent != RParent) {
687  unsigned LDepth = LI->getLoopDepth(LParent),
688  RDepth = LI->getLoopDepth(RParent);
689  if (LDepth != RDepth)
690  return (int)LDepth - (int)RDepth;
691  }
692 
693  // Compare the number of operands.
694  unsigned LNumOps = LInst->getNumOperands(),
695  RNumOps = RInst->getNumOperands();
696  if (LNumOps != RNumOps)
697  return (int)LNumOps - (int)RNumOps;
698 
699  for (unsigned Idx : seq(0u, LNumOps)) {
700  int Result =
701  CompareValueComplexity(EqCacheValue, LI, LInst->getOperand(Idx),
702  RInst->getOperand(Idx), Depth + 1);
703  if (Result != 0)
704  return Result;
705  }
706  }
707 
708  EqCacheValue.unionSets(LV, RV);
709  return 0;
710 }
711 
712 // Return negative, zero, or positive, if LHS is less than, equal to, or greater
713 // than RHS, respectively. A three-way result allows recursive comparisons to be
714 // more efficient.
715 // If the max analysis depth was reached, return std::nullopt, assuming we do
716 // not know if they are equivalent for sure.
717 static std::optional<int>
719  EquivalenceClasses<const Value *> &EqCacheValue,
720  const LoopInfo *const LI, const SCEV *LHS,
721  const SCEV *RHS, DominatorTree &DT, unsigned Depth = 0) {
722  // Fast-path: SCEVs are uniqued so we can do a quick equality check.
723  if (LHS == RHS)
724  return 0;
725 
726  // Primarily, sort the SCEVs by their getSCEVType().
727  SCEVTypes LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
728  if (LType != RType)
729  return (int)LType - (int)RType;
730 
731  if (EqCacheSCEV.isEquivalent(LHS, RHS))
732  return 0;
733 
735  return std::nullopt;
736 
737  // Aside from the getSCEVType() ordering, the particular ordering
738  // isn't very important except that it's beneficial to be consistent,
739  // so that (a + b) and (b + a) don't end up as different expressions.
740  switch (LType) {
741  case scUnknown: {
742  const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
743  const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
744 
745  int X = CompareValueComplexity(EqCacheValue, LI, LU->getValue(),
746  RU->getValue(), Depth + 1);
747  if (X == 0)
748  EqCacheSCEV.unionSets(LHS, RHS);
749  return X;
750  }
751 
752  case scConstant: {
753  const SCEVConstant *LC = cast<SCEVConstant>(LHS);
754  const SCEVConstant *RC = cast<SCEVConstant>(RHS);
755 
756  // Compare constant values.
757  const APInt &LA = LC->getAPInt();
758  const APInt &RA = RC->getAPInt();
759  unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth();
760  if (LBitWidth != RBitWidth)
761  return (int)LBitWidth - (int)RBitWidth;
762  return LA.ult(RA) ? -1 : 1;
763  }
764 
765  case scAddRecExpr: {
766  const SCEVAddRecExpr *LA = cast<SCEVAddRecExpr>(LHS);
767  const SCEVAddRecExpr *RA = cast<SCEVAddRecExpr>(RHS);
768 
769  // There is always a dominance between two recs that are used by one SCEV,
770  // so we can safely sort recs by loop header dominance. We require such
771  // order in getAddExpr.
772  const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop();
773  if (LLoop != RLoop) {
774  const BasicBlock *LHead = LLoop->getHeader(), *RHead = RLoop->getHeader();
775  assert(LHead != RHead && "Two loops share the same header?");
776  if (DT.dominates(LHead, RHead))
777  return 1;
778  else
779  assert(DT.dominates(RHead, LHead) &&
780  "No dominance between recurrences used by one SCEV?");
781  return -1;
782  }
783 
784  // Addrec complexity grows with operand count.
785  unsigned LNumOps = LA->getNumOperands(), RNumOps = RA->getNumOperands();
786  if (LNumOps != RNumOps)
787  return (int)LNumOps - (int)RNumOps;
788 
789  // Lexicographically compare.
790  for (unsigned i = 0; i != LNumOps; ++i) {
791  auto X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI,
792  LA->getOperand(i), RA->getOperand(i), DT,
793  Depth + 1);
794  if (X != 0)
795  return X;
796  }
797  EqCacheSCEV.unionSets(LHS, RHS);
798  return 0;
799  }
800 
801  case scAddExpr:
802  case scMulExpr:
803  case scSMaxExpr:
804  case scUMaxExpr:
805  case scSMinExpr:
806  case scUMinExpr:
807  case scSequentialUMinExpr: {
808  const SCEVNAryExpr *LC = cast<SCEVNAryExpr>(LHS);
809  const SCEVNAryExpr *RC = cast<SCEVNAryExpr>(RHS);
810 
811  // Lexicographically compare n-ary expressions.
812  unsigned LNumOps = LC->getNumOperands(), RNumOps = RC->getNumOperands();
813  if (LNumOps != RNumOps)
814  return (int)LNumOps - (int)RNumOps;
815 
816  for (unsigned i = 0; i != LNumOps; ++i) {
817  auto X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI,
818  LC->getOperand(i), RC->getOperand(i), DT,
819  Depth + 1);
820  if (X != 0)
821  return X;
822  }
823  EqCacheSCEV.unionSets(LHS, RHS);
824  return 0;
825  }
826 
827  case scUDivExpr: {
828  const SCEVUDivExpr *LC = cast<SCEVUDivExpr>(LHS);
829  const SCEVUDivExpr *RC = cast<SCEVUDivExpr>(RHS);
830 
831  // Lexicographically compare udiv expressions.
832  auto X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LC->getLHS(),
833  RC->getLHS(), DT, Depth + 1);
834  if (X != 0)
835  return X;
836  X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LC->getRHS(),
837  RC->getRHS(), DT, Depth + 1);
838  if (X == 0)
839  EqCacheSCEV.unionSets(LHS, RHS);
840  return X;
841  }
842 
843  case scPtrToInt:
844  case scTruncate:
845  case scZeroExtend:
846  case scSignExtend: {
847  const SCEVCastExpr *LC = cast<SCEVCastExpr>(LHS);
848  const SCEVCastExpr *RC = cast<SCEVCastExpr>(RHS);
849 
850  // Compare cast expressions by operand.
851  auto X =
852  CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LC->getOperand(),
853  RC->getOperand(), DT, Depth + 1);
854  if (X == 0)
855  EqCacheSCEV.unionSets(LHS, RHS);
856  return X;
857  }
858 
859  case scCouldNotCompute:
860  llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
861  }
862  llvm_unreachable("Unknown SCEV kind!");
863 }
864 
865 /// Given a list of SCEV objects, order them by their complexity, and group
866 /// objects of the same complexity together by value. When this routine is
867 /// finished, we know that any duplicates in the vector are consecutive and that
868 /// complexity is monotonically increasing.
869 ///
870 /// Note that we go take special precautions to ensure that we get deterministic
871 /// results from this routine. In other words, we don't want the results of
872 /// this to depend on where the addresses of various SCEV objects happened to
873 /// land in memory.
875  LoopInfo *LI, DominatorTree &DT) {
876  if (Ops.size() < 2) return; // Noop
877 
880 
881  // Whether LHS has provably less complexity than RHS.
882  auto IsLessComplex = [&](const SCEV *LHS, const SCEV *RHS) {
883  auto Complexity =
884  CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LHS, RHS, DT);
885  return Complexity && *Complexity < 0;
886  };
887  if (Ops.size() == 2) {
888  // This is the common case, which also happens to be trivially simple.
889  // Special case it.
890  const SCEV *&LHS = Ops[0], *&RHS = Ops[1];
891  if (IsLessComplex(RHS, LHS))
892  std::swap(LHS, RHS);
893  return;
894  }
895 
896  // Do the rough sort by complexity.
897  llvm::stable_sort(Ops, [&](const SCEV *LHS, const SCEV *RHS) {
898  return IsLessComplex(LHS, RHS);
899  });
900 
901  // Now that we are sorted by complexity, group elements of the same
902  // complexity. Note that this is, at worst, N^2, but the vector is likely to
903  // be extremely short in practice. Note that we take this approach because we
904  // do not want to depend on the addresses of the objects we are grouping.
905  for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) {
906  const SCEV *S = Ops[i];
907  unsigned Complexity = S->getSCEVType();
908 
909  // If there are any objects of the same complexity and same value as this
910  // one, group them.
911  for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) {
912  if (Ops[j] == S) { // Found a duplicate.
913  // Move it to immediately after i'th element.
914  std::swap(Ops[i+1], Ops[j]);
915  ++i; // no need to rescan it.
916  if (i == e-2) return; // Done!
917  }
918  }
919  }
920 }
921 
922 /// Returns true if \p Ops contains a huge SCEV (the subtree of S contains at
923 /// least HugeExprThreshold nodes).
925  return any_of(Ops, [](const SCEV *S) {
926  return S->getExpressionSize() >= HugeExprThreshold;
927  });
928 }
929 
930 //===----------------------------------------------------------------------===//
931 // Simple SCEV method implementations
932 //===----------------------------------------------------------------------===//
933 
934 /// Compute BC(It, K). The result has width W. Assume, K > 0.
935 static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
936  ScalarEvolution &SE,
937  Type *ResultTy) {
938  // Handle the simplest case efficiently.
939  if (K == 1)
940  return SE.getTruncateOrZeroExtend(It, ResultTy);
941 
942  // We are using the following formula for BC(It, K):
943  //
944  // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
945  //
946  // Suppose, W is the bitwidth of the return value. We must be prepared for
947  // overflow. Hence, we must assure that the result of our computation is
948  // equal to the accurate one modulo 2^W. Unfortunately, division isn't
949  // safe in modular arithmetic.
950  //
951  // However, this code doesn't use exactly that formula; the formula it uses
952  // is something like the following, where T is the number of factors of 2 in
953  // K! (i.e. trailing zeros in the binary representation of K!), and ^ is
954  // exponentiation:
955  //
956  // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T)
957  //
958  // This formula is trivially equivalent to the previous formula. However,
959  // this formula can be implemented much more efficiently. The trick is that
960  // K! / 2^T is odd, and exact division by an odd number *is* safe in modular
961  // arithmetic. To do exact division in modular arithmetic, all we have
962  // to do is multiply by the inverse. Therefore, this step can be done at
963  // width W.
964  //
965  // The next issue is how to safely do the division by 2^T. The way this
966  // is done is by doing the multiplication step at a width of at least W + T
967  // bits. This way, the bottom W+T bits of the product are accurate. Then,
968  // when we perform the division by 2^T (which is equivalent to a right shift
969  // by T), the bottom W bits are accurate. Extra bits are okay; they'll get
970  // truncated out after the division by 2^T.
971  //
972  // In comparison to just directly using the first formula, this technique
973  // is much more efficient; using the first formula requires W * K bits,
974  // but this formula less than W + K bits. Also, the first formula requires
975  // a division step, whereas this formula only requires multiplies and shifts.
976  //
977  // It doesn't matter whether the subtraction step is done in the calculation
978  // width or the input iteration count's width; if the subtraction overflows,
979  // the result must be zero anyway. We prefer here to do it in the width of
980  // the induction variable because it helps a lot for certain cases; CodeGen
981  // isn't smart enough to ignore the overflow, which leads to much less
982  // efficient code if the width of the subtraction is wider than the native
983  // register width.
984  //
985  // (It's possible to not widen at all by pulling out factors of 2 before
986  // the multiplication; for example, K=2 can be calculated as
987  // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires
988  // extra arithmetic, so it's not an obvious win, and it gets
989  // much more complicated for K > 3.)
990 
991  // Protection from insane SCEVs; this bound is conservative,
992  // but it probably doesn't matter.
993  if (K > 1000)
994  return SE.getCouldNotCompute();
995 
996  unsigned W = SE.getTypeSizeInBits(ResultTy);
997 
998  // Calculate K! / 2^T and T; we divide out the factors of two before
999  // multiplying for calculating K! / 2^T to avoid overflow.
1000  // Other overflow doesn't matter because we only care about the bottom
1001  // W bits of the result.
1002  APInt OddFactorial(W, 1);
1003  unsigned T = 1;
1004  for (unsigned i = 3; i <= K; ++i) {
1005  APInt Mult(W, i);
1006  unsigned TwoFactors = Mult.countTrailingZeros();
1007  T += TwoFactors;
1008  Mult.lshrInPlace(TwoFactors);
1009  OddFactorial *= Mult;
1010  }
1011 
1012  // We need at least W + T bits for the multiplication step
1013  unsigned CalculationBits = W + T;
1014 
1015  // Calculate 2^T, at width T+W.
1016  APInt DivFactor = APInt::getOneBitSet(CalculationBits, T);
1017 
1018  // Calculate the multiplicative inverse of K! / 2^T;
1019  // this multiplication factor will perform the exact division by
1020  // K! / 2^T.
1022  APInt MultiplyFactor = OddFactorial.zext(W+1);
1023  MultiplyFactor = MultiplyFactor.multiplicativeInverse(Mod);
1024  MultiplyFactor = MultiplyFactor.trunc(W);
1025 
1026  // Calculate the product, at width T+W
1027  IntegerType *CalculationTy = IntegerType::get(SE.getContext(),
1028  CalculationBits);
1029  const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
1030  for (unsigned i = 1; i != K; ++i) {
1031  const SCEV *S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i));
1032  Dividend = SE.getMulExpr(Dividend,
1033  SE.getTruncateOrZeroExtend(S, CalculationTy));
1034  }
1035 
1036  // Divide by 2^T
1037  const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
1038 
1039  // Truncate the result, and divide by K! / 2^T.
1040 
1041  return SE.getMulExpr(SE.getConstant(MultiplyFactor),
1042  SE.getTruncateOrZeroExtend(DivResult, ResultTy));
1043 }
1044 
1045 /// Return the value of this chain of recurrences at the specified iteration
1046 /// number. We can evaluate this recurrence by multiplying each element in the
1047 /// chain by the binomial coefficient corresponding to it. In other words, we
1048 /// can evaluate {A,+,B,+,C,+,D} as:
1049 ///
1050 /// A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
1051 ///
1052 /// where BC(It, k) stands for binomial coefficient.
1054  ScalarEvolution &SE) const {
1055  return evaluateAtIteration(makeArrayRef(op_begin(), op_end()), It, SE);
1056 }
1057 
1058 const SCEV *
1060  const SCEV *It, ScalarEvolution &SE) {
1061  assert(Operands.size() > 0);
1062  const SCEV *Result = Operands[0];
1063  for (unsigned i = 1, e = Operands.size(); i != e; ++i) {
1064  // The computation is correct in the face of overflow provided that the
1065  // multiplication is performed _after_ the evaluation of the binomial
1066  // coefficient.
1067  const SCEV *Coeff = BinomialCoefficient(It, i, SE, Result->getType());
1068  if (isa<SCEVCouldNotCompute>(Coeff))
1069  return Coeff;
1070 
1071  Result = SE.getAddExpr(Result, SE.getMulExpr(Operands[i], Coeff));
1072  }
1073  return Result;
1074 }
1075 
1076 //===----------------------------------------------------------------------===//
1077 // SCEV Expression folder implementations
1078 //===----------------------------------------------------------------------===//
1079 
1081  unsigned Depth) {
1082  assert(Depth <= 1 &&
1083  "getLosslessPtrToIntExpr() should self-recurse at most once.");
1084 
1085  // We could be called with an integer-typed operands during SCEV rewrites.
1086  // Since the operand is an integer already, just perform zext/trunc/self cast.
1087  if (!Op->getType()->isPointerTy())
1088  return Op;
1089 
1090  // What would be an ID for such a SCEV cast expression?
1093  ID.AddPointer(Op);
1094 
1095  void *IP = nullptr;
1096 
1097  // Is there already an expression for such a cast?
1098  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1099  return S;
1100 
1101  // It isn't legal for optimizations to construct new ptrtoint expressions
1102  // for non-integral pointers.
1103  if (getDataLayout().isNonIntegralPointerType(Op->getType()))
1104  return getCouldNotCompute();
1105 
1106  Type *IntPtrTy = getDataLayout().getIntPtrType(Op->getType());
1107 
1108  // We can only trivially model ptrtoint if SCEV's effective (integer) type
1109  // is sufficiently wide to represent all possible pointer values.
1110  // We could theoretically teach SCEV to truncate wider pointers, but
1111  // that isn't implemented for now.
1113  getDataLayout().getTypeSizeInBits(IntPtrTy))
1114  return getCouldNotCompute();
1115 
1116  // If not, is this expression something we can't reduce any further?
1117  if (auto *U = dyn_cast<SCEVUnknown>(Op)) {
1118  // Perform some basic constant folding. If the operand of the ptr2int cast
1119  // is a null pointer, don't create a ptr2int SCEV expression (that will be
1120  // left as-is), but produce a zero constant.
1121  // NOTE: We could handle a more general case, but lack motivational cases.
1122  if (isa<ConstantPointerNull>(U->getValue()))
1123  return getZero(IntPtrTy);
1124 
1125  // Create an explicit cast node.
1126  // We can reuse the existing insert position since if we get here,
1127  // we won't have made any changes which would invalidate it.
1128  SCEV *S = new (SCEVAllocator)
1129  SCEVPtrToIntExpr(ID.Intern(SCEVAllocator), Op, IntPtrTy);
1130  UniqueSCEVs.InsertNode(S, IP);
1131  registerUser(S, Op);
1132  return S;
1133  }
1134 
1135  assert(Depth == 0 && "getLosslessPtrToIntExpr() should not self-recurse for "
1136  "non-SCEVUnknown's.");
1137 
1138  // Otherwise, we've got some expression that is more complex than just a
1139  // single SCEVUnknown. But we don't want to have a SCEVPtrToIntExpr of an
1140  // arbitrary expression, we want to have SCEVPtrToIntExpr of an SCEVUnknown
1141  // only, and the expressions must otherwise be integer-typed.
1142  // So sink the cast down to the SCEVUnknown's.
1143 
1144  /// The SCEVPtrToIntSinkingRewriter takes a scalar evolution expression,
1145  /// which computes a pointer-typed value, and rewrites the whole expression
1146  /// tree so that *all* the computations are done on integers, and the only
1147  /// pointer-typed operands in the expression are SCEVUnknown.
1148  class SCEVPtrToIntSinkingRewriter
1149  : public SCEVRewriteVisitor<SCEVPtrToIntSinkingRewriter> {
1151 
1152  public:
1153  SCEVPtrToIntSinkingRewriter(ScalarEvolution &SE) : SCEVRewriteVisitor(SE) {}
1154 
1155  static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE) {
1156  SCEVPtrToIntSinkingRewriter Rewriter(SE);
1157  return Rewriter.visit(Scev);
1158  }
1159 
1160  const SCEV *visit(const SCEV *S) {
1161  Type *STy = S->getType();
1162  // If the expression is not pointer-typed, just keep it as-is.
1163  if (!STy->isPointerTy())
1164  return S;
1165  // Else, recursively sink the cast down into it.
1166  return Base::visit(S);
1167  }
1168 
1169  const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
1171  bool Changed = false;
1172  for (const auto *Op : Expr->operands()) {
1173  Operands.push_back(visit(Op));
1174  Changed |= Op != Operands.back();
1175  }
1176  return !Changed ? Expr : SE.getAddExpr(Operands, Expr->getNoWrapFlags());
1177  }
1178 
1179  const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
1181  bool Changed = false;
1182  for (const auto *Op : Expr->operands()) {
1183  Operands.push_back(visit(Op));
1184  Changed |= Op != Operands.back();
1185  }
1186  return !Changed ? Expr : SE.getMulExpr(Operands, Expr->getNoWrapFlags());
1187  }
1188 
1189  const SCEV *visitUnknown(const SCEVUnknown *Expr) {
1190  assert(Expr->getType()->isPointerTy() &&
1191  "Should only reach pointer-typed SCEVUnknown's.");
1192  return SE.getLosslessPtrToIntExpr(Expr, /*Depth=*/1);
1193  }
1194  };
1195 
1196  // And actually perform the cast sinking.
1197  const SCEV *IntOp = SCEVPtrToIntSinkingRewriter::rewrite(Op, *this);
1198  assert(IntOp->getType()->isIntegerTy() &&
1199  "We must have succeeded in sinking the cast, "
1200  "and ending up with an integer-typed expression!");
1201  return IntOp;
1202 }
1203 
1205  assert(Ty->isIntegerTy() && "Target type must be an integer type!");
1206 
1207  const SCEV *IntOp = getLosslessPtrToIntExpr(Op);
1208  if (isa<SCEVCouldNotCompute>(IntOp))
1209  return IntOp;
1210 
1211  return getTruncateOrZeroExtend(IntOp, Ty);
1212 }
1213 
1215  unsigned Depth) {
1216  assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) &&
1217  "This is not a truncating conversion!");
1218  assert(isSCEVable(Ty) &&
1219  "This is not a conversion to a SCEVable type!");
1220  assert(!Op->getType()->isPointerTy() && "Can't truncate pointer!");
1221  Ty = getEffectiveSCEVType(Ty);
1222 
1225  ID.AddPointer(Op);
1226  ID.AddPointer(Ty);
1227  void *IP = nullptr;
1228  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1229 
1230  // Fold if the operand is constant.
1231  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1232  return getConstant(
1233  cast<ConstantInt>(ConstantExpr::getTrunc(SC->getValue(), Ty)));
1234 
1235  // trunc(trunc(x)) --> trunc(x)
1236  if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op))
1237  return getTruncateExpr(ST->getOperand(), Ty, Depth + 1);
1238 
1239  // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing
1240  if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
1241  return getTruncateOrSignExtend(SS->getOperand(), Ty, Depth + 1);
1242 
1243  // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing
1244  if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1245  return getTruncateOrZeroExtend(SZ->getOperand(), Ty, Depth + 1);
1246 
1247  if (Depth > MaxCastDepth) {
1248  SCEV *S =
1249  new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), Op, Ty);
1250  UniqueSCEVs.InsertNode(S, IP);
1251  registerUser(S, Op);
1252  return S;
1253  }
1254 
1255  // trunc(x1 + ... + xN) --> trunc(x1) + ... + trunc(xN) and
1256  // trunc(x1 * ... * xN) --> trunc(x1) * ... * trunc(xN),
1257  // if after transforming we have at most one truncate, not counting truncates
1258  // that replace other casts.
1259  if (isa<SCEVAddExpr>(Op) || isa<SCEVMulExpr>(Op)) {
1260  auto *CommOp = cast<SCEVCommutativeExpr>(Op);
1262  unsigned numTruncs = 0;
1263  for (unsigned i = 0, e = CommOp->getNumOperands(); i != e && numTruncs < 2;
1264  ++i) {
1265  const SCEV *S = getTruncateExpr(CommOp->getOperand(i), Ty, Depth + 1);
1266  if (!isa<SCEVIntegralCastExpr>(CommOp->getOperand(i)) &&
1267  isa<SCEVTruncateExpr>(S))
1268  numTruncs++;
1269  Operands.push_back(S);
1270  }
1271  if (numTruncs < 2) {
1272  if (isa<SCEVAddExpr>(Op))
1273  return getAddExpr(Operands);
1274  else if (isa<SCEVMulExpr>(Op))
1275  return getMulExpr(Operands);
1276  else
1277  llvm_unreachable("Unexpected SCEV type for Op.");
1278  }
1279  // Although we checked in the beginning that ID is not in the cache, it is
1280  // possible that during recursion and different modification ID was inserted
1281  // into the cache. So if we find it, just return it.
1282  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1283  return S;
1284  }
1285 
1286  // If the input value is a chrec scev, truncate the chrec's operands.
1287  if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
1289  for (const SCEV *Op : AddRec->operands())
1290  Operands.push_back(getTruncateExpr(Op, Ty, Depth + 1));
1291  return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap);
1292  }
1293 
1294  // Return zero if truncating to known zeros.
1295  uint32_t MinTrailingZeros = GetMinTrailingZeros(Op);
1296  if (MinTrailingZeros >= getTypeSizeInBits(Ty))
1297  return getZero(Ty);
1298 
1299  // The cast wasn't folded; create an explicit cast node. We can reuse
1300  // the existing insert position since if we get here, we won't have
1301  // made any changes which would invalidate it.
1302  SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator),
1303  Op, Ty);
1304  UniqueSCEVs.InsertNode(S, IP);
1305  registerUser(S, Op);
1306  return S;
1307 }
1308 
1309 // Get the limit of a recurrence such that incrementing by Step cannot cause
1310 // signed overflow as long as the value of the recurrence within the
1311 // loop does not exceed this limit before incrementing.
1312 static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step,
1313  ICmpInst::Predicate *Pred,
1314  ScalarEvolution *SE) {
1315  unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1316  if (SE->isKnownPositive(Step)) {
1317  *Pred = ICmpInst::ICMP_SLT;
1319  SE->getSignedRangeMax(Step));
1320  }
1321  if (SE->isKnownNegative(Step)) {
1322  *Pred = ICmpInst::ICMP_SGT;
1324  SE->getSignedRangeMin(Step));
1325  }
1326  return nullptr;
1327 }
1328 
1329 // Get the limit of a recurrence such that incrementing by Step cannot cause
1330 // unsigned overflow as long as the value of the recurrence within the loop does
1331 // not exceed this limit before incrementing.
1332 static const SCEV *getUnsignedOverflowLimitForStep(const SCEV *Step,
1333  ICmpInst::Predicate *Pred,
1334  ScalarEvolution *SE) {
1335  unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1336  *Pred = ICmpInst::ICMP_ULT;
1337 
1339  SE->getUnsignedRangeMax(Step));
1340 }
1341 
1342 namespace {
1343 
1344 struct ExtendOpTraitsBase {
1345  typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *,
1346  unsigned);
1347 };
1348 
1349 // Used to make code generic over signed and unsigned overflow.
1350 template <typename ExtendOp> struct ExtendOpTraits {
1351  // Members present:
1352  //
1353  // static const SCEV::NoWrapFlags WrapType;
1354  //
1355  // static const ExtendOpTraitsBase::GetExtendExprTy GetExtendExpr;
1356  //
1357  // static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1358  // ICmpInst::Predicate *Pred,
1359  // ScalarEvolution *SE);
1360 };
1361 
1362 template <>
1363 struct ExtendOpTraits<SCEVSignExtendExpr> : public ExtendOpTraitsBase {
1364  static const SCEV::NoWrapFlags WrapType = SCEV::FlagNSW;
1365 
1366  static const GetExtendExprTy GetExtendExpr;
1367 
1368  static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1369  ICmpInst::Predicate *Pred,
1370  ScalarEvolution *SE) {
1371  return getSignedOverflowLimitForStep(Step, Pred, SE);
1372  }
1373 };
1374 
1375 const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1377 
1378 template <>
1379 struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase {
1380  static const SCEV::NoWrapFlags WrapType = SCEV::FlagNUW;
1381 
1382  static const GetExtendExprTy GetExtendExpr;
1383 
1384  static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1385  ICmpInst::Predicate *Pred,
1386  ScalarEvolution *SE) {
1387  return getUnsignedOverflowLimitForStep(Step, Pred, SE);
1388  }
1389 };
1390 
1391 const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1393 
1394 } // end anonymous namespace
1395 
1396 // The recurrence AR has been shown to have no signed/unsigned wrap or something
1397 // close to it. Typically, if we can prove NSW/NUW for AR, then we can just as
1398 // easily prove NSW/NUW for its preincrement or postincrement sibling. This
1399 // allows normalizing a sign/zero extended AddRec as such: {sext/zext(Step +
1400 // Start),+,Step} => {(Step + sext/zext(Start),+,Step} As a result, the
1401 // expression "Step + sext/zext(PreIncAR)" is congruent with
1402 // "sext/zext(PostIncAR)"
1403 template <typename ExtendOpTy>
1404 static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty,
1405  ScalarEvolution *SE, unsigned Depth) {
1406  auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1407  auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1408 
1409  const Loop *L = AR->getLoop();
1410  const SCEV *Start = AR->getStart();
1411  const SCEV *Step = AR->getStepRecurrence(*SE);
1412 
1413  // Check for a simple looking step prior to loop entry.
1414  const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Start);
1415  if (!SA)
1416  return nullptr;
1417 
1418  // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV
1419  // subtraction is expensive. For this purpose, perform a quick and dirty
1420  // difference, by checking for Step in the operand list.
1422  for (const SCEV *Op : SA->operands())
1423  if (Op != Step)
1424  DiffOps.push_back(Op);
1425 
1426  if (DiffOps.size() == SA->getNumOperands())
1427  return nullptr;
1428 
1429  // Try to prove `WrapType` (SCEV::FlagNSW or SCEV::FlagNUW) on `PreStart` +
1430  // `Step`:
1431 
1432  // 1. NSW/NUW flags on the step increment.
1433  auto PreStartFlags =
1435  const SCEV *PreStart = SE->getAddExpr(DiffOps, PreStartFlags);
1436  const SCEVAddRecExpr *PreAR = dyn_cast<SCEVAddRecExpr>(
1437  SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap));
1438 
1439  // "{S,+,X} is <nsw>/<nuw>" and "the backedge is taken at least once" implies
1440  // "S+X does not sign/unsign-overflow".
1441  //
1442 
1443  const SCEV *BECount = SE->getBackedgeTakenCount(L);
1444  if (PreAR && PreAR->getNoWrapFlags(WrapType) &&
1445  !isa<SCEVCouldNotCompute>(BECount) && SE->isKnownPositive(BECount))
1446  return PreStart;
1447 
1448  // 2. Direct overflow check on the step operation's expression.
1449  unsigned BitWidth = SE->getTypeSizeInBits(AR->getType());
1450  Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2);
1451  const SCEV *OperandExtendedStart =
1452  SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy, Depth),
1453  (SE->*GetExtendExpr)(Step, WideTy, Depth));
1454  if ((SE->*GetExtendExpr)(Start, WideTy, Depth) == OperandExtendedStart) {
1455  if (PreAR && AR->getNoWrapFlags(WrapType)) {
1456  // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW
1457  // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then
1458  // `PreAR` == {`PreStart`,+,`Step`} is also `WrapType`. Cache this fact.
1459  SE->setNoWrapFlags(const_cast<SCEVAddRecExpr *>(PreAR), WrapType);
1460  }
1461  return PreStart;
1462  }
1463 
1464  // 3. Loop precondition.
1465  ICmpInst::Predicate Pred;
1466  const SCEV *OverflowLimit =
1467  ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(Step, &Pred, SE);
1468 
1469  if (OverflowLimit &&
1470  SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit))
1471  return PreStart;
1472 
1473  return nullptr;
1474 }
1475 
1476 // Get the normalized zero or sign extended expression for this AddRec's Start.
1477 template <typename ExtendOpTy>
1478 static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty,
1479  ScalarEvolution *SE,
1480  unsigned Depth) {
1481  auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1482 
1483  const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE, Depth);
1484  if (!PreStart)
1485  return (SE->*GetExtendExpr)(AR->getStart(), Ty, Depth);
1486 
1487  return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty,
1488  Depth),
1489  (SE->*GetExtendExpr)(PreStart, Ty, Depth));
1490 }
1491 
1492 // Try to prove away overflow by looking at "nearby" add recurrences. A
1493 // motivating example for this rule: if we know `{0,+,4}` is `ult` `-1` and it
1494 // does not itself wrap then we can conclude that `{1,+,4}` is `nuw`.
1495 //
1496 // Formally:
1497 //
1498 // {S,+,X} == {S-T,+,X} + T
1499 // => Ext({S,+,X}) == Ext({S-T,+,X} + T)
1500 //
1501 // If ({S-T,+,X} + T) does not overflow ... (1)
1502 //
1503 // RHS == Ext({S-T,+,X} + T) == Ext({S-T,+,X}) + Ext(T)
1504 //
1505 // If {S-T,+,X} does not overflow ... (2)
1506 //
1507 // RHS == Ext({S-T,+,X}) + Ext(T) == {Ext(S-T),+,Ext(X)} + Ext(T)
1508 // == {Ext(S-T)+Ext(T),+,Ext(X)}
1509 //
1510 // If (S-T)+T does not overflow ... (3)
1511 //
1512 // RHS == {Ext(S-T)+Ext(T),+,Ext(X)} == {Ext(S-T+T),+,Ext(X)}
1513 // == {Ext(S),+,Ext(X)} == LHS
1514 //
1515 // Thus, if (1), (2) and (3) are true for some T, then
1516 // Ext({S,+,X}) == {Ext(S),+,Ext(X)}
1517 //
1518 // (3) is implied by (1) -- "(S-T)+T does not overflow" is simply "({S-T,+,X}+T)
1519 // does not overflow" restricted to the 0th iteration. Therefore we only need
1520 // to check for (1) and (2).
1521 //
1522 // In the current context, S is `Start`, X is `Step`, Ext is `ExtendOpTy` and T
1523 // is `Delta` (defined below).
1524 template <typename ExtendOpTy>
1525 bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start,
1526  const SCEV *Step,
1527  const Loop *L) {
1528  auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1529 
1530  // We restrict `Start` to a constant to prevent SCEV from spending too much
1531  // time here. It is correct (but more expensive) to continue with a
1532  // non-constant `Start` and do a general SCEV subtraction to compute
1533  // `PreStart` below.
1534  const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start);
1535  if (!StartC)
1536  return false;
1537 
1538  APInt StartAI = StartC->getAPInt();
1539 
1540  for (unsigned Delta : {-2, -1, 1, 2}) {
1541  const SCEV *PreStart = getConstant(StartAI - Delta);
1542 
1545  ID.AddPointer(PreStart);
1546  ID.AddPointer(Step);
1547  ID.AddPointer(L);
1548  void *IP = nullptr;
1549  const auto *PreAR =
1550  static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
1551 
1552  // Give up if we don't already have the add recurrence we need because
1553  // actually constructing an add recurrence is relatively expensive.
1554  if (PreAR && PreAR->getNoWrapFlags(WrapType)) { // proves (2)
1555  const SCEV *DeltaS = getConstant(StartC->getType(), Delta);
1557  const SCEV *Limit = ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(
1558  DeltaS, &Pred, this);
1559  if (Limit && isKnownPredicate(Pred, PreAR, Limit)) // proves (1)
1560  return true;
1561  }
1562  }
1563 
1564  return false;
1565 }
1566 
1567 // Finds an integer D for an expression (C + x + y + ...) such that the top
1568 // level addition in (D + (C - D + x + y + ...)) would not wrap (signed or
1569 // unsigned) and the number of trailing zeros of (C - D + x + y + ...) is
1570 // maximized, where C is the \p ConstantTerm, x, y, ... are arbitrary SCEVs, and
1571 // the (C + x + y + ...) expression is \p WholeAddExpr.
1573  const SCEVConstant *ConstantTerm,
1574  const SCEVAddExpr *WholeAddExpr) {
1575  const APInt &C = ConstantTerm->getAPInt();
1576  const unsigned BitWidth = C.getBitWidth();
1577  // Find number of trailing zeros of (x + y + ...) w/o the C first:
1578  uint32_t TZ = BitWidth;
1579  for (unsigned I = 1, E = WholeAddExpr->getNumOperands(); I < E && TZ; ++I)
1580  TZ = std::min(TZ, SE.GetMinTrailingZeros(WholeAddExpr->getOperand(I)));
1581  if (TZ) {
1582  // Set D to be as many least significant bits of C as possible while still
1583  // guaranteeing that adding D to (C - D + x + y + ...) won't cause a wrap:
1584  return TZ < BitWidth ? C.trunc(TZ).zext(BitWidth) : C;
1585  }
1586  return APInt(BitWidth, 0);
1587 }
1588 
1589 // Finds an integer D for an affine AddRec expression {C,+,x} such that the top
1590 // level addition in (D + {C-D,+,x}) would not wrap (signed or unsigned) and the
1591 // number of trailing zeros of (C - D + x * n) is maximized, where C is the \p
1592 // ConstantStart, x is an arbitrary \p Step, and n is the loop trip count.
1594  const APInt &ConstantStart,
1595  const SCEV *Step) {
1596  const unsigned BitWidth = ConstantStart.getBitWidth();
1597  const uint32_t TZ = SE.GetMinTrailingZeros(Step);
1598  if (TZ)
1599  return TZ < BitWidth ? ConstantStart.trunc(TZ).zext(BitWidth)
1600  : ConstantStart;
1601  return APInt(BitWidth, 0);
1602 }
1603 
1604 const SCEV *
1606  assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1607  "This is not an extending conversion!");
1608  assert(isSCEVable(Ty) &&
1609  "This is not a conversion to a SCEVable type!");
1610  assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1611  Ty = getEffectiveSCEVType(Ty);
1612 
1613  // Fold if the operand is constant.
1614  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1615  return getConstant(
1616  cast<ConstantInt>(ConstantExpr::getZExt(SC->getValue(), Ty)));
1617 
1618  // zext(zext(x)) --> zext(x)
1619  if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1620  return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1621 
1622  // Before doing any expensive analysis, check to see if we've already
1623  // computed a SCEV for this Op and Ty.
1626  ID.AddPointer(Op);
1627  ID.AddPointer(Ty);
1628  void *IP = nullptr;
1629  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1630  if (Depth > MaxCastDepth) {
1631  SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1632  Op, Ty);
1633  UniqueSCEVs.InsertNode(S, IP);
1634  registerUser(S, Op);
1635  return S;
1636  }
1637 
1638  // zext(trunc(x)) --> zext(x) or x or trunc(x)
1639  if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
1640  // It's possible the bits taken off by the truncate were all zero bits. If
1641  // so, we should be able to simplify this further.
1642  const SCEV *X = ST->getOperand();
1644  unsigned TruncBits = getTypeSizeInBits(ST->getType());
1645  unsigned NewBits = getTypeSizeInBits(Ty);
1646  if (CR.truncate(TruncBits).zeroExtend(NewBits).contains(
1647  CR.zextOrTrunc(NewBits)))
1648  return getTruncateOrZeroExtend(X, Ty, Depth);
1649  }
1650 
1651  // If the input value is a chrec scev, and we can prove that the value
1652  // did not overflow the old, smaller, value, we can zero extend all of the
1653  // operands (often constants). This allows analysis of something like
1654  // this: for (unsigned char X = 0; X < 100; ++X) { int Y = X; }
1655  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
1656  if (AR->isAffine()) {
1657  const SCEV *Start = AR->getStart();
1658  const SCEV *Step = AR->getStepRecurrence(*this);
1659  unsigned BitWidth = getTypeSizeInBits(AR->getType());
1660  const Loop *L = AR->getLoop();
1661 
1662  if (!AR->hasNoUnsignedWrap()) {
1663  auto NewFlags = proveNoWrapViaConstantRanges(AR);
1664  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
1665  }
1666 
1667  // If we have special knowledge that this addrec won't overflow,
1668  // we don't need to do any further analysis.
1669  if (AR->hasNoUnsignedWrap()) {
1670  Start =
1671  getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1);
1672  Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1673  return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1674  }
1675 
1676  // Check whether the backedge-taken count is SCEVCouldNotCompute.
1677  // Note that this serves two purposes: It filters out loops that are
1678  // simply not analyzable, and it covers the case where this code is
1679  // being called from within backedge-taken count analysis, such that
1680  // attempting to ask for the backedge-taken count would likely result
1681  // in infinite recursion. In the later case, the analysis code will
1682  // cope with a conservative value, and it will take care to purge
1683  // that value once it has finished.
1684  const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
1685  if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
1686  // Manually compute the final value for AR, checking for overflow.
1687 
1688  // Check whether the backedge-taken count can be losslessly casted to
1689  // the addrec's type. The count is always unsigned.
1690  const SCEV *CastedMaxBECount =
1691  getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
1692  const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
1693  CastedMaxBECount, MaxBECount->getType(), Depth);
1694  if (MaxBECount == RecastedMaxBECount) {
1695  Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
1696  // Check whether Start+Step*MaxBECount has no unsigned overflow.
1697  const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step,
1698  SCEV::FlagAnyWrap, Depth + 1);
1699  const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul,
1701  Depth + 1),
1702  WideTy, Depth + 1);
1703  const SCEV *WideStart = getZeroExtendExpr(Start, WideTy, Depth + 1);
1704  const SCEV *WideMaxBECount =
1705  getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
1706  const SCEV *OperandExtendedAdd =
1707  getAddExpr(WideStart,
1708  getMulExpr(WideMaxBECount,
1709  getZeroExtendExpr(Step, WideTy, Depth + 1),
1710  SCEV::FlagAnyWrap, Depth + 1),
1711  SCEV::FlagAnyWrap, Depth + 1);
1712  if (ZAdd == OperandExtendedAdd) {
1713  // Cache knowledge of AR NUW, which is propagated to this AddRec.
1714  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1715  // Return the expression with the addrec on the outside.
1716  Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1717  Depth + 1);
1718  Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1719  return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1720  }
1721  // Similar to above, only this time treat the step value as signed.
1722  // This covers loops that count down.
1723  OperandExtendedAdd =
1724  getAddExpr(WideStart,
1725  getMulExpr(WideMaxBECount,
1726  getSignExtendExpr(Step, WideTy, Depth + 1),
1727  SCEV::FlagAnyWrap, Depth + 1),
1728  SCEV::FlagAnyWrap, Depth + 1);
1729  if (ZAdd == OperandExtendedAdd) {
1730  // Cache knowledge of AR NW, which is propagated to this AddRec.
1731  // Negative step causes unsigned wrap, but it still can't self-wrap.
1732  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1733  // Return the expression with the addrec on the outside.
1734  Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1735  Depth + 1);
1736  Step = getSignExtendExpr(Step, Ty, Depth + 1);
1737  return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1738  }
1739  }
1740  }
1741 
1742  // Normally, in the cases we can prove no-overflow via a
1743  // backedge guarding condition, we can also compute a backedge
1744  // taken count for the loop. The exceptions are assumptions and
1745  // guards present in the loop -- SCEV is not great at exploiting
1746  // these to compute max backedge taken counts, but can still use
1747  // these to prove lack of overflow. Use this fact to avoid
1748  // doing extra work that may not pay off.
1749  if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards ||
1750  !AC.assumptions().empty()) {
1751 
1752  auto NewFlags = proveNoUnsignedWrapViaInduction(AR);
1753  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
1754  if (AR->hasNoUnsignedWrap()) {
1755  // Same as nuw case above - duplicated here to avoid a compile time
1756  // issue. It's not clear that the order of checks does matter, but
1757  // it's one of two issue possible causes for a change which was
1758  // reverted. Be conservative for the moment.
1759  Start =
1760  getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1);
1761  Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1762  return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1763  }
1764 
1765  // For a negative step, we can extend the operands iff doing so only
1766  // traverses values in the range zext([0,UINT_MAX]).
1767  if (isKnownNegative(Step)) {
1769  getSignedRangeMin(Step));
1772  // Cache knowledge of AR NW, which is propagated to this
1773  // AddRec. Negative step causes unsigned wrap, but it
1774  // still can't self-wrap.
1775  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1776  // Return the expression with the addrec on the outside.
1777  Start = getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1778  Depth + 1);
1779  Step = getSignExtendExpr(Step, Ty, Depth + 1);
1780  return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1781  }
1782  }
1783  }
1784 
1785  // zext({C,+,Step}) --> (zext(D) + zext({C-D,+,Step}))<nuw><nsw>
1786  // if D + (C - D + Step * n) could be proven to not unsigned wrap
1787  // where D maximizes the number of trailing zeros of (C - D + Step * n)
1788  if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
1789  const APInt &C = SC->getAPInt();
1790  const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
1791  if (D != 0) {
1792  const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1793  const SCEV *SResidual =
1794  getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
1795  const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1796  return getAddExpr(SZExtD, SZExtR,
1798  Depth + 1);
1799  }
1800  }
1801 
1802  if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) {
1803  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1804  Start =
1805  getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1);
1806  Step = getZeroExtendExpr(Step, Ty, Depth + 1);
1807  return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
1808  }
1809  }
1810 
1811  // zext(A % B) --> zext(A) % zext(B)
1812  {
1813  const SCEV *LHS;
1814  const SCEV *RHS;
1815  if (matchURem(Op, LHS, RHS))
1816  return getURemExpr(getZeroExtendExpr(LHS, Ty, Depth + 1),
1817  getZeroExtendExpr(RHS, Ty, Depth + 1));
1818  }
1819 
1820  // zext(A / B) --> zext(A) / zext(B).
1821  if (auto *Div = dyn_cast<SCEVUDivExpr>(Op))
1822  return getUDivExpr(getZeroExtendExpr(Div->getLHS(), Ty, Depth + 1),
1823  getZeroExtendExpr(Div->getRHS(), Ty, Depth + 1));
1824 
1825  if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1826  // zext((A + B + ...)<nuw>) --> (zext(A) + zext(B) + ...)<nuw>
1827  if (SA->hasNoUnsignedWrap()) {
1828  // If the addition does not unsign overflow then we can, by definition,
1829  // commute the zero extension with the addition operation.
1831  for (const auto *Op : SA->operands())
1832  Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1833  return getAddExpr(Ops, SCEV::FlagNUW, Depth + 1);
1834  }
1835 
1836  // zext(C + x + y + ...) --> (zext(D) + zext((C - D) + x + y + ...))
1837  // if D + (C - D + x + y + ...) could be proven to not unsigned wrap
1838  // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1839  //
1840  // Often address arithmetics contain expressions like
1841  // (zext (add (shl X, C1), C2)), for instance, (zext (5 + (4 * X))).
1842  // This transformation is useful while proving that such expressions are
1843  // equal or differ by a small constant amount, see LoadStoreVectorizer pass.
1844  if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1845  const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1846  if (D != 0) {
1847  const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1848  const SCEV *SResidual =
1850  const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1851  return getAddExpr(SZExtD, SZExtR,
1853  Depth + 1);
1854  }
1855  }
1856  }
1857 
1858  if (auto *SM = dyn_cast<SCEVMulExpr>(Op)) {
1859  // zext((A * B * ...)<nuw>) --> (zext(A) * zext(B) * ...)<nuw>
1860  if (SM->hasNoUnsignedWrap()) {
1861  // If the multiply does not unsign overflow then we can, by definition,
1862  // commute the zero extension with the multiply operation.
1864  for (const auto *Op : SM->operands())
1865  Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1866  return getMulExpr(Ops, SCEV::FlagNUW, Depth + 1);
1867  }
1868 
1869  // zext(2^K * (trunc X to iN)) to iM ->
1870  // 2^K * (zext(trunc X to i{N-K}) to iM)<nuw>
1871  //
1872  // Proof:
1873  //
1874  // zext(2^K * (trunc X to iN)) to iM
1875  // = zext((trunc X to iN) << K) to iM
1876  // = zext((trunc X to i{N-K}) << K)<nuw> to iM
1877  // (because shl removes the top K bits)
1878  // = zext((2^K * (trunc X to i{N-K}))<nuw>) to iM
1879  // = (2^K * (zext(trunc X to i{N-K}) to iM))<nuw>.
1880  //
1881  if (SM->getNumOperands() == 2)
1882  if (auto *MulLHS = dyn_cast<SCEVConstant>(SM->getOperand(0)))
1883  if (MulLHS->getAPInt().isPowerOf2())
1884  if (auto *TruncRHS = dyn_cast<SCEVTruncateExpr>(SM->getOperand(1))) {
1885  int NewTruncBits = getTypeSizeInBits(TruncRHS->getType()) -
1886  MulLHS->getAPInt().logBase2();
1887  Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits);
1888  return getMulExpr(
1889  getZeroExtendExpr(MulLHS, Ty),
1891  getTruncateExpr(TruncRHS->getOperand(), NewTruncTy), Ty),
1892  SCEV::FlagNUW, Depth + 1);
1893  }
1894  }
1895 
1896  // The cast wasn't folded; create an explicit cast node.
1897  // Recompute the insert position, as it may have been invalidated.
1898  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1899  SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1900  Op, Ty);
1901  UniqueSCEVs.InsertNode(S, IP);
1902  registerUser(S, Op);
1903  return S;
1904 }
1905 
1906 const SCEV *
1908  assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1909  "This is not an extending conversion!");
1910  assert(isSCEVable(Ty) &&
1911  "This is not a conversion to a SCEVable type!");
1912  assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1913  Ty = getEffectiveSCEVType(Ty);
1914 
1915  // Fold if the operand is constant.
1916  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1917  return getConstant(
1918  cast<ConstantInt>(ConstantExpr::getSExt(SC->getValue(), Ty)));
1919 
1920  // sext(sext(x)) --> sext(x)
1921  if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
1922  return getSignExtendExpr(SS->getOperand(), Ty, Depth + 1);
1923 
1924  // sext(zext(x)) --> zext(x)
1925  if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1926  return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1927 
1928  // Before doing any expensive analysis, check to see if we've already
1929  // computed a SCEV for this Op and Ty.
1932  ID.AddPointer(Op);
1933  ID.AddPointer(Ty);
1934  void *IP = nullptr;
1935  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1936  // Limit recursion depth.
1937  if (Depth > MaxCastDepth) {
1938  SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
1939  Op, Ty);
1940  UniqueSCEVs.InsertNode(S, IP);
1941  registerUser(S, Op);
1942  return S;
1943  }
1944 
1945  // sext(trunc(x)) --> sext(x) or x or trunc(x)
1946  if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
1947  // It's possible the bits taken off by the truncate were all sign bits. If
1948  // so, we should be able to simplify this further.
1949  const SCEV *X = ST->getOperand();
1951  unsigned TruncBits = getTypeSizeInBits(ST->getType());
1952  unsigned NewBits = getTypeSizeInBits(Ty);
1953  if (CR.truncate(TruncBits).signExtend(NewBits).contains(
1954  CR.sextOrTrunc(NewBits)))
1955  return getTruncateOrSignExtend(X, Ty, Depth);
1956  }
1957 
1958  if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1959  // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
1960  if (SA->hasNoSignedWrap()) {
1961  // If the addition does not sign overflow then we can, by definition,
1962  // commute the sign extension with the addition operation.
1964  for (const auto *Op : SA->operands())
1965  Ops.push_back(getSignExtendExpr(Op, Ty, Depth + 1));
1966  return getAddExpr(Ops, SCEV::FlagNSW, Depth + 1);
1967  }
1968 
1969  // sext(C + x + y + ...) --> (sext(D) + sext((C - D) + x + y + ...))
1970  // if D + (C - D + x + y + ...) could be proven to not signed wrap
1971  // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1972  //
1973  // For instance, this will bring two seemingly different expressions:
1974  // 1 + sext(5 + 20 * %x + 24 * %y) and
1975  // sext(6 + 20 * %x + 24 * %y)
1976  // to the same form:
1977  // 2 + sext(4 + 20 * %x + 24 * %y)
1978  if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1979  const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1980  if (D != 0) {
1981  const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
1982  const SCEV *SResidual =
1984  const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
1985  return getAddExpr(SSExtD, SSExtR,
1987  Depth + 1);
1988  }
1989  }
1990  }
1991  // If the input value is a chrec scev, and we can prove that the value
1992  // did not overflow the old, smaller, value, we can sign extend all of the
1993  // operands (often constants). This allows analysis of something like
1994  // this: for (signed char X = 0; X < 100; ++X) { int Y = X; }
1995  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
1996  if (AR->isAffine()) {
1997  const SCEV *Start = AR->getStart();
1998  const SCEV *Step = AR->getStepRecurrence(*this);
1999  unsigned BitWidth = getTypeSizeInBits(AR->getType());
2000  const Loop *L = AR->getLoop();
2001 
2002  if (!AR->hasNoSignedWrap()) {
2003  auto NewFlags = proveNoWrapViaConstantRanges(AR);
2004  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
2005  }
2006 
2007  // If we have special knowledge that this addrec won't overflow,
2008  // we don't need to do any further analysis.
2009  if (AR->hasNoSignedWrap()) {
2010  Start =
2011  getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1);
2012  Step = getSignExtendExpr(Step, Ty, Depth + 1);
2013  return getAddRecExpr(Start, Step, L, SCEV::FlagNSW);
2014  }
2015 
2016  // Check whether the backedge-taken count is SCEVCouldNotCompute.
2017  // Note that this serves two purposes: It filters out loops that are
2018  // simply not analyzable, and it covers the case where this code is
2019  // being called from within backedge-taken count analysis, such that
2020  // attempting to ask for the backedge-taken count would likely result
2021  // in infinite recursion. In the later case, the analysis code will
2022  // cope with a conservative value, and it will take care to purge
2023  // that value once it has finished.
2024  const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
2025  if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
2026  // Manually compute the final value for AR, checking for
2027  // overflow.
2028 
2029  // Check whether the backedge-taken count can be losslessly casted to
2030  // the addrec's type. The count is always unsigned.
2031  const SCEV *CastedMaxBECount =
2032  getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
2033  const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
2034  CastedMaxBECount, MaxBECount->getType(), Depth);
2035  if (MaxBECount == RecastedMaxBECount) {
2036  Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
2037  // Check whether Start+Step*MaxBECount has no signed overflow.
2038  const SCEV *SMul = getMulExpr(CastedMaxBECount, Step,
2039  SCEV::FlagAnyWrap, Depth + 1);
2040  const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul,
2042  Depth + 1),
2043  WideTy, Depth + 1);
2044  const SCEV *WideStart = getSignExtendExpr(Start, WideTy, Depth + 1);
2045  const SCEV *WideMaxBECount =
2046  getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
2047  const SCEV *OperandExtendedAdd =
2048  getAddExpr(WideStart,
2049  getMulExpr(WideMaxBECount,
2050  getSignExtendExpr(Step, WideTy, Depth + 1),
2051  SCEV::FlagAnyWrap, Depth + 1),
2052  SCEV::FlagAnyWrap, Depth + 1);
2053  if (SAdd == OperandExtendedAdd) {
2054  // Cache knowledge of AR NSW, which is propagated to this AddRec.
2055  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2056  // Return the expression with the addrec on the outside.
2057  Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2058  Depth + 1);
2059  Step = getSignExtendExpr(Step, Ty, Depth + 1);
2060  return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2061  }
2062  // Similar to above, only this time treat the step value as unsigned.
2063  // This covers loops that count up with an unsigned step.
2064  OperandExtendedAdd =
2065  getAddExpr(WideStart,
2066  getMulExpr(WideMaxBECount,
2067  getZeroExtendExpr(Step, WideTy, Depth + 1),
2068  SCEV::FlagAnyWrap, Depth + 1),
2069  SCEV::FlagAnyWrap, Depth + 1);
2070  if (SAdd == OperandExtendedAdd) {
2071  // If AR wraps around then
2072  //
2073  // abs(Step) * MaxBECount > unsigned-max(AR->getType())
2074  // => SAdd != OperandExtendedAdd
2075  //
2076  // Thus (AR is not NW => SAdd != OperandExtendedAdd) <=>
2077  // (SAdd == OperandExtendedAdd => AR is NW)
2078 
2079  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
2080 
2081  // Return the expression with the addrec on the outside.
2082  Start = getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2083  Depth + 1);
2084  Step = getZeroExtendExpr(Step, Ty, Depth + 1);
2085  return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2086  }
2087  }
2088  }
2089 
2090  auto NewFlags = proveNoSignedWrapViaInduction(AR);
2091  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
2092  if (AR->hasNoSignedWrap()) {
2093  // Same as nsw case above - duplicated here to avoid a compile time
2094  // issue. It's not clear that the order of checks does matter, but
2095  // it's one of two issue possible causes for a change which was
2096  // reverted. Be conservative for the moment.
2097  Start =
2098  getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1);
2099  Step = getSignExtendExpr(Step, Ty, Depth + 1);
2100  return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2101  }
2102 
2103  // sext({C,+,Step}) --> (sext(D) + sext({C-D,+,Step}))<nuw><nsw>
2104  // if D + (C - D + Step * n) could be proven to not signed wrap
2105  // where D maximizes the number of trailing zeros of (C - D + Step * n)
2106  if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
2107  const APInt &C = SC->getAPInt();
2108  const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
2109  if (D != 0) {
2110  const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
2111  const SCEV *SResidual =
2112  getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
2113  const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
2114  return getAddExpr(SSExtD, SSExtR,
2116  Depth + 1);
2117  }
2118  }
2119 
2120  if (proveNoWrapByVaryingStart<SCEVSignExtendExpr>(Start, Step, L)) {
2121  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2122  Start =
2123  getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1);
2124  Step = getSignExtendExpr(Step, Ty, Depth + 1);
2125  return getAddRecExpr(Start, Step, L, AR->getNoWrapFlags());
2126  }
2127  }
2128 
2129  // If the input value is provably positive and we could not simplify
2130  // away the sext build a zext instead.
2131  if (isKnownNonNegative(Op))
2132  return getZeroExtendExpr(Op, Ty, Depth + 1);
2133 
2134  // The cast wasn't folded; create an explicit cast node.
2135  // Recompute the insert position, as it may have been invalidated.
2136  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2137  SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
2138  Op, Ty);
2139  UniqueSCEVs.InsertNode(S, IP);
2140  registerUser(S, { Op });
2141  return S;
2142 }
2143 
2145  Type *Ty) {
2146  switch (Kind) {
2147  case scTruncate:
2148  return getTruncateExpr(Op, Ty);
2149  case scZeroExtend:
2150  return getZeroExtendExpr(Op, Ty);
2151  case scSignExtend:
2152  return getSignExtendExpr(Op, Ty);
2153  case scPtrToInt:
2154  return getPtrToIntExpr(Op, Ty);
2155  default:
2156  llvm_unreachable("Not a SCEV cast expression!");
2157  }
2158 }
2159 
2160 /// getAnyExtendExpr - Return a SCEV for the given operand extended with
2161 /// unspecified bits out to the given type.
2163  Type *Ty) {
2164  assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
2165  "This is not an extending conversion!");
2166  assert(isSCEVable(Ty) &&
2167  "This is not a conversion to a SCEVable type!");
2168  Ty = getEffectiveSCEVType(Ty);
2169 
2170  // Sign-extend negative constants.
2171  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
2172  if (SC->getAPInt().isNegative())
2173  return getSignExtendExpr(Op, Ty);
2174 
2175  // Peel off a truncate cast.
2176  if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Op)) {
2177  const SCEV *NewOp = T->getOperand();
2178  if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty))
2179  return getAnyExtendExpr(NewOp, Ty);
2180  return getTruncateOrNoop(NewOp, Ty);
2181  }
2182 
2183  // Next try a zext cast. If the cast is folded, use it.
2184  const SCEV *ZExt = getZeroExtendExpr(Op, Ty);
2185  if (!isa<SCEVZeroExtendExpr>(ZExt))
2186  return ZExt;
2187 
2188  // Next try a sext cast. If the cast is folded, use it.
2189  const SCEV *SExt = getSignExtendExpr(Op, Ty);
2190  if (!isa<SCEVSignExtendExpr>(SExt))
2191  return SExt;
2192 
2193  // Force the cast to be folded into the operands of an addrec.
2194  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) {
2196  for (const SCEV *Op : AR->operands())
2197  Ops.push_back(getAnyExtendExpr(Op, Ty));
2198  return getAddRecExpr(Ops, AR->getLoop(), SCEV::FlagNW);
2199  }
2200 
2201  // If the expression is obviously signed, use the sext cast value.
2202  if (isa<SCEVSMaxExpr>(Op))
2203  return SExt;
2204 
2205  // Absent any other information, use the zext cast value.
2206  return ZExt;
2207 }
2208 
2209 /// Process the given Ops list, which is a list of operands to be added under
2210 /// the given scale, update the given map. This is a helper function for
2211 /// getAddRecExpr. As an example of what it does, given a sequence of operands
2212 /// that would form an add expression like this:
2213 ///
2214 /// m + n + 13 + (A * (o + p + (B * (q + m + 29)))) + r + (-1 * r)
2215 ///
2216 /// where A and B are constants, update the map with these values:
2217 ///
2218 /// (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0)
2219 ///
2220 /// and add 13 + A*B*29 to AccumulatedConstant.
2221 /// This will allow getAddRecExpr to produce this:
2222 ///
2223 /// 13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B)
2224 ///
2225 /// This form often exposes folding opportunities that are hidden in
2226 /// the original operand list.
2227 ///
2228 /// Return true iff it appears that any interesting folding opportunities
2229 /// may be exposed. This helps getAddRecExpr short-circuit extra work in
2230 /// the common case where no interesting opportunities are present, and
2231 /// is also used as a check to avoid infinite recursion.
2232 static bool
2235  APInt &AccumulatedConstant,
2236  const SCEV *const *Ops, size_t NumOperands,
2237  const APInt &Scale,
2238  ScalarEvolution &SE) {
2239  bool Interesting = false;
2240 
2241  // Iterate over the add operands. They are sorted, with constants first.
2242  unsigned i = 0;
2243  while (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
2244  ++i;
2245  // Pull a buried constant out to the outside.
2246  if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero())
2247  Interesting = true;
2248  AccumulatedConstant += Scale * C->getAPInt();
2249  }
2250 
2251  // Next comes everything else. We're especially interested in multiplies
2252  // here, but they're in the middle, so just visit the rest with one loop.
2253  for (; i != NumOperands; ++i) {
2254  const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[i]);
2255  if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) {
2256  APInt NewScale =
2257  Scale * cast<SCEVConstant>(Mul->getOperand(0))->getAPInt();
2258  if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) {
2259  // A multiplication of a constant with another add; recurse.
2260  const SCEVAddExpr *Add = cast<SCEVAddExpr>(Mul->getOperand(1));
2261  Interesting |=
2262  CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2263  Add->op_begin(), Add->getNumOperands(),
2264  NewScale, SE);
2265  } else {
2266  // A multiplication of a constant with some other value. Update
2267  // the map.
2269  const SCEV *Key = SE.getMulExpr(MulOps);
2270  auto Pair = M.insert({Key, NewScale});
2271  if (Pair.second) {
2272  NewOps.push_back(Pair.first->first);
2273  } else {
2274  Pair.first->second += NewScale;
2275  // The map already had an entry for this value, which may indicate
2276  // a folding opportunity.
2277  Interesting = true;
2278  }
2279  }
2280  } else {
2281  // An ordinary operand. Update the map.
2282  std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair =
2283  M.insert({Ops[i], Scale});
2284  if (Pair.second) {
2285  NewOps.push_back(Pair.first->first);
2286  } else {
2287  Pair.first->second += Scale;
2288  // The map already had an entry for this value, which may indicate
2289  // a folding opportunity.
2290  Interesting = true;
2291  }
2292  }
2293  }
2294 
2295  return Interesting;
2296 }
2297 
2299  const SCEV *LHS, const SCEV *RHS,
2300  const Instruction *CtxI) {
2301  const SCEV *(ScalarEvolution::*Operation)(const SCEV *, const SCEV *,
2302  SCEV::NoWrapFlags, unsigned);
2303  switch (BinOp) {
2304  default:
2305  llvm_unreachable("Unsupported binary op");
2306  case Instruction::Add:
2308  break;
2309  case Instruction::Sub:
2311  break;
2312  case Instruction::Mul:
2314  break;
2315  }
2316 
2317  const SCEV *(ScalarEvolution::*Extension)(const SCEV *, Type *, unsigned) =
2320 
2321  // Check ext(LHS op RHS) == ext(LHS) op ext(RHS)
2322  auto *NarrowTy = cast<IntegerType>(LHS->getType());
2323  auto *WideTy =
2324  IntegerType::get(NarrowTy->getContext(), NarrowTy->getBitWidth() * 2);
2325 
2326  const SCEV *A = (this->*Extension)(
2327  (this->*Operation)(LHS, RHS, SCEV::FlagAnyWrap, 0), WideTy, 0);
2328  const SCEV *LHSB = (this->*Extension)(LHS, WideTy, 0);
2329  const SCEV *RHSB = (this->*Extension)(RHS, WideTy, 0);
2330  const SCEV *B = (this->*Operation)(LHSB, RHSB, SCEV::FlagAnyWrap, 0);
2331  if (A == B)
2332  return true;
2333  // Can we use context to prove the fact we need?
2334  if (!CtxI)
2335  return false;
2336  // We can prove that add(x, constant) doesn't wrap if isKnownPredicateAt can
2337  // guarantee that x <= max_int - constant at the given context.
2338  // TODO: Support other operations.
2339  if (BinOp != Instruction::Add)
2340  return false;
2341  auto *RHSC = dyn_cast<SCEVConstant>(RHS);
2342  // TODO: Lift this limitation.
2343  if (!RHSC)
2344  return false;
2345  APInt C = RHSC->getAPInt();
2346  // TODO: Also lift this limitation.
2347  if (Signed && C.isNegative())
2348  return false;
2349  unsigned NumBits = C.getBitWidth();
2350  APInt Max =
2351  Signed ? APInt::getSignedMaxValue(NumBits) : APInt::getMaxValue(NumBits);
2352  APInt Limit = Max - C;
2354  return isKnownPredicateAt(Pred, LHS, getConstant(Limit), CtxI);
2355 }
2356 
2359  const OverflowingBinaryOperator *OBO) {
2360  // It cannot be done any better.
2361  if (OBO->hasNoUnsignedWrap() && OBO->hasNoSignedWrap())
2362  return std::nullopt;
2363 
2364  SCEV::NoWrapFlags Flags = SCEV::NoWrapFlags::FlagAnyWrap;
2365 
2366  if (OBO->hasNoUnsignedWrap())
2367  Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
2368  if (OBO->hasNoSignedWrap())
2369  Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW);
2370 
2371  bool Deduced = false;
2372 
2373  if (OBO->getOpcode() != Instruction::Add &&
2374  OBO->getOpcode() != Instruction::Sub &&
2375  OBO->getOpcode() != Instruction::Mul)
2376  return std::nullopt;
2377 
2378  const SCEV *LHS = getSCEV(OBO->getOperand(0));
2379  const SCEV *RHS = getSCEV(OBO->getOperand(1));
2380 
2381  const Instruction *CtxI =
2382  UseContextForNoWrapFlagInference ? dyn_cast<Instruction>(OBO) : nullptr;
2383  if (!OBO->hasNoUnsignedWrap() &&
2385  /* Signed */ false, LHS, RHS, CtxI)) {
2386  Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
2387  Deduced = true;
2388  }
2389 
2390  if (!OBO->hasNoSignedWrap() &&
2392  /* Signed */ true, LHS, RHS, CtxI)) {
2393  Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW);
2394  Deduced = true;
2395  }
2396 
2397  if (Deduced)
2398  return Flags;
2399  return std::nullopt;
2400 }
2401 
2402 // We're trying to construct a SCEV of type `Type' with `Ops' as operands and
2403 // `OldFlags' as can't-wrap behavior. Infer a more aggressive set of
2404 // can't-overflow flags for the operation if possible.
2405 static SCEV::NoWrapFlags
2407  const ArrayRef<const SCEV *> Ops,
2408  SCEV::NoWrapFlags Flags) {
2409  using namespace std::placeholders;
2410 
2411  using OBO = OverflowingBinaryOperator;
2412 
2413  bool CanAnalyze =
2414  Type == scAddExpr || Type == scAddRecExpr || Type == scMulExpr;
2415  (void)CanAnalyze;
2416  assert(CanAnalyze && "don't call from other places!");
2417 
2418  int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
2419  SCEV::NoWrapFlags SignOrUnsignWrap =
2420  ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2421 
2422  // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
2423  auto IsKnownNonNegative = [&](const SCEV *S) {
2424  return SE->isKnownNonNegative(S);
2425  };
2426 
2427  if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative))
2428  Flags =
2429  ScalarEvolution::setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
2430 
2431  SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2432 
2433  if (SignOrUnsignWrap != SignOrUnsignMask &&
2434  (Type == scAddExpr || Type == scMulExpr) && Ops.size() == 2 &&
2435  isa<SCEVConstant>(Ops[0])) {
2436 
2437  auto Opcode = [&] {
2438  switch (Type) {
2439  case scAddExpr:
2440  return Instruction::Add;
2441  case scMulExpr:
2442  return Instruction::Mul;
2443  default:
2444  llvm_unreachable("Unexpected SCEV op.");
2445  }
2446  }();
2447 
2448  const APInt &C = cast<SCEVConstant>(Ops[0])->getAPInt();
2449 
2450  // (A <opcode> C) --> (A <opcode> C)<nsw> if the op doesn't sign overflow.
2451  if (!(SignOrUnsignWrap & SCEV::FlagNSW)) {
2453  Opcode, C, OBO::NoSignedWrap);
2454  if (NSWRegion.contains(SE->getSignedRange(Ops[1])))
2455  Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW);
2456  }
2457 
2458  // (A <opcode> C) --> (A <opcode> C)<nuw> if the op doesn't unsign overflow.
2459  if (!(SignOrUnsignWrap & SCEV::FlagNUW)) {
2461  Opcode, C, OBO::NoUnsignedWrap);
2462  if (NUWRegion.contains(SE->getUnsignedRange(Ops[1])))
2463  Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
2464  }
2465  }
2466 
2467  // <0,+,nonnegative><nw> is also nuw
2468  // TODO: Add corresponding nsw case
2470  !ScalarEvolution::hasFlags(Flags, SCEV::FlagNUW) && Ops.size() == 2 &&
2471  Ops[0]->isZero() && IsKnownNonNegative(Ops[1]))
2472  Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
2473 
2474  // both (udiv X, Y) * Y and Y * (udiv X, Y) are always NUW
2476  Ops.size() == 2) {
2477  if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[0]))
2478  if (UDiv->getOperand(1) == Ops[1])
2479  Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
2480  if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[1]))
2481  if (UDiv->getOperand(1) == Ops[0])
2482  Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
2483  }
2484 
2485  return Flags;
2486 }
2487 
2489  return isLoopInvariant(S, L) && properlyDominates(S, L->getHeader());
2490 }
2491 
2492 /// Get a canonical add expression, or something simpler if possible.
2494  SCEV::NoWrapFlags OrigFlags,
2495  unsigned Depth) {
2496  assert(!(OrigFlags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) &&
2497  "only nuw or nsw allowed");
2498  assert(!Ops.empty() && "Cannot get empty add!");
2499  if (Ops.size() == 1) return Ops[0];
2500 #ifndef NDEBUG
2501  Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
2502  for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2503  assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
2504  "SCEVAddExpr operand types don't match!");
2505  unsigned NumPtrs = count_if(
2506  Ops, [](const SCEV *Op) { return Op->getType()->isPointerTy(); });
2507  assert(NumPtrs <= 1 && "add has at most one pointer operand");
2508 #endif
2509 
2510  // Sort by complexity, this groups all similar expression types together.
2511  GroupByComplexity(Ops, &LI, DT);
2512 
2513  // If there are any constants, fold them together.
2514  unsigned Idx = 0;
2515  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
2516  ++Idx;
2517  assert(Idx < Ops.size());
2518  while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
2519  // We found two constants, fold them together!
2520  Ops[0] = getConstant(LHSC->getAPInt() + RHSC->getAPInt());
2521  if (Ops.size() == 2) return Ops[0];
2522  Ops.erase(Ops.begin()+1); // Erase the folded element
2523  LHSC = cast<SCEVConstant>(Ops[0]);
2524  }
2525 
2526  // If we are left with a constant zero being added, strip it off.
2527  if (LHSC->getValue()->isZero()) {
2528  Ops.erase(Ops.begin());
2529  --Idx;
2530  }
2531 
2532  if (Ops.size() == 1) return Ops[0];
2533  }
2534 
2535  // Delay expensive flag strengthening until necessary.
2536  auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
2537  return StrengthenNoWrapFlags(this, scAddExpr, Ops, OrigFlags);
2538  };
2539 
2540  // Limit recursion calls depth.
2541  if (Depth > MaxArithDepth || hasHugeExpression(Ops))
2542  return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2543 
2544  if (SCEV *S = findExistingSCEVInCache(scAddExpr, Ops)) {
2545  // Don't strengthen flags if we have no new information.
2546  SCEVAddExpr *Add = static_cast<SCEVAddExpr *>(S);
2547  if (Add->getNoWrapFlags(OrigFlags) != OrigFlags)
2548  Add->setNoWrapFlags(ComputeFlags(Ops));
2549  return S;
2550  }
2551 
2552  // Okay, check to see if the same value occurs in the operand list more than
2553  // once. If so, merge them together into an multiply expression. Since we
2554  // sorted the list, these values are required to be adjacent.
2555  Type *Ty = Ops[0]->getType();
2556  bool FoundMatch = false;
2557  for (unsigned i = 0, e = Ops.size(); i != e-1; ++i)
2558  if (Ops[i] == Ops[i+1]) { // X + Y + Y --> X + Y*2
2559  // Scan ahead to count how many equal operands there are.
2560  unsigned Count = 2;
2561  while (i+Count != e && Ops[i+Count] == Ops[i])
2562  ++Count;
2563  // Merge the values into a multiply.
2564  const SCEV *Scale = getConstant(Ty, Count);
2565  const SCEV *Mul = getMulExpr(Scale, Ops[i], SCEV::FlagAnyWrap, Depth + 1);
2566  if (Ops.size() == Count)
2567  return Mul;
2568  Ops[i] = Mul;
2569  Ops.erase(Ops.begin()+i+1, Ops.begin()+i+Count);
2570  --i; e -= Count - 1;
2571  FoundMatch = true;
2572  }
2573  if (FoundMatch)
2574  return getAddExpr(Ops, OrigFlags, Depth + 1);
2575 
2576  // Check for truncates. If all the operands are truncated from the same
2577  // type, see if factoring out the truncate would permit the result to be
2578  // folded. eg., n*trunc(x) + m*trunc(y) --> trunc(trunc(m)*x + trunc(n)*y)
2579  // if the contents of the resulting outer trunc fold to something simple.
2580  auto FindTruncSrcType = [&]() -> Type * {
2581  // We're ultimately looking to fold an addrec of truncs and muls of only
2582  // constants and truncs, so if we find any other types of SCEV
2583  // as operands of the addrec then we bail and return nullptr here.
2584  // Otherwise, we return the type of the operand of a trunc that we find.
2585  if (auto *T = dyn_cast<SCEVTruncateExpr>(Ops[Idx]))
2586  return T->getOperand()->getType();
2587  if (const auto *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
2588  const auto *LastOp = Mul->getOperand(Mul->getNumOperands() - 1);
2589  if (const auto *T = dyn_cast<SCEVTruncateExpr>(LastOp))
2590  return T->getOperand()->getType();
2591  }
2592  return nullptr;
2593  };
2594  if (auto *SrcType = FindTruncSrcType()) {
2596  bool Ok = true;
2597  // Check all the operands to see if they can be represented in the
2598  // source type of the truncate.
2599  for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
2600  if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Ops[i])) {
2601  if (T->getOperand()->getType() != SrcType) {
2602  Ok = false;
2603  break;
2604  }
2605  LargeOps.push_back(T->getOperand());
2606  } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
2607  LargeOps.push_back(getAnyExtendExpr(C, SrcType));
2608  } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Ops[i])) {
2609  SmallVector<const SCEV *, 8> LargeMulOps;
2610  for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) {
2611  if (const SCEVTruncateExpr *T =
2612  dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) {
2613  if (T->getOperand()->getType() != SrcType) {
2614  Ok = false;
2615  break;
2616  }
2617  LargeMulOps.push_back(T->getOperand());
2618  } else if (const auto *C = dyn_cast<SCEVConstant>(M->getOperand(j))) {
2619  LargeMulOps.push_back(getAnyExtendExpr(C, SrcType));
2620  } else {
2621  Ok = false;
2622  break;
2623  }
2624  }
2625  if (Ok)
2626  LargeOps.push_back(getMulExpr(LargeMulOps, SCEV::FlagAnyWrap, Depth + 1));
2627  } else {
2628  Ok = false;
2629  break;
2630  }
2631  }
2632  if (Ok) {
2633  // Evaluate the expression in the larger type.
2634  const SCEV *Fold = getAddExpr(LargeOps, SCEV::FlagAnyWrap, Depth + 1);
2635  // If it folds to something simple, use it. Otherwise, don't.
2636  if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
2637  return getTruncateExpr(Fold, Ty);
2638  }
2639  }
2640 
2641  if (Ops.size() == 2) {
2642  // Check if we have an expression of the form ((X + C1) - C2), where C1 and
2643  // C2 can be folded in a way that allows retaining wrapping flags of (X +
2644  // C1).
2645  const SCEV *A = Ops[0];
2646  const SCEV *B = Ops[1];
2647  auto *AddExpr = dyn_cast<SCEVAddExpr>(B);
2648  auto *C = dyn_cast<SCEVConstant>(A);
2649  if (AddExpr && C && isa<SCEVConstant>(AddExpr->getOperand(0))) {
2650  auto C1 = cast<SCEVConstant>(AddExpr->getOperand(0))->getAPInt();
2651  auto C2 = C->getAPInt();
2652  SCEV::NoWrapFlags PreservedFlags = SCEV::FlagAnyWrap;
2653 
2654  APInt ConstAdd = C1 + C2;
2655  auto AddFlags = AddExpr->getNoWrapFlags();
2656  // Adding a smaller constant is NUW if the original AddExpr was NUW.
2657  if (ScalarEvolution::hasFlags(AddFlags, SCEV::FlagNUW) &&
2658  ConstAdd.ule(C1)) {
2659  PreservedFlags =
2660  ScalarEvolution::setFlags(PreservedFlags, SCEV::FlagNUW);
2661  }
2662 
2663  // Adding a constant with the same sign and small magnitude is NSW, if the
2664  // original AddExpr was NSW.
2665  if (ScalarEvolution::hasFlags(AddFlags, SCEV::FlagNSW) &&
2666  C1.isSignBitSet() == ConstAdd.isSignBitSet() &&
2667  ConstAdd.abs().ule(C1.abs())) {
2668  PreservedFlags =
2669  ScalarEvolution::setFlags(PreservedFlags, SCEV::FlagNSW);
2670  }
2671 
2672  if (PreservedFlags != SCEV::FlagAnyWrap) {
2673  SmallVector<const SCEV *, 4> NewOps(AddExpr->operands());
2674  NewOps[0] = getConstant(ConstAdd);
2675  return getAddExpr(NewOps, PreservedFlags);
2676  }
2677  }
2678  }
2679 
2680  // Canonicalize (-1 * urem X, Y) + X --> (Y * X/Y)
2681  if (Ops.size() == 2) {
2682  const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[0]);
2683  if (Mul && Mul->getNumOperands() == 2 &&
2684  Mul->getOperand(0)->isAllOnesValue()) {
2685  const SCEV *X;
2686  const SCEV *Y;
2687  if (matchURem(Mul->getOperand(1), X, Y) && X == Ops[1]) {
2688  return getMulExpr(Y, getUDivExpr(X, Y));
2689  }
2690  }
2691  }
2692 
2693  // Skip past any other cast SCEVs.
2694  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
2695  ++Idx;
2696 
2697  // If there are add operands they would be next.
2698  if (Idx < Ops.size()) {
2699  bool DeletedAdd = false;
2700  // If the original flags and all inlined SCEVAddExprs are NUW, use the
2701  // common NUW flag for expression after inlining. Other flags cannot be
2702  // preserved, because they may depend on the original order of operations.
2703  SCEV::NoWrapFlags CommonFlags = maskFlags(OrigFlags, SCEV::FlagNUW);
2704  while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
2705  if (Ops.size() > AddOpsInlineThreshold ||
2706  Add->getNumOperands() > AddOpsInlineThreshold)
2707  break;
2708  // If we have an add, expand the add operands onto the end of the operands
2709  // list.
2710  Ops.erase(Ops.begin()+Idx);
2711  Ops.append(Add->op_begin(), Add->op_end());
2712  DeletedAdd = true;
2713  CommonFlags = maskFlags(CommonFlags, Add->getNoWrapFlags());
2714  }
2715 
2716  // If we deleted at least one add, we added operands to the end of the list,
2717  // and they are not necessarily sorted. Recurse to resort and resimplify
2718  // any operands we just acquired.
2719  if (DeletedAdd)
2720  return getAddExpr(Ops, CommonFlags, Depth + 1);
2721  }
2722 
2723  // Skip over the add expression until we get to a multiply.
2724  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
2725  ++Idx;
2726 
2727  // Check to see if there are any folding opportunities present with
2728  // operands multiplied by constant values.
2729  if (Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx])) {
2733  APInt AccumulatedConstant(BitWidth, 0);
2734  if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2735  Ops.data(), Ops.size(),
2736  APInt(BitWidth, 1), *this)) {
2737  struct APIntCompare {
2738  bool operator()(const APInt &LHS, const APInt &RHS) const {
2739  return LHS.ult(RHS);
2740  }
2741  };
2742 
2743  // Some interesting folding opportunity is present, so its worthwhile to
2744  // re-generate the operands list. Group the operands by constant scale,
2745  // to avoid multiplying by the same constant scale multiple times.
2746  std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare> MulOpLists;
2747  for (const SCEV *NewOp : NewOps)
2748  MulOpLists[M.find(NewOp)->second].push_back(NewOp);
2749  // Re-generate the operands list.
2750  Ops.clear();
2751  if (AccumulatedConstant != 0)
2752  Ops.push_back(getConstant(AccumulatedConstant));
2753  for (auto &MulOp : MulOpLists) {
2754  if (MulOp.first == 1) {
2755  Ops.push_back(getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1));
2756  } else if (MulOp.first != 0) {
2757  Ops.push_back(getMulExpr(
2758  getConstant(MulOp.first),
2759  getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1),
2760  SCEV::FlagAnyWrap, Depth + 1));
2761  }
2762  }
2763  if (Ops.empty())
2764  return getZero(Ty);
2765  if (Ops.size() == 1)
2766  return Ops[0];
2767  return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2768  }
2769  }
2770 
2771  // If we are adding something to a multiply expression, make sure the
2772  // something is not already an operand of the multiply. If so, merge it into
2773  // the multiply.
2774  for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) {
2775  const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]);
2776  for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) {
2777  const SCEV *MulOpSCEV = Mul->getOperand(MulOp);
2778  if (isa<SCEVConstant>(MulOpSCEV))
2779  continue;
2780  for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
2781  if (MulOpSCEV == Ops[AddOp]) {
2782  // Fold W + X + (X * Y * Z) --> W + (X * ((Y*Z)+1))
2783  const SCEV *InnerMul = Mul->getOperand(MulOp == 0);
2784  if (Mul->getNumOperands() != 2) {
2785  // If the multiply has more than two operands, we must get the
2786  // Y*Z term.
2788  Mul->op_begin()+MulOp);
2789  MulOps.append(Mul->op_begin()+MulOp+1, Mul->op_end());
2790  InnerMul = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2791  }
2792  SmallVector<const SCEV *, 2> TwoOps = {getOne(Ty), InnerMul};
2793  const SCEV *AddOne = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2794  const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV,
2795  SCEV::FlagAnyWrap, Depth + 1);
2796  if (Ops.size() == 2) return OuterMul;
2797  if (AddOp < Idx) {
2798  Ops.erase(Ops.begin()+AddOp);
2799  Ops.erase(Ops.begin()+Idx-1);
2800  } else {
2801  Ops.erase(Ops.begin()+Idx);
2802  Ops.erase(Ops.begin()+AddOp-1);
2803  }
2804  Ops.push_back(OuterMul);
2805  return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2806  }
2807 
2808  // Check this multiply against other multiplies being added together.
2809  for (unsigned OtherMulIdx = Idx+1;
2810  OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]);
2811  ++OtherMulIdx) {
2812  const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]);
2813  // If MulOp occurs in OtherMul, we can fold the two multiplies
2814  // together.
2815  for (unsigned OMulOp = 0, e = OtherMul->getNumOperands();
2816  OMulOp != e; ++OMulOp)
2817  if (OtherMul->getOperand(OMulOp) == MulOpSCEV) {
2818  // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E))
2819  const SCEV *InnerMul1 = Mul->getOperand(MulOp == 0);
2820  if (Mul->getNumOperands() != 2) {
2822  Mul->op_begin()+MulOp);
2823  MulOps.append(Mul->op_begin()+MulOp+1, Mul->op_end());
2824  InnerMul1 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2825  }
2826  const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0);
2827  if (OtherMul->getNumOperands() != 2) {
2828  SmallVector<const SCEV *, 4> MulOps(OtherMul->op_begin(),
2829  OtherMul->op_begin()+OMulOp);
2830  MulOps.append(OtherMul->op_begin()+OMulOp+1, OtherMul->op_end());
2831  InnerMul2 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2832  }
2833  SmallVector<const SCEV *, 2> TwoOps = {InnerMul1, InnerMul2};
2834  const SCEV *InnerMulSum =
2835  getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2836  const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum,
2837  SCEV::FlagAnyWrap, Depth + 1);
2838  if (Ops.size() == 2) return OuterMul;
2839  Ops.erase(Ops.begin()+Idx);
2840  Ops.erase(Ops.begin()+OtherMulIdx-1);
2841  Ops.push_back(OuterMul);
2842  return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2843  }
2844  }
2845  }
2846  }
2847 
2848  // If there are any add recurrences in the operands list, see if any other
2849  // added values are loop invariant. If so, we can fold them into the
2850  // recurrence.
2851  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
2852  ++Idx;
2853 
2854  // Scan over all recurrences, trying to fold loop invariants into them.
2855  for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
2856  // Scan all of the other operands to this add and add them to the vector if
2857  // they are loop invariant w.r.t. the recurrence.
2859  const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
2860  const Loop *AddRecLoop = AddRec->getLoop();
2861  for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2862  if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) {
2863  LIOps.push_back(Ops[i]);
2864  Ops.erase(Ops.begin()+i);
2865  --i; --e;
2866  }
2867 
2868  // If we found some loop invariants, fold them into the recurrence.
2869  if (!LIOps.empty()) {
2870  // Compute nowrap flags for the addition of the loop-invariant ops and
2871  // the addrec. Temporarily push it as an operand for that purpose. These
2872  // flags are valid in the scope of the addrec only.
2873  LIOps.push_back(AddRec);
2874  SCEV::NoWrapFlags Flags = ComputeFlags(LIOps);
2875  LIOps.pop_back();
2876 
2877  // NLI + LI + {Start,+,Step} --> NLI + {LI+Start,+,Step}
2878  LIOps.push_back(AddRec->getStart());
2879 
2880  SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2881 
2882  // It is not in general safe to propagate flags valid on an add within
2883  // the addrec scope to one outside it. We must prove that the inner
2884  // scope is guaranteed to execute if the outer one does to be able to
2885  // safely propagate. We know the program is undefined if poison is
2886  // produced on the inner scoped addrec. We also know that *for this use*
2887  // the outer scoped add can't overflow (because of the flags we just
2888  // computed for the inner scoped add) without the program being undefined.
2889  // Proving that entry to the outer scope neccesitates entry to the inner
2890  // scope, thus proves the program undefined if the flags would be violated
2891  // in the outer scope.
2892  SCEV::NoWrapFlags AddFlags = Flags;
2893  if (AddFlags != SCEV::FlagAnyWrap) {
2894  auto *DefI = getDefiningScopeBound(LIOps);
2895  auto *ReachI = &*AddRecLoop->getHeader()->begin();
2896  if (!isGuaranteedToTransferExecutionTo(DefI, ReachI))
2897  AddFlags = SCEV::FlagAnyWrap;
2898  }
2899  AddRecOps[0] = getAddExpr(LIOps, AddFlags, Depth + 1);
2900 
2901  // Build the new addrec. Propagate the NUW and NSW flags if both the
2902  // outer add and the inner addrec are guaranteed to have no overflow.
2903  // Always propagate NW.
2904  Flags = AddRec->getNoWrapFlags(setFlags(Flags, SCEV::FlagNW));
2905  const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags);
2906 
2907  // If all of the other operands were loop invariant, we are done.
2908  if (Ops.size() == 1) return NewRec;
2909 
2910  // Otherwise, add the folded AddRec by the non-invariant parts.
2911  for (unsigned i = 0;; ++i)
2912  if (Ops[i] == AddRec) {
2913  Ops[i] = NewRec;
2914  break;
2915  }
2916  return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2917  }
2918 
2919  // Okay, if there weren't any loop invariants to be folded, check to see if
2920  // there are multiple AddRec's with the same loop induction variable being
2921  // added together. If so, we can fold them.
2922  for (unsigned OtherIdx = Idx+1;
2923  OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2924  ++OtherIdx) {
2925  // We expect the AddRecExpr's to be sorted in reverse dominance order,
2926  // so that the 1st found AddRecExpr is dominated by all others.
2927  assert(DT.dominates(
2928  cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()->getHeader(),
2929  AddRec->getLoop()->getHeader()) &&
2930  "AddRecExprs are not sorted in reverse dominance order?");
2931  if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) {
2932  // Other + {A,+,B}<L> + {C,+,D}<L> --> Other + {A+C,+,B+D}<L>
2933  SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2934  for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2935  ++OtherIdx) {
2936  const auto *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
2937  if (OtherAddRec->getLoop() == AddRecLoop) {
2938  for (unsigned i = 0, e = OtherAddRec->getNumOperands();
2939  i != e; ++i) {
2940  if (i >= AddRecOps.size()) {
2941  AddRecOps.append(OtherAddRec->op_begin()+i,
2942  OtherAddRec->op_end());
2943  break;
2944  }
2945  SmallVector<const SCEV *, 2> TwoOps = {
2946  AddRecOps[i], OtherAddRec->getOperand(i)};
2947  AddRecOps[i] = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2948  }
2949  Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
2950  }
2951  }
2952  // Step size has changed, so we cannot guarantee no self-wraparound.
2953  Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap);
2954  return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2955  }
2956  }
2957 
2958  // Otherwise couldn't fold anything into this recurrence. Move onto the
2959  // next one.
2960  }
2961 
2962  // Okay, it looks like we really DO need an add expr. Check to see if we
2963  // already have one, otherwise create a new one.
2964  return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2965 }
2966 
2967 const SCEV *
2968 ScalarEvolution::getOrCreateAddExpr(ArrayRef<const SCEV *> Ops,
2969  SCEV::NoWrapFlags Flags) {
2972  for (const SCEV *Op : Ops)
2973  ID.AddPointer(Op);
2974  void *IP = nullptr;
2975  SCEVAddExpr *S =
2976  static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2977  if (!S) {
2978  const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2979  std::uninitialized_copy(Ops.begin(), Ops.end(), O);
2980  S = new (SCEVAllocator)
2981  SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size());
2982  UniqueSCEVs.InsertNode(S, IP);
2983  registerUser(S, Ops);
2984  }
2985  S->setNoWrapFlags(Flags);
2986  return S;
2987 }
2988 
2989 const SCEV *
2990 ScalarEvolution::getOrCreateAddRecExpr(ArrayRef<const SCEV *> Ops,
2991  const Loop *L, SCEV::NoWrapFlags Flags) {
2994  for (const SCEV *Op : Ops)
2995  ID.AddPointer(Op);
2996  ID.AddPointer(L);
2997  void *IP = nullptr;
2998  SCEVAddRecExpr *S =
2999  static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3000  if (!S) {
3001  const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3002  std::uninitialized_copy(Ops.begin(), Ops.end(), O);
3003  S = new (SCEVAllocator)
3004  SCEVAddRecExpr(ID.Intern(SCEVAllocator), O, Ops.size(), L);
3005  UniqueSCEVs.InsertNode(S, IP);
3006  LoopUsers[L].push_back(S);
3007  registerUser(S, Ops);
3008  }
3009  setNoWrapFlags(S, Flags);
3010  return S;
3011 }
3012 
3013 const SCEV *
3014 ScalarEvolution::getOrCreateMulExpr(ArrayRef<const SCEV *> Ops,
3015  SCEV::NoWrapFlags Flags) {
3018  for (const SCEV *Op : Ops)
3019  ID.AddPointer(Op);
3020  void *IP = nullptr;
3021  SCEVMulExpr *S =
3022  static_cast<SCEVMulExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
3023  if (!S) {
3024  const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3025  std::uninitialized_copy(Ops.begin(), Ops.end(), O);
3026  S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator),
3027  O, Ops.size());
3028  UniqueSCEVs.InsertNode(S, IP);
3029  registerUser(S, Ops);
3030  }
3031  S->setNoWrapFlags(Flags);
3032  return S;
3033 }
3034 
3035 static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) {
3036  uint64_t k = i*j;
3037  if (j > 1 && k / j != i) Overflow = true;
3038  return k;
3039 }
3040 
3041 /// Compute the result of "n choose k", the binomial coefficient. If an
3042 /// intermediate computation overflows, Overflow will be set and the return will
3043 /// be garbage. Overflow is not cleared on absence of overflow.
3044 static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) {
3045  // We use the multiplicative formula:
3046  // n(n-1)(n-2)...(n-(k-1)) / k(k-1)(k-2)...1 .
3047  // At each iteration, we take the n-th term of the numeral and divide by the
3048  // (k-n)th term of the denominator. This division will always produce an
3049  // integral result, and helps reduce the chance of overflow in the
3050  // intermediate computations. However, we can still overflow even when the
3051  // final result would fit.
3052 
3053  if (n == 0 || n == k) return 1;
3054  if (k > n) return 0;
3055 
3056  if (k > n/2)
3057  k = n-k;
3058 
3059  uint64_t r = 1;
3060  for (uint64_t i = 1; i <= k; ++i) {
3061  r = umul_ov(r, n-(i-1), Overflow);
3062  r /= i;
3063  }
3064  return r;
3065 }
3066 
3067 /// Determine if any of the operands in this SCEV are a constant or if
3068 /// any of the add or multiply expressions in this SCEV contain a constant.
3069 static bool containsConstantInAddMulChain(const SCEV *StartExpr) {
3070  struct FindConstantInAddMulChain {
3071  bool FoundConstant = false;
3072 
3073  bool follow(const SCEV *S) {
3074  FoundConstant |= isa<SCEVConstant>(S);
3075  return isa<SCEVAddExpr>(S) || isa<SCEVMulExpr>(S);
3076  }
3077 
3078  bool isDone() const {
3079  return FoundConstant;
3080  }
3081  };
3082 
3083  FindConstantInAddMulChain F;
3085  ST.visitAll(StartExpr);
3086  return F.FoundConstant;
3087 }
3088 
3089 /// Get a canonical multiply expression, or something simpler if possible.
3091  SCEV::NoWrapFlags OrigFlags,
3092  unsigned Depth) {
3093  assert(OrigFlags == maskFlags(OrigFlags, SCEV::FlagNUW | SCEV::FlagNSW) &&
3094  "only nuw or nsw allowed");
3095  assert(!Ops.empty() && "Cannot get empty mul!");
3096  if (Ops.size() == 1) return Ops[0];
3097 #ifndef NDEBUG
3098  Type *ETy = Ops[0]->getType();
3099  assert(!ETy->isPointerTy());
3100  for (unsigned i = 1, e = Ops.size(); i != e; ++i)
3101  assert(Ops[i]->getType() == ETy &&
3102  "SCEVMulExpr operand types don't match!");
3103 #endif
3104 
3105  // Sort by complexity, this groups all similar expression types together.
3106  GroupByComplexity(Ops, &LI, DT);
3107 
3108  // If there are any constants, fold them together.
3109  unsigned Idx = 0;
3110  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3111  ++Idx;
3112  assert(Idx < Ops.size());
3113  while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
3114  // We found two constants, fold them together!
3115  Ops[0] = getConstant(LHSC->getAPInt() * RHSC->getAPInt());
3116  if (Ops.size() == 2) return Ops[0];
3117  Ops.erase(Ops.begin()+1); // Erase the folded element
3118  LHSC = cast<SCEVConstant>(Ops[0]);
3119  }
3120 
3121  // If we have a multiply of zero, it will always be zero.
3122  if (LHSC->getValue()->isZero())
3123  return LHSC;
3124 
3125  // If we are left with a constant one being multiplied, strip it off.
3126  if (LHSC->getValue()->isOne()) {
3127  Ops.erase(Ops.begin());
3128  --Idx;
3129  }
3130 
3131  if (Ops.size() == 1)
3132  return Ops[0];
3133  }
3134 
3135  // Delay expensive flag strengthening until necessary.
3136  auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
3137  return StrengthenNoWrapFlags(this, scMulExpr, Ops, OrigFlags);
3138  };
3139 
3140  // Limit recursion calls depth.
3141  if (Depth > MaxArithDepth || hasHugeExpression(Ops))
3142  return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3143 
3144  if (SCEV *S = findExistingSCEVInCache(scMulExpr, Ops)) {
3145  // Don't strengthen flags if we have no new information.
3146  SCEVMulExpr *Mul = static_cast<SCEVMulExpr *>(S);
3147  if (Mul->getNoWrapFlags(OrigFlags) != OrigFlags)
3148  Mul->setNoWrapFlags(ComputeFlags(Ops));
3149  return S;
3150  }
3151 
3152  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3153  if (Ops.size() == 2) {
3154  // C1*(C2+V) -> C1*C2 + C1*V
3155  if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1]))
3156  // If any of Add's ops are Adds or Muls with a constant, apply this
3157  // transformation as well.
3158  //
3159  // TODO: There are some cases where this transformation is not
3160  // profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of
3161  // this transformation should be narrowed down.
3162  if (Add->getNumOperands() == 2 && containsConstantInAddMulChain(Add)) {
3163  const SCEV *LHS = getMulExpr(LHSC, Add->getOperand(0),
3164  SCEV::FlagAnyWrap, Depth + 1);
3165  const SCEV *RHS = getMulExpr(LHSC, Add->getOperand(1),
3166  SCEV::FlagAnyWrap, Depth + 1);
3167  return getAddExpr(LHS, RHS, SCEV::FlagAnyWrap, Depth + 1);
3168  }
3169 
3170  if (Ops[0]->isAllOnesValue()) {
3171  // If we have a mul by -1 of an add, try distributing the -1 among the
3172  // add operands.
3173  if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) {
3175  bool AnyFolded = false;
3176  for (const SCEV *AddOp : Add->operands()) {
3177  const SCEV *Mul = getMulExpr(Ops[0], AddOp, SCEV::FlagAnyWrap,
3178  Depth + 1);
3179  if (!isa<SCEVMulExpr>(Mul)) AnyFolded = true;
3180  NewOps.push_back(Mul);
3181  }
3182  if (AnyFolded)
3183  return getAddExpr(NewOps, SCEV::FlagAnyWrap, Depth + 1);
3184  } else if (const auto *AddRec = dyn_cast<SCEVAddRecExpr>(Ops[1])) {
3185  // Negation preserves a recurrence's no self-wrap property.
3187  for (const SCEV *AddRecOp : AddRec->operands())
3188  Operands.push_back(getMulExpr(Ops[0], AddRecOp, SCEV::FlagAnyWrap,
3189  Depth + 1));
3190 
3191  return getAddRecExpr(Operands, AddRec->getLoop(),
3192  AddRec->getNoWrapFlags(SCEV::FlagNW));
3193  }
3194  }
3195  }
3196  }
3197 
3198  // Skip over the add expression until we get to a multiply.
3199  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
3200  ++Idx;
3201 
3202  // If there are mul operands inline them all into this expression.
3203  if (Idx < Ops.size()) {
3204  bool DeletedMul = false;
3205  while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
3206  if (Ops.size() > MulOpsInlineThreshold)
3207  break;
3208  // If we have an mul, expand the mul operands onto the end of the
3209  // operands list.
3210  Ops.erase(Ops.begin()+Idx);
3211  Ops.append(Mul->op_begin(), Mul->op_end());
3212  DeletedMul = true;
3213  }
3214 
3215  // If we deleted at least one mul, we added operands to the end of the
3216  // list, and they are not necessarily sorted. Recurse to resort and
3217  // resimplify any operands we just acquired.
3218  if (DeletedMul)
3219  return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3220  }
3221 
3222  // If there are any add recurrences in the operands list, see if any other
3223  // added values are loop invariant. If so, we can fold them into the
3224  // recurrence.
3225  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
3226  ++Idx;
3227 
3228  // Scan over all recurrences, trying to fold loop invariants into them.
3229  for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
3230  // Scan all of the other operands to this mul and add them to the vector
3231  // if they are loop invariant w.r.t. the recurrence.
3233  const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
3234  const Loop *AddRecLoop = AddRec->getLoop();
3235  for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3236  if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) {
3237  LIOps.push_back(Ops[i]);
3238  Ops.erase(Ops.begin()+i);
3239  --i; --e;
3240  }
3241 
3242  // If we found some loop invariants, fold them into the recurrence.
3243  if (!LIOps.empty()) {
3244  // NLI * LI * {Start,+,Step} --> NLI * {LI*Start,+,LI*Step}
3246  NewOps.reserve(AddRec->getNumOperands());
3247  const SCEV *Scale = getMulExpr(LIOps, SCEV::FlagAnyWrap, Depth + 1);
3248  for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i)
3249  NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i),
3250  SCEV::FlagAnyWrap, Depth + 1));
3251 
3252  // Build the new addrec. Propagate the NUW and NSW flags if both the
3253  // outer mul and the inner addrec are guaranteed to have no overflow.
3254  //
3255  // No self-wrap cannot be guaranteed after changing the step size, but
3256  // will be inferred if either NUW or NSW is true.
3257  SCEV::NoWrapFlags Flags = ComputeFlags({Scale, AddRec});
3258  const SCEV *NewRec = getAddRecExpr(
3259  NewOps, AddRecLoop, AddRec->getNoWrapFlags(Flags));
3260 
3261  // If all of the other operands were loop invariant, we are done.
3262  if (Ops.size() == 1) return NewRec;
3263 
3264  // Otherwise, multiply the folded AddRec by the non-invariant parts.
3265  for (unsigned i = 0;; ++i)
3266  if (Ops[i] == AddRec) {
3267  Ops[i] = NewRec;
3268  break;
3269  }
3270  return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3271  }
3272 
3273  // Okay, if there weren't any loop invariants to be folded, check to see
3274  // if there are multiple AddRec's with the same loop induction variable
3275  // being multiplied together. If so, we can fold them.
3276 
3277  // {A1,+,A2,+,...,+,An}<L> * {B1,+,B2,+,...,+,Bn}<L>
3278  // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [
3279  // choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z
3280  // ]]],+,...up to x=2n}.
3281  // Note that the arguments to choose() are always integers with values
3282  // known at compile time, never SCEV objects.
3283  //
3284  // The implementation avoids pointless extra computations when the two
3285  // addrec's are of different length (mathematically, it's equivalent to
3286  // an infinite stream of zeros on the right).
3287  bool OpsModified = false;
3288  for (unsigned OtherIdx = Idx+1;
3289  OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
3290  ++OtherIdx) {
3291  const SCEVAddRecExpr *OtherAddRec =
3292  dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]);
3293  if (!OtherAddRec || OtherAddRec->getLoop() != AddRecLoop)
3294  continue;
3295 
3296  // Limit max number of arguments to avoid creation of unreasonably big
3297  // SCEVAddRecs with very complex operands.
3298  if (AddRec->getNumOperands() + OtherAddRec->getNumOperands() - 1 >
3299  MaxAddRecSize || hasHugeExpression({AddRec, OtherAddRec}))
3300  continue;
3301 
3302  bool Overflow = false;
3303  Type *Ty = AddRec->getType();
3304  bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64;
3305  SmallVector<const SCEV*, 7> AddRecOps;
3306  for (int x = 0, xe = AddRec->getNumOperands() +
3307  OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) {
3309  for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) {
3310  uint64_t Coeff1 = Choose(x, 2*x - y, Overflow);
3311  for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1),
3312  ze = std::min(x+1, (int)OtherAddRec->getNumOperands());
3313  z < ze && !Overflow; ++z) {
3314  uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow);
3315  uint64_t Coeff;
3316  if (LargerThan64Bits)
3317  Coeff = umul_ov(Coeff1, Coeff2, Overflow);
3318  else
3319  Coeff = Coeff1*Coeff2;
3320  const SCEV *CoeffTerm = getConstant(Ty, Coeff);
3321  const SCEV *Term1 = AddRec->getOperand(y-z);
3322  const SCEV *Term2 = OtherAddRec->getOperand(z);
3323  SumOps.push_back(getMulExpr(CoeffTerm, Term1, Term2,
3324  SCEV::FlagAnyWrap, Depth + 1));
3325  }
3326  }
3327  if (SumOps.empty())
3328  SumOps.push_back(getZero(Ty));
3329  AddRecOps.push_back(getAddExpr(SumOps, SCEV::FlagAnyWrap, Depth + 1));
3330  }
3331  if (!Overflow) {
3332  const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRecLoop,
3334  if (Ops.size() == 2) return NewAddRec;
3335  Ops[Idx] = NewAddRec;
3336  Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
3337  OpsModified = true;
3338  AddRec = dyn_cast<SCEVAddRecExpr>(NewAddRec);
3339  if (!AddRec)
3340  break;
3341  }
3342  }
3343  if (OpsModified)
3344  return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3345 
3346  // Otherwise couldn't fold anything into this recurrence. Move onto the
3347  // next one.
3348  }
3349 
3350  // Okay, it looks like we really DO need an mul expr. Check to see if we
3351  // already have one, otherwise create a new one.
3352  return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3353 }
3354 
3355 /// Represents an unsigned remainder expression based on unsigned division.
3357  const SCEV *RHS) {
3360  "SCEVURemExpr operand types don't match!");
3361 
3362  // Short-circuit easy cases
3363  if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3364  // If constant is one, the result is trivial
3365  if (RHSC->getValue()->isOne())
3366  return getZero(LHS->getType()); // X urem 1 --> 0
3367 
3368  // If constant is a power of two, fold into a zext(trunc(LHS)).
3369  if (RHSC->getAPInt().isPowerOf2()) {
3370  Type *FullTy = LHS->getType();
3371  Type *TruncTy =
3372  IntegerType::get(getContext(), RHSC->getAPInt().logBase2());
3373  return getZeroExtendExpr(getTruncateExpr(LHS, TruncTy), FullTy);
3374  }
3375  }
3376 
3377  // Fallback to %a == %x urem %y == %x -<nuw> ((%x udiv %y) *<nuw> %y)
3378  const SCEV *UDiv = getUDivExpr(LHS, RHS);
3379  const SCEV *Mult = getMulExpr(UDiv, RHS, SCEV::FlagNUW);
3380  return getMinusSCEV(LHS, Mult, SCEV::FlagNUW);
3381 }
3382 
3383 /// Get a canonical unsigned division expression, or something simpler if
3384 /// possible.
3386  const SCEV *RHS) {
3387  assert(!LHS->getType()->isPointerTy() &&
3388  "SCEVUDivExpr operand can't be pointer!");
3389  assert(LHS->getType() == RHS->getType() &&
3390  "SCEVUDivExpr operand types don't match!");
3391 
3394  ID.AddPointer(LHS);
3395  ID.AddPointer(RHS);
3396  void *IP = nullptr;
3397  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3398  return S;
3399 
3400  // 0 udiv Y == 0
3401  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS))
3402  if (LHSC->getValue()->isZero())
3403  return LHS;
3404 
3405  if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3406  if (RHSC->getValue()->isOne())
3407  return LHS; // X udiv 1 --> x
3408  // If the denominator is zero, the result of the udiv is undefined. Don't
3409  // try to analyze it, because the resolution chosen here may differ from
3410  // the resolution chosen in other parts of the compiler.
3411  if (!RHSC->getValue()->isZero()) {
3412  // Determine if the division can be folded into the operands of
3413  // its operands.
3414  // TODO: Generalize this to non-constants by using known-bits information.
3415  Type *Ty = LHS->getType();
3416  unsigned LZ = RHSC->getAPInt().countLeadingZeros();
3417  unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1;
3418  // For non-power-of-two values, effectively round the value up to the
3419  // nearest power of two.
3420  if (!RHSC->getAPInt().isPowerOf2())
3421  ++MaxShiftAmt;
3422  IntegerType *ExtTy =
3423  IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt);
3424  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
3425  if (const SCEVConstant *Step =
3426  dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this))) {
3427  // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded.
3428  const APInt &StepInt = Step->getAPInt();
3429  const APInt &DivInt = RHSC->getAPInt();
3430  if (!StepInt.urem(DivInt) &&
3431  getZeroExtendExpr(AR, ExtTy) ==
3432  getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3433  getZeroExtendExpr(Step, ExtTy),
3434  AR->getLoop(), SCEV::FlagAnyWrap)) {
3436  for (const SCEV *Op : AR->operands())
3437  Operands.push_back(getUDivExpr(Op, RHS));
3438  return getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagNW);
3439  }
3440  /// Get a canonical UDivExpr for a recurrence.
3441  /// {X,+,N}/C => {Y,+,N}/C where Y=X-(X%N). Safe when C%N=0.
3442  // We can currently only fold X%N if X is constant.
3443  const SCEVConstant *StartC = dyn_cast<SCEVConstant>(AR->getStart());
3444  if (StartC && !DivInt.urem(StepInt) &&
3445  getZeroExtendExpr(AR, ExtTy) ==
3446  getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3447  getZeroExtendExpr(Step, ExtTy),
3448  AR->getLoop(), SCEV::FlagAnyWrap)) {
3449  const APInt &StartInt = StartC->getAPInt();
3450  const APInt &StartRem = StartInt.urem(StepInt);
3451  if (StartRem != 0) {
3452  const SCEV *NewLHS =
3453  getAddRecExpr(getConstant(StartInt - StartRem), Step,
3454  AR->getLoop(), SCEV::FlagNW);
3455  if (LHS != NewLHS) {
3456  LHS = NewLHS;
3457 
3458  // Reset the ID to include the new LHS, and check if it is
3459  // already cached.
3460  ID.clear();
3461  ID.AddInteger(scUDivExpr);
3462  ID.AddPointer(LHS);
3463  ID.AddPointer(RHS);
3464  IP = nullptr;
3465  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3466  return S;
3467  }
3468  }
3469  }
3470  }
3471  // (A*B)/C --> A*(B/C) if safe and B/C can be folded.
3472  if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) {
3474  for (const SCEV *Op : M->operands())
3475  Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3476  if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands))
3477  // Find an operand that's safely divisible.
3478  for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) {
3479  const SCEV *Op = M->getOperand(i);
3480  const SCEV *Div = getUDivExpr(Op, RHSC);
3481  if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) {
3482  Operands = SmallVector<const SCEV *, 4>(M->operands());
3483  Operands[i] = Div;
3484  return getMulExpr(Operands);
3485  }
3486  }
3487  }
3488 
3489  // (A/B)/C --> A/(B*C) if safe and B*C can be folded.
3490  if (const SCEVUDivExpr *OtherDiv = dyn_cast<SCEVUDivExpr>(LHS)) {
3491  if (auto *DivisorConstant =
3492  dyn_cast<SCEVConstant>(OtherDiv->getRHS())) {
3493  bool Overflow = false;
3494  APInt NewRHS =
3495  DivisorConstant->getAPInt().umul_ov(RHSC->getAPInt(), Overflow);
3496  if (Overflow) {
3497  return getConstant(RHSC->getType(), 0, false);
3498  }
3499  return getUDivExpr(OtherDiv->getLHS(), getConstant(NewRHS));
3500  }
3501  }
3502 
3503  // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded.
3504  if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(LHS)) {
3506  for (const SCEV *Op : A->operands())
3507  Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3508  if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) {
3509  Operands.clear();
3510  for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) {
3511  const SCEV *Op = getUDivExpr(A->getOperand(i), RHS);
3512  if (isa<SCEVUDivExpr>(Op) ||
3513  getMulExpr(Op, RHS) != A->getOperand(i))
3514  break;
3515  Operands.push_back(Op);
3516  }
3517  if (Operands.size() == A->getNumOperands())
3518  return getAddExpr(Operands);
3519  }
3520  }
3521 
3522  // Fold if both operands are constant.
3523  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS))
3524  return getConstant(LHSC->getAPInt().udiv(RHSC->getAPInt()));
3525  }
3526  }
3527 
3528  // The Insertion Point (IP) might be invalid by now (due to UniqueSCEVs
3529  // changes). Make sure we get a new one.
3530  IP = nullptr;
3531  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
3532  SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator),
3533  LHS, RHS);
3534  UniqueSCEVs.InsertNode(S, IP);
3535  registerUser(S, {LHS, RHS});
3536  return S;
3537 }
3538 
3539 APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) {
3540  APInt A = C1->getAPInt().abs();
3541  APInt B = C2->getAPInt().abs();
3542  uint32_t ABW = A.getBitWidth();
3543  uint32_t BBW = B.getBitWidth();
3544 
3545  if (ABW > BBW)
3546  B = B.zext(ABW);
3547  else if (ABW < BBW)
3548  A = A.zext(BBW);
3549 
3551 }
3552 
3553 /// Get a canonical unsigned division expression, or something simpler if
3554 /// possible. There is no representation for an exact udiv in SCEV IR, but we
3555 /// can attempt to remove factors from the LHS and RHS. We can't do this when
3556 /// it's not exact because the udiv may be clearing bits.
3558  const SCEV *RHS) {
3559  // TODO: we could try to find factors in all sorts of things, but for now we
3560  // just deal with u/exact (multiply, constant). See SCEVDivision towards the
3561  // end of this file for inspiration.
3562 
3563  const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(LHS);
3564  if (!Mul || !Mul->hasNoUnsignedWrap())
3565  return getUDivExpr(LHS, RHS);
3566 
3567  if (const SCEVConstant *RHSCst = dyn_cast<SCEVConstant>(RHS)) {
3568  // If the mulexpr multiplies by a constant, then that constant must be the
3569  // first element of the mulexpr.
3570  if (const auto *LHSCst = dyn_cast<SCEVConstant>(Mul->getOperand(0))) {
3571  if (LHSCst == RHSCst) {
3573  return getMulExpr(Operands);
3574  }
3575 
3576  // We can't just assume that LHSCst divides RHSCst cleanly, it could be
3577  // that there's a factor provided by one of the other terms. We need to
3578  // check.
3579  APInt Factor = gcd(LHSCst, RHSCst);
3580  if (!Factor.isIntN(1)) {
3581  LHSCst =
3582  cast<SCEVConstant>(getConstant(LHSCst->getAPInt().udiv(Factor)));
3583  RHSCst =
3584  cast<SCEVConstant>(getConstant(RHSCst->getAPInt().udiv(Factor)));
3586  Operands.push_back(LHSCst);
3587  Operands.append(Mul->op_begin() + 1, Mul->op_end());
3588  LHS = getMulExpr(Operands);
3589  RHS = RHSCst;
3590  Mul = dyn_cast<SCEVMulExpr>(LHS);
3591  if (!Mul)
3592  return getUDivExactExpr(LHS, RHS);
3593  }
3594  }
3595  }
3596 
3597  for (int i = 0, e = Mul->getNumOperands(); i != e; ++i) {
3598  if (Mul->getOperand(i) == RHS) {
3600  Operands.append(Mul->op_begin(), Mul->op_begin() + i);
3601  Operands.append(Mul->op_begin() + i + 1, Mul->op_end());
3602  return getMulExpr(Operands);
3603  }
3604  }
3605 
3606  return getUDivExpr(LHS, RHS);
3607 }
3608 
3609 /// Get an add recurrence expression for the specified loop. Simplify the
3610 /// expression as much as possible.
3611 const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, const SCEV *Step,
3612  const Loop *L,
3613  SCEV::NoWrapFlags Flags) {
3615  Operands.push_back(Start);
3616  if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step))
3617  if (StepChrec->getLoop() == L) {
3618  Operands.append(StepChrec->op_begin(), StepChrec->op_end());
3619  return getAddRecExpr(Operands, L, maskFlags(Flags, SCEV::FlagNW));
3620  }
3621 
3622  Operands.push_back(Step);
3623  return getAddRecExpr(Operands, L, Flags);
3624 }
3625 
3626 /// Get an add recurrence expression for the specified loop. Simplify the
3627 /// expression as much as possible.
3628 const SCEV *
3630  const Loop *L, SCEV::NoWrapFlags Flags) {
3631  if (Operands.size() == 1) return Operands[0];
3632 #ifndef NDEBUG
3633  Type *ETy = getEffectiveSCEVType(Operands[0]->getType());
3634  for (unsigned i = 1, e = Operands.size(); i != e; ++i) {
3636  "SCEVAddRecExpr operand types don't match!");
3637  assert(!Operands[i]->getType()->isPointerTy() && "Step must be integer");
3638  }
3639  for (unsigned i = 0, e = Operands.size(); i != e; ++i)
3641  "SCEVAddRecExpr operand is not loop-invariant!");
3642 #endif
3643 
3644  if (Operands.back()->isZero()) {
3645  Operands.pop_back();
3646  return getAddRecExpr(Operands, L, SCEV::FlagAnyWrap); // {X,+,0} --> X
3647  }
3648 
3649  // It's tempting to want to call getConstantMaxBackedgeTakenCount count here and
3650  // use that information to infer NUW and NSW flags. However, computing a
3651  // BE count requires calling getAddRecExpr, so we may not yet have a
3652  // meaningful BE count at this point (and if we don't, we'd be stuck
3653  // with a SCEVCouldNotCompute as the cached BE count).
3654 
3655  Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
3656 
3657  // Canonicalize nested AddRecs in by nesting them in order of loop depth.
3658  if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
3659  const Loop *NestedLoop = NestedAR->getLoop();
3660  if (L->contains(NestedLoop)
3661  ? (L->getLoopDepth() < NestedLoop->getLoopDepth())
3662  : (!NestedLoop->contains(L) &&
3663  DT.dominates(L->getHeader(), NestedLoop->getHeader()))) {
3664  SmallVector<const SCEV *, 4> NestedOperands(NestedAR->operands());
3665  Operands[0] = NestedAR->getStart();
3666  // AddRecs require their operands be loop-invariant with respect to their
3667  // loops. Don't perform this transformation if it would break this
3668  // requirement.
3669  bool AllInvariant = all_of(
3670  Operands, [&](const SCEV *Op) { return isLoopInvariant(Op, L); });
3671 
3672  if (AllInvariant) {
3673  // Create a recurrence for the outer loop with the same step size.
3674  //
3675  // The outer recurrence keeps its NW flag but only keeps NUW/NSW if the
3676  // inner recurrence has the same property.
3677  SCEV::NoWrapFlags OuterFlags =
3678  maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags());
3679 
3680  NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags);
3681  AllInvariant = all_of(NestedOperands, [&](const SCEV *Op) {
3682  return isLoopInvariant(Op, NestedLoop);
3683  });
3684 
3685  if (AllInvariant) {
3686  // Ok, both add recurrences are valid after the transformation.
3687  //
3688  // The inner recurrence keeps its NW flag but only keeps NUW/NSW if
3689  // the outer recurrence has the same property.
3690  SCEV::NoWrapFlags InnerFlags =
3691  maskFlags(NestedAR->getNoWrapFlags(), SCEV::FlagNW | Flags);
3692  return getAddRecExpr(NestedOperands, NestedLoop, InnerFlags);
3693  }
3694  }
3695  // Reset Operands to its original state.
3696  Operands[0] = NestedAR;
3697  }
3698  }
3699 
3700  // Okay, it looks like we really DO need an addrec expr. Check to see if we
3701  // already have one, otherwise create a new one.
3702  return getOrCreateAddRecExpr(Operands, L, Flags);
3703 }
3704 
3705 const SCEV *
3707  const SmallVectorImpl<const SCEV *> &IndexExprs) {
3708  const SCEV *BaseExpr = getSCEV(GEP->getPointerOperand());
3709  // getSCEV(Base)->getType() has the same address space as Base->getType()
3710  // because SCEV::getType() preserves the address space.
3711  Type *IntIdxTy = getEffectiveSCEVType(BaseExpr->getType());
3712  const bool AssumeInBoundsFlags = [&]() {
3713  if (!GEP->isInBounds())
3714  return false;
3715 
3716  // We'd like to propagate flags from the IR to the corresponding SCEV nodes,
3717  // but to do that, we have to ensure that said flag is valid in the entire
3718  // defined scope of the SCEV.
3719  auto *GEPI = dyn_cast<Instruction>(GEP);
3720  // TODO: non-instructions have global scope. We might be able to prove
3721  // some global scope cases
3722  return GEPI && isSCEVExprNeverPoison(GEPI);
3723  }();
3724 
3725  SCEV::NoWrapFlags OffsetWrap =
3726  AssumeInBoundsFlags ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
3727 
3728  Type *CurTy = GEP->getType();
3729  bool FirstIter = true;
3731  for (const SCEV *IndexExpr : IndexExprs) {
3732  // Compute the (potentially symbolic) offset in bytes for this index.
3733  if (StructType *STy = dyn_cast<StructType>(CurTy)) {
3734  // For a struct, add the member offset.
3735  ConstantInt *Index = cast<SCEVConstant>(IndexExpr)->getValue();
3736  unsigned FieldNo = Index->getZExtValue();
3737  const SCEV *FieldOffset = getOffsetOfExpr(IntIdxTy, STy, FieldNo);
3738  Offsets.push_back(FieldOffset);
3739 
3740  // Update CurTy to the type of the field at Index.
3741  CurTy = STy->getTypeAtIndex(Index);
3742  } else {
3743  // Update CurTy to its element type.
3744  if (FirstIter) {
3745  assert(isa<PointerType>(CurTy) &&
3746  "The first index of a GEP indexes a pointer");
3747  CurTy = GEP->getSourceElementType();
3748  FirstIter = false;
3749  } else {
3750  CurTy = GetElementPtrInst::getTypeAtIndex(CurTy, (uint64_t)0);
3751  }
3752  // For an array, add the element offset, explicitly scaled.
3753  const SCEV *ElementSize = getSizeOfExpr(IntIdxTy, CurTy);
3754  // Getelementptr indices are signed.
3755  IndexExpr = getTruncateOrSignExtend(IndexExpr, IntIdxTy);
3756 
3757  // Multiply the index by the element size to compute the element offset.
3758  const SCEV *LocalOffset = getMulExpr(IndexExpr, ElementSize, OffsetWrap);
3759  Offsets.push_back(LocalOffset);
3760  }
3761  }
3762 
3763  // Handle degenerate case of GEP without offsets.
3764  if (Offsets.empty())
3765  return BaseExpr;
3766 
3767  // Add the offsets together, assuming nsw if inbounds.
3768  const SCEV *Offset = getAddExpr(Offsets, OffsetWrap);
3769  // Add the base address and the offset. We cannot use the nsw flag, as the
3770  // base address is unsigned. However, if we know that the offset is
3771  // non-negative, we can use nuw.
3772  SCEV::NoWrapFlags BaseWrap = AssumeInBoundsFlags && isKnownNonNegative(Offset)
3774  auto *GEPExpr = getAddExpr(BaseExpr, Offset, BaseWrap);
3775  assert(BaseExpr->getType() == GEPExpr->getType() &&
3776  "GEP should not change type mid-flight.");
3777  return GEPExpr;
3778 }
3779 
3780 SCEV *ScalarEvolution::findExistingSCEVInCache(SCEVTypes SCEVType,
3781  ArrayRef<const SCEV *> Ops) {
3783  ID.AddInteger(SCEVType);
3784  for (const SCEV *Op : Ops)
3785  ID.AddPointer(Op);
3786  void *IP = nullptr;
3787  return UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3788 }
3789 
3790 const SCEV *ScalarEvolution::getAbsExpr(const SCEV *Op, bool IsNSW) {
3792  return getSMaxExpr(Op, getNegativeSCEV(Op, Flags));
3793 }
3794 
3797  assert(SCEVMinMaxExpr::isMinMaxType(Kind) && "Not a SCEVMinMaxExpr!");
3798  assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
3799  if (Ops.size() == 1) return Ops[0];
3800 #ifndef NDEBUG
3801  Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
3802  for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
3803  assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
3804  "Operand types don't match!");
3805  assert(Ops[0]->getType()->isPointerTy() ==
3806  Ops[i]->getType()->isPointerTy() &&
3807  "min/max should be consistently pointerish");
3808  }
3809 #endif
3810 
3811  bool IsSigned = Kind == scSMaxExpr || Kind == scSMinExpr;
3812  bool IsMax = Kind == scSMaxExpr || Kind == scUMaxExpr;
3813 
3814  // Sort by complexity, this groups all similar expression types together.
3815  GroupByComplexity(Ops, &LI, DT);
3816 
3817  // Check if we have created the same expression before.
3818  if (const SCEV *S = findExistingSCEVInCache(Kind, Ops)) {
3819  return S;
3820  }
3821 
3822  // If there are any constants, fold them together.
3823  unsigned Idx = 0;
3824  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3825  ++Idx;
3826  assert(Idx < Ops.size());
3827  auto FoldOp = [&](const APInt &LHS, const APInt &RHS) {
3828  if (Kind == scSMaxExpr)
3829  return APIntOps::smax(LHS, RHS);
3830  else if (Kind == scSMinExpr)
3831  return APIntOps::smin(LHS, RHS);
3832  else if (Kind == scUMaxExpr)
3833  return APIntOps::umax(LHS, RHS);
3834  else if (Kind == scUMinExpr)
3835  return APIntOps::umin(LHS, RHS);
3836  llvm_unreachable("Unknown SCEV min/max opcode");
3837  };
3838 
3839  while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
3840  // We found two constants, fold them together!
3841  ConstantInt *Fold = ConstantInt::get(
3842  getContext(), FoldOp(LHSC->getAPInt(), RHSC->getAPInt()));
3843  Ops[0] = getConstant(Fold);
3844  Ops.erase(Ops.begin()+1); // Erase the folded element
3845  if (Ops.size() == 1) return Ops[0];
3846  LHSC = cast<SCEVConstant>(Ops[0]);
3847  }
3848 
3849  bool IsMinV = LHSC->getValue()->isMinValue(IsSigned);
3850  bool IsMaxV = LHSC->getValue()->isMaxValue(IsSigned);
3851 
3852  if (IsMax ? IsMinV : IsMaxV) {
3853  // If we are left with a constant minimum(/maximum)-int, strip it off.
3854  Ops.erase(Ops.begin());
3855  --Idx;
3856  } else if (IsMax ? IsMaxV : IsMinV) {
3857  // If we have a max(/min) with a constant maximum(/minimum)-int,
3858  // it will always be the extremum.
3859  return LHSC;
3860  }
3861 
3862  if (Ops.size() == 1) return Ops[0];
3863  }
3864 
3865  // Find the first operation of the same kind
3866  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < Kind)
3867  ++Idx;
3868 
3869  // Check to see if one of the operands is of the same kind. If so, expand its
3870  // operands onto our operand list, and recurse to simplify.
3871  if (Idx < Ops.size()) {
3872  bool DeletedAny = false;
3873  while (Ops[Idx]->getSCEVType() == Kind) {
3874  const SCEVMinMaxExpr *SMME = cast<SCEVMinMaxExpr>(Ops[Idx]);
3875  Ops.erase(Ops.begin()+Idx);
3876  Ops.append(SMME->op_begin(), SMME->op_end());
3877  DeletedAny = true;
3878  }
3879 
3880  if (DeletedAny)
3881  return getMinMaxExpr(Kind, Ops);
3882  }
3883 
3884  // Okay, check to see if the same value occurs in the operand list twice. If
3885  // so, delete one. Since we sorted the list, these values are required to
3886  // be adjacent.
3887  llvm::CmpInst::Predicate GEPred =
3889  llvm::CmpInst::Predicate LEPred =
3891  llvm::CmpInst::Predicate FirstPred = IsMax ? GEPred : LEPred;
3892  llvm::CmpInst::Predicate SecondPred = IsMax ? LEPred : GEPred;
3893  for (unsigned i = 0, e = Ops.size() - 1; i != e; ++i) {
3894  if (Ops[i] == Ops[i + 1] ||
3895  isKnownViaNonRecursiveReasoning(FirstPred, Ops[i], Ops[i + 1])) {
3896  // X op Y op Y --> X op Y
3897  // X op Y --> X, if we know X, Y are ordered appropriately
3898  Ops.erase(Ops.begin() + i + 1, Ops.begin() + i + 2);
3899  --i;
3900  --e;
3901  } else if (isKnownViaNonRecursiveReasoning(SecondPred, Ops[i],
3902  Ops[i + 1])) {
3903  // X op Y --> Y, if we know X, Y are ordered appropriately
3904  Ops.erase(Ops.begin() + i, Ops.begin() + i + 1);
3905  --i;
3906  --e;
3907  }
3908  }
3909 
3910  if (Ops.size() == 1) return Ops[0];
3911 
3912  assert(!Ops.empty() && "Reduced smax down to nothing!");
3913 
3914  // Okay, it looks like we really DO need an expr. Check to see if we
3915  // already have one, otherwise create a new one.
3917  ID.AddInteger(Kind);
3918  for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3919  ID.AddPointer(Ops[i]);
3920  void *IP = nullptr;
3921  const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3922  if (ExistingSCEV)
3923  return ExistingSCEV;
3924  const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3925  std::uninitialized_copy(Ops.begin(), Ops.end(), O);
3926  SCEV *S = new (SCEVAllocator)
3927  SCEVMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
3928 
3929  UniqueSCEVs.InsertNode(S, IP);
3930  registerUser(S, Ops);
3931  return S;
3932 }
3933 
3934 namespace {
3935 
3936 class SCEVSequentialMinMaxDeduplicatingVisitor final
3937  : public SCEVVisitor<SCEVSequentialMinMaxDeduplicatingVisitor,
3938  Optional<const SCEV *>> {
3939  using RetVal = Optional<const SCEV *>;
3941 
3942  ScalarEvolution &SE;
3943  const SCEVTypes RootKind; // Must be a sequential min/max expression.
3944  const SCEVTypes NonSequentialRootKind; // Non-sequential variant of RootKind.
3946 
3947  bool canRecurseInto(SCEVTypes Kind) const {
3948  // We can only recurse into the SCEV expression of the same effective type
3949  // as the type of our root SCEV expression.
3950  return RootKind == Kind || NonSequentialRootKind == Kind;
3951  };
3952 
3953  RetVal visitAnyMinMaxExpr(const SCEV *S) {
3954  assert((isa<SCEVMinMaxExpr>(S) || isa<SCEVSequentialMinMaxExpr>(S)) &&
3955  "Only for min/max expressions.");
3956  SCEVTypes Kind = S->getSCEVType();
3957 
3958  if (!canRecurseInto(Kind))
3959  return S;
3960 
3961  auto *NAry = cast<SCEVNAryExpr>(S);
3963  bool Changed =
3964  visit(Kind, makeArrayRef(NAry->op_begin(), NAry->op_end()), NewOps);
3965 
3966  if (!Changed)
3967  return S;
3968  if (NewOps.empty())
3969  return std::nullopt;
3970 
3971  return isa<SCEVSequentialMinMaxExpr>(S)
3972  ? SE.getSequentialMinMaxExpr(Kind, NewOps)
3973  : SE.getMinMaxExpr(Kind, NewOps);
3974  }
3975 
3976  RetVal visit(const SCEV *S) {
3977  // Has the whole operand been seen already?
3978  if (!SeenOps.insert(S).second)
3979  return std::nullopt;
3980  return Base::visit(S);
3981  }
3982 
3983 public:
3984  SCEVSequentialMinMaxDeduplicatingVisitor(ScalarEvolution &SE,
3985  SCEVTypes RootKind)
3986  : SE(SE), RootKind(RootKind),
3987  NonSequentialRootKind(
3988  SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType(
3989  RootKind)) {}
3990 
3991  bool /*Changed*/ visit(SCEVTypes Kind, ArrayRef<const SCEV *> OrigOps,
3993  bool Changed = false;
3995  Ops.reserve(OrigOps.size());
3996 
3997  for (const SCEV *Op : OrigOps) {
3998  RetVal NewOp = visit(Op);
3999  if (NewOp != Op)
4000  Changed = true;
4001  if (NewOp)
4002  Ops.emplace_back(*NewOp);
4003  }
4004 
4005  if (Changed)
4006  NewOps = std::move(Ops);
4007  return Changed;
4008  }
4009 
4010  RetVal visitConstant(const SCEVConstant *Constant) { return Constant; }
4011 
4012  RetVal visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { return Expr; }
4013 
4014  RetVal visitTruncateExpr(const SCEVTruncateExpr *Expr) { return Expr; }
4015 
4016  RetVal visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { return Expr; }
4017 
4018  RetVal visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { return Expr; }
4019 
4020  RetVal visitAddExpr(const SCEVAddExpr *Expr) { return Expr; }
4021 
4022  RetVal visitMulExpr(const SCEVMulExpr *Expr) { return Expr; }
4023 
4024  RetVal visitUDivExpr(const SCEVUDivExpr *Expr) { return Expr; }
4025 
4026  RetVal visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
4027 
4028  RetVal visitSMaxExpr(const SCEVSMaxExpr *Expr) {
4029  return visitAnyMinMaxExpr(Expr);
4030  }
4031 
4032  RetVal visitUMaxExpr(const SCEVUMaxExpr *Expr) {
4033  return visitAnyMinMaxExpr(Expr);
4034  }
4035 
4036  RetVal visitSMinExpr(const SCEVSMinExpr *Expr) {
4037  return visitAnyMinMaxExpr(Expr);
4038  }
4039 
4040  RetVal visitUMinExpr(const SCEVUMinExpr *Expr) {
4041  return visitAnyMinMaxExpr(Expr);
4042  }
4043 
4044  RetVal visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
4045  return visitAnyMinMaxExpr(Expr);
4046  }
4047 
4048  RetVal visitUnknown(const SCEVUnknown *Expr) { return Expr; }
4049 
4050  RetVal visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { return Expr; }
4051 };
4052 
4053 } // namespace
4054 
4055 /// Return true if V is poison given that AssumedPoison is already poison.
4056 static bool impliesPoison(const SCEV *AssumedPoison, const SCEV *S) {
4057  // The only way poison may be introduced in a SCEV expression is from a
4058  // poison SCEVUnknown (ConstantExprs are also represented as SCEVUnknown,
4059  // not SCEVConstant). Notably, nowrap flags in SCEV nodes can *not*
4060  // introduce poison -- they encode guaranteed, non-speculated knowledge.
4061  //
4062  // Additionally, all SCEV nodes propagate poison from inputs to outputs,
4063  // with the notable exception of umin_seq, where only poison from the first
4064  // operand is (unconditionally) propagated.
4065  struct SCEVPoisonCollector {
4066  bool LookThroughSeq;
4067  SmallPtrSet<const SCEV *, 4> MaybePoison;
4068  SCEVPoisonCollector(bool LookThroughSeq) : LookThroughSeq(LookThroughSeq) {}
4069 
4070  bool follow(const SCEV *S) {
4071  // TODO: We can always follow the first operand, but the SCEVTraversal
4072  // API doesn't support this.
4073  if (!LookThroughSeq && isa<SCEVSequentialMinMaxExpr>(S))
4074  return false;
4075 
4076  if (auto *SU = dyn_cast<SCEVUnknown>(S)) {
4077  if (!isGuaranteedNotToBePoison(SU->getValue()))
4078  MaybePoison.insert(S);
4079  }
4080  return true;
4081  }
4082  bool isDone() const { return false; }
4083  };
4084 
4085  // First collect all SCEVs that might result in AssumedPoison to be poison.
4086  // We need to look through umin_seq here, because we want to find all SCEVs
4087  // that *might* result in poison, not only those that are *required* to.
4088  SCEVPoisonCollector PC1(/* LookThroughSeq */ true);
4089  visitAll(AssumedPoison, PC1);
4090 
4091  // AssumedPoison is never poison. As the assumption is false, the implication
4092  // is true. Don't bother walking the other SCEV in this case.
4093  if (PC1.MaybePoison.empty())
4094  return true;
4095 
4096  // Collect all SCEVs in S that, if poison, *will* result in S being poison
4097  // as well. We cannot look through umin_seq here, as its argument only *may*
4098  // make the result poison.
4099  SCEVPoisonCollector PC2(/* LookThroughSeq */ false);
4100  visitAll(S, PC2);
4101 
4102  // Make sure that no matter which SCEV in PC1.MaybePoison is actually poison,
4103  // it will also make S poison by being part of PC2.MaybePoison.
4104  return all_of(PC1.MaybePoison,
4105  [&](const SCEV *S) { return PC2.MaybePoison.contains(S); });
4106 }
4107 
4108 const SCEV *
4111  assert(SCEVSequentialMinMaxExpr::isSequentialMinMaxType(Kind) &&
4112  "Not a SCEVSequentialMinMaxExpr!");
4113  assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
4114  if (Ops.size() == 1)
4115  return Ops[0];
4116 #ifndef NDEBUG
4117  Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
4118  for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4119  assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
4120  "Operand types don't match!");
4121  assert(Ops[0]->getType()->isPointerTy() ==
4122  Ops[i]->getType()->isPointerTy() &&
4123  "min/max should be consistently pointerish");
4124  }
4125 #endif
4126 
4127  // Note that SCEVSequentialMinMaxExpr is *NOT* commutative,
4128  // so we can *NOT* do any kind of sorting of the expressions!
4129 
4130  // Check if we have created the same expression before.
4131  if (const SCEV *S = findExistingSCEVInCache(Kind, Ops))
4132  return S;
4133 
4134  // FIXME: there are *some* simplifications that we can do here.
4135 
4136  // Keep only the first instance of an operand.
4137  {
4138  SCEVSequentialMinMaxDeduplicatingVisitor Deduplicator(*this, Kind);
4139  bool Changed = Deduplicator.visit(Kind, Ops, Ops);
4140  if (Changed)
4141  return getSequentialMinMaxExpr(Kind, Ops);
4142  }
4143 
4144  // Check to see if one of the operands is of the same kind. If so, expand its
4145  // operands onto our operand list, and recurse to simplify.
4146  {
4147  unsigned Idx = 0;
4148  bool DeletedAny = false;
4149  while (Idx < Ops.size()) {
4150  if (Ops[Idx]->getSCEVType() != Kind) {
4151  ++Idx;
4152  continue;
4153  }
4154  const auto *SMME = cast<SCEVSequentialMinMaxExpr>(Ops[Idx]);
4155  Ops.erase(Ops.begin() + Idx);
4156  Ops.insert(Ops.begin() + Idx, SMME->op_begin(), SMME->op_end());
4157  DeletedAny = true;
4158  }
4159 
4160  if (DeletedAny)
4161  return getSequentialMinMaxExpr(Kind, Ops);
4162  }
4163 
4164  const SCEV *SaturationPoint;
4165  ICmpInst::Predicate Pred;
4166  switch (Kind) {
4167  case scSequentialUMinExpr:
4168  SaturationPoint = getZero(Ops[0]->getType());
4169  Pred = ICmpInst::ICMP_ULE;
4170  break;
4171  default:
4172  llvm_unreachable("Not a sequential min/max type.");
4173  }
4174 
4175  for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4176  // We can replace %x umin_seq %y with %x umin %y if either:
4177  // * %y being poison implies %x is also poison.
4178  // * %x cannot be the saturating value (e.g. zero for umin).
4179  if (::impliesPoison(Ops[i], Ops[i - 1]) ||
4180  isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_NE, Ops[i - 1],
4181  SaturationPoint)) {
4182  SmallVector<const SCEV *> SeqOps = {Ops[i - 1], Ops[i]};
4183  Ops[i - 1] = getMinMaxExpr(
4185  SeqOps);
4186  Ops.erase(Ops.begin() + i);
4187  return getSequentialMinMaxExpr(Kind, Ops);
4188  }
4189  // Fold %x umin_seq %y to %x if %x ule %y.
4190  // TODO: We might be able to prove the predicate for a later operand.
4191  if (isKnownViaNonRecursiveReasoning(Pred, Ops[i - 1], Ops[i])) {
4192  Ops.erase(Ops.begin() + i);
4193  return getSequentialMinMaxExpr(Kind, Ops);
4194  }
4195  }
4196 
4197  // Okay, it looks like we really DO need an expr. Check to see if we
4198  // already have one, otherwise create a new one.
4200  ID.AddInteger(Kind);
4201  for (unsigned i = 0, e = Ops.size(); i != e; ++i)
4202  ID.AddPointer(Ops[i]);
4203  void *IP = nullptr;
4204  const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
4205  if (ExistingSCEV)
4206  return ExistingSCEV;
4207 
4208  const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
4209  std::uninitialized_copy(Ops.begin(), Ops.end(), O);
4210  SCEV *S = new (SCEVAllocator)
4211  SCEVSequentialMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
4212 
4213  UniqueSCEVs.InsertNode(S, IP);
4214  registerUser(S, Ops);
4215  return S;
4216 }
4217 
4220  return getSMaxExpr(Ops);
4221 }
4222 
4224  return getMinMaxExpr(scSMaxExpr, Ops);
4225 }
4226 
4229  return getUMaxExpr(Ops);
4230 }
4231 
4233  return getMinMaxExpr(scUMaxExpr, Ops);
4234 }
4235 
4237  const SCEV *RHS) {
4239  return getSMinExpr(Ops);
4240 }
4241 
4243  return getMinMaxExpr(scSMinExpr, Ops);
4244 }
4245 
4247  bool Sequential) {
4249  return getUMinExpr(Ops, Sequential);
4250 }
4251 
4253  bool Sequential) {
4254  return Sequential ? getSequentialMinMaxExpr(scSequentialUMinExpr, Ops)
4255  : getMinMaxExpr(scUMinExpr, Ops);
4256 }
4257 
4258 const SCEV *
4260  ScalableVectorType *ScalableTy) {
4261  Constant *NullPtr = Constant::getNullValue(ScalableTy->getPointerTo());
4262  Constant *One = ConstantInt::get(IntTy, 1);
4263  Constant *GEP = ConstantExpr::getGetElementPtr(ScalableTy, NullPtr, One);
4264  // Note that the expression we created is the final expression, we don't
4265  // want to simplify it any further Also, if we call a normal getSCEV(),
4266  // we'll end up in an endless recursion. So just create an SCEVUnknown.
4267  return getUnknown(ConstantExpr::getPtrToInt(GEP, IntTy));
4268 }
4269 
4270 const SCEV *ScalarEvolution::getSizeOfExpr(Type *IntTy, Type *AllocTy) {
4271  if (auto *ScalableAllocTy = dyn_cast<ScalableVectorType>(AllocTy))
4272  return getSizeOfScalableVectorExpr(IntTy, ScalableAllocTy);
4273  // We can bypass creating a target-independent constant expression and then
4274  // folding it back into a ConstantInt. This is just a compile-time
4275  // optimization.
4276  return getConstant(IntTy, getDataLayout().getTypeAllocSize(AllocTy));
4277 }
4278 
4280  if (auto *ScalableStoreTy = dyn_cast<ScalableVectorType>(StoreTy))
4281  return getSizeOfScalableVectorExpr(IntTy, ScalableStoreTy);
4282  // We can bypass creating a target-independent constant expression and then
4283  // folding it back into a ConstantInt. This is just a compile-time
4284  // optimization.
4285  return getConstant(IntTy, getDataLayout().getTypeStoreSize(StoreTy));
4286 }
4287 
4289  StructType *STy,
4290  unsigned FieldNo) {
4291  // We can bypass creating a target-independent constant expression and then
4292  // folding it back into a ConstantInt. This is just a compile-time
4293  // optimization.
4294  return getConstant(
4295  IntTy, getDataLayout().getStructLayout(STy)->getElementOffset(FieldNo));
4296 }
4297 
4299  // Don't attempt to do anything other than create a SCEVUnknown object
4300  // here. createSCEV only calls getUnknown after checking for all other
4301  // interesting possibilities, and any other code that calls getUnknown
4302  // is doing so in order to hide a value from SCEV canonicalization.
4303 
4306  ID.AddPointer(V);
4307  void *IP = nullptr;
4308  if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) {
4309  assert(cast<SCEVUnknown>(S)->getValue() == V &&
4310  "Stale SCEVUnknown in uniquing map!");
4311  return S;
4312  }
4313  SCEV *S = new (SCEVAllocator) SCEVUnknown(ID.Intern(SCEVAllocator), V, this,
4314  FirstUnknown);
4315  FirstUnknown = cast<SCEVUnknown>(S);
4316  UniqueSCEVs.InsertNode(S, IP);
4317  return S;
4318 }
4319 
4320 //===----------------------------------------------------------------------===//
4321 // Basic SCEV Analysis and PHI Idiom Recognition Code
4322 //
4323 
4324 /// Test if values of the given type are analyzable within the SCEV
4325 /// framework. This primarily includes integer types, and it can optionally
4326 /// include pointer types if the ScalarEvolution class has access to
4327 /// target-specific information.
4329  // Integers and pointers are always SCEVable.
4330  return Ty->isIntOrPtrTy();
4331 }
4332 
4333 /// Return the size in bits of the specified type, for which isSCEVable must
4334 /// return true.
4336  assert(isSCEVable(Ty) && "Type is not SCEVable!");
4337  if (Ty->isPointerTy())
4339  return getDataLayout().getTypeSizeInBits(Ty);
4340 }
4341 
4342 /// Return a type with the same bitwidth as the given type and which represents
4343 /// how SCEV will treat the given type, for which isSCEVable must return
4344 /// true. For pointer types, this is the pointer index sized integer type.
4346  assert(isSCEVable(Ty) && "Type is not SCEVable!");
4347 
4348  if (Ty->isIntegerTy())
4349  return Ty;
4350 
4351  // The only other support type is pointer.
4352  assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!");
4353  return getDataLayout().getIndexType(Ty);
4354 }
4355 
4357  return getTypeSizeInBits(T1) >= getTypeSizeInBits(T2) ? T1 : T2;
4358 }
4359 
4361  const SCEV *B) {
4362  /// For a valid use point to exist, the defining scope of one operand
4363  /// must dominate the other.
4364  bool PreciseA, PreciseB;
4365  auto *ScopeA = getDefiningScopeBound({A}, PreciseA);
4366  auto *ScopeB = getDefiningScopeBound({B}, PreciseB);
4367  if (!PreciseA || !PreciseB)
4368  // Can't tell.
4369  return false;
4370  return (ScopeA == ScopeB) || DT.dominates(ScopeA, ScopeB) ||
4371  DT.dominates(ScopeB, ScopeA);
4372 }
4373 
4374 
4376  return CouldNotCompute.get();
4377 }
4378 
4379 bool ScalarEvolution::checkValidity(const SCEV *S) const {
4380  bool ContainsNulls = SCEVExprContains(S, [](const SCEV *S) {
4381  auto *SU = dyn_cast<SCEVUnknown>(S);
4382  return SU && SU->getValue() == nullptr;
4383  });
4384 
4385  return !ContainsNulls;
4386 }
4387 
4389  HasRecMapType::iterator I = HasRecMap.find(S);
4390  if (I != HasRecMap.end())
4391  return I->second;
4392 
4393  bool FoundAddRec =
4394  SCEVExprContains(S, [](const SCEV *S) { return isa<SCEVAddRecExpr>(S); });
4395  HasRecMap.insert({S, FoundAddRec});
4396  return FoundAddRec;
4397 }
4398 
4399 /// Return the ValueOffsetPair set for \p S. \p S can be represented
4400 /// by the value and offset from any ValueOffsetPair in the set.
4401 ArrayRef<Value *> ScalarEvolution::getSCEVValues(const SCEV *S) {
4402  ExprValueMapType::iterator SI = ExprValueMap.find_as(S);
4403  if (SI == ExprValueMap.end())
4404  return std::nullopt;
4405 #ifndef NDEBUG
4406  if (VerifySCEVMap) {
4407  // Check there is no dangling Value in the set returned.
4408  for (Value *V : SI->second)
4409  assert(ValueExprMap.count(V));
4410  }
4411 #endif
4412  return SI->second.getArrayRef();
4413 }
4414 
4415 /// Erase Value from ValueExprMap and ExprValueMap. ValueExprMap.erase(V)
4416 /// cannot be used separately. eraseValueFromMap should be used to remove
4417 /// V from ValueExprMap and ExprValueMap at the same time.
4418 void ScalarEvolution::eraseValueFromMap(Value *V) {
4419  ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4420  if (I != ValueExprMap.end()) {
4421  auto EVIt = ExprValueMap.find(I->second);
4422  bool Removed = EVIt->second.remove(V);
4423  (void) Removed;
4424  assert(Removed && "Value not in ExprValueMap?");
4425  ValueExprMap.erase(I);
4426  }
4427 }
4428 
4429 void ScalarEvolution::insertValueToMap(Value *V, const SCEV *S) {
4430  // A recursive query may have already computed the SCEV. It should be
4431  // equivalent, but may not necessarily be exactly the same, e.g. due to lazily
4432  // inferred nowrap flags.
4433  auto It = ValueExprMap.find_as(V);
4434  if (It == ValueExprMap.end()) {
4435  ValueExprMap.insert({SCEVCallbackVH(V, this), S});
4436  ExprValueMap[S].insert(V);
4437  }
4438 }
4439 
4440 /// Return an existing SCEV if it exists, otherwise analyze the expression and
4441 /// create a new one.
4443  assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4444 
4445  if (const SCEV *S = getExistingSCEV(V))
4446  return S;
4447  return createSCEVIter(V);
4448 }
4449 
4450 const SCEV *ScalarEvolution::getExistingSCEV(Value *V) {
4451  assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4452 
4453  ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4454  if (I != ValueExprMap.end()) {
4455  const SCEV *S = I->second;
4456  assert(checkValidity(S) &&
4457  "existing SCEV has not been properly invalidated");
4458  return S;
4459  }
4460  return nullptr;
4461 }
4462 
4463 /// Return a SCEV corresponding to -V = -1*V
4465  SCEV::NoWrapFlags Flags) {
4466  if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4467  return getConstant(
4468  cast<ConstantInt>(ConstantExpr::getNeg(VC->getValue())));
4469 
4470  Type *Ty = V->getType();
4471  Ty = getEffectiveSCEVType(Ty);
4472  return getMulExpr(V, getMinusOne(Ty), Flags);
4473 }
4474 
4475 /// If Expr computes ~A, return A else return nullptr
4476 static const SCEV *MatchNotExpr(const SCEV *Expr) {
4477  const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Expr);
4478  if (!Add || Add->getNumOperands() != 2 ||
4479  !Add->getOperand(0)->isAllOnesValue())
4480  return nullptr;
4481 
4482  const SCEVMulExpr *AddRHS = dyn_cast<SCEVMulExpr>(Add->getOperand(1));
4483  if (!AddRHS || AddRHS->getNumOperands() != 2 ||
4484  !AddRHS->getOperand(0)->isAllOnesValue())
4485  return nullptr;
4486 
4487  return AddRHS->getOperand(1);
4488 }
4489 
4490 /// Return a SCEV corresponding to ~V = -1-V
4492  assert(!V->getType()->isPointerTy() && "Can't negate pointer");
4493 
4494  if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4495  return getConstant(
4496  cast<ConstantInt>(ConstantExpr::getNot(VC->getValue())));
4497 
4498  // Fold ~(u|s)(min|max)(~x, ~y) to (u|s)(max|min)(x, y)
4499  if (const SCEVMinMaxExpr *MME = dyn_cast<SCEVMinMaxExpr>(V)) {
4500  auto MatchMinMaxNegation = [&](const SCEVMinMaxExpr *MME) {
4501  SmallVector<const SCEV *, 2> MatchedOperands;
4502  for (const SCEV *Operand : MME->operands()) {
4503  const SCEV *Matched = MatchNotExpr(Operand);
4504  if (!Matched)
4505  return (const SCEV *)nullptr;
4506  MatchedOperands.push_back(Matched);
4507  }
4508  return getMinMaxExpr(SCEVMinMaxExpr::negate(MME->getSCEVType()),
4509  MatchedOperands);
4510  };
4511  if (const SCEV *Replaced = MatchMinMaxNegation(MME))
4512  return Replaced;
4513  }
4514 
4515  Type *Ty = V->getType();
4516  Ty = getEffectiveSCEVType(Ty);
4517  return getMinusSCEV(getMinusOne(Ty), V);
4518 }
4519 
4521  assert(P->getType()->isPointerTy());
4522 
4523  if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(P)) {
4524  // The base of an AddRec is the first operand.
4525  SmallVector<const SCEV *> Ops{AddRec->operands()};
4526  Ops[0] = removePointerBase(Ops[0]);
4527  // Don't try to transfer nowrap flags for now. We could in some cases
4528  // (for example, if pointer operand of the AddRec is a SCEVUnknown).
4529  return getAddRecExpr(Ops, AddRec->getLoop(), SCEV::FlagAnyWrap);
4530  }
4531  if (auto *Add = dyn_cast<SCEVAddExpr>(P)) {
4532  // The base of an Add is the pointer operand.
4533  SmallVector<const SCEV *> Ops{Add->operands()};
4534  const SCEV **PtrOp = nullptr;
4535  for (const SCEV *&AddOp : Ops) {
4536  if (AddOp->getType()->isPointerTy()) {
4537  assert(!PtrOp && "Cannot have multiple pointer ops");
4538  PtrOp = &AddOp;
4539  }
4540  }
4541  *PtrOp = removePointerBase(*PtrOp);
4542  // Don't try to transfer nowrap flags for now. We could in some cases
4543  // (for example, if the pointer operand of the Add is a SCEVUnknown).
4544  return getAddExpr(Ops);
4545  }
4546  // Any other expression must be a pointer base.
4547  return getZero(P->getType());
4548 }
4549 
4551  SCEV::NoWrapFlags Flags,
4552  unsigned Depth) {
4553  // Fast path: X - X --> 0.
4554  if (LHS == RHS)
4555  return getZero(LHS->getType());
4556 
4557  // If we subtract two pointers with different pointer bases, bail.
4558  // Eventually, we're going to add an assertion to getMulExpr that we
4559  // can't multiply by a pointer.
4560  if (RHS->getType()->isPointerTy()) {
4561  if (!LHS->getType()->isPointerTy() ||
4563  return getCouldNotCompute();
4566  }
4567 
4568  // We represent LHS - RHS as LHS + (-1)*RHS. This transformation
4569  // makes it so that we cannot make much use of NUW.
4570  auto AddFlags = SCEV::FlagAnyWrap;
4571  const bool RHSIsNotMinSigned =
4573  if (hasFlags(Flags, SCEV::FlagNSW)) {
4574  // Let M be the minimum representable signed value. Then (-1)*RHS
4575  // signed-wraps if and only if RHS is M. That can happen even for
4576  // a NSW subtraction because e.g. (-1)*M signed-wraps even though
4577  // -1 - M does not. So to transfer NSW from LHS - RHS to LHS +
4578  // (-1)*RHS, we need to prove that RHS != M.
4579  //
4580  // If LHS is non-negative and we know that LHS - RHS does not
4581  // signed-wrap, then RHS cannot be M. So we can rule out signed-wrap
4582  // either by proving that RHS > M or that LHS >= 0.
4583  if (RHSIsNotMinSigned || isKnownNonNegative(LHS)) {
4584  AddFlags = SCEV::FlagNSW;
4585  }
4586  }
4587 
4588  // FIXME: Find a correct way to transfer NSW to (-1)*M when LHS -
4589  // RHS is NSW and LHS >= 0.
4590  //
4591  // The difficulty here is that the NSW flag may have been proven
4592  // relative to a loop that is to be found in a recurrence in LHS and
4593  // not in RHS. Applying NSW to (-1)*M may then let the NSW have a
4594  // larger scope than intended.
4595  auto NegFlags = RHSIsNotMinSigned ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
4596 
4597  return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags, Depth);
4598 }
4599 
4601  unsigned Depth) {
4602  Type *SrcTy = V->getType();
4603  assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4604  "Cannot truncate or zero extend with non-integer arguments!");
4605  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4606  return V; // No conversion
4607  if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4608  return getTruncateExpr(V, Ty, Depth);
4609  return getZeroExtendExpr(V, Ty, Depth);
4610 }
4611 
4613  unsigned Depth) {
4614  Type *SrcTy = V->getType();
4615  assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4616  "Cannot truncate or zero extend with non-integer arguments!");
4617  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4618  return V; // No conversion
4619  if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4620  return getTruncateExpr(V, Ty, Depth);
4621  return getSignExtendExpr(V, Ty, Depth);
4622 }
4623 
4624 const SCEV *
4626  Type *SrcTy = V->getType();
4627  assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4628  "Cannot noop or zero extend with non-integer arguments!");
4629  assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
4630  "getNo