LLVM 23.0.0git
JumpTableToSwitch.cpp
Go to the documentation of this file.
1//===- JumpTableToSwitch.cpp ----------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "llvm/ADT/DenseSet.h"
11#include "llvm/ADT/STLExtras.h"
13#include "llvm/ADT/Statistic.h"
18#include "llvm/IR/IRBuilder.h"
19#include "llvm/IR/LLVMContext.h"
23#include "llvm/Support/Error.h"
25#include <limits>
26
27using namespace llvm;
28
30 JumpTableSizeThreshold("jump-table-to-switch-size-threshold", cl::Hidden,
31 cl::desc("Only split jump tables with size less or "
32 "equal than JumpTableSizeThreshold."),
33 cl::init(10));
34
35// TODO: Consider adding a cost model for profitability analysis of this
36// transformation. Currently we replace a jump table with a switch if all the
37// functions in the jump table are smaller than the provided threshold.
39 "jump-table-to-switch-function-size-threshold", cl::Hidden,
40 cl::desc("Only split jump tables containing functions whose sizes are less "
41 "or equal than this threshold."),
42 cl::init(50));
43
44namespace llvm {
46} // end namespace llvm
47
48#define DEBUG_TYPE "jump-table-to-switch"
49
50STATISTIC(NumEligibleJumpTables, "The number of jump tables seen by the pass "
51 "that can be converted if deemed profitable.");
52STATISTIC(NumJumpTablesConverted,
53 "The number of jump tables converted into switches.");
54
55namespace {
56struct JumpTableTy {
57 Value *Index;
59};
60} // anonymous namespace
61
62static std::optional<JumpTableTy> parseJumpTable(GetElementPtrInst *GEP,
63 PointerType *PtrTy) {
64 Constant *Ptr = dyn_cast<Constant>(GEP->getPointerOperand());
65 if (!Ptr)
66 return std::nullopt;
67
69 if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer())
70 return std::nullopt;
71
72 Function &F = *GEP->getParent()->getParent();
73 const DataLayout &DL = F.getDataLayout();
74 const unsigned BitWidth =
75 DL.getIndexSizeInBits(GEP->getPointerAddressSpace());
77 APInt ConstantOffset(BitWidth, 0);
78 if (!GEP->collectOffset(DL, BitWidth, VariableOffsets, ConstantOffset))
79 return std::nullopt;
80 if (VariableOffsets.size() != 1)
81 return std::nullopt;
82 // TODO: consider supporting more general patterns
83 if (!ConstantOffset.isZero())
84 return std::nullopt;
85 APInt StrideBytes = VariableOffsets.front().second;
86 const uint64_t JumpTableSizeBytes = GV->getGlobalSize(DL);
87 if (JumpTableSizeBytes % StrideBytes.getZExtValue() != 0)
88 return std::nullopt;
89 ++NumEligibleJumpTables;
90 const uint64_t N = JumpTableSizeBytes / StrideBytes.getZExtValue();
92 return std::nullopt;
93
94 JumpTableTy JumpTable;
95 JumpTable.Index = VariableOffsets.front().first;
96 JumpTable.Funcs.reserve(N);
97 for (uint64_t Index = 0; Index < N; ++Index) {
98 // ConstantOffset is zero.
99 APInt Offset = Index * StrideBytes;
100 Constant *C =
102 auto *Func = dyn_cast_or_null<Function>(C);
103 if (!Func || Func->isDeclaration() ||
104 Func->getInstructionCount() > FunctionSizeThreshold)
105 return std::nullopt;
106 JumpTable.Funcs.push_back(Func);
107 }
108 return JumpTable;
109}
110
111static BasicBlock *
112expandToSwitch(CallBase *CB, const JumpTableTy &JT, DomTreeUpdater &DTU,
115 GetGuidForFunction) {
116 ++NumJumpTablesConverted;
117 const bool IsVoid = CB->getType() == Type::getVoidTy(CB->getContext());
118
120 BasicBlock *BB = CB->getParent();
121 BasicBlock *Tail = SplitBlock(BB, CB, &DTU, nullptr, nullptr,
122 BB->getName() + Twine(".tail"));
123 DTUpdates.push_back({DominatorTree::Delete, BB, Tail});
125
126 Function &F = *BB->getParent();
127 BasicBlock *BBUnreachable = BasicBlock::Create(
128 F.getContext(), "default.switch.case.unreachable", &F, Tail);
129 IRBuilder<> BuilderUnreachable(BBUnreachable);
130 BuilderUnreachable.CreateUnreachable();
131
132 IRBuilder<> Builder(BB);
133 SwitchInst *Switch = Builder.CreateSwitch(JT.Index, BBUnreachable);
134 DTUpdates.push_back({DominatorTree::Insert, BB, BBUnreachable});
135
136 IRBuilder<> BuilderTail(CB);
137 PHINode *PHI =
138 IsVoid ? nullptr : BuilderTail.CreatePHI(CB->getType(), JT.Funcs.size());
139 const auto *ProfMD = CB->getMetadata(LLVMContext::MD_prof);
140
141 SmallVector<uint64_t> BranchWeights;
143 const bool HadProfile = isValueProfileMD(ProfMD);
144 if (HadProfile) {
145 // The assumptions, coming in, are that the functions in JT.Funcs are
146 // defined in this module (from parseJumpTable).
148 JT.Funcs, [](const Function *F) { return F && !F->isDeclaration(); }));
149 BranchWeights.reserve(JT.Funcs.size() + 1);
150 // The first is the default target, which is the unreachable block created
151 // above.
152 BranchWeights.push_back(0U);
153 uint64_t TotalCount = 0;
154 auto Targets = getValueProfDataFromInst(
155 *CB, InstrProfValueKind::IPVK_IndirectCallTarget,
156 std::numeric_limits<uint32_t>::max(), TotalCount);
157
158 for (const auto &[G, C] : Targets) {
159 [[maybe_unused]] auto It = GuidToCounter.insert({G, C});
160 // We should always be inserting as it is verifier-enforced IR invariant
161 // that VP metadata does not have duplicate values.
162 assert(It.second);
163 }
164 }
165 for (auto [Index, Func] : llvm::enumerate(JT.Funcs)) {
166 BasicBlock *B = BasicBlock::Create(Func->getContext(),
167 "call." + Twine(Index), &F, Tail);
168 DTUpdates.push_back({DominatorTree::Insert, BB, B});
169 DTUpdates.push_back({DominatorTree::Insert, B, Tail});
170
172 // The MD_prof metadata (VP kind), if it existed, can be dropped, it doesn't
173 // make sense on a direct call. Note that the values are used for the branch
174 // weights of the switch.
175 Call->setMetadata(LLVMContext::MD_prof, nullptr);
176 Call->setCalledFunction(Func);
177 Call->insertInto(B, B->end());
178 Switch->addCase(
179 cast<ConstantInt>(ConstantInt::get(JT.Index->getType(), Index)), B);
180 GlobalValue::GUID FctID = GetGuidForFunction(*Func);
181 // It'd be OK to _not_ find target functions in GuidToCounter, e.g. suppose
182 // just some of the jump targets are taken (for the given profile).
183 BranchWeights.push_back(FctID == 0U ? 0U
184 : GuidToCounter.lookup_or(FctID, 0U));
186 if (PHI)
187 PHI->addIncoming(Call, B);
188 }
189 DTU.applyUpdates(DTUpdates);
190 ORE.emit([&]() {
191 return OptimizationRemark(DEBUG_TYPE, "ReplacedJumpTableWithSwitch", CB)
192 << "expanded indirect call into switch";
193 });
194 // Only set branch weights on the switch if we have non-zero branch weights.
195 // We can have no non-zero branch weights while having VP metadata if for
196 // example, all of the functions are external and not instrumented.
197 if (HadProfile && !ProfcheckDisableMetadataFixes &&
198 llvm::any_of(BranchWeights, not_equal_to(0))) {
199 setBranchWeights(*Switch, downscaleWeights(BranchWeights),
200 /*IsExpected=*/false);
201 } else
203 if (PHI)
205 CB->eraseFromParent();
206 return Tail;
207}
208
215 DomTreeUpdater DTU(DT, PDT, DomTreeUpdater::UpdateStrategy::Lazy);
216 bool Changed = false;
217 auto FuncToGuid = [&](const Function &Fct) {
218 if (const auto MaybeGUID = Fct.getGUIDIfAssigned(); MaybeGUID)
219 return *MaybeGUID;
220
222 getIRPGOFuncName(Fct, InLTO));
223 };
224
225 for (BasicBlock &BB : make_early_inc_range(F)) {
226 BasicBlock *CurrentBB = &BB;
227 while (CurrentBB) {
228 BasicBlock *SplittedOutTail = nullptr;
229 for (Instruction &I : make_early_inc_range(*CurrentBB)) {
230 auto *Call = dyn_cast<CallInst>(&I);
231 if (!Call || Call->getCalledFunction() || Call->isMustTailCall())
232 continue;
233 auto *L = dyn_cast<LoadInst>(Call->getCalledOperand());
234 // Skip atomic or volatile loads.
235 if (!L || !L->isSimple())
236 continue;
237 auto *GEP = dyn_cast<GetElementPtrInst>(L->getPointerOperand());
238 if (!GEP)
239 continue;
240 auto *PtrTy = dyn_cast<PointerType>(L->getType());
241 assert(PtrTy && "call operand must be a pointer");
242 std::optional<JumpTableTy> JumpTable = parseJumpTable(GEP, PtrTy);
243 if (!JumpTable)
244 continue;
245 SplittedOutTail =
246 expandToSwitch(Call, *JumpTable, DTU, ORE, FuncToGuid);
247 Changed = true;
248 break;
249 }
250 CurrentBB = SplittedOutTail ? SplittedOutTail : nullptr;
251 }
252 }
253
254 if (!Changed)
255 return PreservedAnalyses::all();
256
258 if (DT)
260 if (PDT)
262 return PA;
263}
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
Rewrite undef for PHI
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
This file defines the DenseSet and SmallDenseSet classes.
#define DEBUG_TYPE
Hexagon Common GEP
static cl::opt< unsigned > FunctionSizeThreshold("jump-table-to-switch-function-size-threshold", cl::Hidden, cl::desc("Only split jump tables containing functions whose sizes are less " "or equal than this threshold."), cl::init(50))
static BasicBlock * expandToSwitch(CallBase *CB, const JumpTableTy &JT, DomTreeUpdater &DTU, OptimizationRemarkEmitter &ORE, llvm::function_ref< GlobalValue::GUID(const Function &)> GetGuidForFunction)
static cl::opt< unsigned > JumpTableSizeThreshold("jump-table-to-switch-size-threshold", cl::Hidden, cl::desc("Only split jump tables with size less or " "equal than JumpTableSizeThreshold."), cl::init(10))
static std::optional< JumpTableTy > parseJumpTable(GetElementPtrInst *GEP, PointerType *PtrTy)
#define F(x, y, z)
Definition MD5.cpp:54
#define I(x, y, z)
Definition MD5.cpp:57
#define G(x, y, z)
Definition MD5.cpp:55
This file contains the declarations for profiling metadata utility functions.
This file contains some templates that are useful if you are working with the STL at all.
This file defines the SmallVector class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition Statistic.h:171
Class for arbitrary precision integers.
Definition APInt.h:78
uint64_t getZExtValue() const
Get zero extended value.
Definition APInt.h:1563
bool isZero() const
Determine if this value is zero, i.e. all bits are clear.
Definition APInt.h:381
PassT::Result * getCachedResult(IRUnitT &IR) const
Get the cached result of an analysis pass for a given IR unit.
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
LLVM Basic Block Representation.
Definition BasicBlock.h:62
const Function * getParent() const
Return the enclosing method, or null if none.
Definition BasicBlock.h:213
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
Definition BasicBlock.h:206
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction; assumes that the block is well-formed.
Definition BasicBlock.h:237
Base class for all callable instructions (InvokeInst and CallInst) Holds everything related to callin...
This is an important base class in LLVM.
Definition Constant.h:43
A parsed version of the target data layout string in and methods for querying it.
Definition DataLayout.h:64
ValueT lookup_or(const_arg_type_t< KeyT > Val, U &&Default) const
Definition DenseMap.h:215
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition DenseMap.h:239
Analysis pass which computes a DominatorTree.
Definition Dominators.h:278
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition Dominators.h:159
void applyUpdates(ArrayRef< UpdateT > Updates)
Submit updates to all available trees.
an instruction for type-safe pointer arithmetic to access elements of arrays and structs
static LLVM_ABI GUID getGUIDAssumingExternalLinkage(StringRef GlobalName)
Return a 64-bit global unique ID constructed from the name of a global symbol.
Definition Globals.cpp:80
uint64_t GUID
Declare a type to represent a global unique identifier for a global value.
const Constant * getInitializer() const
getInitializer - Return the initializer for this global variable.
LLVM_ABI uint64_t getGlobalSize(const DataLayout &DL) const
Get the size of this global variable in bytes.
Definition Globals.cpp:624
bool isConstant() const
If the value is a global constant, its value is immutable throughout the runtime execution of the pro...
bool hasDefinitiveInitializer() const
hasDefinitiveInitializer - Whether the global variable has an initializer, and any other instances of...
UnreachableInst * CreateUnreachable()
Definition IRBuilder.h:1380
PHINode * CreatePHI(Type *Ty, unsigned NumReservedValues, const Twine &Name="")
Definition IRBuilder.h:2539
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition IRBuilder.h:2858
LLVM_ABI Instruction * clone() const
Create a copy of 'this' instruction that is identical in all ways except the following:
LLVM_ABI InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
MDNode * getMetadata(unsigned KindID) const
Get the metadata of given kind attached to this Instruction.
LLVM_ABI PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM)
Run the pass over the function.
size_type size() const
Definition MapVector.h:58
std::pair< KeyT, ValueT > & front()
Definition MapVector.h:81
The optimization diagnostic interface.
LLVM_ABI void emit(DiagnosticInfoOptimizationBase &OptDiag)
Output the remark via the diagnostic handler and to the optimization record file.
Diagnostic information for applied optimization remarks.
Analysis pass which computes a PostDominatorTree.
PostDominatorTree Class - Concrete subclass of DominatorTree that is used to compute the post-dominat...
A set of analyses that are preserved following a run of a transformation pass.
Definition Analysis.h:112
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition Analysis.h:118
PreservedAnalyses & preserve()
Mark an analysis as preserved.
Definition Analysis.h:132
void reserve(size_type N)
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Multiway switch.
Twine - A lightweight data structure for efficiently representing the concatenation of temporary valu...
Definition Twine.h:82
static LLVM_ABI Type * getVoidTy(LLVMContext &C)
Definition Type.cpp:282
static UncondBrInst * Create(BasicBlock *Target, InsertPosition InsertBefore=nullptr)
LLVM Value Representation.
Definition Value.h:75
Type * getType() const
All values are typed, get the type of this value.
Definition Value.h:255
LLVM_ABI void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition Value.cpp:549
LLVMContext & getContext() const
All values hold a context through their type.
Definition Value.h:258
LLVM_ABI StringRef getName() const
Return a constant reference to the value's name.
Definition Value.cpp:318
An efficient, type-erasing, non-owning reference to a callable.
const ParentTy * getParent() const
Definition ilist_node.h:34
CallInst * Call
Changed
@ Tail
Attemps to make calls as fast as possible while guaranteeing that tail call optimization can always b...
Definition CallingConv.h:76
@ C
The default llvm calling convention, compatible with C.
Definition CallingConv.h:34
initializer< Ty > init(const Ty &Val)
This is an optimization pass for GlobalISel generic memory operations.
@ Offset
Definition DWP.cpp:558
constexpr auto not_equal_to(T &&Arg)
Functor variant of std::not_equal_to that can be used as a UnaryPredicate in functional algorithms li...
Definition STLExtras.h:2179
cl::opt< bool > ProfcheckDisableMetadataFixes
Definition LoopInfo.cpp:60
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1738
LLVM_ABI std::string getIRPGOFuncName(const Function &F, bool InLTO=false)
auto enumerate(FirstRange &&First, RestRanges &&...Rest)
Given two or more input ranges, returns a new range whose values are tuples (A, B,...
Definition STLExtras.h:2553
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:643
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...
Definition STLExtras.h:633
LLVM_ABI void setExplicitlyUnknownBranchWeights(Instruction &I, StringRef PassName)
Specify that the branch weights for this terminator cannot be known at compile time.
LLVM_ABI void setBranchWeights(Instruction &I, ArrayRef< uint32_t > Weights, bool IsExpected, bool ElideAllZero=false)
Create a new branch_weights metadata node and add or overwrite a prof metadata reference to instructi...
auto dyn_cast_or_null(const Y &Val)
Definition Casting.h:753
bool any_of(R &&range, UnaryPredicate P)
Provide wrappers to std::any_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1745
LLVM_ABI Constant * ConstantFoldLoadFromConst(Constant *C, Type *Ty, const APInt &Offset, const DataLayout &DL)
Extract value of C at the given Offset reinterpreted as Ty.
LLVM_ABI SmallVector< InstrProfValueData, 4 > getValueProfDataFromInst(const Instruction &Inst, InstrProfValueKind ValueKind, uint32_t MaxNumValueData, uint64_t &TotalC, bool GetNoICPValue=false)
Extract the value profile data from Inst and returns them if Inst is annotated with value profile dat...
LLVM_ABI bool isValueProfileMD(const MDNode *ProfileData)
Checks if an MDNode contains value profiling Metadata.
LLVM_ABI BasicBlock * SplitBlock(BasicBlock *Old, BasicBlock::iterator SplitPt, DominatorTree *DT, LoopInfo *LI=nullptr, MemorySSAUpdater *MSSAU=nullptr, const Twine &BBName="")
Split the specified block at the specified instruction.
constexpr unsigned BitWidth
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:559
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.
LLVM_ABI SmallVector< uint32_t > downscaleWeights(ArrayRef< uint64_t > Weights, std::optional< uint64_t > KnownMaxCount=std::nullopt)
downscale the given weights preserving the ratio.
#define N
A MapVector that performs no allocations if smaller than a certain size.
Definition MapVector.h:334