21using namespace SwitchCG;
26 const APInt &LowCase = Clusters[
First].Low->getValue();
27 const APInt &HighCase = Clusters[
Last].High->getValue();
33 return (HighCase - LowCase).getLimitedValue((
UINT64_MAX - 1) / 100) + 1;
56 for (
unsigned i = 1, e = Clusters.size(); i < e; ++i)
57 assert(Clusters[i - 1].
High->getValue().slt(Clusters[i].Low->getValue()));
60 assert(TLI &&
"TLI not set!");
65 const unsigned SmallNumberOfEntries = MinJumpTableEntries / 2;
68 const int64_t
N = Clusters.size();
69 if (
N < 2 ||
N < MinJumpTableEntries)
74 for (
unsigned i = 0; i <
N; ++i) {
75 const APInt &
Hi = Clusters[i].High->getValue();
76 const APInt &
Lo = Clusters[i].Low->getValue();
77 TotalCases[i] = (
Hi -
Lo).getLimitedValue() + 1;
79 TotalCases[i] += TotalCases[i - 1];
91 Clusters[0] = JTCluster;
118 enum PartitionScores :
unsigned {
126 MinPartitions[
N - 1] = 1;
127 LastElement[
N - 1] =
N - 1;
128 PartitionsScore[
N - 1] = PartitionScores::SingleCase;
131 for (int64_t i =
N - 2; i >= 0; i--) {
134 MinPartitions[i] = MinPartitions[i + 1] + 1;
136 PartitionsScore[i] = PartitionsScore[i + 1] + PartitionScores::SingleCase;
139 for (int64_t j =
N - 1; j > i; j--) {
144 assert(Range >= NumCases);
147 unsigned NumPartitions = 1 + (j ==
N - 1 ? 0 : MinPartitions[j + 1]);
148 unsigned Score = j ==
N - 1 ? 0 : PartitionsScore[j + 1];
149 int64_t NumEntries = j - i + 1;
152 Score += PartitionScores::SingleCase;
153 else if (NumEntries <= SmallNumberOfEntries)
154 Score += PartitionScores::FewCases;
155 else if (NumEntries >= MinJumpTableEntries)
156 Score += PartitionScores::Table;
160 if (NumPartitions < MinPartitions[i] ||
161 (NumPartitions == MinPartitions[i] && Score > PartitionsScore[i])) {
162 MinPartitions[i] = NumPartitions;
164 PartitionsScore[i] = Score;
171 unsigned DstIndex = 0;
179 if (NumClusters >= MinJumpTableEntries &&
181 Clusters[DstIndex++] = JTCluster;
184 std::memmove(&Clusters[DstIndex++], &Clusters[
I],
sizeof(Clusters[
I]));
187 Clusters.resize(DstIndex);
198 unsigned NumCmps = 0;
199 std::vector<MachineBasicBlock*> Table;
208 Prob += Clusters[
I].Prob;
209 const APInt &
Low = Clusters[
I].Low->getValue();
210 const APInt &
High = Clusters[
I].High->getValue();
211 NumCmps += (
Low ==
High) ? 1 : 2;
214 const APInt &PreviousHigh = Clusters[
I - 1].High->getValue();
216 uint64_t Gap = (
Low - PreviousHigh).getLimitedValue() - 1;
218 Table.push_back(DefaultMBB);
221 for (
uint64_t J = 0; J < ClusterSize; ++J)
222 Table.push_back(Clusters[
I].MBB);
223 JTProbs[Clusters[
I].MBB] += Clusters[
I].Prob;
226 unsigned NumDests = JTProbs.
size();
227 if (TLI->isSuitableForBitTests(NumDests, NumCmps,
228 Clusters[
First].Low->getValue(),
229 Clusters[
Last].High->getValue(), *
DL)) {
243 if (
Done.count(Succ))
245 addSuccessorWithProb(JumpTableMBB, Succ, JTProbs[Succ]);
254 JumpTable JT(-1U, JTI, JumpTableMBB,
nullptr);
256 Clusters[
Last].High->getValue(), SI->getCondition(),
258 JTCases.emplace_back(std::move(JTH), std::move(JT));
261 JTCases.size() - 1, Prob);
272 assert(!Clusters.empty());
276 for (
unsigned i = 1; i < Clusters.size(); ++i)
277 assert(Clusters[i-1].
High->getValue().slt(Clusters[i].Low->getValue()));
285 EVT PTy = TLI->getPointerTy(*
DL);
286 if (!TLI->isOperationLegal(
ISD::SHL, PTy))
290 const int64_t
N = Clusters.size();
300 MinPartitions[
N - 1] = 1;
301 LastElement[
N - 1] =
N - 1;
304 for (int64_t i =
N - 2; i >= 0; --i) {
307 MinPartitions[i] = MinPartitions[i + 1] + 1;
312 for (int64_t j = std::min(
N - 1, i +
BitWidth - 1); j > i; --j) {
316 if (!TLI->rangeFitsInWord(Clusters[i].Low->getValue(),
317 Clusters[j].High->getValue(), *
DL))
322 bool RangesOnly =
true;
323 BitVector Dests(FuncInfo.MF->getNumBlockIDs());
324 for (int64_t k = i; k <= j; k++) {
331 if (!RangesOnly || Dests.
count() > 3)
335 unsigned NumPartitions = 1 + (j ==
N - 1 ? 0 : MinPartitions[j + 1]);
336 if (NumPartitions < MinPartitions[i]) {
338 MinPartitions[i] = NumPartitions;
345 unsigned DstIndex = 0;
352 if (buildBitTests(Clusters,
First,
Last, SI, BitTestCluster)) {
353 Clusters[DstIndex++] = BitTestCluster;
356 std::memmove(&Clusters[DstIndex], &Clusters[
First],
357 sizeof(Clusters[0]) * NumClusters);
358 DstIndex += NumClusters;
361 Clusters.resize(DstIndex);
372 BitVector Dests(FuncInfo.MF->getNumBlockIDs());
373 unsigned NumCmps = 0;
377 NumCmps += (Clusters[
I].Low == Clusters[
I].High) ? 1 : 2;
379 unsigned NumDests = Dests.
count();
385 if (!TLI->isSuitableForBitTests(NumDests, NumCmps,
Low,
High, *
DL))
391 const int BitWidth = TLI->getPointerTy(*DL).getSizeInBits();
393 "Case range must fit in bit mask!");
397 bool ContiguousRange =
true;
399 if (Clusters[
I].
Low->getValue() != Clusters[
I - 1].High->getValue() + 1) {
400 ContiguousRange =
false;
410 ContiguousRange =
false;
418 for (
unsigned i =
First; i <=
Last; ++i) {
421 for (j = 0; j < CBV.size(); ++j)
422 if (CBV[j].BB == Clusters[i].
MBB)
430 uint64_t Lo = (Clusters[i].Low->getValue() - LowBound).getZExtValue();
431 uint64_t Hi = (Clusters[i].High->getValue() - LowBound).getZExtValue();
433 CB->
Mask |= (-1ULL >> (63 - (
Hi -
Lo))) <<
Lo;
436 TotalProb += Clusters[i].Prob;
445 return a.
Bits > b.Bits;
446 return a.
Mask < b.Mask;
449 for (
auto &CB : CBV) {
451 FuncInfo.MF->CreateMachineBasicBlock(SI->getParent());
454 BitTestCases.emplace_back(std::move(LowBound), std::move(CmpRange),
455 SI->getCondition(), -1U, MVT::Other,
false,
456 ContiguousRange,
nullptr,
nullptr, std::move(BTI),
460 BitTestCases.size() - 1, TotalProb);
467 assert(
CC.Low ==
CC.High &&
"Input clusters must be single-case");
475 const unsigned N = Clusters.size();
476 unsigned DstIndex = 0;
477 for (
unsigned SrcIndex = 0; SrcIndex <
N; ++SrcIndex) {
482 if (DstIndex != 0 && Clusters[DstIndex - 1].
MBB == Succ &&
483 (CaseVal->
getValue() - Clusters[DstIndex - 1].High->getValue()) == 1) {
486 Clusters[DstIndex - 1].High = CaseVal;
487 Clusters[DstIndex - 1].Prob +=
CC.Prob;
489 std::memmove(&Clusters[DstIndex++], &Clusters[SrcIndex],
490 sizeof(Clusters[SrcIndex]));
493 Clusters.resize(DstIndex);
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
const char LLVMTargetMachineRef TM
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
This file describes how to lower LLVM code to machine code.
Class for arbitrary precision integers.
unsigned getBitWidth() const
Return the number of bits in the APInt.
bool slt(const APInt &RHS) const
Signed less than comparison.
static APInt getZero(unsigned numBits)
Get the '0' value for the specified bit-width.
size_type count() const
count - Returns the number of bits which are set.
BlockFrequencyInfo pass uses BlockFrequencyInfoImpl implementation to estimate IR basic block frequen...
static BranchProbability getZero()
This is the shared class of boolean and integer constants.
const APInt & getValue() const
Return the constant as an APInt value reference.
void normalizeSuccProbs()
Normalize probabilities of all successors so that the sum of them becomes one.
int getNumber() const
MachineBasicBlocks are uniquely numbered at the function level, unless they're not in a MachineFuncti...
MachineBasicBlock * CreateMachineBasicBlock(const BasicBlock *bb=nullptr)
CreateMachineBasicBlock - Allocate a new MachineBasicBlock.
MachineJumpTableInfo * getOrCreateJumpTableInfo(unsigned JTEntryKind)
getOrCreateJumpTableInfo - Get the JumpTableInfo for this function, if it does already exist,...
unsigned createJumpTableIndex(const std::vector< MachineBasicBlock * > &DestBBs)
createJumpTableIndex - Create a new jump table.
Analysis providing profile information.
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
bool buildBitTests(CaseClusterVector &Clusters, unsigned First, unsigned Last, const SwitchInst *SI, CaseCluster &BTCluster)
Build a bit test cluster from Clusters[First..Last].
bool buildJumpTable(const CaseClusterVector &Clusters, unsigned First, unsigned Last, const SwitchInst *SI, MachineBasicBlock *DefaultMBB, CaseCluster &JTCluster)
void findJumpTables(CaseClusterVector &Clusters, const SwitchInst *SI, MachineBasicBlock *DefaultMBB, ProfileSummaryInfo *PSI, BlockFrequencyInfo *BFI)
void findBitTestClusters(CaseClusterVector &Clusters, const SwitchInst *SI)
virtual unsigned getMinimumJumpTableEntries() const
Return lower limit for number of blocks in a jump table.
virtual bool isSuitableForJumpTable(const SwitchInst *SI, uint64_t NumCases, uint64_t Range, ProfileSummaryInfo *PSI, BlockFrequencyInfo *BFI) const
Return true if lowering to a jump table is suitable for a set of case clusters which may contain NumC...
virtual bool areJTsAllowed(const Function *Fn) const
Return true if lowering to a jump table is allowed.
CodeGenOptLevel getOptLevel() const
Returns the optimization level: None, Less, Default, or Aggressive.
@ C
The default llvm calling convention, compatible with C.
@ SHL
Shift and rotation operations.
uint64_t getJumpTableNumCases(const SmallVectorImpl< unsigned > &TotalCases, unsigned First, unsigned Last)
Return the number of cases within a range.
std::vector< CaseCluster > CaseClusterVector
void sortAndRangeify(CaseClusterVector &Clusters)
Sort Clusters and merge adjacent cases.
uint64_t getJumpTableRange(const CaseClusterVector &Clusters, unsigned First, unsigned Last)
Return the range of values within a range.
std::vector< CaseBits > CaseBitsVector
@ CC_Range
A cluster of adjacent case labels with the same destination, or just one case.
@ CC_JumpTable
A cluster of cases suitable for jump table lowering.
This is an optimization pass for GlobalISel generic memory operations.
@ Low
Lower the current thread's priority such that it does not affect foreground tasks significantly.
void sort(IteratorTy Start, IteratorTy End)
@ First
Helpers to iterate all locations in the MemoryEffectsBase class.
constexpr unsigned BitWidth
TypeSize getSizeInBits() const
Return the size of the specified value type in bits.
BranchProbability ExtraProb
A cluster of case labels.
static CaseCluster jumpTable(const ConstantInt *Low, const ConstantInt *High, unsigned JTCasesIndex, BranchProbability Prob)
static CaseCluster bitTests(const ConstantInt *Low, const ConstantInt *High, unsigned BTCasesIndex, BranchProbability Prob)