LLVM 20.0.0git
MIRSampleProfile.cpp
Go to the documentation of this file.
1//===-------- MIRSampleProfile.cpp: MIRSampleFDO (For FSAFDO) -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file provides the implementation of the MIRSampleProfile loader, mainly
10// for flow sensitive SampleFDO.
11//
12//===----------------------------------------------------------------------===//
13
15#include "llvm/ADT/DenseMap.h"
16#include "llvm/ADT/DenseSet.h"
25#include "llvm/CodeGen/Passes.h"
26#include "llvm/IR/Function.h"
27#include "llvm/IR/PseudoProbe.h"
30#include "llvm/Support/Debug.h"
35#include <optional>
36
37using namespace llvm;
38using namespace sampleprof;
39using namespace llvm::sampleprofutil;
41
42#define DEBUG_TYPE "fs-profile-loader"
43
45 "show-fs-branchprob", cl::Hidden, cl::init(false),
46 cl::desc("Print setting flow sensitive branch probabilities"));
48 "fs-profile-debug-prob-diff-threshold", cl::init(10),
49 cl::desc("Only show debug message if the branch probility is greater than "
50 "this value (in percentage)."));
51
53 "fs-profile-debug-bw-threshold", cl::init(10000),
54 cl::desc("Only show debug message if the source branch weight is greater "
55 " than this value."));
56
57static cl::opt<bool> ViewBFIBefore("fs-viewbfi-before", cl::Hidden,
58 cl::init(false),
59 cl::desc("View BFI before MIR loader"));
60static cl::opt<bool> ViewBFIAfter("fs-viewbfi-after", cl::Hidden,
61 cl::init(false),
62 cl::desc("View BFI after MIR loader"));
63
64namespace llvm {
66}
68
70 "Load MIR Sample Profile",
71 /* cfg = */ false, /* is_analysis = */ false)
78 /* cfg = */ false, /* is_analysis = */ false)
79
81
83llvm::createMIRProfileLoaderPass(std::string File, std::string RemappingFile,
85 IntrusiveRefCntPtr<vfs::FileSystem> FS) {
86 return new MIRProfileLoaderPass(File, RemappingFile, P, std::move(FS));
87}
88
89namespace llvm {
90
91// Internal option used to control BFI display only after MBP pass.
92// Defined in CodeGen/MachineBlockFrequencyInfo.cpp:
93// -view-block-layout-with-bfi={none | fraction | integer | count}
95
96// Command line option to specify the name of the function for CFG dump
97// Defined in Analysis/BlockFrequencyInfo.cpp: -view-bfi-func-name=
99
100std::optional<PseudoProbe> extractProbe(const MachineInstr &MI) {
101 if (MI.isPseudoProbe()) {
102 PseudoProbe Probe;
103 Probe.Id = MI.getOperand(1).getImm();
104 Probe.Type = MI.getOperand(2).getImm();
105 Probe.Attr = MI.getOperand(3).getImm();
106 Probe.Factor = 1;
107 DILocation *DebugLoc = MI.getDebugLoc();
108 Probe.Discriminator = DebugLoc ? DebugLoc->getDiscriminator() : 0;
109 return Probe;
110 }
111
112 // Ignore callsite probes since they do not have FS discriminators.
113 return std::nullopt;
114}
115
116namespace afdo_detail {
117template <> struct IRTraits<MachineBasicBlock> {
133 static Function &getFunction(MachineFunction &F) { return F.getFunction(); }
136 }
138 return BB->predecessors();
139 }
141 return BB->successors();
142 }
143};
144} // namespace afdo_detail
145
147 : public SampleProfileLoaderBaseImpl<MachineFunction> {
148public:
152 DT = MDT;
153 PDT = MPDT;
154 LI = MLI;
155 BFI = MBFI;
156 ORE = MORE;
157 }
159 P = Pass;
162 assert(LowBit < HighBit && "HighBit needs to be greater than Lowbit");
163 }
164
168 std::move(FS)) {}
169
172 bool doInitialization(Module &M);
173 bool isValid() const { return ProfileIsValid; }
174
175protected:
177
178 /// Hold the information of the basic block frequency.
180
181 /// PassNum is the sequence number this pass is called, start from 1.
183
184 // LowBit in the FS discriminator used by this instance. Note the number is
185 // 0-based. Base discrimnator use bit 0 to bit 11.
186 unsigned LowBit;
187 // HighwBit in the FS discriminator used by this instance. Note the number
188 // is 0-based.
189 unsigned HighBit;
190
191 bool ProfileIsValid = true;
194 return getProbeWeight(MI);
195 if (ImprovedFSDiscriminator && MI.isMetaInstruction())
196 return std::error_code();
197 return getInstWeightImpl(MI);
198 }
199};
200
201template <>
203 MachineFunction &F) {}
204
206 LLVM_DEBUG(dbgs() << "\nPropagation complete. Setting branch probs\n");
207 for (auto &BI : F) {
208 MachineBasicBlock *BB = &BI;
209 if (BB->succ_size() < 2)
210 continue;
211 const MachineBasicBlock *EC = EquivalenceClass[BB];
212 uint64_t BBWeight = BlockWeights[EC];
213 uint64_t SumEdgeWeight = 0;
214 for (MachineBasicBlock *Succ : BB->successors()) {
215 Edge E = std::make_pair(BB, Succ);
216 SumEdgeWeight += EdgeWeights[E];
217 }
218
219 if (BBWeight != SumEdgeWeight) {
220 LLVM_DEBUG(dbgs() << "BBweight is not equal to SumEdgeWeight: BBWWeight="
221 << BBWeight << " SumEdgeWeight= " << SumEdgeWeight
222 << "\n");
223 BBWeight = SumEdgeWeight;
224 }
225 if (BBWeight == 0) {
226 LLVM_DEBUG(dbgs() << "SKIPPED. All branch weights are zero.\n");
227 continue;
228 }
229
230#ifndef NDEBUG
231 uint64_t BBWeightOrig = BBWeight;
232#endif
233 uint32_t MaxWeight = std::numeric_limits<uint32_t>::max();
234 uint32_t Factor = 1;
235 if (BBWeight > MaxWeight) {
236 Factor = BBWeight / MaxWeight + 1;
237 BBWeight /= Factor;
238 LLVM_DEBUG(dbgs() << "Scaling weights by " << Factor << "\n");
239 }
240
242 SE = BB->succ_end();
243 SI != SE; ++SI) {
244 MachineBasicBlock *Succ = *SI;
245 Edge E = std::make_pair(BB, Succ);
246 uint64_t EdgeWeight = EdgeWeights[E];
247 EdgeWeight /= Factor;
248
249 assert(BBWeight >= EdgeWeight &&
250 "BBweight is larger than EdgeWeight -- should not happen.\n");
251
252 BranchProbability OldProb = BFI->getMBPI()->getEdgeProbability(BB, SI);
253 BranchProbability NewProb(EdgeWeight, BBWeight);
254 if (OldProb == NewProb)
255 continue;
256 BB->setSuccProbability(SI, NewProb);
257#ifndef NDEBUG
258 if (!ShowFSBranchProb)
259 continue;
260 bool Show = false;
262 if (OldProb > NewProb)
263 Diff = OldProb - NewProb;
264 else
265 Diff = NewProb - OldProb;
267 Show &= (BBWeightOrig >= FSProfileDebugBWThreshold);
268
269 auto DIL = BB->findBranchDebugLoc();
270 auto SuccDIL = Succ->findBranchDebugLoc();
271 if (Show) {
272 dbgs() << "Set branch fs prob: MBB (" << BB->getNumber() << " -> "
273 << Succ->getNumber() << "): ";
274 if (DIL)
275 dbgs() << DIL->getFilename() << ":" << DIL->getLine() << ":"
276 << DIL->getColumn();
277 if (SuccDIL)
278 dbgs() << "-->" << SuccDIL->getFilename() << ":" << SuccDIL->getLine()
279 << ":" << SuccDIL->getColumn();
280 dbgs() << " W=" << BBWeightOrig << " " << OldProb << " --> " << NewProb
281 << "\n";
282 }
283#endif
284 }
285 }
286}
287
289 auto &Ctx = M.getContext();
290
292 Filename, Ctx, *FS, P, RemappingFilename);
293 if (std::error_code EC = ReaderOrErr.getError()) {
294 std::string Msg = "Could not open profile: " + EC.message();
295 Ctx.diagnose(DiagnosticInfoSampleProfile(Filename, Msg));
296 return false;
297 }
298
299 Reader = std::move(ReaderOrErr.get());
300 Reader->setModule(&M);
302
303 // Load pseudo probe descriptors for probe-based function samples.
304 if (Reader->profileIsProbeBased()) {
305 ProbeManager = std::make_unique<PseudoProbeManager>(M);
306 if (!ProbeManager->moduleIsProbed(M)) {
307 return false;
308 }
309 }
310
311 return true;
312}
313
315 // Do not load non-FS profiles. A line or probe can get a zero-valued
316 // discriminator at certain pass which could result in accidentally loading
317 // the corresponding base counter in the non-FS profile, while a non-zero
318 // discriminator would end up getting zero samples. This could in turn undo
319 // the sample distribution effort done by previous BFI maintenance and the
320 // probe distribution factor work for pseudo probes.
321 if (!Reader->profileIsFS())
322 return false;
323
324 Function &Func = MF.getFunction();
325 clearFunctionData(false);
326 Samples = Reader->getSamplesFor(Func);
327 if (!Samples || Samples->empty())
328 return false;
329
331 if (!ProbeManager->profileIsValid(MF.getFunction(), *Samples))
332 return false;
333 } else {
334 if (getFunctionLoc(MF) == 0)
335 return false;
336 }
337
338 DenseSet<GlobalValue::GUID> InlinedGUIDs;
339 bool Changed = computeAndPropagateWeights(MF, InlinedGUIDs);
340
341 // Set the new BPI, BFI.
342 setBranchProbs(MF);
343
344 return Changed;
345}
346
347} // namespace llvm
348
350 std::string FileName, std::string RemappingFileName, FSDiscriminatorPass P,
352 : MachineFunctionPass(ID), ProfileFileName(FileName), P(P) {
353 LowBit = getFSPassBitBegin(P);
354 HighBit = getFSPassBitEnd(P);
355
356 auto VFS = FS ? std::move(FS) : vfs::getRealFileSystem();
357 MIRSampleLoader = std::make_unique<MIRProfileLoader>(
358 FileName, RemappingFileName, std::move(VFS));
359 assert(LowBit < HighBit && "HighBit needs to be greater than Lowbit");
360}
361
362bool MIRProfileLoaderPass::runOnMachineFunction(MachineFunction &MF) {
363 if (!MIRSampleLoader->isValid())
364 return false;
365
366 LLVM_DEBUG(dbgs() << "MIRProfileLoader pass working on Func: "
367 << MF.getFunction().getName() << "\n");
368 MBFI = &getAnalysis<MachineBlockFrequencyInfoWrapperPass>().getMBFI();
369 auto *MDT = &getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree();
370 auto *MPDT =
371 &getAnalysis<MachinePostDominatorTreeWrapperPass>().getPostDomTree();
372
373 MF.RenumberBlocks();
374 MDT->updateBlockNumbers();
375 MPDT->updateBlockNumbers();
376
377 MIRSampleLoader->setInitVals(
378 MDT, MPDT, &getAnalysis<MachineLoopInfoWrapperPass>().getLI(), MBFI,
379 &getAnalysis<MachineOptimizationRemarkEmitterPass>().getORE());
380
382 (ViewBlockFreqFuncName.empty() ||
384 MBFI->view("MIR_Prof_loader_b." + MF.getName(), false);
385 }
386
387 bool Changed = MIRSampleLoader->runOnFunction(MF);
388 if (Changed)
389 MBFI->calculate(MF, *MBFI->getMBPI(),
390 *&getAnalysis<MachineLoopInfoWrapperPass>().getLI());
391
393 (ViewBlockFreqFuncName.empty() ||
395 MBFI->view("MIR_prof_loader_a." + MF.getName(), false);
396 }
397
398 return Changed;
399}
400
401bool MIRProfileLoaderPass::doInitialization(Module &M) {
402 LLVM_DEBUG(dbgs() << "MIRProfileLoader pass working on Module " << M.getName()
403 << "\n");
404
405 MIRSampleLoader->setFSPass(P);
406 return MIRSampleLoader->doInitialization(M);
407}
408
409void MIRProfileLoaderPass::getAnalysisUsage(AnalysisUsage &AU) const {
410 AU.setPreservesAll();
417}
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
#define LLVM_DEBUG(X)
Definition: Debug.h:101
This file defines the DenseMap class.
This file defines the DenseSet and SmallDenseSet classes.
std::string Name
IRTranslator LLVM IR MI
#define F(x, y, z)
Definition: MD5.cpp:55
Load MIR Sample Profile
static cl::opt< bool > ShowFSBranchProb("show-fs-branchprob", cl::Hidden, cl::init(false), cl::desc("Print setting flow sensitive branch probabilities"))
static cl::opt< bool > ViewBFIAfter("fs-viewbfi-after", cl::Hidden, cl::init(false), cl::desc("View BFI after MIR loader"))
static cl::opt< unsigned > FSProfileDebugBWThreshold("fs-profile-debug-bw-threshold", cl::init(10000), cl::desc("Only show debug message if the source branch weight is greater " " than this value."))
#define DEBUG_TYPE
static cl::opt< bool > ViewBFIBefore("fs-viewbfi-before", cl::Hidden, cl::init(false), cl::desc("View BFI before MIR loader"))
static cl::opt< unsigned > FSProfileDebugProbDiffThreshold("fs-profile-debug-prob-diff-threshold", cl::init(10), cl::desc("Only show debug message if the branch probility is greater than " "this value (in percentage)."))
===- MachineOptimizationRemarkEmitter.h - Opt Diagnostics -*- C++ -*-—===//
#define P(N)
#define INITIALIZE_PASS_DEPENDENCY(depName)
Definition: PassSupport.h:55
#define INITIALIZE_PASS_END(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:57
#define INITIALIZE_PASS_BEGIN(passName, arg, name, cfg, analysis)
Definition: PassSupport.h:52
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
This file provides the interface for the sampled PGO profile loader base implementation.
This file provides the utility functions for the sampled PGO loader base implementation.
Defines the virtual file system interface vfs::FileSystem.
Represent the analysis usage information of a pass.
AnalysisUsage & addRequired()
void setPreservesAll()
Set by analyses that do not transform their input at all.
AnalysisUsage & addRequiredTransitive()
Debug location.
A debug info location.
Definition: DebugLoc.h:33
Implements a dense probed hash-table based set.
Definition: DenseSet.h:271
Diagnostic information for the sample profiler.
Represents either an error or a value T.
Definition: ErrorOr.h:56
FunctionPass class - This class is used to implement most global optimizations.
Definition: Pass.h:310
Class to represent profile counts.
Definition: Function.h:296
A smart pointer to a reference-counted object that inherits from RefCountedBase or ThreadSafeRefCount...
MIRProfileLoaderPass(std::string FileName="", std::string RemappingFileName="", FSDiscriminatorPass P=FSDiscriminatorPass::Pass1, IntrusiveRefCntPtr< vfs::FileSystem > FS=nullptr)
FS bits will only use the '1' bits in the Mask.
MIRProfileLoader(StringRef Name, StringRef RemapName, IntrusiveRefCntPtr< vfs::FileSystem > FS)
void setBranchProbs(MachineFunction &F)
ErrorOr< uint64_t > getInstWeight(const MachineInstr &MI) override
bool runOnFunction(MachineFunction &F)
MachineBlockFrequencyInfo * BFI
Hold the information of the basic block frequency.
FSDiscriminatorPass P
PassNum is the sequence number this pass is called, start from 1.
bool doInitialization(Module &M)
void setInitVals(MachineDominatorTree *MDT, MachinePostDominatorTree *MPDT, MachineLoopInfo *MLI, MachineBlockFrequencyInfo *MBFI, MachineOptimizationRemarkEmitter *MORE)
void setFSPass(FSDiscriminatorPass Pass)
int getNumber() const
MachineBasicBlocks are uniquely numbered at the function level, unless they're not in a MachineFuncti...
void setSuccProbability(succ_iterator I, BranchProbability Prob)
Set successor probability of a given iterator.
unsigned succ_size() const
SmallVectorImpl< MachineBasicBlock * >::iterator succ_iterator
DebugLoc findBranchDebugLoc()
Find and return the merged DebugLoc of the branch instructions of the block.
iterator_range< succ_iterator > successors()
iterator_range< pred_iterator > predecessors()
MachineBlockFrequencyInfo pass uses BlockFrequencyInfoImpl implementation to estimate machine basic b...
void view(const Twine &Name, bool isSimple=true) const
Pop up a ghostview window with the current block frequency propagation rendered using dot.
const MachineBranchProbabilityInfo * getMBPI() const
void calculate(const MachineFunction &F, const MachineBranchProbabilityInfo &MBPI, const MachineLoopInfo &MLI)
calculate - compute block frequency info for the given function.
BranchProbability getEdgeProbability(const MachineBasicBlock *Src, const MachineBasicBlock *Dst) const
Analysis pass which computes a MachineDominatorTree.
DominatorTree Class - Concrete subclass of DominatorTreeBase that is used to compute a normal dominat...
MachineFunctionPass - This class adapts the FunctionPass interface to allow convenient creation of pa...
void getAnalysisUsage(AnalysisUsage &AU) const override
getAnalysisUsage - Subclasses that override getAnalysisUsage must call this.
StringRef getName() const
getName - Return the name of the corresponding LLVM function.
Function & getFunction()
Return the LLVM function that this machine code represents.
void RenumberBlocks(MachineBasicBlock *MBBFrom=nullptr)
RenumberBlocks - This discards all of the MachineBasicBlock numbers and recomputes them.
Representation of each machine instruction.
Definition: MachineInstr.h:69
Diagnostic information for optimization analysis remarks.
MachinePostDominatorTree - an analysis pass wrapper for DominatorTree used to compute the post-domina...
A Module instance is used to store all the information related to an LLVM module.
Definition: Module.h:65
Pass interface - Implemented by all 'passes'.
Definition: Pass.h:94
bool computeAndPropagateWeights(FunctionT &F, const DenseSet< GlobalValue::GUID > &InlinedGUIDs)
Generate branch weight metadata for all branches in F.
void computeDominanceAndLoopInfo(FunctionT &F)
IntrusiveRefCntPtr< vfs::FileSystem > FS
VirtualFileSystem to load profile files from.
EdgeWeightMap EdgeWeights
Map edges to their computed weights.
OptRemarkEmitterT * ORE
Optimization Remark Emitter used to emit diagnostic remarks.
unsigned getFunctionLoc(FunctionT &Func)
Get the line number for the function header.
ErrorOr< uint64_t > getInstWeightImpl(const InstructionT &Inst)
EquivalenceClassMap EquivalenceClass
Equivalence classes for block weights.
std::unique_ptr< SampleProfileReader > Reader
Profile reader object.
DominatorTreePtrT DT
Dominance, post-dominance and loop information.
std::string Filename
Name of the profile file to load.
virtual ErrorOr< uint64_t > getProbeWeight(const InstructionT &Inst)
std::string RemappingFilename
Name of the profile remapping file to load.
FunctionSamples * Samples
Samples collected for the body of this function.
std::pair< const BasicBlockT *, const BasicBlockT * > Edge
void clearFunctionData(bool ResetDT=true)
Clear all the per-function data used to load samples and propagate weights.
BlockWeightMap BlockWeights
Map basic blocks to their computed weights.
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:50
StringRef getName() const
Return a constant reference to the value's name.
Definition: Value.cpp:309
A range adaptor for a pair of iterators.
static ErrorOr< std::unique_ptr< SampleProfileReader > > create(StringRef Filename, LLVMContext &C, vfs::FileSystem &FS, FSDiscriminatorPass P=FSDiscriminatorPass::Base, StringRef RemapFilename="")
Create a sample profile reader appropriate to the file format.
initializer< Ty > init(const Ty &Val)
Definition: CommandLine.h:443
IntrusiveRefCntPtr< FileSystem > getRealFileSystem()
Gets an vfs::FileSystem for the 'real' file system, as seen by the operating system.
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
static unsigned getFSPassBitBegin(sampleprof::FSDiscriminatorPass P)
Definition: Discriminator.h:94
char & MIRProfileLoaderPassID
This pass reads flow sensitive profile.
static unsigned getFSPassBitEnd(sampleprof::FSDiscriminatorPass P)
Definition: Discriminator.h:87
cl::opt< std::string > ViewBlockFreqFuncName("view-bfi-func-name", cl::Hidden, cl::desc("The option to specify " "the name of the function " "whose CFG will be displayed."))
std::optional< PseudoProbe > extractProbe(const Instruction &Inst)
Definition: PseudoProbe.cpp:56
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:163
cl::opt< GVDAGType > ViewBlockLayoutWithBFI("view-block-layout-with-bfi", cl::Hidden, cl::desc("Pop up a window to show a dag displaying MBP layout and associated " "block frequencies of the CFG."), cl::values(clEnumValN(GVDT_None, "none", "do not display graphs."), clEnumValN(GVDT_Fraction, "fraction", "display a graph using the " "fractional block frequency representation."), clEnumValN(GVDT_Integer, "integer", "display a graph using the raw " "integer fractional block frequency representation."), clEnumValN(GVDT_Count, "count", "display a graph using the real " "profile count if available.")))
cl::opt< bool > ImprovedFSDiscriminator("improved-fs-discriminator", cl::Hidden, cl::init(false), cl::desc("New FS discriminators encoding (incompatible with the original " "encoding)"))
FunctionPass * createMIRProfileLoaderPass(std::string File, std::string RemappingFile, sampleprof::FSDiscriminatorPass P, IntrusiveRefCntPtr< vfs::FileSystem > FS)
Read Flow Sensitive Profile.
OutputIt move(R &&Range, OutputIt Out)
Provide wrappers to std::move which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1849
Implement std::hash so that hash_code can be used in STL containers.
Definition: BitVector.h:858
#define MORE()
Definition: regcomp.c:252
uint32_t Discriminator
Definition: PseudoProbe.h:121
static PredRangeT getPredecessors(MachineBasicBlock *BB)
static SuccRangeT getSuccessors(MachineBasicBlock *BB)
static const MachineBasicBlock * getEntryBB(const MachineFunction *F)
static Function & getFunction(MachineFunction &F)