80#define DEBUG_TYPE "loop-flatten"
82STATISTIC(NumFlattened,
"Number of loops flattened");
86 cl::desc(
"Limit on the cost of instructions that can be repeated due to "
92 cl::desc(
"Assume that the product of the two iteration "
93 "trip counts will never overflow"));
97 cl::desc(
"Widen the loop induction variables, if possible, so "
98 "overflow checks won't reject flattening"));
110 Loop *OuterLoop =
nullptr;
111 Loop *InnerLoop =
nullptr;
113 PHINode *InnerInductionPHI =
nullptr;
114 PHINode *OuterInductionPHI =
nullptr;
118 Value *InnerTripCount =
nullptr;
119 Value *OuterTripCount =
nullptr;
136 bool Widened =
false;
139 PHINode *NarrowInnerInductionPHI =
nullptr;
140 PHINode *NarrowOuterInductionPHI =
nullptr;
144 FlattenInfo(
Loop *OL,
Loop *IL) : OuterLoop(OL), InnerLoop(IL){};
146 bool isNarrowInductionPhi(
PHINode *Phi) {
150 return NarrowInnerInductionPHI == Phi || NarrowOuterInductionPHI == Phi;
152 bool isInnerLoopIncrement(
User *U) {
153 return InnerIncrement ==
U;
155 bool isOuterLoopIncrement(
User *U) {
156 return OuterIncrement ==
U;
158 bool isInnerLoopTest(
User *U) {
163 for (
User *U : OuterInductionPHI->
users()) {
164 if (isOuterLoopIncrement(U))
167 auto IsValidOuterPHIUses = [&] (
User *
U) ->
bool {
168 LLVM_DEBUG(
dbgs() <<
"Found use of outer induction variable: ";
U->dump());
169 if (!ValidOuterPHIUses.
count(U)) {
170 LLVM_DEBUG(
dbgs() <<
"Did not match expected pattern, bailing\n");
177 if (
auto *V = dyn_cast<TruncInst>(U)) {
178 for (
auto *K :
V->users()) {
179 if (!IsValidOuterPHIUses(K))
185 if (!IsValidOuterPHIUses(U))
191 bool matchLinearIVUser(
User *U,
Value *InnerTripCount,
193 LLVM_DEBUG(
dbgs() <<
"Checking linear i*M+j expression for: ";
U->dump());
194 Value *MatchedMul =
nullptr;
195 Value *MatchedItCount =
nullptr;
219 return !isInstructionTriviallyDead(cast<Instruction>(U));
227 if (Widened && IsAdd &&
228 (isa<SExtInst>(MatchedItCount) || isa<ZExtInst>(MatchedItCount))) {
230 "Unexpected type mismatch in types after widening");
231 MatchedItCount = isa<SExtInst>(MatchedItCount)
232 ? dyn_cast<SExtInst>(MatchedItCount)->getOperand(0)
233 : dyn_cast<ZExtInst>(MatchedItCount)->getOperand(0);
237 InnerTripCount->
dump());
239 if ((IsAdd || IsAddTrunc) && MatchedItCount == InnerTripCount) {
241 ValidOuterPHIUses.
insert(MatchedMul);
246 LLVM_DEBUG(
dbgs() <<
"Did not match expected pattern, bailing\n");
251 Value *SExtInnerTripCount = InnerTripCount;
253 (isa<SExtInst>(InnerTripCount) || isa<ZExtInst>(InnerTripCount)))
254 SExtInnerTripCount = cast<Instruction>(InnerTripCount)->getOperand(0);
256 for (
User *U : InnerInductionPHI->
users()) {
258 if (isInnerLoopIncrement(U)) {
259 LLVM_DEBUG(
dbgs() <<
"Use is inner loop increment, continuing\n");
265 if (isa<TruncInst>(U)) {
268 U = *
U->user_begin();
275 if (isInnerLoopTest(U)) {
280 if (!matchLinearIVUser(U, SExtInnerTripCount, ValidOuterPHIUses)) {
295 IterationInstructions.
insert(Increment);
313 if (isa<SCEVCouldNotCompute>(BackedgeTakenCount)) {
314 LLVM_DEBUG(
dbgs() <<
"Backedge-taken count is not predictable\n");
321 const SCEV *SCEVTripCount =
323 BackedgeTakenCount->
getType(), L);
326 if (SCEVRHS == SCEVTripCount)
330 const SCEV *BackedgeTCExt =
nullptr;
332 const SCEV *SCEVTripCountExt;
338 if (SCEVRHS != BackedgeTCExt && SCEVRHS != SCEVTripCountExt) {
345 if (SCEVRHS == BackedgeTCExt || SCEVRHS == BackedgeTakenCount) {
350 IterationInstructions);
361 auto *TripCountInst = dyn_cast<Instruction>(
RHS);
362 if (!TripCountInst) {
366 if ((!isa<ZExtInst>(TripCountInst) && !isa<SExtInst>(TripCountInst)) ||
367 SE->
getSCEV(TripCountInst->getOperand(0)) != SCEVTripCount) {
368 LLVM_DEBUG(
dbgs() <<
"Could not find valid extended trip count\n");
380 LLVM_DEBUG(
dbgs() <<
"Finding components of loop: " << L->getName() <<
"\n");
382 if (!L->isLoopSimplifyForm()) {
389 if (!L->isCanonical(*SE)) {
397 if (L->getExitingBlock() != Latch) {
405 InductionPHI = L->getInductionVariable(*SE);
422 ICmpInst *Compare = L->getLatchCmpInst();
423 if (!Compare || !IsValidPredicate(Compare->getUnsignedPredicate()) ||
424 Compare->hasNUsesOrMore(2)) {
429 IterationInstructions.
insert(BackBranch);
431 IterationInstructions.
insert(Compare);
440 if ((Compare->getOperand(0) != Increment || !Increment->hasNUses(2)) &&
441 !Increment->hasNUses(1)) {
450 Value *
RHS = Compare->getOperand(1);
453 Increment, BackBranch, SE, IsWidened);
472 SafeOuterPHIs.
insert(FI.OuterInductionPHI);
476 for (
PHINode &InnerPHI : FI.InnerLoop->getHeader()->phis()) {
479 if (&InnerPHI == FI.InnerInductionPHI)
481 if (FI.isNarrowInductionPhi(&InnerPHI))
486 assert(InnerPHI.getNumIncomingValues() == 2);
487 Value *PreHeaderValue =
488 InnerPHI.getIncomingValueForBlock(FI.InnerLoop->getLoopPreheader());
490 InnerPHI.getIncomingValueForBlock(FI.InnerLoop->getLoopLatch());
495 PHINode *OuterPHI = dyn_cast<PHINode>(PreHeaderValue);
496 if (!OuterPHI || OuterPHI->
getParent() != FI.OuterLoop->getHeader()) {
505 PHINode *LCSSAPHI = dyn_cast<PHINode>(
516 dbgs() <<
"LCSSA PHI incoming value does not match latch value\n");
523 SafeOuterPHIs.
insert(OuterPHI);
524 FI.InnerPHIsToTransform.insert(&InnerPHI);
527 for (
PHINode &OuterPHI : FI.OuterLoop->getHeader()->phis()) {
528 if (FI.isNarrowInductionPhi(&OuterPHI))
530 if (!SafeOuterPHIs.
count(&OuterPHI)) {
531 LLVM_DEBUG(
dbgs() <<
"found unsafe PHI in outer loop: "; OuterPHI.dump());
550 for (
auto *
B : FI.OuterLoop->getBlocks()) {
551 if (FI.InnerLoop->contains(
B))
555 if (!isa<PHINode>(&
I) && !
I.isTerminator() &&
557 LLVM_DEBUG(
dbgs() <<
"Cannot flatten because instruction may have "
566 if (IterationInstructions.
count(&
I))
582 RepeatedInstrCost +=
Cost;
586 LLVM_DEBUG(
dbgs() <<
"Cost of instructions that will be repeated: "
587 << RepeatedInstrCost <<
"\n");
591 LLVM_DEBUG(
dbgs() <<
"checkOuterLoopInsts: not profitable, bailing.\n");
612 if (!FI.checkInnerInductionPhiUsers(ValidOuterPHIUses))
617 if (!FI.checkOuterInductionPhiUsers(ValidOuterPHIUses))
621 dbgs() <<
"Found " << FI.LinearIVUses.size()
622 <<
" value(s) that can be replaced:\n";
623 for (
Value *V : FI.LinearIVUses) {
634 Function *
F = FI.OuterLoop->getHeader()->getParent();
639 return OverflowResult::NeverOverflows;
644 FI.InnerTripCount, FI.OuterTripCount,
DL, AC,
645 FI.OuterLoop->getLoopPreheader()->getTerminator(), DT);
646 if (OR != OverflowResult::MayOverflow)
649 for (
Value *V : FI.LinearIVUses) {
650 for (
Value *U : V->users()) {
651 if (
auto *
GEP = dyn_cast<GetElementPtrInst>(U)) {
652 for (
Value *GEPUser : U->users()) {
653 auto *GEPUserInst = cast<Instruction>(GEPUser);
654 if (!isa<LoadInst>(GEPUserInst) &&
655 !(isa<StoreInst>(GEPUserInst) &&
656 GEP == GEPUserInst->getOperand(1)))
665 if (
GEP->isInBounds() &&
666 V->getType()->getIntegerBitWidth() >=
667 DL.getPointerTypeSizeInBits(
GEP->getType())) {
669 dbgs() <<
"use of linear IV would be UB if overflow occurred: ";
671 return OverflowResult::NeverOverflows;
678 return OverflowResult::MayOverflow;
686 FI.InnerInductionPHI, FI.InnerTripCount,
687 FI.InnerIncrement, FI.InnerBranch, SE, FI.Widened))
690 FI.OuterInductionPHI, FI.OuterTripCount,
691 FI.OuterIncrement, FI.OuterBranch, SE, FI.Widened))
696 if (!FI.OuterLoop->isLoopInvariant(FI.InnerTripCount)) {
700 if (!FI.OuterLoop->isLoopInvariant(FI.OuterTripCount)) {
709 if (FI.InnerInductionPHI->getType() != FI.OuterInductionPHI->getType())
731 Function *
F = FI.OuterLoop->getHeader()->getParent();
732 LLVM_DEBUG(
dbgs() <<
"Checks all passed, doing the transformation\n");
736 FI.InnerLoop->getHeader());
738 Remark <<
"Flattened into outer loop";
742 Value *NewTripCount = BinaryOperator::CreateMul(
743 FI.InnerTripCount, FI.OuterTripCount,
"flatten.tripcount",
744 FI.OuterLoop->getLoopPreheader()->getTerminator());
746 NewTripCount->
dump());
750 FI.InnerInductionPHI->removeIncomingValue(FI.InnerLoop->getLoopLatch());
755 PHI->removeIncomingValue(FI.InnerLoop->getLoopLatch());
759 cast<User>(FI.OuterBranch->getCondition())->setOperand(1, NewTripCount);
762 BasicBlock *InnerExitBlock = FI.InnerLoop->getExitBlock();
763 BasicBlock *InnerExitingBlock = FI.InnerLoop->getExitingBlock();
768 DT->
deleteEdge(InnerExitingBlock, FI.InnerLoop->getHeader());
770 MSSAU->
removeEdge(InnerExitingBlock, FI.InnerLoop->getHeader());
775 for (
Value *V : FI.LinearIVUses) {
776 Value *OuterValue = FI.OuterInductionPHI;
778 OuterValue =
Builder.CreateTrunc(FI.OuterInductionPHI, V->getType(),
783 V->replaceAllUsesWith(OuterValue);
791 U->markLoopAsDeleted(*FI.InnerLoop, FI.InnerLoop->getName());
792 LI->
erase(FI.InnerLoop);
809 Module *M = FI.InnerLoop->getHeader()->getParent()->getParent();
810 auto &
DL = M->getDataLayout();
811 auto *InnerType = FI.InnerInductionPHI->getType();
812 auto *OuterType = FI.OuterInductionPHI->getType();
813 unsigned MaxLegalSize =
DL.getLargestLegalIntTypeSizeInBits();
814 auto *MaxLegalType =
DL.getLargestLegalIntType(M->getContext());
819 if (InnerType != OuterType ||
820 InnerType->getScalarSizeInBits() >= MaxLegalSize ||
821 MaxLegalType->getScalarSizeInBits() <
822 InnerType->getScalarSizeInBits() * 2) {
829 unsigned ElimExt = 0;
830 unsigned Widened = 0;
845 if (!CreateWideIV({FI.InnerInductionPHI, MaxLegalType,
false},
Deleted))
850 FI.InnerPHIsToTransform.insert(FI.InnerInductionPHI);
852 if (!CreateWideIV({FI.OuterInductionPHI, MaxLegalType,
false},
Deleted))
855 assert(Widened &&
"Widened IV expected");
859 FI.NarrowInnerInductionPHI = FI.InnerInductionPHI;
860 FI.NarrowOuterInductionPHI = FI.OuterInductionPHI;
871 dbgs() <<
"Loop flattening running on outer loop "
872 << FI.OuterLoop->getHeader()->getName() <<
" and inner loop "
873 << FI.InnerLoop->getHeader()->getName() <<
" in "
874 << FI.OuterLoop->getHeader()->getParent()->getName() <<
"\n");
893 if (FI.Widened && !CanFlatten)
906 if (OR == OverflowResult::AlwaysOverflowsHigh ||
907 OR == OverflowResult::AlwaysOverflowsLow) {
908 LLVM_DEBUG(
dbgs() <<
"Multiply would always overflow, so not profitable\n");
910 }
else if (OR == OverflowResult::MayOverflow) {
911 LLVM_DEBUG(
dbgs() <<
"Multiply might overflow, not flattening\n");
915 LLVM_DEBUG(
dbgs() <<
"Multiply cannot overflow, modifying loop in-place\n");
923 bool Changed =
false;
925 std::optional<MemorySSAUpdater> MSSAU;
940 FlattenInfo FI(OuterLoop, InnerLoop);
942 MSSAU ? &*MSSAU :
nullptr);
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static bool CanWidenIV(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI, ScalarEvolution *SE, AssumptionCache *AC, const TargetTransformInfo *TTI)
static bool verifyTripCount(Value *RHS, Loop *L, SmallPtrSetImpl< Instruction * > &IterationInstructions, PHINode *&InductionPHI, Value *&TripCount, BinaryOperator *&Increment, BranchInst *&BackBranch, ScalarEvolution *SE, bool IsWidened)
static bool FlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI, ScalarEvolution *SE, AssumptionCache *AC, const TargetTransformInfo *TTI, LPMUpdater *U, MemorySSAUpdater *MSSAU)
static cl::opt< bool > WidenIV("loop-flatten-widen-iv", cl::Hidden, cl::init(true), cl::desc("Widen the loop induction variables, if possible, so " "overflow checks won't reject flattening"))
static bool setLoopComponents(Value *&TC, Value *&TripCount, BinaryOperator *&Increment, SmallPtrSetImpl< Instruction * > &IterationInstructions)
static bool DoFlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI, ScalarEvolution *SE, AssumptionCache *AC, const TargetTransformInfo *TTI, LPMUpdater *U, MemorySSAUpdater *MSSAU)
static bool checkIVUsers(FlattenInfo &FI)
static bool CanFlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI, ScalarEvolution *SE, AssumptionCache *AC, const TargetTransformInfo *TTI)
static bool findLoopComponents(Loop *L, SmallPtrSetImpl< Instruction * > &IterationInstructions, PHINode *&InductionPHI, Value *&TripCount, BinaryOperator *&Increment, BranchInst *&BackBranch, ScalarEvolution *SE, bool IsWidened)
static OverflowResult checkOverflow(FlattenInfo &FI, DominatorTree *DT, AssumptionCache *AC)
static bool checkPHIs(FlattenInfo &FI, const TargetTransformInfo *TTI)
static cl::opt< unsigned > RepeatedInstructionThreshold("loop-flatten-cost-threshold", cl::Hidden, cl::init(2), cl::desc("Limit on the cost of instructions that can be repeated due to " "loop flattening"))
static cl::opt< bool > AssumeNoOverflow("loop-flatten-assume-no-overflow", cl::Hidden, cl::init(false), cl::desc("Assume that the product of the two iteration " "trip counts will never overflow"))
static bool checkOuterLoopInsts(FlattenInfo &FI, SmallPtrSetImpl< Instruction * > &IterationInstructions, const TargetTransformInfo *TTI)
This file defines the interface for the loop nest analysis.
This header provides classes for managing a pipeline of passes over loops in LLVM IR.
Module.h This file contains the declarations for the Module class.
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Virtual Register Rewriter
A container for analyses that lazily runs them and caches their results.
A cache of @llvm.assume calls within a function.
LLVM Basic Block Representation.
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Conditional or Unconditional Branch instruction.
static BranchInst * Create(BasicBlock *IfTrue, Instruction *InsertBefore=nullptr)
BasicBlock * getSuccessor(unsigned i) const
bool isUnconditional() const
Value * getCondition() const
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
@ ICMP_ULT
unsigned less than
This is the shared class of boolean and integer constants.
IntegerType * getType() const
getType - Specialize the getType() method to always return an IntegerType, which reduces the amount o...
static Constant * get(Type *Ty, uint64_t V, bool IsSigned=false)
If Ty is a vector type, return a Constant with a splat of the given value.
const APInt & getValue() const
Return the constant as an APInt value reference.
A parsed version of the target data layout string in and methods for querying it.
void deleteEdge(NodeT *From, NodeT *To)
Inform the dominator tree about a CFG edge deletion and update the tree.
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
This instruction compares its operands according to the predicate given to the constructor.
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
const BasicBlock * getParent() const
BasicBlock * getSuccessor(unsigned Idx) const LLVM_READONLY
Return the specified successor. This instruction must be a terminator.
SymbolTableList< Instruction >::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
This class provides an interface for updating the loop pass manager based on mutations to the loop ne...
LoopT * getParentLoop() const
Return the parent loop if it exists or nullptr for top level loops.
PreservedAnalyses run(LoopNest &LN, LoopAnalysisManager &LAM, LoopStandardAnalysisResults &AR, LPMUpdater &U)
void erase(Loop *L)
Update LoopInfo after removing the last backedge from a loop.
This class represents a loop nest and can be used to query its properties.
ArrayRef< Loop * > getLoops() const
Get the loops in the nest.
Represents a single loop in the control flow graph.
An analysis that produces MemorySSA for a function.
void removeEdge(BasicBlock *From, BasicBlock *To)
Update the MemoryPhi in To following an edge deletion between From and To.
void verifyMemorySSA(VerificationLevel=VerificationLevel::Fast) const
Verify that MemorySSA is self consistent (IE definitions dominate all uses, uses appear in the right ...
A Module instance is used to store all the information related to an LLVM module.
Value * getIncomingValueForBlock(const BasicBlock *BB) const
Value * hasConstantValue() const
If the specified PHI node always merges together the same value, return the value,...
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
This class uses information about analyze scalars to rewrite expressions in canonical form.
This class represents an analyzed expression in the program.
Type * getType() const
Return the LLVM type of this SCEV expression.
The main scalar evolution driver.
const SCEV * getBackedgeTakenCount(const Loop *L, ExitCountKind Kind=Exact)
If the specified loop has a predictable backedge-taken count, return it, otherwise return a SCEVCould...
const SCEV * getSCEV(Value *V)
Return a SCEV expression for the full generality of the specified expression.
const SCEV * getTripCountFromExitCount(const SCEV *ExitCount)
A version of getTripCountFromExitCount below which always picks an evaluation type which can not resu...
void forgetLoop(const Loop *L)
This method should be called by the client when it has changed a loop in a way that may effect Scalar...
const SCEV * getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
void forgetBlockAndLoopDispositions(Value *V=nullptr)
Called when the client has changed the disposition of values in a loop or block.
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
iterator_range< user_iterator > users()
LLVMContext & getContext() const
All values hold a context through their type.
void dump() const
Support for debugging, callable in GDB: V->dump()
bool match(Val *V, const Pattern &P)
specificval_ty m_Specific(const Value *V)
Match if we have a specific specified value.
CastClass_match< OpTy, Instruction::Trunc > m_Trunc(const OpTy &Op)
Matches Trunc.
BinaryOp_match< LHS, RHS, Instruction::Add, true > m_c_Add(const LHS &L, const RHS &R)
Matches a Add with LHS and RHS in either order.
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
BinaryOp_match< LHS, RHS, Instruction::Mul, true > m_c_Mul(const LHS &L, const RHS &R)
Matches a Mul with LHS and RHS in either order.
initializer< Ty > init(const Ty &Val)
This is an optimization pass for GlobalISel generic memory operations.
OverflowResult computeOverflowForUnsignedMul(const Value *LHS, const Value *RHS, const DataLayout &DL, AssumptionCache *AC, const Instruction *CxtI, const DominatorTree *DT, bool UseInstrInfo=true)
PHINode * createWideIV(const WideIVInfo &WI, LoopInfo *LI, ScalarEvolution *SE, SCEVExpander &Rewriter, DominatorTree *DT, SmallVectorImpl< WeakTrackingVH > &DeadInsts, unsigned &NumElimExt, unsigned &NumWidened, bool HasGuards, bool UsePostIncrementRanges)
Widen Induction Variables - Extend the width of an IV to cover its widest uses.
bool isGuaranteedToExecuteForEveryIteration(const Instruction *I, const Loop *L)
Return true if this function can prove that the instruction I is executed for every iteration of the ...
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
bool VerifyMemorySSA
Enables verification of MemorySSA.
bool isSafeToSpeculativelyExecute(const Instruction *I, const Instruction *CtxI=nullptr, AssumptionCache *AC=nullptr, const DominatorTree *DT=nullptr, const TargetLibraryInfo *TLI=nullptr)
Return true if the instruction does not have any effects besides calculating the result and does not ...
auto count_if(R &&Range, UnaryPredicate P)
Wrapper function around std::count_if to count the number of times an element satisfying a given pred...
PreservedAnalyses getLoopPassPreservedAnalyses()
Returns the minimum set of Analyses that all loop passes must preserve.
bool RecursivelyDeleteDeadPHINode(PHINode *PN, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr)
If the specified value is an effectively dead PHI node, due to being a def-use chain of single-use no...
The adaptor from a function pass to a loop pass computes these analyses and makes them available to t...
TargetTransformInfo & TTI
Collect information about induction variables that are used by sign/zero extend operations.