LLVM  14.0.0git
ScalarEvolution.cpp
Go to the documentation of this file.
1 //===- ScalarEvolution.cpp - Scalar Evolution Analysis --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the implementation of the scalar evolution analysis
10 // engine, which is used primarily to analyze expressions involving induction
11 // variables in loops.
12 //
13 // There are several aspects to this library. First is the representation of
14 // scalar expressions, which are represented as subclasses of the SCEV class.
15 // These classes are used to represent certain types of subexpressions that we
16 // can handle. We only create one SCEV of a particular shape, so
17 // pointer-comparisons for equality are legal.
18 //
19 // One important aspect of the SCEV objects is that they are never cyclic, even
20 // if there is a cycle in the dataflow for an expression (ie, a PHI node). If
21 // the PHI node is one of the idioms that we can represent (e.g., a polynomial
22 // recurrence) then we represent it directly as a recurrence node, otherwise we
23 // represent it as a SCEVUnknown node.
24 //
25 // In addition to being able to represent expressions of various types, we also
26 // have folders that are used to build the *canonical* representation for a
27 // particular expression. These folders are capable of using a variety of
28 // rewrite rules to simplify the expressions.
29 //
30 // Once the folders are defined, we can implement the more interesting
31 // higher-level code, such as the code that recognizes PHI nodes of various
32 // types, computes the execution count of a loop, etc.
33 //
34 // TODO: We should use these routines and value representations to implement
35 // dependence analysis!
36 //
37 //===----------------------------------------------------------------------===//
38 //
39 // There are several good references for the techniques used in this analysis.
40 //
41 // Chains of recurrences -- a method to expedite the evaluation
42 // of closed-form functions
43 // Olaf Bachmann, Paul S. Wang, Eugene V. Zima
44 //
45 // On computational properties of chains of recurrences
46 // Eugene V. Zima
47 //
48 // Symbolic Evaluation of Chains of Recurrences for Loop Optimization
49 // Robert A. van Engelen
50 //
51 // Efficient Symbolic Analysis for Optimizing Compilers
52 // Robert A. van Engelen
53 //
54 // Using the chains of recurrences algebra for data dependence testing and
55 // induction variable substitution
56 // MS Thesis, Johnie Birch
57 //
58 //===----------------------------------------------------------------------===//
59 
61 #include "llvm/ADT/APInt.h"
62 #include "llvm/ADT/ArrayRef.h"
63 #include "llvm/ADT/DenseMap.h"
66 #include "llvm/ADT/FoldingSet.h"
67 #include "llvm/ADT/None.h"
68 #include "llvm/ADT/Optional.h"
69 #include "llvm/ADT/STLExtras.h"
70 #include "llvm/ADT/ScopeExit.h"
71 #include "llvm/ADT/Sequence.h"
72 #include "llvm/ADT/SetVector.h"
73 #include "llvm/ADT/SmallPtrSet.h"
74 #include "llvm/ADT/SmallSet.h"
75 #include "llvm/ADT/SmallVector.h"
76 #include "llvm/ADT/Statistic.h"
77 #include "llvm/ADT/StringRef.h"
81 #include "llvm/Analysis/LoopInfo.h"
86 #include "llvm/Config/llvm-config.h"
87 #include "llvm/IR/Argument.h"
88 #include "llvm/IR/BasicBlock.h"
89 #include "llvm/IR/CFG.h"
90 #include "llvm/IR/Constant.h"
91 #include "llvm/IR/ConstantRange.h"
92 #include "llvm/IR/Constants.h"
93 #include "llvm/IR/DataLayout.h"
94 #include "llvm/IR/DerivedTypes.h"
95 #include "llvm/IR/Dominators.h"
96 #include "llvm/IR/Function.h"
97 #include "llvm/IR/GlobalAlias.h"
98 #include "llvm/IR/GlobalValue.h"
99 #include "llvm/IR/GlobalVariable.h"
100 #include "llvm/IR/InstIterator.h"
101 #include "llvm/IR/InstrTypes.h"
102 #include "llvm/IR/Instruction.h"
103 #include "llvm/IR/Instructions.h"
104 #include "llvm/IR/IntrinsicInst.h"
105 #include "llvm/IR/Intrinsics.h"
106 #include "llvm/IR/LLVMContext.h"
107 #include "llvm/IR/Metadata.h"
108 #include "llvm/IR/Operator.h"
109 #include "llvm/IR/PatternMatch.h"
110 #include "llvm/IR/Type.h"
111 #include "llvm/IR/Use.h"
112 #include "llvm/IR/User.h"
113 #include "llvm/IR/Value.h"
114 #include "llvm/IR/Verifier.h"
115 #include "llvm/InitializePasses.h"
116 #include "llvm/Pass.h"
117 #include "llvm/Support/Casting.h"
119 #include "llvm/Support/Compiler.h"
120 #include "llvm/Support/Debug.h"
122 #include "llvm/Support/KnownBits.h"
125 #include <algorithm>
126 #include <cassert>
127 #include <climits>
128 #include <cstddef>
129 #include <cstdint>
130 #include <cstdlib>
131 #include <map>
132 #include <memory>
133 #include <tuple>
134 #include <utility>
135 #include <vector>
136 
137 using namespace llvm;
138 using namespace PatternMatch;
139 
140 #define DEBUG_TYPE "scalar-evolution"
141 
142 STATISTIC(NumTripCountsComputed,
143  "Number of loops with predictable loop counts");
144 STATISTIC(NumTripCountsNotComputed,
145  "Number of loops without predictable loop counts");
146 STATISTIC(NumBruteForceTripCountsComputed,
147  "Number of loops with trip counts computed by force");
148 
149 static cl::opt<unsigned>
150 MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
152  cl::desc("Maximum number of iterations SCEV will "
153  "symbolically execute a constant "
154  "derived loop"),
155  cl::init(100));
156 
157 // FIXME: Enable this with EXPENSIVE_CHECKS when the test suite is clean.
159  "verify-scev", cl::Hidden,
160  cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"));
162  "verify-scev-strict", cl::Hidden,
163  cl::desc("Enable stricter verification with -verify-scev is passed"));
164 static cl::opt<bool>
165  VerifySCEVMap("verify-scev-maps", cl::Hidden,
166  cl::desc("Verify no dangling value in ScalarEvolution's "
167  "ExprValueMap (slow)"));
168 
169 static cl::opt<bool> VerifyIR(
170  "scev-verify-ir", cl::Hidden,
171  cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"),
172  cl::init(false));
173 
175  "scev-mulops-inline-threshold", cl::Hidden,
176  cl::desc("Threshold for inlining multiplication operands into a SCEV"),
177  cl::init(32));
178 
180  "scev-addops-inline-threshold", cl::Hidden,
181  cl::desc("Threshold for inlining addition operands into a SCEV"),
182  cl::init(500));
183 
185  "scalar-evolution-max-scev-compare-depth", cl::Hidden,
186  cl::desc("Maximum depth of recursive SCEV complexity comparisons"),
187  cl::init(32));
188 
190  "scalar-evolution-max-scev-operations-implication-depth", cl::Hidden,
191  cl::desc("Maximum depth of recursive SCEV operations implication analysis"),
192  cl::init(2));
193 
195  "scalar-evolution-max-value-compare-depth", cl::Hidden,
196  cl::desc("Maximum depth of recursive value complexity comparisons"),
197  cl::init(2));
198 
199 static cl::opt<unsigned>
200  MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden,
201  cl::desc("Maximum depth of recursive arithmetics"),
202  cl::init(32));
203 
205  "scalar-evolution-max-constant-evolving-depth", cl::Hidden,
206  cl::desc("Maximum depth of recursive constant evolving"), cl::init(32));
207 
208 static cl::opt<unsigned>
209  MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden,
210  cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"),
211  cl::init(8));
212 
213 static cl::opt<unsigned>
214  MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden,
215  cl::desc("Max coefficients in AddRec during evolving"),
216  cl::init(8));
217 
218 static cl::opt<unsigned>
219  HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden,
220  cl::desc("Size of the expression which is considered huge"),
221  cl::init(4096));
222 
223 static cl::opt<bool>
224 ClassifyExpressions("scalar-evolution-classify-expressions",
225  cl::Hidden, cl::init(true),
226  cl::desc("When printing analysis, include information on every instruction"));
227 
229  "scalar-evolution-use-expensive-range-sharpening", cl::Hidden,
230  cl::init(false),
231  cl::desc("Use more powerful methods of sharpening expression ranges. May "
232  "be costly in terms of compile time"));
233 
234 //===----------------------------------------------------------------------===//
235 // SCEV class definitions
236 //===----------------------------------------------------------------------===//
237 
238 //===----------------------------------------------------------------------===//
239 // Implementation of the SCEV class.
240 //
241 
242 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
244  print(dbgs());
245  dbgs() << '\n';
246 }
247 #endif
248 
249 void SCEV::print(raw_ostream &OS) const {
250  switch (getSCEVType()) {
251  case scConstant:
252  cast<SCEVConstant>(this)->getValue()->printAsOperand(OS, false);
253  return;
254  case scPtrToInt: {
255  const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(this);
256  const SCEV *Op = PtrToInt->getOperand();
257  OS << "(ptrtoint " << *Op->getType() << " " << *Op << " to "
258  << *PtrToInt->getType() << ")";
259  return;
260  }
261  case scTruncate: {
262  const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(this);
263  const SCEV *Op = Trunc->getOperand();
264  OS << "(trunc " << *Op->getType() << " " << *Op << " to "
265  << *Trunc->getType() << ")";
266  return;
267  }
268  case scZeroExtend: {
269  const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(this);
270  const SCEV *Op = ZExt->getOperand();
271  OS << "(zext " << *Op->getType() << " " << *Op << " to "
272  << *ZExt->getType() << ")";
273  return;
274  }
275  case scSignExtend: {
276  const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(this);
277  const SCEV *Op = SExt->getOperand();
278  OS << "(sext " << *Op->getType() << " " << *Op << " to "
279  << *SExt->getType() << ")";
280  return;
281  }
282  case scAddRecExpr: {
283  const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(this);
284  OS << "{" << *AR->getOperand(0);
285  for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i)
286  OS << ",+," << *AR->getOperand(i);
287  OS << "}<";
288  if (AR->hasNoUnsignedWrap())
289  OS << "nuw><";
290  if (AR->hasNoSignedWrap())
291  OS << "nsw><";
292  if (AR->hasNoSelfWrap() &&
293  !AR->getNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW)))
294  OS << "nw><";
295  AR->getLoop()->getHeader()->printAsOperand(OS, /*PrintType=*/false);
296  OS << ">";
297  return;
298  }
299  case scAddExpr:
300  case scMulExpr:
301  case scUMaxExpr:
302  case scSMaxExpr:
303  case scUMinExpr:
304  case scSMinExpr:
305  case scSequentialUMinExpr: {
306  const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(this);
307  const char *OpStr = nullptr;
308  switch (NAry->getSCEVType()) {
309  case scAddExpr: OpStr = " + "; break;
310  case scMulExpr: OpStr = " * "; break;
311  case scUMaxExpr: OpStr = " umax "; break;
312  case scSMaxExpr: OpStr = " smax "; break;
313  case scUMinExpr:
314  OpStr = " umin ";
315  break;
316  case scSMinExpr:
317  OpStr = " smin ";
318  break;
320  OpStr = " umin_seq ";
321  break;
322  default:
323  llvm_unreachable("There are no other nary expression types.");
324  }
325  OS << "(";
326  ListSeparator LS(OpStr);
327  for (const SCEV *Op : NAry->operands())
328  OS << LS << *Op;
329  OS << ")";
330  switch (NAry->getSCEVType()) {
331  case scAddExpr:
332  case scMulExpr:
333  if (NAry->hasNoUnsignedWrap())
334  OS << "<nuw>";
335  if (NAry->hasNoSignedWrap())
336  OS << "<nsw>";
337  break;
338  default:
339  // Nothing to print for other nary expressions.
340  break;
341  }
342  return;
343  }
344  case scUDivExpr: {
345  const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(this);
346  OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")";
347  return;
348  }
349  case scUnknown: {
350  const SCEVUnknown *U = cast<SCEVUnknown>(this);
351  Type *AllocTy;
352  if (U->isSizeOf(AllocTy)) {
353  OS << "sizeof(" << *AllocTy << ")";
354  return;
355  }
356  if (U->isAlignOf(AllocTy)) {
357  OS << "alignof(" << *AllocTy << ")";
358  return;
359  }
360 
361  Type *CTy;
362  Constant *FieldNo;
363  if (U->isOffsetOf(CTy, FieldNo)) {
364  OS << "offsetof(" << *CTy << ", ";
365  FieldNo->printAsOperand(OS, false);
366  OS << ")";
367  return;
368  }
369 
370  // Otherwise just print it normally.
371  U->getValue()->printAsOperand(OS, false);
372  return;
373  }
374  case scCouldNotCompute:
375  OS << "***COULDNOTCOMPUTE***";
376  return;
377  }
378  llvm_unreachable("Unknown SCEV kind!");
379 }
380 
381 Type *SCEV::getType() const {
382  switch (getSCEVType()) {
383  case scConstant:
384  return cast<SCEVConstant>(this)->getType();
385  case scPtrToInt:
386  case scTruncate:
387  case scZeroExtend:
388  case scSignExtend:
389  return cast<SCEVCastExpr>(this)->getType();
390  case scAddRecExpr:
391  return cast<SCEVAddRecExpr>(this)->getType();
392  case scMulExpr:
393  return cast<SCEVMulExpr>(this)->getType();
394  case scUMaxExpr:
395  case scSMaxExpr:
396  case scUMinExpr:
397  case scSMinExpr:
398  return cast<SCEVMinMaxExpr>(this)->getType();
400  return cast<SCEVSequentialMinMaxExpr>(this)->getType();
401  case scAddExpr:
402  return cast<SCEVAddExpr>(this)->getType();
403  case scUDivExpr:
404  return cast<SCEVUDivExpr>(this)->getType();
405  case scUnknown:
406  return cast<SCEVUnknown>(this)->getType();
407  case scCouldNotCompute:
408  llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
409  }
410  llvm_unreachable("Unknown SCEV kind!");
411 }
412 
413 bool SCEV::isZero() const {
414  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
415  return SC->getValue()->isZero();
416  return false;
417 }
418 
419 bool SCEV::isOne() const {
420  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
421  return SC->getValue()->isOne();
422  return false;
423 }
424 
425 bool SCEV::isAllOnesValue() const {
426  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
427  return SC->getValue()->isMinusOne();
428  return false;
429 }
430 
432  const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(this);
433  if (!Mul) return false;
434 
435  // If there is a constant factor, it will be first.
436  const SCEVConstant *SC = dyn_cast<SCEVConstant>(Mul->getOperand(0));
437  if (!SC) return false;
438 
439  // Return true if the value is negative, this matches things like (-42 * V).
440  return SC->getAPInt().isNegative();
441 }
442 
445 
447  return S->getSCEVType() == scCouldNotCompute;
448 }
449 
452  ID.AddInteger(scConstant);
453  ID.AddPointer(V);
454  void *IP = nullptr;
455  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
456  SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V);
457  UniqueSCEVs.InsertNode(S, IP);
458  return S;
459 }
460 
462  return getConstant(ConstantInt::get(getContext(), Val));
463 }
464 
465 const SCEV *
467  IntegerType *ITy = cast<IntegerType>(getEffectiveSCEVType(Ty));
468  return getConstant(ConstantInt::get(ITy, V, isSigned));
469 }
470 
472  const SCEV *op, Type *ty)
473  : SCEV(ID, SCEVTy, computeExpressionSize(op)), Ty(ty) {
474  Operands[0] = op;
475 }
476 
477 SCEVPtrToIntExpr::SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op,
478  Type *ITy)
479  : SCEVCastExpr(ID, scPtrToInt, Op, ITy) {
480  assert(getOperand()->getType()->isPointerTy() && Ty->isIntegerTy() &&
481  "Must be a non-bit-width-changing pointer-to-integer cast!");
482 }
483 
485  SCEVTypes SCEVTy, const SCEV *op,
486  Type *ty)
487  : SCEVCastExpr(ID, SCEVTy, op, ty) {}
488 
489 SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op,
490  Type *ty)
492  assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
493  "Cannot truncate non-integer value!");
494 }
495 
496 SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID,
497  const SCEV *op, Type *ty)
499  assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
500  "Cannot zero extend non-integer value!");
501 }
502 
503 SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID,
504  const SCEV *op, Type *ty)
506  assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
507  "Cannot sign extend non-integer value!");
508 }
509 
510 void SCEVUnknown::deleted() {
511  // Clear this SCEVUnknown from various maps.
512  SE->forgetMemoizedResults(this);
513 
514  // Remove this SCEVUnknown from the uniquing map.
515  SE->UniqueSCEVs.RemoveNode(this);
516 
517  // Release the value.
518  setValPtr(nullptr);
519 }
520 
521 void SCEVUnknown::allUsesReplacedWith(Value *New) {
522  // Remove this SCEVUnknown from the uniquing map.
523  SE->UniqueSCEVs.RemoveNode(this);
524 
525  // Update this SCEVUnknown to point to the new value. This is needed
526  // because there may still be outstanding SCEVs which still point to
527  // this SCEVUnknown.
528  setValPtr(New);
529 }
530 
531 bool SCEVUnknown::isSizeOf(Type *&AllocTy) const {
532  if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue()))
533  if (VCE->getOpcode() == Instruction::PtrToInt)
534  if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
535  if (CE->getOpcode() == Instruction::GetElementPtr &&
536  CE->getOperand(0)->isNullValue() &&
537  CE->getNumOperands() == 2)
538  if (ConstantInt *CI = dyn_cast<ConstantInt>(CE->getOperand(1)))
539  if (CI->isOne()) {
540  AllocTy = cast<GEPOperator>(CE)->getSourceElementType();
541  return true;
542  }
543 
544  return false;
545 }
546 
547 bool SCEVUnknown::isAlignOf(Type *&AllocTy) const {
548  if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue()))
549  if (VCE->getOpcode() == Instruction::PtrToInt)
550  if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
551  if (CE->getOpcode() == Instruction::GetElementPtr &&
552  CE->getOperand(0)->isNullValue()) {
553  Type *Ty = cast<GEPOperator>(CE)->getSourceElementType();
554  if (StructType *STy = dyn_cast<StructType>(Ty))
555  if (!STy->isPacked() &&
556  CE->getNumOperands() == 3 &&
557  CE->getOperand(1)->isNullValue()) {
558  if (ConstantInt *CI = dyn_cast<ConstantInt>(CE->getOperand(2)))
559  if (CI->isOne() &&
560  STy->getNumElements() == 2 &&
561  STy->getElementType(0)->isIntegerTy(1)) {
562  AllocTy = STy->getElementType(1);
563  return true;
564  }
565  }
566  }
567 
568  return false;
569 }
570 
571 bool SCEVUnknown::isOffsetOf(Type *&CTy, Constant *&FieldNo) const {
572  if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue()))
573  if (VCE->getOpcode() == Instruction::PtrToInt)
574  if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
575  if (CE->getOpcode() == Instruction::GetElementPtr &&
576  CE->getNumOperands() == 3 &&
577  CE->getOperand(0)->isNullValue() &&
578  CE->getOperand(1)->isNullValue()) {
579  Type *Ty = cast<GEPOperator>(CE)->getSourceElementType();
580  // Ignore vector types here so that ScalarEvolutionExpander doesn't
581  // emit getelementptrs that index into vectors.
582  if (Ty->isStructTy() || Ty->isArrayTy()) {
583  CTy = Ty;
584  FieldNo = CE->getOperand(2);
585  return true;
586  }
587  }
588 
589  return false;
590 }
591 
592 //===----------------------------------------------------------------------===//
593 // SCEV Utilities
594 //===----------------------------------------------------------------------===//
595 
596 /// Compare the two values \p LV and \p RV in terms of their "complexity" where
597 /// "complexity" is a partial (and somewhat ad-hoc) relation used to order
598 /// operands in SCEV expressions. \p EqCache is a set of pairs of values that
599 /// have been previously deemed to be "equally complex" by this routine. It is
600 /// intended to avoid exponential time complexity in cases like:
601 ///
602 /// %a = f(%x, %y)
603 /// %b = f(%a, %a)
604 /// %c = f(%b, %b)
605 ///
606 /// %d = f(%x, %y)
607 /// %e = f(%d, %d)
608 /// %f = f(%e, %e)
609 ///
610 /// CompareValueComplexity(%f, %c)
611 ///
612 /// Since we do not continue running this routine on expression trees once we
613 /// have seen unequal values, there is no need to track them in the cache.
614 static int
616  const LoopInfo *const LI, Value *LV, Value *RV,
617  unsigned Depth) {
618  if (Depth > MaxValueCompareDepth || EqCacheValue.isEquivalent(LV, RV))
619  return 0;
620 
621  // Order pointer values after integer values. This helps SCEVExpander form
622  // GEPs.
623  bool LIsPointer = LV->getType()->isPointerTy(),
624  RIsPointer = RV->getType()->isPointerTy();
625  if (LIsPointer != RIsPointer)
626  return (int)LIsPointer - (int)RIsPointer;
627 
628  // Compare getValueID values.
629  unsigned LID = LV->getValueID(), RID = RV->getValueID();
630  if (LID != RID)
631  return (int)LID - (int)RID;
632 
633  // Sort arguments by their position.
634  if (const auto *LA = dyn_cast<Argument>(LV)) {
635  const auto *RA = cast<Argument>(RV);
636  unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo();
637  return (int)LArgNo - (int)RArgNo;
638  }
639 
640  if (const auto *LGV = dyn_cast<GlobalValue>(LV)) {
641  const auto *RGV = cast<GlobalValue>(RV);
642 
643  const auto IsGVNameSemantic = [&](const GlobalValue *GV) {
644  auto LT = GV->getLinkage();
645  return !(GlobalValue::isPrivateLinkage(LT) ||
647  };
648 
649  // Use the names to distinguish the two values, but only if the
650  // names are semantically important.
651  if (IsGVNameSemantic(LGV) && IsGVNameSemantic(RGV))
652  return LGV->getName().compare(RGV->getName());
653  }
654 
655  // For instructions, compare their loop depth, and their operand count. This
656  // is pretty loose.
657  if (const auto *LInst = dyn_cast<Instruction>(LV)) {
658  const auto *RInst = cast<Instruction>(RV);
659 
660  // Compare loop depths.
661  const BasicBlock *LParent = LInst->getParent(),
662  *RParent = RInst->getParent();
663  if (LParent != RParent) {
664  unsigned LDepth = LI->getLoopDepth(LParent),
665  RDepth = LI->getLoopDepth(RParent);
666  if (LDepth != RDepth)
667  return (int)LDepth - (int)RDepth;
668  }
669 
670  // Compare the number of operands.
671  unsigned LNumOps = LInst->getNumOperands(),
672  RNumOps = RInst->getNumOperands();
673  if (LNumOps != RNumOps)
674  return (int)LNumOps - (int)RNumOps;
675 
676  for (unsigned Idx : seq(0u, LNumOps)) {
677  int Result =
678  CompareValueComplexity(EqCacheValue, LI, LInst->getOperand(Idx),
679  RInst->getOperand(Idx), Depth + 1);
680  if (Result != 0)
681  return Result;
682  }
683  }
684 
685  EqCacheValue.unionSets(LV, RV);
686  return 0;
687 }
688 
689 // Return negative, zero, or positive, if LHS is less than, equal to, or greater
690 // than RHS, respectively. A three-way result allows recursive comparisons to be
691 // more efficient.
692 // If the max analysis depth was reached, return None, assuming we do not know
693 // if they are equivalent for sure.
694 static Optional<int>
696  EquivalenceClasses<const Value *> &EqCacheValue,
697  const LoopInfo *const LI, const SCEV *LHS,
698  const SCEV *RHS, DominatorTree &DT, unsigned Depth = 0) {
699  // Fast-path: SCEVs are uniqued so we can do a quick equality check.
700  if (LHS == RHS)
701  return 0;
702 
703  // Primarily, sort the SCEVs by their getSCEVType().
704  SCEVTypes LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
705  if (LType != RType)
706  return (int)LType - (int)RType;
707 
708  if (EqCacheSCEV.isEquivalent(LHS, RHS))
709  return 0;
710 
712  return None;
713 
714  // Aside from the getSCEVType() ordering, the particular ordering
715  // isn't very important except that it's beneficial to be consistent,
716  // so that (a + b) and (b + a) don't end up as different expressions.
717  switch (LType) {
718  case scUnknown: {
719  const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
720  const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
721 
722  int X = CompareValueComplexity(EqCacheValue, LI, LU->getValue(),
723  RU->getValue(), Depth + 1);
724  if (X == 0)
725  EqCacheSCEV.unionSets(LHS, RHS);
726  return X;
727  }
728 
729  case scConstant: {
730  const SCEVConstant *LC = cast<SCEVConstant>(LHS);
731  const SCEVConstant *RC = cast<SCEVConstant>(RHS);
732 
733  // Compare constant values.
734  const APInt &LA = LC->getAPInt();
735  const APInt &RA = RC->getAPInt();
736  unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth();
737  if (LBitWidth != RBitWidth)
738  return (int)LBitWidth - (int)RBitWidth;
739  return LA.ult(RA) ? -1 : 1;
740  }
741 
742  case scAddRecExpr: {
743  const SCEVAddRecExpr *LA = cast<SCEVAddRecExpr>(LHS);
744  const SCEVAddRecExpr *RA = cast<SCEVAddRecExpr>(RHS);
745 
746  // There is always a dominance between two recs that are used by one SCEV,
747  // so we can safely sort recs by loop header dominance. We require such
748  // order in getAddExpr.
749  const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop();
750  if (LLoop != RLoop) {
751  const BasicBlock *LHead = LLoop->getHeader(), *RHead = RLoop->getHeader();
752  assert(LHead != RHead && "Two loops share the same header?");
753  if (DT.dominates(LHead, RHead))
754  return 1;
755  else
756  assert(DT.dominates(RHead, LHead) &&
757  "No dominance between recurrences used by one SCEV?");
758  return -1;
759  }
760 
761  // Addrec complexity grows with operand count.
762  unsigned LNumOps = LA->getNumOperands(), RNumOps = RA->getNumOperands();
763  if (LNumOps != RNumOps)
764  return (int)LNumOps - (int)RNumOps;
765 
766  // Lexicographically compare.
767  for (unsigned i = 0; i != LNumOps; ++i) {
768  auto X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI,
769  LA->getOperand(i), RA->getOperand(i), DT,
770  Depth + 1);
771  if (X != 0)
772  return X;
773  }
774  EqCacheSCEV.unionSets(LHS, RHS);
775  return 0;
776  }
777 
778  case scAddExpr:
779  case scMulExpr:
780  case scSMaxExpr:
781  case scUMaxExpr:
782  case scSMinExpr:
783  case scUMinExpr:
784  case scSequentialUMinExpr: {
785  const SCEVNAryExpr *LC = cast<SCEVNAryExpr>(LHS);
786  const SCEVNAryExpr *RC = cast<SCEVNAryExpr>(RHS);
787 
788  // Lexicographically compare n-ary expressions.
789  unsigned LNumOps = LC->getNumOperands(), RNumOps = RC->getNumOperands();
790  if (LNumOps != RNumOps)
791  return (int)LNumOps - (int)RNumOps;
792 
793  for (unsigned i = 0; i != LNumOps; ++i) {
794  auto X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI,
795  LC->getOperand(i), RC->getOperand(i), DT,
796  Depth + 1);
797  if (X != 0)
798  return X;
799  }
800  EqCacheSCEV.unionSets(LHS, RHS);
801  return 0;
802  }
803 
804  case scUDivExpr: {
805  const SCEVUDivExpr *LC = cast<SCEVUDivExpr>(LHS);
806  const SCEVUDivExpr *RC = cast<SCEVUDivExpr>(RHS);
807 
808  // Lexicographically compare udiv expressions.
809  auto X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LC->getLHS(),
810  RC->getLHS(), DT, Depth + 1);
811  if (X != 0)
812  return X;
813  X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LC->getRHS(),
814  RC->getRHS(), DT, Depth + 1);
815  if (X == 0)
816  EqCacheSCEV.unionSets(LHS, RHS);
817  return X;
818  }
819 
820  case scPtrToInt:
821  case scTruncate:
822  case scZeroExtend:
823  case scSignExtend: {
824  const SCEVCastExpr *LC = cast<SCEVCastExpr>(LHS);
825  const SCEVCastExpr *RC = cast<SCEVCastExpr>(RHS);
826 
827  // Compare cast expressions by operand.
828  auto X =
829  CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LC->getOperand(),
830  RC->getOperand(), DT, Depth + 1);
831  if (X == 0)
832  EqCacheSCEV.unionSets(LHS, RHS);
833  return X;
834  }
835 
836  case scCouldNotCompute:
837  llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
838  }
839  llvm_unreachable("Unknown SCEV kind!");
840 }
841 
842 /// Given a list of SCEV objects, order them by their complexity, and group
843 /// objects of the same complexity together by value. When this routine is
844 /// finished, we know that any duplicates in the vector are consecutive and that
845 /// complexity is monotonically increasing.
846 ///
847 /// Note that we go take special precautions to ensure that we get deterministic
848 /// results from this routine. In other words, we don't want the results of
849 /// this to depend on where the addresses of various SCEV objects happened to
850 /// land in memory.
852  LoopInfo *LI, DominatorTree &DT) {
853  if (Ops.size() < 2) return; // Noop
854 
857 
858  // Whether LHS has provably less complexity than RHS.
859  auto IsLessComplex = [&](const SCEV *LHS, const SCEV *RHS) {
860  auto Complexity =
861  CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LHS, RHS, DT);
862  return Complexity && *Complexity < 0;
863  };
864  if (Ops.size() == 2) {
865  // This is the common case, which also happens to be trivially simple.
866  // Special case it.
867  const SCEV *&LHS = Ops[0], *&RHS = Ops[1];
868  if (IsLessComplex(RHS, LHS))
869  std::swap(LHS, RHS);
870  return;
871  }
872 
873  // Do the rough sort by complexity.
874  llvm::stable_sort(Ops, [&](const SCEV *LHS, const SCEV *RHS) {
875  return IsLessComplex(LHS, RHS);
876  });
877 
878  // Now that we are sorted by complexity, group elements of the same
879  // complexity. Note that this is, at worst, N^2, but the vector is likely to
880  // be extremely short in practice. Note that we take this approach because we
881  // do not want to depend on the addresses of the objects we are grouping.
882  for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) {
883  const SCEV *S = Ops[i];
884  unsigned Complexity = S->getSCEVType();
885 
886  // If there are any objects of the same complexity and same value as this
887  // one, group them.
888  for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) {
889  if (Ops[j] == S) { // Found a duplicate.
890  // Move it to immediately after i'th element.
891  std::swap(Ops[i+1], Ops[j]);
892  ++i; // no need to rescan it.
893  if (i == e-2) return; // Done!
894  }
895  }
896  }
897 }
898 
899 /// Returns true if \p Ops contains a huge SCEV (the subtree of S contains at
900 /// least HugeExprThreshold nodes).
902  return any_of(Ops, [](const SCEV *S) {
903  return S->getExpressionSize() >= HugeExprThreshold;
904  });
905 }
906 
907 //===----------------------------------------------------------------------===//
908 // Simple SCEV method implementations
909 //===----------------------------------------------------------------------===//
910 
911 /// Compute BC(It, K). The result has width W. Assume, K > 0.
912 static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
913  ScalarEvolution &SE,
914  Type *ResultTy) {
915  // Handle the simplest case efficiently.
916  if (K == 1)
917  return SE.getTruncateOrZeroExtend(It, ResultTy);
918 
919  // We are using the following formula for BC(It, K):
920  //
921  // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
922  //
923  // Suppose, W is the bitwidth of the return value. We must be prepared for
924  // overflow. Hence, we must assure that the result of our computation is
925  // equal to the accurate one modulo 2^W. Unfortunately, division isn't
926  // safe in modular arithmetic.
927  //
928  // However, this code doesn't use exactly that formula; the formula it uses
929  // is something like the following, where T is the number of factors of 2 in
930  // K! (i.e. trailing zeros in the binary representation of K!), and ^ is
931  // exponentiation:
932  //
933  // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T)
934  //
935  // This formula is trivially equivalent to the previous formula. However,
936  // this formula can be implemented much more efficiently. The trick is that
937  // K! / 2^T is odd, and exact division by an odd number *is* safe in modular
938  // arithmetic. To do exact division in modular arithmetic, all we have
939  // to do is multiply by the inverse. Therefore, this step can be done at
940  // width W.
941  //
942  // The next issue is how to safely do the division by 2^T. The way this
943  // is done is by doing the multiplication step at a width of at least W + T
944  // bits. This way, the bottom W+T bits of the product are accurate. Then,
945  // when we perform the division by 2^T (which is equivalent to a right shift
946  // by T), the bottom W bits are accurate. Extra bits are okay; they'll get
947  // truncated out after the division by 2^T.
948  //
949  // In comparison to just directly using the first formula, this technique
950  // is much more efficient; using the first formula requires W * K bits,
951  // but this formula less than W + K bits. Also, the first formula requires
952  // a division step, whereas this formula only requires multiplies and shifts.
953  //
954  // It doesn't matter whether the subtraction step is done in the calculation
955  // width or the input iteration count's width; if the subtraction overflows,
956  // the result must be zero anyway. We prefer here to do it in the width of
957  // the induction variable because it helps a lot for certain cases; CodeGen
958  // isn't smart enough to ignore the overflow, which leads to much less
959  // efficient code if the width of the subtraction is wider than the native
960  // register width.
961  //
962  // (It's possible to not widen at all by pulling out factors of 2 before
963  // the multiplication; for example, K=2 can be calculated as
964  // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires
965  // extra arithmetic, so it's not an obvious win, and it gets
966  // much more complicated for K > 3.)
967 
968  // Protection from insane SCEVs; this bound is conservative,
969  // but it probably doesn't matter.
970  if (K > 1000)
971  return SE.getCouldNotCompute();
972 
973  unsigned W = SE.getTypeSizeInBits(ResultTy);
974 
975  // Calculate K! / 2^T and T; we divide out the factors of two before
976  // multiplying for calculating K! / 2^T to avoid overflow.
977  // Other overflow doesn't matter because we only care about the bottom
978  // W bits of the result.
979  APInt OddFactorial(W, 1);
980  unsigned T = 1;
981  for (unsigned i = 3; i <= K; ++i) {
982  APInt Mult(W, i);
983  unsigned TwoFactors = Mult.countTrailingZeros();
984  T += TwoFactors;
985  Mult.lshrInPlace(TwoFactors);
986  OddFactorial *= Mult;
987  }
988 
989  // We need at least W + T bits for the multiplication step
990  unsigned CalculationBits = W + T;
991 
992  // Calculate 2^T, at width T+W.
993  APInt DivFactor = APInt::getOneBitSet(CalculationBits, T);
994 
995  // Calculate the multiplicative inverse of K! / 2^T;
996  // this multiplication factor will perform the exact division by
997  // K! / 2^T.
999  APInt MultiplyFactor = OddFactorial.zext(W+1);
1000  MultiplyFactor = MultiplyFactor.multiplicativeInverse(Mod);
1001  MultiplyFactor = MultiplyFactor.trunc(W);
1002 
1003  // Calculate the product, at width T+W
1004  IntegerType *CalculationTy = IntegerType::get(SE.getContext(),
1005  CalculationBits);
1006  const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
1007  for (unsigned i = 1; i != K; ++i) {
1008  const SCEV *S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i));
1009  Dividend = SE.getMulExpr(Dividend,
1010  SE.getTruncateOrZeroExtend(S, CalculationTy));
1011  }
1012 
1013  // Divide by 2^T
1014  const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
1015 
1016  // Truncate the result, and divide by K! / 2^T.
1017 
1018  return SE.getMulExpr(SE.getConstant(MultiplyFactor),
1019  SE.getTruncateOrZeroExtend(DivResult, ResultTy));
1020 }
1021 
1022 /// Return the value of this chain of recurrences at the specified iteration
1023 /// number. We can evaluate this recurrence by multiplying each element in the
1024 /// chain by the binomial coefficient corresponding to it. In other words, we
1025 /// can evaluate {A,+,B,+,C,+,D} as:
1026 ///
1027 /// A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
1028 ///
1029 /// where BC(It, k) stands for binomial coefficient.
1031  ScalarEvolution &SE) const {
1032  return evaluateAtIteration(makeArrayRef(op_begin(), op_end()), It, SE);
1033 }
1034 
1035 const SCEV *
1037  const SCEV *It, ScalarEvolution &SE) {
1038  assert(Operands.size() > 0);
1039  const SCEV *Result = Operands[0];
1040  for (unsigned i = 1, e = Operands.size(); i != e; ++i) {
1041  // The computation is correct in the face of overflow provided that the
1042  // multiplication is performed _after_ the evaluation of the binomial
1043  // coefficient.
1044  const SCEV *Coeff = BinomialCoefficient(It, i, SE, Result->getType());
1045  if (isa<SCEVCouldNotCompute>(Coeff))
1046  return Coeff;
1047 
1048  Result = SE.getAddExpr(Result, SE.getMulExpr(Operands[i], Coeff));
1049  }
1050  return Result;
1051 }
1052 
1053 //===----------------------------------------------------------------------===//
1054 // SCEV Expression folder implementations
1055 //===----------------------------------------------------------------------===//
1056 
1058  unsigned Depth) {
1059  assert(Depth <= 1 &&
1060  "getLosslessPtrToIntExpr() should self-recurse at most once.");
1061 
1062  // We could be called with an integer-typed operands during SCEV rewrites.
1063  // Since the operand is an integer already, just perform zext/trunc/self cast.
1064  if (!Op->getType()->isPointerTy())
1065  return Op;
1066 
1067  // What would be an ID for such a SCEV cast expression?
1069  ID.AddInteger(scPtrToInt);
1070  ID.AddPointer(Op);
1071 
1072  void *IP = nullptr;
1073 
1074  // Is there already an expression for such a cast?
1075  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1076  return S;
1077 
1078  // It isn't legal for optimizations to construct new ptrtoint expressions
1079  // for non-integral pointers.
1080  if (getDataLayout().isNonIntegralPointerType(Op->getType()))
1081  return getCouldNotCompute();
1082 
1083  Type *IntPtrTy = getDataLayout().getIntPtrType(Op->getType());
1084 
1085  // We can only trivially model ptrtoint if SCEV's effective (integer) type
1086  // is sufficiently wide to represent all possible pointer values.
1087  // We could theoretically teach SCEV to truncate wider pointers, but
1088  // that isn't implemented for now.
1090  getDataLayout().getTypeSizeInBits(IntPtrTy))
1091  return getCouldNotCompute();
1092 
1093  // If not, is this expression something we can't reduce any further?
1094  if (auto *U = dyn_cast<SCEVUnknown>(Op)) {
1095  // Perform some basic constant folding. If the operand of the ptr2int cast
1096  // is a null pointer, don't create a ptr2int SCEV expression (that will be
1097  // left as-is), but produce a zero constant.
1098  // NOTE: We could handle a more general case, but lack motivational cases.
1099  if (isa<ConstantPointerNull>(U->getValue()))
1100  return getZero(IntPtrTy);
1101 
1102  // Create an explicit cast node.
1103  // We can reuse the existing insert position since if we get here,
1104  // we won't have made any changes which would invalidate it.
1105  SCEV *S = new (SCEVAllocator)
1106  SCEVPtrToIntExpr(ID.Intern(SCEVAllocator), Op, IntPtrTy);
1107  UniqueSCEVs.InsertNode(S, IP);
1108  registerUser(S, Op);
1109  return S;
1110  }
1111 
1112  assert(Depth == 0 && "getLosslessPtrToIntExpr() should not self-recurse for "
1113  "non-SCEVUnknown's.");
1114 
1115  // Otherwise, we've got some expression that is more complex than just a
1116  // single SCEVUnknown. But we don't want to have a SCEVPtrToIntExpr of an
1117  // arbitrary expression, we want to have SCEVPtrToIntExpr of an SCEVUnknown
1118  // only, and the expressions must otherwise be integer-typed.
1119  // So sink the cast down to the SCEVUnknown's.
1120 
1121  /// The SCEVPtrToIntSinkingRewriter takes a scalar evolution expression,
1122  /// which computes a pointer-typed value, and rewrites the whole expression
1123  /// tree so that *all* the computations are done on integers, and the only
1124  /// pointer-typed operands in the expression are SCEVUnknown.
1125  class SCEVPtrToIntSinkingRewriter
1126  : public SCEVRewriteVisitor<SCEVPtrToIntSinkingRewriter> {
1128 
1129  public:
1130  SCEVPtrToIntSinkingRewriter(ScalarEvolution &SE) : SCEVRewriteVisitor(SE) {}
1131 
1132  static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE) {
1133  SCEVPtrToIntSinkingRewriter Rewriter(SE);
1134  return Rewriter.visit(Scev);
1135  }
1136 
1137  const SCEV *visit(const SCEV *S) {
1138  Type *STy = S->getType();
1139  // If the expression is not pointer-typed, just keep it as-is.
1140  if (!STy->isPointerTy())
1141  return S;
1142  // Else, recursively sink the cast down into it.
1143  return Base::visit(S);
1144  }
1145 
1146  const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
1148  bool Changed = false;
1149  for (auto *Op : Expr->operands()) {
1150  Operands.push_back(visit(Op));
1151  Changed |= Op != Operands.back();
1152  }
1153  return !Changed ? Expr : SE.getAddExpr(Operands, Expr->getNoWrapFlags());
1154  }
1155 
1156  const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
1158  bool Changed = false;
1159  for (auto *Op : Expr->operands()) {
1160  Operands.push_back(visit(Op));
1161  Changed |= Op != Operands.back();
1162  }
1163  return !Changed ? Expr : SE.getMulExpr(Operands, Expr->getNoWrapFlags());
1164  }
1165 
1166  const SCEV *visitUnknown(const SCEVUnknown *Expr) {
1167  assert(Expr->getType()->isPointerTy() &&
1168  "Should only reach pointer-typed SCEVUnknown's.");
1169  return SE.getLosslessPtrToIntExpr(Expr, /*Depth=*/1);
1170  }
1171  };
1172 
1173  // And actually perform the cast sinking.
1174  const SCEV *IntOp = SCEVPtrToIntSinkingRewriter::rewrite(Op, *this);
1175  assert(IntOp->getType()->isIntegerTy() &&
1176  "We must have succeeded in sinking the cast, "
1177  "and ending up with an integer-typed expression!");
1178  return IntOp;
1179 }
1180 
1182  assert(Ty->isIntegerTy() && "Target type must be an integer type!");
1183 
1184  const SCEV *IntOp = getLosslessPtrToIntExpr(Op);
1185  if (isa<SCEVCouldNotCompute>(IntOp))
1186  return IntOp;
1187 
1188  return getTruncateOrZeroExtend(IntOp, Ty);
1189 }
1190 
1192  unsigned Depth) {
1193  assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) &&
1194  "This is not a truncating conversion!");
1195  assert(isSCEVable(Ty) &&
1196  "This is not a conversion to a SCEVable type!");
1197  assert(!Op->getType()->isPointerTy() && "Can't truncate pointer!");
1198  Ty = getEffectiveSCEVType(Ty);
1199 
1201  ID.AddInteger(scTruncate);
1202  ID.AddPointer(Op);
1203  ID.AddPointer(Ty);
1204  void *IP = nullptr;
1205  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1206 
1207  // Fold if the operand is constant.
1208  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1209  return getConstant(
1210  cast<ConstantInt>(ConstantExpr::getTrunc(SC->getValue(), Ty)));
1211 
1212  // trunc(trunc(x)) --> trunc(x)
1213  if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op))
1214  return getTruncateExpr(ST->getOperand(), Ty, Depth + 1);
1215 
1216  // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing
1217  if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
1218  return getTruncateOrSignExtend(SS->getOperand(), Ty, Depth + 1);
1219 
1220  // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing
1221  if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1222  return getTruncateOrZeroExtend(SZ->getOperand(), Ty, Depth + 1);
1223 
1224  if (Depth > MaxCastDepth) {
1225  SCEV *S =
1226  new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), Op, Ty);
1227  UniqueSCEVs.InsertNode(S, IP);
1228  registerUser(S, Op);
1229  return S;
1230  }
1231 
1232  // trunc(x1 + ... + xN) --> trunc(x1) + ... + trunc(xN) and
1233  // trunc(x1 * ... * xN) --> trunc(x1) * ... * trunc(xN),
1234  // if after transforming we have at most one truncate, not counting truncates
1235  // that replace other casts.
1236  if (isa<SCEVAddExpr>(Op) || isa<SCEVMulExpr>(Op)) {
1237  auto *CommOp = cast<SCEVCommutativeExpr>(Op);
1239  unsigned numTruncs = 0;
1240  for (unsigned i = 0, e = CommOp->getNumOperands(); i != e && numTruncs < 2;
1241  ++i) {
1242  const SCEV *S = getTruncateExpr(CommOp->getOperand(i), Ty, Depth + 1);
1243  if (!isa<SCEVIntegralCastExpr>(CommOp->getOperand(i)) &&
1244  isa<SCEVTruncateExpr>(S))
1245  numTruncs++;
1246  Operands.push_back(S);
1247  }
1248  if (numTruncs < 2) {
1249  if (isa<SCEVAddExpr>(Op))
1250  return getAddExpr(Operands);
1251  else if (isa<SCEVMulExpr>(Op))
1252  return getMulExpr(Operands);
1253  else
1254  llvm_unreachable("Unexpected SCEV type for Op.");
1255  }
1256  // Although we checked in the beginning that ID is not in the cache, it is
1257  // possible that during recursion and different modification ID was inserted
1258  // into the cache. So if we find it, just return it.
1259  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1260  return S;
1261  }
1262 
1263  // If the input value is a chrec scev, truncate the chrec's operands.
1264  if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
1266  for (const SCEV *Op : AddRec->operands())
1267  Operands.push_back(getTruncateExpr(Op, Ty, Depth + 1));
1268  return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap);
1269  }
1270 
1271  // Return zero if truncating to known zeros.
1272  uint32_t MinTrailingZeros = GetMinTrailingZeros(Op);
1273  if (MinTrailingZeros >= getTypeSizeInBits(Ty))
1274  return getZero(Ty);
1275 
1276  // The cast wasn't folded; create an explicit cast node. We can reuse
1277  // the existing insert position since if we get here, we won't have
1278  // made any changes which would invalidate it.
1279  SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator),
1280  Op, Ty);
1281  UniqueSCEVs.InsertNode(S, IP);
1282  registerUser(S, Op);
1283  return S;
1284 }
1285 
1286 // Get the limit of a recurrence such that incrementing by Step cannot cause
1287 // signed overflow as long as the value of the recurrence within the
1288 // loop does not exceed this limit before incrementing.
1289 static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step,
1290  ICmpInst::Predicate *Pred,
1291  ScalarEvolution *SE) {
1292  unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1293  if (SE->isKnownPositive(Step)) {
1294  *Pred = ICmpInst::ICMP_SLT;
1296  SE->getSignedRangeMax(Step));
1297  }
1298  if (SE->isKnownNegative(Step)) {
1299  *Pred = ICmpInst::ICMP_SGT;
1301  SE->getSignedRangeMin(Step));
1302  }
1303  return nullptr;
1304 }
1305 
1306 // Get the limit of a recurrence such that incrementing by Step cannot cause
1307 // unsigned overflow as long as the value of the recurrence within the loop does
1308 // not exceed this limit before incrementing.
1309 static const SCEV *getUnsignedOverflowLimitForStep(const SCEV *Step,
1310  ICmpInst::Predicate *Pred,
1311  ScalarEvolution *SE) {
1312  unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1313  *Pred = ICmpInst::ICMP_ULT;
1314 
1316  SE->getUnsignedRangeMax(Step));
1317 }
1318 
1319 namespace {
1320 
1321 struct ExtendOpTraitsBase {
1322  typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *,
1323  unsigned);
1324 };
1325 
1326 // Used to make code generic over signed and unsigned overflow.
1327 template <typename ExtendOp> struct ExtendOpTraits {
1328  // Members present:
1329  //
1330  // static const SCEV::NoWrapFlags WrapType;
1331  //
1332  // static const ExtendOpTraitsBase::GetExtendExprTy GetExtendExpr;
1333  //
1334  // static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1335  // ICmpInst::Predicate *Pred,
1336  // ScalarEvolution *SE);
1337 };
1338 
1339 template <>
1340 struct ExtendOpTraits<SCEVSignExtendExpr> : public ExtendOpTraitsBase {
1341  static const SCEV::NoWrapFlags WrapType = SCEV::FlagNSW;
1342 
1343  static const GetExtendExprTy GetExtendExpr;
1344 
1345  static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1346  ICmpInst::Predicate *Pred,
1347  ScalarEvolution *SE) {
1348  return getSignedOverflowLimitForStep(Step, Pred, SE);
1349  }
1350 };
1351 
1352 const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1354 
1355 template <>
1356 struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase {
1357  static const SCEV::NoWrapFlags WrapType = SCEV::FlagNUW;
1358 
1359  static const GetExtendExprTy GetExtendExpr;
1360 
1361  static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1362  ICmpInst::Predicate *Pred,
1363  ScalarEvolution *SE) {
1364  return getUnsignedOverflowLimitForStep(Step, Pred, SE);
1365  }
1366 };
1367 
1368 const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1370 
1371 } // end anonymous namespace
1372 
1373 // The recurrence AR has been shown to have no signed/unsigned wrap or something
1374 // close to it. Typically, if we can prove NSW/NUW for AR, then we can just as
1375 // easily prove NSW/NUW for its preincrement or postincrement sibling. This
1376 // allows normalizing a sign/zero extended AddRec as such: {sext/zext(Step +
1377 // Start),+,Step} => {(Step + sext/zext(Start),+,Step} As a result, the
1378 // expression "Step + sext/zext(PreIncAR)" is congruent with
1379 // "sext/zext(PostIncAR)"
1380 template <typename ExtendOpTy>
1381 static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty,
1382  ScalarEvolution *SE, unsigned Depth) {
1383  auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1384  auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1385 
1386  const Loop *L = AR->getLoop();
1387  const SCEV *Start = AR->getStart();
1388  const SCEV *Step = AR->getStepRecurrence(*SE);
1389 
1390  // Check for a simple looking step prior to loop entry.
1391  const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Start);
1392  if (!SA)
1393  return nullptr;
1394 
1395  // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV
1396  // subtraction is expensive. For this purpose, perform a quick and dirty
1397  // difference, by checking for Step in the operand list.
1399  for (const SCEV *Op : SA->operands())
1400  if (Op != Step)
1401  DiffOps.push_back(Op);
1402 
1403  if (DiffOps.size() == SA->getNumOperands())
1404  return nullptr;
1405 
1406  // Try to prove `WrapType` (SCEV::FlagNSW or SCEV::FlagNUW) on `PreStart` +
1407  // `Step`:
1408 
1409  // 1. NSW/NUW flags on the step increment.
1410  auto PreStartFlags =
1412  const SCEV *PreStart = SE->getAddExpr(DiffOps, PreStartFlags);
1413  const SCEVAddRecExpr *PreAR = dyn_cast<SCEVAddRecExpr>(
1414  SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap));
1415 
1416  // "{S,+,X} is <nsw>/<nuw>" and "the backedge is taken at least once" implies
1417  // "S+X does not sign/unsign-overflow".
1418  //
1419 
1420  const SCEV *BECount = SE->getBackedgeTakenCount(L);
1421  if (PreAR && PreAR->getNoWrapFlags(WrapType) &&
1422  !isa<SCEVCouldNotCompute>(BECount) && SE->isKnownPositive(BECount))
1423  return PreStart;
1424 
1425  // 2. Direct overflow check on the step operation's expression.
1426  unsigned BitWidth = SE->getTypeSizeInBits(AR->getType());
1427  Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2);
1428  const SCEV *OperandExtendedStart =
1429  SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy, Depth),
1430  (SE->*GetExtendExpr)(Step, WideTy, Depth));
1431  if ((SE->*GetExtendExpr)(Start, WideTy, Depth) == OperandExtendedStart) {
1432  if (PreAR && AR->getNoWrapFlags(WrapType)) {
1433  // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW
1434  // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then
1435  // `PreAR` == {`PreStart`,+,`Step`} is also `WrapType`. Cache this fact.
1436  SE->setNoWrapFlags(const_cast<SCEVAddRecExpr *>(PreAR), WrapType);
1437  }
1438  return PreStart;
1439  }
1440 
1441  // 3. Loop precondition.
1442  ICmpInst::Predicate Pred;
1443  const SCEV *OverflowLimit =
1444  ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(Step, &Pred, SE);
1445 
1446  if (OverflowLimit &&
1447  SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit))
1448  return PreStart;
1449 
1450  return nullptr;
1451 }
1452 
1453 // Get the normalized zero or sign extended expression for this AddRec's Start.
1454 template <typename ExtendOpTy>
1455 static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty,
1456  ScalarEvolution *SE,
1457  unsigned Depth) {
1458  auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1459 
1460  const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE, Depth);
1461  if (!PreStart)
1462  return (SE->*GetExtendExpr)(AR->getStart(), Ty, Depth);
1463 
1464  return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty,
1465  Depth),
1466  (SE->*GetExtendExpr)(PreStart, Ty, Depth));
1467 }
1468 
1469 // Try to prove away overflow by looking at "nearby" add recurrences. A
1470 // motivating example for this rule: if we know `{0,+,4}` is `ult` `-1` and it
1471 // does not itself wrap then we can conclude that `{1,+,4}` is `nuw`.
1472 //
1473 // Formally:
1474 //
1475 // {S,+,X} == {S-T,+,X} + T
1476 // => Ext({S,+,X}) == Ext({S-T,+,X} + T)
1477 //
1478 // If ({S-T,+,X} + T) does not overflow ... (1)
1479 //
1480 // RHS == Ext({S-T,+,X} + T) == Ext({S-T,+,X}) + Ext(T)
1481 //
1482 // If {S-T,+,X} does not overflow ... (2)
1483 //
1484 // RHS == Ext({S-T,+,X}) + Ext(T) == {Ext(S-T),+,Ext(X)} + Ext(T)
1485 // == {Ext(S-T)+Ext(T),+,Ext(X)}
1486 //
1487 // If (S-T)+T does not overflow ... (3)
1488 //
1489 // RHS == {Ext(S-T)+Ext(T),+,Ext(X)} == {Ext(S-T+T),+,Ext(X)}
1490 // == {Ext(S),+,Ext(X)} == LHS
1491 //
1492 // Thus, if (1), (2) and (3) are true for some T, then
1493 // Ext({S,+,X}) == {Ext(S),+,Ext(X)}
1494 //
1495 // (3) is implied by (1) -- "(S-T)+T does not overflow" is simply "({S-T,+,X}+T)
1496 // does not overflow" restricted to the 0th iteration. Therefore we only need
1497 // to check for (1) and (2).
1498 //
1499 // In the current context, S is `Start`, X is `Step`, Ext is `ExtendOpTy` and T
1500 // is `Delta` (defined below).
1501 template <typename ExtendOpTy>
1502 bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start,
1503  const SCEV *Step,
1504  const Loop *L) {
1505  auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1506 
1507  // We restrict `Start` to a constant to prevent SCEV from spending too much
1508  // time here. It is correct (but more expensive) to continue with a
1509  // non-constant `Start` and do a general SCEV subtraction to compute
1510  // `PreStart` below.
1511  const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start);
1512  if (!StartC)
1513  return false;
1514 
1515  APInt StartAI = StartC->getAPInt();
1516 
1517  for (unsigned Delta : {-2, -1, 1, 2}) {
1518  const SCEV *PreStart = getConstant(StartAI - Delta);
1519 
1521  ID.AddInteger(scAddRecExpr);
1522  ID.AddPointer(PreStart);
1523  ID.AddPointer(Step);
1524  ID.AddPointer(L);
1525  void *IP = nullptr;
1526  const auto *PreAR =
1527  static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
1528 
1529  // Give up if we don't already have the add recurrence we need because
1530  // actually constructing an add recurrence is relatively expensive.
1531  if (PreAR && PreAR->getNoWrapFlags(WrapType)) { // proves (2)
1532  const SCEV *DeltaS = getConstant(StartC->getType(), Delta);
1534  const SCEV *Limit = ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(
1535  DeltaS, &Pred, this);
1536  if (Limit && isKnownPredicate(Pred, PreAR, Limit)) // proves (1)
1537  return true;
1538  }
1539  }
1540 
1541  return false;
1542 }
1543 
1544 // Finds an integer D for an expression (C + x + y + ...) such that the top
1545 // level addition in (D + (C - D + x + y + ...)) would not wrap (signed or
1546 // unsigned) and the number of trailing zeros of (C - D + x + y + ...) is
1547 // maximized, where C is the \p ConstantTerm, x, y, ... are arbitrary SCEVs, and
1548 // the (C + x + y + ...) expression is \p WholeAddExpr.
1550  const SCEVConstant *ConstantTerm,
1551  const SCEVAddExpr *WholeAddExpr) {
1552  const APInt &C = ConstantTerm->getAPInt();
1553  const unsigned BitWidth = C.getBitWidth();
1554  // Find number of trailing zeros of (x + y + ...) w/o the C first:
1555  uint32_t TZ = BitWidth;
1556  for (unsigned I = 1, E = WholeAddExpr->getNumOperands(); I < E && TZ; ++I)
1557  TZ = std::min(TZ, SE.GetMinTrailingZeros(WholeAddExpr->getOperand(I)));
1558  if (TZ) {
1559  // Set D to be as many least significant bits of C as possible while still
1560  // guaranteeing that adding D to (C - D + x + y + ...) won't cause a wrap:
1561  return TZ < BitWidth ? C.trunc(TZ).zext(BitWidth) : C;
1562  }
1563  return APInt(BitWidth, 0);
1564 }
1565 
1566 // Finds an integer D for an affine AddRec expression {C,+,x} such that the top
1567 // level addition in (D + {C-D,+,x}) would not wrap (signed or unsigned) and the
1568 // number of trailing zeros of (C - D + x * n) is maximized, where C is the \p
1569 // ConstantStart, x is an arbitrary \p Step, and n is the loop trip count.
1571  const APInt &ConstantStart,
1572  const SCEV *Step) {
1573  const unsigned BitWidth = ConstantStart.getBitWidth();
1574  const uint32_t TZ = SE.GetMinTrailingZeros(Step);
1575  if (TZ)
1576  return TZ < BitWidth ? ConstantStart.trunc(TZ).zext(BitWidth)
1577  : ConstantStart;
1578  return APInt(BitWidth, 0);
1579 }
1580 
1581 const SCEV *
1583  assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1584  "This is not an extending conversion!");
1585  assert(isSCEVable(Ty) &&
1586  "This is not a conversion to a SCEVable type!");
1587  assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1588  Ty = getEffectiveSCEVType(Ty);
1589 
1590  // Fold if the operand is constant.
1591  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1592  return getConstant(
1593  cast<ConstantInt>(ConstantExpr::getZExt(SC->getValue(), Ty)));
1594 
1595  // zext(zext(x)) --> zext(x)
1596  if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1597  return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1598 
1599  // Before doing any expensive analysis, check to see if we've already
1600  // computed a SCEV for this Op and Ty.
1602  ID.AddInteger(scZeroExtend);
1603  ID.AddPointer(Op);
1604  ID.AddPointer(Ty);
1605  void *IP = nullptr;
1606  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1607  if (Depth > MaxCastDepth) {
1608  SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1609  Op, Ty);
1610  UniqueSCEVs.InsertNode(S, IP);
1611  registerUser(S, Op);
1612  return S;
1613  }
1614 
1615  // zext(trunc(x)) --> zext(x) or x or trunc(x)
1616  if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
1617  // It's possible the bits taken off by the truncate were all zero bits. If
1618  // so, we should be able to simplify this further.
1619  const SCEV *X = ST->getOperand();
1621  unsigned TruncBits = getTypeSizeInBits(ST->getType());
1622  unsigned NewBits = getTypeSizeInBits(Ty);
1623  if (CR.truncate(TruncBits).zeroExtend(NewBits).contains(
1624  CR.zextOrTrunc(NewBits)))
1625  return getTruncateOrZeroExtend(X, Ty, Depth);
1626  }
1627 
1628  // If the input value is a chrec scev, and we can prove that the value
1629  // did not overflow the old, smaller, value, we can zero extend all of the
1630  // operands (often constants). This allows analysis of something like
1631  // this: for (unsigned char X = 0; X < 100; ++X) { int Y = X; }
1632  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
1633  if (AR->isAffine()) {
1634  const SCEV *Start = AR->getStart();
1635  const SCEV *Step = AR->getStepRecurrence(*this);
1636  unsigned BitWidth = getTypeSizeInBits(AR->getType());
1637  const Loop *L = AR->getLoop();
1638 
1639  if (!AR->hasNoUnsignedWrap()) {
1640  auto NewFlags = proveNoWrapViaConstantRanges(AR);
1641  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
1642  }
1643 
1644  // If we have special knowledge that this addrec won't overflow,
1645  // we don't need to do any further analysis.
1646  if (AR->hasNoUnsignedWrap())
1647  return getAddRecExpr(
1648  getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1),
1649  getZeroExtendExpr(Step, Ty, Depth + 1), L, AR->getNoWrapFlags());
1650 
1651  // Check whether the backedge-taken count is SCEVCouldNotCompute.
1652  // Note that this serves two purposes: It filters out loops that are
1653  // simply not analyzable, and it covers the case where this code is
1654  // being called from within backedge-taken count analysis, such that
1655  // attempting to ask for the backedge-taken count would likely result
1656  // in infinite recursion. In the later case, the analysis code will
1657  // cope with a conservative value, and it will take care to purge
1658  // that value once it has finished.
1659  const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
1660  if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
1661  // Manually compute the final value for AR, checking for overflow.
1662 
1663  // Check whether the backedge-taken count can be losslessly casted to
1664  // the addrec's type. The count is always unsigned.
1665  const SCEV *CastedMaxBECount =
1666  getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
1667  const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
1668  CastedMaxBECount, MaxBECount->getType(), Depth);
1669  if (MaxBECount == RecastedMaxBECount) {
1670  Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
1671  // Check whether Start+Step*MaxBECount has no unsigned overflow.
1672  const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step,
1673  SCEV::FlagAnyWrap, Depth + 1);
1674  const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul,
1676  Depth + 1),
1677  WideTy, Depth + 1);
1678  const SCEV *WideStart = getZeroExtendExpr(Start, WideTy, Depth + 1);
1679  const SCEV *WideMaxBECount =
1680  getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
1681  const SCEV *OperandExtendedAdd =
1682  getAddExpr(WideStart,
1683  getMulExpr(WideMaxBECount,
1684  getZeroExtendExpr(Step, WideTy, Depth + 1),
1685  SCEV::FlagAnyWrap, Depth + 1),
1686  SCEV::FlagAnyWrap, Depth + 1);
1687  if (ZAdd == OperandExtendedAdd) {
1688  // Cache knowledge of AR NUW, which is propagated to this AddRec.
1689  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1690  // Return the expression with the addrec on the outside.
1691  return getAddRecExpr(
1692  getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1693  Depth + 1),
1694  getZeroExtendExpr(Step, Ty, Depth + 1), L,
1695  AR->getNoWrapFlags());
1696  }
1697  // Similar to above, only this time treat the step value as signed.
1698  // This covers loops that count down.
1699  OperandExtendedAdd =
1700  getAddExpr(WideStart,
1701  getMulExpr(WideMaxBECount,
1702  getSignExtendExpr(Step, WideTy, Depth + 1),
1703  SCEV::FlagAnyWrap, Depth + 1),
1704  SCEV::FlagAnyWrap, Depth + 1);
1705  if (ZAdd == OperandExtendedAdd) {
1706  // Cache knowledge of AR NW, which is propagated to this AddRec.
1707  // Negative step causes unsigned wrap, but it still can't self-wrap.
1708  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1709  // Return the expression with the addrec on the outside.
1710  return getAddRecExpr(
1711  getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1712  Depth + 1),
1713  getSignExtendExpr(Step, Ty, Depth + 1), L,
1714  AR->getNoWrapFlags());
1715  }
1716  }
1717  }
1718 
1719  // Normally, in the cases we can prove no-overflow via a
1720  // backedge guarding condition, we can also compute a backedge
1721  // taken count for the loop. The exceptions are assumptions and
1722  // guards present in the loop -- SCEV is not great at exploiting
1723  // these to compute max backedge taken counts, but can still use
1724  // these to prove lack of overflow. Use this fact to avoid
1725  // doing extra work that may not pay off.
1726  if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards ||
1727  !AC.assumptions().empty()) {
1728 
1729  auto NewFlags = proveNoUnsignedWrapViaInduction(AR);
1730  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
1731  if (AR->hasNoUnsignedWrap()) {
1732  // Same as nuw case above - duplicated here to avoid a compile time
1733  // issue. It's not clear that the order of checks does matter, but
1734  // it's one of two issue possible causes for a change which was
1735  // reverted. Be conservative for the moment.
1736  return getAddRecExpr(
1737  getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1738  Depth + 1),
1739  getZeroExtendExpr(Step, Ty, Depth + 1), L,
1740  AR->getNoWrapFlags());
1741  }
1742 
1743  // For a negative step, we can extend the operands iff doing so only
1744  // traverses values in the range zext([0,UINT_MAX]).
1745  if (isKnownNegative(Step)) {
1747  getSignedRangeMin(Step));
1750  // Cache knowledge of AR NW, which is propagated to this
1751  // AddRec. Negative step causes unsigned wrap, but it
1752  // still can't self-wrap.
1753  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1754  // Return the expression with the addrec on the outside.
1755  return getAddRecExpr(
1756  getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1757  Depth + 1),
1758  getSignExtendExpr(Step, Ty, Depth + 1), L,
1759  AR->getNoWrapFlags());
1760  }
1761  }
1762  }
1763 
1764  // zext({C,+,Step}) --> (zext(D) + zext({C-D,+,Step}))<nuw><nsw>
1765  // if D + (C - D + Step * n) could be proven to not unsigned wrap
1766  // where D maximizes the number of trailing zeros of (C - D + Step * n)
1767  if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
1768  const APInt &C = SC->getAPInt();
1769  const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
1770  if (D != 0) {
1771  const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1772  const SCEV *SResidual =
1773  getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
1774  const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1775  return getAddExpr(SZExtD, SZExtR,
1777  Depth + 1);
1778  }
1779  }
1780 
1781  if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) {
1782  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1783  return getAddRecExpr(
1784  getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1),
1785  getZeroExtendExpr(Step, Ty, Depth + 1), L, AR->getNoWrapFlags());
1786  }
1787  }
1788 
1789  // zext(A % B) --> zext(A) % zext(B)
1790  {
1791  const SCEV *LHS;
1792  const SCEV *RHS;
1793  if (matchURem(Op, LHS, RHS))
1794  return getURemExpr(getZeroExtendExpr(LHS, Ty, Depth + 1),
1795  getZeroExtendExpr(RHS, Ty, Depth + 1));
1796  }
1797 
1798  // zext(A / B) --> zext(A) / zext(B).
1799  if (auto *Div = dyn_cast<SCEVUDivExpr>(Op))
1800  return getUDivExpr(getZeroExtendExpr(Div->getLHS(), Ty, Depth + 1),
1801  getZeroExtendExpr(Div->getRHS(), Ty, Depth + 1));
1802 
1803  if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1804  // zext((A + B + ...)<nuw>) --> (zext(A) + zext(B) + ...)<nuw>
1805  if (SA->hasNoUnsignedWrap()) {
1806  // If the addition does not unsign overflow then we can, by definition,
1807  // commute the zero extension with the addition operation.
1809  for (const auto *Op : SA->operands())
1810  Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1811  return getAddExpr(Ops, SCEV::FlagNUW, Depth + 1);
1812  }
1813 
1814  // zext(C + x + y + ...) --> (zext(D) + zext((C - D) + x + y + ...))
1815  // if D + (C - D + x + y + ...) could be proven to not unsigned wrap
1816  // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1817  //
1818  // Often address arithmetics contain expressions like
1819  // (zext (add (shl X, C1), C2)), for instance, (zext (5 + (4 * X))).
1820  // This transformation is useful while proving that such expressions are
1821  // equal or differ by a small constant amount, see LoadStoreVectorizer pass.
1822  if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1823  const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1824  if (D != 0) {
1825  const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1826  const SCEV *SResidual =
1828  const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1829  return getAddExpr(SZExtD, SZExtR,
1831  Depth + 1);
1832  }
1833  }
1834  }
1835 
1836  if (auto *SM = dyn_cast<SCEVMulExpr>(Op)) {
1837  // zext((A * B * ...)<nuw>) --> (zext(A) * zext(B) * ...)<nuw>
1838  if (SM->hasNoUnsignedWrap()) {
1839  // If the multiply does not unsign overflow then we can, by definition,
1840  // commute the zero extension with the multiply operation.
1842  for (const auto *Op : SM->operands())
1843  Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1844  return getMulExpr(Ops, SCEV::FlagNUW, Depth + 1);
1845  }
1846 
1847  // zext(2^K * (trunc X to iN)) to iM ->
1848  // 2^K * (zext(trunc X to i{N-K}) to iM)<nuw>
1849  //
1850  // Proof:
1851  //
1852  // zext(2^K * (trunc X to iN)) to iM
1853  // = zext((trunc X to iN) << K) to iM
1854  // = zext((trunc X to i{N-K}) << K)<nuw> to iM
1855  // (because shl removes the top K bits)
1856  // = zext((2^K * (trunc X to i{N-K}))<nuw>) to iM
1857  // = (2^K * (zext(trunc X to i{N-K}) to iM))<nuw>.
1858  //
1859  if (SM->getNumOperands() == 2)
1860  if (auto *MulLHS = dyn_cast<SCEVConstant>(SM->getOperand(0)))
1861  if (MulLHS->getAPInt().isPowerOf2())
1862  if (auto *TruncRHS = dyn_cast<SCEVTruncateExpr>(SM->getOperand(1))) {
1863  int NewTruncBits = getTypeSizeInBits(TruncRHS->getType()) -
1864  MulLHS->getAPInt().logBase2();
1865  Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits);
1866  return getMulExpr(
1867  getZeroExtendExpr(MulLHS, Ty),
1869  getTruncateExpr(TruncRHS->getOperand(), NewTruncTy), Ty),
1870  SCEV::FlagNUW, Depth + 1);
1871  }
1872  }
1873 
1874  // The cast wasn't folded; create an explicit cast node.
1875  // Recompute the insert position, as it may have been invalidated.
1876  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1877  SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1878  Op, Ty);
1879  UniqueSCEVs.InsertNode(S, IP);
1880  registerUser(S, Op);
1881  return S;
1882 }
1883 
1884 const SCEV *
1886  assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1887  "This is not an extending conversion!");
1888  assert(isSCEVable(Ty) &&
1889  "This is not a conversion to a SCEVable type!");
1890  assert(!Op->getType()->isPointerTy() && "Can't extend pointer!");
1891  Ty = getEffectiveSCEVType(Ty);
1892 
1893  // Fold if the operand is constant.
1894  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1895  return getConstant(
1896  cast<ConstantInt>(ConstantExpr::getSExt(SC->getValue(), Ty)));
1897 
1898  // sext(sext(x)) --> sext(x)
1899  if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
1900  return getSignExtendExpr(SS->getOperand(), Ty, Depth + 1);
1901 
1902  // sext(zext(x)) --> zext(x)
1903  if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1904  return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1905 
1906  // Before doing any expensive analysis, check to see if we've already
1907  // computed a SCEV for this Op and Ty.
1909  ID.AddInteger(scSignExtend);
1910  ID.AddPointer(Op);
1911  ID.AddPointer(Ty);
1912  void *IP = nullptr;
1913  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1914  // Limit recursion depth.
1915  if (Depth > MaxCastDepth) {
1916  SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
1917  Op, Ty);
1918  UniqueSCEVs.InsertNode(S, IP);
1919  registerUser(S, Op);
1920  return S;
1921  }
1922 
1923  // sext(trunc(x)) --> sext(x) or x or trunc(x)
1924  if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
1925  // It's possible the bits taken off by the truncate were all sign bits. If
1926  // so, we should be able to simplify this further.
1927  const SCEV *X = ST->getOperand();
1929  unsigned TruncBits = getTypeSizeInBits(ST->getType());
1930  unsigned NewBits = getTypeSizeInBits(Ty);
1931  if (CR.truncate(TruncBits).signExtend(NewBits).contains(
1932  CR.sextOrTrunc(NewBits)))
1933  return getTruncateOrSignExtend(X, Ty, Depth);
1934  }
1935 
1936  if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1937  // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
1938  if (SA->hasNoSignedWrap()) {
1939  // If the addition does not sign overflow then we can, by definition,
1940  // commute the sign extension with the addition operation.
1942  for (const auto *Op : SA->operands())
1943  Ops.push_back(getSignExtendExpr(Op, Ty, Depth + 1));
1944  return getAddExpr(Ops, SCEV::FlagNSW, Depth + 1);
1945  }
1946 
1947  // sext(C + x + y + ...) --> (sext(D) + sext((C - D) + x + y + ...))
1948  // if D + (C - D + x + y + ...) could be proven to not signed wrap
1949  // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1950  //
1951  // For instance, this will bring two seemingly different expressions:
1952  // 1 + sext(5 + 20 * %x + 24 * %y) and
1953  // sext(6 + 20 * %x + 24 * %y)
1954  // to the same form:
1955  // 2 + sext(4 + 20 * %x + 24 * %y)
1956  if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1957  const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1958  if (D != 0) {
1959  const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
1960  const SCEV *SResidual =
1962  const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
1963  return getAddExpr(SSExtD, SSExtR,
1965  Depth + 1);
1966  }
1967  }
1968  }
1969  // If the input value is a chrec scev, and we can prove that the value
1970  // did not overflow the old, smaller, value, we can sign extend all of the
1971  // operands (often constants). This allows analysis of something like
1972  // this: for (signed char X = 0; X < 100; ++X) { int Y = X; }
1973  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
1974  if (AR->isAffine()) {
1975  const SCEV *Start = AR->getStart();
1976  const SCEV *Step = AR->getStepRecurrence(*this);
1977  unsigned BitWidth = getTypeSizeInBits(AR->getType());
1978  const Loop *L = AR->getLoop();
1979 
1980  if (!AR->hasNoSignedWrap()) {
1981  auto NewFlags = proveNoWrapViaConstantRanges(AR);
1982  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
1983  }
1984 
1985  // If we have special knowledge that this addrec won't overflow,
1986  // we don't need to do any further analysis.
1987  if (AR->hasNoSignedWrap())
1988  return getAddRecExpr(
1989  getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1),
1990  getSignExtendExpr(Step, Ty, Depth + 1), L, SCEV::FlagNSW);
1991 
1992  // Check whether the backedge-taken count is SCEVCouldNotCompute.
1993  // Note that this serves two purposes: It filters out loops that are
1994  // simply not analyzable, and it covers the case where this code is
1995  // being called from within backedge-taken count analysis, such that
1996  // attempting to ask for the backedge-taken count would likely result
1997  // in infinite recursion. In the later case, the analysis code will
1998  // cope with a conservative value, and it will take care to purge
1999  // that value once it has finished.
2000  const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
2001  if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
2002  // Manually compute the final value for AR, checking for
2003  // overflow.
2004 
2005  // Check whether the backedge-taken count can be losslessly casted to
2006  // the addrec's type. The count is always unsigned.
2007  const SCEV *CastedMaxBECount =
2008  getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
2009  const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
2010  CastedMaxBECount, MaxBECount->getType(), Depth);
2011  if (MaxBECount == RecastedMaxBECount) {
2012  Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
2013  // Check whether Start+Step*MaxBECount has no signed overflow.
2014  const SCEV *SMul = getMulExpr(CastedMaxBECount, Step,
2015  SCEV::FlagAnyWrap, Depth + 1);
2016  const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul,
2018  Depth + 1),
2019  WideTy, Depth + 1);
2020  const SCEV *WideStart = getSignExtendExpr(Start, WideTy, Depth + 1);
2021  const SCEV *WideMaxBECount =
2022  getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
2023  const SCEV *OperandExtendedAdd =
2024  getAddExpr(WideStart,
2025  getMulExpr(WideMaxBECount,
2026  getSignExtendExpr(Step, WideTy, Depth + 1),
2027  SCEV::FlagAnyWrap, Depth + 1),
2028  SCEV::FlagAnyWrap, Depth + 1);
2029  if (SAdd == OperandExtendedAdd) {
2030  // Cache knowledge of AR NSW, which is propagated to this AddRec.
2031  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2032  // Return the expression with the addrec on the outside.
2033  return getAddRecExpr(
2034  getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2035  Depth + 1),
2036  getSignExtendExpr(Step, Ty, Depth + 1), L,
2037  AR->getNoWrapFlags());
2038  }
2039  // Similar to above, only this time treat the step value as unsigned.
2040  // This covers loops that count up with an unsigned step.
2041  OperandExtendedAdd =
2042  getAddExpr(WideStart,
2043  getMulExpr(WideMaxBECount,
2044  getZeroExtendExpr(Step, WideTy, Depth + 1),
2045  SCEV::FlagAnyWrap, Depth + 1),
2046  SCEV::FlagAnyWrap, Depth + 1);
2047  if (SAdd == OperandExtendedAdd) {
2048  // If AR wraps around then
2049  //
2050  // abs(Step) * MaxBECount > unsigned-max(AR->getType())
2051  // => SAdd != OperandExtendedAdd
2052  //
2053  // Thus (AR is not NW => SAdd != OperandExtendedAdd) <=>
2054  // (SAdd == OperandExtendedAdd => AR is NW)
2055 
2056  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
2057 
2058  // Return the expression with the addrec on the outside.
2059  return getAddRecExpr(
2060  getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2061  Depth + 1),
2062  getZeroExtendExpr(Step, Ty, Depth + 1), L,
2063  AR->getNoWrapFlags());
2064  }
2065  }
2066  }
2067 
2068  auto NewFlags = proveNoSignedWrapViaInduction(AR);
2069  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
2070  if (AR->hasNoSignedWrap()) {
2071  // Same as nsw case above - duplicated here to avoid a compile time
2072  // issue. It's not clear that the order of checks does matter, but
2073  // it's one of two issue possible causes for a change which was
2074  // reverted. Be conservative for the moment.
2075  return getAddRecExpr(
2076  getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1),
2077  getSignExtendExpr(Step, Ty, Depth + 1), L, AR->getNoWrapFlags());
2078  }
2079 
2080  // sext({C,+,Step}) --> (sext(D) + sext({C-D,+,Step}))<nuw><nsw>
2081  // if D + (C - D + Step * n) could be proven to not signed wrap
2082  // where D maximizes the number of trailing zeros of (C - D + Step * n)
2083  if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
2084  const APInt &C = SC->getAPInt();
2085  const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
2086  if (D != 0) {
2087  const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
2088  const SCEV *SResidual =
2089  getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
2090  const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
2091  return getAddExpr(SSExtD, SSExtR,
2093  Depth + 1);
2094  }
2095  }
2096 
2097  if (proveNoWrapByVaryingStart<SCEVSignExtendExpr>(Start, Step, L)) {
2098  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2099  return getAddRecExpr(
2100  getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1),
2101  getSignExtendExpr(Step, Ty, Depth + 1), L, AR->getNoWrapFlags());
2102  }
2103  }
2104 
2105  // If the input value is provably positive and we could not simplify
2106  // away the sext build a zext instead.
2107  if (isKnownNonNegative(Op))
2108  return getZeroExtendExpr(Op, Ty, Depth + 1);
2109 
2110  // The cast wasn't folded; create an explicit cast node.
2111  // Recompute the insert position, as it may have been invalidated.
2112  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2113  SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
2114  Op, Ty);
2115  UniqueSCEVs.InsertNode(S, IP);
2116  registerUser(S, { Op });
2117  return S;
2118 }
2119 
2121  Type *Ty) {
2122  switch (Kind) {
2123  case scTruncate:
2124  return getTruncateExpr(Op, Ty);
2125  case scZeroExtend:
2126  return getZeroExtendExpr(Op, Ty);
2127  case scSignExtend:
2128  return getSignExtendExpr(Op, Ty);
2129  case scPtrToInt:
2130  return getPtrToIntExpr(Op, Ty);
2131  default:
2132  llvm_unreachable("Not a SCEV cast expression!");
2133  }
2134 }
2135 
2136 /// getAnyExtendExpr - Return a SCEV for the given operand extended with
2137 /// unspecified bits out to the given type.
2139  Type *Ty) {
2140  assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
2141  "This is not an extending conversion!");
2142  assert(isSCEVable(Ty) &&
2143  "This is not a conversion to a SCEVable type!");
2144  Ty = getEffectiveSCEVType(Ty);
2145 
2146  // Sign-extend negative constants.
2147  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
2148  if (SC->getAPInt().isNegative())
2149  return getSignExtendExpr(Op, Ty);
2150 
2151  // Peel off a truncate cast.
2152  if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Op)) {
2153  const SCEV *NewOp = T->getOperand();
2154  if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty))
2155  return getAnyExtendExpr(NewOp, Ty);
2156  return getTruncateOrNoop(NewOp, Ty);
2157  }
2158 
2159  // Next try a zext cast. If the cast is folded, use it.
2160  const SCEV *ZExt = getZeroExtendExpr(Op, Ty);
2161  if (!isa<SCEVZeroExtendExpr>(ZExt))
2162  return ZExt;
2163 
2164  // Next try a sext cast. If the cast is folded, use it.
2165  const SCEV *SExt = getSignExtendExpr(Op, Ty);
2166  if (!isa<SCEVSignExtendExpr>(SExt))
2167  return SExt;
2168 
2169  // Force the cast to be folded into the operands of an addrec.
2170  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) {
2172  for (const SCEV *Op : AR->operands())
2173  Ops.push_back(getAnyExtendExpr(Op, Ty));
2174  return getAddRecExpr(Ops, AR->getLoop(), SCEV::FlagNW);
2175  }
2176 
2177  // If the expression is obviously signed, use the sext cast value.
2178  if (isa<SCEVSMaxExpr>(Op))
2179  return SExt;
2180 
2181  // Absent any other information, use the zext cast value.
2182  return ZExt;
2183 }
2184 
2185 /// Process the given Ops list, which is a list of operands to be added under
2186 /// the given scale, update the given map. This is a helper function for
2187 /// getAddRecExpr. As an example of what it does, given a sequence of operands
2188 /// that would form an add expression like this:
2189 ///
2190 /// m + n + 13 + (A * (o + p + (B * (q + m + 29)))) + r + (-1 * r)
2191 ///
2192 /// where A and B are constants, update the map with these values:
2193 ///
2194 /// (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0)
2195 ///
2196 /// and add 13 + A*B*29 to AccumulatedConstant.
2197 /// This will allow getAddRecExpr to produce this:
2198 ///
2199 /// 13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B)
2200 ///
2201 /// This form often exposes folding opportunities that are hidden in
2202 /// the original operand list.
2203 ///
2204 /// Return true iff it appears that any interesting folding opportunities
2205 /// may be exposed. This helps getAddRecExpr short-circuit extra work in
2206 /// the common case where no interesting opportunities are present, and
2207 /// is also used as a check to avoid infinite recursion.
2208 static bool
2211  APInt &AccumulatedConstant,
2212  const SCEV *const *Ops, size_t NumOperands,
2213  const APInt &Scale,
2214  ScalarEvolution &SE) {
2215  bool Interesting = false;
2216 
2217  // Iterate over the add operands. They are sorted, with constants first.
2218  unsigned i = 0;
2219  while (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
2220  ++i;
2221  // Pull a buried constant out to the outside.
2222  if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero())
2223  Interesting = true;
2224  AccumulatedConstant += Scale * C->getAPInt();
2225  }
2226 
2227  // Next comes everything else. We're especially interested in multiplies
2228  // here, but they're in the middle, so just visit the rest with one loop.
2229  for (; i != NumOperands; ++i) {
2230  const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[i]);
2231  if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) {
2232  APInt NewScale =
2233  Scale * cast<SCEVConstant>(Mul->getOperand(0))->getAPInt();
2234  if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) {
2235  // A multiplication of a constant with another add; recurse.
2236  const SCEVAddExpr *Add = cast<SCEVAddExpr>(Mul->getOperand(1));
2237  Interesting |=
2238  CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2239  Add->op_begin(), Add->getNumOperands(),
2240  NewScale, SE);
2241  } else {
2242  // A multiplication of a constant with some other value. Update
2243  // the map.
2245  const SCEV *Key = SE.getMulExpr(MulOps);
2246  auto Pair = M.insert({Key, NewScale});
2247  if (Pair.second) {
2248  NewOps.push_back(Pair.first->first);
2249  } else {
2250  Pair.first->second += NewScale;
2251  // The map already had an entry for this value, which may indicate
2252  // a folding opportunity.
2253  Interesting = true;
2254  }
2255  }
2256  } else {
2257  // An ordinary operand. Update the map.
2258  std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair =
2259  M.insert({Ops[i], Scale});
2260  if (Pair.second) {
2261  NewOps.push_back(Pair.first->first);
2262  } else {
2263  Pair.first->second += Scale;
2264  // The map already had an entry for this value, which may indicate
2265  // a folding opportunity.
2266  Interesting = true;
2267  }
2268  }
2269  }
2270 
2271  return Interesting;
2272 }
2273 
2275  const SCEV *LHS, const SCEV *RHS) {
2276  const SCEV *(ScalarEvolution::*Operation)(const SCEV *, const SCEV *,
2277  SCEV::NoWrapFlags, unsigned);
2278  switch (BinOp) {
2279  default:
2280  llvm_unreachable("Unsupported binary op");
2281  case Instruction::Add:
2283  break;
2284  case Instruction::Sub:
2286  break;
2287  case Instruction::Mul:
2289  break;
2290  }
2291 
2292  const SCEV *(ScalarEvolution::*Extension)(const SCEV *, Type *, unsigned) =
2295 
2296  // Check ext(LHS op RHS) == ext(LHS) op ext(RHS)
2297  auto *NarrowTy = cast<IntegerType>(LHS->getType());
2298  auto *WideTy =
2299  IntegerType::get(NarrowTy->getContext(), NarrowTy->getBitWidth() * 2);
2300 
2301  const SCEV *A = (this->*Extension)(
2302  (this->*Operation)(LHS, RHS, SCEV::FlagAnyWrap, 0), WideTy, 0);
2303  const SCEV *B = (this->*Operation)((this->*Extension)(LHS, WideTy, 0),
2304  (this->*Extension)(RHS, WideTy, 0),
2305  SCEV::FlagAnyWrap, 0);
2306  return A == B;
2307 }
2308 
2309 std::pair<SCEV::NoWrapFlags, bool /*Deduced*/>
2311  const OverflowingBinaryOperator *OBO) {
2312  SCEV::NoWrapFlags Flags = SCEV::NoWrapFlags::FlagAnyWrap;
2313 
2314  if (OBO->hasNoUnsignedWrap())
2315  Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
2316  if (OBO->hasNoSignedWrap())
2317  Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW);
2318 
2319  bool Deduced = false;
2320 
2321  if (OBO->hasNoUnsignedWrap() && OBO->hasNoSignedWrap())
2322  return {Flags, Deduced};
2323 
2324  if (OBO->getOpcode() != Instruction::Add &&
2325  OBO->getOpcode() != Instruction::Sub &&
2326  OBO->getOpcode() != Instruction::Mul)
2327  return {Flags, Deduced};
2328 
2329  const SCEV *LHS = getSCEV(OBO->getOperand(0));
2330  const SCEV *RHS = getSCEV(OBO->getOperand(1));
2331 
2332  if (!OBO->hasNoUnsignedWrap() &&
2334  /* Signed */ false, LHS, RHS)) {
2335  Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
2336  Deduced = true;
2337  }
2338 
2339  if (!OBO->hasNoSignedWrap() &&
2341  /* Signed */ true, LHS, RHS)) {
2342  Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW);
2343  Deduced = true;
2344  }
2345 
2346  return {Flags, Deduced};
2347 }
2348 
2349 // We're trying to construct a SCEV of type `Type' with `Ops' as operands and
2350 // `OldFlags' as can't-wrap behavior. Infer a more aggressive set of
2351 // can't-overflow flags for the operation if possible.
2352 static SCEV::NoWrapFlags
2354  const ArrayRef<const SCEV *> Ops,
2355  SCEV::NoWrapFlags Flags) {
2356  using namespace std::placeholders;
2357 
2358  using OBO = OverflowingBinaryOperator;
2359 
2360  bool CanAnalyze =
2361  Type == scAddExpr || Type == scAddRecExpr || Type == scMulExpr;
2362  (void)CanAnalyze;
2363  assert(CanAnalyze && "don't call from other places!");
2364 
2365  int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
2366  SCEV::NoWrapFlags SignOrUnsignWrap =
2367  ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2368 
2369  // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
2370  auto IsKnownNonNegative = [&](const SCEV *S) {
2371  return SE->isKnownNonNegative(S);
2372  };
2373 
2374  if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative))
2375  Flags =
2376  ScalarEvolution::setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
2377 
2378  SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2379 
2380  if (SignOrUnsignWrap != SignOrUnsignMask &&
2381  (Type == scAddExpr || Type == scMulExpr) && Ops.size() == 2 &&
2382  isa<SCEVConstant>(Ops[0])) {
2383 
2384  auto Opcode = [&] {
2385  switch (Type) {
2386  case scAddExpr:
2387  return Instruction::Add;
2388  case scMulExpr:
2389  return Instruction::Mul;
2390  default:
2391  llvm_unreachable("Unexpected SCEV op.");
2392  }
2393  }();
2394 
2395  const APInt &C = cast<SCEVConstant>(Ops[0])->getAPInt();
2396 
2397  // (A <opcode> C) --> (A <opcode> C)<nsw> if the op doesn't sign overflow.
2398  if (!(SignOrUnsignWrap & SCEV::FlagNSW)) {
2400  Opcode, C, OBO::NoSignedWrap);
2401  if (NSWRegion.contains(SE->getSignedRange(Ops[1])))
2402  Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW);
2403  }
2404 
2405  // (A <opcode> C) --> (A <opcode> C)<nuw> if the op doesn't unsign overflow.
2406  if (!(SignOrUnsignWrap & SCEV::FlagNUW)) {
2408  Opcode, C, OBO::NoUnsignedWrap);
2409  if (NUWRegion.contains(SE->getUnsignedRange(Ops[1])))
2410  Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
2411  }
2412  }
2413 
2414  // <0,+,nonnegative><nw> is also nuw
2415  // TODO: Add corresponding nsw case
2417  !ScalarEvolution::hasFlags(Flags, SCEV::FlagNUW) && Ops.size() == 2 &&
2418  Ops[0]->isZero() && IsKnownNonNegative(Ops[1]))
2419  Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
2420 
2421  // both (udiv X, Y) * Y and Y * (udiv X, Y) are always NUW
2423  Ops.size() == 2) {
2424  if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[0]))
2425  if (UDiv->getOperand(1) == Ops[1])
2426  Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
2427  if (auto *UDiv = dyn_cast<SCEVUDivExpr>(Ops[1]))
2428  if (UDiv->getOperand(1) == Ops[0])
2429  Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
2430  }
2431 
2432  return Flags;
2433 }
2434 
2436  return isLoopInvariant(S, L) && properlyDominates(S, L->getHeader());
2437 }
2438 
2439 /// Get a canonical add expression, or something simpler if possible.
2441  SCEV::NoWrapFlags OrigFlags,
2442  unsigned Depth) {
2443  assert(!(OrigFlags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) &&
2444  "only nuw or nsw allowed");
2445  assert(!Ops.empty() && "Cannot get empty add!");
2446  if (Ops.size() == 1) return Ops[0];
2447 #ifndef NDEBUG
2448  Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
2449  for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2450  assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
2451  "SCEVAddExpr operand types don't match!");
2452  unsigned NumPtrs = count_if(
2453  Ops, [](const SCEV *Op) { return Op->getType()->isPointerTy(); });
2454  assert(NumPtrs <= 1 && "add has at most one pointer operand");
2455 #endif
2456 
2457  // Sort by complexity, this groups all similar expression types together.
2458  GroupByComplexity(Ops, &LI, DT);
2459 
2460  // If there are any constants, fold them together.
2461  unsigned Idx = 0;
2462  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
2463  ++Idx;
2464  assert(Idx < Ops.size());
2465  while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
2466  // We found two constants, fold them together!
2467  Ops[0] = getConstant(LHSC->getAPInt() + RHSC->getAPInt());
2468  if (Ops.size() == 2) return Ops[0];
2469  Ops.erase(Ops.begin()+1); // Erase the folded element
2470  LHSC = cast<SCEVConstant>(Ops[0]);
2471  }
2472 
2473  // If we are left with a constant zero being added, strip it off.
2474  if (LHSC->getValue()->isZero()) {
2475  Ops.erase(Ops.begin());
2476  --Idx;
2477  }
2478 
2479  if (Ops.size() == 1) return Ops[0];
2480  }
2481 
2482  // Delay expensive flag strengthening until necessary.
2483  auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
2484  return StrengthenNoWrapFlags(this, scAddExpr, Ops, OrigFlags);
2485  };
2486 
2487  // Limit recursion calls depth.
2488  if (Depth > MaxArithDepth || hasHugeExpression(Ops))
2489  return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2490 
2491  if (SCEV *S = findExistingSCEVInCache(scAddExpr, Ops)) {
2492  // Don't strengthen flags if we have no new information.
2493  SCEVAddExpr *Add = static_cast<SCEVAddExpr *>(S);
2494  if (Add->getNoWrapFlags(OrigFlags) != OrigFlags)
2495  Add->setNoWrapFlags(ComputeFlags(Ops));
2496  return S;
2497  }
2498 
2499  // Okay, check to see if the same value occurs in the operand list more than
2500  // once. If so, merge them together into an multiply expression. Since we
2501  // sorted the list, these values are required to be adjacent.
2502  Type *Ty = Ops[0]->getType();
2503  bool FoundMatch = false;
2504  for (unsigned i = 0, e = Ops.size(); i != e-1; ++i)
2505  if (Ops[i] == Ops[i+1]) { // X + Y + Y --> X + Y*2
2506  // Scan ahead to count how many equal operands there are.
2507  unsigned Count = 2;
2508  while (i+Count != e && Ops[i+Count] == Ops[i])
2509  ++Count;
2510  // Merge the values into a multiply.
2511  const SCEV *Scale = getConstant(Ty, Count);
2512  const SCEV *Mul = getMulExpr(Scale, Ops[i], SCEV::FlagAnyWrap, Depth + 1);
2513  if (Ops.size() == Count)
2514  return Mul;
2515  Ops[i] = Mul;
2516  Ops.erase(Ops.begin()+i+1, Ops.begin()+i+Count);
2517  --i; e -= Count - 1;
2518  FoundMatch = true;
2519  }
2520  if (FoundMatch)
2521  return getAddExpr(Ops, OrigFlags, Depth + 1);
2522 
2523  // Check for truncates. If all the operands are truncated from the same
2524  // type, see if factoring out the truncate would permit the result to be
2525  // folded. eg., n*trunc(x) + m*trunc(y) --> trunc(trunc(m)*x + trunc(n)*y)
2526  // if the contents of the resulting outer trunc fold to something simple.
2527  auto FindTruncSrcType = [&]() -> Type * {
2528  // We're ultimately looking to fold an addrec of truncs and muls of only
2529  // constants and truncs, so if we find any other types of SCEV
2530  // as operands of the addrec then we bail and return nullptr here.
2531  // Otherwise, we return the type of the operand of a trunc that we find.
2532  if (auto *T = dyn_cast<SCEVTruncateExpr>(Ops[Idx]))
2533  return T->getOperand()->getType();
2534  if (const auto *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
2535  const auto *LastOp = Mul->getOperand(Mul->getNumOperands() - 1);
2536  if (const auto *T = dyn_cast<SCEVTruncateExpr>(LastOp))
2537  return T->getOperand()->getType();
2538  }
2539  return nullptr;
2540  };
2541  if (auto *SrcType = FindTruncSrcType()) {
2543  bool Ok = true;
2544  // Check all the operands to see if they can be represented in the
2545  // source type of the truncate.
2546  for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
2547  if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Ops[i])) {
2548  if (T->getOperand()->getType() != SrcType) {
2549  Ok = false;
2550  break;
2551  }
2552  LargeOps.push_back(T->getOperand());
2553  } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
2554  LargeOps.push_back(getAnyExtendExpr(C, SrcType));
2555  } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Ops[i])) {
2556  SmallVector<const SCEV *, 8> LargeMulOps;
2557  for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) {
2558  if (const SCEVTruncateExpr *T =
2559  dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) {
2560  if (T->getOperand()->getType() != SrcType) {
2561  Ok = false;
2562  break;
2563  }
2564  LargeMulOps.push_back(T->getOperand());
2565  } else if (const auto *C = dyn_cast<SCEVConstant>(M->getOperand(j))) {
2566  LargeMulOps.push_back(getAnyExtendExpr(C, SrcType));
2567  } else {
2568  Ok = false;
2569  break;
2570  }
2571  }
2572  if (Ok)
2573  LargeOps.push_back(getMulExpr(LargeMulOps, SCEV::FlagAnyWrap, Depth + 1));
2574  } else {
2575  Ok = false;
2576  break;
2577  }
2578  }
2579  if (Ok) {
2580  // Evaluate the expression in the larger type.
2581  const SCEV *Fold = getAddExpr(LargeOps, SCEV::FlagAnyWrap, Depth + 1);
2582  // If it folds to something simple, use it. Otherwise, don't.
2583  if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
2584  return getTruncateExpr(Fold, Ty);
2585  }
2586  }
2587 
2588  if (Ops.size() == 2) {
2589  // Check if we have an expression of the form ((X + C1) - C2), where C1 and
2590  // C2 can be folded in a way that allows retaining wrapping flags of (X +
2591  // C1).
2592  const SCEV *A = Ops[0];
2593  const SCEV *B = Ops[1];
2594  auto *AddExpr = dyn_cast<SCEVAddExpr>(B);
2595  auto *C = dyn_cast<SCEVConstant>(A);
2596  if (AddExpr && C && isa<SCEVConstant>(AddExpr->getOperand(0))) {
2597  auto C1 = cast<SCEVConstant>(AddExpr->getOperand(0))->getAPInt();
2598  auto C2 = C->getAPInt();
2599  SCEV::NoWrapFlags PreservedFlags = SCEV::FlagAnyWrap;
2600 
2601  APInt ConstAdd = C1 + C2;
2602  auto AddFlags = AddExpr->getNoWrapFlags();
2603  // Adding a smaller constant is NUW if the original AddExpr was NUW.
2604  if (ScalarEvolution::hasFlags(AddFlags, SCEV::FlagNUW) &&
2605  ConstAdd.ule(C1)) {
2606  PreservedFlags =
2607  ScalarEvolution::setFlags(PreservedFlags, SCEV::FlagNUW);
2608  }
2609 
2610  // Adding a constant with the same sign and small magnitude is NSW, if the
2611  // original AddExpr was NSW.
2612  if (ScalarEvolution::hasFlags(AddFlags, SCEV::FlagNSW) &&
2613  C1.isSignBitSet() == ConstAdd.isSignBitSet() &&
2614  ConstAdd.abs().ule(C1.abs())) {
2615  PreservedFlags =
2616  ScalarEvolution::setFlags(PreservedFlags, SCEV::FlagNSW);
2617  }
2618 
2619  if (PreservedFlags != SCEV::FlagAnyWrap) {
2620  SmallVector<const SCEV *, 4> NewOps(AddExpr->operands());
2621  NewOps[0] = getConstant(ConstAdd);
2622  return getAddExpr(NewOps, PreservedFlags);
2623  }
2624  }
2625  }
2626 
2627  // Canonicalize (-1 * urem X, Y) + X --> (Y * X/Y)
2628  if (Ops.size() == 2) {
2629  const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[0]);
2630  if (Mul && Mul->getNumOperands() == 2 &&
2631  Mul->getOperand(0)->isAllOnesValue()) {
2632  const SCEV *X;
2633  const SCEV *Y;
2634  if (matchURem(Mul->getOperand(1), X, Y) && X == Ops[1]) {
2635  return getMulExpr(Y, getUDivExpr(X, Y));
2636  }
2637  }
2638  }
2639 
2640  // Skip past any other cast SCEVs.
2641  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
2642  ++Idx;
2643 
2644  // If there are add operands they would be next.
2645  if (Idx < Ops.size()) {
2646  bool DeletedAdd = false;
2647  // If the original flags and all inlined SCEVAddExprs are NUW, use the
2648  // common NUW flag for expression after inlining. Other flags cannot be
2649  // preserved, because they may depend on the original order of operations.
2650  SCEV::NoWrapFlags CommonFlags = maskFlags(OrigFlags, SCEV::FlagNUW);
2651  while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
2652  if (Ops.size() > AddOpsInlineThreshold ||
2653  Add->getNumOperands() > AddOpsInlineThreshold)
2654  break;
2655  // If we have an add, expand the add operands onto the end of the operands
2656  // list.
2657  Ops.erase(Ops.begin()+Idx);
2658  Ops.append(Add->op_begin(), Add->op_end());
2659  DeletedAdd = true;
2660  CommonFlags = maskFlags(CommonFlags, Add->getNoWrapFlags());
2661  }
2662 
2663  // If we deleted at least one add, we added operands to the end of the list,
2664  // and they are not necessarily sorted. Recurse to resort and resimplify
2665  // any operands we just acquired.
2666  if (DeletedAdd)
2667  return getAddExpr(Ops, CommonFlags, Depth + 1);
2668  }
2669 
2670  // Skip over the add expression until we get to a multiply.
2671  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
2672  ++Idx;
2673 
2674  // Check to see if there are any folding opportunities present with
2675  // operands multiplied by constant values.
2676  if (Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx])) {
2680  APInt AccumulatedConstant(BitWidth, 0);
2681  if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2682  Ops.data(), Ops.size(),
2683  APInt(BitWidth, 1), *this)) {
2684  struct APIntCompare {
2685  bool operator()(const APInt &LHS, const APInt &RHS) const {
2686  return LHS.ult(RHS);
2687  }
2688  };
2689 
2690  // Some interesting folding opportunity is present, so its worthwhile to
2691  // re-generate the operands list. Group the operands by constant scale,
2692  // to avoid multiplying by the same constant scale multiple times.
2693  std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare> MulOpLists;
2694  for (const SCEV *NewOp : NewOps)
2695  MulOpLists[M.find(NewOp)->second].push_back(NewOp);
2696  // Re-generate the operands list.
2697  Ops.clear();
2698  if (AccumulatedConstant != 0)
2699  Ops.push_back(getConstant(AccumulatedConstant));
2700  for (auto &MulOp : MulOpLists) {
2701  if (MulOp.first == 1) {
2702  Ops.push_back(getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1));
2703  } else if (MulOp.first != 0) {
2704  Ops.push_back(getMulExpr(
2705  getConstant(MulOp.first),
2706  getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1),
2707  SCEV::FlagAnyWrap, Depth + 1));
2708  }
2709  }
2710  if (Ops.empty())
2711  return getZero(Ty);
2712  if (Ops.size() == 1)
2713  return Ops[0];
2714  return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2715  }
2716  }
2717 
2718  // If we are adding something to a multiply expression, make sure the
2719  // something is not already an operand of the multiply. If so, merge it into
2720  // the multiply.
2721  for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) {
2722  const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]);
2723  for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) {
2724  const SCEV *MulOpSCEV = Mul->getOperand(MulOp);
2725  if (isa<SCEVConstant>(MulOpSCEV))
2726  continue;
2727  for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
2728  if (MulOpSCEV == Ops[AddOp]) {
2729  // Fold W + X + (X * Y * Z) --> W + (X * ((Y*Z)+1))
2730  const SCEV *InnerMul = Mul->getOperand(MulOp == 0);
2731  if (Mul->getNumOperands() != 2) {
2732  // If the multiply has more than two operands, we must get the
2733  // Y*Z term.
2735  Mul->op_begin()+MulOp);
2736  MulOps.append(Mul->op_begin()+MulOp+1, Mul->op_end());
2737  InnerMul = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2738  }
2739  SmallVector<const SCEV *, 2> TwoOps = {getOne(Ty), InnerMul};
2740  const SCEV *AddOne = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2741  const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV,
2742  SCEV::FlagAnyWrap, Depth + 1);
2743  if (Ops.size() == 2) return OuterMul;
2744  if (AddOp < Idx) {
2745  Ops.erase(Ops.begin()+AddOp);
2746  Ops.erase(Ops.begin()+Idx-1);
2747  } else {
2748  Ops.erase(Ops.begin()+Idx);
2749  Ops.erase(Ops.begin()+AddOp-1);
2750  }
2751  Ops.push_back(OuterMul);
2752  return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2753  }
2754 
2755  // Check this multiply against other multiplies being added together.
2756  for (unsigned OtherMulIdx = Idx+1;
2757  OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]);
2758  ++OtherMulIdx) {
2759  const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]);
2760  // If MulOp occurs in OtherMul, we can fold the two multiplies
2761  // together.
2762  for (unsigned OMulOp = 0, e = OtherMul->getNumOperands();
2763  OMulOp != e; ++OMulOp)
2764  if (OtherMul->getOperand(OMulOp) == MulOpSCEV) {
2765  // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E))
2766  const SCEV *InnerMul1 = Mul->getOperand(MulOp == 0);
2767  if (Mul->getNumOperands() != 2) {
2769  Mul->op_begin()+MulOp);
2770  MulOps.append(Mul->op_begin()+MulOp+1, Mul->op_end());
2771  InnerMul1 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2772  }
2773  const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0);
2774  if (OtherMul->getNumOperands() != 2) {
2775  SmallVector<const SCEV *, 4> MulOps(OtherMul->op_begin(),
2776  OtherMul->op_begin()+OMulOp);
2777  MulOps.append(OtherMul->op_begin()+OMulOp+1, OtherMul->op_end());
2778  InnerMul2 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2779  }
2780  SmallVector<const SCEV *, 2> TwoOps = {InnerMul1, InnerMul2};
2781  const SCEV *InnerMulSum =
2782  getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2783  const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum,
2784  SCEV::FlagAnyWrap, Depth + 1);
2785  if (Ops.size() == 2) return OuterMul;
2786  Ops.erase(Ops.begin()+Idx);
2787  Ops.erase(Ops.begin()+OtherMulIdx-1);
2788  Ops.push_back(OuterMul);
2789  return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2790  }
2791  }
2792  }
2793  }
2794 
2795  // If there are any add recurrences in the operands list, see if any other
2796  // added values are loop invariant. If so, we can fold them into the
2797  // recurrence.
2798  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
2799  ++Idx;
2800 
2801  // Scan over all recurrences, trying to fold loop invariants into them.
2802  for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
2803  // Scan all of the other operands to this add and add them to the vector if
2804  // they are loop invariant w.r.t. the recurrence.
2806  const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
2807  const Loop *AddRecLoop = AddRec->getLoop();
2808  for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2809  if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) {
2810  LIOps.push_back(Ops[i]);
2811  Ops.erase(Ops.begin()+i);
2812  --i; --e;
2813  }
2814 
2815  // If we found some loop invariants, fold them into the recurrence.
2816  if (!LIOps.empty()) {
2817  // Compute nowrap flags for the addition of the loop-invariant ops and
2818  // the addrec. Temporarily push it as an operand for that purpose. These
2819  // flags are valid in the scope of the addrec only.
2820  LIOps.push_back(AddRec);
2821  SCEV::NoWrapFlags Flags = ComputeFlags(LIOps);
2822  LIOps.pop_back();
2823 
2824  // NLI + LI + {Start,+,Step} --> NLI + {LI+Start,+,Step}
2825  LIOps.push_back(AddRec->getStart());
2826 
2827  SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2828 
2829  // It is not in general safe to propagate flags valid on an add within
2830  // the addrec scope to one outside it. We must prove that the inner
2831  // scope is guaranteed to execute if the outer one does to be able to
2832  // safely propagate. We know the program is undefined if poison is
2833  // produced on the inner scoped addrec. We also know that *for this use*
2834  // the outer scoped add can't overflow (because of the flags we just
2835  // computed for the inner scoped add) without the program being undefined.
2836  // Proving that entry to the outer scope neccesitates entry to the inner
2837  // scope, thus proves the program undefined if the flags would be violated
2838  // in the outer scope.
2839  SCEV::NoWrapFlags AddFlags = Flags;
2840  if (AddFlags != SCEV::FlagAnyWrap) {
2841  auto *DefI = getDefiningScopeBound(LIOps);
2842  auto *ReachI = &*AddRecLoop->getHeader()->begin();
2843  if (!isGuaranteedToTransferExecutionTo(DefI, ReachI))
2844  AddFlags = SCEV::FlagAnyWrap;
2845  }
2846  AddRecOps[0] = getAddExpr(LIOps, AddFlags, Depth + 1);
2847 
2848  // Build the new addrec. Propagate the NUW and NSW flags if both the
2849  // outer add and the inner addrec are guaranteed to have no overflow.
2850  // Always propagate NW.
2851  Flags = AddRec->getNoWrapFlags(setFlags(Flags, SCEV::FlagNW));
2852  const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags);
2853 
2854  // If all of the other operands were loop invariant, we are done.
2855  if (Ops.size() == 1) return NewRec;
2856 
2857  // Otherwise, add the folded AddRec by the non-invariant parts.
2858  for (unsigned i = 0;; ++i)
2859  if (Ops[i] == AddRec) {
2860  Ops[i] = NewRec;
2861  break;
2862  }
2863  return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2864  }
2865 
2866  // Okay, if there weren't any loop invariants to be folded, check to see if
2867  // there are multiple AddRec's with the same loop induction variable being
2868  // added together. If so, we can fold them.
2869  for (unsigned OtherIdx = Idx+1;
2870  OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2871  ++OtherIdx) {
2872  // We expect the AddRecExpr's to be sorted in reverse dominance order,
2873  // so that the 1st found AddRecExpr is dominated by all others.
2874  assert(DT.dominates(
2875  cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()->getHeader(),
2876  AddRec->getLoop()->getHeader()) &&
2877  "AddRecExprs are not sorted in reverse dominance order?");
2878  if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) {
2879  // Other + {A,+,B}<L> + {C,+,D}<L> --> Other + {A+C,+,B+D}<L>
2880  SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2881  for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2882  ++OtherIdx) {
2883  const auto *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
2884  if (OtherAddRec->getLoop() == AddRecLoop) {
2885  for (unsigned i = 0, e = OtherAddRec->getNumOperands();
2886  i != e; ++i) {
2887  if (i >= AddRecOps.size()) {
2888  AddRecOps.append(OtherAddRec->op_begin()+i,
2889  OtherAddRec->op_end());
2890  break;
2891  }
2892  SmallVector<const SCEV *, 2> TwoOps = {
2893  AddRecOps[i], OtherAddRec->getOperand(i)};
2894  AddRecOps[i] = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2895  }
2896  Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
2897  }
2898  }
2899  // Step size has changed, so we cannot guarantee no self-wraparound.
2900  Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap);
2901  return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2902  }
2903  }
2904 
2905  // Otherwise couldn't fold anything into this recurrence. Move onto the
2906  // next one.
2907  }
2908 
2909  // Okay, it looks like we really DO need an add expr. Check to see if we
2910  // already have one, otherwise create a new one.
2911  return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2912 }
2913 
2914 const SCEV *
2915 ScalarEvolution::getOrCreateAddExpr(ArrayRef<const SCEV *> Ops,
2916  SCEV::NoWrapFlags Flags) {
2918  ID.AddInteger(scAddExpr);
2919  for (const SCEV *Op : Ops)
2920  ID.AddPointer(Op);
2921  void *IP = nullptr;
2922  SCEVAddExpr *S =
2923  static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2924  if (!S) {
2925  const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2926  std::uninitialized_copy(Ops.begin(), Ops.end(), O);
2927  S = new (SCEVAllocator)
2928  SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size());
2929  UniqueSCEVs.InsertNode(S, IP);
2930  registerUser(S, Ops);
2931  }
2932  S->setNoWrapFlags(Flags);
2933  return S;
2934 }
2935 
2936 const SCEV *
2937 ScalarEvolution::getOrCreateAddRecExpr(ArrayRef<const SCEV *> Ops,
2938  const Loop *L, SCEV::NoWrapFlags Flags) {
2940  ID.AddInteger(scAddRecExpr);
2941  for (const SCEV *Op : Ops)
2942  ID.AddPointer(Op);
2943  ID.AddPointer(L);
2944  void *IP = nullptr;
2945  SCEVAddRecExpr *S =
2946  static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2947  if (!S) {
2948  const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2949  std::uninitialized_copy(Ops.begin(), Ops.end(), O);
2950  S = new (SCEVAllocator)
2951  SCEVAddRecExpr(ID.Intern(SCEVAllocator), O, Ops.size(), L);
2952  UniqueSCEVs.InsertNode(S, IP);
2953  LoopUsers[L].push_back(S);
2954  registerUser(S, Ops);
2955  }
2956  setNoWrapFlags(S, Flags);
2957  return S;
2958 }
2959 
2960 const SCEV *
2961 ScalarEvolution::getOrCreateMulExpr(ArrayRef<const SCEV *> Ops,
2962  SCEV::NoWrapFlags Flags) {
2964  ID.AddInteger(scMulExpr);
2965  for (const SCEV *Op : Ops)
2966  ID.AddPointer(Op);
2967  void *IP = nullptr;
2968  SCEVMulExpr *S =
2969  static_cast<SCEVMulExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2970  if (!S) {
2971  const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2972  std::uninitialized_copy(Ops.begin(), Ops.end(), O);
2973  S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator),
2974  O, Ops.size());
2975  UniqueSCEVs.InsertNode(S, IP);
2976  registerUser(S, Ops);
2977  }
2978  S->setNoWrapFlags(Flags);
2979  return S;
2980 }
2981 
2982 static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) {
2983  uint64_t k = i*j;
2984  if (j > 1 && k / j != i) Overflow = true;
2985  return k;
2986 }
2987 
2988 /// Compute the result of "n choose k", the binomial coefficient. If an
2989 /// intermediate computation overflows, Overflow will be set and the return will
2990 /// be garbage. Overflow is not cleared on absence of overflow.
2991 static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) {
2992  // We use the multiplicative formula:
2993  // n(n-1)(n-2)...(n-(k-1)) / k(k-1)(k-2)...1 .
2994  // At each iteration, we take the n-th term of the numeral and divide by the
2995  // (k-n)th term of the denominator. This division will always produce an
2996  // integral result, and helps reduce the chance of overflow in the
2997  // intermediate computations. However, we can still overflow even when the
2998  // final result would fit.
2999 
3000  if (n == 0 || n == k) return 1;
3001  if (k > n) return 0;
3002 
3003  if (k > n/2)
3004  k = n-k;
3005 
3006  uint64_t r = 1;
3007  for (uint64_t i = 1; i <= k; ++i) {
3008  r = umul_ov(r, n-(i-1), Overflow);
3009  r /= i;
3010  }
3011  return r;
3012 }
3013 
3014 /// Determine if any of the operands in this SCEV are a constant or if
3015 /// any of the add or multiply expressions in this SCEV contain a constant.
3016 static bool containsConstantInAddMulChain(const SCEV *StartExpr) {
3017  struct FindConstantInAddMulChain {
3018  bool FoundConstant = false;
3019 
3020  bool follow(const SCEV *S) {
3021  FoundConstant |= isa<SCEVConstant>(S);
3022  return isa<SCEVAddExpr>(S) || isa<SCEVMulExpr>(S);
3023  }
3024 
3025  bool isDone() const {
3026  return FoundConstant;
3027  }
3028  };
3029 
3030  FindConstantInAddMulChain F;
3032  ST.visitAll(StartExpr);
3033  return F.FoundConstant;
3034 }
3035 
3036 /// Get a canonical multiply expression, or something simpler if possible.
3038  SCEV::NoWrapFlags OrigFlags,
3039  unsigned Depth) {
3040  assert(OrigFlags == maskFlags(OrigFlags, SCEV::FlagNUW | SCEV::FlagNSW) &&
3041  "only nuw or nsw allowed");
3042  assert(!Ops.empty() && "Cannot get empty mul!");
3043  if (Ops.size() == 1) return Ops[0];
3044 #ifndef NDEBUG
3045  Type *ETy = Ops[0]->getType();
3046  assert(!ETy->isPointerTy());
3047  for (unsigned i = 1, e = Ops.size(); i != e; ++i)
3048  assert(Ops[i]->getType() == ETy &&
3049  "SCEVMulExpr operand types don't match!");
3050 #endif
3051 
3052  // Sort by complexity, this groups all similar expression types together.
3053  GroupByComplexity(Ops, &LI, DT);
3054 
3055  // If there are any constants, fold them together.
3056  unsigned Idx = 0;
3057  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3058  ++Idx;
3059  assert(Idx < Ops.size());
3060  while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
3061  // We found two constants, fold them together!
3062  Ops[0] = getConstant(LHSC->getAPInt() * RHSC->getAPInt());
3063  if (Ops.size() == 2) return Ops[0];
3064  Ops.erase(Ops.begin()+1); // Erase the folded element
3065  LHSC = cast<SCEVConstant>(Ops[0]);
3066  }
3067 
3068  // If we have a multiply of zero, it will always be zero.
3069  if (LHSC->getValue()->isZero())
3070  return LHSC;
3071 
3072  // If we are left with a constant one being multiplied, strip it off.
3073  if (LHSC->getValue()->isOne()) {
3074  Ops.erase(Ops.begin());
3075  --Idx;
3076  }
3077 
3078  if (Ops.size() == 1)
3079  return Ops[0];
3080  }
3081 
3082  // Delay expensive flag strengthening until necessary.
3083  auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
3084  return StrengthenNoWrapFlags(this, scMulExpr, Ops, OrigFlags);
3085  };
3086 
3087  // Limit recursion calls depth.
3088  if (Depth > MaxArithDepth || hasHugeExpression(Ops))
3089  return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3090 
3091  if (SCEV *S = findExistingSCEVInCache(scMulExpr, Ops)) {
3092  // Don't strengthen flags if we have no new information.
3093  SCEVMulExpr *Mul = static_cast<SCEVMulExpr *>(S);
3094  if (Mul->getNoWrapFlags(OrigFlags) != OrigFlags)
3095  Mul->setNoWrapFlags(ComputeFlags(Ops));
3096  return S;
3097  }
3098 
3099  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3100  if (Ops.size() == 2) {
3101  // C1*(C2+V) -> C1*C2 + C1*V
3102  if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1]))
3103  // If any of Add's ops are Adds or Muls with a constant, apply this
3104  // transformation as well.
3105  //
3106  // TODO: There are some cases where this transformation is not
3107  // profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of
3108  // this transformation should be narrowed down.
3109  if (Add->getNumOperands() == 2 && containsConstantInAddMulChain(Add))
3110  return getAddExpr(getMulExpr(LHSC, Add->getOperand(0),
3111  SCEV::FlagAnyWrap, Depth + 1),
3112  getMulExpr(LHSC, Add->getOperand(1),
3113  SCEV::FlagAnyWrap, Depth + 1),
3114  SCEV::FlagAnyWrap, Depth + 1);
3115 
3116  if (Ops[0]->isAllOnesValue()) {
3117  // If we have a mul by -1 of an add, try distributing the -1 among the
3118  // add operands.
3119  if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) {
3121  bool AnyFolded = false;
3122  for (const SCEV *AddOp : Add->operands()) {
3123  const SCEV *Mul = getMulExpr(Ops[0], AddOp, SCEV::FlagAnyWrap,
3124  Depth + 1);
3125  if (!isa<SCEVMulExpr>(Mul)) AnyFolded = true;
3126  NewOps.push_back(Mul);
3127  }
3128  if (AnyFolded)
3129  return getAddExpr(NewOps, SCEV::FlagAnyWrap, Depth + 1);
3130  } else if (const auto *AddRec = dyn_cast<SCEVAddRecExpr>(Ops[1])) {
3131  // Negation preserves a recurrence's no self-wrap property.
3133  for (const SCEV *AddRecOp : AddRec->operands())
3134  Operands.push_back(getMulExpr(Ops[0], AddRecOp, SCEV::FlagAnyWrap,
3135  Depth + 1));
3136 
3137  return getAddRecExpr(Operands, AddRec->getLoop(),
3138  AddRec->getNoWrapFlags(SCEV::FlagNW));
3139  }
3140  }
3141  }
3142  }
3143 
3144  // Skip over the add expression until we get to a multiply.
3145  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
3146  ++Idx;
3147 
3148  // If there are mul operands inline them all into this expression.
3149  if (Idx < Ops.size()) {
3150  bool DeletedMul = false;
3151  while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
3152  if (Ops.size() > MulOpsInlineThreshold)
3153  break;
3154  // If we have an mul, expand the mul operands onto the end of the
3155  // operands list.
3156  Ops.erase(Ops.begin()+Idx);
3157  Ops.append(Mul->op_begin(), Mul->op_end());
3158  DeletedMul = true;
3159  }
3160 
3161  // If we deleted at least one mul, we added operands to the end of the
3162  // list, and they are not necessarily sorted. Recurse to resort and
3163  // resimplify any operands we just acquired.
3164  if (DeletedMul)
3165  return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3166  }
3167 
3168  // If there are any add recurrences in the operands list, see if any other
3169  // added values are loop invariant. If so, we can fold them into the
3170  // recurrence.
3171  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
3172  ++Idx;
3173 
3174  // Scan over all recurrences, trying to fold loop invariants into them.
3175  for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
3176  // Scan all of the other operands to this mul and add them to the vector
3177  // if they are loop invariant w.r.t. the recurrence.
3179  const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
3180  const Loop *AddRecLoop = AddRec->getLoop();
3181  for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3182  if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) {
3183  LIOps.push_back(Ops[i]);
3184  Ops.erase(Ops.begin()+i);
3185  --i; --e;
3186  }
3187 
3188  // If we found some loop invariants, fold them into the recurrence.
3189  if (!LIOps.empty()) {
3190  // NLI * LI * {Start,+,Step} --> NLI * {LI*Start,+,LI*Step}
3192  NewOps.reserve(AddRec->getNumOperands());
3193  const SCEV *Scale = getMulExpr(LIOps, SCEV::FlagAnyWrap, Depth + 1);
3194  for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i)
3195  NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i),
3196  SCEV::FlagAnyWrap, Depth + 1));
3197 
3198  // Build the new addrec. Propagate the NUW and NSW flags if both the
3199  // outer mul and the inner addrec are guaranteed to have no overflow.
3200  //
3201  // No self-wrap cannot be guaranteed after changing the step size, but
3202  // will be inferred if either NUW or NSW is true.
3203  SCEV::NoWrapFlags Flags = ComputeFlags({Scale, AddRec});
3204  const SCEV *NewRec = getAddRecExpr(
3205  NewOps, AddRecLoop, AddRec->getNoWrapFlags(Flags));
3206 
3207  // If all of the other operands were loop invariant, we are done.
3208  if (Ops.size() == 1) return NewRec;
3209 
3210  // Otherwise, multiply the folded AddRec by the non-invariant parts.
3211  for (unsigned i = 0;; ++i)
3212  if (Ops[i] == AddRec) {
3213  Ops[i] = NewRec;
3214  break;
3215  }
3216  return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3217  }
3218 
3219  // Okay, if there weren't any loop invariants to be folded, check to see
3220  // if there are multiple AddRec's with the same loop induction variable
3221  // being multiplied together. If so, we can fold them.
3222 
3223  // {A1,+,A2,+,...,+,An}<L> * {B1,+,B2,+,...,+,Bn}<L>
3224  // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [
3225  // choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z
3226  // ]]],+,...up to x=2n}.
3227  // Note that the arguments to choose() are always integers with values
3228  // known at compile time, never SCEV objects.
3229  //
3230  // The implementation avoids pointless extra computations when the two
3231  // addrec's are of different length (mathematically, it's equivalent to
3232  // an infinite stream of zeros on the right).
3233  bool OpsModified = false;
3234  for (unsigned OtherIdx = Idx+1;
3235  OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
3236  ++OtherIdx) {
3237  const SCEVAddRecExpr *OtherAddRec =
3238  dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]);
3239  if (!OtherAddRec || OtherAddRec->getLoop() != AddRecLoop)
3240  continue;
3241 
3242  // Limit max number of arguments to avoid creation of unreasonably big
3243  // SCEVAddRecs with very complex operands.
3244  if (AddRec->getNumOperands() + OtherAddRec->getNumOperands() - 1 >
3245  MaxAddRecSize || hasHugeExpression({AddRec, OtherAddRec}))
3246  continue;
3247 
3248  bool Overflow = false;
3249  Type *Ty = AddRec->getType();
3250  bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64;
3251  SmallVector<const SCEV*, 7> AddRecOps;
3252  for (int x = 0, xe = AddRec->getNumOperands() +
3253  OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) {
3255  for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) {
3256  uint64_t Coeff1 = Choose(x, 2*x - y, Overflow);
3257  for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1),
3258  ze = std::min(x+1, (int)OtherAddRec->getNumOperands());
3259  z < ze && !Overflow; ++z) {
3260  uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow);
3261  uint64_t Coeff;
3262  if (LargerThan64Bits)
3263  Coeff = umul_ov(Coeff1, Coeff2, Overflow);
3264  else
3265  Coeff = Coeff1*Coeff2;
3266  const SCEV *CoeffTerm = getConstant(Ty, Coeff);
3267  const SCEV *Term1 = AddRec->getOperand(y-z);
3268  const SCEV *Term2 = OtherAddRec->getOperand(z);
3269  SumOps.push_back(getMulExpr(CoeffTerm, Term1, Term2,
3270  SCEV::FlagAnyWrap, Depth + 1));
3271  }
3272  }
3273  if (SumOps.empty())
3274  SumOps.push_back(getZero(Ty));
3275  AddRecOps.push_back(getAddExpr(SumOps, SCEV::FlagAnyWrap, Depth + 1));
3276  }
3277  if (!Overflow) {
3278  const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRecLoop,
3280  if (Ops.size() == 2) return NewAddRec;
3281  Ops[Idx] = NewAddRec;
3282  Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
3283  OpsModified = true;
3284  AddRec = dyn_cast<SCEVAddRecExpr>(NewAddRec);
3285  if (!AddRec)
3286  break;
3287  }
3288  }
3289  if (OpsModified)
3290  return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3291 
3292  // Otherwise couldn't fold anything into this recurrence. Move onto the
3293  // next one.
3294  }
3295 
3296  // Okay, it looks like we really DO need an mul expr. Check to see if we
3297  // already have one, otherwise create a new one.
3298  return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3299 }
3300 
3301 /// Represents an unsigned remainder expression based on unsigned division.
3303  const SCEV *RHS) {
3306  "SCEVURemExpr operand types don't match!");
3307 
3308  // Short-circuit easy cases
3309  if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3310  // If constant is one, the result is trivial
3311  if (RHSC->getValue()->isOne())
3312  return getZero(LHS->getType()); // X urem 1 --> 0
3313 
3314  // If constant is a power of two, fold into a zext(trunc(LHS)).
3315  if (RHSC->getAPInt().isPowerOf2()) {
3316  Type *FullTy = LHS->getType();
3317  Type *TruncTy =
3318  IntegerType::get(getContext(), RHSC->getAPInt().logBase2());
3319  return getZeroExtendExpr(getTruncateExpr(LHS, TruncTy), FullTy);
3320  }
3321  }
3322 
3323  // Fallback to %a == %x urem %y == %x -<nuw> ((%x udiv %y) *<nuw> %y)
3324  const SCEV *UDiv = getUDivExpr(LHS, RHS);
3325  const SCEV *Mult = getMulExpr(UDiv, RHS, SCEV::FlagNUW);
3326  return getMinusSCEV(LHS, Mult, SCEV::FlagNUW);
3327 }
3328 
3329 /// Get a canonical unsigned division expression, or something simpler if
3330 /// possible.
3332  const SCEV *RHS) {
3333  assert(!LHS->getType()->isPointerTy() &&
3334  "SCEVUDivExpr operand can't be pointer!");
3335  assert(LHS->getType() == RHS->getType() &&
3336  "SCEVUDivExpr operand types don't match!");
3337 
3339  ID.AddInteger(scUDivExpr);
3340  ID.AddPointer(LHS);
3341  ID.AddPointer(RHS);
3342  void *IP = nullptr;
3343  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3344  return S;
3345 
3346  // 0 udiv Y == 0
3347  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS))
3348  if (LHSC->getValue()->isZero())
3349  return LHS;
3350 
3351  if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3352  if (RHSC->getValue()->isOne())
3353  return LHS; // X udiv 1 --> x
3354  // If the denominator is zero, the result of the udiv is undefined. Don't
3355  // try to analyze it, because the resolution chosen here may differ from
3356  // the resolution chosen in other parts of the compiler.
3357  if (!RHSC->getValue()->isZero()) {
3358  // Determine if the division can be folded into the operands of
3359  // its operands.
3360  // TODO: Generalize this to non-constants by using known-bits information.
3361  Type *Ty = LHS->getType();
3362  unsigned LZ = RHSC->getAPInt().countLeadingZeros();
3363  unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1;
3364  // For non-power-of-two values, effectively round the value up to the
3365  // nearest power of two.
3366  if (!RHSC->getAPInt().isPowerOf2())
3367  ++MaxShiftAmt;
3368  IntegerType *ExtTy =
3369  IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt);
3370  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
3371  if (const SCEVConstant *Step =
3372  dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this))) {
3373  // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded.
3374  const APInt &StepInt = Step->getAPInt();
3375  const APInt &DivInt = RHSC->getAPInt();
3376  if (!StepInt.urem(DivInt) &&
3377  getZeroExtendExpr(AR, ExtTy) ==
3378  getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3379  getZeroExtendExpr(Step, ExtTy),
3380  AR->getLoop(), SCEV::FlagAnyWrap)) {
3382  for (const SCEV *Op : AR->operands())
3383  Operands.push_back(getUDivExpr(Op, RHS));
3384  return getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagNW);
3385  }
3386  /// Get a canonical UDivExpr for a recurrence.
3387  /// {X,+,N}/C => {Y,+,N}/C where Y=X-(X%N). Safe when C%N=0.
3388  // We can currently only fold X%N if X is constant.
3389  const SCEVConstant *StartC = dyn_cast<SCEVConstant>(AR->getStart());
3390  if (StartC && !DivInt.urem(StepInt) &&
3391  getZeroExtendExpr(AR, ExtTy) ==
3392  getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3393  getZeroExtendExpr(Step, ExtTy),
3394  AR->getLoop(), SCEV::FlagAnyWrap)) {
3395  const APInt &StartInt = StartC->getAPInt();
3396  const APInt &StartRem = StartInt.urem(StepInt);
3397  if (StartRem != 0) {
3398  const SCEV *NewLHS =
3399  getAddRecExpr(getConstant(StartInt - StartRem), Step,
3400  AR->getLoop(), SCEV::FlagNW);
3401  if (LHS != NewLHS) {
3402  LHS = NewLHS;
3403 
3404  // Reset the ID to include the new LHS, and check if it is
3405  // already cached.
3406  ID.clear();
3407  ID.AddInteger(scUDivExpr);
3408  ID.AddPointer(LHS);
3409  ID.AddPointer(RHS);
3410  IP = nullptr;
3411  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3412  return S;
3413  }
3414  }
3415  }
3416  }
3417  // (A*B)/C --> A*(B/C) if safe and B/C can be folded.
3418  if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) {
3420  for (const SCEV *Op : M->operands())
3421  Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3422  if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands))
3423  // Find an operand that's safely divisible.
3424  for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) {
3425  const SCEV *Op = M->getOperand(i);
3426  const SCEV *Div = getUDivExpr(Op, RHSC);
3427  if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) {
3428  Operands = SmallVector<const SCEV *, 4>(M->operands());
3429  Operands[i] = Div;
3430  return getMulExpr(Operands);
3431  }
3432  }
3433  }
3434 
3435  // (A/B)/C --> A/(B*C) if safe and B*C can be folded.
3436  if (const SCEVUDivExpr *OtherDiv = dyn_cast<SCEVUDivExpr>(LHS)) {
3437  if (auto *DivisorConstant =
3438  dyn_cast<SCEVConstant>(OtherDiv->getRHS())) {
3439  bool Overflow = false;
3440  APInt NewRHS =
3441  DivisorConstant->getAPInt().umul_ov(RHSC->getAPInt(), Overflow);
3442  if (Overflow) {
3443  return getConstant(RHSC->getType(), 0, false);
3444  }
3445  return getUDivExpr(OtherDiv->getLHS(), getConstant(NewRHS));
3446  }
3447  }
3448 
3449  // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded.
3450  if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(LHS)) {
3452  for (const SCEV *Op : A->operands())
3453  Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3454  if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) {
3455  Operands.clear();
3456  for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) {
3457  const SCEV *Op = getUDivExpr(A->getOperand(i), RHS);
3458  if (isa<SCEVUDivExpr>(Op) ||
3459  getMulExpr(Op, RHS) != A->getOperand(i))
3460  break;
3461  Operands.push_back(Op);
3462  }
3463  if (Operands.size() == A->getNumOperands())
3464  return getAddExpr(Operands);
3465  }
3466  }
3467 
3468  // Fold if both operands are constant.
3469  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
3470  Constant *LHSCV = LHSC->getValue();
3471  Constant *RHSCV = RHSC->getValue();
3472  return getConstant(cast<ConstantInt>(ConstantExpr::getUDiv(LHSCV,
3473  RHSCV)));
3474  }
3475  }
3476  }
3477 
3478  // The Insertion Point (IP) might be invalid by now (due to UniqueSCEVs
3479  // changes). Make sure we get a new one.
3480  IP = nullptr;
3481  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
3482  SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator),
3483  LHS, RHS);
3484  UniqueSCEVs.InsertNode(S, IP);
3485  registerUser(S, {LHS, RHS});
3486  return S;
3487 }
3488 
3489 static const APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) {
3490  APInt A = C1->getAPInt().abs();
3491  APInt B = C2->getAPInt().abs();
3492  uint32_t ABW = A.getBitWidth();
3493  uint32_t BBW = B.getBitWidth();
3494 
3495  if (ABW > BBW)
3496  B = B.zext(ABW);
3497  else if (ABW < BBW)
3498  A = A.zext(BBW);
3499 
3501 }
3502 
3503 /// Get a canonical unsigned division expression, or something simpler if
3504 /// possible. There is no representation for an exact udiv in SCEV IR, but we
3505 /// can attempt to remove factors from the LHS and RHS. We can't do this when
3506 /// it's not exact because the udiv may be clearing bits.
3508  const SCEV *RHS) {
3509  // TODO: we could try to find factors in all sorts of things, but for now we
3510  // just deal with u/exact (multiply, constant). See SCEVDivision towards the
3511  // end of this file for inspiration.
3512 
3513  const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(LHS);
3514  if (!Mul || !Mul->hasNoUnsignedWrap())
3515  return getUDivExpr(LHS, RHS);
3516 
3517  if (const SCEVConstant *RHSCst = dyn_cast<SCEVConstant>(RHS)) {
3518  // If the mulexpr multiplies by a constant, then that constant must be the
3519  // first element of the mulexpr.
3520  if (const auto *LHSCst = dyn_cast<SCEVConstant>(Mul->getOperand(0))) {
3521  if (LHSCst == RHSCst) {
3523  return getMulExpr(Operands);
3524  }
3525 
3526  // We can't just assume that LHSCst divides RHSCst cleanly, it could be
3527  // that there's a factor provided by one of the other terms. We need to
3528  // check.
3529  APInt Factor = gcd(LHSCst, RHSCst);
3530  if (!Factor.isIntN(1)) {
3531  LHSCst =
3532  cast<SCEVConstant>(getConstant(LHSCst->getAPInt().udiv(Factor)));
3533  RHSCst =
3534  cast<SCEVConstant>(getConstant(RHSCst->getAPInt().udiv(Factor)));
3536  Operands.push_back(LHSCst);
3537  Operands.append(Mul->op_begin() + 1, Mul->op_end());
3538  LHS = getMulExpr(Operands);
3539  RHS = RHSCst;
3540  Mul = dyn_cast<SCEVMulExpr>(LHS);
3541  if (!Mul)
3542  return getUDivExactExpr(LHS, RHS);
3543  }
3544  }
3545  }
3546 
3547  for (int i = 0, e = Mul->getNumOperands(); i != e; ++i) {
3548  if (Mul->getOperand(i) == RHS) {
3550  Operands.append(Mul->op_begin(), Mul->op_begin() + i);
3551  Operands.append(Mul->op_begin() + i + 1, Mul->op_end());
3552  return getMulExpr(Operands);
3553  }
3554  }
3555 
3556  return getUDivExpr(LHS, RHS);
3557 }
3558 
3559 /// Get an add recurrence expression for the specified loop. Simplify the
3560 /// expression as much as possible.
3561 const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, const SCEV *Step,
3562  const Loop *L,
3563  SCEV::NoWrapFlags Flags) {
3565  Operands.push_back(Start);
3566  if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step))
3567  if (StepChrec->getLoop() == L) {
3568  Operands.append(StepChrec->op_begin(), StepChrec->op_end());
3569  return getAddRecExpr(Operands, L, maskFlags(Flags, SCEV::FlagNW));
3570  }
3571 
3572  Operands.push_back(Step);
3573  return getAddRecExpr(Operands, L, Flags);
3574 }
3575 
3576 /// Get an add recurrence expression for the specified loop. Simplify the
3577 /// expression as much as possible.
3578 const SCEV *
3580  const Loop *L, SCEV::NoWrapFlags Flags) {
3581  if (Operands.size() == 1) return Operands[0];
3582 #ifndef NDEBUG
3583  Type *ETy = getEffectiveSCEVType(Operands[0]->getType());
3584  for (unsigned i = 1, e = Operands.size(); i != e; ++i) {
3586  "SCEVAddRecExpr operand types don't match!");
3587  assert(!Operands[i]->getType()->isPointerTy() && "Step must be integer");
3588  }
3589  for (unsigned i = 0, e = Operands.size(); i != e; ++i)
3591  "SCEVAddRecExpr operand is not loop-invariant!");
3592 #endif
3593 
3594  if (Operands.back()->isZero()) {
3595  Operands.pop_back();
3596  return getAddRecExpr(Operands, L, SCEV::FlagAnyWrap); // {X,+,0} --> X
3597  }
3598 
3599  // It's tempting to want to call getConstantMaxBackedgeTakenCount count here and
3600  // use that information to infer NUW and NSW flags. However, computing a
3601  // BE count requires calling getAddRecExpr, so we may not yet have a
3602  // meaningful BE count at this point (and if we don't, we'd be stuck
3603  // with a SCEVCouldNotCompute as the cached BE count).
3604 
3605  Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
3606 
3607  // Canonicalize nested AddRecs in by nesting them in order of loop depth.
3608  if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
3609  const Loop *NestedLoop = NestedAR->getLoop();
3610  if (L->contains(NestedLoop)
3611  ? (L->getLoopDepth() < NestedLoop->getLoopDepth())
3612  : (!NestedLoop->contains(L) &&
3613  DT.dominates(L->getHeader(), NestedLoop->getHeader()))) {
3614  SmallVector<const SCEV *, 4> NestedOperands(NestedAR->operands());
3615  Operands[0] = NestedAR->getStart();
3616  // AddRecs require their operands be loop-invariant with respect to their
3617  // loops. Don't perform this transformation if it would break this
3618  // requirement.
3619  bool AllInvariant = all_of(
3620  Operands, [&](const SCEV *Op) { return isLoopInvariant(Op, L); });
3621 
3622  if (AllInvariant) {
3623  // Create a recurrence for the outer loop with the same step size.
3624  //
3625  // The outer recurrence keeps its NW flag but only keeps NUW/NSW if the
3626  // inner recurrence has the same property.
3627  SCEV::NoWrapFlags OuterFlags =
3628  maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags());
3629 
3630  NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags);
3631  AllInvariant = all_of(NestedOperands, [&](const SCEV *Op) {
3632  return isLoopInvariant(Op, NestedLoop);
3633  });
3634 
3635  if (AllInvariant) {
3636  // Ok, both add recurrences are valid after the transformation.
3637  //
3638  // The inner recurrence keeps its NW flag but only keeps NUW/NSW if
3639  // the outer recurrence has the same property.
3640  SCEV::NoWrapFlags InnerFlags =
3641  maskFlags(NestedAR->getNoWrapFlags(), SCEV::FlagNW | Flags);
3642  return getAddRecExpr(NestedOperands, NestedLoop, InnerFlags);
3643  }
3644  }
3645  // Reset Operands to its original state.
3646  Operands[0] = NestedAR;
3647  }
3648  }
3649 
3650  // Okay, it looks like we really DO need an addrec expr. Check to see if we
3651  // already have one, otherwise create a new one.
3652  return getOrCreateAddRecExpr(Operands, L, Flags);
3653 }
3654 
3655 const SCEV *
3657  const SmallVectorImpl<const SCEV *> &IndexExprs) {
3658  const SCEV *BaseExpr = getSCEV(GEP->getPointerOperand());
3659  // getSCEV(Base)->getType() has the same address space as Base->getType()
3660  // because SCEV::getType() preserves the address space.
3661  Type *IntIdxTy = getEffectiveSCEVType(BaseExpr->getType());
3662  const bool AssumeInBoundsFlags = [&]() {
3663  if (!GEP->isInBounds())
3664  return false;
3665 
3666  // We'd like to propagate flags from the IR to the corresponding SCEV nodes,
3667  // but to do that, we have to ensure that said flag is valid in the entire
3668  // defined scope of the SCEV.
3669  auto *GEPI = dyn_cast<Instruction>(GEP);
3670  // TODO: non-instructions have global scope. We might be able to prove
3671  // some global scope cases
3672  return GEPI && isSCEVExprNeverPoison(GEPI);
3673  }();
3674 
3675  SCEV::NoWrapFlags OffsetWrap =
3676  AssumeInBoundsFlags ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
3677 
3678  Type *CurTy = GEP->getType();
3679  bool FirstIter = true;
3681  for (const SCEV *IndexExpr : IndexExprs) {
3682  // Compute the (potentially symbolic) offset in bytes for this index.
3683  if (StructType *STy = dyn_cast<StructType>(CurTy)) {
3684  // For a struct, add the member offset.
3685  ConstantInt *Index = cast<SCEVConstant>(IndexExpr)->getValue();
3686  unsigned FieldNo = Index->getZExtValue();
3687  const SCEV *FieldOffset = getOffsetOfExpr(IntIdxTy, STy, FieldNo);
3688  Offsets.push_back(FieldOffset);
3689 
3690  // Update CurTy to the type of the field at Index.
3691  CurTy = STy->getTypeAtIndex(Index);
3692  } else {
3693  // Update CurTy to its element type.
3694  if (FirstIter) {
3695  assert(isa<PointerType>(CurTy) &&
3696  "The first index of a GEP indexes a pointer");
3697  CurTy = GEP->getSourceElementType();
3698  FirstIter = false;
3699  } else {
3700  CurTy = GetElementPtrInst::getTypeAtIndex(CurTy, (uint64_t)0);
3701  }
3702  // For an array, add the element offset, explicitly scaled.
3703  const SCEV *ElementSize = getSizeOfExpr(IntIdxTy, CurTy);
3704  // Getelementptr indices are signed.
3705  IndexExpr = getTruncateOrSignExtend(IndexExpr, IntIdxTy);
3706 
3707  // Multiply the index by the element size to compute the element offset.
3708  const SCEV *LocalOffset = getMulExpr(IndexExpr, ElementSize, OffsetWrap);
3709  Offsets.push_back(LocalOffset);
3710  }
3711  }
3712 
3713  // Handle degenerate case of GEP without offsets.
3714  if (Offsets.empty())
3715  return BaseExpr;
3716 
3717  // Add the offsets together, assuming nsw if inbounds.
3718  const SCEV *Offset = getAddExpr(Offsets, OffsetWrap);
3719  // Add the base address and the offset. We cannot use the nsw flag, as the
3720  // base address is unsigned. However, if we know that the offset is
3721  // non-negative, we can use nuw.
3722  SCEV::NoWrapFlags BaseWrap = AssumeInBoundsFlags && isKnownNonNegative(Offset)
3724  auto *GEPExpr = getAddExpr(BaseExpr, Offset, BaseWrap);
3725  assert(BaseExpr->getType() == GEPExpr->getType() &&
3726  "GEP should not change type mid-flight.");
3727  return GEPExpr;
3728 }
3729 
3730 SCEV *ScalarEvolution::findExistingSCEVInCache(SCEVTypes SCEVType,
3731  ArrayRef<const SCEV *> Ops) {
3733  ID.AddInteger(SCEVType);
3734  for (const SCEV *Op : Ops)
3735  ID.AddPointer(Op);
3736  void *IP = nullptr;
3737  return UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3738 }
3739 
3740 const SCEV *ScalarEvolution::getAbsExpr(const SCEV *Op, bool IsNSW) {
3742  return getSMaxExpr(Op, getNegativeSCEV(Op, Flags));
3743 }
3744 
3747  assert(SCEVMinMaxExpr::isMinMaxType(Kind) && "Not a SCEVMinMaxExpr!");
3748  assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
3749  if (Ops.size() == 1) return Ops[0];
3750 #ifndef NDEBUG
3751  Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
3752  for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
3753  assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
3754  "Operand types don't match!");
3755  assert(Ops[0]->getType()->isPointerTy() ==
3756  Ops[i]->getType()->isPointerTy() &&
3757  "min/max should be consistently pointerish");
3758  }
3759 #endif
3760 
3761  bool IsSigned = Kind == scSMaxExpr || Kind == scSMinExpr;
3762  bool IsMax = Kind == scSMaxExpr || Kind == scUMaxExpr;
3763 
3764  // Sort by complexity, this groups all similar expression types together.
3765  GroupByComplexity(Ops, &LI, DT);
3766 
3767  // Check if we have created the same expression before.
3768  if (const SCEV *S = findExistingSCEVInCache(Kind, Ops)) {
3769  return S;
3770  }
3771 
3772  // If there are any constants, fold them together.
3773  unsigned Idx = 0;
3774  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3775  ++Idx;
3776  assert(Idx < Ops.size());
3777  auto FoldOp = [&](const APInt &LHS, const APInt &RHS) {
3778  if (Kind == scSMaxExpr)
3779  return APIntOps::smax(LHS, RHS);
3780  else if (Kind == scSMinExpr)
3781  return APIntOps::smin(LHS, RHS);
3782  else if (Kind == scUMaxExpr)
3783  return APIntOps::umax(LHS, RHS);
3784  else if (Kind == scUMinExpr)
3785  return APIntOps::umin(LHS, RHS);
3786  llvm_unreachable("Unknown SCEV min/max opcode");
3787  };
3788 
3789  while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
3790  // We found two constants, fold them together!
3791  ConstantInt *Fold = ConstantInt::get(
3792  getContext(), FoldOp(LHSC->getAPInt(), RHSC->getAPInt()));
3793  Ops[0] = getConstant(Fold);
3794  Ops.erase(Ops.begin()+1); // Erase the folded element
3795  if (Ops.size() == 1) return Ops[0];
3796  LHSC = cast<SCEVConstant>(Ops[0]);
3797  }
3798 
3799  bool IsMinV = LHSC->getValue()->isMinValue(IsSigned);
3800  bool IsMaxV = LHSC->getValue()->isMaxValue(IsSigned);
3801 
3802  if (IsMax ? IsMinV : IsMaxV) {
3803  // If we are left with a constant minimum(/maximum)-int, strip it off.
3804  Ops.erase(Ops.begin());
3805  --Idx;
3806  } else if (IsMax ? IsMaxV : IsMinV) {
3807  // If we have a max(/min) with a constant maximum(/minimum)-int,
3808  // it will always be the extremum.
3809  return LHSC;
3810  }
3811 
3812  if (Ops.size() == 1) return Ops[0];
3813  }
3814 
3815  // Find the first operation of the same kind
3816  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < Kind)
3817  ++Idx;
3818 
3819  // Check to see if one of the operands is of the same kind. If so, expand its
3820  // operands onto our operand list, and recurse to simplify.
3821  if (Idx < Ops.size()) {
3822  bool DeletedAny = false;
3823  while (Ops[Idx]->getSCEVType() == Kind) {
3824  const SCEVMinMaxExpr *SMME = cast<SCEVMinMaxExpr>(Ops[Idx]);
3825  Ops.erase(Ops.begin()+Idx);
3826  Ops.append(SMME->op_begin(), SMME->op_end());
3827  DeletedAny = true;
3828  }
3829 
3830  if (DeletedAny)
3831  return getMinMaxExpr(Kind, Ops);
3832  }
3833 
3834  // Okay, check to see if the same value occurs in the operand list twice. If
3835  // so, delete one. Since we sorted the list, these values are required to
3836  // be adjacent.
3837  llvm::CmpInst::Predicate GEPred =
3839  llvm::CmpInst::Predicate LEPred =
3841  llvm::CmpInst::Predicate FirstPred = IsMax ? GEPred : LEPred;
3842  llvm::CmpInst::Predicate SecondPred = IsMax ? LEPred : GEPred;
3843  for (unsigned i = 0, e = Ops.size() - 1; i != e; ++i) {
3844  if (Ops[i] == Ops[i + 1] ||
3845  isKnownViaNonRecursiveReasoning(FirstPred, Ops[i], Ops[i + 1])) {
3846  // X op Y op Y --> X op Y
3847  // X op Y --> X, if we know X, Y are ordered appropriately
3848  Ops.erase(Ops.begin() + i + 1, Ops.begin() + i + 2);
3849  --i;
3850  --e;
3851  } else if (isKnownViaNonRecursiveReasoning(SecondPred, Ops[i],
3852  Ops[i + 1])) {
3853  // X op Y --> Y, if we know X, Y are ordered appropriately
3854  Ops.erase(Ops.begin() + i, Ops.begin() + i + 1);
3855  --i;
3856  --e;
3857  }
3858  }
3859 
3860  if (Ops.size() == 1) return Ops[0];
3861 
3862  assert(!Ops.empty() && "Reduced smax down to nothing!");
3863 
3864  // Okay, it looks like we really DO need an expr. Check to see if we
3865  // already have one, otherwise create a new one.
3867  ID.AddInteger(Kind);
3868  for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3869  ID.AddPointer(Ops[i]);
3870  void *IP = nullptr;
3871  const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
3872  if (ExistingSCEV)
3873  return ExistingSCEV;
3874  const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3875  std::uninitialized_copy(Ops.begin(), Ops.end(), O);
3876  SCEV *S = new (SCEVAllocator)
3877  SCEVMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
3878 
3879  UniqueSCEVs.InsertNode(S, IP);
3880  registerUser(S, Ops);
3881  return S;
3882 }
3883 
3884 namespace {
3885 
3886 class SCEVSequentialMinMaxDeduplicatingVisitor final
3887  : public SCEVVisitor<SCEVSequentialMinMaxDeduplicatingVisitor,
3888  Optional<const SCEV *>> {
3889  using RetVal = Optional<const SCEV *>;
3891 
3892  ScalarEvolution &SE;
3893  const SCEVTypes RootKind; // Must be a sequential min/max expression.
3894  const SCEVTypes NonSequentialRootKind; // Non-sequential variant of RootKind.
3896 
3897  bool canRecurseInto(SCEVTypes Kind) const {
3898  // We can only recurse into the SCEV expression of the same effective type
3899  // as the type of our root SCEV expression.
3900  return RootKind == Kind || NonSequentialRootKind == Kind;
3901  };
3902 
3903  RetVal visitAnyMinMaxExpr(const SCEV *S) {
3904  assert((isa<SCEVMinMaxExpr>(S) || isa<SCEVSequentialMinMaxExpr>(S)) &&
3905  "Only for min/max expressions.");
3906  SCEVTypes Kind = S->getSCEVType();
3907 
3908  if (!canRecurseInto(Kind))
3909  return S;
3910 
3911  auto *NAry = cast<SCEVNAryExpr>(S);
3913  bool Changed =
3914  visit(Kind, makeArrayRef(NAry->op_begin(), NAry->op_end()), NewOps);
3915 
3916  if (!Changed)
3917  return S;
3918  if (NewOps.empty())
3919  return None;
3920 
3921  return isa<SCEVSequentialMinMaxExpr>(S)
3922  ? SE.getSequentialMinMaxExpr(Kind, NewOps)
3923  : SE.getMinMaxExpr(Kind, NewOps);
3924  }
3925 
3926  RetVal visit(const SCEV *S) {
3927  // Has the whole operand been seen already?
3928  if (!SeenOps.insert(S).second)
3929  return None;
3930  return Base::visit(S);
3931  }
3932 
3933 public:
3934  SCEVSequentialMinMaxDeduplicatingVisitor(ScalarEvolution &SE,
3935  SCEVTypes RootKind)
3936  : SE(SE), RootKind(RootKind),
3937  NonSequentialRootKind(
3938  SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType(
3939  RootKind)) {}
3940 
3941  bool /*Changed*/ visit(SCEVTypes Kind, ArrayRef<const SCEV *> OrigOps,
3943  bool Changed = false;
3945  Ops.reserve(OrigOps.size());
3946 
3947  for (const SCEV *Op : OrigOps) {
3948  RetVal NewOp = visit(Op);
3949  if (NewOp != Op)
3950  Changed = true;
3951  if (NewOp)
3952  Ops.emplace_back(*NewOp);
3953  }
3954 
3955  if (Changed)
3956  NewOps = std::move(Ops);
3957  return Changed;
3958  }
3959 
3960  RetVal visitConstant(const SCEVConstant *Constant) { return Constant; }
3961 
3962  RetVal visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) { return Expr; }
3963 
3964  RetVal visitTruncateExpr(const SCEVTruncateExpr *Expr) { return Expr; }
3965 
3966  RetVal visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { return Expr; }
3967 
3968  RetVal visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { return Expr; }
3969 
3970  RetVal visitAddExpr(const SCEVAddExpr *Expr) { return Expr; }
3971 
3972  RetVal visitMulExpr(const SCEVMulExpr *Expr) { return Expr; }
3973 
3974  RetVal visitUDivExpr(const SCEVUDivExpr *Expr) { return Expr; }
3975 
3976  RetVal visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
3977 
3978  RetVal visitSMaxExpr(const SCEVSMaxExpr *Expr) {
3979  return visitAnyMinMaxExpr(Expr);
3980  }
3981 
3982  RetVal visitUMaxExpr(const SCEVUMaxExpr *Expr) {
3983  return visitAnyMinMaxExpr(Expr);
3984  }
3985 
3986  RetVal visitSMinExpr(const SCEVSMinExpr *Expr) {
3987  return visitAnyMinMaxExpr(Expr);
3988  }
3989 
3990  RetVal visitUMinExpr(const SCEVUMinExpr *Expr) {
3991  return visitAnyMinMaxExpr(Expr);
3992  }
3993 
3994  RetVal visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
3995  return visitAnyMinMaxExpr(Expr);
3996  }
3997 
3998  RetVal visitUnknown(const SCEVUnknown *Expr) { return Expr; }
3999 
4000  RetVal visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { return Expr; }
4001 };
4002 
4003 } // namespace
4004 
4005 const SCEV *
4008  assert(SCEVSequentialMinMaxExpr::isSequentialMinMaxType(Kind) &&
4009  "Not a SCEVSequentialMinMaxExpr!");
4010  assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
4011  if (Ops.size() == 1)
4012  return Ops[0];
4013  if (Ops.size() == 2 &&
4014  any_of(Ops, [](const SCEV *Op) { return isa<SCEVConstant>(Op); }))
4015  return getMinMaxExpr(
4017  Ops);
4018 #ifndef NDEBUG
4019  Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
4020  for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
4021  assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
4022  "Operand types don't match!");
4023  assert(Ops[0]->getType()->isPointerTy() ==
4024  Ops[i]->getType()->isPointerTy() &&
4025  "min/max should be consistently pointerish");
4026  }
4027 #endif
4028 
4029  // Note that SCEVSequentialMinMaxExpr is *NOT* commutative,
4030  // so we can *NOT* do any kind of sorting of the expressions!
4031 
4032  // Check if we have created the same expression before.
4033  if (const SCEV *S = findExistingSCEVInCache(Kind, Ops))
4034  return S;
4035 
4036  // FIXME: there are *some* simplifications that we can do here.
4037 
4038  // Keep only the first instance of an operand.
4039  {
4040  SCEVSequentialMinMaxDeduplicatingVisitor Deduplicator(*this, Kind);
4041  bool Changed = Deduplicator.visit(Kind, Ops, Ops);
4042  if (Changed)
4043  return getSequentialMinMaxExpr(Kind, Ops);
4044  }
4045 
4046  // Check to see if one of the operands is of the same kind. If so, expand its
4047  // operands onto our operand list, and recurse to simplify.
4048  {
4049  unsigned Idx = 0;
4050  bool DeletedAny = false;
4051  while (Idx < Ops.size()) {
4052  if (Ops[Idx]->getSCEVType() != Kind) {
4053  ++Idx;
4054  continue;
4055  }
4056  const auto *SMME = cast<SCEVSequentialMinMaxExpr>(Ops[Idx]);
4057  Ops.erase(Ops.begin() + Idx);
4058  Ops.insert(Ops.begin() + Idx, SMME->op_begin(), SMME->op_end());
4059  DeletedAny = true;
4060  }
4061 
4062  if (DeletedAny)
4063  return getSequentialMinMaxExpr(Kind, Ops);
4064  }
4065 
4066  // Okay, it looks like we really DO need an expr. Check to see if we
4067  // already have one, otherwise create a new one.
4069  ID.AddInteger(Kind);
4070  for (unsigned i = 0, e = Ops.size(); i != e; ++i)
4071  ID.AddPointer(Ops[i]);
4072  void *IP = nullptr;
4073  const SCEV *ExistingSCEV = UniqueSCEVs.FindNodeOrInsertPos(ID, IP);
4074  if (ExistingSCEV)
4075  return ExistingSCEV;
4076 
4077  const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
4078  std::uninitialized_copy(Ops.begin(), Ops.end(), O);
4079  SCEV *S = new (SCEVAllocator)
4080  SCEVSequentialMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
4081 
4082  UniqueSCEVs.InsertNode(S, IP);
4083  registerUser(S, Ops);
4084  return S;
4085 }
4086 
4089  return getSMaxExpr(Ops);
4090 }
4091 
4093  return getMinMaxExpr(scSMaxExpr, Ops);
4094 }
4095 
4098  return getUMaxExpr(Ops);
4099 }
4100 
4102  return getMinMaxExpr(scUMaxExpr, Ops);
4103 }
4104 
4106  const SCEV *RHS) {
4108  return getSMinExpr(Ops);
4109 }
4110 
4112  return getMinMaxExpr(scSMinExpr, Ops);
4113 }
4114 
4116  bool Sequential) {
4118  return getUMinExpr(Ops, Sequential);
4119 }
4120 
4122  bool Sequential) {
4123  return Sequential ? getSequentialMinMaxExpr(scSequentialUMinExpr, Ops)
4124  : getMinMaxExpr(scUMinExpr, Ops);
4125 }
4126 
4127 const SCEV *
4129  ScalableVectorType *ScalableTy) {
4130  Constant *NullPtr = Constant::getNullValue(ScalableTy->getPointerTo());
4131  Constant *One = ConstantInt::get(IntTy, 1);
4132  Constant *GEP = ConstantExpr::getGetElementPtr(ScalableTy, NullPtr, One);
4133  // Note that the expression we created is the final expression, we don't
4134  // want to simplify it any further Also, if we call a normal getSCEV(),
4135  // we'll end up in an endless recursion. So just create an SCEVUnknown.
4136  return getUnknown(ConstantExpr::getPtrToInt(GEP, IntTy));
4137 }
4138 
4139 const SCEV *ScalarEvolution::getSizeOfExpr(Type *IntTy, Type *AllocTy) {
4140  if (auto *ScalableAllocTy = dyn_cast<ScalableVectorType>(AllocTy))
4141  return getSizeOfScalableVectorExpr(IntTy, ScalableAllocTy);
4142  // We can bypass creating a target-independent constant expression and then
4143  // folding it back into a ConstantInt. This is just a compile-time
4144  // optimization.
4145  return getConstant(IntTy, getDataLayout().getTypeAllocSize(AllocTy));
4146 }
4147 
4149  if (auto *ScalableStoreTy = dyn_cast<ScalableVectorType>(StoreTy))
4150  return getSizeOfScalableVectorExpr(IntTy, ScalableStoreTy);
4151  // We can bypass creating a target-independent constant expression and then
4152  // folding it back into a ConstantInt. This is just a compile-time
4153  // optimization.
4154  return getConstant(IntTy, getDataLayout().getTypeStoreSize(StoreTy));
4155 }
4156 
4158  StructType *STy,
4159  unsigned FieldNo) {
4160  // We can bypass creating a target-independent constant expression and then
4161  // folding it back into a ConstantInt. This is just a compile-time
4162  // optimization.
4163  return getConstant(
4164  IntTy, getDataLayout().getStructLayout(STy)->getElementOffset(FieldNo));
4165 }
4166 
4168  // Don't attempt to do anything other than create a SCEVUnknown object
4169  // here. createSCEV only calls getUnknown after checking for all other
4170  // interesting possibilities, and any other code that calls getUnknown
4171  // is doing so in order to hide a value from SCEV canonicalization.
4172 
4174  ID.AddInteger(scUnknown);
4175  ID.AddPointer(V);
4176  void *IP = nullptr;
4177  if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) {
4178  assert(cast<SCEVUnknown>(S)->getValue() == V &&
4179  "Stale SCEVUnknown in uniquing map!");
4180  return S;
4181  }
4182  SCEV *S = new (SCEVAllocator) SCEVUnknown(ID.Intern(SCEVAllocator), V, this,
4183  FirstUnknown);
4184  FirstUnknown = cast<SCEVUnknown>(S);
4185  UniqueSCEVs.InsertNode(S, IP);
4186  return S;
4187 }
4188 
4189 //===----------------------------------------------------------------------===//
4190 // Basic SCEV Analysis and PHI Idiom Recognition Code
4191 //
4192 
4193 /// Test if values of the given type are analyzable within the SCEV
4194 /// framework. This primarily includes integer types, and it can optionally
4195 /// include pointer types if the ScalarEvolution class has access to
4196 /// target-specific information.
4198  // Integers and pointers are always SCEVable.
4199  return Ty->isIntOrPtrTy();
4200 }
4201 
4202 /// Return the size in bits of the specified type, for which isSCEVable must
4203 /// return true.
4205  assert(isSCEVable(Ty) && "Type is not SCEVable!");
4206  if (Ty->isPointerTy())
4208  return getDataLayout().getTypeSizeInBits(Ty);
4209 }
4210 
4211 /// Return a type with the same bitwidth as the given type and which represents
4212 /// how SCEV will treat the given type, for which isSCEVable must return
4213 /// true. For pointer types, this is the pointer index sized integer type.
4215  assert(isSCEVable(Ty) && "Type is not SCEVable!");
4216 
4217  if (Ty->isIntegerTy())
4218  return Ty;
4219 
4220  // The only other support type is pointer.
4221  assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!");
4222  return getDataLayout().getIndexType(Ty);
4223 }
4224 
4226  return getTypeSizeInBits(T1) >= getTypeSizeInBits(T2) ? T1 : T2;
4227 }
4228 
4230  const SCEV *B) {
4231  /// For a valid use point to exist, the defining scope of one operand
4232  /// must dominate the other.
4233  bool PreciseA, PreciseB;
4234  auto *ScopeA = getDefiningScopeBound({A}, PreciseA);
4235  auto *ScopeB = getDefiningScopeBound({B}, PreciseB);
4236  if (!PreciseA || !PreciseB)
4237  // Can't tell.
4238  return false;
4239  return (ScopeA == ScopeB) || DT.dominates(ScopeA, ScopeB) ||
4240  DT.dominates(ScopeB, ScopeA);
4241 }
4242 
4243 
4245  return CouldNotCompute.get();
4246 }
4247 
4248 bool ScalarEvolution::checkValidity(const SCEV *S) const {
4249  bool ContainsNulls = SCEVExprContains(S, [](const SCEV *S) {
4250  auto *SU = dyn_cast<SCEVUnknown>(S);
4251  return SU && SU->getValue() == nullptr;
4252  });
4253 
4254  return !ContainsNulls;
4255 }
4256 
4258  HasRecMapType::iterator I = HasRecMap.find(S);
4259  if (I != HasRecMap.end())
4260  return I->second;
4261 
4262  bool FoundAddRec =
4263  SCEVExprContains(S, [](const SCEV *S) { return isa<SCEVAddRecExpr>(S); });
4264  HasRecMap.insert({S, FoundAddRec});
4265  return FoundAddRec;
4266 }
4267 
4268 /// Try to split a SCEVAddExpr into a pair of {SCEV, ConstantInt}.
4269 /// If \p S is a SCEVAddExpr and is composed of a sub SCEV S' and an
4270 /// offset I, then return {S', I}, else return {\p S, nullptr}.
4271 static std::pair<const SCEV *, ConstantInt *> splitAddExpr(const SCEV *S) {
4272  const auto *Add = dyn_cast<SCEVAddExpr>(S);
4273  if (!Add)
4274  return {S, nullptr};
4275 
4276  if (Add->getNumOperands() != 2)
4277  return {S, nullptr};
4278 
4279  auto *ConstOp = dyn_cast<SCEVConstant>(Add->getOperand(0));
4280  if (!ConstOp)
4281  return {S, nullptr};
4282 
4283  return {Add->getOperand(1), ConstOp->getValue()};
4284 }
4285 
4286 /// Return the ValueOffsetPair set for \p S. \p S can be represented
4287 /// by the value and offset from any ValueOffsetPair in the set.
4289 ScalarEvolution::getSCEVValues(const SCEV *S) {
4290  ExprValueMapType::iterator SI = ExprValueMap.find_as(S);
4291  if (SI == ExprValueMap.end())
4292  return nullptr;
4293 #ifndef NDEBUG
4294  if (VerifySCEVMap) {
4295  // Check there is no dangling Value in the set returned.
4296  for (const auto &VE : SI->second)
4297  assert(ValueExprMap.count(VE.first));
4298  }
4299 #endif
4300  return &SI->second;
4301 }
4302 
4303 /// Erase Value from ValueExprMap and ExprValueMap. ValueExprMap.erase(V)
4304 /// cannot be used separately. eraseValueFromMap should be used to remove
4305 /// V from ValueExprMap and ExprValueMap at the same time.
4306 void ScalarEvolution::eraseValueFromMap(Value *V) {
4307  ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4308  if (I != ValueExprMap.end()) {
4309  const SCEV *S = I->second;
4310  // Remove {V, 0} from the set of ExprValueMap[S]
4311  if (auto *SV = getSCEVValues(S))
4312  SV->remove({V, nullptr});
4313 
4314  // Remove {V, Offset} from the set of ExprValueMap[Stripped]
4315  const SCEV *Stripped;
4317  std::tie(Stripped, Offset) = splitAddExpr(S);
4318  if (Offset != nullptr) {
4319  if (auto *SV = getSCEVValues(Stripped))
4320  SV->remove({V, Offset});
4321  }
4322  ValueExprMap.erase(V);
4323  }
4324 }
4325 
4326 void ScalarEvolution::insertValueToMap(Value *V, const SCEV *S) {
4327  // A recursive query may have already computed the SCEV. It should be
4328  // equivalent, but may not necessarily be exactly the same, e.g. due to lazily
4329  // inferred nowrap flags.
4330  auto It = ValueExprMap.find_as(V);
4331  if (It == ValueExprMap.end()) {
4332  ValueExprMap.insert({SCEVCallbackVH(V, this), S});
4333  ExprValueMap[S].insert({V, nullptr});
4334  }
4335 }
4336 
4337 /// Return an existing SCEV if it exists, otherwise analyze the expression and
4338 /// create a new one.
4340  assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4341 
4342  const SCEV *S = getExistingSCEV(V);
4343  if (S == nullptr) {
4344  S = createSCEV(V);
4345  // During PHI resolution, it is possible to create two SCEVs for the same
4346  // V, so it is needed to double check whether V->S is inserted into
4347  // ValueExprMap before insert S->{V, 0} into ExprValueMap.
4348  std::pair<ValueExprMapType::iterator, bool> Pair =
4349  ValueExprMap.insert({SCEVCallbackVH(V, this), S});
4350  if (Pair.second) {
4351  ExprValueMap[S].insert({V, nullptr});
4352 
4353  // If S == Stripped + Offset, add Stripped -> {V, Offset} into
4354  // ExprValueMap.
4355  const SCEV *Stripped = S;
4356  ConstantInt *Offset = nullptr;
4357  std::tie(Stripped, Offset) = splitAddExpr(S);
4358  // If stripped is SCEVUnknown, don't bother to save
4359  // Stripped -> {V, offset}. It doesn't simplify and sometimes even
4360  // increase the complexity of the expansion code.
4361  // If V is GetElementPtrInst, don't save Stripped -> {V, offset}
4362  // because it may generate add/sub instead of GEP in SCEV expansion.
4363  if (Offset != nullptr && !isa<SCEVUnknown>(Stripped) &&
4364  !isa<GetElementPtrInst>(V))
4365  ExprValueMap[Stripped].insert({V, Offset});
4366  }
4367  }
4368  return S;
4369 }
4370 
4371 const SCEV *ScalarEvolution::getExistingSCEV(Value *V) {
4372  assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
4373 
4374  ValueExprMapType::iterator I = ValueExprMap.find_as(V);
4375  if (I != ValueExprMap.end()) {
4376  const SCEV *S = I->second;
4377  assert(checkValidity(S) &&
4378  "existing SCEV has not been properly invalidated");
4379  return S;
4380  }
4381  return nullptr;
4382 }
4383 
4384 /// Return a SCEV corresponding to -V = -1*V
4386  SCEV::NoWrapFlags Flags) {
4387  if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4388  return getConstant(
4389  cast<ConstantInt>(ConstantExpr::getNeg(VC->getValue())));
4390 
4391  Type *Ty = V->getType();
4392  Ty = getEffectiveSCEVType(Ty);
4393  return getMulExpr(V, getMinusOne(Ty), Flags);
4394 }
4395 
4396 /// If Expr computes ~A, return A else return nullptr
4397 static const SCEV *MatchNotExpr(const SCEV *Expr) {
4398  const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Expr);
4399  if (!Add || Add->getNumOperands() != 2 ||
4400  !Add->getOperand(0)->isAllOnesValue())
4401  return nullptr;
4402 
4403  const SCEVMulExpr *AddRHS = dyn_cast<SCEVMulExpr>(Add->getOperand(1));
4404  if (!AddRHS || AddRHS->getNumOperands() != 2 ||
4405  !AddRHS->getOperand(0)->isAllOnesValue())
4406  return nullptr;
4407 
4408  return AddRHS->getOperand(1);
4409 }
4410 
4411 /// Return a SCEV corresponding to ~V = -1-V
4413  assert(!V->getType()->isPointerTy() && "Can't negate pointer");
4414 
4415  if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
4416  return getConstant(
4417  cast<ConstantInt>(ConstantExpr::getNot(VC->getValue())));
4418 
4419  // Fold ~(u|s)(min|max)(~x, ~y) to (u|s)(max|min)(x, y)
4420  if (const SCEVMinMaxExpr *MME = dyn_cast<SCEVMinMaxExpr>(V)) {
4421  auto MatchMinMaxNegation = [&](const SCEVMinMaxExpr *MME) {
4422  SmallVector<const SCEV *, 2> MatchedOperands;
4423  for (const SCEV *Operand : MME->operands()) {
4424  const SCEV *Matched = MatchNotExpr(Operand);
4425  if (!Matched)
4426  return (const SCEV *)nullptr;
4427  MatchedOperands.push_back(Matched);
4428  }
4429  return getMinMaxExpr(SCEVMinMaxExpr::negate(MME->getSCEVType()),
4430  MatchedOperands);
4431  };
4432  if (const SCEV *Replaced = MatchMinMaxNegation(MME))
4433  return Replaced;
4434  }
4435 
4436  Type *Ty = V->getType();
4437  Ty = getEffectiveSCEVType(Ty);
4438  return getMinusSCEV(getMinusOne(Ty), V);
4439 }
4440 
4442  assert(P->getType()->isPointerTy());
4443 
4444  if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(P)) {
4445  // The base of an AddRec is the first operand.
4446  SmallVector<const SCEV *> Ops{AddRec->operands()};
4447  Ops[0] = removePointerBase(Ops[0]);
4448  // Don't try to transfer nowrap flags for now. We could in some cases
4449  // (for example, if pointer operand of the AddRec is a SCEVUnknown).
4450  return getAddRecExpr(Ops, AddRec->getLoop(), SCEV::FlagAnyWrap);
4451  }
4452  if (auto *Add = dyn_cast<SCEVAddExpr>(P)) {
4453  // The base of an Add is the pointer operand.
4454  SmallVector<const SCEV *> Ops{Add->operands()};
4455  const SCEV **PtrOp = nullptr;
4456  for (const SCEV *&AddOp : Ops) {
4457  if (AddOp->getType()->isPointerTy()) {
4458  assert(!PtrOp && "Cannot have multiple pointer ops");
4459  PtrOp = &AddOp;
4460  }
4461  }
4462  *PtrOp = removePointerBase(*PtrOp);
4463  // Don't try to transfer nowrap flags for now. We could in some cases
4464  // (for example, if the pointer operand of the Add is a SCEVUnknown).
4465  return getAddExpr(Ops);
4466  }
4467  // Any other expression must be a pointer base.
4468  return getZero(P->getType());
4469 }
4470 
4472  SCEV::NoWrapFlags Flags,
4473  unsigned Depth) {
4474  // Fast path: X - X --> 0.
4475  if (LHS == RHS)
4476  return getZero(LHS->getType());
4477 
4478  // If we subtract two pointers with different pointer bases, bail.
4479  // Eventually, we're going to add an assertion to getMulExpr that we
4480  // can't multiply by a pointer.
4481  if (RHS->getType()->isPointerTy()) {
4482  if (!LHS->getType()->isPointerTy() ||
4484  return getCouldNotCompute();
4487  }
4488 
4489  // We represent LHS - RHS as LHS + (-1)*RHS. This transformation
4490  // makes it so that we cannot make much use of NUW.
4491  auto AddFlags = SCEV::FlagAnyWrap;
4492  const bool RHSIsNotMinSigned =
4494  if (hasFlags(Flags, SCEV::FlagNSW)) {
4495  // Let M be the minimum representable signed value. Then (-1)*RHS
4496  // signed-wraps if and only if RHS is M. That can happen even for
4497  // a NSW subtraction because e.g. (-1)*M signed-wraps even though
4498  // -1 - M does not. So to transfer NSW from LHS - RHS to LHS +
4499  // (-1)*RHS, we need to prove that RHS != M.
4500  //
4501  // If LHS is non-negative and we know that LHS - RHS does not
4502  // signed-wrap, then RHS cannot be M. So we can rule out signed-wrap
4503  // either by proving that RHS > M or that LHS >= 0.
4504  if (RHSIsNotMinSigned || isKnownNonNegative(LHS)) {
4505  AddFlags = SCEV::FlagNSW;
4506  }
4507  }
4508 
4509  // FIXME: Find a correct way to transfer NSW to (-1)*M when LHS -
4510  // RHS is NSW and LHS >= 0.
4511  //
4512  // The difficulty here is that the NSW flag may have been proven
4513  // relative to a loop that is to be found in a recurrence in LHS and
4514  // not in RHS. Applying NSW to (-1)*M may then let the NSW have a
4515  // larger scope than intended.
4516  auto NegFlags = RHSIsNotMinSigned ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
4517 
4518  return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags, Depth);
4519 }
4520 
4522  unsigned Depth) {
4523  Type *SrcTy = V->getType();
4524  assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4525  "Cannot truncate or zero extend with non-integer arguments!");
4526  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4527  return V; // No conversion
4528  if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4529  return getTruncateExpr(V, Ty, Depth);
4530  return getZeroExtendExpr(V, Ty, Depth);
4531 }
4532 
4534  unsigned Depth) {
4535  Type *SrcTy = V->getType();
4536  assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4537  "Cannot truncate or zero extend with non-integer arguments!");
4538  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4539  return V; // No conversion
4540  if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4541  return getTruncateExpr(V, Ty, Depth);
4542  return getSignExtendExpr(V, Ty, Depth);
4543 }
4544 
4545 const SCEV *
4547  Type *SrcTy = V->getType();
4548  assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4549  "Cannot noop or zero extend with non-integer arguments!");
4550  assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
4551  "getNoopOrZeroExtend cannot truncate!");
4552  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4553  return V; // No conversion
4554  return getZeroExtendExpr(V, Ty);
4555 }
4556 
4557 const SCEV *
4559  Type *SrcTy = V->getType();
4560  assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4561  "Cannot noop or sign extend with non-integer arguments!");
4562  assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
4563  "getNoopOrSignExtend cannot truncate!");
4564  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4565  return V; // No conversion
4566  return getSignExtendExpr(V, Ty);
4567 }
4568 
4569 const SCEV *
4571  Type *SrcTy = V->getType();
4572  assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4573  "Cannot noop or any extend with non-integer arguments!");
4574  assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
4575  "getNoopOrAnyExtend cannot truncate!");
4576  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4577  return V; // No conversion
4578  return getAnyExtendExpr(V, Ty);
4579 }
4580 
4581 const SCEV *
4583  Type *SrcTy = V->getType();
4584  assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4585  "Cannot truncate or noop with non-integer arguments!");
4586  assert(getTypeSizeInBits(SrcTy) >= getTypeSizeInBits(Ty) &&
4587  "getTruncateOrNoop cannot extend!");
4588  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4589  return V; // No conversion
4590  return getTruncateExpr(V, Ty);
4591 }
4592 
4594  const SCEV *RHS) {
4595  const SCEV *PromotedLHS = LHS;
4596  const SCEV *PromotedRHS = RHS;
4597 
4599  PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
4600  else
4601  PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
4602 
4603  return getUMaxExpr(PromotedLHS, PromotedRHS);
4604 }
4605 
4607  const SCEV *RHS,
4608  bool Sequential) {
4609  S