23#define DEBUG_TYPE "ctx-instr-lower"
28 "A function name, assumed to be global, which will be treated as the "
29 "root of an interesting graph, which will be profiled independently "
30 "from other similar graphs."));
39static auto StartCtx =
"__llvm_ctx_profile_start_context";
40static auto ReleaseCtx =
"__llvm_ctx_profile_release_context";
41static auto GetCtx =
"__llvm_ctx_profile_get_context";
48class CtxInstrumentationLowerer final {
51 Type *ContextNodeTy =
nullptr;
52 Type *ContextRootTy =
nullptr;
71std::pair<uint32_t, uint32_t> getNrCountersAndCallsites(
const Function &
F) {
74 for (
const auto &BB :
F) {
75 for (
const auto &
I : BB) {
76 if (
const auto *Incr = dyn_cast<InstrProfIncrementInst>(&
I)) {
78 static_cast<uint32_t>(Incr->getNumCounters()->getZExtValue());
79 assert((!NrCounters || V == NrCounters) &&
80 "expected all llvm.instrprof.increment[.step] intrinsics to "
81 "have the same total nr of counters parameter");
83 }
else if (
const auto *CSIntr = dyn_cast<InstrProfCallsite>(&
I)) {
85 static_cast<uint32_t>(CSIntr->getNumCounters()->getZExtValue());
86 assert((!NrCallsites || V == NrCallsites) &&
87 "expected all llvm.instrprof.callsite intrinsics to have the "
88 "same total nr of callsites parameter");
92 if (NrCounters && NrCallsites)
93 return std::make_pair(NrCounters, NrCallsites);
97 return {NrCounters, NrCallsites};
104CtxInstrumentationLowerer::CtxInstrumentationLowerer(
Module &M,
107 auto *
PointerTy = PointerType::get(
M.getContext(), 0);
131 if (
const auto *
F =
M.getFunction(Fname)) {
132 if (
F->isDeclaration())
134 auto *
G =
M.getOrInsertGlobal(Fname +
"_ctx_root", ContextRootTy);
135 cast<GlobalVariable>(
G)->setInitializer(
137 ContextRootMap.insert(std::make_pair(
F,
G));
138 for (
const auto &BB : *
F)
139 for (
const auto &
I : BB)
140 if (
const auto *CB = dyn_cast<CallBase>(&
I))
141 if (CB->isMustTailCall()) {
142 M.getContext().emitError(
143 "The function " + Fname +
144 " was indicated as a context root, but it features musttail "
145 "calls, which is not supported.");
152 M.getOrInsertFunction(
154 FunctionType::get(ContextNodeTy->getPointerTo(),
155 {ContextRootTy->getPointerTo(),
162 FunctionType::get(ContextNodeTy->getPointerTo(),
170 M.getOrInsertFunction(
174 ContextRootTy->getPointerTo(),
183 CallsiteInfoTLS->setThreadLocal(
true);
194 CtxInstrumentationLowerer Lowerer(M,
MAM);
195 bool Changed =
false;
197 Changed |= Lowerer.lowerFunction(
F);
201bool CtxInstrumentationLowerer::lowerFunction(
Function &
F) {
202 if (
F.isDeclaration())
208 auto [NrCounters, NrCallsites] = getNrCountersAndCallsites(
F);
211 Value *RealContext =
nullptr;
214 Value *TheRootContext =
nullptr;
215 Value *ExpectedCalleeTLSAddr =
nullptr;
216 Value *CallsiteInfoTLSAddr =
nullptr;
218 auto &Head =
F.getEntryBlock();
219 for (
auto &
I : Head) {
221 if (
auto *Mark = dyn_cast<InstrProfIncrementInst>(&
I)) {
222 assert(Mark->getIndex()->isZero());
226 Guid = Builder.getInt64(
F.getGUID());
232 {ContextNodeTy, ArrayType::get(Builder.getInt64Ty(), NrCounters),
233 ArrayType::get(Builder.getPtrTy(), NrCallsites)});
239 auto Iter = ContextRootMap.find(&
F);
240 if (Iter != ContextRootMap.end()) {
241 TheRootContext = Iter->second;
242 Context = Builder.CreateCall(StartCtx, {TheRootContext,
Guid,
243 Builder.getInt32(NrCounters),
244 Builder.getInt32(NrCallsites)});
249 Builder.CreateCall(GetCtx, {&
F,
Guid, Builder.getInt32(NrCounters),
250 Builder.getInt32(NrCallsites)});
256 auto *CtxAsInt = Builder.CreatePtrToInt(Context, Builder.getInt64Ty());
257 if (NrCallsites > 0) {
260 auto *
Index = Builder.CreateAnd(CtxAsInt, Builder.getInt64(1));
262 ExpectedCalleeTLSAddr = Builder.CreateGEP(
263 Builder.getInt8Ty()->getPointerTo(),
264 Builder.CreateThreadLocalAddress(ExpectedCalleeTLS), {
Index});
265 CallsiteInfoTLSAddr = Builder.CreateGEP(
266 Builder.getInt32Ty(),
267 Builder.CreateThreadLocalAddress(CallsiteInfoTLS), {
Index});
274 RealContext = Builder.CreateIntToPtr(
275 Builder.CreateAnd(CtxAsInt, Builder.getInt64(-2)),
284 <<
"Function doesn't have instrumentation, skipping";
289 bool ContextWasReleased =
false;
292 if (
auto *Instr = dyn_cast<InstrProfCntrInstBase>(&
I)) {
294 switch (
Instr->getIntrinsicID()) {
295 case llvm::Intrinsic::instrprof_increment:
296 case llvm::Intrinsic::instrprof_increment_step: {
299 auto *AsStep = cast<InstrProfIncrementInst>(Instr);
300 auto *
GEP = Builder.CreateGEP(
301 ThisContextType, RealContext,
302 {Builder.getInt32(0), Builder.getInt32(1), AsStep->getIndex()});
304 Builder.CreateAdd(Builder.CreateLoad(Builder.getInt64Ty(),
GEP),
308 case llvm::Intrinsic::instrprof_callsite:
312 auto *CSIntrinsic = dyn_cast<InstrProfCallsite>(Instr);
313 Builder.CreateStore(CSIntrinsic->getCallee(), ExpectedCalleeTLSAddr,
325 Builder.CreateGEP(ThisContextType, Context,
326 {Builder.getInt32(0), Builder.getInt32(2),
327 CSIntrinsic->getIndex()}),
328 CallsiteInfoTLSAddr,
true);
332 }
else if (TheRootContext && isa<ReturnInst>(
I)) {
335 Builder.CreateCall(ReleaseCtx, {TheRootContext});
336 ContextWasReleased =
true;
344 if (TheRootContext && !ContextWasReleased)
345 F.getContext().emitError(
346 "[ctx_prof] An entrypoint was instrumented but it has no `ret` "
347 "instructions above which to release the context: " +
static cl::list< std::string > ContextRoots("profile-context-root", cl::Hidden, cl::desc("A function name, assumed to be global, which will be treated as the " "root of an interesting graph, which will be profiled independently " "from other similar graphs."))
FunctionAnalysisManager FAM
ModuleAnalysisManager MAM
This header defines various interfaces for pass management in LLVM.
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
A container for analyses that lazily runs them and caches their results.
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
static Constant * getNullValue(Type *Ty)
Constructor to create a '0' constant of arbitrary type.
@ HiddenVisibility
The GV is hidden.
@ ExternalLinkage
Externally visible function.
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
An analysis over an "outer" IR unit that provides access to an analysis manager over an "inner" IR un...
A Module instance is used to store all the information related to an LLVM module.
PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM)
static bool isContextualIRPGOEnabled()
A set of analyses that are preserved following a run of a transformation pass.
static PreservedAnalyses none()
Convenience factory function for the empty preserved set.
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Class to represent struct types.
static StructType * get(LLVMContext &Context, ArrayRef< Type * > Elements, bool isPacked=false)
This static method is the primary way to create a literal StructType.
The instances of the Type class are immutable: once they are created, they are never changed.
PointerType * getPointerTo(unsigned AddrSpace=0) const
Return a pointer to the current type.
static Type * getVoidTy(LLVMContext &C)
static IntegerType * getInt8Ty(LLVMContext &C)
static IntegerType * getInt32Ty(LLVMContext &C)
static IntegerType * getInt64Ty(LLVMContext &C)
LLVM Value Representation.
Pass manager infrastructure for declaring and invalidating analyses.
static auto ExpectedCalleeTLS
NodeAddr< InstrNode * > Instr
This is an optimization pass for GlobalISel generic memory operations.
iterator_range< early_inc_iterator_impl< detail::IterOfRange< RangeT > > > make_early_inc_range(RangeT &&Range)
Make a range that does early increment to allow mutation of the underlying range without disrupting i...