68#define DEBUG_TYPE "mergeicmps"
76 :
GEP(
GEP), LoadI(LoadI), BaseId(BaseId), Offset(Offset) {}
78 BCEAtom(
const BCEAtom &) =
delete;
79 BCEAtom &operator=(
const BCEAtom &) =
delete;
81 BCEAtom(BCEAtom &&that) =
default;
82 BCEAtom &operator=(BCEAtom &&that) {
88 Offset = std::move(that.Offset);
103 return BaseId != O.BaseId ? BaseId < O.BaseId : Offset.slt(O.Offset);
114class BaseIdentifier {
120 const auto Insertion = BaseToIndex.try_emplace(
Base, Order);
121 if (Insertion.second)
123 return Insertion.first->second;
134BCEAtom visitICmpLoadOperand(
Value *
const Val, BaseIdentifier &BaseId) {
135 auto *
const LoadI = dyn_cast<LoadInst>(Val);
139 if (LoadI->isUsedOutsideOfBlock(LoadI->getParent())) {
144 if (!LoadI->isSimple()) {
149 if (
Addr->getType()->getPointerAddressSpace() != 0) {
153 const auto &
DL = LoadI->getModule()->getDataLayout();
163 auto *
GEP = dyn_cast<GetElementPtrInst>(
Addr);
166 if (
GEP->isUsedOutsideOfBlock(LoadI->getParent())) {
172 Base =
GEP->getPointerOperand();
189 BCECmp(BCEAtom L, BCEAtom R,
int SizeBits,
const ICmpInst *CmpI)
204 BCECmpBlock(BCECmp Cmp,
BasicBlock *BB, InstructionSet BlockInsts)
207 const BCEAtom &Lhs()
const {
return Cmp.Lhs; }
208 const BCEAtom &Rhs()
const {
return Cmp.Rhs; }
209 int SizeBits()
const {
return Cmp.SizeBits; }
212 bool doesOtherWork()
const;
232 InstructionSet BlockInsts;
234 bool RequireSplit =
false;
236 unsigned OrigOrder = 0;
242bool BCECmpBlock::canSinkBCECmpInst(
const Instruction *Inst,
247 auto MayClobber = [&](
LoadInst *LI) {
253 if (MayClobber(
Cmp.Lhs.LoadI) || MayClobber(
Cmp.Rhs.LoadI))
259 const Instruction *OpI = dyn_cast<Instruction>(Op);
260 return OpI && BlockInsts.contains(OpI);
267 if (BlockInsts.count(&Inst))
269 assert(canSinkBCECmpInst(&Inst, AA) &&
"Split unsplittable block");
282 if (!BlockInsts.count(&Inst)) {
283 if (!canSinkBCECmpInst(&Inst, AA))
290bool BCECmpBlock::doesOtherWork()
const {
296 if (!BlockInsts.count(&Inst))
304std::optional<BCECmp> visitICmp(
const ICmpInst *
const CmpI,
306 BaseIdentifier &BaseId) {
319 << (ExpectedPredicate == ICmpInst::ICMP_EQ ?
"eq" :
"ne")
321 auto Lhs = visitICmpLoadOperand(CmpI->
getOperand(0), BaseId);
324 auto Rhs = visitICmpLoadOperand(CmpI->
getOperand(1), BaseId);
328 return BCECmp(std::move(Lhs), std::move(Rhs),
334std::optional<BCECmpBlock> visitCmpBlock(
Value *
const Val,
337 BaseIdentifier &BaseId) {
340 auto *
const BranchI = dyn_cast<BranchInst>(
Block->getTerminator());
346 if (BranchI->isUnconditional()) {
352 ExpectedPredicate = ICmpInst::ICMP_EQ;
356 const auto *
const Const = cast<ConstantInt>(Val);
358 if (!
Const->isZero())
361 assert(BranchI->getNumSuccessors() == 2 &&
"expecting a cond branch");
362 BasicBlock *
const FalseBlock = BranchI->getSuccessor(1);
363 Cond = BranchI->getCondition();
365 FalseBlock == PhiBlock ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE;
368 auto *CmpI = dyn_cast<ICmpInst>(
Cond);
373 std::optional<BCECmp>
Result = visitICmp(CmpI, ExpectedPredicate, BaseId);
382 BlockInsts.insert(
Result->Rhs.GEP);
383 return BCECmpBlock(std::move(*Result), Block, BlockInsts);
386static inline void enqueueBlock(std::vector<BCECmpBlock> &Comparisons,
387 BCECmpBlock &&Comparison) {
389 <<
"': Found cmp of " << Comparison.SizeBits()
390 <<
" bits between " << Comparison.Lhs().BaseId <<
" + "
391 << Comparison.Lhs().Offset <<
" and "
392 << Comparison.Rhs().BaseId <<
" + "
393 << Comparison.Rhs().Offset <<
"\n");
395 Comparison.OrigOrder = Comparisons.size();
396 Comparisons.push_back(std::move(Comparison));
402 using ContiguousBlocks = std::vector<BCECmpBlock>;
404 BCECmpChain(
const std::vector<BasicBlock *> &
Blocks,
PHINode &Phi,
410 bool atLeastOneMerged()
const {
411 return any_of(MergedBlocks_,
418 std::vector<ContiguousBlocks> MergedBlocks_;
423static bool areContiguous(
const BCECmpBlock &First,
const BCECmpBlock &Second) {
424 return First.Lhs().BaseId == Second.Lhs().BaseId &&
425 First.Rhs().BaseId == Second.Rhs().BaseId &&
426 First.Lhs().Offset +
First.SizeBits() / 8 == Second.Lhs().Offset &&
427 First.Rhs().Offset +
First.SizeBits() / 8 == Second.Rhs().Offset;
430static unsigned getMinOrigOrder(
const BCECmpChain::ContiguousBlocks &
Blocks) {
431 unsigned MinOrigOrder = std::numeric_limits<unsigned>::max();
432 for (
const BCECmpBlock &Block :
Blocks)
433 MinOrigOrder = std::min(MinOrigOrder,
Block.OrigOrder);
439static std::vector<BCECmpChain::ContiguousBlocks>
440mergeBlocks(std::vector<BCECmpBlock> &&
Blocks) {
441 std::vector<BCECmpChain::ContiguousBlocks> MergedBlocks;
445 [](
const BCECmpBlock &LhsBlock,
const BCECmpBlock &RhsBlock) {
446 return std::tie(LhsBlock.Lhs(), LhsBlock.Rhs()) <
447 std::tie(RhsBlock.Lhs(), RhsBlock.Rhs());
450 BCECmpChain::ContiguousBlocks *LastMergedBlock =
nullptr;
451 for (BCECmpBlock &Block :
Blocks) {
452 if (!LastMergedBlock || !areContiguous(LastMergedBlock->back(), Block)) {
453 MergedBlocks.emplace_back();
454 LastMergedBlock = &MergedBlocks.back();
457 << LastMergedBlock->back().BB->getName() <<
"\n");
459 LastMergedBlock->push_back(std::move(Block));
464 llvm::sort(MergedBlocks, [](
const BCECmpChain::ContiguousBlocks &LhsBlocks,
465 const BCECmpChain::ContiguousBlocks &RhsBlocks) {
466 return getMinOrigOrder(LhsBlocks) < getMinOrigOrder(RhsBlocks);
472BCECmpChain::BCECmpChain(
const std::vector<BasicBlock *> &
Blocks,
PHINode &Phi,
475 assert(!
Blocks.empty() &&
"a chain should have at least one block");
477 std::vector<BCECmpBlock> Comparisons;
478 BaseIdentifier BaseId;
480 assert(Block &&
"invalid block");
481 std::optional<BCECmpBlock> Comparison = visitCmpBlock(
484 LLVM_DEBUG(
dbgs() <<
"chain with invalid BCECmpBlock, no merge.\n");
487 if (Comparison->doesOtherWork()) {
489 <<
"' does extra work besides compare\n");
490 if (Comparisons.empty()) {
504 if (Comparison->canSplit(AA)) {
506 <<
"Split initial block '" << Comparison->BB->getName()
507 <<
"' that does extra work besides compare\n");
508 Comparison->RequireSplit =
true;
509 enqueueBlock(Comparisons, std::move(*Comparison));
512 <<
"ignoring initial block '" << Comparison->BB->getName()
513 <<
"' that does extra work besides compare\n");
542 enqueueBlock(Comparisons, std::move(*Comparison));
546 if (Comparisons.empty()) {
547 LLVM_DEBUG(
dbgs() <<
"chain with no BCE basic blocks, no merge\n");
550 EntryBlock_ = Comparisons[0].BB;
551 MergedBlocks_ = mergeBlocks(std::move(Comparisons));
558class MergedBlockName {
564 :
Name(makeName(Comparisons)) {}
571 if (Comparisons.
size() == 1)
572 return Comparisons[0].BB->getName();
573 const int size = std::accumulate(Comparisons.
begin(), Comparisons.
end(), 0,
574 [](
int i,
const BCECmpBlock &Cmp) {
575 return i + Cmp.BB->getName().size();
586 Scratch.
append(str.begin(), str.end());
588 append(Comparisons[0].BB->getName());
589 for (
int I = 1,
E = Comparisons.
size();
I <
E; ++
I) {
596 return Scratch.
str();
607 assert(!Comparisons.
empty() &&
"merging zero comparisons");
609 const BCECmpBlock &FirstCmp = Comparisons[0];
614 NextCmpBlock->
getParent(), InsertBefore);
618 if (FirstCmp.Lhs().GEP)
619 Lhs =
Builder.Insert(FirstCmp.Lhs().GEP->clone());
621 Lhs = FirstCmp.Lhs().LoadI->getPointerOperand();
622 if (FirstCmp.Rhs().GEP)
623 Rhs =
Builder.Insert(FirstCmp.Rhs().GEP->clone());
625 Rhs = FirstCmp.Rhs().LoadI->getPointerOperand();
627 Value *IsEqual =
nullptr;
635 Comparisons, [](
const BCECmpBlock &
B) {
return B.RequireSplit; });
636 if (ToSplit != Comparisons.
end()) {
638 ToSplit->split(BB, AA);
641 if (Comparisons.
size() == 1) {
649 IsEqual =
Builder.CreateICmpEQ(LhsLoad, RhsLoad);
651 const unsigned TotalSizeBits = std::accumulate(
652 Comparisons.
begin(), Comparisons.
end(), 0u,
653 [](
int Size,
const BCECmpBlock &
C) { return Size + C.SizeBits(); });
665 IsEqual =
Builder.CreateICmpEQ(
671 if (NextCmpBlock == PhiBB) {
678 Builder.CreateCondBr(IsEqual, NextCmpBlock, PhiBB);
680 DTU.
applyUpdates({{DominatorTree::Insert, BB, NextCmpBlock},
681 {DominatorTree::Insert, BB, PhiBB}});
688 assert(atLeastOneMerged() &&
"simplifying trivial BCECmpChain");
689 LLVM_DEBUG(
dbgs() <<
"Simplifying comparison chain starting at block "
690 << EntryBlock_->getName() <<
"\n");
697 InsertBefore = NextCmpBlock = mergeComparisons(
698 Blocks, InsertBefore, NextCmpBlock, Phi_, TLI, AA, DTU);
709 DTU.
applyUpdates({{DominatorTree::Delete, Pred, EntryBlock_},
710 {DominatorTree::Insert, Pred, NextCmpBlock}});
715 const bool ChainEntryIsFnEntry = EntryBlock_->isEntryBlock();
716 if (ChainEntryIsFnEntry && DTU.
hasDomTree()) {
718 << EntryBlock_->getName() <<
" to "
719 << NextCmpBlock->
getName() <<
"\n");
721 DTU.
applyUpdates({{DominatorTree::Delete, NextCmpBlock, EntryBlock_}});
723 EntryBlock_ =
nullptr;
727 for (
const auto &
Blocks : MergedBlocks_) {
728 for (
const BCECmpBlock &Block :
Blocks) {
736 MergedBlocks_.clear();
740std::vector<BasicBlock *> getOrderedBlocks(
PHINode &Phi,
744 std::vector<BasicBlock *>
Blocks(NumBlocks);
745 assert(LastBlock &&
"invalid last block");
747 for (
int BlockIndex = NumBlocks - 1; BlockIndex > 0; --BlockIndex) {
752 <<
" has its address taken\n");
755 Blocks[BlockIndex] = CurBlock;
757 if (!SinglePredecessor) {
760 <<
" has two or more predecessors\n");
766 <<
" does not link back to the phi\n");
769 CurBlock = SinglePredecessor;
816 <<
"skip: non-constant value not from cmp or not from last block.\n");
833 if (
Blocks.empty())
return false;
834 BCECmpChain CmpChain(
Blocks, Phi, AA);
836 if (!CmpChain.atLeastOneMerged()) {
841 return CmpChain.simplify(TLI, AA, DTU);
855 if (!TLI.
has(LibFunc_memcmp))
859 DomTreeUpdater::UpdateStrategy::Eager);
861 bool MadeChange =
false;
865 if (
auto *
const Phi = dyn_cast<PHINode>(&*BB.
begin()))
866 MadeChange |= processPhi(*Phi, TLI, AA, DTU);
882 const auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(
F);
883 const auto &
TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(
F);
886 auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>();
887 auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults();
888 return runImpl(
F, TLI,
TTI, AA, DTWP ? &DTWP->getDomTree() :
nullptr);
903char MergeICmpsLegacyPass::ID = 0;
905 "Merge contiguous icmps into a memcmp",
false,
false)
920 const bool MadeChanges =
runImpl(
F, TLI,
TTI, AA, DT);
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
SmallVector< MachineOperand, 4 > Cond
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
static Error split(StringRef Str, char Separator, std::pair< StringRef, StringRef > &Split)
Checked version of split, to ensure mandatory subparts.
DenseMap< Block *, BlockRelaxAux > Blocks
static bool runImpl(Function &F, const TargetLowering &TLI)
This is the interface for a simple mod/ref and alias analysis over globals.
Merge contiguous icmps into a memcmp
#define INITIALIZE_PASS_DEPENDENCY(depName)
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
A manager for alias analyses.
A wrapper pass to provide the legacy pass manager access to a suitably prepared AAResults object.
ModRefInfo getModRefInfo(const Instruction *I, const std::optional< MemoryLocation > &OptLoc)
Check whether or not an instruction may read or write the optionally specified memory location.
Class for arbitrary precision integers.
A container for analyses that lazily runs them and caches their results.
PassT::Result * getCachedResult(IRUnitT &IR) const
Get the cached result of an analysis pass for a given IR unit.
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
AnalysisUsage & addPreserved()
Add the specified Pass class to the set of analyses preserved by this pass.
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
size_t size() const
size - Get the array size.
bool empty() const
empty - Check if the array is empty.
LLVM Basic Block Representation.
iterator begin()
Instruction iterator methods.
bool hasAddressTaken() const
Returns true if there are any uses of this basic block other than direct branches,...
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
const BasicBlock * getSinglePredecessor() const
Return the predecessor of this block if it has a single predecessor block.
const BasicBlock * getSingleSuccessor() const
Return the successor of this block if it has a single successor.
const Function * getParent() const
Return the enclosing method, or null if none.
LLVMContext & getContext() const
Get the context in which this basic block lives.
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Predicate
This enumeration lists the possible predicates for CmpInst subclasses.
Predicate getPredicate() const
Return the predicate for this instruction.
static Constant * get(Type *Ty, uint64_t V, bool IsSigned=false)
If Ty is a vector type, return a Constant with a splat of the given value.
static ConstantInt * getFalse(LLVMContext &Context)
bool hasDomTree() const
Returns true if it holds a DominatorTree.
void applyUpdates(ArrayRef< DominatorTree::UpdateType > Updates)
Submit updates to all available trees.
DominatorTree & getDomTree()
Flush DomTree updates and return DomTree.
Analysis pass which computes a DominatorTree.
DomTreeNodeBase< NodeT > * setNewRoot(NodeT *BB)
Add a new node to the forward dominator tree and make it a new root.
Legacy analysis pass which computes a DominatorTree.
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
FunctionPass class - This class is used to implement most global optimizations.
virtual bool runOnFunction(Function &F)=0
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
bool skipFunction(const Function &F) const
Optional passes call this function to check whether the pass should be skipped.
an instruction for type-safe pointer arithmetic to access elements of arrays and structs
Legacy wrapper pass to provide the GlobalsAAResult object.
This instruction compares its operands according to the predicate given to the constructor.
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
bool mayWriteToMemory() const LLVM_READONLY
Return true if this instruction may modify memory.
const Module * getModule() const
Return the module owning the function this instruction belongs to or nullptr it the function does not...
const BasicBlock * getParent() const
bool comesBefore(const Instruction *Other) const
Given an instruction Other in the same basic block as this instruction, return true if this instructi...
void moveBefore(Instruction *MovePos)
Unlink this instruction from its current basic block and insert it into the basic block that MovePos ...
This is an important class for using LLVM in a threaded context.
An instruction for reading from memory.
static MemoryLocation get(const LoadInst *LI)
Return a location with information about the memory reference by the given instruction.
const DataLayout & getDataLayout() const
Get the data layout for the module's target platform.
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
Value * getIncomingValueForBlock(const BasicBlock *BB) const
BasicBlock * getIncomingBlock(unsigned i) const
Return incoming basic block number i.
Value * getIncomingValue(unsigned i) const
Return incoming value number x.
int getBasicBlockIndex(const BasicBlock *BB) const
Return the first index of the specified basic block in the value list for this PHI.
unsigned getNumIncomingValues() const
Return the number of incoming edges.
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
Pass interface - Implemented by all 'passes'.
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
void preserve()
Mark an analysis as preserved.
Implements a dense probed hash-table based set with some number of buckets stored inline.
SmallString - A SmallString is just a SmallVector with methods and accessors that make it work better...
void append(StringRef RHS)
Append from a StringRef.
StringRef str() const
Explicit conversion to StringRef.
void reserve(size_type N)
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
StringRef - Represent a constant reference to a string, i.e.
constexpr bool empty() const
empty - Check if the string is empty.
Analysis pass providing the TargetTransformInfo.
Analysis pass providing the TargetLibraryInfo.
Provides information about what library functions are available for the current target.
bool has(LibFunc F) const
Tests whether a library function is available.
unsigned getSizeTSize(const Module &M) const
Returns the size of the size_t type in bits.
unsigned getIntSize() const
Get size of a C-level int or unsigned int, in bits.
bool replaceUsesOfWith(Value *From, Value *To)
Replace uses of one Value with another.
Value * getOperand(unsigned i) const
LLVM Value Representation.
Type * getType() const
All values are typed, get the type of this value.
bool hasOneUse() const
Return true if there is exactly one use of this value.
StringRef getName() const
Return a constant reference to the value's name.
std::pair< iterator, bool > insert(const ValueT &V)
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
@ C
The default llvm calling convention, compatible with C.
void append(SmallVectorImpl< char > &path, const Twine &a, const Twine &b="", const Twine &c="", const Twine &d="")
Append to path.
This is an optimization pass for GlobalISel generic memory operations.
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
bool operator<(int64_t V1, const APSInt &V2)
auto size(R &&Range, std::enable_if_t< std::is_base_of< std::random_access_iterator_tag, typename std::iterator_traits< decltype(Range.begin())>::iterator_category >::value, void > *=nullptr)
Get the size of a range.
void initializeMergeICmpsLegacyPassPass(PassRegistry &)
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Value * emitMemCmp(Value *Ptr1, Value *Ptr2, Value *Len, IRBuilderBase &B, const DataLayout &DL, const TargetLibraryInfo *TLI)
Emit a call to the memcmp function.
auto reverse(ContainerTy &&C)
bool isModSet(const ModRefInfo MRI)
void sort(IteratorTy Start, IteratorTy End)
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
bool none_of(R &&Range, UnaryPredicate P)
Provide wrappers to std::none_of which take ranges instead of having to pass begin/end explicitly.
Interval::pred_iterator pred_begin(Interval *I)
pred_begin/pred_end - define methods so that Intervals may be used just like BasicBlocks can with the...
Pass * createMergeICmpsLegacyPass()
bool isDereferenceablePointer(const Value *V, Type *Ty, const DataLayout &DL, const Instruction *CtxI=nullptr, AssumptionCache *AC=nullptr, const DominatorTree *DT=nullptr, const TargetLibraryInfo *TLI=nullptr)
Return true if this is always a dereferenceable pointer.
OutputIt move(R &&Range, OutputIt Out)
Provide wrappers to std::move which take ranges instead of having to pass begin/end explicitly.
auto find_if(R &&Range, UnaryPredicate P)
Provide wrappers to std::find_if which take ranges instead of having to pass begin/end explicitly.
bool pred_empty(const BasicBlock *BB)
void DeleteDeadBlocks(ArrayRef< BasicBlock * > BBs, DomTreeUpdater *DTU=nullptr, bool KeepOneInputPHIs=false)
Delete the specified blocks from BB.
void swap(llvm::BitVector &LHS, llvm::BitVector &RHS)
Implement std::swap in terms of BitVector swap.
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)