80#define DEBUG_TYPE "loop-flatten"
82STATISTIC(NumFlattened,
"Number of loops flattened");
86 cl::desc(
"Limit on the cost of instructions that can be repeated due to "
92 cl::desc(
"Assume that the product of the two iteration "
93 "trip counts will never overflow"));
97 cl::desc(
"Widen the loop induction variables, if possible, so "
98 "overflow checks won't reject flattening"));
110 Loop *OuterLoop =
nullptr;
111 Loop *InnerLoop =
nullptr;
113 PHINode *InnerInductionPHI =
nullptr;
114 PHINode *OuterInductionPHI =
nullptr;
118 Value *InnerTripCount =
nullptr;
119 Value *OuterTripCount =
nullptr;
136 bool Widened =
false;
139 PHINode *NarrowInnerInductionPHI =
nullptr;
140 PHINode *NarrowOuterInductionPHI =
nullptr;
144 FlattenInfo(
Loop *OL,
Loop *IL) : OuterLoop(OL), InnerLoop(IL){};
146 bool isNarrowInductionPhi(
PHINode *Phi) {
150 return NarrowInnerInductionPHI ==
Phi || NarrowOuterInductionPHI ==
Phi;
152 bool isInnerLoopIncrement(
User *U) {
153 return InnerIncrement ==
U;
155 bool isOuterLoopIncrement(
User *U) {
156 return OuterIncrement ==
U;
158 bool isInnerLoopTest(
User *U) {
163 for (
User *U : OuterInductionPHI->
users()) {
164 if (isOuterLoopIncrement(U))
167 auto IsValidOuterPHIUses = [&] (
User *
U) ->
bool {
168 LLVM_DEBUG(
dbgs() <<
"Found use of outer induction variable: ";
U->dump());
169 if (!ValidOuterPHIUses.
count(U)) {
170 LLVM_DEBUG(
dbgs() <<
"Did not match expected pattern, bailing\n");
177 if (
auto *V = dyn_cast<TruncInst>(U)) {
178 for (
auto *K :
V->users()) {
179 if (!IsValidOuterPHIUses(K))
185 if (!IsValidOuterPHIUses(U))
191 bool matchLinearIVUser(
User *U,
Value *InnerTripCount,
193 LLVM_DEBUG(
dbgs() <<
"Checking linear i*M+j expression for: ";
U->dump());
194 Value *MatchedMul =
nullptr;
195 Value *MatchedItCount =
nullptr;
219 return !isInstructionTriviallyDead(cast<Instruction>(U));
227 if (Widened && IsAdd &&
228 (isa<SExtInst>(MatchedItCount) || isa<ZExtInst>(MatchedItCount))) {
230 "Unexpected type mismatch in types after widening");
231 MatchedItCount = isa<SExtInst>(MatchedItCount)
232 ? dyn_cast<SExtInst>(MatchedItCount)->getOperand(0)
233 : dyn_cast<ZExtInst>(MatchedItCount)->getOperand(0);
237 InnerTripCount->
dump());
239 if ((IsAdd || IsAddTrunc) && MatchedItCount == InnerTripCount) {
241 ValidOuterPHIUses.
insert(MatchedMul);
246 LLVM_DEBUG(
dbgs() <<
"Did not match expected pattern, bailing\n");
251 Value *SExtInnerTripCount = InnerTripCount;
253 (isa<SExtInst>(InnerTripCount) || isa<ZExtInst>(InnerTripCount)))
254 SExtInnerTripCount = cast<Instruction>(InnerTripCount)->getOperand(0);
256 for (
User *U : InnerInductionPHI->
users()) {
258 if (isInnerLoopIncrement(U)) {
259 LLVM_DEBUG(
dbgs() <<
"Use is inner loop increment, continuing\n");
265 if (isa<TruncInst>(U)) {
268 U = *
U->user_begin();
275 if (isInnerLoopTest(U)) {
280 if (!matchLinearIVUser(U, SExtInnerTripCount, ValidOuterPHIUses)) {
295 IterationInstructions.
insert(Increment);
313 if (isa<SCEVCouldNotCompute>(BackedgeTakenCount)) {
314 LLVM_DEBUG(
dbgs() <<
"Backedge-taken count is not predictable\n");
321 const SCEV *SCEVTripCount =
323 BackedgeTakenCount->
getType(), L);
326 if (SCEVRHS == SCEVTripCount)
330 const SCEV *BackedgeTCExt =
nullptr;
332 const SCEV *SCEVTripCountExt;
338 if (SCEVRHS != BackedgeTCExt && SCEVRHS != SCEVTripCountExt) {
345 if (SCEVRHS == BackedgeTCExt || SCEVRHS == BackedgeTakenCount) {
350 IterationInstructions);
361 auto *TripCountInst = dyn_cast<Instruction>(
RHS);
362 if (!TripCountInst) {
366 if ((!isa<ZExtInst>(TripCountInst) && !isa<SExtInst>(TripCountInst)) ||
367 SE->
getSCEV(TripCountInst->getOperand(0)) != SCEVTripCount) {
368 LLVM_DEBUG(
dbgs() <<
"Could not find valid extended trip count\n");
380 LLVM_DEBUG(
dbgs() <<
"Finding components of loop: " << L->getName() <<
"\n");
382 if (!L->isLoopSimplifyForm()) {
389 if (!L->isCanonical(*SE)) {
397 if (L->getExitingBlock() != Latch) {
405 InductionPHI = L->getInductionVariable(*SE);
422 ICmpInst *Compare = L->getLatchCmpInst();
423 if (!Compare || !IsValidPredicate(Compare->getUnsignedPredicate()) ||
424 Compare->hasNUsesOrMore(2)) {
429 IterationInstructions.
insert(BackBranch);
431 IterationInstructions.
insert(Compare);
440 if ((Compare->getOperand(0) != Increment || !Increment->hasNUses(2)) &&
441 !Increment->hasNUses(1)) {
450 Value *
RHS = Compare->getOperand(1);
453 Increment, BackBranch, SE, IsWidened);
472 SafeOuterPHIs.
insert(FI.OuterInductionPHI);
476 for (
PHINode &InnerPHI : FI.InnerLoop->getHeader()->phis()) {
479 if (&InnerPHI == FI.InnerInductionPHI)
481 if (FI.isNarrowInductionPhi(&InnerPHI))
486 assert(InnerPHI.getNumIncomingValues() == 2);
487 Value *PreHeaderValue =
488 InnerPHI.getIncomingValueForBlock(FI.InnerLoop->getLoopPreheader());
490 InnerPHI.getIncomingValueForBlock(FI.InnerLoop->getLoopLatch());
495 PHINode *OuterPHI = dyn_cast<PHINode>(PreHeaderValue);
496 if (!OuterPHI || OuterPHI->
getParent() != FI.OuterLoop->getHeader()) {
505 PHINode *LCSSAPHI = dyn_cast<PHINode>(
516 dbgs() <<
"LCSSA PHI incoming value does not match latch value\n");
523 SafeOuterPHIs.
insert(OuterPHI);
524 FI.InnerPHIsToTransform.insert(&InnerPHI);
527 for (
PHINode &OuterPHI : FI.OuterLoop->getHeader()->phis()) {
528 if (FI.isNarrowInductionPhi(&OuterPHI))
530 if (!SafeOuterPHIs.
count(&OuterPHI)) {
531 LLVM_DEBUG(
dbgs() <<
"found unsafe PHI in outer loop: "; OuterPHI.dump());
550 for (
auto *
B : FI.OuterLoop->getBlocks()) {
551 if (FI.InnerLoop->contains(
B))
555 if (!isa<PHINode>(&
I) && !
I.isTerminator() &&
557 LLVM_DEBUG(
dbgs() <<
"Cannot flatten because instruction may have "
566 if (IterationInstructions.
count(&
I))
582 RepeatedInstrCost +=
Cost;
586 LLVM_DEBUG(
dbgs() <<
"Cost of instructions that will be repeated: "
587 << RepeatedInstrCost <<
"\n");
591 LLVM_DEBUG(
dbgs() <<
"checkOuterLoopInsts: not profitable, bailing.\n");
612 if (!FI.checkInnerInductionPhiUsers(ValidOuterPHIUses))
617 if (!FI.checkOuterInductionPhiUsers(ValidOuterPHIUses))
621 dbgs() <<
"Found " << FI.LinearIVUses.size()
622 <<
" value(s) that can be replaced:\n";
623 for (
Value *V : FI.LinearIVUses) {
634 Function *
F = FI.OuterLoop->getHeader()->getParent();
639 return OverflowResult::NeverOverflows;
644 FI.InnerTripCount, FI.OuterTripCount,
646 FI.OuterLoop->getLoopPreheader()->getTerminator()));
647 if (OR != OverflowResult::MayOverflow)
650 for (
Value *V : FI.LinearIVUses) {
651 for (
Value *U : V->users()) {
652 if (
auto *
GEP = dyn_cast<GetElementPtrInst>(U)) {
653 for (
Value *GEPUser : U->users()) {
654 auto *GEPUserInst = cast<Instruction>(GEPUser);
655 if (!isa<LoadInst>(GEPUserInst) &&
656 !(isa<StoreInst>(GEPUserInst) &&
657 GEP == GEPUserInst->getOperand(1)))
666 if (
GEP->isInBounds() &&
667 V->getType()->getIntegerBitWidth() >=
668 DL.getPointerTypeSizeInBits(
GEP->getType())) {
670 dbgs() <<
"use of linear IV would be UB if overflow occurred: ";
672 return OverflowResult::NeverOverflows;
679 return OverflowResult::MayOverflow;
687 FI.InnerInductionPHI, FI.InnerTripCount,
688 FI.InnerIncrement, FI.InnerBranch, SE, FI.Widened))
691 FI.OuterInductionPHI, FI.OuterTripCount,
692 FI.OuterIncrement, FI.OuterBranch, SE, FI.Widened))
697 if (!FI.OuterLoop->isLoopInvariant(FI.InnerTripCount)) {
701 if (!FI.OuterLoop->isLoopInvariant(FI.OuterTripCount)) {
710 if (FI.InnerInductionPHI->getType() != FI.OuterInductionPHI->getType())
732 Function *
F = FI.OuterLoop->getHeader()->getParent();
733 LLVM_DEBUG(
dbgs() <<
"Checks all passed, doing the transformation\n");
737 FI.InnerLoop->getHeader());
739 Remark <<
"Flattened into outer loop";
743 Value *NewTripCount = BinaryOperator::CreateMul(
744 FI.InnerTripCount, FI.OuterTripCount,
"flatten.tripcount",
745 FI.OuterLoop->getLoopPreheader()->getTerminator());
747 NewTripCount->
dump());
751 FI.InnerInductionPHI->removeIncomingValue(FI.InnerLoop->getLoopLatch());
756 PHI->removeIncomingValue(FI.InnerLoop->getLoopLatch());
760 cast<User>(FI.OuterBranch->getCondition())->setOperand(1, NewTripCount);
763 BasicBlock *InnerExitBlock = FI.InnerLoop->getExitBlock();
764 BasicBlock *InnerExitingBlock = FI.InnerLoop->getExitingBlock();
769 DT->
deleteEdge(InnerExitingBlock, FI.InnerLoop->getHeader());
771 MSSAU->
removeEdge(InnerExitingBlock, FI.InnerLoop->getHeader());
775 IRBuilder<> Builder(FI.OuterInductionPHI->getParent()->getTerminator());
776 for (
Value *V : FI.LinearIVUses) {
777 Value *OuterValue = FI.OuterInductionPHI;
779 OuterValue = Builder.
CreateTrunc(FI.OuterInductionPHI, V->getType(),
784 V->replaceAllUsesWith(OuterValue);
792 U->markLoopAsDeleted(*FI.InnerLoop, FI.InnerLoop->getName());
793 LI->
erase(FI.InnerLoop);
810 Module *M = FI.InnerLoop->getHeader()->getParent()->getParent();
811 auto &
DL = M->getDataLayout();
812 auto *InnerType = FI.InnerInductionPHI->getType();
813 auto *OuterType = FI.OuterInductionPHI->getType();
814 unsigned MaxLegalSize =
DL.getLargestLegalIntTypeSizeInBits();
815 auto *MaxLegalType =
DL.getLargestLegalIntType(M->getContext());
820 if (InnerType != OuterType ||
821 InnerType->getScalarSizeInBits() >= MaxLegalSize ||
822 MaxLegalType->getScalarSizeInBits() <
823 InnerType->getScalarSizeInBits() * 2) {
830 unsigned ElimExt = 0;
831 unsigned Widened = 0;
846 if (!CreateWideIV({FI.InnerInductionPHI, MaxLegalType,
false},
Deleted))
851 FI.InnerPHIsToTransform.insert(FI.InnerInductionPHI);
853 if (!CreateWideIV({FI.OuterInductionPHI, MaxLegalType,
false},
Deleted))
856 assert(Widened &&
"Widened IV expected");
860 FI.NarrowInnerInductionPHI = FI.InnerInductionPHI;
861 FI.NarrowOuterInductionPHI = FI.OuterInductionPHI;
872 dbgs() <<
"Loop flattening running on outer loop "
873 << FI.OuterLoop->getHeader()->getName() <<
" and inner loop "
874 << FI.InnerLoop->getHeader()->getName() <<
" in "
875 << FI.OuterLoop->getHeader()->getParent()->getName() <<
"\n");
894 if (FI.Widened && !CanFlatten)
907 if (OR == OverflowResult::AlwaysOverflowsHigh ||
908 OR == OverflowResult::AlwaysOverflowsLow) {
909 LLVM_DEBUG(
dbgs() <<
"Multiply would always overflow, so not profitable\n");
911 }
else if (OR == OverflowResult::MayOverflow) {
912 LLVM_DEBUG(
dbgs() <<
"Multiply might overflow, not flattening\n");
916 LLVM_DEBUG(
dbgs() <<
"Multiply cannot overflow, modifying loop in-place\n");
924 bool Changed =
false;
926 std::optional<MemorySSAUpdater> MSSAU;
941 FlattenInfo FI(OuterLoop, InnerLoop);
943 MSSAU ? &*MSSAU :
nullptr);
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static bool CanWidenIV(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI, ScalarEvolution *SE, AssumptionCache *AC, const TargetTransformInfo *TTI)
static bool verifyTripCount(Value *RHS, Loop *L, SmallPtrSetImpl< Instruction * > &IterationInstructions, PHINode *&InductionPHI, Value *&TripCount, BinaryOperator *&Increment, BranchInst *&BackBranch, ScalarEvolution *SE, bool IsWidened)
static bool FlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI, ScalarEvolution *SE, AssumptionCache *AC, const TargetTransformInfo *TTI, LPMUpdater *U, MemorySSAUpdater *MSSAU)
static cl::opt< bool > WidenIV("loop-flatten-widen-iv", cl::Hidden, cl::init(true), cl::desc("Widen the loop induction variables, if possible, so " "overflow checks won't reject flattening"))
static bool setLoopComponents(Value *&TC, Value *&TripCount, BinaryOperator *&Increment, SmallPtrSetImpl< Instruction * > &IterationInstructions)
static bool DoFlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI, ScalarEvolution *SE, AssumptionCache *AC, const TargetTransformInfo *TTI, LPMUpdater *U, MemorySSAUpdater *MSSAU)
static bool checkIVUsers(FlattenInfo &FI)
static bool CanFlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI, ScalarEvolution *SE, AssumptionCache *AC, const TargetTransformInfo *TTI)
static bool findLoopComponents(Loop *L, SmallPtrSetImpl< Instruction * > &IterationInstructions, PHINode *&InductionPHI, Value *&TripCount, BinaryOperator *&Increment, BranchInst *&BackBranch, ScalarEvolution *SE, bool IsWidened)
static OverflowResult checkOverflow(FlattenInfo &FI, DominatorTree *DT, AssumptionCache *AC)
static bool checkPHIs(FlattenInfo &FI, const TargetTransformInfo *TTI)
static cl::opt< unsigned > RepeatedInstructionThreshold("loop-flatten-cost-threshold", cl::Hidden, cl::init(2), cl::desc("Limit on the cost of instructions that can be repeated due to " "loop flattening"))
static cl::opt< bool > AssumeNoOverflow("loop-flatten-assume-no-overflow", cl::Hidden, cl::init(false), cl::desc("Assume that the product of the two iteration " "trip counts will never overflow"))
static bool checkOuterLoopInsts(FlattenInfo &FI, SmallPtrSetImpl< Instruction * > &IterationInstructions, const TargetTransformInfo *TTI)
This file defines the interface for the loop nest analysis.
This header provides classes for managing a pipeline of passes over loops in LLVM IR.
Module.h This file contains the declarations for the Module class.
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Virtual Register Rewriter
A container for analyses that lazily runs them and caches their results.
A cache of @llvm.assume calls within a function.
LLVM Basic Block Representation.
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Conditional or Unconditional Branch instruction.
static BranchInst * Create(BasicBlock *IfTrue, Instruction *InsertBefore=nullptr)
BasicBlock * getSuccessor(unsigned i) const
bool isUnconditional() const
Value * getCondition() const
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
@ ICMP_ULT
unsigned less than
This is the shared class of boolean and integer constants.
IntegerType * getType() const
getType - Specialize the getType() method to always return an IntegerType, which reduces the amount o...
static Constant * get(Type *Ty, uint64_t V, bool IsSigned=false)
If Ty is a vector type, return a Constant with a splat of the given value.
const APInt & getValue() const
Return the constant as an APInt value reference.
A parsed version of the target data layout string in and methods for querying it.
void deleteEdge(NodeT *From, NodeT *To)
Inform the dominator tree about a CFG edge deletion and update the tree.
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
This instruction compares its operands according to the predicate given to the constructor.
Value * CreateTrunc(Value *V, Type *DestTy, const Twine &Name="")
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
const BasicBlock * getParent() const
InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
BasicBlock * getSuccessor(unsigned Idx) const LLVM_READONLY
Return the specified successor. This instruction must be a terminator.
This class provides an interface for updating the loop pass manager based on mutations to the loop ne...
LoopT * getParentLoop() const
Return the parent loop if it exists or nullptr for top level loops.
PreservedAnalyses run(LoopNest &LN, LoopAnalysisManager &LAM, LoopStandardAnalysisResults &AR, LPMUpdater &U)
void erase(Loop *L)
Update LoopInfo after removing the last backedge from a loop.
This class represents a loop nest and can be used to query its properties.
ArrayRef< Loop * > getLoops() const
Get the loops in the nest.
Represents a single loop in the control flow graph.
An analysis that produces MemorySSA for a function.
void removeEdge(BasicBlock *From, BasicBlock *To)
Update the MemoryPhi in To following an edge deletion between From and To.
void verifyMemorySSA(VerificationLevel=VerificationLevel::Fast) const
Verify that MemorySSA is self consistent (IE definitions dominate all uses, uses appear in the right ...
A Module instance is used to store all the information related to an LLVM module.
Value * getIncomingValueForBlock(const BasicBlock *BB) const
Value * hasConstantValue() const
If the specified PHI node always merges together the same value, return the value,...
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
This class uses information about analyze scalars to rewrite expressions in canonical form.
This class represents an analyzed expression in the program.
Type * getType() const
Return the LLVM type of this SCEV expression.
The main scalar evolution driver.
const SCEV * getBackedgeTakenCount(const Loop *L, ExitCountKind Kind=Exact)
If the specified loop has a predictable backedge-taken count, return it, otherwise return a SCEVCould...
const SCEV * getSCEV(Value *V)
Return a SCEV expression for the full generality of the specified expression.
const SCEV * getTripCountFromExitCount(const SCEV *ExitCount)
A version of getTripCountFromExitCount below which always picks an evaluation type which can not resu...
void forgetLoop(const Loop *L)
This method should be called by the client when it has changed a loop in a way that may effect Scalar...
const SCEV * getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth=0)
void forgetBlockAndLoopDispositions(Value *V=nullptr)
Called when the client has changed the disposition of values in a loop or block.
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
iterator_range< user_iterator > users()
LLVMContext & getContext() const
All values hold a context through their type.
void dump() const
Support for debugging, callable in GDB: V->dump()
bool match(Val *V, const Pattern &P)
specificval_ty m_Specific(const Value *V)
Match if we have a specific specified value.
CastOperator_match< OpTy, Instruction::Trunc > m_Trunc(const OpTy &Op)
Matches Trunc.
BinaryOp_match< LHS, RHS, Instruction::Add, true > m_c_Add(const LHS &L, const RHS &R)
Matches a Add with LHS and RHS in either order.
class_match< Value > m_Value()
Match an arbitrary value and ignore it.
BinaryOp_match< LHS, RHS, Instruction::Mul, true > m_c_Mul(const LHS &L, const RHS &R)
Matches a Mul with LHS and RHS in either order.
initializer< Ty > init(const Ty &Val)
NodeAddr< PhiNode * > Phi
This is an optimization pass for GlobalISel generic memory operations.
PHINode * createWideIV(const WideIVInfo &WI, LoopInfo *LI, ScalarEvolution *SE, SCEVExpander &Rewriter, DominatorTree *DT, SmallVectorImpl< WeakTrackingVH > &DeadInsts, unsigned &NumElimExt, unsigned &NumWidened, bool HasGuards, bool UsePostIncrementRanges)
Widen Induction Variables - Extend the width of an IV to cover its widest uses.
OverflowResult computeOverflowForUnsignedMul(const Value *LHS, const Value *RHS, const SimplifyQuery &SQ)
bool isGuaranteedToExecuteForEveryIteration(const Instruction *I, const Loop *L)
Return true if this function can prove that the instruction I is executed for every iteration of the ...
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
bool VerifyMemorySSA
Enables verification of MemorySSA.
bool isSafeToSpeculativelyExecute(const Instruction *I, const Instruction *CtxI=nullptr, AssumptionCache *AC=nullptr, const DominatorTree *DT=nullptr, const TargetLibraryInfo *TLI=nullptr)
Return true if the instruction does not have any effects besides calculating the result and does not ...
auto count_if(R &&Range, UnaryPredicate P)
Wrapper function around std::count_if to count the number of times an element satisfying a given pred...
PreservedAnalyses getLoopPassPreservedAnalyses()
Returns the minimum set of Analyses that all loop passes must preserve.
bool RecursivelyDeleteDeadPHINode(PHINode *PN, const TargetLibraryInfo *TLI=nullptr, MemorySSAUpdater *MSSAU=nullptr)
If the specified value is an effectively dead PHI node, due to being a def-use chain of single-use no...
The adaptor from a function pass to a loop pass computes these analyses and makes them available to t...
TargetTransformInfo & TTI
Collect information about induction variables that are used by sign/zero extend operations.