LLVM 23.0.0git
BranchProbability.h
Go to the documentation of this file.
1//===- BranchProbability.h - Branch Probability Wrapper ---------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Definition of BranchProbability shared by IR and Machine Instructions.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef LLVM_SUPPORT_BRANCHPROBABILITY_H
14#define LLVM_SUPPORT_BRANCHPROBABILITY_H
15
16#include "llvm/ADT/ADL.h"
19#include <algorithm>
20#include <cassert>
21#include <iterator>
22#include <numeric>
23
24namespace llvm {
25
26class raw_ostream;
27
28// This class represents Branch Probability as a non-negative fraction that is
29// no greater than 1. It uses a fixed-point-like implementation, in which the
30// denominator is always a constant value (here we use 1<<31 for maximum
31// precision).
32class BranchProbability {
33 // Numerator
34 uint32_t N;
35
36 // Denominator, which is a constant value.
37 static constexpr uint32_t D = 1u << 31;
38 static constexpr uint32_t UnknownN = UINT32_MAX;
39
40 // Construct a BranchProbability with only numerator assuming the denominator
41 // is 1<<31. For internal use only.
42 explicit BranchProbability(uint32_t n) : N(n) {}
43
44public:
45 BranchProbability() : N(UnknownN) {}
46 LLVM_ABI BranchProbability(uint32_t Numerator, uint32_t Denominator);
47
48 bool isZero() const { return N == 0; }
49 bool isOne() const { return N == D; }
50 bool isUnknown() const { return N == UnknownN; }
51
52 static BranchProbability getZero() { return BranchProbability(0); }
53 static BranchProbability getOne() { return BranchProbability(D); }
54 static BranchProbability getUnknown() { return BranchProbability(UnknownN); }
55 // Create a BranchProbability object with the given numerator and 1<<31
56 // as denominator.
57 static BranchProbability getRaw(uint32_t N) { return BranchProbability(N); }
58 // Create a BranchProbability object from 64-bit integers.
60 uint64_t Denominator);
61 // Create a BranchProbability from a double, which must be from 0 to 1.
63
64 // Normalize given probabilties so that the sum of them becomes approximate
65 // one.
66 template <class ProbabilityIter>
67 static void normalizeProbabilities(ProbabilityIter Begin,
68 ProbabilityIter End);
69
70 template <class ProbabilityContainer>
71 static void normalizeProbabilities(ProbabilityContainer &&R) {
73 }
74
75 uint32_t getNumerator() const { return N; }
76 static uint32_t getDenominator() { return D; }
77 double toDouble() const { return static_cast<double>(N) / D; }
78
79 // Return (1 - Probability).
80 BranchProbability getCompl() const { return BranchProbability(D - N); }
81
83
84#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
85 LLVM_DUMP_METHOD void dump() const;
86#endif
87
88 /// Scale a large integer.
89 ///
90 /// Scales \c Num. Guarantees full precision. Returns the floor of the
91 /// result.
92 ///
93 /// \return \c Num times \c this.
95
96 /// Scale a large integer by the inverse.
97 ///
98 /// Scales \c Num by the inverse of \c this. Guarantees full precision.
99 /// Returns the floor of the result.
100 ///
101 /// \return \c Num divided by \c this.
103
104 /// Compute pow(Probability, N).
105 LLVM_ABI BranchProbability pow(unsigned N) const;
106
107 BranchProbability &operator+=(BranchProbability RHS) {
108 assert(N != UnknownN && RHS.N != UnknownN &&
109 "Unknown probability cannot participate in arithmetics.");
110 // Saturate the result in case of overflow.
111 N = (uint64_t(N) + RHS.N > D) ? D : N + RHS.N;
112 return *this;
113 }
114
115 BranchProbability &operator-=(BranchProbability RHS) {
116 assert(N != UnknownN && RHS.N != UnknownN &&
117 "Unknown probability cannot participate in arithmetics.");
118 // Saturate the result in case of underflow.
119 N = N < RHS.N ? 0 : N - RHS.N;
120 return *this;
121 }
122
123 BranchProbability &operator*=(BranchProbability RHS) {
124 assert(N != UnknownN && RHS.N != UnknownN &&
125 "Unknown probability cannot participate in arithmetics.");
126 N = (static_cast<uint64_t>(N) * RHS.N + D / 2) / D;
127 return *this;
128 }
129
130 BranchProbability &operator*=(uint32_t RHS) {
131 assert(N != UnknownN &&
132 "Unknown probability cannot participate in arithmetics.");
133 N = (uint64_t(N) * RHS > D) ? D : N * RHS;
134 return *this;
135 }
136
137 BranchProbability &operator/=(BranchProbability RHS) {
138 assert(N != UnknownN && RHS.N != UnknownN &&
139 "Unknown probability cannot participate in arithmetics.");
140 N = (static_cast<uint64_t>(N) * D + RHS.N / 2) / RHS.N;
141 return *this;
142 }
143
144 BranchProbability &operator/=(uint32_t RHS) {
145 assert(N != UnknownN &&
146 "Unknown probability cannot participate in arithmetics.");
147 assert(RHS > 0 && "The divider cannot be zero.");
148 N /= RHS;
149 return *this;
150 }
151
152 BranchProbability operator+(BranchProbability RHS) const {
153 BranchProbability Prob(*this);
154 Prob += RHS;
155 return Prob;
156 }
157
158 BranchProbability operator-(BranchProbability RHS) const {
159 BranchProbability Prob(*this);
160 Prob -= RHS;
161 return Prob;
162 }
163
164 BranchProbability operator*(BranchProbability RHS) const {
165 BranchProbability Prob(*this);
166 Prob *= RHS;
167 return Prob;
168 }
169
170 BranchProbability operator*(uint32_t RHS) const {
171 BranchProbability Prob(*this);
172 Prob *= RHS;
173 return Prob;
174 }
175
176 BranchProbability operator/(BranchProbability RHS) const {
177 BranchProbability Prob(*this);
178 Prob /= RHS;
179 return Prob;
180 }
181
182 BranchProbability operator/(uint32_t RHS) const {
183 BranchProbability Prob(*this);
184 Prob /= RHS;
185 return Prob;
186 }
187
188 bool operator==(BranchProbability RHS) const { return N == RHS.N; }
189 bool operator!=(BranchProbability RHS) const { return !(*this == RHS); }
190
191 bool operator<(BranchProbability RHS) const {
192 assert(N != UnknownN && RHS.N != UnknownN &&
193 "Unknown probability cannot participate in comparisons.");
194 return N < RHS.N;
195 }
196
197 bool operator>(BranchProbability RHS) const {
198 assert(N != UnknownN && RHS.N != UnknownN &&
199 "Unknown probability cannot participate in comparisons.");
200 return RHS < *this;
201 }
202
203 bool operator<=(BranchProbability RHS) const {
204 assert(N != UnknownN && RHS.N != UnknownN &&
205 "Unknown probability cannot participate in comparisons.");
206 return !(RHS < *this);
207 }
208
209 bool operator>=(BranchProbability RHS) const {
210 assert(N != UnknownN && RHS.N != UnknownN &&
211 "Unknown probability cannot participate in comparisons.");
212 return !(*this < RHS);
213 }
214};
215
217 return Prob.print(OS);
218}
219
220template <class ProbabilityIter>
222 ProbabilityIter End) {
223 if (Begin == End)
224 return;
225
226 unsigned UnknownProbCount = 0;
227 uint64_t Sum = std::accumulate(Begin, End, uint64_t(0),
228 [&](uint64_t S, const BranchProbability &BP) {
229 if (!BP.isUnknown())
230 return S + BP.N;
231 UnknownProbCount++;
232 return S;
233 });
234
235 if (UnknownProbCount > 0) {
236 BranchProbability ProbForUnknown = BranchProbability::getZero();
237 // If the sum of all known probabilities is less than one, evenly distribute
238 // the complement of sum to unknown probabilities. Otherwise, set unknown
239 // probabilities to zeros and continue to normalize known probabilities.
241 ProbForUnknown = BranchProbability::getRaw(
242 (BranchProbability::getDenominator() - Sum) / UnknownProbCount);
243
244 std::replace_if(Begin, End,
245 [](const BranchProbability &BP) { return BP.isUnknown(); },
246 ProbForUnknown);
247
249 return;
250 }
251
252 if (Sum == 0) {
253 BranchProbability BP(1, std::distance(Begin, End));
254 std::fill(Begin, End, BP);
255 return;
256 }
257
258 for (auto I = Begin; I != End; ++I)
259 I->N = (I->N * uint64_t(D) + Sum / 2) / Sum;
260}
261
262}
263
264#endif
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
#define LLVM_ABI
Definition Compiler.h:213
#define LLVM_DUMP_METHOD
Mark debug helper function definitions like dump() that should not be stripped from debug builds.
Definition Compiler.h:661
#define I(x, y, z)
Definition MD5.cpp:57
Value * RHS
LLVM_DUMP_METHOD void dump() const
static LLVM_ABI BranchProbability getBranchProbability(uint64_t Numerator, uint64_t Denominator)
BranchProbability operator-(BranchProbability RHS) const
BranchProbability & operator-=(BranchProbability RHS)
static uint32_t getDenominator()
bool operator<(BranchProbability RHS) const
bool operator!=(BranchProbability RHS) const
static BranchProbability getRaw(uint32_t N)
bool operator==(BranchProbability RHS) const
BranchProbability operator/(uint32_t RHS) const
BranchProbability & operator/=(BranchProbability RHS)
LLVM_ABI BranchProbability pow(unsigned N) const
Compute pow(Probability, N).
bool operator<=(BranchProbability RHS) const
static BranchProbability getOne()
LLVM_ABI raw_ostream & print(raw_ostream &OS) const
BranchProbability & operator*=(BranchProbability RHS)
LLVM_ABI uint64_t scaleByInverse(uint64_t Num) const
Scale a large integer by the inverse.
BranchProbability operator*(BranchProbability RHS) const
static BranchProbability getUnknown()
BranchProbability operator/(BranchProbability RHS) const
uint32_t getNumerator() const
LLVM_ABI uint64_t scale(uint64_t Num) const
Scale a large integer.
static void normalizeProbabilities(ProbabilityContainer &&R)
BranchProbability operator+(BranchProbability RHS) const
bool operator>=(BranchProbability RHS) const
BranchProbability operator*(uint32_t RHS) const
BranchProbability & operator*=(uint32_t RHS)
BranchProbability getCompl() const
BranchProbability & operator+=(BranchProbability RHS)
BranchProbability & operator/=(uint32_t RHS)
static BranchProbability getZero()
bool operator>(BranchProbability RHS) const
static void normalizeProbabilities(ProbabilityIter Begin, ProbabilityIter End)
This class implements an extremely fast bulk output stream that can only output to a stream.
Definition raw_ostream.h:53
This is an optimization pass for GlobalISel generic memory operations.
constexpr auto adl_begin(RangeT &&range) -> decltype(adl_detail::begin_impl(std::forward< RangeT >(range)))
Returns the begin iterator to range using std::begin and function found through Argument-Dependent Lo...
Definition ADL.h:78
constexpr auto adl_end(RangeT &&range) -> decltype(adl_detail::end_impl(std::forward< RangeT >(range)))
Returns the end iterator to range using std::end and functions found through Argument-Dependent Looku...
Definition ADL.h:86
raw_ostream & operator<<(raw_ostream &OS, const APFixedPoint &FX)
#define N