LLVM  13.0.0git
ScalarEvolution.cpp
Go to the documentation of this file.
1 //===- ScalarEvolution.cpp - Scalar Evolution Analysis --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the implementation of the scalar evolution analysis
10 // engine, which is used primarily to analyze expressions involving induction
11 // variables in loops.
12 //
13 // There are several aspects to this library. First is the representation of
14 // scalar expressions, which are represented as subclasses of the SCEV class.
15 // These classes are used to represent certain types of subexpressions that we
16 // can handle. We only create one SCEV of a particular shape, so
17 // pointer-comparisons for equality are legal.
18 //
19 // One important aspect of the SCEV objects is that they are never cyclic, even
20 // if there is a cycle in the dataflow for an expression (ie, a PHI node). If
21 // the PHI node is one of the idioms that we can represent (e.g., a polynomial
22 // recurrence) then we represent it directly as a recurrence node, otherwise we
23 // represent it as a SCEVUnknown node.
24 //
25 // In addition to being able to represent expressions of various types, we also
26 // have folders that are used to build the *canonical* representation for a
27 // particular expression. These folders are capable of using a variety of
28 // rewrite rules to simplify the expressions.
29 //
30 // Once the folders are defined, we can implement the more interesting
31 // higher-level code, such as the code that recognizes PHI nodes of various
32 // types, computes the execution count of a loop, etc.
33 //
34 // TODO: We should use these routines and value representations to implement
35 // dependence analysis!
36 //
37 //===----------------------------------------------------------------------===//
38 //
39 // There are several good references for the techniques used in this analysis.
40 //
41 // Chains of recurrences -- a method to expedite the evaluation
42 // of closed-form functions
43 // Olaf Bachmann, Paul S. Wang, Eugene V. Zima
44 //
45 // On computational properties of chains of recurrences
46 // Eugene V. Zima
47 //
48 // Symbolic Evaluation of Chains of Recurrences for Loop Optimization
49 // Robert A. van Engelen
50 //
51 // Efficient Symbolic Analysis for Optimizing Compilers
52 // Robert A. van Engelen
53 //
54 // Using the chains of recurrences algebra for data dependence testing and
55 // induction variable substitution
56 // MS Thesis, Johnie Birch
57 //
58 //===----------------------------------------------------------------------===//
59 
61 #include "llvm/ADT/APInt.h"
62 #include "llvm/ADT/ArrayRef.h"
63 #include "llvm/ADT/DenseMap.h"
66 #include "llvm/ADT/FoldingSet.h"
67 #include "llvm/ADT/None.h"
68 #include "llvm/ADT/Optional.h"
69 #include "llvm/ADT/STLExtras.h"
70 #include "llvm/ADT/ScopeExit.h"
71 #include "llvm/ADT/Sequence.h"
72 #include "llvm/ADT/SetVector.h"
73 #include "llvm/ADT/SmallPtrSet.h"
74 #include "llvm/ADT/SmallSet.h"
75 #include "llvm/ADT/SmallVector.h"
76 #include "llvm/ADT/Statistic.h"
77 #include "llvm/ADT/StringRef.h"
81 #include "llvm/Analysis/LoopInfo.h"
86 #include "llvm/Config/llvm-config.h"
87 #include "llvm/IR/Argument.h"
88 #include "llvm/IR/BasicBlock.h"
89 #include "llvm/IR/CFG.h"
90 #include "llvm/IR/Constant.h"
91 #include "llvm/IR/ConstantRange.h"
92 #include "llvm/IR/Constants.h"
93 #include "llvm/IR/DataLayout.h"
94 #include "llvm/IR/DerivedTypes.h"
95 #include "llvm/IR/Dominators.h"
96 #include "llvm/IR/Function.h"
97 #include "llvm/IR/GlobalAlias.h"
98 #include "llvm/IR/GlobalValue.h"
99 #include "llvm/IR/GlobalVariable.h"
100 #include "llvm/IR/InstIterator.h"
101 #include "llvm/IR/InstrTypes.h"
102 #include "llvm/IR/Instruction.h"
103 #include "llvm/IR/Instructions.h"
104 #include "llvm/IR/IntrinsicInst.h"
105 #include "llvm/IR/Intrinsics.h"
106 #include "llvm/IR/LLVMContext.h"
107 #include "llvm/IR/Metadata.h"
108 #include "llvm/IR/Operator.h"
109 #include "llvm/IR/PatternMatch.h"
110 #include "llvm/IR/Type.h"
111 #include "llvm/IR/Use.h"
112 #include "llvm/IR/User.h"
113 #include "llvm/IR/Value.h"
114 #include "llvm/IR/Verifier.h"
115 #include "llvm/InitializePasses.h"
116 #include "llvm/Pass.h"
117 #include "llvm/Support/Casting.h"
119 #include "llvm/Support/Compiler.h"
120 #include "llvm/Support/Debug.h"
122 #include "llvm/Support/KnownBits.h"
125 #include <algorithm>
126 #include <cassert>
127 #include <climits>
128 #include <cstddef>
129 #include <cstdint>
130 #include <cstdlib>
131 #include <map>
132 #include <memory>
133 #include <tuple>
134 #include <utility>
135 #include <vector>
136 
137 using namespace llvm;
138 using namespace PatternMatch;
139 
140 #define DEBUG_TYPE "scalar-evolution"
141 
142 STATISTIC(NumArrayLenItCounts,
143  "Number of trip counts computed with array length");
144 STATISTIC(NumTripCountsComputed,
145  "Number of loops with predictable loop counts");
146 STATISTIC(NumTripCountsNotComputed,
147  "Number of loops without predictable loop counts");
148 STATISTIC(NumBruteForceTripCountsComputed,
149  "Number of loops with trip counts computed by force");
150 
151 static cl::opt<unsigned>
152 MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
154  cl::desc("Maximum number of iterations SCEV will "
155  "symbolically execute a constant "
156  "derived loop"),
157  cl::init(100));
158 
159 // FIXME: Enable this with EXPENSIVE_CHECKS when the test suite is clean.
161  "verify-scev", cl::Hidden,
162  cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"));
164  "verify-scev-strict", cl::Hidden,
165  cl::desc("Enable stricter verification with -verify-scev is passed"));
166 static cl::opt<bool>
167  VerifySCEVMap("verify-scev-maps", cl::Hidden,
168  cl::desc("Verify no dangling value in ScalarEvolution's "
169  "ExprValueMap (slow)"));
170 
171 static cl::opt<bool> VerifyIR(
172  "scev-verify-ir", cl::Hidden,
173  cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"),
174  cl::init(false));
175 
177  "scev-mulops-inline-threshold", cl::Hidden,
178  cl::desc("Threshold for inlining multiplication operands into a SCEV"),
179  cl::init(32));
180 
182  "scev-addops-inline-threshold", cl::Hidden,
183  cl::desc("Threshold for inlining addition operands into a SCEV"),
184  cl::init(500));
185 
187  "scalar-evolution-max-scev-compare-depth", cl::Hidden,
188  cl::desc("Maximum depth of recursive SCEV complexity comparisons"),
189  cl::init(32));
190 
192  "scalar-evolution-max-scev-operations-implication-depth", cl::Hidden,
193  cl::desc("Maximum depth of recursive SCEV operations implication analysis"),
194  cl::init(2));
195 
197  "scalar-evolution-max-value-compare-depth", cl::Hidden,
198  cl::desc("Maximum depth of recursive value complexity comparisons"),
199  cl::init(2));
200 
201 static cl::opt<unsigned>
202  MaxArithDepth("scalar-evolution-max-arith-depth", cl::Hidden,
203  cl::desc("Maximum depth of recursive arithmetics"),
204  cl::init(32));
205 
207  "scalar-evolution-max-constant-evolving-depth", cl::Hidden,
208  cl::desc("Maximum depth of recursive constant evolving"), cl::init(32));
209 
210 static cl::opt<unsigned>
211  MaxCastDepth("scalar-evolution-max-cast-depth", cl::Hidden,
212  cl::desc("Maximum depth of recursive SExt/ZExt/Trunc"),
213  cl::init(8));
214 
215 static cl::opt<unsigned>
216  MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden,
217  cl::desc("Max coefficients in AddRec during evolving"),
218  cl::init(8));
219 
220 static cl::opt<unsigned>
221  HugeExprThreshold("scalar-evolution-huge-expr-threshold", cl::Hidden,
222  cl::desc("Size of the expression which is considered huge"),
223  cl::init(4096));
224 
225 static cl::opt<bool>
226 ClassifyExpressions("scalar-evolution-classify-expressions",
227  cl::Hidden, cl::init(true),
228  cl::desc("When printing analysis, include information on every instruction"));
229 
231  "scalar-evolution-use-expensive-range-sharpening", cl::Hidden,
232  cl::init(false),
233  cl::desc("Use more powerful methods of sharpening expression ranges. May "
234  "be costly in terms of compile time"));
235 
236 //===----------------------------------------------------------------------===//
237 // SCEV class definitions
238 //===----------------------------------------------------------------------===//
239 
240 //===----------------------------------------------------------------------===//
241 // Implementation of the SCEV class.
242 //
243 
244 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
246  print(dbgs());
247  dbgs() << '\n';
248 }
249 #endif
250 
251 void SCEV::print(raw_ostream &OS) const {
252  switch (getSCEVType()) {
253  case scConstant:
254  cast<SCEVConstant>(this)->getValue()->printAsOperand(OS, false);
255  return;
256  case scPtrToInt: {
257  const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(this);
258  const SCEV *Op = PtrToInt->getOperand();
259  OS << "(ptrtoint " << *Op->getType() << " " << *Op << " to "
260  << *PtrToInt->getType() << ")";
261  return;
262  }
263  case scTruncate: {
264  const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(this);
265  const SCEV *Op = Trunc->getOperand();
266  OS << "(trunc " << *Op->getType() << " " << *Op << " to "
267  << *Trunc->getType() << ")";
268  return;
269  }
270  case scZeroExtend: {
271  const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(this);
272  const SCEV *Op = ZExt->getOperand();
273  OS << "(zext " << *Op->getType() << " " << *Op << " to "
274  << *ZExt->getType() << ")";
275  return;
276  }
277  case scSignExtend: {
278  const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(this);
279  const SCEV *Op = SExt->getOperand();
280  OS << "(sext " << *Op->getType() << " " << *Op << " to "
281  << *SExt->getType() << ")";
282  return;
283  }
284  case scAddRecExpr: {
285  const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(this);
286  OS << "{" << *AR->getOperand(0);
287  for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i)
288  OS << ",+," << *AR->getOperand(i);
289  OS << "}<";
290  if (AR->hasNoUnsignedWrap())
291  OS << "nuw><";
292  if (AR->hasNoSignedWrap())
293  OS << "nsw><";
294  if (AR->hasNoSelfWrap() &&
295  !AR->getNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW)))
296  OS << "nw><";
297  AR->getLoop()->getHeader()->printAsOperand(OS, /*PrintType=*/false);
298  OS << ">";
299  return;
300  }
301  case scAddExpr:
302  case scMulExpr:
303  case scUMaxExpr:
304  case scSMaxExpr:
305  case scUMinExpr:
306  case scSMinExpr: {
307  const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(this);
308  const char *OpStr = nullptr;
309  switch (NAry->getSCEVType()) {
310  case scAddExpr: OpStr = " + "; break;
311  case scMulExpr: OpStr = " * "; break;
312  case scUMaxExpr: OpStr = " umax "; break;
313  case scSMaxExpr: OpStr = " smax "; break;
314  case scUMinExpr:
315  OpStr = " umin ";
316  break;
317  case scSMinExpr:
318  OpStr = " smin ";
319  break;
320  default:
321  llvm_unreachable("There are no other nary expression types.");
322  }
323  OS << "(";
324  ListSeparator LS(OpStr);
325  for (const SCEV *Op : NAry->operands())
326  OS << LS << *Op;
327  OS << ")";
328  switch (NAry->getSCEVType()) {
329  case scAddExpr:
330  case scMulExpr:
331  if (NAry->hasNoUnsignedWrap())
332  OS << "<nuw>";
333  if (NAry->hasNoSignedWrap())
334  OS << "<nsw>";
335  break;
336  default:
337  // Nothing to print for other nary expressions.
338  break;
339  }
340  return;
341  }
342  case scUDivExpr: {
343  const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(this);
344  OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")";
345  return;
346  }
347  case scUnknown: {
348  const SCEVUnknown *U = cast<SCEVUnknown>(this);
349  Type *AllocTy;
350  if (U->isSizeOf(AllocTy)) {
351  OS << "sizeof(" << *AllocTy << ")";
352  return;
353  }
354  if (U->isAlignOf(AllocTy)) {
355  OS << "alignof(" << *AllocTy << ")";
356  return;
357  }
358 
359  Type *CTy;
360  Constant *FieldNo;
361  if (U->isOffsetOf(CTy, FieldNo)) {
362  OS << "offsetof(" << *CTy << ", ";
363  FieldNo->printAsOperand(OS, false);
364  OS << ")";
365  return;
366  }
367 
368  // Otherwise just print it normally.
369  U->getValue()->printAsOperand(OS, false);
370  return;
371  }
372  case scCouldNotCompute:
373  OS << "***COULDNOTCOMPUTE***";
374  return;
375  }
376  llvm_unreachable("Unknown SCEV kind!");
377 }
378 
379 Type *SCEV::getType() const {
380  switch (getSCEVType()) {
381  case scConstant:
382  return cast<SCEVConstant>(this)->getType();
383  case scPtrToInt:
384  case scTruncate:
385  case scZeroExtend:
386  case scSignExtend:
387  return cast<SCEVCastExpr>(this)->getType();
388  case scAddRecExpr:
389  case scMulExpr:
390  case scUMaxExpr:
391  case scSMaxExpr:
392  case scUMinExpr:
393  case scSMinExpr:
394  return cast<SCEVNAryExpr>(this)->getType();
395  case scAddExpr:
396  return cast<SCEVAddExpr>(this)->getType();
397  case scUDivExpr:
398  return cast<SCEVUDivExpr>(this)->getType();
399  case scUnknown:
400  return cast<SCEVUnknown>(this)->getType();
401  case scCouldNotCompute:
402  llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
403  }
404  llvm_unreachable("Unknown SCEV kind!");
405 }
406 
407 bool SCEV::isZero() const {
408  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
409  return SC->getValue()->isZero();
410  return false;
411 }
412 
413 bool SCEV::isOne() const {
414  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
415  return SC->getValue()->isOne();
416  return false;
417 }
418 
419 bool SCEV::isAllOnesValue() const {
420  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
421  return SC->getValue()->isMinusOne();
422  return false;
423 }
424 
426  const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(this);
427  if (!Mul) return false;
428 
429  // If there is a constant factor, it will be first.
430  const SCEVConstant *SC = dyn_cast<SCEVConstant>(Mul->getOperand(0));
431  if (!SC) return false;
432 
433  // Return true if the value is negative, this matches things like (-42 * V).
434  return SC->getAPInt().isNegative();
435 }
436 
439 
441  return S->getSCEVType() == scCouldNotCompute;
442 }
443 
446  ID.AddInteger(scConstant);
447  ID.AddPointer(V);
448  void *IP = nullptr;
449  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
450  SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V);
451  UniqueSCEVs.InsertNode(S, IP);
452  return S;
453 }
454 
456  return getConstant(ConstantInt::get(getContext(), Val));
457 }
458 
459 const SCEV *
460 ScalarEvolution::getConstant(Type *Ty, uint64_t V, bool isSigned) {
461  IntegerType *ITy = cast<IntegerType>(getEffectiveSCEVType(Ty));
462  return getConstant(ConstantInt::get(ITy, V, isSigned));
463 }
464 
466  const SCEV *op, Type *ty)
467  : SCEV(ID, SCEVTy, computeExpressionSize(op)), Ty(ty) {
468  Operands[0] = op;
469 }
470 
471 SCEVPtrToIntExpr::SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op,
472  Type *ITy)
473  : SCEVCastExpr(ID, scPtrToInt, Op, ITy) {
474  assert(getOperand()->getType()->isPointerTy() && Ty->isIntegerTy() &&
475  "Must be a non-bit-width-changing pointer-to-integer cast!");
476 }
477 
479  SCEVTypes SCEVTy, const SCEV *op,
480  Type *ty)
481  : SCEVCastExpr(ID, SCEVTy, op, ty) {}
482 
483 SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op,
484  Type *ty)
486  assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
487  "Cannot truncate non-integer value!");
488 }
489 
490 SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID,
491  const SCEV *op, Type *ty)
493  assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
494  "Cannot zero extend non-integer value!");
495 }
496 
497 SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID,
498  const SCEV *op, Type *ty)
500  assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
501  "Cannot sign extend non-integer value!");
502 }
503 
504 void SCEVUnknown::deleted() {
505  // Clear this SCEVUnknown from various maps.
506  SE->forgetMemoizedResults(this);
507 
508  // Remove this SCEVUnknown from the uniquing map.
509  SE->UniqueSCEVs.RemoveNode(this);
510 
511  // Release the value.
512  setValPtr(nullptr);
513 }
514 
515 void SCEVUnknown::allUsesReplacedWith(Value *New) {
516  // Remove this SCEVUnknown from the uniquing map.
517  SE->UniqueSCEVs.RemoveNode(this);
518 
519  // Update this SCEVUnknown to point to the new value. This is needed
520  // because there may still be outstanding SCEVs which still point to
521  // this SCEVUnknown.
522  setValPtr(New);
523 }
524 
525 bool SCEVUnknown::isSizeOf(Type *&AllocTy) const {
526  if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue()))
527  if (VCE->getOpcode() == Instruction::PtrToInt)
528  if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
529  if (CE->getOpcode() == Instruction::GetElementPtr &&
530  CE->getOperand(0)->isNullValue() &&
531  CE->getNumOperands() == 2)
532  if (ConstantInt *CI = dyn_cast<ConstantInt>(CE->getOperand(1)))
533  if (CI->isOne()) {
534  AllocTy = cast<PointerType>(CE->getOperand(0)->getType())
535  ->getElementType();
536  return true;
537  }
538 
539  return false;
540 }
541 
542 bool SCEVUnknown::isAlignOf(Type *&AllocTy) const {
543  if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue()))
544  if (VCE->getOpcode() == Instruction::PtrToInt)
545  if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
546  if (CE->getOpcode() == Instruction::GetElementPtr &&
547  CE->getOperand(0)->isNullValue()) {
548  Type *Ty =
549  cast<PointerType>(CE->getOperand(0)->getType())->getElementType();
550  if (StructType *STy = dyn_cast<StructType>(Ty))
551  if (!STy->isPacked() &&
552  CE->getNumOperands() == 3 &&
553  CE->getOperand(1)->isNullValue()) {
554  if (ConstantInt *CI = dyn_cast<ConstantInt>(CE->getOperand(2)))
555  if (CI->isOne() &&
556  STy->getNumElements() == 2 &&
557  STy->getElementType(0)->isIntegerTy(1)) {
558  AllocTy = STy->getElementType(1);
559  return true;
560  }
561  }
562  }
563 
564  return false;
565 }
566 
567 bool SCEVUnknown::isOffsetOf(Type *&CTy, Constant *&FieldNo) const {
568  if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue()))
569  if (VCE->getOpcode() == Instruction::PtrToInt)
570  if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
571  if (CE->getOpcode() == Instruction::GetElementPtr &&
572  CE->getNumOperands() == 3 &&
573  CE->getOperand(0)->isNullValue() &&
574  CE->getOperand(1)->isNullValue()) {
575  Type *Ty =
576  cast<PointerType>(CE->getOperand(0)->getType())->getElementType();
577  // Ignore vector types here so that ScalarEvolutionExpander doesn't
578  // emit getelementptrs that index into vectors.
579  if (Ty->isStructTy() || Ty->isArrayTy()) {
580  CTy = Ty;
581  FieldNo = CE->getOperand(2);
582  return true;
583  }
584  }
585 
586  return false;
587 }
588 
589 //===----------------------------------------------------------------------===//
590 // SCEV Utilities
591 //===----------------------------------------------------------------------===//
592 
593 /// Compare the two values \p LV and \p RV in terms of their "complexity" where
594 /// "complexity" is a partial (and somewhat ad-hoc) relation used to order
595 /// operands in SCEV expressions. \p EqCache is a set of pairs of values that
596 /// have been previously deemed to be "equally complex" by this routine. It is
597 /// intended to avoid exponential time complexity in cases like:
598 ///
599 /// %a = f(%x, %y)
600 /// %b = f(%a, %a)
601 /// %c = f(%b, %b)
602 ///
603 /// %d = f(%x, %y)
604 /// %e = f(%d, %d)
605 /// %f = f(%e, %e)
606 ///
607 /// CompareValueComplexity(%f, %c)
608 ///
609 /// Since we do not continue running this routine on expression trees once we
610 /// have seen unequal values, there is no need to track them in the cache.
611 static int
613  const LoopInfo *const LI, Value *LV, Value *RV,
614  unsigned Depth) {
615  if (Depth > MaxValueCompareDepth || EqCacheValue.isEquivalent(LV, RV))
616  return 0;
617 
618  // Order pointer values after integer values. This helps SCEVExpander form
619  // GEPs.
620  bool LIsPointer = LV->getType()->isPointerTy(),
621  RIsPointer = RV->getType()->isPointerTy();
622  if (LIsPointer != RIsPointer)
623  return (int)LIsPointer - (int)RIsPointer;
624 
625  // Compare getValueID values.
626  unsigned LID = LV->getValueID(), RID = RV->getValueID();
627  if (LID != RID)
628  return (int)LID - (int)RID;
629 
630  // Sort arguments by their position.
631  if (const auto *LA = dyn_cast<Argument>(LV)) {
632  const auto *RA = cast<Argument>(RV);
633  unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo();
634  return (int)LArgNo - (int)RArgNo;
635  }
636 
637  if (const auto *LGV = dyn_cast<GlobalValue>(LV)) {
638  const auto *RGV = cast<GlobalValue>(RV);
639 
640  const auto IsGVNameSemantic = [&](const GlobalValue *GV) {
641  auto LT = GV->getLinkage();
642  return !(GlobalValue::isPrivateLinkage(LT) ||
644  };
645 
646  // Use the names to distinguish the two values, but only if the
647  // names are semantically important.
648  if (IsGVNameSemantic(LGV) && IsGVNameSemantic(RGV))
649  return LGV->getName().compare(RGV->getName());
650  }
651 
652  // For instructions, compare their loop depth, and their operand count. This
653  // is pretty loose.
654  if (const auto *LInst = dyn_cast<Instruction>(LV)) {
655  const auto *RInst = cast<Instruction>(RV);
656 
657  // Compare loop depths.
658  const BasicBlock *LParent = LInst->getParent(),
659  *RParent = RInst->getParent();
660  if (LParent != RParent) {
661  unsigned LDepth = LI->getLoopDepth(LParent),
662  RDepth = LI->getLoopDepth(RParent);
663  if (LDepth != RDepth)
664  return (int)LDepth - (int)RDepth;
665  }
666 
667  // Compare the number of operands.
668  unsigned LNumOps = LInst->getNumOperands(),
669  RNumOps = RInst->getNumOperands();
670  if (LNumOps != RNumOps)
671  return (int)LNumOps - (int)RNumOps;
672 
673  for (unsigned Idx : seq(0u, LNumOps)) {
674  int Result =
675  CompareValueComplexity(EqCacheValue, LI, LInst->getOperand(Idx),
676  RInst->getOperand(Idx), Depth + 1);
677  if (Result != 0)
678  return Result;
679  }
680  }
681 
682  EqCacheValue.unionSets(LV, RV);
683  return 0;
684 }
685 
686 // Return negative, zero, or positive, if LHS is less than, equal to, or greater
687 // than RHS, respectively. A three-way result allows recursive comparisons to be
688 // more efficient.
689 // If the max analysis depth was reached, return None, assuming we do not know
690 // if they are equivalent for sure.
691 static Optional<int>
693  EquivalenceClasses<const Value *> &EqCacheValue,
694  const LoopInfo *const LI, const SCEV *LHS,
695  const SCEV *RHS, DominatorTree &DT, unsigned Depth = 0) {
696  // Fast-path: SCEVs are uniqued so we can do a quick equality check.
697  if (LHS == RHS)
698  return 0;
699 
700  // Primarily, sort the SCEVs by their getSCEVType().
701  SCEVTypes LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
702  if (LType != RType)
703  return (int)LType - (int)RType;
704 
705  if (EqCacheSCEV.isEquivalent(LHS, RHS))
706  return 0;
707 
709  return None;
710 
711  // Aside from the getSCEVType() ordering, the particular ordering
712  // isn't very important except that it's beneficial to be consistent,
713  // so that (a + b) and (b + a) don't end up as different expressions.
714  switch (LType) {
715  case scUnknown: {
716  const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
717  const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
718 
719  int X = CompareValueComplexity(EqCacheValue, LI, LU->getValue(),
720  RU->getValue(), Depth + 1);
721  if (X == 0)
722  EqCacheSCEV.unionSets(LHS, RHS);
723  return X;
724  }
725 
726  case scConstant: {
727  const SCEVConstant *LC = cast<SCEVConstant>(LHS);
728  const SCEVConstant *RC = cast<SCEVConstant>(RHS);
729 
730  // Compare constant values.
731  const APInt &LA = LC->getAPInt();
732  const APInt &RA = RC->getAPInt();
733  unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth();
734  if (LBitWidth != RBitWidth)
735  return (int)LBitWidth - (int)RBitWidth;
736  return LA.ult(RA) ? -1 : 1;
737  }
738 
739  case scAddRecExpr: {
740  const SCEVAddRecExpr *LA = cast<SCEVAddRecExpr>(LHS);
741  const SCEVAddRecExpr *RA = cast<SCEVAddRecExpr>(RHS);
742 
743  // There is always a dominance between two recs that are used by one SCEV,
744  // so we can safely sort recs by loop header dominance. We require such
745  // order in getAddExpr.
746  const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop();
747  if (LLoop != RLoop) {
748  const BasicBlock *LHead = LLoop->getHeader(), *RHead = RLoop->getHeader();
749  assert(LHead != RHead && "Two loops share the same header?");
750  if (DT.dominates(LHead, RHead))
751  return 1;
752  else
753  assert(DT.dominates(RHead, LHead) &&
754  "No dominance between recurrences used by one SCEV?");
755  return -1;
756  }
757 
758  // Addrec complexity grows with operand count.
759  unsigned LNumOps = LA->getNumOperands(), RNumOps = RA->getNumOperands();
760  if (LNumOps != RNumOps)
761  return (int)LNumOps - (int)RNumOps;
762 
763  // Lexicographically compare.
764  for (unsigned i = 0; i != LNumOps; ++i) {
765  auto X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI,
766  LA->getOperand(i), RA->getOperand(i), DT,
767  Depth + 1);
768  if (X != 0)
769  return X;
770  }
771  EqCacheSCEV.unionSets(LHS, RHS);
772  return 0;
773  }
774 
775  case scAddExpr:
776  case scMulExpr:
777  case scSMaxExpr:
778  case scUMaxExpr:
779  case scSMinExpr:
780  case scUMinExpr: {
781  const SCEVNAryExpr *LC = cast<SCEVNAryExpr>(LHS);
782  const SCEVNAryExpr *RC = cast<SCEVNAryExpr>(RHS);
783 
784  // Lexicographically compare n-ary expressions.
785  unsigned LNumOps = LC->getNumOperands(), RNumOps = RC->getNumOperands();
786  if (LNumOps != RNumOps)
787  return (int)LNumOps - (int)RNumOps;
788 
789  for (unsigned i = 0; i != LNumOps; ++i) {
790  auto X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI,
791  LC->getOperand(i), RC->getOperand(i), DT,
792  Depth + 1);
793  if (X != 0)
794  return X;
795  }
796  EqCacheSCEV.unionSets(LHS, RHS);
797  return 0;
798  }
799 
800  case scUDivExpr: {
801  const SCEVUDivExpr *LC = cast<SCEVUDivExpr>(LHS);
802  const SCEVUDivExpr *RC = cast<SCEVUDivExpr>(RHS);
803 
804  // Lexicographically compare udiv expressions.
805  auto X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LC->getLHS(),
806  RC->getLHS(), DT, Depth + 1);
807  if (X != 0)
808  return X;
809  X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LC->getRHS(),
810  RC->getRHS(), DT, Depth + 1);
811  if (X == 0)
812  EqCacheSCEV.unionSets(LHS, RHS);
813  return X;
814  }
815 
816  case scPtrToInt:
817  case scTruncate:
818  case scZeroExtend:
819  case scSignExtend: {
820  const SCEVCastExpr *LC = cast<SCEVCastExpr>(LHS);
821  const SCEVCastExpr *RC = cast<SCEVCastExpr>(RHS);
822 
823  // Compare cast expressions by operand.
824  auto X =
825  CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LC->getOperand(),
826  RC->getOperand(), DT, Depth + 1);
827  if (X == 0)
828  EqCacheSCEV.unionSets(LHS, RHS);
829  return X;
830  }
831 
832  case scCouldNotCompute:
833  llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
834  }
835  llvm_unreachable("Unknown SCEV kind!");
836 }
837 
838 /// Given a list of SCEV objects, order them by their complexity, and group
839 /// objects of the same complexity together by value. When this routine is
840 /// finished, we know that any duplicates in the vector are consecutive and that
841 /// complexity is monotonically increasing.
842 ///
843 /// Note that we go take special precautions to ensure that we get deterministic
844 /// results from this routine. In other words, we don't want the results of
845 /// this to depend on where the addresses of various SCEV objects happened to
846 /// land in memory.
848  LoopInfo *LI, DominatorTree &DT) {
849  if (Ops.size() < 2) return; // Noop
850 
853 
854  // Whether LHS has provably less complexity than RHS.
855  auto IsLessComplex = [&](const SCEV *LHS, const SCEV *RHS) {
856  auto Complexity =
857  CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LHS, RHS, DT);
858  return Complexity && *Complexity < 0;
859  };
860  if (Ops.size() == 2) {
861  // This is the common case, which also happens to be trivially simple.
862  // Special case it.
863  const SCEV *&LHS = Ops[0], *&RHS = Ops[1];
864  if (IsLessComplex(RHS, LHS))
865  std::swap(LHS, RHS);
866  return;
867  }
868 
869  // Do the rough sort by complexity.
870  llvm::stable_sort(Ops, [&](const SCEV *LHS, const SCEV *RHS) {
871  return IsLessComplex(LHS, RHS);
872  });
873 
874  // Now that we are sorted by complexity, group elements of the same
875  // complexity. Note that this is, at worst, N^2, but the vector is likely to
876  // be extremely short in practice. Note that we take this approach because we
877  // do not want to depend on the addresses of the objects we are grouping.
878  for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) {
879  const SCEV *S = Ops[i];
880  unsigned Complexity = S->getSCEVType();
881 
882  // If there are any objects of the same complexity and same value as this
883  // one, group them.
884  for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) {
885  if (Ops[j] == S) { // Found a duplicate.
886  // Move it to immediately after i'th element.
887  std::swap(Ops[i+1], Ops[j]);
888  ++i; // no need to rescan it.
889  if (i == e-2) return; // Done!
890  }
891  }
892  }
893 }
894 
895 /// Returns true if \p Ops contains a huge SCEV (the subtree of S contains at
896 /// least HugeExprThreshold nodes).
898  return any_of(Ops, [](const SCEV *S) {
899  return S->getExpressionSize() >= HugeExprThreshold;
900  });
901 }
902 
903 //===----------------------------------------------------------------------===//
904 // Simple SCEV method implementations
905 //===----------------------------------------------------------------------===//
906 
907 /// Compute BC(It, K). The result has width W. Assume, K > 0.
908 static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
909  ScalarEvolution &SE,
910  Type *ResultTy) {
911  // Handle the simplest case efficiently.
912  if (K == 1)
913  return SE.getTruncateOrZeroExtend(It, ResultTy);
914 
915  // We are using the following formula for BC(It, K):
916  //
917  // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
918  //
919  // Suppose, W is the bitwidth of the return value. We must be prepared for
920  // overflow. Hence, we must assure that the result of our computation is
921  // equal to the accurate one modulo 2^W. Unfortunately, division isn't
922  // safe in modular arithmetic.
923  //
924  // However, this code doesn't use exactly that formula; the formula it uses
925  // is something like the following, where T is the number of factors of 2 in
926  // K! (i.e. trailing zeros in the binary representation of K!), and ^ is
927  // exponentiation:
928  //
929  // BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T)
930  //
931  // This formula is trivially equivalent to the previous formula. However,
932  // this formula can be implemented much more efficiently. The trick is that
933  // K! / 2^T is odd, and exact division by an odd number *is* safe in modular
934  // arithmetic. To do exact division in modular arithmetic, all we have
935  // to do is multiply by the inverse. Therefore, this step can be done at
936  // width W.
937  //
938  // The next issue is how to safely do the division by 2^T. The way this
939  // is done is by doing the multiplication step at a width of at least W + T
940  // bits. This way, the bottom W+T bits of the product are accurate. Then,
941  // when we perform the division by 2^T (which is equivalent to a right shift
942  // by T), the bottom W bits are accurate. Extra bits are okay; they'll get
943  // truncated out after the division by 2^T.
944  //
945  // In comparison to just directly using the first formula, this technique
946  // is much more efficient; using the first formula requires W * K bits,
947  // but this formula less than W + K bits. Also, the first formula requires
948  // a division step, whereas this formula only requires multiplies and shifts.
949  //
950  // It doesn't matter whether the subtraction step is done in the calculation
951  // width or the input iteration count's width; if the subtraction overflows,
952  // the result must be zero anyway. We prefer here to do it in the width of
953  // the induction variable because it helps a lot for certain cases; CodeGen
954  // isn't smart enough to ignore the overflow, which leads to much less
955  // efficient code if the width of the subtraction is wider than the native
956  // register width.
957  //
958  // (It's possible to not widen at all by pulling out factors of 2 before
959  // the multiplication; for example, K=2 can be calculated as
960  // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires
961  // extra arithmetic, so it's not an obvious win, and it gets
962  // much more complicated for K > 3.)
963 
964  // Protection from insane SCEVs; this bound is conservative,
965  // but it probably doesn't matter.
966  if (K > 1000)
967  return SE.getCouldNotCompute();
968 
969  unsigned W = SE.getTypeSizeInBits(ResultTy);
970 
971  // Calculate K! / 2^T and T; we divide out the factors of two before
972  // multiplying for calculating K! / 2^T to avoid overflow.
973  // Other overflow doesn't matter because we only care about the bottom
974  // W bits of the result.
975  APInt OddFactorial(W, 1);
976  unsigned T = 1;
977  for (unsigned i = 3; i <= K; ++i) {
978  APInt Mult(W, i);
979  unsigned TwoFactors = Mult.countTrailingZeros();
980  T += TwoFactors;
981  Mult.lshrInPlace(TwoFactors);
982  OddFactorial *= Mult;
983  }
984 
985  // We need at least W + T bits for the multiplication step
986  unsigned CalculationBits = W + T;
987 
988  // Calculate 2^T, at width T+W.
989  APInt DivFactor = APInt::getOneBitSet(CalculationBits, T);
990 
991  // Calculate the multiplicative inverse of K! / 2^T;
992  // this multiplication factor will perform the exact division by
993  // K! / 2^T.
995  APInt MultiplyFactor = OddFactorial.zext(W+1);
996  MultiplyFactor = MultiplyFactor.multiplicativeInverse(Mod);
997  MultiplyFactor = MultiplyFactor.trunc(W);
998 
999  // Calculate the product, at width T+W
1000  IntegerType *CalculationTy = IntegerType::get(SE.getContext(),
1001  CalculationBits);
1002  const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
1003  for (unsigned i = 1; i != K; ++i) {
1004  const SCEV *S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i));
1005  Dividend = SE.getMulExpr(Dividend,
1006  SE.getTruncateOrZeroExtend(S, CalculationTy));
1007  }
1008 
1009  // Divide by 2^T
1010  const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
1011 
1012  // Truncate the result, and divide by K! / 2^T.
1013 
1014  return SE.getMulExpr(SE.getConstant(MultiplyFactor),
1015  SE.getTruncateOrZeroExtend(DivResult, ResultTy));
1016 }
1017 
1018 /// Return the value of this chain of recurrences at the specified iteration
1019 /// number. We can evaluate this recurrence by multiplying each element in the
1020 /// chain by the binomial coefficient corresponding to it. In other words, we
1021 /// can evaluate {A,+,B,+,C,+,D} as:
1022 ///
1023 /// A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
1024 ///
1025 /// where BC(It, k) stands for binomial coefficient.
1027  ScalarEvolution &SE) const {
1028  const SCEV *Result = getStart();
1029  for (unsigned i = 1, e = getNumOperands(); i != e; ++i) {
1030  // The computation is correct in the face of overflow provided that the
1031  // multiplication is performed _after_ the evaluation of the binomial
1032  // coefficient.
1033  const SCEV *Coeff = BinomialCoefficient(It, i, SE, getType());
1034  if (isa<SCEVCouldNotCompute>(Coeff))
1035  return Coeff;
1036 
1037  Result = SE.getAddExpr(Result, SE.getMulExpr(getOperand(i), Coeff));
1038  }
1039  return Result;
1040 }
1041 
1042 //===----------------------------------------------------------------------===//
1043 // SCEV Expression folder implementations
1044 //===----------------------------------------------------------------------===//
1045 
1047  unsigned Depth) {
1048  assert(Depth <= 1 &&
1049  "getLosslessPtrToIntExpr() should self-recurse at most once.");
1050 
1051  // We could be called with an integer-typed operands during SCEV rewrites.
1052  // Since the operand is an integer already, just perform zext/trunc/self cast.
1053  if (!Op->getType()->isPointerTy())
1054  return Op;
1055 
1056  assert(!getDataLayout().isNonIntegralPointerType(Op->getType()) &&
1057  "Source pointer type must be integral for ptrtoint!");
1058 
1059  // What would be an ID for such a SCEV cast expression?
1061  ID.AddInteger(scPtrToInt);
1062  ID.AddPointer(Op);
1063 
1064  void *IP = nullptr;
1065 
1066  // Is there already an expression for such a cast?
1067  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1068  return S;
1069 
1070  Type *IntPtrTy = getDataLayout().getIntPtrType(Op->getType());
1071 
1072  // We can only model ptrtoint if SCEV's effective (integer) type
1073  // is sufficiently wide to represent all possible pointer values.
1075  getDataLayout().getTypeSizeInBits(IntPtrTy))
1076  return getCouldNotCompute();
1077 
1078  // If not, is this expression something we can't reduce any further?
1079  if (auto *U = dyn_cast<SCEVUnknown>(Op)) {
1080  // Perform some basic constant folding. If the operand of the ptr2int cast
1081  // is a null pointer, don't create a ptr2int SCEV expression (that will be
1082  // left as-is), but produce a zero constant.
1083  // NOTE: We could handle a more general case, but lack motivational cases.
1084  if (isa<ConstantPointerNull>(U->getValue()))
1085  return getZero(IntPtrTy);
1086 
1087  // Create an explicit cast node.
1088  // We can reuse the existing insert position since if we get here,
1089  // we won't have made any changes which would invalidate it.
1090  SCEV *S = new (SCEVAllocator)
1091  SCEVPtrToIntExpr(ID.Intern(SCEVAllocator), Op, IntPtrTy);
1092  UniqueSCEVs.InsertNode(S, IP);
1093  addToLoopUseLists(S);
1094  return S;
1095  }
1096 
1097  assert(Depth == 0 && "getLosslessPtrToIntExpr() should not self-recurse for "
1098  "non-SCEVUnknown's.");
1099 
1100  // Otherwise, we've got some expression that is more complex than just a
1101  // single SCEVUnknown. But we don't want to have a SCEVPtrToIntExpr of an
1102  // arbitrary expression, we want to have SCEVPtrToIntExpr of an SCEVUnknown
1103  // only, and the expressions must otherwise be integer-typed.
1104  // So sink the cast down to the SCEVUnknown's.
1105 
1106  /// The SCEVPtrToIntSinkingRewriter takes a scalar evolution expression,
1107  /// which computes a pointer-typed value, and rewrites the whole expression
1108  /// tree so that *all* the computations are done on integers, and the only
1109  /// pointer-typed operands in the expression are SCEVUnknown.
1110  class SCEVPtrToIntSinkingRewriter
1111  : public SCEVRewriteVisitor<SCEVPtrToIntSinkingRewriter> {
1113 
1114  public:
1115  SCEVPtrToIntSinkingRewriter(ScalarEvolution &SE) : SCEVRewriteVisitor(SE) {}
1116 
1117  static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE) {
1118  SCEVPtrToIntSinkingRewriter Rewriter(SE);
1119  return Rewriter.visit(Scev);
1120  }
1121 
1122  const SCEV *visit(const SCEV *S) {
1123  Type *STy = S->getType();
1124  // If the expression is not pointer-typed, just keep it as-is.
1125  if (!STy->isPointerTy())
1126  return S;
1127  // Else, recursively sink the cast down into it.
1128  return Base::visit(S);
1129  }
1130 
1131  const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
1133  bool Changed = false;
1134  for (auto *Op : Expr->operands()) {
1135  Operands.push_back(visit(Op));
1136  Changed |= Op != Operands.back();
1137  }
1138  return !Changed ? Expr : SE.getAddExpr(Operands, Expr->getNoWrapFlags());
1139  }
1140 
1141  const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
1143  bool Changed = false;
1144  for (auto *Op : Expr->operands()) {
1145  Operands.push_back(visit(Op));
1146  Changed |= Op != Operands.back();
1147  }
1148  return !Changed ? Expr : SE.getMulExpr(Operands, Expr->getNoWrapFlags());
1149  }
1150 
1151  const SCEV *visitUnknown(const SCEVUnknown *Expr) {
1152  Type *ExprPtrTy = Expr->getType();
1153  assert(ExprPtrTy->isPointerTy() &&
1154  "Should only reach pointer-typed SCEVUnknown's.");
1155  return SE.getLosslessPtrToIntExpr(Expr, /*Depth=*/1);
1156  }
1157  };
1158 
1159  // And actually perform the cast sinking.
1160  const SCEV *IntOp = SCEVPtrToIntSinkingRewriter::rewrite(Op, *this);
1161  assert(IntOp->getType()->isIntegerTy() &&
1162  "We must have succeeded in sinking the cast, "
1163  "and ending up with an integer-typed expression!");
1164  return IntOp;
1165 }
1166 
1168  assert(Ty->isIntegerTy() && "Target type must be an integer type!");
1169 
1170  const SCEV *IntOp = getLosslessPtrToIntExpr(Op);
1171  if (isa<SCEVCouldNotCompute>(IntOp))
1172  return IntOp;
1173 
1174  return getTruncateOrZeroExtend(IntOp, Ty);
1175 }
1176 
1178  unsigned Depth) {
1179  assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) &&
1180  "This is not a truncating conversion!");
1181  assert(isSCEVable(Ty) &&
1182  "This is not a conversion to a SCEVable type!");
1183  Ty = getEffectiveSCEVType(Ty);
1184 
1186  ID.AddInteger(scTruncate);
1187  ID.AddPointer(Op);
1188  ID.AddPointer(Ty);
1189  void *IP = nullptr;
1190  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1191 
1192  // Fold if the operand is constant.
1193  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1194  return getConstant(
1195  cast<ConstantInt>(ConstantExpr::getTrunc(SC->getValue(), Ty)));
1196 
1197  // trunc(trunc(x)) --> trunc(x)
1198  if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op))
1199  return getTruncateExpr(ST->getOperand(), Ty, Depth + 1);
1200 
1201  // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing
1202  if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
1203  return getTruncateOrSignExtend(SS->getOperand(), Ty, Depth + 1);
1204 
1205  // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing
1206  if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1207  return getTruncateOrZeroExtend(SZ->getOperand(), Ty, Depth + 1);
1208 
1209  if (Depth > MaxCastDepth) {
1210  SCEV *S =
1211  new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), Op, Ty);
1212  UniqueSCEVs.InsertNode(S, IP);
1213  addToLoopUseLists(S);
1214  return S;
1215  }
1216 
1217  // trunc(x1 + ... + xN) --> trunc(x1) + ... + trunc(xN) and
1218  // trunc(x1 * ... * xN) --> trunc(x1) * ... * trunc(xN),
1219  // if after transforming we have at most one truncate, not counting truncates
1220  // that replace other casts.
1221  if (isa<SCEVAddExpr>(Op) || isa<SCEVMulExpr>(Op)) {
1222  auto *CommOp = cast<SCEVCommutativeExpr>(Op);
1224  unsigned numTruncs = 0;
1225  for (unsigned i = 0, e = CommOp->getNumOperands(); i != e && numTruncs < 2;
1226  ++i) {
1227  const SCEV *S = getTruncateExpr(CommOp->getOperand(i), Ty, Depth + 1);
1228  if (!isa<SCEVIntegralCastExpr>(CommOp->getOperand(i)) &&
1229  isa<SCEVTruncateExpr>(S))
1230  numTruncs++;
1231  Operands.push_back(S);
1232  }
1233  if (numTruncs < 2) {
1234  if (isa<SCEVAddExpr>(Op))
1235  return getAddExpr(Operands);
1236  else if (isa<SCEVMulExpr>(Op))
1237  return getMulExpr(Operands);
1238  else
1239  llvm_unreachable("Unexpected SCEV type for Op.");
1240  }
1241  // Although we checked in the beginning that ID is not in the cache, it is
1242  // possible that during recursion and different modification ID was inserted
1243  // into the cache. So if we find it, just return it.
1244  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
1245  return S;
1246  }
1247 
1248  // If the input value is a chrec scev, truncate the chrec's operands.
1249  if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
1251  for (const SCEV *Op : AddRec->operands())
1252  Operands.push_back(getTruncateExpr(Op, Ty, Depth + 1));
1253  return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap);
1254  }
1255 
1256  // Return zero if truncating to known zeros.
1257  uint32_t MinTrailingZeros = GetMinTrailingZeros(Op);
1258  if (MinTrailingZeros >= getTypeSizeInBits(Ty))
1259  return getZero(Ty);
1260 
1261  // The cast wasn't folded; create an explicit cast node. We can reuse
1262  // the existing insert position since if we get here, we won't have
1263  // made any changes which would invalidate it.
1264  SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator),
1265  Op, Ty);
1266  UniqueSCEVs.InsertNode(S, IP);
1267  addToLoopUseLists(S);
1268  return S;
1269 }
1270 
1271 // Get the limit of a recurrence such that incrementing by Step cannot cause
1272 // signed overflow as long as the value of the recurrence within the
1273 // loop does not exceed this limit before incrementing.
1274 static const SCEV *getSignedOverflowLimitForStep(const SCEV *Step,
1275  ICmpInst::Predicate *Pred,
1276  ScalarEvolution *SE) {
1277  unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1278  if (SE->isKnownPositive(Step)) {
1279  *Pred = ICmpInst::ICMP_SLT;
1281  SE->getSignedRangeMax(Step));
1282  }
1283  if (SE->isKnownNegative(Step)) {
1284  *Pred = ICmpInst::ICMP_SGT;
1286  SE->getSignedRangeMin(Step));
1287  }
1288  return nullptr;
1289 }
1290 
1291 // Get the limit of a recurrence such that incrementing by Step cannot cause
1292 // unsigned overflow as long as the value of the recurrence within the loop does
1293 // not exceed this limit before incrementing.
1294 static const SCEV *getUnsignedOverflowLimitForStep(const SCEV *Step,
1295  ICmpInst::Predicate *Pred,
1296  ScalarEvolution *SE) {
1297  unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1298  *Pred = ICmpInst::ICMP_ULT;
1299 
1301  SE->getUnsignedRangeMax(Step));
1302 }
1303 
1304 namespace {
1305 
1306 struct ExtendOpTraitsBase {
1307  typedef const SCEV *(ScalarEvolution::*GetExtendExprTy)(const SCEV *, Type *,
1308  unsigned);
1309 };
1310 
1311 // Used to make code generic over signed and unsigned overflow.
1312 template <typename ExtendOp> struct ExtendOpTraits {
1313  // Members present:
1314  //
1315  // static const SCEV::NoWrapFlags WrapType;
1316  //
1317  // static const ExtendOpTraitsBase::GetExtendExprTy GetExtendExpr;
1318  //
1319  // static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1320  // ICmpInst::Predicate *Pred,
1321  // ScalarEvolution *SE);
1322 };
1323 
1324 template <>
1325 struct ExtendOpTraits<SCEVSignExtendExpr> : public ExtendOpTraitsBase {
1326  static const SCEV::NoWrapFlags WrapType = SCEV::FlagNSW;
1327 
1328  static const GetExtendExprTy GetExtendExpr;
1329 
1330  static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1331  ICmpInst::Predicate *Pred,
1332  ScalarEvolution *SE) {
1333  return getSignedOverflowLimitForStep(Step, Pred, SE);
1334  }
1335 };
1336 
1337 const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1339 
1340 template <>
1341 struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase {
1342  static const SCEV::NoWrapFlags WrapType = SCEV::FlagNUW;
1343 
1344  static const GetExtendExprTy GetExtendExpr;
1345 
1346  static const SCEV *getOverflowLimitForStep(const SCEV *Step,
1347  ICmpInst::Predicate *Pred,
1348  ScalarEvolution *SE) {
1349  return getUnsignedOverflowLimitForStep(Step, Pred, SE);
1350  }
1351 };
1352 
1353 const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<
1355 
1356 } // end anonymous namespace
1357 
1358 // The recurrence AR has been shown to have no signed/unsigned wrap or something
1359 // close to it. Typically, if we can prove NSW/NUW for AR, then we can just as
1360 // easily prove NSW/NUW for its preincrement or postincrement sibling. This
1361 // allows normalizing a sign/zero extended AddRec as such: {sext/zext(Step +
1362 // Start),+,Step} => {(Step + sext/zext(Start),+,Step} As a result, the
1363 // expression "Step + sext/zext(PreIncAR)" is congruent with
1364 // "sext/zext(PostIncAR)"
1365 template <typename ExtendOpTy>
1366 static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty,
1367  ScalarEvolution *SE, unsigned Depth) {
1368  auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1369  auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1370 
1371  const Loop *L = AR->getLoop();
1372  const SCEV *Start = AR->getStart();
1373  const SCEV *Step = AR->getStepRecurrence(*SE);
1374 
1375  // Check for a simple looking step prior to loop entry.
1376  const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Start);
1377  if (!SA)
1378  return nullptr;
1379 
1380  // Create an AddExpr for "PreStart" after subtracting Step. Full SCEV
1381  // subtraction is expensive. For this purpose, perform a quick and dirty
1382  // difference, by checking for Step in the operand list.
1384  for (const SCEV *Op : SA->operands())
1385  if (Op != Step)
1386  DiffOps.push_back(Op);
1387 
1388  if (DiffOps.size() == SA->getNumOperands())
1389  return nullptr;
1390 
1391  // Try to prove `WrapType` (SCEV::FlagNSW or SCEV::FlagNUW) on `PreStart` +
1392  // `Step`:
1393 
1394  // 1. NSW/NUW flags on the step increment.
1395  auto PreStartFlags =
1397  const SCEV *PreStart = SE->getAddExpr(DiffOps, PreStartFlags);
1398  const SCEVAddRecExpr *PreAR = dyn_cast<SCEVAddRecExpr>(
1399  SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap));
1400 
1401  // "{S,+,X} is <nsw>/<nuw>" and "the backedge is taken at least once" implies
1402  // "S+X does not sign/unsign-overflow".
1403  //
1404 
1405  const SCEV *BECount = SE->getBackedgeTakenCount(L);
1406  if (PreAR && PreAR->getNoWrapFlags(WrapType) &&
1407  !isa<SCEVCouldNotCompute>(BECount) && SE->isKnownPositive(BECount))
1408  return PreStart;
1409 
1410  // 2. Direct overflow check on the step operation's expression.
1411  unsigned BitWidth = SE->getTypeSizeInBits(AR->getType());
1412  Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2);
1413  const SCEV *OperandExtendedStart =
1414  SE->getAddExpr((SE->*GetExtendExpr)(PreStart, WideTy, Depth),
1415  (SE->*GetExtendExpr)(Step, WideTy, Depth));
1416  if ((SE->*GetExtendExpr)(Start, WideTy, Depth) == OperandExtendedStart) {
1417  if (PreAR && AR->getNoWrapFlags(WrapType)) {
1418  // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW
1419  // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then
1420  // `PreAR` == {`PreStart`,+,`Step`} is also `WrapType`. Cache this fact.
1421  SE->setNoWrapFlags(const_cast<SCEVAddRecExpr *>(PreAR), WrapType);
1422  }
1423  return PreStart;
1424  }
1425 
1426  // 3. Loop precondition.
1427  ICmpInst::Predicate Pred;
1428  const SCEV *OverflowLimit =
1429  ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(Step, &Pred, SE);
1430 
1431  if (OverflowLimit &&
1432  SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit))
1433  return PreStart;
1434 
1435  return nullptr;
1436 }
1437 
1438 // Get the normalized zero or sign extended expression for this AddRec's Start.
1439 template <typename ExtendOpTy>
1440 static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty,
1441  ScalarEvolution *SE,
1442  unsigned Depth) {
1443  auto GetExtendExpr = ExtendOpTraits<ExtendOpTy>::GetExtendExpr;
1444 
1445  const SCEV *PreStart = getPreStartForExtend<ExtendOpTy>(AR, Ty, SE, Depth);
1446  if (!PreStart)
1447  return (SE->*GetExtendExpr)(AR->getStart(), Ty, Depth);
1448 
1449  return SE->getAddExpr((SE->*GetExtendExpr)(AR->getStepRecurrence(*SE), Ty,
1450  Depth),
1451  (SE->*GetExtendExpr)(PreStart, Ty, Depth));
1452 }
1453 
1454 // Try to prove away overflow by looking at "nearby" add recurrences. A
1455 // motivating example for this rule: if we know `{0,+,4}` is `ult` `-1` and it
1456 // does not itself wrap then we can conclude that `{1,+,4}` is `nuw`.
1457 //
1458 // Formally:
1459 //
1460 // {S,+,X} == {S-T,+,X} + T
1461 // => Ext({S,+,X}) == Ext({S-T,+,X} + T)
1462 //
1463 // If ({S-T,+,X} + T) does not overflow ... (1)
1464 //
1465 // RHS == Ext({S-T,+,X} + T) == Ext({S-T,+,X}) + Ext(T)
1466 //
1467 // If {S-T,+,X} does not overflow ... (2)
1468 //
1469 // RHS == Ext({S-T,+,X}) + Ext(T) == {Ext(S-T),+,Ext(X)} + Ext(T)
1470 // == {Ext(S-T)+Ext(T),+,Ext(X)}
1471 //
1472 // If (S-T)+T does not overflow ... (3)
1473 //
1474 // RHS == {Ext(S-T)+Ext(T),+,Ext(X)} == {Ext(S-T+T),+,Ext(X)}
1475 // == {Ext(S),+,Ext(X)} == LHS
1476 //
1477 // Thus, if (1), (2) and (3) are true for some T, then
1478 // Ext({S,+,X}) == {Ext(S),+,Ext(X)}
1479 //
1480 // (3) is implied by (1) -- "(S-T)+T does not overflow" is simply "({S-T,+,X}+T)
1481 // does not overflow" restricted to the 0th iteration. Therefore we only need
1482 // to check for (1) and (2).
1483 //
1484 // In the current context, S is `Start`, X is `Step`, Ext is `ExtendOpTy` and T
1485 // is `Delta` (defined below).
1486 template <typename ExtendOpTy>
1487 bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start,
1488  const SCEV *Step,
1489  const Loop *L) {
1490  auto WrapType = ExtendOpTraits<ExtendOpTy>::WrapType;
1491 
1492  // We restrict `Start` to a constant to prevent SCEV from spending too much
1493  // time here. It is correct (but more expensive) to continue with a
1494  // non-constant `Start` and do a general SCEV subtraction to compute
1495  // `PreStart` below.
1496  const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start);
1497  if (!StartC)
1498  return false;
1499 
1500  APInt StartAI = StartC->getAPInt();
1501 
1502  for (unsigned Delta : {-2, -1, 1, 2}) {
1503  const SCEV *PreStart = getConstant(StartAI - Delta);
1504 
1506  ID.AddInteger(scAddRecExpr);
1507  ID.AddPointer(PreStart);
1508  ID.AddPointer(Step);
1509  ID.AddPointer(L);
1510  void *IP = nullptr;
1511  const auto *PreAR =
1512  static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
1513 
1514  // Give up if we don't already have the add recurrence we need because
1515  // actually constructing an add recurrence is relatively expensive.
1516  if (PreAR && PreAR->getNoWrapFlags(WrapType)) { // proves (2)
1517  const SCEV *DeltaS = getConstant(StartC->getType(), Delta);
1519  const SCEV *Limit = ExtendOpTraits<ExtendOpTy>::getOverflowLimitForStep(
1520  DeltaS, &Pred, this);
1521  if (Limit && isKnownPredicate(Pred, PreAR, Limit)) // proves (1)
1522  return true;
1523  }
1524  }
1525 
1526  return false;
1527 }
1528 
1529 // Finds an integer D for an expression (C + x + y + ...) such that the top
1530 // level addition in (D + (C - D + x + y + ...)) would not wrap (signed or
1531 // unsigned) and the number of trailing zeros of (C - D + x + y + ...) is
1532 // maximized, where C is the \p ConstantTerm, x, y, ... are arbitrary SCEVs, and
1533 // the (C + x + y + ...) expression is \p WholeAddExpr.
1535  const SCEVConstant *ConstantTerm,
1536  const SCEVAddExpr *WholeAddExpr) {
1537  const APInt &C = ConstantTerm->getAPInt();
1538  const unsigned BitWidth = C.getBitWidth();
1539  // Find number of trailing zeros of (x + y + ...) w/o the C first:
1540  uint32_t TZ = BitWidth;
1541  for (unsigned I = 1, E = WholeAddExpr->getNumOperands(); I < E && TZ; ++I)
1542  TZ = std::min(TZ, SE.GetMinTrailingZeros(WholeAddExpr->getOperand(I)));
1543  if (TZ) {
1544  // Set D to be as many least significant bits of C as possible while still
1545  // guaranteeing that adding D to (C - D + x + y + ...) won't cause a wrap:
1546  return TZ < BitWidth ? C.trunc(TZ).zext(BitWidth) : C;
1547  }
1548  return APInt(BitWidth, 0);
1549 }
1550 
1551 // Finds an integer D for an affine AddRec expression {C,+,x} such that the top
1552 // level addition in (D + {C-D,+,x}) would not wrap (signed or unsigned) and the
1553 // number of trailing zeros of (C - D + x * n) is maximized, where C is the \p
1554 // ConstantStart, x is an arbitrary \p Step, and n is the loop trip count.
1556  const APInt &ConstantStart,
1557  const SCEV *Step) {
1558  const unsigned BitWidth = ConstantStart.getBitWidth();
1559  const uint32_t TZ = SE.GetMinTrailingZeros(Step);
1560  if (TZ)
1561  return TZ < BitWidth ? ConstantStart.trunc(TZ).zext(BitWidth)
1562  : ConstantStart;
1563  return APInt(BitWidth, 0);
1564 }
1565 
1566 const SCEV *
1568  assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1569  "This is not an extending conversion!");
1570  assert(isSCEVable(Ty) &&
1571  "This is not a conversion to a SCEVable type!");
1572  Ty = getEffectiveSCEVType(Ty);
1573 
1574  // Fold if the operand is constant.
1575  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1576  return getConstant(
1577  cast<ConstantInt>(ConstantExpr::getZExt(SC->getValue(), Ty)));
1578 
1579  // zext(zext(x)) --> zext(x)
1580  if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1581  return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1582 
1583  // Before doing any expensive analysis, check to see if we've already
1584  // computed a SCEV for this Op and Ty.
1586  ID.AddInteger(scZeroExtend);
1587  ID.AddPointer(Op);
1588  ID.AddPointer(Ty);
1589  void *IP = nullptr;
1590  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1591  if (Depth > MaxCastDepth) {
1592  SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1593  Op, Ty);
1594  UniqueSCEVs.InsertNode(S, IP);
1595  addToLoopUseLists(S);
1596  return S;
1597  }
1598 
1599  // zext(trunc(x)) --> zext(x) or x or trunc(x)
1600  if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
1601  // It's possible the bits taken off by the truncate were all zero bits. If
1602  // so, we should be able to simplify this further.
1603  const SCEV *X = ST->getOperand();
1605  unsigned TruncBits = getTypeSizeInBits(ST->getType());
1606  unsigned NewBits = getTypeSizeInBits(Ty);
1607  if (CR.truncate(TruncBits).zeroExtend(NewBits).contains(
1608  CR.zextOrTrunc(NewBits)))
1609  return getTruncateOrZeroExtend(X, Ty, Depth);
1610  }
1611 
1612  // If the input value is a chrec scev, and we can prove that the value
1613  // did not overflow the old, smaller, value, we can zero extend all of the
1614  // operands (often constants). This allows analysis of something like
1615  // this: for (unsigned char X = 0; X < 100; ++X) { int Y = X; }
1616  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
1617  if (AR->isAffine()) {
1618  const SCEV *Start = AR->getStart();
1619  const SCEV *Step = AR->getStepRecurrence(*this);
1620  unsigned BitWidth = getTypeSizeInBits(AR->getType());
1621  const Loop *L = AR->getLoop();
1622 
1623  if (!AR->hasNoUnsignedWrap()) {
1624  auto NewFlags = proveNoWrapViaConstantRanges(AR);
1625  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
1626  }
1627 
1628  // If we have special knowledge that this addrec won't overflow,
1629  // we don't need to do any further analysis.
1630  if (AR->hasNoUnsignedWrap())
1631  return getAddRecExpr(
1632  getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1),
1633  getZeroExtendExpr(Step, Ty, Depth + 1), L, AR->getNoWrapFlags());
1634 
1635  // Check whether the backedge-taken count is SCEVCouldNotCompute.
1636  // Note that this serves two purposes: It filters out loops that are
1637  // simply not analyzable, and it covers the case where this code is
1638  // being called from within backedge-taken count analysis, such that
1639  // attempting to ask for the backedge-taken count would likely result
1640  // in infinite recursion. In the later case, the analysis code will
1641  // cope with a conservative value, and it will take care to purge
1642  // that value once it has finished.
1643  const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
1644  if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
1645  // Manually compute the final value for AR, checking for overflow.
1646 
1647  // Check whether the backedge-taken count can be losslessly casted to
1648  // the addrec's type. The count is always unsigned.
1649  const SCEV *CastedMaxBECount =
1650  getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
1651  const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
1652  CastedMaxBECount, MaxBECount->getType(), Depth);
1653  if (MaxBECount == RecastedMaxBECount) {
1654  Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
1655  // Check whether Start+Step*MaxBECount has no unsigned overflow.
1656  const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step,
1657  SCEV::FlagAnyWrap, Depth + 1);
1658  const SCEV *ZAdd = getZeroExtendExpr(getAddExpr(Start, ZMul,
1660  Depth + 1),
1661  WideTy, Depth + 1);
1662  const SCEV *WideStart = getZeroExtendExpr(Start, WideTy, Depth + 1);
1663  const SCEV *WideMaxBECount =
1664  getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
1665  const SCEV *OperandExtendedAdd =
1666  getAddExpr(WideStart,
1667  getMulExpr(WideMaxBECount,
1668  getZeroExtendExpr(Step, WideTy, Depth + 1),
1669  SCEV::FlagAnyWrap, Depth + 1),
1670  SCEV::FlagAnyWrap, Depth + 1);
1671  if (ZAdd == OperandExtendedAdd) {
1672  // Cache knowledge of AR NUW, which is propagated to this AddRec.
1673  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1674  // Return the expression with the addrec on the outside.
1675  return getAddRecExpr(
1676  getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1677  Depth + 1),
1678  getZeroExtendExpr(Step, Ty, Depth + 1), L,
1679  AR->getNoWrapFlags());
1680  }
1681  // Similar to above, only this time treat the step value as signed.
1682  // This covers loops that count down.
1683  OperandExtendedAdd =
1684  getAddExpr(WideStart,
1685  getMulExpr(WideMaxBECount,
1686  getSignExtendExpr(Step, WideTy, Depth + 1),
1687  SCEV::FlagAnyWrap, Depth + 1),
1688  SCEV::FlagAnyWrap, Depth + 1);
1689  if (ZAdd == OperandExtendedAdd) {
1690  // Cache knowledge of AR NW, which is propagated to this AddRec.
1691  // Negative step causes unsigned wrap, but it still can't self-wrap.
1692  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1693  // Return the expression with the addrec on the outside.
1694  return getAddRecExpr(
1695  getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1696  Depth + 1),
1697  getSignExtendExpr(Step, Ty, Depth + 1), L,
1698  AR->getNoWrapFlags());
1699  }
1700  }
1701  }
1702 
1703  // Normally, in the cases we can prove no-overflow via a
1704  // backedge guarding condition, we can also compute a backedge
1705  // taken count for the loop. The exceptions are assumptions and
1706  // guards present in the loop -- SCEV is not great at exploiting
1707  // these to compute max backedge taken counts, but can still use
1708  // these to prove lack of overflow. Use this fact to avoid
1709  // doing extra work that may not pay off.
1710  if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards ||
1711  !AC.assumptions().empty()) {
1712 
1713  auto NewFlags = proveNoUnsignedWrapViaInduction(AR);
1714  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
1715  if (AR->hasNoUnsignedWrap()) {
1716  // Same as nuw case above - duplicated here to avoid a compile time
1717  // issue. It's not clear that the order of checks does matter, but
1718  // it's one of two issue possible causes for a change which was
1719  // reverted. Be conservative for the moment.
1720  return getAddRecExpr(
1721  getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1722  Depth + 1),
1723  getZeroExtendExpr(Step, Ty, Depth + 1), L,
1724  AR->getNoWrapFlags());
1725  }
1726 
1727  // For a negative step, we can extend the operands iff doing so only
1728  // traverses values in the range zext([0,UINT_MAX]).
1729  if (isKnownNegative(Step)) {
1731  getSignedRangeMin(Step));
1734  // Cache knowledge of AR NW, which is propagated to this
1735  // AddRec. Negative step causes unsigned wrap, but it
1736  // still can't self-wrap.
1737  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
1738  // Return the expression with the addrec on the outside.
1739  return getAddRecExpr(
1740  getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
1741  Depth + 1),
1742  getSignExtendExpr(Step, Ty, Depth + 1), L,
1743  AR->getNoWrapFlags());
1744  }
1745  }
1746  }
1747 
1748  // zext({C,+,Step}) --> (zext(D) + zext({C-D,+,Step}))<nuw><nsw>
1749  // if D + (C - D + Step * n) could be proven to not unsigned wrap
1750  // where D maximizes the number of trailing zeros of (C - D + Step * n)
1751  if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
1752  const APInt &C = SC->getAPInt();
1753  const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
1754  if (D != 0) {
1755  const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1756  const SCEV *SResidual =
1757  getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
1758  const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1759  return getAddExpr(SZExtD, SZExtR,
1761  Depth + 1);
1762  }
1763  }
1764 
1765  if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) {
1766  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
1767  return getAddRecExpr(
1768  getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1),
1769  getZeroExtendExpr(Step, Ty, Depth + 1), L, AR->getNoWrapFlags());
1770  }
1771  }
1772 
1773  // zext(A % B) --> zext(A) % zext(B)
1774  {
1775  const SCEV *LHS;
1776  const SCEV *RHS;
1777  if (matchURem(Op, LHS, RHS))
1778  return getURemExpr(getZeroExtendExpr(LHS, Ty, Depth + 1),
1779  getZeroExtendExpr(RHS, Ty, Depth + 1));
1780  }
1781 
1782  // zext(A / B) --> zext(A) / zext(B).
1783  if (auto *Div = dyn_cast<SCEVUDivExpr>(Op))
1784  return getUDivExpr(getZeroExtendExpr(Div->getLHS(), Ty, Depth + 1),
1785  getZeroExtendExpr(Div->getRHS(), Ty, Depth + 1));
1786 
1787  if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1788  // zext((A + B + ...)<nuw>) --> (zext(A) + zext(B) + ...)<nuw>
1789  if (SA->hasNoUnsignedWrap()) {
1790  // If the addition does not unsign overflow then we can, by definition,
1791  // commute the zero extension with the addition operation.
1793  for (const auto *Op : SA->operands())
1794  Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1795  return getAddExpr(Ops, SCEV::FlagNUW, Depth + 1);
1796  }
1797 
1798  // zext(C + x + y + ...) --> (zext(D) + zext((C - D) + x + y + ...))
1799  // if D + (C - D + x + y + ...) could be proven to not unsigned wrap
1800  // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1801  //
1802  // Often address arithmetics contain expressions like
1803  // (zext (add (shl X, C1), C2)), for instance, (zext (5 + (4 * X))).
1804  // This transformation is useful while proving that such expressions are
1805  // equal or differ by a small constant amount, see LoadStoreVectorizer pass.
1806  if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1807  const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1808  if (D != 0) {
1809  const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
1810  const SCEV *SResidual =
1812  const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
1813  return getAddExpr(SZExtD, SZExtR,
1815  Depth + 1);
1816  }
1817  }
1818  }
1819 
1820  if (auto *SM = dyn_cast<SCEVMulExpr>(Op)) {
1821  // zext((A * B * ...)<nuw>) --> (zext(A) * zext(B) * ...)<nuw>
1822  if (SM->hasNoUnsignedWrap()) {
1823  // If the multiply does not unsign overflow then we can, by definition,
1824  // commute the zero extension with the multiply operation.
1826  for (const auto *Op : SM->operands())
1827  Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
1828  return getMulExpr(Ops, SCEV::FlagNUW, Depth + 1);
1829  }
1830 
1831  // zext(2^K * (trunc X to iN)) to iM ->
1832  // 2^K * (zext(trunc X to i{N-K}) to iM)<nuw>
1833  //
1834  // Proof:
1835  //
1836  // zext(2^K * (trunc X to iN)) to iM
1837  // = zext((trunc X to iN) << K) to iM
1838  // = zext((trunc X to i{N-K}) << K)<nuw> to iM
1839  // (because shl removes the top K bits)
1840  // = zext((2^K * (trunc X to i{N-K}))<nuw>) to iM
1841  // = (2^K * (zext(trunc X to i{N-K}) to iM))<nuw>.
1842  //
1843  if (SM->getNumOperands() == 2)
1844  if (auto *MulLHS = dyn_cast<SCEVConstant>(SM->getOperand(0)))
1845  if (MulLHS->getAPInt().isPowerOf2())
1846  if (auto *TruncRHS = dyn_cast<SCEVTruncateExpr>(SM->getOperand(1))) {
1847  int NewTruncBits = getTypeSizeInBits(TruncRHS->getType()) -
1848  MulLHS->getAPInt().logBase2();
1849  Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits);
1850  return getMulExpr(
1851  getZeroExtendExpr(MulLHS, Ty),
1853  getTruncateExpr(TruncRHS->getOperand(), NewTruncTy), Ty),
1854  SCEV::FlagNUW, Depth + 1);
1855  }
1856  }
1857 
1858  // The cast wasn't folded; create an explicit cast node.
1859  // Recompute the insert position, as it may have been invalidated.
1860  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1861  SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1862  Op, Ty);
1863  UniqueSCEVs.InsertNode(S, IP);
1864  addToLoopUseLists(S);
1865  return S;
1866 }
1867 
1868 const SCEV *
1870  assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1871  "This is not an extending conversion!");
1872  assert(isSCEVable(Ty) &&
1873  "This is not a conversion to a SCEVable type!");
1874  Ty = getEffectiveSCEVType(Ty);
1875 
1876  // Fold if the operand is constant.
1877  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1878  return getConstant(
1879  cast<ConstantInt>(ConstantExpr::getSExt(SC->getValue(), Ty)));
1880 
1881  // sext(sext(x)) --> sext(x)
1882  if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
1883  return getSignExtendExpr(SS->getOperand(), Ty, Depth + 1);
1884 
1885  // sext(zext(x)) --> zext(x)
1886  if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1887  return getZeroExtendExpr(SZ->getOperand(), Ty, Depth + 1);
1888 
1889  // Before doing any expensive analysis, check to see if we've already
1890  // computed a SCEV for this Op and Ty.
1892  ID.AddInteger(scSignExtend);
1893  ID.AddPointer(Op);
1894  ID.AddPointer(Ty);
1895  void *IP = nullptr;
1896  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1897  // Limit recursion depth.
1898  if (Depth > MaxCastDepth) {
1899  SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
1900  Op, Ty);
1901  UniqueSCEVs.InsertNode(S, IP);
1902  addToLoopUseLists(S);
1903  return S;
1904  }
1905 
1906  // sext(trunc(x)) --> sext(x) or x or trunc(x)
1907  if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
1908  // It's possible the bits taken off by the truncate were all sign bits. If
1909  // so, we should be able to simplify this further.
1910  const SCEV *X = ST->getOperand();
1912  unsigned TruncBits = getTypeSizeInBits(ST->getType());
1913  unsigned NewBits = getTypeSizeInBits(Ty);
1914  if (CR.truncate(TruncBits).signExtend(NewBits).contains(
1915  CR.sextOrTrunc(NewBits)))
1916  return getTruncateOrSignExtend(X, Ty, Depth);
1917  }
1918 
1919  if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
1920  // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
1921  if (SA->hasNoSignedWrap()) {
1922  // If the addition does not sign overflow then we can, by definition,
1923  // commute the sign extension with the addition operation.
1925  for (const auto *Op : SA->operands())
1926  Ops.push_back(getSignExtendExpr(Op, Ty, Depth + 1));
1927  return getAddExpr(Ops, SCEV::FlagNSW, Depth + 1);
1928  }
1929 
1930  // sext(C + x + y + ...) --> (sext(D) + sext((C - D) + x + y + ...))
1931  // if D + (C - D + x + y + ...) could be proven to not signed wrap
1932  // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
1933  //
1934  // For instance, this will bring two seemingly different expressions:
1935  // 1 + sext(5 + 20 * %x + 24 * %y) and
1936  // sext(6 + 20 * %x + 24 * %y)
1937  // to the same form:
1938  // 2 + sext(4 + 20 * %x + 24 * %y)
1939  if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
1940  const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
1941  if (D != 0) {
1942  const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
1943  const SCEV *SResidual =
1945  const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
1946  return getAddExpr(SSExtD, SSExtR,
1948  Depth + 1);
1949  }
1950  }
1951  }
1952  // If the input value is a chrec scev, and we can prove that the value
1953  // did not overflow the old, smaller, value, we can sign extend all of the
1954  // operands (often constants). This allows analysis of something like
1955  // this: for (signed char X = 0; X < 100; ++X) { int Y = X; }
1956  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
1957  if (AR->isAffine()) {
1958  const SCEV *Start = AR->getStart();
1959  const SCEV *Step = AR->getStepRecurrence(*this);
1960  unsigned BitWidth = getTypeSizeInBits(AR->getType());
1961  const Loop *L = AR->getLoop();
1962 
1963  if (!AR->hasNoSignedWrap()) {
1964  auto NewFlags = proveNoWrapViaConstantRanges(AR);
1965  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
1966  }
1967 
1968  // If we have special knowledge that this addrec won't overflow,
1969  // we don't need to do any further analysis.
1970  if (AR->hasNoSignedWrap())
1971  return getAddRecExpr(
1972  getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1),
1973  getSignExtendExpr(Step, Ty, Depth + 1), L, SCEV::FlagNSW);
1974 
1975  // Check whether the backedge-taken count is SCEVCouldNotCompute.
1976  // Note that this serves two purposes: It filters out loops that are
1977  // simply not analyzable, and it covers the case where this code is
1978  // being called from within backedge-taken count analysis, such that
1979  // attempting to ask for the backedge-taken count would likely result
1980  // in infinite recursion. In the later case, the analysis code will
1981  // cope with a conservative value, and it will take care to purge
1982  // that value once it has finished.
1983  const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
1984  if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
1985  // Manually compute the final value for AR, checking for
1986  // overflow.
1987 
1988  // Check whether the backedge-taken count can be losslessly casted to
1989  // the addrec's type. The count is always unsigned.
1990  const SCEV *CastedMaxBECount =
1991  getTruncateOrZeroExtend(MaxBECount, Start->getType(), Depth);
1992  const SCEV *RecastedMaxBECount = getTruncateOrZeroExtend(
1993  CastedMaxBECount, MaxBECount->getType(), Depth);
1994  if (MaxBECount == RecastedMaxBECount) {
1995  Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
1996  // Check whether Start+Step*MaxBECount has no signed overflow.
1997  const SCEV *SMul = getMulExpr(CastedMaxBECount, Step,
1998  SCEV::FlagAnyWrap, Depth + 1);
1999  const SCEV *SAdd = getSignExtendExpr(getAddExpr(Start, SMul,
2001  Depth + 1),
2002  WideTy, Depth + 1);
2003  const SCEV *WideStart = getSignExtendExpr(Start, WideTy, Depth + 1);
2004  const SCEV *WideMaxBECount =
2005  getZeroExtendExpr(CastedMaxBECount, WideTy, Depth + 1);
2006  const SCEV *OperandExtendedAdd =
2007  getAddExpr(WideStart,
2008  getMulExpr(WideMaxBECount,
2009  getSignExtendExpr(Step, WideTy, Depth + 1),
2010  SCEV::FlagAnyWrap, Depth + 1),
2011  SCEV::FlagAnyWrap, Depth + 1);
2012  if (SAdd == OperandExtendedAdd) {
2013  // Cache knowledge of AR NSW, which is propagated to this AddRec.
2014  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2015  // Return the expression with the addrec on the outside.
2016  return getAddRecExpr(
2017  getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2018  Depth + 1),
2019  getSignExtendExpr(Step, Ty, Depth + 1), L,
2020  AR->getNoWrapFlags());
2021  }
2022  // Similar to above, only this time treat the step value as unsigned.
2023  // This covers loops that count up with an unsigned step.
2024  OperandExtendedAdd =
2025  getAddExpr(WideStart,
2026  getMulExpr(WideMaxBECount,
2027  getZeroExtendExpr(Step, WideTy, Depth + 1),
2028  SCEV::FlagAnyWrap, Depth + 1),
2029  SCEV::FlagAnyWrap, Depth + 1);
2030  if (SAdd == OperandExtendedAdd) {
2031  // If AR wraps around then
2032  //
2033  // abs(Step) * MaxBECount > unsigned-max(AR->getType())
2034  // => SAdd != OperandExtendedAdd
2035  //
2036  // Thus (AR is not NW => SAdd != OperandExtendedAdd) <=>
2037  // (SAdd == OperandExtendedAdd => AR is NW)
2038 
2039  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
2040 
2041  // Return the expression with the addrec on the outside.
2042  return getAddRecExpr(
2043  getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
2044  Depth + 1),
2045  getZeroExtendExpr(Step, Ty, Depth + 1), L,
2046  AR->getNoWrapFlags());
2047  }
2048  }
2049  }
2050 
2051  auto NewFlags = proveNoSignedWrapViaInduction(AR);
2052  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
2053  if (AR->hasNoSignedWrap()) {
2054  // Same as nsw case above - duplicated here to avoid a compile time
2055  // issue. It's not clear that the order of checks does matter, but
2056  // it's one of two issue possible causes for a change which was
2057  // reverted. Be conservative for the moment.
2058  return getAddRecExpr(
2059  getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1),
2060  getSignExtendExpr(Step, Ty, Depth + 1), L, AR->getNoWrapFlags());
2061  }
2062 
2063  // sext({C,+,Step}) --> (sext(D) + sext({C-D,+,Step}))<nuw><nsw>
2064  // if D + (C - D + Step * n) could be proven to not signed wrap
2065  // where D maximizes the number of trailing zeros of (C - D + Step * n)
2066  if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
2067  const APInt &C = SC->getAPInt();
2068  const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
2069  if (D != 0) {
2070  const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
2071  const SCEV *SResidual =
2072  getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
2073  const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
2074  return getAddExpr(SSExtD, SSExtR,
2076  Depth + 1);
2077  }
2078  }
2079 
2080  if (proveNoWrapByVaryingStart<SCEVSignExtendExpr>(Start, Step, L)) {
2081  setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
2082  return getAddRecExpr(
2083  getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1),
2084  getSignExtendExpr(Step, Ty, Depth + 1), L, AR->getNoWrapFlags());
2085  }
2086  }
2087 
2088  // If the input value is provably positive and we could not simplify
2089  // away the sext build a zext instead.
2090  if (isKnownNonNegative(Op))
2091  return getZeroExtendExpr(Op, Ty, Depth + 1);
2092 
2093  // The cast wasn't folded; create an explicit cast node.
2094  // Recompute the insert position, as it may have been invalidated.
2095  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2096  SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
2097  Op, Ty);
2098  UniqueSCEVs.InsertNode(S, IP);
2099  addToLoopUseLists(S);
2100  return S;
2101 }
2102 
2103 /// getAnyExtendExpr - Return a SCEV for the given operand extended with
2104 /// unspecified bits out to the given type.
2106  Type *Ty) {
2107  assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
2108  "This is not an extending conversion!");
2109  assert(isSCEVable(Ty) &&
2110  "This is not a conversion to a SCEVable type!");
2111  Ty = getEffectiveSCEVType(Ty);
2112 
2113  // Sign-extend negative constants.
2114  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
2115  if (SC->getAPInt().isNegative())
2116  return getSignExtendExpr(Op, Ty);
2117 
2118  // Peel off a truncate cast.
2119  if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Op)) {
2120  const SCEV *NewOp = T->getOperand();
2121  if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty))
2122  return getAnyExtendExpr(NewOp, Ty);
2123  return getTruncateOrNoop(NewOp, Ty);
2124  }
2125 
2126  // Next try a zext cast. If the cast is folded, use it.
2127  const SCEV *ZExt = getZeroExtendExpr(Op, Ty);
2128  if (!isa<SCEVZeroExtendExpr>(ZExt))
2129  return ZExt;
2130 
2131  // Next try a sext cast. If the cast is folded, use it.
2132  const SCEV *SExt = getSignExtendExpr(Op, Ty);
2133  if (!isa<SCEVSignExtendExpr>(SExt))
2134  return SExt;
2135 
2136  // Force the cast to be folded into the operands of an addrec.
2137  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) {
2139  for (const SCEV *Op : AR->operands())
2140  Ops.push_back(getAnyExtendExpr(Op, Ty));
2141  return getAddRecExpr(Ops, AR->getLoop(), SCEV::FlagNW);
2142  }
2143 
2144  // If the expression is obviously signed, use the sext cast value.
2145  if (isa<SCEVSMaxExpr>(Op))
2146  return SExt;
2147 
2148  // Absent any other information, use the zext cast value.
2149  return ZExt;
2150 }
2151 
2152 /// Process the given Ops list, which is a list of operands to be added under
2153 /// the given scale, update the given map. This is a helper function for
2154 /// getAddRecExpr. As an example of what it does, given a sequence of operands
2155 /// that would form an add expression like this:
2156 ///
2157 /// m + n + 13 + (A * (o + p + (B * (q + m + 29)))) + r + (-1 * r)
2158 ///
2159 /// where A and B are constants, update the map with these values:
2160 ///
2161 /// (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0)
2162 ///
2163 /// and add 13 + A*B*29 to AccumulatedConstant.
2164 /// This will allow getAddRecExpr to produce this:
2165 ///
2166 /// 13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B)
2167 ///
2168 /// This form often exposes folding opportunities that are hidden in
2169 /// the original operand list.
2170 ///
2171 /// Return true iff it appears that any interesting folding opportunities
2172 /// may be exposed. This helps getAddRecExpr short-circuit extra work in
2173 /// the common case where no interesting opportunities are present, and
2174 /// is also used as a check to avoid infinite recursion.
2175 static bool
2178  APInt &AccumulatedConstant,
2179  const SCEV *const *Ops, size_t NumOperands,
2180  const APInt &Scale,
2181  ScalarEvolution &SE) {
2182  bool Interesting = false;
2183 
2184  // Iterate over the add operands. They are sorted, with constants first.
2185  unsigned i = 0;
2186  while (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
2187  ++i;
2188  // Pull a buried constant out to the outside.
2189  if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero())
2190  Interesting = true;
2191  AccumulatedConstant += Scale * C->getAPInt();
2192  }
2193 
2194  // Next comes everything else. We're especially interested in multiplies
2195  // here, but they're in the middle, so just visit the rest with one loop.
2196  for (; i != NumOperands; ++i) {
2197  const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[i]);
2198  if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) {
2199  APInt NewScale =
2200  Scale * cast<SCEVConstant>(Mul->getOperand(0))->getAPInt();
2201  if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) {
2202  // A multiplication of a constant with another add; recurse.
2203  const SCEVAddExpr *Add = cast<SCEVAddExpr>(Mul->getOperand(1));
2204  Interesting |=
2205  CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2206  Add->op_begin(), Add->getNumOperands(),
2207  NewScale, SE);
2208  } else {
2209  // A multiplication of a constant with some other value. Update
2210  // the map.
2211  SmallVector<const SCEV *, 4> MulOps(drop_begin(Mul->operands()));
2212  const SCEV *Key = SE.getMulExpr(MulOps);
2213  auto Pair = M.insert({Key, NewScale});
2214  if (Pair.second) {
2215  NewOps.push_back(Pair.first->first);
2216  } else {
2217  Pair.first->second += NewScale;
2218  // The map already had an entry for this value, which may indicate
2219  // a folding opportunity.
2220  Interesting = true;
2221  }
2222  }
2223  } else {
2224  // An ordinary operand. Update the map.
2225  std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair =
2226  M.insert({Ops[i], Scale});
2227  if (Pair.second) {
2228  NewOps.push_back(Pair.first->first);
2229  } else {
2230  Pair.first->second += Scale;
2231  // The map already had an entry for this value, which may indicate
2232  // a folding opportunity.
2233  Interesting = true;
2234  }
2235  }
2236  }
2237 
2238  return Interesting;
2239 }
2240 
2241 // We're trying to construct a SCEV of type `Type' with `Ops' as operands and
2242 // `OldFlags' as can't-wrap behavior. Infer a more aggressive set of
2243 // can't-overflow flags for the operation if possible.
2244 static SCEV::NoWrapFlags
2246  const ArrayRef<const SCEV *> Ops,
2247  SCEV::NoWrapFlags Flags) {
2248  using namespace std::placeholders;
2249 
2250  using OBO = OverflowingBinaryOperator;
2251 
2252  bool CanAnalyze =
2253  Type == scAddExpr || Type == scAddRecExpr || Type == scMulExpr;
2254  (void)CanAnalyze;
2255  assert(CanAnalyze && "don't call from other places!");
2256 
2257  int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
2258  SCEV::NoWrapFlags SignOrUnsignWrap =
2259  ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2260 
2261  // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
2262  auto IsKnownNonNegative = [&](const SCEV *S) {
2263  return SE->isKnownNonNegative(S);
2264  };
2265 
2266  if (SignOrUnsignWrap == SCEV::FlagNSW && all_of(Ops, IsKnownNonNegative))
2267  Flags =
2268  ScalarEvolution::setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
2269 
2270  SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
2271 
2272  if (SignOrUnsignWrap != SignOrUnsignMask &&
2273  (Type == scAddExpr || Type == scMulExpr) && Ops.size() == 2 &&
2274  isa<SCEVConstant>(Ops[0])) {
2275 
2276  auto Opcode = [&] {
2277  switch (Type) {
2278  case scAddExpr:
2279  return Instruction::Add;
2280  case scMulExpr:
2281  return Instruction::Mul;
2282  default:
2283  llvm_unreachable("Unexpected SCEV op.");
2284  }
2285  }();
2286 
2287  const APInt &C = cast<SCEVConstant>(Ops[0])->getAPInt();
2288 
2289  // (A <opcode> C) --> (A <opcode> C)<nsw> if the op doesn't sign overflow.
2290  if (!(SignOrUnsignWrap & SCEV::FlagNSW)) {
2292  Opcode, C, OBO::NoSignedWrap);
2293  if (NSWRegion.contains(SE->getSignedRange(Ops[1])))
2294  Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW);
2295  }
2296 
2297  // (A <opcode> C) --> (A <opcode> C)<nuw> if the op doesn't unsign overflow.
2298  if (!(SignOrUnsignWrap & SCEV::FlagNUW)) {
2300  Opcode, C, OBO::NoUnsignedWrap);
2301  if (NUWRegion.contains(SE->getUnsignedRange(Ops[1])))
2302  Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
2303  }
2304  }
2305 
2306  return Flags;
2307 }
2308 
2310  return isLoopInvariant(S, L) && properlyDominates(S, L->getHeader());
2311 }
2312 
2313 /// Get a canonical add expression, or something simpler if possible.
2315  SCEV::NoWrapFlags OrigFlags,
2316  unsigned Depth) {
2317  assert(!(OrigFlags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) &&
2318  "only nuw or nsw allowed");
2319  assert(!Ops.empty() && "Cannot get empty add!");
2320  if (Ops.size() == 1) return Ops[0];
2321 #ifndef NDEBUG
2322  Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
2323  for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2324  assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
2325  "SCEVAddExpr operand types don't match!");
2326 #endif
2327 
2328  // Sort by complexity, this groups all similar expression types together.
2329  GroupByComplexity(Ops, &LI, DT);
2330 
2331  // If there are any constants, fold them together.
2332  unsigned Idx = 0;
2333  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
2334  ++Idx;
2335  assert(Idx < Ops.size());
2336  while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
2337  // We found two constants, fold them together!
2338  Ops[0] = getConstant(LHSC->getAPInt() + RHSC->getAPInt());
2339  if (Ops.size() == 2) return Ops[0];
2340  Ops.erase(Ops.begin()+1); // Erase the folded element
2341  LHSC = cast<SCEVConstant>(Ops[0]);
2342  }
2343 
2344  // If we are left with a constant zero being added, strip it off.
2345  if (LHSC->getValue()->isZero()) {
2346  Ops.erase(Ops.begin());
2347  --Idx;
2348  }
2349 
2350  if (Ops.size() == 1) return Ops[0];
2351  }
2352 
2353  // Delay expensive flag strengthening until necessary.
2354  auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
2355  return StrengthenNoWrapFlags(this, scAddExpr, Ops, OrigFlags);
2356  };
2357 
2358  // Limit recursion calls depth.
2359  if (Depth > MaxArithDepth || hasHugeExpression(Ops))
2360  return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2361 
2362  if (SCEV *S = std::get<0>(findExistingSCEVInCache(scAddExpr, Ops))) {
2363  // Don't strengthen flags if we have no new information.
2364  SCEVAddExpr *Add = static_cast<SCEVAddExpr *>(S);
2365  if (Add->getNoWrapFlags(OrigFlags) != OrigFlags)
2366  Add->setNoWrapFlags(ComputeFlags(Ops));
2367  return S;
2368  }
2369 
2370  // Okay, check to see if the same value occurs in the operand list more than
2371  // once. If so, merge them together into an multiply expression. Since we
2372  // sorted the list, these values are required to be adjacent.
2373  Type *Ty = Ops[0]->getType();
2374  bool FoundMatch = false;
2375  for (unsigned i = 0, e = Ops.size(); i != e-1; ++i)
2376  if (Ops[i] == Ops[i+1]) { // X + Y + Y --> X + Y*2
2377  // Scan ahead to count how many equal operands there are.
2378  unsigned Count = 2;
2379  while (i+Count != e && Ops[i+Count] == Ops[i])
2380  ++Count;
2381  // Merge the values into a multiply.
2382  const SCEV *Scale = getConstant(Ty, Count);
2383  const SCEV *Mul = getMulExpr(Scale, Ops[i], SCEV::FlagAnyWrap, Depth + 1);
2384  if (Ops.size() == Count)
2385  return Mul;
2386  Ops[i] = Mul;
2387  Ops.erase(Ops.begin()+i+1, Ops.begin()+i+Count);
2388  --i; e -= Count - 1;
2389  FoundMatch = true;
2390  }
2391  if (FoundMatch)
2392  return getAddExpr(Ops, OrigFlags, Depth + 1);
2393 
2394  // Check for truncates. If all the operands are truncated from the same
2395  // type, see if factoring out the truncate would permit the result to be
2396  // folded. eg., n*trunc(x) + m*trunc(y) --> trunc(trunc(m)*x + trunc(n)*y)
2397  // if the contents of the resulting outer trunc fold to something simple.
2398  auto FindTruncSrcType = [&]() -> Type * {
2399  // We're ultimately looking to fold an addrec of truncs and muls of only
2400  // constants and truncs, so if we find any other types of SCEV
2401  // as operands of the addrec then we bail and return nullptr here.
2402  // Otherwise, we return the type of the operand of a trunc that we find.
2403  if (auto *T = dyn_cast<SCEVTruncateExpr>(Ops[Idx]))
2404  return T->getOperand()->getType();
2405  if (const auto *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
2406  const auto *LastOp = Mul->getOperand(Mul->getNumOperands() - 1);
2407  if (const auto *T = dyn_cast<SCEVTruncateExpr>(LastOp))
2408  return T->getOperand()->getType();
2409  }
2410  return nullptr;
2411  };
2412  if (auto *SrcType = FindTruncSrcType()) {
2414  bool Ok = true;
2415  // Check all the operands to see if they can be represented in the
2416  // source type of the truncate.
2417  for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
2418  if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Ops[i])) {
2419  if (T->getOperand()->getType() != SrcType) {
2420  Ok = false;
2421  break;
2422  }
2423  LargeOps.push_back(T->getOperand());
2424  } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
2425  LargeOps.push_back(getAnyExtendExpr(C, SrcType));
2426  } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Ops[i])) {
2427  SmallVector<const SCEV *, 8> LargeMulOps;
2428  for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) {
2429  if (const SCEVTruncateExpr *T =
2430  dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) {
2431  if (T->getOperand()->getType() != SrcType) {
2432  Ok = false;
2433  break;
2434  }
2435  LargeMulOps.push_back(T->getOperand());
2436  } else if (const auto *C = dyn_cast<SCEVConstant>(M->getOperand(j))) {
2437  LargeMulOps.push_back(getAnyExtendExpr(C, SrcType));
2438  } else {
2439  Ok = false;
2440  break;
2441  }
2442  }
2443  if (Ok)
2444  LargeOps.push_back(getMulExpr(LargeMulOps, SCEV::FlagAnyWrap, Depth + 1));
2445  } else {
2446  Ok = false;
2447  break;
2448  }
2449  }
2450  if (Ok) {
2451  // Evaluate the expression in the larger type.
2452  const SCEV *Fold = getAddExpr(LargeOps, SCEV::FlagAnyWrap, Depth + 1);
2453  // If it folds to something simple, use it. Otherwise, don't.
2454  if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
2455  return getTruncateExpr(Fold, Ty);
2456  }
2457  }
2458 
2459  // Skip past any other cast SCEVs.
2460  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
2461  ++Idx;
2462 
2463  // If there are add operands they would be next.
2464  if (Idx < Ops.size()) {
2465  bool DeletedAdd = false;
2466  while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
2467  if (Ops.size() > AddOpsInlineThreshold ||
2468  Add->getNumOperands() > AddOpsInlineThreshold)
2469  break;
2470  // If we have an add, expand the add operands onto the end of the operands
2471  // list.
2472  Ops.erase(Ops.begin()+Idx);
2473  Ops.append(Add->op_begin(), Add->op_end());
2474  DeletedAdd = true;
2475  }
2476 
2477  // If we deleted at least one add, we added operands to the end of the list,
2478  // and they are not necessarily sorted. Recurse to resort and resimplify
2479  // any operands we just acquired.
2480  if (DeletedAdd)
2481  return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2482  }
2483 
2484  // Skip over the add expression until we get to a multiply.
2485  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
2486  ++Idx;
2487 
2488  // Check to see if there are any folding opportunities present with
2489  // operands multiplied by constant values.
2490  if (Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx])) {
2491  uint64_t BitWidth = getTypeSizeInBits(Ty);
2494  APInt AccumulatedConstant(BitWidth, 0);
2495  if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
2496  Ops.data(), Ops.size(),
2497  APInt(BitWidth, 1), *this)) {
2498  struct APIntCompare {
2499  bool operator()(const APInt &LHS, const APInt &RHS) const {
2500  return LHS.ult(RHS);
2501  }
2502  };
2503 
2504  // Some interesting folding opportunity is present, so its worthwhile to
2505  // re-generate the operands list. Group the operands by constant scale,
2506  // to avoid multiplying by the same constant scale multiple times.
2507  std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare> MulOpLists;
2508  for (const SCEV *NewOp : NewOps)
2509  MulOpLists[M.find(NewOp)->second].push_back(NewOp);
2510  // Re-generate the operands list.
2511  Ops.clear();
2512  if (AccumulatedConstant != 0)
2513  Ops.push_back(getConstant(AccumulatedConstant));
2514  for (auto &MulOp : MulOpLists)
2515  if (MulOp.first != 0)
2516  Ops.push_back(getMulExpr(
2517  getConstant(MulOp.first),
2518  getAddExpr(MulOp.second, SCEV::FlagAnyWrap, Depth + 1),
2519  SCEV::FlagAnyWrap, Depth + 1));
2520  if (Ops.empty())
2521  return getZero(Ty);
2522  if (Ops.size() == 1)
2523  return Ops[0];
2524  return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2525  }
2526  }
2527 
2528  // If we are adding something to a multiply expression, make sure the
2529  // something is not already an operand of the multiply. If so, merge it into
2530  // the multiply.
2531  for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) {
2532  const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]);
2533  for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) {
2534  const SCEV *MulOpSCEV = Mul->getOperand(MulOp);
2535  if (isa<SCEVConstant>(MulOpSCEV))
2536  continue;
2537  for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
2538  if (MulOpSCEV == Ops[AddOp]) {
2539  // Fold W + X + (X * Y * Z) --> W + (X * ((Y*Z)+1))
2540  const SCEV *InnerMul = Mul->getOperand(MulOp == 0);
2541  if (Mul->getNumOperands() != 2) {
2542  // If the multiply has more than two operands, we must get the
2543  // Y*Z term.
2544  SmallVector<const SCEV *, 4> MulOps(Mul->op_begin(),
2545  Mul->op_begin()+MulOp);
2546  MulOps.append(Mul->op_begin()+MulOp+1, Mul->op_end());
2547  InnerMul = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2548  }
2549  SmallVector<const SCEV *, 2> TwoOps = {getOne(Ty), InnerMul};
2550  const SCEV *AddOne = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2551  const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV,
2552  SCEV::FlagAnyWrap, Depth + 1);
2553  if (Ops.size() == 2) return OuterMul;
2554  if (AddOp < Idx) {
2555  Ops.erase(Ops.begin()+AddOp);
2556  Ops.erase(Ops.begin()+Idx-1);
2557  } else {
2558  Ops.erase(Ops.begin()+Idx);
2559  Ops.erase(Ops.begin()+AddOp-1);
2560  }
2561  Ops.push_back(OuterMul);
2562  return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2563  }
2564 
2565  // Check this multiply against other multiplies being added together.
2566  for (unsigned OtherMulIdx = Idx+1;
2567  OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]);
2568  ++OtherMulIdx) {
2569  const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]);
2570  // If MulOp occurs in OtherMul, we can fold the two multiplies
2571  // together.
2572  for (unsigned OMulOp = 0, e = OtherMul->getNumOperands();
2573  OMulOp != e; ++OMulOp)
2574  if (OtherMul->getOperand(OMulOp) == MulOpSCEV) {
2575  // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E))
2576  const SCEV *InnerMul1 = Mul->getOperand(MulOp == 0);
2577  if (Mul->getNumOperands() != 2) {
2578  SmallVector<const SCEV *, 4> MulOps(Mul->op_begin(),
2579  Mul->op_begin()+MulOp);
2580  MulOps.append(Mul->op_begin()+MulOp+1, Mul->op_end());
2581  InnerMul1 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2582  }
2583  const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0);
2584  if (OtherMul->getNumOperands() != 2) {
2585  SmallVector<const SCEV *, 4> MulOps(OtherMul->op_begin(),
2586  OtherMul->op_begin()+OMulOp);
2587  MulOps.append(OtherMul->op_begin()+OMulOp+1, OtherMul->op_end());
2588  InnerMul2 = getMulExpr(MulOps, SCEV::FlagAnyWrap, Depth + 1);
2589  }
2590  SmallVector<const SCEV *, 2> TwoOps = {InnerMul1, InnerMul2};
2591  const SCEV *InnerMulSum =
2592  getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2593  const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum,
2594  SCEV::FlagAnyWrap, Depth + 1);
2595  if (Ops.size() == 2) return OuterMul;
2596  Ops.erase(Ops.begin()+Idx);
2597  Ops.erase(Ops.begin()+OtherMulIdx-1);
2598  Ops.push_back(OuterMul);
2599  return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2600  }
2601  }
2602  }
2603  }
2604 
2605  // If there are any add recurrences in the operands list, see if any other
2606  // added values are loop invariant. If so, we can fold them into the
2607  // recurrence.
2608  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
2609  ++Idx;
2610 
2611  // Scan over all recurrences, trying to fold loop invariants into them.
2612  for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
2613  // Scan all of the other operands to this add and add them to the vector if
2614  // they are loop invariant w.r.t. the recurrence.
2616  const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
2617  const Loop *AddRecLoop = AddRec->getLoop();
2618  for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2619  if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) {
2620  LIOps.push_back(Ops[i]);
2621  Ops.erase(Ops.begin()+i);
2622  --i; --e;
2623  }
2624 
2625  // If we found some loop invariants, fold them into the recurrence.
2626  if (!LIOps.empty()) {
2627  // Compute nowrap flags for the addition of the loop-invariant ops and
2628  // the addrec. Temporarily push it as an operand for that purpose.
2629  LIOps.push_back(AddRec);
2630  SCEV::NoWrapFlags Flags = ComputeFlags(LIOps);
2631  LIOps.pop_back();
2632 
2633  // NLI + LI + {Start,+,Step} --> NLI + {LI+Start,+,Step}
2634  LIOps.push_back(AddRec->getStart());
2635 
2636  SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2637  // This follows from the fact that the no-wrap flags on the outer add
2638  // expression are applicable on the 0th iteration, when the add recurrence
2639  // will be equal to its start value.
2640  AddRecOps[0] = getAddExpr(LIOps, Flags, Depth + 1);
2641 
2642  // Build the new addrec. Propagate the NUW and NSW flags if both the
2643  // outer add and the inner addrec are guaranteed to have no overflow.
2644  // Always propagate NW.
2645  Flags = AddRec->getNoWrapFlags(setFlags(Flags, SCEV::FlagNW));
2646  const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags);
2647 
2648  // If all of the other operands were loop invariant, we are done.
2649  if (Ops.size() == 1) return NewRec;
2650 
2651  // Otherwise, add the folded AddRec by the non-invariant parts.
2652  for (unsigned i = 0;; ++i)
2653  if (Ops[i] == AddRec) {
2654  Ops[i] = NewRec;
2655  break;
2656  }
2657  return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2658  }
2659 
2660  // Okay, if there weren't any loop invariants to be folded, check to see if
2661  // there are multiple AddRec's with the same loop induction variable being
2662  // added together. If so, we can fold them.
2663  for (unsigned OtherIdx = Idx+1;
2664  OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2665  ++OtherIdx) {
2666  // We expect the AddRecExpr's to be sorted in reverse dominance order,
2667  // so that the 1st found AddRecExpr is dominated by all others.
2668  assert(DT.dominates(
2669  cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()->getHeader(),
2670  AddRec->getLoop()->getHeader()) &&
2671  "AddRecExprs are not sorted in reverse dominance order?");
2672  if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) {
2673  // Other + {A,+,B}<L> + {C,+,D}<L> --> Other + {A+C,+,B+D}<L>
2674  SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
2675  for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
2676  ++OtherIdx) {
2677  const auto *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
2678  if (OtherAddRec->getLoop() == AddRecLoop) {
2679  for (unsigned i = 0, e = OtherAddRec->getNumOperands();
2680  i != e; ++i) {
2681  if (i >= AddRecOps.size()) {
2682  AddRecOps.append(OtherAddRec->op_begin()+i,
2683  OtherAddRec->op_end());
2684  break;
2685  }
2686  SmallVector<const SCEV *, 2> TwoOps = {
2687  AddRecOps[i], OtherAddRec->getOperand(i)};
2688  AddRecOps[i] = getAddExpr(TwoOps, SCEV::FlagAnyWrap, Depth + 1);
2689  }
2690  Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
2691  }
2692  }
2693  // Step size has changed, so we cannot guarantee no self-wraparound.
2694  Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap);
2695  return getAddExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2696  }
2697  }
2698 
2699  // Otherwise couldn't fold anything into this recurrence. Move onto the
2700  // next one.
2701  }
2702 
2703  // Okay, it looks like we really DO need an add expr. Check to see if we
2704  // already have one, otherwise create a new one.
2705  return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
2706 }
2707 
2708 const SCEV *
2709 ScalarEvolution::getOrCreateAddExpr(ArrayRef<const SCEV *> Ops,
2710  SCEV::NoWrapFlags Flags) {
2712  ID.AddInteger(scAddExpr);
2713  for (const SCEV *Op : Ops)
2714  ID.AddPointer(Op);
2715  void *IP = nullptr;
2716  SCEVAddExpr *S =
2717  static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2718  if (!S) {
2719  const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2720  std::uninitialized_copy(Ops.begin(), Ops.end(), O);
2721  S = new (SCEVAllocator)
2722  SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size());
2723  UniqueSCEVs.InsertNode(S, IP);
2724  addToLoopUseLists(S);
2725  }
2726  S->setNoWrapFlags(Flags);
2727  return S;
2728 }
2729 
2730 const SCEV *
2731 ScalarEvolution::getOrCreateAddRecExpr(ArrayRef<const SCEV *> Ops,
2732  const Loop *L, SCEV::NoWrapFlags Flags) {
2734  ID.AddInteger(scAddRecExpr);
2735  for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2736  ID.AddPointer(Ops[i]);
2737  ID.AddPointer(L);
2738  void *IP = nullptr;
2739  SCEVAddRecExpr *S =
2740  static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2741  if (!S) {
2742  const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2743  std::uninitialized_copy(Ops.begin(), Ops.end(), O);
2744  S = new (SCEVAllocator)
2745  SCEVAddRecExpr(ID.Intern(SCEVAllocator), O, Ops.size(), L);
2746  UniqueSCEVs.InsertNode(S, IP);
2747  addToLoopUseLists(S);
2748  }
2749  setNoWrapFlags(S, Flags);
2750  return S;
2751 }
2752 
2753 const SCEV *
2754 ScalarEvolution::getOrCreateMulExpr(ArrayRef<const SCEV *> Ops,
2755  SCEV::NoWrapFlags Flags) {
2757  ID.AddInteger(scMulExpr);
2758  for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2759  ID.AddPointer(Ops[i]);
2760  void *IP = nullptr;
2761  SCEVMulExpr *S =
2762  static_cast<SCEVMulExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2763  if (!S) {
2764  const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2765  std::uninitialized_copy(Ops.begin(), Ops.end(), O);
2766  S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator),
2767  O, Ops.size());
2768  UniqueSCEVs.InsertNode(S, IP);
2769  addToLoopUseLists(S);
2770  }
2771  S->setNoWrapFlags(Flags);
2772  return S;
2773 }
2774 
2775 static uint64_t umul_ov(uint64_t i, uint64_t j, bool &Overflow) {
2776  uint64_t k = i*j;
2777  if (j > 1 && k / j != i) Overflow = true;
2778  return k;
2779 }
2780 
2781 /// Compute the result of "n choose k", the binomial coefficient. If an
2782 /// intermediate computation overflows, Overflow will be set and the return will
2783 /// be garbage. Overflow is not cleared on absence of overflow.
2784 static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) {
2785  // We use the multiplicative formula:
2786  // n(n-1)(n-2)...(n-(k-1)) / k(k-1)(k-2)...1 .
2787  // At each iteration, we take the n-th term of the numeral and divide by the
2788  // (k-n)th term of the denominator. This division will always produce an
2789  // integral result, and helps reduce the chance of overflow in the
2790  // intermediate computations. However, we can still overflow even when the
2791  // final result would fit.
2792 
2793  if (n == 0 || n == k) return 1;
2794  if (k > n) return 0;
2795 
2796  if (k > n/2)
2797  k = n-k;
2798 
2799  uint64_t r = 1;
2800  for (uint64_t i = 1; i <= k; ++i) {
2801  r = umul_ov(r, n-(i-1), Overflow);
2802  r /= i;
2803  }
2804  return r;
2805 }
2806 
2807 /// Determine if any of the operands in this SCEV are a constant or if
2808 /// any of the add or multiply expressions in this SCEV contain a constant.
2809 static bool containsConstantInAddMulChain(const SCEV *StartExpr) {
2810  struct FindConstantInAddMulChain {
2811  bool FoundConstant = false;
2812 
2813  bool follow(const SCEV *S) {
2814  FoundConstant |= isa<SCEVConstant>(S);
2815  return isa<SCEVAddExpr>(S) || isa<SCEVMulExpr>(S);
2816  }
2817 
2818  bool isDone() const {
2819  return FoundConstant;
2820  }
2821  };
2822 
2823  FindConstantInAddMulChain F;
2825  ST.visitAll(StartExpr);
2826  return F.FoundConstant;
2827 }
2828 
2829 /// Get a canonical multiply expression, or something simpler if possible.
2831  SCEV::NoWrapFlags OrigFlags,
2832  unsigned Depth) {
2833  assert(OrigFlags == maskFlags(OrigFlags, SCEV::FlagNUW | SCEV::FlagNSW) &&
2834  "only nuw or nsw allowed");
2835  assert(!Ops.empty() && "Cannot get empty mul!");
2836  if (Ops.size() == 1) return Ops[0];
2837 #ifndef NDEBUG
2838  Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
2839  for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2840  assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
2841  "SCEVMulExpr operand types don't match!");
2842 #endif
2843 
2844  // Sort by complexity, this groups all similar expression types together.
2845  GroupByComplexity(Ops, &LI, DT);
2846 
2847  // If there are any constants, fold them together.
2848  unsigned Idx = 0;
2849  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
2850  ++Idx;
2851  assert(Idx < Ops.size());
2852  while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
2853  // We found two constants, fold them together!
2854  Ops[0] = getConstant(LHSC->getAPInt() * RHSC->getAPInt());
2855  if (Ops.size() == 2) return Ops[0];
2856  Ops.erase(Ops.begin()+1); // Erase the folded element
2857  LHSC = cast<SCEVConstant>(Ops[0]);
2858  }
2859 
2860  // If we have a multiply of zero, it will always be zero.
2861  if (LHSC->getValue()->isZero())
2862  return LHSC;
2863 
2864  // If we are left with a constant one being multiplied, strip it off.
2865  if (LHSC->getValue()->isOne()) {
2866  Ops.erase(Ops.begin());
2867  --Idx;
2868  }
2869 
2870  if (Ops.size() == 1)
2871  return Ops[0];
2872  }
2873 
2874  // Delay expensive flag strengthening until necessary.
2875  auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
2876  return StrengthenNoWrapFlags(this, scMulExpr, Ops, OrigFlags);
2877  };
2878 
2879  // Limit recursion calls depth.
2880  if (Depth > MaxArithDepth || hasHugeExpression(Ops))
2881  return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
2882 
2883  if (SCEV *S = std::get<0>(findExistingSCEVInCache(scMulExpr, Ops))) {
2884  // Don't strengthen flags if we have no new information.
2885  SCEVMulExpr *Mul = static_cast<SCEVMulExpr *>(S);
2886  if (Mul->getNoWrapFlags(OrigFlags) != OrigFlags)
2887  Mul->setNoWrapFlags(ComputeFlags(Ops));
2888  return S;
2889  }
2890 
2891  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
2892  if (Ops.size() == 2) {
2893  // C1*(C2+V) -> C1*C2 + C1*V
2894  if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1]))
2895  // If any of Add's ops are Adds or Muls with a constant, apply this
2896  // transformation as well.
2897  //
2898  // TODO: There are some cases where this transformation is not
2899  // profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of
2900  // this transformation should be narrowed down.
2901  if (Add->getNumOperands() == 2 && containsConstantInAddMulChain(Add))
2902  return getAddExpr(getMulExpr(LHSC, Add->getOperand(0),
2903  SCEV::FlagAnyWrap, Depth + 1),
2904  getMulExpr(LHSC, Add->getOperand(1),
2905  SCEV::FlagAnyWrap, Depth + 1),
2906  SCEV::FlagAnyWrap, Depth + 1);
2907 
2908  if (Ops[0]->isAllOnesValue()) {
2909  // If we have a mul by -1 of an add, try distributing the -1 among the
2910  // add operands.
2911  if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) {
2913  bool AnyFolded = false;
2914  for (const SCEV *AddOp : Add->operands()) {
2915  const SCEV *Mul = getMulExpr(Ops[0], AddOp, SCEV::FlagAnyWrap,
2916  Depth + 1);
2917  if (!isa<SCEVMulExpr>(Mul)) AnyFolded = true;
2918  NewOps.push_back(Mul);
2919  }
2920  if (AnyFolded)
2921  return getAddExpr(NewOps, SCEV::FlagAnyWrap, Depth + 1);
2922  } else if (const auto *AddRec = dyn_cast<SCEVAddRecExpr>(Ops[1])) {
2923  // Negation preserves a recurrence's no self-wrap property.
2925  for (const SCEV *AddRecOp : AddRec->operands())
2926  Operands.push_back(getMulExpr(Ops[0], AddRecOp, SCEV::FlagAnyWrap,
2927  Depth + 1));
2928 
2929  return getAddRecExpr(Operands, AddRec->getLoop(),
2930  AddRec->getNoWrapFlags(SCEV::FlagNW));
2931  }
2932  }
2933  }
2934  }
2935 
2936  // Skip over the add expression until we get to a multiply.
2937  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
2938  ++Idx;
2939 
2940  // If there are mul operands inline them all into this expression.
2941  if (Idx < Ops.size()) {
2942  bool DeletedMul = false;
2943  while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
2944  if (Ops.size() > MulOpsInlineThreshold)
2945  break;
2946  // If we have an mul, expand the mul operands onto the end of the
2947  // operands list.
2948  Ops.erase(Ops.begin()+Idx);
2949  Ops.append(Mul->op_begin(), Mul->op_end());
2950  DeletedMul = true;
2951  }
2952 
2953  // If we deleted at least one mul, we added operands to the end of the
2954  // list, and they are not necessarily sorted. Recurse to resort and
2955  // resimplify any operands we just acquired.
2956  if (DeletedMul)
2957  return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
2958  }
2959 
2960  // If there are any add recurrences in the operands list, see if any other
2961  // added values are loop invariant. If so, we can fold them into the
2962  // recurrence.
2963  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
2964  ++Idx;
2965 
2966  // Scan over all recurrences, trying to fold loop invariants into them.
2967  for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
2968  // Scan all of the other operands to this mul and add them to the vector
2969  // if they are loop invariant w.r.t. the recurrence.
2971  const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
2972  const Loop *AddRecLoop = AddRec->getLoop();
2973  for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2974  if (isAvailableAtLoopEntry(Ops[i], AddRecLoop)) {
2975  LIOps.push_back(Ops[i]);
2976  Ops.erase(Ops.begin()+i);
2977  --i; --e;
2978  }
2979 
2980  // If we found some loop invariants, fold them into the recurrence.
2981  if (!LIOps.empty()) {
2982  // NLI * LI * {Start,+,Step} --> NLI * {LI*Start,+,LI*Step}
2984  NewOps.reserve(AddRec->getNumOperands());
2985  const SCEV *Scale = getMulExpr(LIOps, SCEV::FlagAnyWrap, Depth + 1);
2986  for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i)
2987  NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i),
2988  SCEV::FlagAnyWrap, Depth + 1));
2989 
2990  // Build the new addrec. Propagate the NUW and NSW flags if both the
2991  // outer mul and the inner addrec are guaranteed to have no overflow.
2992  //
2993  // No self-wrap cannot be guaranteed after changing the step size, but
2994  // will be inferred if either NUW or NSW is true.
2995  SCEV::NoWrapFlags Flags = ComputeFlags({Scale, AddRec});
2996  const SCEV *NewRec = getAddRecExpr(
2997  NewOps, AddRecLoop, AddRec->getNoWrapFlags(Flags));
2998 
2999  // If all of the other operands were loop invariant, we are done.
3000  if (Ops.size() == 1) return NewRec;
3001 
3002  // Otherwise, multiply the folded AddRec by the non-invariant parts.
3003  for (unsigned i = 0;; ++i)
3004  if (Ops[i] == AddRec) {
3005  Ops[i] = NewRec;
3006  break;
3007  }
3008  return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3009  }
3010 
3011  // Okay, if there weren't any loop invariants to be folded, check to see
3012  // if there are multiple AddRec's with the same loop induction variable
3013  // being multiplied together. If so, we can fold them.
3014 
3015  // {A1,+,A2,+,...,+,An}<L> * {B1,+,B2,+,...,+,Bn}<L>
3016  // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [
3017  // choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z
3018  // ]]],+,...up to x=2n}.
3019  // Note that the arguments to choose() are always integers with values
3020  // known at compile time, never SCEV objects.
3021  //
3022  // The implementation avoids pointless extra computations when the two
3023  // addrec's are of different length (mathematically, it's equivalent to
3024  // an infinite stream of zeros on the right).
3025  bool OpsModified = false;
3026  for (unsigned OtherIdx = Idx+1;
3027  OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
3028  ++OtherIdx) {
3029  const SCEVAddRecExpr *OtherAddRec =
3030  dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]);
3031  if (!OtherAddRec || OtherAddRec->getLoop() != AddRecLoop)
3032  continue;
3033 
3034  // Limit max number of arguments to avoid creation of unreasonably big
3035  // SCEVAddRecs with very complex operands.
3036  if (AddRec->getNumOperands() + OtherAddRec->getNumOperands() - 1 >
3037  MaxAddRecSize || hasHugeExpression({AddRec, OtherAddRec}))
3038  continue;
3039 
3040  bool Overflow = false;
3041  Type *Ty = AddRec->getType();
3042  bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64;
3043  SmallVector<const SCEV*, 7> AddRecOps;
3044  for (int x = 0, xe = AddRec->getNumOperands() +
3045  OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) {
3047  for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) {
3048  uint64_t Coeff1 = Choose(x, 2*x - y, Overflow);
3049  for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1),
3050  ze = std::min(x+1, (int)OtherAddRec->getNumOperands());
3051  z < ze && !Overflow; ++z) {
3052  uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow);
3053  uint64_t Coeff;
3054  if (LargerThan64Bits)
3055  Coeff = umul_ov(Coeff1, Coeff2, Overflow);
3056  else
3057  Coeff = Coeff1*Coeff2;
3058  const SCEV *CoeffTerm = getConstant(Ty, Coeff);
3059  const SCEV *Term1 = AddRec->getOperand(y-z);
3060  const SCEV *Term2 = OtherAddRec->getOperand(z);
3061  SumOps.push_back(getMulExpr(CoeffTerm, Term1, Term2,
3062  SCEV::FlagAnyWrap, Depth + 1));
3063  }
3064  }
3065  if (SumOps.empty())
3066  SumOps.push_back(getZero(Ty));
3067  AddRecOps.push_back(getAddExpr(SumOps, SCEV::FlagAnyWrap, Depth + 1));
3068  }
3069  if (!Overflow) {
3070  const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRecLoop,
3072  if (Ops.size() == 2) return NewAddRec;
3073  Ops[Idx] = NewAddRec;
3074  Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
3075  OpsModified = true;
3076  AddRec = dyn_cast<SCEVAddRecExpr>(NewAddRec);
3077  if (!AddRec)
3078  break;
3079  }
3080  }
3081  if (OpsModified)
3082  return getMulExpr(Ops, SCEV::FlagAnyWrap, Depth + 1);
3083 
3084  // Otherwise couldn't fold anything into this recurrence. Move onto the
3085  // next one.
3086  }
3087 
3088  // Okay, it looks like we really DO need an mul expr. Check to see if we
3089  // already have one, otherwise create a new one.
3090  return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
3091 }
3092 
3093 /// Represents an unsigned remainder expression based on unsigned division.
3095  const SCEV *RHS) {
3097  getEffectiveSCEVType(RHS->getType()) &&
3098  "SCEVURemExpr operand types don't match!");
3099 
3100  // Short-circuit easy cases
3101  if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3102  // If constant is one, the result is trivial
3103  if (RHSC->getValue()->isOne())
3104  return getZero(LHS->getType()); // X urem 1 --> 0
3105 
3106  // If constant is a power of two, fold into a zext(trunc(LHS)).
3107  if (RHSC->getAPInt().isPowerOf2()) {
3108  Type *FullTy = LHS->getType();
3109  Type *TruncTy =
3110  IntegerType::get(getContext(), RHSC->getAPInt().logBase2());
3111  return getZeroExtendExpr(getTruncateExpr(LHS, TruncTy), FullTy);
3112  }
3113  }
3114 
3115  // Fallback to %a == %x urem %y == %x -<nuw> ((%x udiv %y) *<nuw> %y)
3116  const SCEV *UDiv = getUDivExpr(LHS, RHS);
3117  const SCEV *Mult = getMulExpr(UDiv, RHS, SCEV::FlagNUW);
3118  return getMinusSCEV(LHS, Mult, SCEV::FlagNUW);
3119 }
3120 
3121 /// Get a canonical unsigned division expression, or something simpler if
3122 /// possible.
3124  const SCEV *RHS) {
3126  getEffectiveSCEVType(RHS->getType()) &&
3127  "SCEVUDivExpr operand types don't match!");
3128 
3130  ID.AddInteger(scUDivExpr);
3131  ID.AddPointer(LHS);
3132  ID.AddPointer(RHS);
3133  void *IP = nullptr;
3134  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3135  return S;
3136 
3137  if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
3138  if (RHSC->getValue()->isOne())
3139  return LHS; // X udiv 1 --> x
3140  // If the denominator is zero, the result of the udiv is undefined. Don't
3141  // try to analyze it, because the resolution chosen here may differ from
3142  // the resolution chosen in other parts of the compiler.
3143  if (!RHSC->getValue()->isZero()) {
3144  // Determine if the division can be folded into the operands of
3145  // its operands.
3146  // TODO: Generalize this to non-constants by using known-bits information.
3147  Type *Ty = LHS->getType();
3148  unsigned LZ = RHSC->getAPInt().countLeadingZeros();
3149  unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1;
3150  // For non-power-of-two values, effectively round the value up to the
3151  // nearest power of two.
3152  if (!RHSC->getAPInt().isPowerOf2())
3153  ++MaxShiftAmt;
3154  IntegerType *ExtTy =
3155  IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt);
3156  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
3157  if (const SCEVConstant *Step =
3158  dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this))) {
3159  // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded.
3160  const APInt &StepInt = Step->getAPInt();
3161  const APInt &DivInt = RHSC->getAPInt();
3162  if (!StepInt.urem(DivInt) &&
3163  getZeroExtendExpr(AR, ExtTy) ==
3164  getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3165  getZeroExtendExpr(Step, ExtTy),
3166  AR->getLoop(), SCEV::FlagAnyWrap)) {
3168  for (const SCEV *Op : AR->operands())
3169  Operands.push_back(getUDivExpr(Op, RHS));
3170  return getAddRecExpr(Operands, AR->getLoop(), SCEV::FlagNW);
3171  }
3172  /// Get a canonical UDivExpr for a recurrence.
3173  /// {X,+,N}/C => {Y,+,N}/C where Y=X-(X%N). Safe when C%N=0.
3174  // We can currently only fold X%N if X is constant.
3175  const SCEVConstant *StartC = dyn_cast<SCEVConstant>(AR->getStart());
3176  if (StartC && !DivInt.urem(StepInt) &&
3177  getZeroExtendExpr(AR, ExtTy) ==
3178  getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
3179  getZeroExtendExpr(Step, ExtTy),
3180  AR->getLoop(), SCEV::FlagAnyWrap)) {
3181  const APInt &StartInt = StartC->getAPInt();
3182  const APInt &StartRem = StartInt.urem(StepInt);
3183  if (StartRem != 0) {
3184  const SCEV *NewLHS =
3185  getAddRecExpr(getConstant(StartInt - StartRem), Step,
3186  AR->getLoop(), SCEV::FlagNW);
3187  if (LHS != NewLHS) {
3188  LHS = NewLHS;
3189 
3190  // Reset the ID to include the new LHS, and check if it is
3191  // already cached.
3192  ID.clear();
3193  ID.AddInteger(scUDivExpr);
3194  ID.AddPointer(LHS);
3195  ID.AddPointer(RHS);
3196  IP = nullptr;
3197  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
3198  return S;
3199  }
3200  }
3201  }
3202  }
3203  // (A*B)/C --> A*(B/C) if safe and B/C can be folded.
3204  if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) {
3206  for (const SCEV *Op : M->operands())
3207  Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3208  if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands))
3209  // Find an operand that's safely divisible.
3210  for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) {
3211  const SCEV *Op = M->getOperand(i);
3212  const SCEV *Div = getUDivExpr(Op, RHSC);
3213  if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) {
3214  Operands = SmallVector<const SCEV *, 4>(M->operands());
3215  Operands[i] = Div;
3216  return getMulExpr(Operands);
3217  }
3218  }
3219  }
3220 
3221  // (A/B)/C --> A/(B*C) if safe and B*C can be folded.
3222  if (const SCEVUDivExpr *OtherDiv = dyn_cast<SCEVUDivExpr>(LHS)) {
3223  if (auto *DivisorConstant =
3224  dyn_cast<SCEVConstant>(OtherDiv->getRHS())) {
3225  bool Overflow = false;
3226  APInt NewRHS =
3227  DivisorConstant->getAPInt().umul_ov(RHSC->getAPInt(), Overflow);
3228  if (Overflow) {
3229  return getConstant(RHSC->getType(), 0, false);
3230  }
3231  return getUDivExpr(OtherDiv->getLHS(), getConstant(NewRHS));
3232  }
3233  }
3234 
3235  // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded.
3236  if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(LHS)) {
3238  for (const SCEV *Op : A->operands())
3239  Operands.push_back(getZeroExtendExpr(Op, ExtTy));
3240  if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) {
3241  Operands.clear();
3242  for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) {
3243  const SCEV *Op = getUDivExpr(A->getOperand(i), RHS);
3244  if (isa<SCEVUDivExpr>(Op) ||
3245  getMulExpr(Op, RHS) != A->getOperand(i))
3246  break;
3247  Operands.push_back(Op);
3248  }
3249  if (Operands.size() == A->getNumOperands())
3250  return getAddExpr(Operands);
3251  }
3252  }
3253 
3254  // Fold if both operands are constant.
3255  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
3256  Constant *LHSCV = LHSC->getValue();
3257  Constant *RHSCV = RHSC->getValue();
3258  return getConstant(cast<ConstantInt>(ConstantExpr::getUDiv(LHSCV,
3259  RHSCV)));
3260  }
3261  }
3262  }
3263 
3264  // The Insertion Point (IP) might be invalid by now (due to UniqueSCEVs
3265  // changes). Make sure we get a new one.
3266  IP = nullptr;
3267  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
3268  SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator),
3269  LHS, RHS);
3270  UniqueSCEVs.InsertNode(S, IP);
3271  addToLoopUseLists(S);
3272  return S;
3273 }
3274 
3275 static const APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) {
3276  APInt A = C1->getAPInt().abs();
3277  APInt B = C2->getAPInt().abs();
3278  uint32_t ABW = A.getBitWidth();
3279  uint32_t BBW = B.getBitWidth();
3280 
3281  if (ABW > BBW)
3282  B = B.zext(ABW);
3283  else if (ABW < BBW)
3284  A = A.zext(BBW);
3285 
3287 }
3288 
3289 /// Get a canonical unsigned division expression, or something simpler if
3290 /// possible. There is no representation for an exact udiv in SCEV IR, but we
3291 /// can attempt to remove factors from the LHS and RHS. We can't do this when
3292 /// it's not exact because the udiv may be clearing bits.
3294  const SCEV *RHS) {
3295  // TODO: we could try to find factors in all sorts of things, but for now we
3296  // just deal with u/exact (multiply, constant). See SCEVDivision towards the
3297  // end of this file for inspiration.
3298 
3299  const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(LHS);
3300  if (!Mul || !Mul->hasNoUnsignedWrap())
3301  return getUDivExpr(LHS, RHS);
3302 
3303  if (const SCEVConstant *RHSCst = dyn_cast<SCEVConstant>(RHS)) {
3304  // If the mulexpr multiplies by a constant, then that constant must be the
3305  // first element of the mulexpr.
3306  if (const auto *LHSCst = dyn_cast<SCEVConstant>(Mul->getOperand(0))) {
3307  if (LHSCst == RHSCst) {
3309  return getMulExpr(Operands);
3310  }
3311 
3312  // We can't just assume that LHSCst divides RHSCst cleanly, it could be
3313  // that there's a factor provided by one of the other terms. We need to
3314  // check.
3315  APInt Factor = gcd(LHSCst, RHSCst);
3316  if (!Factor.isIntN(1)) {
3317  LHSCst =
3318  cast<SCEVConstant>(getConstant(LHSCst->getAPInt().udiv(Factor)));
3319  RHSCst =
3320  cast<SCEVConstant>(getConstant(RHSCst->getAPInt().udiv(Factor)));
3322  Operands.push_back(LHSCst);
3323  Operands.append(Mul->op_begin() + 1, Mul->op_end());
3324  LHS = getMulExpr(Operands);
3325  RHS = RHSCst;
3326  Mul = dyn_cast<SCEVMulExpr>(LHS);
3327  if (!Mul)
3328  return getUDivExactExpr(LHS, RHS);
3329  }
3330  }
3331  }
3332 
3333  for (int i = 0, e = Mul->getNumOperands(); i != e; ++i) {
3334  if (Mul->getOperand(i) == RHS) {
3336  Operands.append(Mul->op_begin(), Mul->op_begin() + i);
3337  Operands.append(Mul->op_begin() + i + 1, Mul->op_end());
3338  return getMulExpr(Operands);
3339  }
3340  }
3341 
3342  return getUDivExpr(LHS, RHS);
3343 }
3344 
3345 /// Get an add recurrence expression for the specified loop. Simplify the
3346 /// expression as much as possible.
3347 const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, const SCEV *Step,
3348  const Loop *L,
3349  SCEV::NoWrapFlags Flags) {
3351  Operands.push_back(Start);
3352  if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step))
3353  if (StepChrec->getLoop() == L) {
3354  Operands.append(StepChrec->op_begin(), StepChrec->op_end());
3355  return getAddRecExpr(Operands, L, maskFlags(Flags, SCEV::FlagNW));
3356  }
3357 
3358  Operands.push_back(Step);
3359  return getAddRecExpr(Operands, L, Flags);
3360 }
3361 
3362 /// Get an add recurrence expression for the specified loop. Simplify the
3363 /// expression as much as possible.
3364 const SCEV *
3366  const Loop *L, SCEV::NoWrapFlags Flags) {
3367  if (Operands.size() == 1) return Operands[0];
3368 #ifndef NDEBUG
3369  Type *ETy = getEffectiveSCEVType(Operands[0]->getType());
3370  for (unsigned i = 1, e = Operands.size(); i != e; ++i)
3372  "SCEVAddRecExpr operand types don't match!");
3373  for (unsigned i = 0, e = Operands.size(); i != e; ++i)
3375  "SCEVAddRecExpr operand is not loop-invariant!");
3376 #endif
3377 
3378  if (Operands.back()->isZero()) {
3379  Operands.pop_back();
3380  return getAddRecExpr(Operands, L, SCEV::FlagAnyWrap); // {X,+,0} --> X
3381  }
3382 
3383  // It's tempting to want to call getConstantMaxBackedgeTakenCount count here and
3384  // use that information to infer NUW and NSW flags. However, computing a
3385  // BE count requires calling getAddRecExpr, so we may not yet have a
3386  // meaningful BE count at this point (and if we don't, we'd be stuck
3387  // with a SCEVCouldNotCompute as the cached BE count).
3388 
3389  Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags);
3390 
3391  // Canonicalize nested AddRecs in by nesting them in order of loop depth.
3392  if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
3393  const Loop *NestedLoop = NestedAR->getLoop();
3394  if (L->contains(NestedLoop)
3395  ? (L->getLoopDepth() < NestedLoop->getLoopDepth())
3396  : (!NestedLoop->contains(L) &&
3397  DT.dominates(L->getHeader(), NestedLoop->getHeader()))) {
3398  SmallVector<const SCEV *, 4> NestedOperands(NestedAR->operands());
3399  Operands[0] = NestedAR->getStart();
3400  // AddRecs require their operands be loop-invariant with respect to their
3401  // loops. Don't perform this transformation if it would break this
3402  // requirement.
3403  bool AllInvariant = all_of(
3404  Operands, [&](const SCEV *Op) { return isLoopInvariant(Op, L); });
3405 
3406  if (AllInvariant) {
3407  // Create a recurrence for the outer loop with the same step size.
3408  //
3409  // The outer recurrence keeps its NW flag but only keeps NUW/NSW if the
3410  // inner recurrence has the same property.
3411  SCEV::NoWrapFlags OuterFlags =
3412  maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags());
3413 
3414  NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags);
3415  AllInvariant = all_of(NestedOperands, [&](const SCEV *Op) {
3416  return isLoopInvariant(Op, NestedLoop);
3417  });
3418 
3419  if (AllInvariant) {
3420  // Ok, both add recurrences are valid after the transformation.
3421  //
3422  // The inner recurrence keeps its NW flag but only keeps NUW/NSW if
3423  // the outer recurrence has the same property.
3424  SCEV::NoWrapFlags InnerFlags =
3425  maskFlags(NestedAR->getNoWrapFlags(), SCEV::FlagNW | Flags);
3426  return getAddRecExpr(NestedOperands, NestedLoop, InnerFlags);
3427  }
3428  }
3429  // Reset Operands to its original state.
3430  Operands[0] = NestedAR;
3431  }
3432  }
3433 
3434  // Okay, it looks like we really DO need an addrec expr. Check to see if we
3435  // already have one, otherwise create a new one.
3436  return getOrCreateAddRecExpr(Operands, L, Flags);
3437 }
3438 
3439 const SCEV *
3441  const SmallVectorImpl<const SCEV *> &IndexExprs) {
3442  const SCEV *BaseExpr = getSCEV(GEP->getPointerOperand());
3443  // getSCEV(Base)->getType() has the same address space as Base->getType()
3444  // because SCEV::getType() preserves the address space.
3445  Type *IntIdxTy = getEffectiveSCEVType(BaseExpr->getType());
3446  // FIXME(PR23527): Don't blindly transfer the inbounds flag from the GEP
3447  // instruction to its SCEV, because the Instruction may be guarded by control
3448  // flow and the no-overflow bits may not be valid for the expression in any
3449  // context. This can be fixed similarly to how these flags are handled for
3450  // adds.
3451  SCEV::NoWrapFlags OffsetWrap =
3452  GEP->isInBounds() ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
3453 
3454  Type *CurTy = GEP->getType();
3455  bool FirstIter = true;
3457  for (const SCEV *IndexExpr : IndexExprs) {
3458  // Compute the (potentially symbolic) offset in bytes for this index.
3459  if (StructType *STy = dyn_cast<StructType>(CurTy)) {
3460  // For a struct, add the member offset.
3461  ConstantInt *Index = cast<SCEVConstant>(IndexExpr)->getValue();
3462  unsigned FieldNo = Index->getZExtValue();
3463  const SCEV *FieldOffset = getOffsetOfExpr(IntIdxTy, STy, FieldNo);
3464  Offsets.push_back(FieldOffset);
3465 
3466  // Update CurTy to the type of the field at Index.
3467  CurTy = STy->getTypeAtIndex(Index);
3468  } else {
3469  // Update CurTy to its element type.
3470  if (FirstIter) {
3471  assert(isa<PointerType>(CurTy) &&
3472  "The first index of a GEP indexes a pointer");
3473  CurTy = GEP->getSourceElementType();
3474  FirstIter = false;
3475  } else {
3476  CurTy = GetElementPtrInst::getTypeAtIndex(CurTy, (uint64_t)0);
3477  }
3478  // For an array, add the element offset, explicitly scaled.
3479  const SCEV *ElementSize = getSizeOfExpr(IntIdxTy, CurTy);
3480  // Getelementptr indices are signed.
3481  IndexExpr = getTruncateOrSignExtend(IndexExpr, IntIdxTy);
3482 
3483  // Multiply the index by the element size to compute the element offset.
3484  const SCEV *LocalOffset = getMulExpr(IndexExpr, ElementSize, OffsetWrap);
3485  Offsets.push_back(LocalOffset);
3486  }
3487  }
3488 
3489  // Handle degenerate case of GEP without offsets.
3490  if (Offsets.empty())
3491  return BaseExpr;
3492 
3493  // Add the offsets together, assuming nsw if inbounds.
3494  const SCEV *Offset = getAddExpr(Offsets, OffsetWrap);
3495  // Add the base address and the offset. We cannot use the nsw flag, as the
3496  // base address is unsigned. However, if we know that the offset is
3497  // non-negative, we can use nuw.
3498  SCEV::NoWrapFlags BaseWrap = GEP->isInBounds() && isKnownNonNegative(Offset)
3500  return getAddExpr(BaseExpr, Offset, BaseWrap);
3501 }
3502 
3503 std::tuple<SCEV *, FoldingSetNodeID, void *>
3504 ScalarEvolution::findExistingSCEVInCache(SCEVTypes SCEVType,
3505  ArrayRef<const SCEV *> Ops) {
3507  void *IP = nullptr;
3508  ID.AddInteger(SCEVType);
3509  for (unsigned i = 0, e = Ops.size(); i != e; ++i)
3510  ID.AddPointer(Ops[i]);
3511  return std::tuple<SCEV *, FoldingSetNodeID, void *>(
3512  UniqueSCEVs.FindNodeOrInsertPos(ID, IP), std::move(ID), IP);
3513 }
3514 
3515 const SCEV *ScalarEvolution::getAbsExpr(const SCEV *Op, bool IsNSW) {
3517  return getSMaxExpr(Op, getNegativeSCEV(Op, Flags));
3518 }
3519 
3522  assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
3523  if (Ops.size() == 1) return Ops[0];
3524 #ifndef NDEBUG
3525  Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
3526  for (unsigned i = 1, e = Ops.size(); i != e; ++i)
3527  assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
3528  "Operand types don't match!");
3529 #endif
3530 
3531  bool IsSigned = Kind == scSMaxExpr || Kind == scSMinExpr;
3532  bool IsMax = Kind == scSMaxExpr || Kind == scUMaxExpr;
3533 
3534  // Sort by complexity, this groups all similar expression types together.
3535  GroupByComplexity(Ops, &LI, DT);
3536 
3537  // Check if we have created the same expression before.
3538  if (const SCEV *S = std::get<0>(findExistingSCEVInCache(Kind, Ops))) {
3539  return S;
3540  }
3541 
3542  // If there are any constants, fold them together.
3543  unsigned Idx = 0;
3544  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
3545  ++Idx;
3546  assert(Idx < Ops.size());
3547  auto FoldOp = [&](const APInt &LHS, const APInt &RHS) {
3548  if (Kind == scSMaxExpr)
3549  return APIntOps::smax(LHS, RHS);
3550  else if (Kind == scSMinExpr)
3551  return APIntOps::smin(LHS, RHS);
3552  else if (Kind == scUMaxExpr)
3553  return APIntOps::umax(LHS, RHS);
3554  else if (Kind == scUMinExpr)
3555  return APIntOps::umin(LHS, RHS);
3556  llvm_unreachable("Unknown SCEV min/max opcode");
3557  };
3558 
3559  while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
3560  // We found two constants, fold them together!
3561  ConstantInt *Fold = ConstantInt::get(
3562  getContext(), FoldOp(LHSC->getAPInt(), RHSC->getAPInt()));
3563  Ops[0] = getConstant(Fold);
3564  Ops.erase(Ops.begin()+1); // Erase the folded element
3565  if (Ops.size() == 1) return Ops[0];
3566  LHSC = cast<SCEVConstant>(Ops[0]);
3567  }
3568 
3569  bool IsMinV = LHSC->getValue()->isMinValue(IsSigned);
3570  bool IsMaxV = LHSC->getValue()->isMaxValue(IsSigned);
3571 
3572  if (IsMax ? IsMinV : IsMaxV) {
3573  // If we are left with a constant minimum(/maximum)-int, strip it off.
3574  Ops.erase(Ops.begin());
3575  --Idx;
3576  } else if (IsMax ? IsMaxV : IsMinV) {
3577  // If we have a max(/min) with a constant maximum(/minimum)-int,
3578  // it will always be the extremum.
3579  return LHSC;
3580  }
3581 
3582  if (Ops.size() == 1) return Ops[0];
3583  }
3584 
3585  // Find the first operation of the same kind
3586  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < Kind)
3587  ++Idx;
3588 
3589  // Check to see if one of the operands is of the same kind. If so, expand its
3590  // operands onto our operand list, and recurse to simplify.
3591  if (Idx < Ops.size()) {
3592  bool DeletedAny = false;
3593  while (Ops[Idx]->getSCEVType() == Kind) {
3594  const SCEVMinMaxExpr *SMME = cast<SCEVMinMaxExpr>(Ops[Idx]);
3595  Ops.erase(Ops.begin()+Idx);
3596  Ops.append(SMME->op_begin(), SMME->op_end());
3597  DeletedAny = true;
3598  }
3599 
3600  if (DeletedAny)
3601  return getMinMaxExpr(Kind, Ops);
3602  }
3603 
3604  // Okay, check to see if the same value occurs in the operand list twice. If
3605  // so, delete one. Since we sorted the list, these values are required to
3606  // be adjacent.
3607  llvm::CmpInst::Predicate GEPred =
3609  llvm::CmpInst::Predicate LEPred =
3611  llvm::CmpInst::Predicate FirstPred = IsMax ? GEPred : LEPred;
3612  llvm::CmpInst::Predicate SecondPred = IsMax ? LEPred : GEPred;
3613  for (unsigned i = 0, e = Ops.size() - 1; i != e; ++i) {
3614  if (Ops[i] == Ops[i + 1] ||
3615  isKnownViaNonRecursiveReasoning(FirstPred, Ops[i], Ops[i + 1])) {
3616  // X op Y op Y --> X op Y
3617  // X op Y --> X, if we know X, Y are ordered appropriately
3618  Ops.erase(Ops.begin() + i + 1, Ops.begin() + i + 2);
3619  --i;
3620  --e;
3621  } else if (isKnownViaNonRecursiveReasoning(SecondPred, Ops[i],
3622  Ops[i + 1])) {
3623  // X op Y --> Y, if we know X, Y are ordered appropriately
3624  Ops.erase(Ops.begin() + i, Ops.begin() + i + 1);
3625  --i;
3626  --e;
3627  }
3628  }
3629 
3630  if (Ops.size() == 1) return Ops[0];
3631 
3632  assert(!Ops.empty() && "Reduced smax down to nothing!");
3633 
3634  // Okay, it looks like we really DO need an expr. Check to see if we
3635  // already have one, otherwise create a new one.
3636  const SCEV *ExistingSCEV;
3638  void *IP;
3639  std::tie(ExistingSCEV, ID, IP) = findExistingSCEVInCache(Kind, Ops);
3640  if (ExistingSCEV)
3641  return ExistingSCEV;
3642  const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
3643  std::uninitialized_copy(Ops.begin(), Ops.end(), O);
3644  SCEV *S = new (SCEVAllocator)
3645  SCEVMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
3646 
3647  UniqueSCEVs.InsertNode(S, IP);
3648  addToLoopUseLists(S);
3649  return S;
3650 }
3651 
3652 const SCEV *ScalarEvolution::getSMaxExpr(const SCEV *LHS, const SCEV *RHS) {
3653  SmallVector<const SCEV *, 2> Ops = {LHS, RHS};
3654  return getSMaxExpr(Ops);
3655 }
3656 
3658  return getMinMaxExpr(scSMaxExpr, Ops);
3659 }
3660 
3661 const SCEV *ScalarEvolution::getUMaxExpr(const SCEV *LHS, const SCEV *RHS) {
3662  SmallVector<const SCEV *, 2> Ops = {LHS, RHS};
3663  return getUMaxExpr(Ops);
3664 }
3665 
3667  return getMinMaxExpr(scUMaxExpr, Ops);
3668 }
3669 
3671  const SCEV *RHS) {
3672  SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
3673  return getSMinExpr(Ops);
3674 }
3675 
3677  return getMinMaxExpr(scSMinExpr, Ops);
3678 }
3679 
3681  const SCEV *RHS) {
3682  SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
3683  return getUMinExpr(Ops);
3684 }
3685 
3687  return getMinMaxExpr(scUMinExpr, Ops);
3688 }
3689 
3690 const SCEV *
3692  ScalableVectorType *ScalableTy) {
3693  Constant *NullPtr = Constant::getNullValue(ScalableTy->getPointerTo());
3694  Constant *One = ConstantInt::get(IntTy, 1);
3695  Constant *GEP = ConstantExpr::getGetElementPtr(ScalableTy, NullPtr, One);
3696  // Note that the expression we created is the final expression, we don't
3697  // want to simplify it any further Also, if we call a normal getSCEV(),
3698  // we'll end up in an endless recursion. So just create an SCEVUnknown.
3699  return getUnknown(ConstantExpr::getPtrToInt(GEP, IntTy));
3700 }
3701 
3702 const SCEV *ScalarEvolution::getSizeOfExpr(Type *IntTy, Type *AllocTy) {
3703  if (auto *ScalableAllocTy = dyn_cast<ScalableVectorType>(AllocTy))
3704  return getSizeOfScalableVectorExpr(IntTy, ScalableAllocTy);
3705  // We can bypass creating a target-independent constant expression and then
3706  // folding it back into a ConstantInt. This is just a compile-time
3707  // optimization.
3708  return getConstant(IntTy, getDataLayout().getTypeAllocSize(AllocTy));
3709 }
3710 
3712  if (auto *ScalableStoreTy = dyn_cast<ScalableVectorType>(StoreTy))
3713  return getSizeOfScalableVectorExpr(IntTy, ScalableStoreTy);
3714  // We can bypass creating a target-independent constant expression and then
3715  // folding it back into a ConstantInt. This is just a compile-time
3716  // optimization.
3717  return getConstant(IntTy, getDataLayout().getTypeStoreSize(StoreTy));
3718 }
3719 
3721  StructType *STy,
3722  unsigned FieldNo) {
3723  // We can bypass creating a target-independent constant expression and then
3724  // folding it back into a ConstantInt. This is just a compile-time
3725  // optimization.
3726  return getConstant(
3727  IntTy, getDataLayout().getStructLayout(STy)->getElementOffset(FieldNo));
3728 }
3729 
3731  // Don't attempt to do anything other than create a SCEVUnknown object
3732  // here. createSCEV only calls getUnknown after checking for all other
3733  // interesting possibilities, and any other code that calls getUnknown
3734  // is doing so in order to hide a value from SCEV canonicalization.
3735 
3737  ID.AddInteger(scUnknown);
3738  ID.AddPointer(V);
3739  void *IP = nullptr;
3740  if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) {
3741  assert(cast<SCEVUnknown>(S)->getValue() == V &&
3742  "Stale SCEVUnknown in uniquing map!");
3743  return S;
3744  }
3745  SCEV *S = new (SCEVAllocator) SCEVUnknown(ID.Intern(SCEVAllocator), V, this,
3746  FirstUnknown);
3747  FirstUnknown = cast<SCEVUnknown>(S);
3748  UniqueSCEVs.InsertNode(S, IP);
3749  return S;
3750 }
3751 
3752 //===----------------------------------------------------------------------===//
3753 // Basic SCEV Analysis and PHI Idiom Recognition Code
3754 //
3755 
3756 /// Test if values of the given type are analyzable within the SCEV
3757 /// framework. This primarily includes integer types, and it can optionally
3758 /// include pointer types if the ScalarEvolution class has access to
3759 /// target-specific information.
3761  // Integers and pointers are always SCEVable.
3762  return Ty->isIntOrPtrTy();
3763 }
3764 
3765 /// Return the size in bits of the specified type, for which isSCEVable must
3766 /// return true.
3768  assert(isSCEVable(Ty) && "Type is not SCEVable!");
3769  if (Ty->isPointerTy())
3771  return getDataLayout().getTypeSizeInBits(Ty);
3772 }
3773 
3774 /// Return a type with the same bitwidth as the given type and which represents
3775 /// how SCEV will treat the given type, for which isSCEVable must return
3776 /// true. For pointer types, this is the pointer index sized integer type.
3778  assert(isSCEVable(Ty) && "Type is not SCEVable!");
3779 
3780  if (Ty->isIntegerTy())
3781  return Ty;
3782 
3783  // The only other support type is pointer.
3784  assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!");
3785  return getDataLayout().getIndexType(Ty);
3786 }
3787 
3789  return getTypeSizeInBits(T1) >= getTypeSizeInBits(T2) ? T1 : T2;
3790 }
3791 
3793  return CouldNotCompute.get();
3794 }
3795 
3796 bool ScalarEvolution::checkValidity(const SCEV *S) const {
3797  bool ContainsNulls = SCEVExprContains(S, [](const SCEV *S) {
3798  auto *SU = dyn_cast<SCEVUnknown>(S);
3799  return SU && SU->getValue() == nullptr;
3800  });
3801 
3802  return !ContainsNulls;
3803 }
3804 
3806  HasRecMapType::iterator I = HasRecMap.find(S);
3807  if (I != HasRecMap.end())
3808  return I->second;
3809 
3810  bool FoundAddRec =
3811  SCEVExprContains(S, [](const SCEV *S) { return isa<SCEVAddRecExpr>(S); });
3812  HasRecMap.insert({S, FoundAddRec});
3813  return FoundAddRec;
3814 }
3815 
3816 /// Try to split a SCEVAddExpr into a pair of {SCEV, ConstantInt}.
3817 /// If \p S is a SCEVAddExpr and is composed of a sub SCEV S' and an
3818 /// offset I, then return {S', I}, else return {\p S, nullptr}.
3819 static std::pair<const SCEV *, ConstantInt *> splitAddExpr(const SCEV *S) {
3820  const auto *Add = dyn_cast<SCEVAddExpr>(S);
3821  if (!Add)
3822  return {S, nullptr};
3823 
3824  if (Add->getNumOperands() != 2)
3825  return {S, nullptr};
3826 
3827  auto *ConstOp = dyn_cast<SCEVConstant>(Add->getOperand(0));
3828  if (!ConstOp)
3829  return {S, nullptr};
3830 
3831  return {Add->getOperand(1), ConstOp->getValue()};
3832 }
3833 
3834 /// Return the ValueOffsetPair set for \p S. \p S can be represented
3835 /// by the value and offset from any ValueOffsetPair in the set.
3837 ScalarEvolution::getSCEVValues(const SCEV *S) {
3838  ExprValueMapType::iterator SI = ExprValueMap.find_as(S);
3839  if (SI == ExprValueMap.end())
3840  return nullptr;
3841 #ifndef NDEBUG
3842  if (VerifySCEVMap) {
3843  // Check there is no dangling Value in the set returned.
3844  for (const auto &VE : SI->second)
3845  assert(ValueExprMap.count(VE.first));
3846  }
3847 #endif
3848  return &SI->second;
3849 }
3850 
3851 /// Erase Value from ValueExprMap and ExprValueMap. ValueExprMap.erase(V)
3852 /// cannot be used separately. eraseValueFromMap should be used to remove
3853 /// V from ValueExprMap and ExprValueMap at the same time.
3855  ValueExprMapType::iterator I = ValueExprMap.find_as(V);
3856  if (I != ValueExprMap.end()) {
3857  const SCEV *S = I->second;
3858  // Remove {V, 0} from the set of ExprValueMap[S]
3859  if (SetVector<ValueOffsetPair> *SV = getSCEVValues(S))
3860  SV->remove({V, nullptr});
3861 
3862  // Remove {V, Offset} from the set of ExprValueMap[Stripped]
3863  const SCEV *Stripped;
3865  std::tie(Stripped, Offset) = splitAddExpr(S);
3866  if (Offset != nullptr) {
3867  if (SetVector<ValueOffsetPair> *SV = getSCEVValues(Stripped))
3868  SV->remove({V, Offset});
3869  }
3870  ValueExprMap.erase(V);
3871  }
3872 }
3873 
3874 /// Check whether value has nuw/nsw/exact set but SCEV does not.
3875 /// TODO: In reality it is better to check the poison recursively
3876 /// but this is better than nothing.
3877 static bool SCEVLostPoisonFlags(const SCEV *S, const Value *V) {
3878  if (auto *I = dyn_cast<Instruction>(V)) {
3879  if (isa<OverflowingBinaryOperator>(I)) {
3880  if (auto *NS = dyn_cast<SCEVNAryExpr>(S)) {
3881  if (I->hasNoSignedWrap() && !NS->hasNoSignedWrap())
3882  return true;
3883  if (I->hasNoUnsignedWrap() && !NS->hasNoUnsignedWrap())
3884  return true;
3885  }
3886  } else if (isa<PossiblyExactOperator>(I) && I->isExact())
3887  return true;
3888  }
3889  return false;
3890 }
3891 
3892 /// Return an existing SCEV if it exists, otherwise analyze the expression and
3893 /// create a new one.
3895  assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
3896 
3897  const SCEV *S = getExistingSCEV(V);
3898  if (S == nullptr) {
3899  S = createSCEV(V);
3900  // During PHI resolution, it is possible to create two SCEVs for the same
3901  // V, so it is needed to double check whether V->S is inserted into
3902  // ValueExprMap before insert S->{V, 0} into ExprValueMap.
3903  std::pair<ValueExprMapType::iterator, bool> Pair =
3904  ValueExprMap.insert({SCEVCallbackVH(V, this), S});
3905  if (Pair.second && !SCEVLostPoisonFlags(S, V)) {
3906  ExprValueMap[S].insert({V, nullptr});
3907 
3908  // If S == Stripped + Offset, add Stripped -> {V, Offset} into
3909  // ExprValueMap.
3910  const SCEV *Stripped = S;
3911  ConstantInt *Offset = nullptr;
3912  std::tie(Stripped, Offset) = splitAddExpr(S);
3913  // If stripped is SCEVUnknown, don't bother to save
3914  // Stripped -> {V, offset}. It doesn't simplify and sometimes even
3915  // increase the complexity of the expansion code.
3916  // If V is GetElementPtrInst, don't save Stripped -> {V, offset}
3917  // because it may generate add/sub instead of GEP in SCEV expansion.
3918  if (Offset != nullptr && !isa<SCEVUnknown>(Stripped) &&
3919  !isa<GetElementPtrInst>(V))
3920  ExprValueMap[Stripped].insert({V, Offset});
3921  }
3922  }
3923  return S;
3924 }
3925 
3926 const SCEV *ScalarEvolution::getExistingSCEV(Value *V) {
3927  assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
3928 
3929  ValueExprMapType::iterator I = ValueExprMap.find_as(V);
3930  if (I != ValueExprMap.end()) {
3931  const SCEV *S = I->second;
3932  if (checkValidity(S))
3933  return S;
3934  eraseValueFromMap(V);
3935  forgetMemoizedResults(S);
3936  }
3937  return nullptr;
3938 }
3939 
3940 /// Return a SCEV corresponding to -V = -1*V
3942  SCEV::NoWrapFlags Flags) {
3943  if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
3944  return getConstant(
3945  cast<ConstantInt>(ConstantExpr::getNeg(VC->getValue())));
3946 
3947  Type *Ty = V->getType();
3948  Ty = getEffectiveSCEVType(Ty);
3949  return getMulExpr(V, getMinusOne(Ty), Flags);
3950 }
3951 
3952 /// If Expr computes ~A, return A else return nullptr
3953 static const SCEV *MatchNotExpr(const SCEV *Expr) {
3954  const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Expr);
3955  if (!Add || Add->getNumOperands() != 2 ||
3956  !Add->getOperand(0)->isAllOnesValue())
3957  return nullptr;
3958 
3959  const SCEVMulExpr *AddRHS = dyn_cast<SCEVMulExpr>(Add->getOperand(1));
3960  if (!AddRHS || AddRHS->getNumOperands() != 2 ||
3961  !AddRHS->getOperand(0)->isAllOnesValue())
3962  return nullptr;
3963 
3964  return AddRHS->getOperand(1);
3965 }
3966 
3967 /// Return a SCEV corresponding to ~V = -1-V
3969  if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
3970  return getConstant(
3971  cast<ConstantInt>(ConstantExpr::getNot(VC->getValue())));
3972 
3973  // Fold ~(u|s)(min|max)(~x, ~y) to (u|s)(max|min)(x, y)
3974  if (const SCEVMinMaxExpr *MME = dyn_cast<SCEVMinMaxExpr>(V)) {
3975  auto MatchMinMaxNegation = [&](const SCEVMinMaxExpr *MME) {
3976  SmallVector<const SCEV *, 2> MatchedOperands;
3977  for (const SCEV *Operand : MME->operands()) {
3978  const SCEV *Matched = MatchNotExpr(Operand);
3979  if (!Matched)
3980  return (const SCEV *)nullptr;
3981  MatchedOperands.push_back(Matched);
3982  }
3983  return getMinMaxExpr(SCEVMinMaxExpr::negate(MME->getSCEVType()),
3984  MatchedOperands);
3985  };
3986  if (const SCEV *Replaced = MatchMinMaxNegation(MME))
3987  return Replaced;
3988  }
3989 
3990  Type *Ty = V->getType();
3991  Ty = getEffectiveSCEVType(Ty);
3992  return getMinusSCEV(getMinusOne(Ty), V);
3993 }
3994 
3995 const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS,
3996  SCEV::NoWrapFlags Flags,
3997  unsigned Depth) {
3998  // Fast path: X - X --> 0.
3999  if (LHS == RHS)
4000  return getZero(LHS->getType());
4001 
4002  // We represent LHS - RHS as LHS + (-1)*RHS. This transformation
4003  // makes it so that we cannot make much use of NUW.
4004  auto AddFlags = SCEV::FlagAnyWrap;
4005  const bool RHSIsNotMinSigned =
4007  if (maskFlags(Flags, SCEV::FlagNSW) == SCEV::FlagNSW) {
4008  // Let M be the minimum representable signed value. Then (-1)*RHS
4009  // signed-wraps if and only if RHS is M. That can happen even for
4010  // a NSW subtraction because e.g. (-1)*M signed-wraps even though
4011  // -1 - M does not. So to transfer NSW from LHS - RHS to LHS +
4012  // (-1)*RHS, we need to prove that RHS != M.
4013  //
4014  // If LHS is non-negative and we know that LHS - RHS does not
4015  // signed-wrap, then RHS cannot be M. So we can rule out signed-wrap
4016  // either by proving that RHS > M or that LHS >= 0.
4017  if (RHSIsNotMinSigned || isKnownNonNegative(LHS)) {
4018  AddFlags = SCEV::FlagNSW;
4019  }
4020  }
4021 
4022  // FIXME: Find a correct way to transfer NSW to (-1)*M when LHS -
4023  // RHS is NSW and LHS >= 0.
4024  //
4025  // The difficulty here is that the NSW flag may have been proven
4026  // relative to a loop that is to be found in a recurrence in LHS and
4027  // not in RHS. Applying NSW to (-1)*M may then let the NSW have a
4028  // larger scope than intended.
4029  auto NegFlags = RHSIsNotMinSigned ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
4030 
4031  return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags, Depth);
4032 }
4033 
4035  unsigned Depth) {
4036  Type *SrcTy = V->getType();
4037  assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4038  "Cannot truncate or zero extend with non-integer arguments!");
4039  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4040  return V; // No conversion
4041  if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4042  return getTruncateExpr(V, Ty, Depth);
4043  return getZeroExtendExpr(V, Ty, Depth);
4044 }
4045 
4047  unsigned Depth) {
4048  Type *SrcTy = V->getType();
4049  assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4050  "Cannot truncate or zero extend with non-integer arguments!");
4051  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4052  return V; // No conversion
4053  if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
4054  return getTruncateExpr(V, Ty, Depth);
4055  return getSignExtendExpr(V, Ty, Depth);
4056 }
4057 
4058 const SCEV *
4060  Type *SrcTy = V->getType();
4061  assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4062  "Cannot noop or zero extend with non-integer arguments!");
4063  assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
4064  "getNoopOrZeroExtend cannot truncate!");
4065  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4066  return V; // No conversion
4067  return getZeroExtendExpr(V, Ty);
4068 }
4069 
4070 const SCEV *
4072  Type *SrcTy = V->getType();
4073  assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4074  "Cannot noop or sign extend with non-integer arguments!");
4075  assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
4076  "getNoopOrSignExtend cannot truncate!");
4077  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4078  return V; // No conversion
4079  return getSignExtendExpr(V, Ty);
4080 }
4081 
4082 const SCEV *
4084  Type *SrcTy = V->getType();
4085  assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4086  "Cannot noop or any extend with non-integer arguments!");
4087  assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
4088  "getNoopOrAnyExtend cannot truncate!");
4089  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4090  return V; // No conversion
4091  return getAnyExtendExpr(V, Ty);
4092 }
4093 
4094 const SCEV *
4096  Type *SrcTy = V->getType();
4097  assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
4098  "Cannot truncate or noop with non-integer arguments!");
4099  assert(getTypeSizeInBits(SrcTy) >= getTypeSizeInBits(Ty) &&
4100  "getTruncateOrNoop cannot extend!");
4101  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
4102  return V; // No conversion
4103  return getTruncateExpr(V, Ty);
4104 }
4105 
4107  const SCEV *RHS) {
4108  const SCEV *PromotedLHS = LHS;
4109  const SCEV *PromotedRHS = RHS;
4110 
4111  if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType()))
4112  PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
4113  else
4114  PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
4115 
4116  return getUMaxExpr(PromotedLHS, PromotedRHS);
4117 }
4118 
4120  const SCEV *RHS) {
4121  SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
4122  return getUMinFromMismatchedTypes(Ops);
4123 }
4124 
4127  assert(!Ops.empty() && "At least one operand must be!");
4128  // Trivial case.
4129  if (Ops.size() == 1)
4130  return Ops[0];
4131 
4132  // Find the max type first.
4133  Type *MaxType = nullptr;
4134  for (auto *S : Ops)
4135  if (MaxType)
4136  MaxType = getWiderType(MaxType, S->getType());
4137  else
4138  MaxType = S->getType();
4139  assert(MaxType && "Failed to find maximum type!");
4140 
4141  // Extend all ops to max type.
4142  SmallVector<const SCEV *, 2> PromotedOps;
4143  for (auto *S : Ops)
4144  PromotedOps.push_back(getNoopOrZeroExtend(S, MaxType));
4145 
4146  // Generate umin.
4147  return getUMinExpr(PromotedOps);
4148 }
4149 
4151  // A pointer operand may evaluate to a nonpointer expression, such as null.
4152  if (!V->getType()->isPointerTy())
4153  return V;
4154 
4155  while (true) {
4156  if (const SCEVIntegralCastExpr *Cast = dyn_cast<SCEVIntegralCastExpr>(V)) {
4157  V = Cast->getOperand();
4158  } else if (const SCEVNAryExpr *NAry = dyn_cast<SCEVNAryExpr>(V)) {
4159  const SCEV *PtrOp = nullptr;
4160  for (const SCEV *NAryOp : NAry->operands()) {
4161  if (NAryOp->getType()->isPointerTy()) {
4162  // Cannot find the base of an expression with multiple pointer ops.
4163  if (PtrOp)
4164  return V;
4165  PtrOp = NAryOp;
4166  }
4167  }
4168  if (!PtrOp) // All operands were non-pointer.
4169  return V;
4170  V = PtrOp;
4171  } else // Not something we can look further into.
4172  return V;
4173  }
4174 }
4175 
4176 /// Push users of the given Instruction onto the given Worklist.
4177 static void
4179  SmallVectorImpl<Instruction *> &Worklist) {
4180  // Push the def-use children onto the Worklist stack.
4181  for (User *U : I->users())
4182  Worklist.push_back(cast<Instruction>(U));
4183 }
4184 
4185 void ScalarEvolution::forgetSymbolicName(Instruction *PN, const SCEV *SymName) {
4187  PushDefUseChildren(PN, Worklist);
4188 
4190  Visited.insert(PN);
4191  while (!Worklist.empty()) {
4192  Instruction *I = Worklist.pop_back_val();
4193  if (!Visited.insert(I).second)
4194  continue;
4195 
4196  auto It = ValueExprMap.find_as(static_cast<Value *>(I));
4197  if (It != ValueExprMap.end()) {
4198  const SCEV *Old = It->second;
4199 
4200  // Short-circuit the def-use traversal if the symbolic name
4201  // ceases to appear in expressions.
4202  if (Old != SymName && !hasOperand(Old, SymName))
4203  continue;
4204 
4205  // SCEVUnknown for a PHI either means that it has an unrecognized
4206  // structure, it's a PHI that's in the progress of being computed
4207  // by createNodeForPHI, or it's a single-value PHI. In the first case,
4208  // additional loop trip count information isn't going to change anything.
4209  // In the second case, createNodeForPHI will perform the necessary
4210  // updates on its own when it gets to that point. In the third, we do
4211  // want to forget the SCEVUnknown.
4212  if (!isa<PHINode>(I) ||
4213  !isa<SCEVUnknown>(Old) ||
4214  (I != PN && Old == SymName)) {
4215  eraseValueFromMap(It->first);
4216  forgetMemoizedResults(Old);
4217  }
4218  }
4219 
4220  PushDefUseChildren(I, Worklist);
4221  }
4222 }
4223 
4224 namespace {
4225 
4226 /// Takes SCEV S and Loop L. For each AddRec sub-expression, use its start
4227 /// expression in case its Loop is L. If it is not L then
4228 /// if IgnoreOtherLoops is true then use AddRec itself
4229 /// otherwise rewrite cannot be done.
4230 /// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
4231 class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> {
4232 public:
4233  static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
4234  bool IgnoreOtherLoops = true) {
4235  SCEVInitRewriter Rewriter(L, SE);
4236  const SCEV *Result = Rewriter.visit(S);
4237  if (Rewriter.hasSeenLoopVariantSCEVUnknown())
4238  return SE.getCouldNotCompute();
4239  return Rewriter.hasSeenOtherLoops() && !IgnoreOtherLoops
4240  ? SE.getCouldNotCompute()
4241  : Result;
4242  }
4243 
4244  const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4245  if (!SE.isLoopInvariant(Expr, L))
4246  SeenLoopVariantSCEVUnknown = true;
4247  return Expr;
4248  }
4249 
4250  const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4251  // Only re-write AddRecExprs for this loop.
4252  if (Expr->getLoop() == L)
4253  return Expr->getStart();
4254  SeenOtherLoops = true;
4255  return Expr;
4256  }
4257 
4258  bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
4259 
4260  bool hasSeenOtherLoops() { return SeenOtherLoops; }
4261 
4262 private:
4263  explicit SCEVInitRewriter(const Loop *L, ScalarEvolution &SE)
4264  : SCEVRewriteVisitor(SE), L(L) {}
4265 
4266  const Loop *L;
4267  bool SeenLoopVariantSCEVUnknown = false;
4268  bool SeenOtherLoops = false;
4269 };
4270 
4271 /// Takes SCEV S and Loop L. For each AddRec sub-expression, use its post
4272 /// increment expression in case its Loop is L. If it is not L then
4273 /// use AddRec itself.
4274 /// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
4275 class SCEVPostIncRewriter : public SCEVRewriteVisitor<SCEVPostIncRewriter> {
4276 public:
4277  static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE) {
4278  SCEVPostIncRewriter Rewriter(L, SE);
4279  const SCEV *Result = Rewriter.visit(S);
4280  return Rewriter.hasSeenLoopVariantSCEVUnknown()
4281  ? SE.getCouldNotCompute()
4282  : Result;
4283  }
4284 
4285  const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4286  if (!SE.isLoopInvariant(Expr, L))
4287  SeenLoopVariantSCEVUnknown = true;
4288  return Expr;
4289  }
4290 
4291  const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4292  // Only re-write AddRecExprs for this loop.
4293  if (Expr->getLoop() == L)
4294  return Expr->getPostIncExpr(SE);
4295  SeenOtherLoops = true;
4296  return Expr;
4297  }
4298 
4299  bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
4300 
4301  bool hasSeenOtherLoops() { return SeenOtherLoops; }
4302 
4303 private:
4304  explicit SCEVPostIncRewriter(const Loop *L, ScalarEvolution &SE)
4305  : SCEVRewriteVisitor(SE), L(L) {}
4306 
4307  const Loop *L;
4308  bool SeenLoopVariantSCEVUnknown = false;
4309  bool SeenOtherLoops = false;
4310 };
4311 
4312 /// This class evaluates the compare condition by matching it against the
4313 /// condition of loop latch. If there is a match we assume a true value
4314 /// for the condition while building SCEV nodes.
4315 class SCEVBackedgeConditionFolder
4316  : public SCEVRewriteVisitor<SCEVBackedgeConditionFolder> {
4317 public:
4318  static const SCEV *rewrite(const SCEV *S, const Loop *L,
4319  ScalarEvolution &SE) {
4320  bool IsPosBECond = false;
4321  Value *BECond = nullptr;
4322  if (BasicBlock *Latch = L->getLoopLatch()) {
4323  BranchInst *BI = dyn_cast<BranchInst>(Latch->getTerminator());
4324  if (BI && BI->isConditional()) {
4325  assert(BI->getSuccessor(0) != BI->getSuccessor(1) &&
4326  "Both outgoing branches should not target same header!");
4327  BECond = BI->getCondition();
4328  IsPosBECond = BI->getSuccessor(0) == L->getHeader();
4329  } else {
4330  return S;
4331  }
4332  }
4333  SCEVBackedgeConditionFolder Rewriter(L, BECond, IsPosBECond, SE);
4334  return Rewriter.visit(S);
4335  }
4336 
4337  const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4338  const SCEV *Result = Expr;
4339  bool InvariantF = SE.isLoopInvariant(Expr, L);
4340 
4341  if (!InvariantF) {
4342  Instruction *I = cast<Instruction>(Expr->getValue());
4343  switch (I->getOpcode()) {
4344  case Instruction::Select: {
4345  SelectInst *SI = cast<SelectInst>(I);
4347  compareWithBackedgeCondition(SI->getCondition());
4348  if (Res.hasValue()) {
4349  bool IsOne = cast<SCEVConstant>(Res.getValue())->getValue()->isOne();
4350  Result = SE.getSCEV(IsOne ? SI->getTrueValue() : SI->getFalseValue());
4351  }
4352  break;
4353  }
4354  default: {
4355  Optional<const SCEV *> Res = compareWithBackedgeCondition(I);
4356  if (Res.hasValue())
4357  Result = Res.getValue();
4358  break;
4359  }
4360  }
4361  }
4362  return Result;
4363  }
4364 
4365 private:
4366  explicit SCEVBackedgeConditionFolder(const Loop *L, Value *BECond,
4367  bool IsPosBECond, ScalarEvolution &SE)
4368  : SCEVRewriteVisitor(SE), L(L), BackedgeCond(BECond),
4369  IsPositiveBECond(IsPosBECond) {}
4370 
4371  Optional<const SCEV *> compareWithBackedgeCondition(Value *IC);
4372 
4373  const Loop *L;
4374  /// Loop back condition.
4375  Value *BackedgeCond = nullptr;
4376  /// Set to true if loop back is on positive branch condition.
4377  bool IsPositiveBECond;
4378 };
4379 
4381 SCEVBackedgeConditionFolder::compareWithBackedgeCondition(Value *IC) {
4382 
4383  // If value matches the backedge condition for loop latch,
4384  // then return a constant evolution node based on loopback
4385  // branch taken.
4386  if (BackedgeCond == IC)
4387  return IsPositiveBECond ? SE.getOne(Type::getInt1Ty(SE.getContext()))
4388  : SE.getZero(Type::getInt1Ty(SE.getContext()));
4389  return None;
4390 }
4391 
4392 class SCEVShiftRewriter : public SCEVRewriteVisitor<SCEVShiftRewriter> {
4393 public:
4394  static const SCEV *rewrite(const SCEV *S, const Loop *L,
4395  ScalarEvolution &SE) {
4396  SCEVShiftRewriter Rewriter(L, SE);
4397  const SCEV *Result = Rewriter.visit(S);
4398  return Rewriter.isValid() ? Result : SE.getCouldNotCompute();
4399  }
4400 
4401  const SCEV *visitUnknown(const SCEVUnknown *Expr) {
4402  // Only allow AddRecExprs for this loop.
4403  if (!SE.isLoopInvariant(Expr, L))
4404  Valid = false;
4405  return Expr;
4406  }
4407 
4408  const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
4409  if (Expr->getLoop() == L && Expr->isAffine())
4410  return SE.getMinusSCEV(Expr, Expr->getStepRecurrence(SE));
4411  Valid = false;
4412  return Expr;
4413  }
4414 
4415  bool isValid() { return Valid; }
4416 
4417 private:
4418  explicit SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE)
4419  : SCEVRewriteVisitor(SE), L(L) {}
4420 
4421  const Loop *L;
4422  bool Valid = true;
4423 };
4424 
4425 } // end anonymous namespace
4426 
4428 ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) {
4429  if (!AR->isAffine())
4430  return SCEV::FlagAnyWrap;
4431 
4432  using OBO = OverflowingBinaryOperator;
4433 
4435 
4436  if (!AR->hasNoSignedWrap()) {
4437  ConstantRange AddRecRange = getSignedRange(AR);
4438  ConstantRange IncRange = getSignedRange(AR->getStepRecurrence(*this));
4439 
4441  Instruction::Add, IncRange, OBO::NoSignedWrap);
4442  if (NSWRegion.contains(AddRecRange))
4444  }
4445 
4446  if (!AR->hasNoUnsignedWrap()) {
4447  ConstantRange AddRecRange = getUnsignedRange(AR);
4448  ConstantRange IncRange = getUnsignedRange(AR->getStepRecurrence(*this));
4449 
4451  Instruction::Add, IncRange, OBO::NoUnsignedWrap);
4452  if (NUWRegion.contains(AddRecRange))
4454  }
4455 
4456  return Result;
4457 }
4458 
4460 ScalarEvolution::proveNoSignedWrapViaInduction(const SCEVAddRecExpr *AR) {
4462 
4463  if (AR->hasNoSignedWrap())
4464  return Result;
4465 
4466  if (!AR->isAffine())
4467  return Result;
4468 
4469  const SCEV *Step = AR->getStepRecurrence(*this);
4470  const Loop *L = AR->getLoop();
4471 
4472  // Check whether the backedge-taken count is SCEVCouldNotCompute.
4473  // Note that this serves two purposes: It filters out loops that are
4474  // simply not analyzable, and it covers the case where this code is
4475  // being called from within backedge-taken count analysis, such that
4476  // attempting to ask for the backedge-taken count would likely result
4477  // in infinite recursion. In the later case, the analysis code will
4478  // cope with a conservative value, and it will take care to purge
4479  // that value once it has finished.
4480  const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
4481 
4482  // Normally, in the cases we can prove no-overflow via a
4483  // backedge guarding condition, we can also compute a backedge
4484  // taken count for the loop. The exceptions are assumptions and
4485  // guards present in the loop -- SCEV is not great at exploiting
4486  // these to compute max backedge taken counts, but can still use
4487  // these to prove lack of overflow. Use this fact to avoid
4488  // doing extra work that may not pay off.
4489 
4490  if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
4491  AC.assumptions().empty())
4492  return Result;
4493 
4494  // If the backedge is guarded by a comparison with the pre-inc value the
4495  // addrec is safe. Also, if the entry is guarded by a comparison with the
4496  // start value and the backedge is guarded by a comparison with the post-inc
4497  // value, the addrec is safe.
4498  ICmpInst::Predicate Pred;
4499  const SCEV *OverflowLimit =
4500  getSignedOverflowLimitForStep(Step, &Pred, this);
4501  if (OverflowLimit &&
4502  (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) ||
4503  isKnownOnEveryIteration(Pred, AR, OverflowLimit))) {
4504  Result = setFlags(Result, SCEV::FlagNSW);
4505  }
4506  return Result;
4507 }
4509 ScalarEvolution::proveNoUnsignedWrapViaInduction(const SCEVAddRecExpr *AR) {
4511 
4512  if (AR->hasNoUnsignedWrap())
4513  return Result;
4514 
4515  if (!AR->isAffine())
4516  return Result;
4517 
4518  const SCEV *Step = AR->getStepRecurrence(*this);
4519  unsigned BitWidth = getTypeSizeInBits(AR->getType());
4520  const Loop *L = AR->getLoop();
4521 
4522  // Check whether the backedge-taken count is SCEVCouldNotCompute.
4523  // Note that this serves two purposes: It filters out loops that are
4524  // simply not analyzable, and it covers the case where this code is
4525  // being called from within backedge-taken count analysis, such that
4526  // attempting to ask for the backedge-taken count would likely result
4527  // in infinite recursion. In the later case, the analysis code will
4528  // cope with a conservative value, and it will take care to purge
4529  // that value once it has finished.
4530  const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
4531 
4532  // Normally, in the cases we can prove no-overflow via a
4533  // backedge guarding condition, we can also compute a backedge
4534  // taken count for the loop. The exceptions are assumptions and
4535  // guards present in the loop -- SCEV is not great at exploiting
4536  // these to compute max backedge taken counts, but can still use
4537  // these to prove lack of overflow. Use this fact to avoid
4538  // doing extra work that may not pay off.
4539 
4540  if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
4541  AC.assumptions().empty())
4542  return Result;
4543 
4544  // If the backedge is guarded by a comparison with the pre-inc value the
4545  // addrec is safe. Also, if the entry is guarded by a comparison with the
4546  // start value and the backedge is guarded by a comparison with the post-inc
4547  // value, the addrec is safe.
4548  if (isKnownPositive(Step)) {
4550  getUnsignedRangeMax(Step));
4553  Result = setFlags(Result, SCEV::FlagNUW);
4554  }
4555  }
4556 
4557  return Result;
4558 }
4559 
4560 namespace {
4561 
4562 /// Represents an abstract binary operation. This may exist as a
4563 /// normal instruction or constant expression, or may have been
4564 /// derived from an expression tree.
4565 struct BinaryOp {
4566  unsigned Opcode;
4567  Value *LHS;
4568  Value *RHS;
4569  bool IsNSW = false;
4570  bool IsNUW = false;
4571 
4572  /// Op is set if this BinaryOp corresponds to a concrete LLVM instruction or
4573  /// constant expression.
4574  Operator *Op = nullptr;
4575 
4576  explicit BinaryOp(Operator *Op)
4577  : Opcode(Op->getOpcode()), LHS(Op->getOperand(0)), RHS(Op->getOperand(1)),
4578  Op(Op) {
4579  if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Op)) {
4580  IsNSW = OBO->hasNoSignedWrap();
4581  IsNUW = OBO->hasNoUnsignedWrap();
4582  }
4583  }
4584 
4585  explicit BinaryOp(unsigned Opcode, Value *LHS, Value *RHS, bool IsNSW = false,
4586  bool IsNUW = false)
4587  : Opcode(Opcode), LHS(LHS), RHS(RHS), IsNSW(IsNSW), IsNUW(IsNUW) {}
4588 };
4589 
4590 } // end anonymous namespace
4591 
4592 /// Try to map \p V into a BinaryOp, and return \c None on failure.
4594  auto *Op = dyn_cast<Operator>(V);
4595  if (!Op)
4596  return None;
4597 
4598  // Implementation detail: all the cleverness here should happen without
4599  // creating new SCEV expressions -- our caller knowns tricks to avoid creating
4600  // SCEV expressions when possible, and we should not break that.
4601 
4602  switch (Op->getOpcode()) {
4603  case Instruction::Add:
4604  case Instruction::Sub:
4605  case Instruction::Mul:
4606  case Instruction::UDiv:
4607  case Instruction::URem:
4608  case Instruction::And:
4609  case Instruction::Or:
4610  case Instruction::AShr:
4611  case Instruction::Shl:
4612  return BinaryOp(Op);
4613 
4614  case Instruction::Xor:
4615  if (auto *RHSC = dyn_cast<ConstantInt>(Op->getOperand(1)))
4616  // If the RHS of the xor is a signmask, then this is just an add.
4617  // Instcombine turns add of signmask into xor as a strength reduction step.
4618  if (RHSC->getValue().isSignMask())
4619  return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1));
4620  return BinaryOp(Op);
4621 
4622  case Instruction::LShr:
4623  // Turn logical shift right of a constant into a unsigned divide.
4624  if (ConstantInt *SA = dyn_cast<ConstantInt>(Op->getOperand(1))) {
4625  uint32_t BitWidth = cast<IntegerType>(Op->getType())->getBitWidth();
4626 
4627  // If the shift count is not less than the bitwidth, the result of
4628  // the shift is undefined. Don't try to analyze it, because the
4629  // resolution chosen here may differ from the resolution chosen in
4630  // other parts of the compiler.
4631  if (SA->getValue().ult(BitWidth)) {
4632  Constant *X =
4633  ConstantInt::get(SA->getContext(),
4634  APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
4635  return BinaryOp(Instruction::UDiv, Op->getOperand(0), X);
4636  }
4637  }
4638  return BinaryOp(Op);
4639 
4640  case Instruction::ExtractValue: {
4641  auto *EVI = cast<ExtractValueInst>(Op);
4642  if (EVI->getNumIndices() != 1 || EVI->getIndices()[0] != 0)
4643  break;
4644 
4645  auto *WO = dyn_cast<WithOverflowInst>(EVI->getAggregateOperand());
4646  if (!WO)
4647  break;
4648 
4649  Instruction::BinaryOps BinOp = WO->getBinaryOp();
4650  bool Signed = WO->isSigned();
4651  // TODO: Should add nuw/nsw flags for mul as well.
4652  if (BinOp == Instruction::Mul || !isOverflowIntrinsicNoWrap(WO, DT))
4653  return BinaryOp(BinOp, WO->getLHS(), WO->getRHS());
4654 
4655  // Now that we know that all uses of the arithmetic-result component of
4656  // CI are guarded by the overflow check, we can go ahead and pretend
4657  // that the arithmetic is non-overflowing.
4658  return BinaryOp(BinOp, WO->getLHS(), WO->getRHS(),
4659  /* IsNSW = */ Signed, /* IsNUW = */ !Signed);
4660  }
4661 
4662  default:
4663  break;
4664  }
4665 
4666  // Recognise intrinsic loop.decrement.reg, and as this has exactly the same
4667  // semantics as a Sub, return a binary sub expression.
4668  if (auto *II = dyn_cast<IntrinsicInst>(V))
4669  if (II->getIntrinsicID() == Intrinsic::loop_decrement_reg)
4670  return BinaryOp(Instruction::Sub, II->getOperand(0), II->getOperand(1));
4671 
4672  return None;
4673 }
4674 
4675 /// Helper function to createAddRecFromPHIWithCasts. We have a phi
4676 /// node whose symbolic (unknown) SCEV is \p SymbolicPHI, which is updated via
4677 /// the loop backedge by a SCEVAddExpr, possibly also with a few casts on the
4678 /// way. This function checks if \p Op, an operand of this SCEVAddExpr,
4679 /// follows one of the following patterns:
4680 /// Op == (SExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
4681 /// Op == (ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)
4682 /// If the SCEV expression of \p Op conforms with one of the expected patterns
4683 /// we return the type of the truncation operation, and indicate whether the
4684 /// truncated type should be treated as signed/unsigned by setting
4685 /// \p Signed to true/false, respectively.
4686 static Type *isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI,
4687  bool &Signed, ScalarEvolution &SE) {
4688  // The case where Op == SymbolicPHI (that is, with no type conversions on
4689  // the way) is handled by the regular add recurrence creating logic and
4690  // would have already been triggered in createAddRecForPHI. Reaching it here
4691  // means that createAddRecFromPHI had failed for this PHI before (e.g.,
4692  // because one of the other operands of the SCEVAddExpr updating this PHI is
4693  // not invariant).
4694  //
4695  // Here we look for the case where Op = (ext(trunc(SymbolicPHI))), and in
4696  // this case predicates that allow us to prove that Op == SymbolicPHI will
4697  // be added.
4698  if (Op == SymbolicPHI)
4699  return nullptr;
4700 
4701  unsigned SourceBits = SE.getTypeSizeInBits(SymbolicPHI->getType());
4702  unsigned NewBits = SE.getTypeSizeInBits(Op->getType());
4703  if (SourceBits != NewBits)
4704  return nullptr;
4705 
4706  const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(Op);
4707  const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(Op);
4708  if (!SExt && !ZExt)
4709  return nullptr;
4710  const SCEVTruncateExpr *Trunc =
4711  SExt ? dyn_cast<SCEVTruncateExpr>(SExt->getOperand())
4712  : dyn_cast<SCEVTruncateExpr>(ZExt->getOperand());
4713  if (!Trunc)
4714  return nullptr;
4715  const SCEV *X = Trunc->getOperand();
4716  if (X != SymbolicPHI)
4717  return nullptr;
4718  Signed = SExt != nullptr;
4719  return Trunc->getType();
4720 }
4721 
4722 static const Loop *isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI) {
4723  if (!PN->getType()->isIntegerTy())
4724  return nullptr;
4725  const Loop *L = LI.getLoopFor(PN->getParent());
4726  if (!L || L->getHeader() != PN->getParent())
4727  return nullptr;
4728  return L;
4729 }
4730 
4731 // Analyze \p SymbolicPHI, a SCEV expression of a phi node, and check if the
4732 // computation that updates the phi follows the following pattern:
4733 // (SExt/ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy) + InvariantAccum
4734 // which correspond to a phi->trunc->sext/zext->add->phi update chain.
4735 // If so, try to see if it can be rewritten as an AddRecExpr under some
4736 // Predicates. If successful, return them as a pair. Also cache the results
4737 // of the analysis.
4738 //
4739 // Example usage scenario:
4740 // Say the Rewriter is called for the following SCEV:
4741 // 8 * ((sext i32 (trunc i64 %X to i32) to i64) + %Step)