50#define DEBUG_TYPE "lower-switch"
62bool IsInRanges(
const IntRange &R,
const std::vector<IntRange> &Ranges) {
69 Ranges, R, [](IntRange
A, IntRange
B) {
return A.High.slt(
B.High); });
70 return I !=
Ranges.end() &&
I->Low.sle(
R.Low);
79 :
Low(low),
High(high), BB(bb) {}
82using CaseVector = std::vector<CaseRange>;
83using CaseItr = std::vector<CaseRange>::iterator;
88 bool operator()(
const CaseRange &C1,
const CaseRange &C2) {
89 const ConstantInt *CI1 = cast<const ConstantInt>(C1.Low);
90 const ConstantInt *CI2 = cast<const ConstantInt>(C2.High);
100 for (CaseVector::const_iterator
B =
C.begin(),
E =
C.end();
B !=
E;) {
101 O <<
"[" <<
B->Low->getValue() <<
", " <<
B->High->getValue() <<
"]";
120 const APInt &NumMergedCases) {
121 for (
auto &
I : SuccBB->
phis()) {
126 APInt LocalNumMergedCases = NumMergedCases;
127 for (;
Idx !=
E && NewBB; ++
Idx) {
141 for (; LocalNumMergedCases.
ugt(0) &&
Idx <
E; ++
Idx)
144 LocalNumMergedCases -= 1;
166 if (Leaf.Low == Leaf.High) {
169 new ICmpInst(*NewLeaf, ICmpInst::ICMP_EQ, Val, Leaf.Low,
"SwitchLeaf");
172 if (Leaf.Low == LowerBound) {
174 Comp =
new ICmpInst(*NewLeaf, ICmpInst::ICMP_SLE, Val, Leaf.High,
176 }
else if (Leaf.High == UpperBound) {
178 Comp =
new ICmpInst(*NewLeaf, ICmpInst::ICMP_SGE, Val, Leaf.Low,
180 }
else if (Leaf.Low->isZero()) {
182 Comp =
new ICmpInst(*NewLeaf, ICmpInst::ICMP_ULE, Val, Leaf.High,
188 Val, NegLo, Val->
getName() +
".off", NewLeaf);
190 Comp =
new ICmpInst(*NewLeaf, ICmpInst::ICMP_ULE,
Add, UpperBound,
211 APInt Range = Leaf.High->getValue() - Leaf.Low->getValue();
212 for (
APInt j(
Range.getBitWidth(), 0,
true);
j.slt(Range); ++j) {
217 assert(BlockIdx != -1 &&
"Switch didn't go to this successor??");
233 const std::vector<IntRange> &UnreachableRanges) {
234 assert(LowerBound && UpperBound &&
"Bounds must be initialized");
242 if (Begin->Low == LowerBound && Begin->High == UpperBound) {
244 FixPhis(Begin->BB, OrigBlock, Predecessor, NumMergedCases);
247 return NewLeafBlock(*Begin, Val, LowerBound, UpperBound, OrigBlock,
251 unsigned Mid =
Size / 2;
252 std::vector<CaseRange>
LHS(Begin, Begin + Mid);
254 std::vector<CaseRange>
RHS(Begin + Mid,
End);
257 CaseRange &Pivot = *(Begin + Mid);
258 LLVM_DEBUG(
dbgs() <<
"Pivot ==> [" << Pivot.Low->getValue() <<
", "
259 << Pivot.High->getValue() <<
"]\n");
272 if (!UnreachableRanges.empty()) {
274 APInt GapLow =
LHS.back().High->getValue() + 1;
276 IntRange Gap = {GapLow, GapHigh};
277 if (GapHigh.
sge(GapLow) && IsInRanges(Gap, UnreachableRanges))
278 NewUpperBound =
LHS.back().High;
282 << NewUpperBound->
getValue() <<
"]\n"
283 <<
"RHS Bounds ==> [" << NewLowerBound->
getValue() <<
", "
284 << UpperBound->
getValue() <<
"]\n");
294 SwitchConvert(
LHS.begin(),
LHS.end(), LowerBound, NewUpperBound, Val,
295 NewNode, OrigBlock,
Default, UnreachableRanges);
297 SwitchConvert(
RHS.begin(),
RHS.end(), NewLowerBound, UpperBound, Val,
298 NewNode, OrigBlock,
Default, UnreachableRanges);
310unsigned Clusterify(CaseVector &Cases,
SwitchInst *SI) {
311 unsigned NumSimpleCases = 0;
314 for (
auto Case :
SI->cases()) {
315 if (Case.getCaseSuccessor() ==
SI->getDefaultDest())
317 Cases.push_back(CaseRange(Case.getCaseValue(), Case.getCaseValue(),
318 Case.getCaseSuccessor()));
325 if (Cases.size() >= 2) {
326 CaseItr
I = Cases.begin();
327 for (CaseItr J = std::next(
I),
E = Cases.end(); J !=
E; ++J) {
328 const APInt &nextValue = J->Low->getValue();
329 const APInt ¤tValue =
I->High->getValue();
336 "Cases should be strictly ascending");
337 if ((nextValue == currentValue + 1) && (currentBB == nextBB)) {
340 }
else if (++
I != J) {
344 Cases.erase(std::next(
I), Cases.end());
347 return NumSimpleCases;
357 Value *Val =
SI->getCondition();
362 if ((OrigBlock != &
F->getEntryBlock() &&
pred_empty(OrigBlock)) ||
364 DeleteList.
insert(OrigBlock);
370 const unsigned NumSimpleCases = Clusterify(Cases, SI);
377 LLVM_DEBUG(
dbgs() <<
"Clusterify finished. Total clusters: " << Cases.size()
378 <<
". Total non-default cases: " << NumSimpleCases
379 <<
"\nCase clusters: " << Cases <<
"\n");
385 FixPhis(
Default, OrigBlock, OrigBlock, UnsignedMax);
386 SI->eraseFromParent();
392 bool DefaultIsUnreachableFromSwitch =
false;
394 if (isa<UnreachableInst>(
Default->getFirstNonPHIOrDbg())) {
398 LowerBound = Cases.front().Low;
399 UpperBound = Cases.back().High;
400 DefaultIsUnreachableFromSwitch =
true;
422 const APInt &
Low = Cases.front().Low->getValue();
423 const APInt &
High = Cases.back().High->getValue();
429 DefaultIsUnreachableFromSwitch = (Min + (NumSimpleCases - 1) == Max);
432 std::vector<IntRange> UnreachableRanges;
434 if (DefaultIsUnreachableFromSwitch) {
436 APInt MaxPop(UnsignedZero);
441 IntRange
R = {SignedMin, SignedMax};
442 UnreachableRanges.push_back(R);
443 for (
const auto &
I : Cases) {
447 IntRange &LastRange = UnreachableRanges.back();
448 if (LastRange.Low.eq(
Low)) {
450 UnreachableRanges.pop_back();
454 LastRange.High =
Low - 1;
456 if (
High.ne(SignedMax)) {
457 IntRange
R = {
High + 1, SignedMax};
458 UnreachableRanges.push_back(R);
462 assert(
High.sge(
Low) &&
"Popularity shouldn't be negative.");
466 if ((Pop +=
N).ugt(MaxPop)) {
473 for (
auto I = UnreachableRanges.begin(),
E = UnreachableRanges.end();
478 assert(Next->Low.sgt(
I->High));
485 const unsigned NumDefaultEdges =
SI->getNumCases() + 1 - NumSimpleCases;
486 for (
unsigned I = 0;
I < NumDefaultEdges; ++
I)
487 Default->removePredecessor(OrigBlock);
493 [PopSucc](
const CaseRange &R) {
return R.BB == PopSucc; });
498 SI->eraseFromParent();
501 if (!MaxPop.isZero())
502 for (
APInt I(UnsignedZero);
I.ult(MaxPop - 1); ++
I)
510 Val =
SI->getCondition();
514 SwitchConvert(Cases.begin(), Cases.end(), LowerBound, UpperBound, Val,
515 OrigBlock, OrigBlock,
Default, UnreachableRanges);
523 FixPhis(
Default, OrigBlock,
nullptr, UnsignedMax);
530 SI->eraseFromParent();
534 DeleteList.
insert(OldDefault);
538 bool Changed =
false;
545 if (DeleteList.
count(&Cur))
548 if (
SwitchInst *SI = dyn_cast<SwitchInst>(Cur.getTerminator())) {
550 ProcessSwitchInst(SI, DeleteList, AC, LVI);
581char LowerSwitchLegacyPass::ID = 0;
587 "Lower SwitchInst's to branches",
false,
false)
595 return new LowerSwitchLegacyPass();
598bool LowerSwitchLegacyPass::runOnFunction(
Function &
F) {
599 LazyValueInfo *LVI = &getAnalysis<LazyValueInfoWrapperPass>().getLVI();
600 auto *ACT = getAnalysisIfAvailable<AssumptionCacheTracker>();
602 return LowerSwitch(
F, LVI, AC);
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static cl::opt< ITMode > IT(cl::desc("IT block support"), cl::Hidden, cl::init(DefaultIT), cl::values(clEnumValN(DefaultIT, "arm-default-it", "Generate any type of IT block"), clEnumValN(RestrictedIT, "arm-restrict-it", "Disallow complex IT blocks")))
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
#define LLVM_ATTRIBUTE_USED
This file contains the declarations for the subclasses of Constant, which represent the different fla...
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
This file defines the DenseMap class.
This file provides various utilities for inspecting and working with the control flow graph in LLVM I...
Lower SwitchInst s to branches
This header defines various interfaces for pass management in LLVM.
#define INITIALIZE_PASS_DEPENDENCY(depName)
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
This file defines the SmallPtrSet class.
This file defines the SmallVector class.
Class for arbitrary precision integers.
static APInt getMaxValue(unsigned numBits)
Gets maximum unsigned value of APInt for specific bit width.
bool sgt(const APInt &RHS) const
Signed greater than comparison.
bool ugt(const APInt &RHS) const
Unsigned greater than comparison.
static APInt getSignedMaxValue(unsigned numBits)
Gets maximum signed value of APInt for a specific bit width.
static APInt getSignedMinValue(unsigned numBits)
Gets minimum signed value of APInt for a specific bit width.
bool slt(const APInt &RHS) const
Signed less than comparison.
bool sge(const APInt &RHS) const
Signed greater or equal comparison.
A container for analyses that lazily runs them and caches their results.
PassT::Result * getCachedResult(IRUnitT &IR) const
Get the cached result of an analysis pass for a given IR unit.
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
A function analysis which provides an AssumptionCache.
An immutable pass that tracks lazily created AssumptionCache objects.
A cache of @llvm.assume calls within a function.
LLVM Basic Block Representation.
iterator begin()
Instruction iterator methods.
iterator_range< const_phi_iterator > phis() const
Returns a range that iterates over the phis in the basic block.
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
const BasicBlock * getSinglePredecessor() const
Return the predecessor of this block if it has a single predecessor block.
const Function * getParent() const
Return the enclosing method, or null if none.
InstListType::iterator iterator
Instruction iterators...
void removePredecessor(BasicBlock *Pred, bool KeepOneInputPHIs=false)
Update PHI nodes in this BasicBlock before removal of predecessor Pred.
static BranchInst * Create(BasicBlock *IfTrue, Instruction *InsertBefore=nullptr)
static Constant * getAdd(Constant *C1, Constant *C2, bool HasNUW=false, bool HasNSW=false)
static Constant * getNeg(Constant *C, bool HasNUW=false, bool HasNSW=false)
This is the shared class of boolean and integer constants.
static Constant * get(Type *Ty, uint64_t V, bool IsSigned=false)
If Ty is a vector type, return a Constant with a splat of the given value.
const APInt & getValue() const
Return the constant as an APInt value reference.
This class represents a range of values.
static ConstantRange fromKnownBits(const KnownBits &Known, bool IsSigned)
Initialize a range based on a known bits constraint.
APInt getSignedMin() const
Return the smallest signed value contained in the ConstantRange.
ConstantRange intersectWith(const ConstantRange &CR, PreferredRangeType Type=Smallest) const
Return the range that results from the intersection of this range with another range.
APInt getSignedMax() const
Return the largest signed value contained in the ConstantRange.
This is an important base class in LLVM.
A parsed version of the target data layout string in and methods for querying it.
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
FunctionPass class - This class is used to implement most global optimizations.
virtual bool runOnFunction(Function &F)=0
runOnFunction - Virtual method overriden by subclasses to do the per-function processing of the pass.
This instruction compares its operands according to the predicate given to the constructor.
SymbolTableList< Instruction >::iterator insertInto(BasicBlock *ParentBB, SymbolTableList< Instruction >::iterator It)
Inserts an unlinked instruction into ParentBB at position It and returns the iterator of the inserted...
Class to represent integer types.
Analysis to compute lazy value information.
Wrapper around LazyValueInfo.
This pass computes, caches, and vends lazy value constraint information.
void eraseBlock(BasicBlock *BB)
Inform the analysis cache that we have erased a block.
ConstantRange getConstantRange(Value *V, Instruction *CxtI, bool UndefAllowed=true)
Return the ConstantRange constraint that is known to hold for the specified value at the specified in...
void addIncoming(Value *V, BasicBlock *BB)
Add an incoming value to the end of the PHI list.
void setIncomingBlock(unsigned i, BasicBlock *BB)
Value * removeIncomingValue(unsigned Idx, bool DeletePHIIfEmpty=true)
Remove an incoming value.
Value * getIncomingValueForBlock(const BasicBlock *BB) const
BasicBlock * getIncomingBlock(unsigned i) const
Return incoming basic block number i.
int getBasicBlockIndex(const BasicBlock *BB) const
Return the first index of the specified basic block in the value list for this PHI.
unsigned getNumIncomingValues() const
Return the number of incoming edges.
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
virtual void getAnalysisUsage(AnalysisUsage &) const
getAnalysisUsage - This function should be overriden by passes that need analysis information to do t...
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses none()
Convenience factory function for the empty preserved set.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
LLVM Value Representation.
LLVMContext & getContext() const
All values hold a context through their type.
StringRef getName() const
Return a constant reference to the value's name.
self_iterator getIterator()
This class implements an extremely fast bulk output stream that can only output to a stream.
const APInt & smin(const APInt &A, const APInt &B)
Determine the smaller of two APInts considered to be signed.
const APInt & smax(const APInt &A, const APInt &B)
Determine the larger of two APInts considered to be signed.
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
@ C
The default llvm calling convention, compatible with C.
This is an optimization pass for GlobalISel generic memory operations.
@ Low
Lower the current thread's priority such that it does not affect foreground tasks significantly.
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
void DeleteDeadBlock(BasicBlock *BB, DomTreeUpdater *DTU=nullptr, bool KeepOneInputPHIs=false)
Delete the specified block, which must have no predecessors.
auto reverse(ContainerTy &&C)
void sort(IteratorTy Start, IteratorTy End)
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
auto lower_bound(R &&Range, T &&Value)
Provide wrappers to std::lower_bound which take ranges instead of having to pass begin/end explicitly...
void computeKnownBits(const Value *V, KnownBits &Known, const DataLayout &DL, unsigned Depth=0, AssumptionCache *AC=nullptr, const Instruction *CxtI=nullptr, const DominatorTree *DT=nullptr, bool UseInstrInfo=true)
Determine which bits of V are known to be either zero or one and return them in the KnownZero/KnownOn...
raw_ostream & operator<<(raw_ostream &OS, const APFixedPoint &FX)
constexpr unsigned BitWidth
void erase_if(Container &C, UnaryPredicate P)
Provide a container algorithm similar to C++ Library Fundamentals v2's erase_if which is equivalent t...
FunctionPass * createLowerSwitchPass()
bool pred_empty(const BasicBlock *BB)
void initializeLowerSwitchLegacyPassPass(PassRegistry &)
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)