58#include <forward_list>
64#define LLE_OPTION "loop-load-elim"
65#define DEBUG_TYPE LLE_OPTION
68 "runtime-check-per-loop-load-elim",
cl::Hidden,
69 cl::desc(
"Max number of memchecks allowed per eliminated load on average"),
74 cl::desc(
"The maximum number of SCEV checks allowed for Loop "
77STATISTIC(NumLoopLoadEliminted,
"Number of loads eliminated by LLE");
82struct StoreToLoadForwardingCandidate {
87 : Load(Load), Store(Store) {}
92 bool isDependenceDistanceOfOne(PredicatedScalarEvolution &PSE, Loop *L,
93 const DominatorTree &DT)
const {
94 Value *LoadPtr = Load->getPointerOperand();
95 Value *StorePtr = Store->getPointerOperand();
97 auto &
DL = Load->getDataLayout();
101 DL.getTypeSizeInBits(LoadType) ==
103 "Should be a known dependence");
106 getPtrStride(PSE, LoadType, LoadPtr, L, DT).value_or(0);
107 int64_t StrideStore =
108 getPtrStride(PSE, LoadType, StorePtr, L, DT).value_or(0);
109 if (!StrideLoad || !StrideStore || StrideLoad != StrideStore)
119 if (std::abs(StrideLoad) != 1)
122 unsigned TypeByteSize =
DL.getTypeAllocSize(LoadType);
133 const APInt &Val = Dist->getAPInt();
134 return Val == TypeByteSize * StrideLoad;
137 Value *getLoadPtr()
const {
return Load->getPointerOperand(); }
140 friend raw_ostream &
operator<<(raw_ostream &OS,
141 const StoreToLoadForwardingCandidate &Cand) {
142 OS << *Cand.Store <<
" -->\n";
143 OS.
indent(2) << *Cand.Load <<
"\n";
156 L->getLoopLatches(Latches);
164 return Load->getParent() != L->getHeader();
170class LoadEliminationForLoop {
172 LoadEliminationForLoop(Loop *L, LoopInfo *LI,
const LoopAccessInfo &LAI,
173 DominatorTree *DT, BlockFrequencyInfo *BFI,
174 ProfileSummaryInfo* PSI)
175 : L(L), LI(LI), LAI(LAI), DT(DT), BFI(BFI), PSI(PSI), PSE(LAI.getPSE()) {}
182 std::forward_list<StoreToLoadForwardingCandidate>
183 findStoreToLoadDependences(
const LoopAccessInfo &LAI) {
184 std::forward_list<StoreToLoadForwardingCandidate> Candidates;
186 const auto &DepChecker = LAI.getDepChecker();
187 const auto *Deps = DepChecker.getDependences();
195 SmallPtrSet<Instruction *, 4> LoadsWithUnknownDependence;
197 for (
const auto &Dep : *Deps) {
199 Instruction *Destination = Dep.getDestination(DepChecker);
205 LoadsWithUnknownDependence.
insert(Source);
207 LoadsWithUnknownDependence.
insert(Destination);
211 if (Dep.isBackward())
217 assert(Dep.isForward() &&
"Needs to be a forward dependence");
229 Store->getDataLayout()))
232 Candidates.emplace_front(Load, Store);
235 if (!LoadsWithUnknownDependence.
empty())
236 Candidates.remove_if([&](
const StoreToLoadForwardingCandidate &
C) {
237 return LoadsWithUnknownDependence.
count(
C.Load);
244 unsigned getInstrIndex(Instruction *Inst) {
245 auto I = InstOrder.find(Inst);
246 assert(
I != InstOrder.end() &&
"No index for instruction");
269 void removeDependencesFromMultipleStores(
270 std::forward_list<StoreToLoadForwardingCandidate> &Candidates) {
273 using LoadToSingleCandT =
274 DenseMap<LoadInst *, const StoreToLoadForwardingCandidate *>;
275 LoadToSingleCandT LoadToSingleCand;
277 for (
const auto &Cand : Candidates) {
279 LoadToSingleCandT::iterator Iter;
281 std::tie(Iter, NewElt) =
282 LoadToSingleCand.insert(std::make_pair(Cand.Load, &Cand));
284 const StoreToLoadForwardingCandidate *&OtherCand = Iter->second;
286 if (OtherCand ==
nullptr)
292 if (Cand.Store->getParent() == OtherCand->Store->
getParent() &&
293 Cand.isDependenceDistanceOfOne(PSE, L, *DT) &&
294 OtherCand->isDependenceDistanceOfOne(PSE, L, *DT)) {
296 if (getInstrIndex(OtherCand->Store) < getInstrIndex(Cand.Store))
303 Candidates.remove_if([&](
const StoreToLoadForwardingCandidate &Cand) {
304 if (LoadToSingleCand[Cand.Load] != &Cand) {
306 dbgs() <<
"Removing from candidates: \n"
308 <<
" The load may have multiple stores forwarding to "
321 bool needsChecking(
unsigned PtrIdx1,
unsigned PtrIdx2,
322 const SmallPtrSetImpl<Value *> &PtrsWrittenOnFwdingPath,
323 const SmallPtrSetImpl<Value *> &CandLoadPtrs) {
325 LAI.getRuntimePointerChecking()->getPointerInfo(PtrIdx1).PointerValue;
327 LAI.getRuntimePointerChecking()->getPointerInfo(PtrIdx2).PointerValue;
328 return ((PtrsWrittenOnFwdingPath.
count(Ptr1) && CandLoadPtrs.
count(Ptr2)) ||
329 (PtrsWrittenOnFwdingPath.
count(Ptr2) && CandLoadPtrs.
count(Ptr1)));
336 SmallPtrSet<Value *, 4> findPointersWrittenOnForwardingPath(
337 const SmallVectorImpl<StoreToLoadForwardingCandidate> &Candidates) {
357 [&](
const StoreToLoadForwardingCandidate &
A,
358 const StoreToLoadForwardingCandidate &
B) {
359 return getInstrIndex(
A.Load) <
360 getInstrIndex(
B.Load);
363 StoreInst *FirstStore =
365 [&](
const StoreToLoadForwardingCandidate &
A,
366 const StoreToLoadForwardingCandidate &
B) {
367 return getInstrIndex(
A.Store) <
368 getInstrIndex(
B.Store);
375 SmallPtrSet<Value *, 4> PtrsWrittenOnFwdingPath;
379 PtrsWrittenOnFwdingPath.insert(S->getPointerOperand());
381 const auto &MemInstrs = LAI.getDepChecker().getMemoryInstructions();
382 std::for_each(MemInstrs.begin() + getInstrIndex(FirstStore) + 1,
383 MemInstrs.end(), InsertStorePtr);
384 std::for_each(MemInstrs.begin(), &MemInstrs[getInstrIndex(LastLoad)],
387 return PtrsWrittenOnFwdingPath;
392 SmallVector<RuntimePointerCheck, 4> collectMemchecks(
393 const SmallVectorImpl<StoreToLoadForwardingCandidate> &Candidates) {
395 SmallPtrSet<Value *, 4> PtrsWrittenOnFwdingPath =
396 findPointersWrittenOnForwardingPath(Candidates);
399 SmallPtrSet<Value *, 4> CandLoadPtrs;
400 for (
const auto &Candidate : Candidates)
401 CandLoadPtrs.
insert(Candidate.getLoadPtr());
403 const auto &AllChecks = LAI.getRuntimePointerChecking()->getChecks();
404 SmallVector<RuntimePointerCheck, 4> Checks;
406 copy_if(AllChecks, std::back_inserter(Checks),
408 for (
auto PtrIdx1 :
Check.first->Members)
409 for (
auto PtrIdx2 :
Check.second->Members)
410 if (needsChecking(PtrIdx1, PtrIdx2, PtrsWrittenOnFwdingPath,
418 LLVM_DEBUG(LAI.getRuntimePointerChecking()->printChecks(
dbgs(), Checks));
425 propagateStoredValueToLoadUsers(
const StoreToLoadForwardingCandidate &Cand,
444 auto *PH = L->getLoopPreheader();
445 assert(PH &&
"Preheader should exist!");
446 Value *InitialPtr =
SEE.expandCodeFor(PtrSCEV->getStart(), Ptr->
getType(),
447 PH->getTerminator());
449 new LoadInst(Cand.Load->
getType(), InitialPtr,
"load_initial",
451 PH->getTerminator()->getIterator());
459 PHI->insertBefore(L->getHeader()->begin());
460 PHI->addIncoming(Initial, PH);
467 assert(
DL.getTypeSizeInBits(LoadType) ==
DL.getTypeSizeInBits(StoreType) &&
468 "The type sizes should match!");
471 if (LoadType != StoreType) {
473 "store_forward_cast",
481 PHI->addIncoming(StoreValue, L->getLoopLatch());
490 LLVM_DEBUG(
dbgs() <<
"\nIn \"" << L->getHeader()->getParent()->getName()
491 <<
"\" checking " << *L <<
"\n");
512 auto StoreToLoadDependences = findStoreToLoadDependences(LAI);
513 if (StoreToLoadDependences.empty())
518 InstOrder = LAI.getDepChecker().generateInstructionOrderMap();
522 removeDependencesFromMultipleStores(StoreToLoadDependences);
523 if (StoreToLoadDependences.empty())
528 for (
const StoreToLoadForwardingCandidate &Cand : StoreToLoadDependences) {
544 if (!Cand.isDependenceDistanceOfOne(PSE, L, *DT))
548 "Loading from something other than indvar?");
551 "Storing to something other than indvar?");
557 <<
". Valid store-to-load forwarding across the loop backedge\n");
559 if (Candidates.
empty())
564 SmallVector<RuntimePointerCheck, 4> Checks = collectMemchecks(Candidates);
572 if (LAI.getPSE().getPredicate().getComplexity() >
578 if (!L->isLoopSimplifyForm()) {
583 if (!Checks.
empty() || !LAI.getPSE().getPredicate().isAlwaysTrue()) {
584 if (LAI.hasConvergentOp()) {
586 "convergent calls\n");
590 auto *HeaderBB = L->getHeader();
592 PGSOQueryType::IRPass)) {
594 dbgs() <<
"Versioning is needed but not allowed when optimizing "
602 LoopVersioning LV(LAI, Checks, L, LI, DT, PSE.getSE());
607 auto NoLongerGoodCandidate = [
this](
608 const StoreToLoadForwardingCandidate &Cand) {
619 SCEVExpander
SEE(*PSE.getSE(),
"storeforward");
620 for (
const auto &Cand : Candidates)
621 propagateStoredValueToLoadUsers(Cand,
SEE);
622 NumLoopLoadEliminted += Candidates.size();
632 DenseMap<Instruction *, unsigned> InstOrder;
636 const LoopAccessInfo &LAI;
638 BlockFrequencyInfo *BFI;
639 ProfileSummaryInfo *PSI;
640 PredicatedScalarEvolution PSE;
660 for (
Loop *TopLevelLoop : LI)
664 if (L->isInnermost())
669 for (
Loop *L : Worklist) {
671 if (!L->isRotatedForm() || !L->getExitingBlock())
674 LoadEliminationForLoop LEL(L, &LI, LAIs.
getInfo(*L), &DT, BFI, PSI);
694 auto *BFI = (PSI && PSI->hasProfileSummary()) ?
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
This file implements a class to represent arbitrary precision integral constant values and operations...
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
This file defines the DenseMap class.
This file builds on the ADT/GraphTraits.h file to build generic depth first graph iterator.
This is the interface for a simple mod/ref and alias analysis over globals.
This header defines various interfaces for pass management in LLVM.
This header provides classes for managing per-loop analyses.
static bool eliminateLoadsAcrossLoops(Function &F, LoopInfo &LI, DominatorTree &DT, BlockFrequencyInfo *BFI, ProfileSummaryInfo *PSI, ScalarEvolution *SE, AssumptionCache *AC, LoopAccessInfoManager &LAIs)
static cl::opt< unsigned > LoadElimSCEVCheckThreshold("loop-load-elimination-scev-check-threshold", cl::init(8), cl::Hidden, cl::desc("The maximum number of SCEV checks allowed for Loop " "Load Elimination"))
static bool isLoadConditional(LoadInst *Load, Loop *L)
Return true if the load is not executed on all paths in the loop.
static bool doesStoreDominatesAllLatches(BasicBlock *StoreBlock, Loop *L, DominatorTree *DT)
Check if the store dominates all latches, so as long as there is no intervening store this value will...
static cl::opt< unsigned > CheckPerElim("runtime-check-per-loop-load-elim", cl::Hidden, cl::desc("Max number of memchecks allowed per eliminated load on average"), cl::init(1))
This header defines the LoopLoadEliminationPass object.
This file defines the SmallPtrSet class.
This file defines the SmallVector class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
A function analysis which provides an AssumptionCache.
A cache of @llvm.assume calls within a function.
LLVM Basic Block Representation.
Analysis pass which computes BlockFrequencyInfo.
BlockFrequencyInfo pass uses BlockFrequencyInfoImpl implementation to estimate IR basic block frequen...
static LLVM_ABI bool isBitOrNoopPointerCastable(Type *SrcTy, Type *DestTy, const DataLayout &DL)
Check whether a bitcast, inttoptr, or ptrtoint cast between these types is valid and a no-op.
static LLVM_ABI CastInst * CreateBitOrPointerCast(Value *S, Type *Ty, const Twine &Name="", InsertPosition InsertBefore=nullptr)
Create a BitCast, a PtrToInt, or an IntToPTr cast instruction.
static DebugLoc getDropped()
Analysis pass which computes a DominatorTree.
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
LLVM_ABI bool dominates(const BasicBlock *BB, const Use &U) const
Return true if the (end of the) basic block BB dominates the use U.
const DebugLoc & getDebugLoc() const
Return the debug location for this node as a DebugLoc.
LLVM_ABI const DataLayout & getDataLayout() const
Get the data layout of the module this instruction belongs to.
An instruction for reading from memory.
Value * getPointerOperand()
Align getAlign() const
Return the alignment of the access that is being performed.
This analysis provides dependence information for the memory accesses of a loop.
LLVM_ABI const LoopAccessInfo & getInfo(Loop &L, bool AllowPartial=false)
Analysis pass that exposes the LoopInfo for a function.
Represents a single loop in the control flow graph.
static PHINode * Create(Type *Ty, unsigned NumReservedValues, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)
Constructors - NumReservedValues is a hint for the number of incoming edges that this phi node will h...
ScalarEvolution * getSE() const
Returns the ScalarEvolution analysis used.
LLVM_ABI const SCEV * getSCEV(Value *V)
Returns the SCEV expression of V, in the context of the current SCEV predicate.
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
PreservedAnalyses & preserve()
Mark an analysis as preserved.
An analysis pass based on the new PM to deliver ProfileSummaryInfo.
Analysis providing profile information.
Analysis pass that exposes the ScalarEvolution for a function.
The main scalar evolution driver.
LLVM_ABI const SCEV * getMinusSCEV(SCEVUse LHS, SCEVUse RHS, SCEV::NoWrapFlags Flags=SCEV::FlagAnyWrap, unsigned Depth=0)
Return LHS-RHS.
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
An instruction for storing to memory.
Value * getValueOperand()
Value * getPointerOperand()
LLVM_ABI unsigned getPointerAddressSpace() const
Get the address space of this pointer or pointer vector type.
Type * getType() const
All values are typed, get the type of this value.
LLVM_ABI void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
const ParentTy * getParent() const
self_iterator getIterator()
raw_ostream & indent(unsigned NumSpaces)
indent - Insert 'NumSpaces' spaces.
@ C
The default llvm calling convention, compatible with C.
initializer< Ty > init(const Ty &Val)
friend class Instruction
Iterator for Instructions in a `BasicBlock.
This is an optimization pass for GlobalISel generic memory operations.
LLVM_ABI bool simplifyLoop(Loop *L, DominatorTree *DT, LoopInfo *LI, ScalarEvolution *SE, AssumptionCache *AC, MemorySSAUpdater *MSSAU, bool PreserveLCSSA)
Simplify each loop in a loop nest recursively.
FunctionAddr VTableAddr Value
auto min_element(R &&Range)
Provide wrappers to std::min_element which take ranges instead of having to pass begin/end explicitly...
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
OuterAnalysisManagerProxy< ModuleAnalysisManager, Function > ModuleAnalysisManagerFunctionProxy
Provide the ModuleAnalysisManager to Function proxy.
LLVM_ABI bool shouldOptimizeForSize(const MachineFunction *MF, ProfileSummaryInfo *PSI, const MachineBlockFrequencyInfo *BFI, PGSOQueryType QueryType=PGSOQueryType::Other)
Returns true if machine function MF is suggested to be size-optimized based on the profile.
std::pair< const RuntimeCheckingPtrGroup *, const RuntimeCheckingPtrGroup * > RuntimePointerCheck
A memcheck which made up of a pair of grouped pointers.
OutputIt copy_if(R &&Range, OutputIt Out, UnaryPredicate P)
Provide wrappers to std::copy_if which take ranges instead of having to pass begin/end explicitly.
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
auto max_element(R &&Range)
Provide wrappers to std::max_element which take ranges instead of having to pass begin/end explicitly...
raw_ostream & operator<<(raw_ostream &OS, const APFixedPoint &FX)
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
void erase_if(Container &C, UnaryPredicate P)
Provide a container algorithm similar to C++ Library Fundamentals v2's erase_if which is equivalent t...
Type * getLoadStoreType(const Value *I)
A helper function that returns the type of a load or store instruction.
iterator_range< df_iterator< T > > depth_first(const T &G)
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.
LLVM_ABI std::optional< int64_t > getPtrStride(PredicatedScalarEvolution &PSE, Type *AccessTy, Value *Ptr, const Loop *Lp, const DominatorTree &DT, const DenseMap< Value *, const SCEV * > &StridesMap=DenseMap< Value *, const SCEV * >(), bool Assume=false, bool ShouldCheckWrap=true)
If the pointer has a constant stride return it in units of the access type size.
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)