50#include "llvm/IR/IntrinsicsAMDGPU.h"
51#include "llvm/IR/IntrinsicsNVPTX.h"
67#define DEBUG_TYPE "openmp-opt"
70 "openmp-opt-disable",
cl::desc(
"Disable OpenMP specific optimizations."),
74 "openmp-opt-enable-merging",
80 cl::desc(
"Disable function internalization."),
91 "openmp-hide-memory-transfer-latency",
92 cl::desc(
"[WIP] Tries to hide the latency of host to device memory"
97 "openmp-opt-disable-deglobalization",
98 cl::desc(
"Disable OpenMP optimizations involving deglobalization."),
102 "openmp-opt-disable-spmdization",
103 cl::desc(
"Disable OpenMP optimizations involving SPMD-ization."),
107 "openmp-opt-disable-folding",
112 "openmp-opt-disable-state-machine-rewrite",
113 cl::desc(
"Disable OpenMP optimizations that replace the state machine."),
117 "openmp-opt-disable-barrier-elimination",
118 cl::desc(
"Disable OpenMP optimizations that eliminate barriers."),
122 "openmp-opt-print-module-after",
123 cl::desc(
"Print the current module after OpenMP optimizations."),
127 "openmp-opt-print-module-before",
128 cl::desc(
"Print the current module before OpenMP optimizations."),
132 "openmp-opt-inline-device",
143 cl::desc(
"Maximal number of attributor iterations."),
148 cl::desc(
"Maximum amount of shared memory to use."),
149 cl::init(std::numeric_limits<unsigned>::max()));
152 "Number of OpenMP runtime calls deduplicated");
154 "Number of OpenMP parallel regions deleted");
156 "Number of OpenMP runtime functions identified");
158 "Number of OpenMP runtime function uses identified");
160 "Number of OpenMP target region entry points (=kernels) identified");
162 "Number of non-OpenMP target region kernels identified");
164 "Number of OpenMP target region entry points (=kernels) executed in "
165 "SPMD-mode instead of generic-mode");
166STATISTIC(NumOpenMPTargetRegionKernelsWithoutStateMachine,
167 "Number of OpenMP target region entry points (=kernels) executed in "
168 "generic-mode without a state machines");
169STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback,
170 "Number of OpenMP target region entry points (=kernels) executed in "
171 "generic-mode with customized state machines with fallback");
172STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback,
173 "Number of OpenMP target region entry points (=kernels) executed in "
174 "generic-mode with customized state machines without fallback");
176 NumOpenMPParallelRegionsReplacedInGPUStateMachine,
177 "Number of OpenMP parallel regions replaced with ID in GPU state machines");
179 "Number of OpenMP parallel regions merged");
181 "Amount of memory pushed to shared memory");
182STATISTIC(NumBarriersEliminated,
"Number of redundant barriers eliminated");
210#define KERNEL_ENVIRONMENT_IDX(MEMBER, IDX) \
211 constexpr const unsigned MEMBER##Idx = IDX;
216#undef KERNEL_ENVIRONMENT_IDX
218#define KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MEMBER, IDX) \
219 constexpr const unsigned MEMBER##Idx = IDX;
229#undef KERNEL_ENVIRONMENT_CONFIGURATION_IDX
231#define KERNEL_ENVIRONMENT_GETTER(MEMBER, RETURNTYPE) \
232 RETURNTYPE *get##MEMBER##FromKernelEnvironment(ConstantStruct *KernelEnvC) { \
233 return cast<RETURNTYPE>(KernelEnvC->getAggregateElement(MEMBER##Idx)); \
239#undef KERNEL_ENVIRONMENT_GETTER
241#define KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MEMBER) \
242 ConstantInt *get##MEMBER##FromKernelEnvironment( \
243 ConstantStruct *KernelEnvC) { \
244 ConstantStruct *ConfigC = \
245 getConfigurationFromKernelEnvironment(KernelEnvC); \
246 return dyn_cast<ConstantInt>(ConfigC->getAggregateElement(MEMBER##Idx)); \
257#undef KERNEL_ENVIRONMENT_CONFIGURATION_GETTER
261 constexpr const int InitKernelEnvironmentArgNo = 0;
262 return cast<GlobalVariable>(
276struct AAHeapToShared;
287 OpenMPPostLink(OpenMPPostLink) {
290 OMPBuilder.initialize();
291 initializeRuntimeFunctions(M);
292 initializeInternalControlVars();
296 struct InternalControlVarInfo {
323 struct RuntimeFunctionInfo {
347 void clearUsesMap() { UsesMap.
clear(); }
350 operator bool()
const {
return Declaration; }
353 UseVector &getOrCreateUseVector(
Function *
F) {
354 std::shared_ptr<UseVector> &UV = UsesMap[
F];
356 UV = std::make_shared<UseVector>();
362 const UseVector *getUseVector(
Function &
F)
const {
363 auto I = UsesMap.find(&
F);
364 if (
I != UsesMap.end())
365 return I->second.get();
370 size_t getNumFunctionsWithUses()
const {
return UsesMap.size(); }
374 size_t getNumArgs()
const {
return ArgumentTypes.
size(); }
392 UseVector &UV = getOrCreateUseVector(
F);
402 while (!ToBeDeleted.
empty()) {
416 decltype(UsesMap)::iterator
begin() {
return UsesMap.
begin(); }
417 decltype(UsesMap)::iterator
end() {
return UsesMap.
end(); }
425 RuntimeFunction::OMPRTL___last>
433 InternalControlVar::ICV___last>
438 void initializeInternalControlVars() {
439#define ICV_RT_SET(_Name, RTL) \
441 auto &ICV = ICVs[_Name]; \
444#define ICV_RT_GET(Name, RTL) \
446 auto &ICV = ICVs[Name]; \
449#define ICV_DATA_ENV(Enum, _Name, _EnvVarName, Init) \
451 auto &ICV = ICVs[Enum]; \
454 ICV.InitKind = Init; \
455 ICV.EnvVarName = _EnvVarName; \
456 switch (ICV.InitKind) { \
457 case ICV_IMPLEMENTATION_DEFINED: \
458 ICV.InitValue = nullptr; \
461 ICV.InitValue = ConstantInt::get( \
462 Type::getInt32Ty(OMPBuilder.Int32->getContext()), 0); \
465 ICV.InitValue = ConstantInt::getFalse(OMPBuilder.Int1->getContext()); \
471#include "llvm/Frontend/OpenMP/OMPKinds.def"
477 static bool declMatchesRTFTypes(
Function *
F,
Type *RTFRetType,
484 if (
F->getReturnType() != RTFRetType)
486 if (
F->arg_size() != RTFArgTypes.
size())
489 auto *RTFTyIt = RTFArgTypes.
begin();
491 if (Arg.getType() != *RTFTyIt)
501 unsigned collectUses(RuntimeFunctionInfo &RFI,
bool CollectStats =
true) {
502 unsigned NumUses = 0;
503 if (!RFI.Declaration)
508 NumOpenMPRuntimeFunctionsIdentified += 1;
509 NumOpenMPRuntimeFunctionUsesIdentified += RFI.Declaration->getNumUses();
513 for (
Use &U : RFI.Declaration->uses()) {
514 if (
Instruction *UserI = dyn_cast<Instruction>(
U.getUser())) {
515 if (!
CGSCC ||
CGSCC->empty() ||
CGSCC->contains(UserI->getFunction())) {
516 RFI.getOrCreateUseVector(UserI->getFunction()).push_back(&U);
520 RFI.getOrCreateUseVector(
nullptr).push_back(&U);
529 auto &RFI = RFIs[RTF];
531 collectUses(RFI,
false);
535 void recollectUses() {
536 for (
int Idx = 0;
Idx < RFIs.size(); ++
Idx)
556 RuntimeFunctionInfo &RFI = RFIs[Fn];
558 if (RFI.Declaration && RFI.Declaration->isDeclaration())
566 void initializeRuntimeFunctions(
Module &M) {
569#define OMP_TYPE(VarName, ...) \
570 Type *VarName = OMPBuilder.VarName; \
573#define OMP_ARRAY_TYPE(VarName, ...) \
574 ArrayType *VarName##Ty = OMPBuilder.VarName##Ty; \
576 PointerType *VarName##PtrTy = OMPBuilder.VarName##PtrTy; \
577 (void)VarName##PtrTy;
579#define OMP_FUNCTION_TYPE(VarName, ...) \
580 FunctionType *VarName = OMPBuilder.VarName; \
582 PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \
585#define OMP_STRUCT_TYPE(VarName, ...) \
586 StructType *VarName = OMPBuilder.VarName; \
588 PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \
591#define OMP_RTL(_Enum, _Name, _IsVarArg, _ReturnType, ...) \
593 SmallVector<Type *, 8> ArgsTypes({__VA_ARGS__}); \
594 Function *F = M.getFunction(_Name); \
595 RTLFunctions.insert(F); \
596 if (declMatchesRTFTypes(F, OMPBuilder._ReturnType, ArgsTypes)) { \
597 RuntimeFunctionIDMap[F] = _Enum; \
598 auto &RFI = RFIs[_Enum]; \
601 RFI.IsVarArg = _IsVarArg; \
602 RFI.ReturnType = OMPBuilder._ReturnType; \
603 RFI.ArgumentTypes = std::move(ArgsTypes); \
604 RFI.Declaration = F; \
605 unsigned NumUses = collectUses(RFI); \
608 dbgs() << TAG << RFI.Name << (RFI.Declaration ? "" : " not") \
610 if (RFI.Declaration) \
611 dbgs() << TAG << "-> got " << NumUses << " uses in " \
612 << RFI.getNumFunctionsWithUses() \
613 << " different functions.\n"; \
617#include "llvm/Frontend/OpenMP/OMPKinds.def"
623 for (
StringRef Prefix : {
"__kmpc",
"_ZN4ompx",
"omp_"})
624 if (
F.hasFnAttribute(Attribute::NoInline) &&
625 F.getName().starts_with(Prefix) &&
626 !
F.hasFnAttribute(Attribute::OptimizeNone))
627 F.removeFnAttr(Attribute::NoInline);
638 bool OpenMPPostLink =
false;
641template <
typename Ty,
bool InsertInval
idates = true>
643 bool contains(
const Ty &Elem)
const {
return Set.contains(Elem); }
644 bool insert(
const Ty &Elem) {
645 if (InsertInvalidates)
647 return Set.insert(Elem);
650 const Ty &operator[](
int Idx)
const {
return Set[
Idx]; }
651 bool operator==(
const BooleanStateWithSetVector &RHS)
const {
652 return BooleanState::operator==(RHS) && Set ==
RHS.Set;
654 bool operator!=(
const BooleanStateWithSetVector &RHS)
const {
655 return !(*
this ==
RHS);
658 bool empty()
const {
return Set.empty(); }
659 size_t size()
const {
return Set.size(); }
662 BooleanStateWithSetVector &
operator^=(
const BooleanStateWithSetVector &RHS) {
663 BooleanState::operator^=(RHS);
664 Set.insert(
RHS.Set.begin(),
RHS.Set.end());
673 typename decltype(Set)::iterator
begin() {
return Set.
begin(); }
674 typename decltype(Set)::iterator
end() {
return Set.
end(); }
679template <
typename Ty,
bool InsertInval
idates = true>
680using BooleanStateWithPtrSetVector =
681 BooleanStateWithSetVector<Ty *, InsertInvalidates>;
685 bool IsAtFixpoint =
false;
689 BooleanStateWithPtrSetVector<
CallBase,
false>
690 ReachedKnownParallelRegions;
693 BooleanStateWithPtrSetVector<CallBase> ReachedUnknownParallelRegions;
698 BooleanStateWithPtrSetVector<Instruction, false> SPMDCompatibilityTracker;
713 bool IsKernelEntry =
false;
716 BooleanStateWithPtrSetVector<Function, false> ReachingKernelEntries;
721 BooleanStateWithSetVector<uint8_t> ParallelLevels;
724 bool NestedParallelism =
false;
729 KernelInfoState() =
default;
730 KernelInfoState(
bool BestState) {
739 bool isAtFixpoint()
const override {
return IsAtFixpoint; }
744 ParallelLevels.indicatePessimisticFixpoint();
745 ReachingKernelEntries.indicatePessimisticFixpoint();
746 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
747 ReachedKnownParallelRegions.indicatePessimisticFixpoint();
748 ReachedUnknownParallelRegions.indicatePessimisticFixpoint();
749 NestedParallelism =
true;
756 ParallelLevels.indicateOptimisticFixpoint();
757 ReachingKernelEntries.indicateOptimisticFixpoint();
758 SPMDCompatibilityTracker.indicateOptimisticFixpoint();
759 ReachedKnownParallelRegions.indicateOptimisticFixpoint();
760 ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
765 KernelInfoState &getAssumed() {
return *
this; }
766 const KernelInfoState &getAssumed()
const {
return *
this; }
768 bool operator==(
const KernelInfoState &RHS)
const {
769 if (SPMDCompatibilityTracker !=
RHS.SPMDCompatibilityTracker)
771 if (ReachedKnownParallelRegions !=
RHS.ReachedKnownParallelRegions)
773 if (ReachedUnknownParallelRegions !=
RHS.ReachedUnknownParallelRegions)
775 if (ReachingKernelEntries !=
RHS.ReachingKernelEntries)
777 if (ParallelLevels !=
RHS.ParallelLevels)
779 if (NestedParallelism !=
RHS.NestedParallelism)
785 bool mayContainParallelRegion() {
786 return !ReachedKnownParallelRegions.empty() ||
787 !ReachedUnknownParallelRegions.empty();
791 static KernelInfoState getBestState() {
return KernelInfoState(
true); }
793 static KernelInfoState getBestState(KernelInfoState &KIS) {
794 return getBestState();
798 static KernelInfoState getWorstState() {
return KernelInfoState(
false); }
801 KernelInfoState
operator^=(
const KernelInfoState &KIS) {
803 if (KIS.KernelInitCB) {
804 if (KernelInitCB && KernelInitCB != KIS.KernelInitCB)
807 KernelInitCB = KIS.KernelInitCB;
809 if (KIS.KernelDeinitCB) {
810 if (KernelDeinitCB && KernelDeinitCB != KIS.KernelDeinitCB)
813 KernelDeinitCB = KIS.KernelDeinitCB;
815 if (KIS.KernelEnvC) {
816 if (KernelEnvC && KernelEnvC != KIS.KernelEnvC)
819 KernelEnvC = KIS.KernelEnvC;
821 SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker;
822 ReachedKnownParallelRegions ^= KIS.ReachedKnownParallelRegions;
823 ReachedUnknownParallelRegions ^= KIS.ReachedUnknownParallelRegions;
824 NestedParallelism |= KIS.NestedParallelism;
828 KernelInfoState
operator&=(
const KernelInfoState &KIS) {
829 return (*
this ^= KIS);
845 OffloadArray() =
default;
852 if (!
Array.getAllocatedType()->isArrayTy())
855 if (!getValues(Array,
Before))
858 this->Array = &
Array;
862 static const unsigned DeviceIDArgNum = 1;
863 static const unsigned BasePtrsArgNum = 3;
864 static const unsigned PtrsArgNum = 4;
865 static const unsigned SizesArgNum = 5;
873 const uint64_t NumValues =
Array.getAllocatedType()->getArrayNumElements();
874 StoredValues.
assign(NumValues,
nullptr);
875 LastAccesses.
assign(NumValues,
nullptr);
880 if (BB !=
Before.getParent())
890 if (!isa<StoreInst>(&
I))
893 auto *S = cast<StoreInst>(&
I);
900 LastAccesses[
Idx] = S;
910 const unsigned NumValues = StoredValues.
size();
911 for (
unsigned I = 0;
I < NumValues; ++
I) {
912 if (!StoredValues[
I] || !LastAccesses[
I])
922 using OptimizationRemarkGetter =
926 OptimizationRemarkGetter OREGetter,
927 OMPInformationCache &OMPInfoCache,
Attributor &A)
929 OREGetter(OREGetter), OMPInfoCache(OMPInfoCache),
A(
A) {}
932 bool remarksEnabled() {
933 auto &Ctx =
M.getContext();
934 return Ctx.getDiagHandlerPtr()->isAnyRemarkEnabled(
DEBUG_TYPE);
938 bool run(
bool IsModulePass) {
942 bool Changed =
false;
948 Changed |= runAttributor(IsModulePass);
951 OMPInfoCache.recollectUses();
954 Changed |= rewriteDeviceCodeStateMachine();
956 if (remarksEnabled())
957 analysisGlobalization();
964 Changed |= runAttributor(IsModulePass);
967 OMPInfoCache.recollectUses();
969 Changed |= deleteParallelRegions();
972 Changed |= hideMemTransfersLatency();
973 Changed |= deduplicateRuntimeCalls();
975 if (mergeParallelRegions()) {
976 deduplicateRuntimeCalls();
982 if (OMPInfoCache.OpenMPPostLink)
983 Changed |= removeRuntimeSymbols();
990 void printICVs()
const {
995 for (
auto ICV : ICVs) {
996 auto ICVInfo = OMPInfoCache.ICVs[ICV];
998 return ORA <<
"OpenMP ICV " <<
ore::NV(
"OpenMPICV", ICVInfo.Name)
1000 << (ICVInfo.InitValue
1001 ?
toString(ICVInfo.InitValue->getValue(), 10,
true)
1002 :
"IMPLEMENTATION_DEFINED");
1005 emitRemark<OptimizationRemarkAnalysis>(
F,
"OpenMPICVTracker",
Remark);
1011 void printKernels()
const {
1017 return ORA <<
"OpenMP GPU kernel "
1018 <<
ore::NV(
"OpenMPGPUKernel",
F->getName()) <<
"\n";
1021 emitRemark<OptimizationRemarkAnalysis>(
F,
"OpenMPGPU",
Remark);
1027 static CallInst *getCallIfRegularCall(
1028 Use &U, OMPInformationCache::RuntimeFunctionInfo *RFI =
nullptr) {
1029 CallInst *CI = dyn_cast<CallInst>(
U.getUser());
1039 static CallInst *getCallIfRegularCall(
1040 Value &V, OMPInformationCache::RuntimeFunctionInfo *RFI =
nullptr) {
1041 CallInst *CI = dyn_cast<CallInst>(&V);
1051 bool mergeParallelRegions() {
1052 const unsigned CallbackCalleeOperand = 2;
1053 const unsigned CallbackFirstArgOperand = 3;
1057 OMPInformationCache::RuntimeFunctionInfo &RFI =
1058 OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
1060 if (!RFI.Declaration)
1064 OMPInformationCache::RuntimeFunctionInfo UnmergableCallsInfo[] = {
1065 OMPInfoCache.RFIs[OMPRTL___kmpc_push_proc_bind],
1066 OMPInfoCache.RFIs[OMPRTL___kmpc_push_num_threads],
1069 bool Changed =
false;
1075 BasicBlock *StartBB =
nullptr, *EndBB =
nullptr;
1076 auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
1077 BasicBlock *CGStartBB = CodeGenIP.getBlock();
1079 SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
1080 assert(StartBB !=
nullptr &&
"StartBB should not be null");
1082 assert(EndBB !=
nullptr &&
"EndBB should not be null");
1083 EndBB->getTerminator()->setSuccessor(0, CGEndBB);
1086 auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
Value &,
1087 Value &Inner,
Value *&ReplacementValue) -> InsertPointTy {
1088 ReplacementValue = &Inner;
1092 auto FiniCB = [&](InsertPointTy CodeGenIP) {};
1096 auto CreateSequentialRegion = [&](
Function *OuterFn,
1104 SplitBlock(ParentBB, SeqEndI->getNextNode(), DT, LI);
1108 SplitBlock(ParentBB, SeqStartI, DT, LI,
nullptr,
"seq.par.merged");
1111 "Expected a different CFG");
1115 auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
1116 BasicBlock *CGStartBB = CodeGenIP.getBlock();
1118 SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
1119 assert(SeqStartBB !=
nullptr &&
"SeqStartBB should not be null");
1121 assert(SeqEndBB !=
nullptr &&
"SeqEndBB should not be null");
1124 auto FiniCB = [&](InsertPointTy CodeGenIP) {};
1130 for (
User *Usr :
I.users()) {
1138 OutsideUsers.
insert(&UsrI);
1141 if (OutsideUsers.
empty())
1148 I.getType(),
DL.getAllocaAddrSpace(),
nullptr,
1149 I.getName() +
".seq.output.alloc", &OuterFn->
front().
front());
1153 new StoreInst(&
I, AllocaI, SeqStartBB->getTerminator());
1159 I.getType(), AllocaI,
I.getName() +
".seq.output.load", UsrI);
1165 InsertPointTy(ParentBB, ParentBB->
end()),
DL);
1166 InsertPointTy SeqAfterIP =
1167 OMPInfoCache.OMPBuilder.createMaster(Loc, BodyGenCB, FiniCB);
1169 OMPInfoCache.OMPBuilder.createBarrier(SeqAfterIP, OMPD_parallel);
1188 assert(MergableCIs.
size() > 1 &&
"Assumed multiple mergable CIs");
1191 OR <<
"Parallel region merged with parallel region"
1192 << (MergableCIs.
size() > 2 ?
"s" :
"") <<
" at ";
1195 if (CI != MergableCIs.
back())
1201 emitRemark<OptimizationRemark>(MergableCIs.
front(),
"OMP150",
Remark);
1205 <<
" parallel regions in " << OriginalFn->
getName()
1209 EndBB =
SplitBlock(BB, MergableCIs.
back()->getNextNode(), DT, LI);
1211 SplitBlock(EndBB, &*EndBB->getFirstInsertionPt(), DT, LI);
1215 assert(BB->getUniqueSuccessor() == StartBB &&
"Expected a different CFG");
1216 const DebugLoc DL = BB->getTerminator()->getDebugLoc();
1221 for (
auto *It = MergableCIs.
begin(), *
End = MergableCIs.
end() - 1;
1230 CreateSequentialRegion(OriginalFn, BB, ForkCI->
getNextNode(),
1241 InsertPointTy AfterIP = OMPInfoCache.OMPBuilder.createParallel(
1242 Loc, AllocaIP, BodyGenCB, PrivCB, FiniCB,
nullptr,
nullptr,
1243 OMP_PROC_BIND_default,
false);
1247 OMPInfoCache.OMPBuilder.finalize(OriginalFn);
1254 for (
auto *CI : MergableCIs) {
1256 FunctionType *FT = OMPInfoCache.OMPBuilder.ParallelTask;
1260 for (
unsigned U = CallbackFirstArgOperand,
E = CI->
arg_size(); U <
E;
1269 for (
unsigned U = CallbackFirstArgOperand,
E = CI->
arg_size(); U <
E;
1273 U - (CallbackFirstArgOperand - CallbackCalleeOperand), A);
1276 if (CI != MergableCIs.back()) {
1279 OMPInfoCache.OMPBuilder.createBarrier(
1288 assert(OutlinedFn != OriginalFn &&
"Outlining failed");
1289 CGUpdater.registerOutlinedFunction(*OriginalFn, *OutlinedFn);
1290 CGUpdater.reanalyzeFunction(*OriginalFn);
1292 NumOpenMPParallelRegionsMerged += MergableCIs.size();
1300 CallInst *CI = getCallIfRegularCall(U, &RFI);
1307 RFI.foreachUse(SCC, DetectPRsCB);
1313 for (
auto &It : BB2PRMap) {
1314 auto &CIs = It.getSecond();
1329 auto IsMergable = [&](
Instruction &
I,
bool IsBeforeMergableRegion) {
1332 if (
I.isTerminator())
1335 if (!isa<CallInst>(&
I))
1339 if (IsBeforeMergableRegion) {
1341 if (!CalledFunction)
1348 for (
const auto &RFI : UnmergableCallsInfo) {
1349 if (CalledFunction == RFI.Declaration)
1357 if (!isa<IntrinsicInst>(CI))
1368 if (CIs.count(&
I)) {
1374 if (IsMergable(
I, MergableCIs.
empty()))
1379 for (; It !=
End; ++It) {
1381 if (CIs.count(&SkipI)) {
1383 <<
" due to " <<
I <<
"\n");
1390 if (MergableCIs.
size() > 1) {
1391 MergableCIsVector.
push_back(MergableCIs);
1393 <<
" parallel regions in block " << BB->
getName()
1398 MergableCIs.
clear();
1401 if (!MergableCIsVector.
empty()) {
1404 for (
auto &MergableCIs : MergableCIsVector)
1405 Merge(MergableCIs, BB);
1406 MergableCIsVector.clear();
1413 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_fork_call);
1414 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_barrier);
1415 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_master);
1416 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_end_master);
1423 bool deleteParallelRegions() {
1424 const unsigned CallbackCalleeOperand = 2;
1426 OMPInformationCache::RuntimeFunctionInfo &RFI =
1427 OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
1429 if (!RFI.Declaration)
1432 bool Changed =
false;
1434 CallInst *CI = getCallIfRegularCall(U);
1437 auto *Fn = dyn_cast<Function>(
1441 if (!Fn->onlyReadsMemory())
1443 if (!Fn->hasFnAttribute(Attribute::WillReturn))
1450 return OR <<
"Removing parallel region with no side-effects.";
1452 emitRemark<OptimizationRemark>(CI,
"OMP160",
Remark);
1454 CGUpdater.removeCallSite(*CI);
1457 ++NumOpenMPParallelRegionsDeleted;
1461 RFI.foreachUse(SCC, DeleteCallCB);
1467 bool deduplicateRuntimeCalls() {
1468 bool Changed =
false;
1471 OMPRTL_omp_get_num_threads,
1472 OMPRTL_omp_in_parallel,
1473 OMPRTL_omp_get_cancellation,
1474 OMPRTL_omp_get_thread_limit,
1475 OMPRTL_omp_get_supported_active_levels,
1476 OMPRTL_omp_get_level,
1477 OMPRTL_omp_get_ancestor_thread_num,
1478 OMPRTL_omp_get_team_size,
1479 OMPRTL_omp_get_active_level,
1480 OMPRTL_omp_in_final,
1481 OMPRTL_omp_get_proc_bind,
1482 OMPRTL_omp_get_num_places,
1483 OMPRTL_omp_get_num_procs,
1484 OMPRTL_omp_get_place_num,
1485 OMPRTL_omp_get_partition_num_places,
1486 OMPRTL_omp_get_partition_place_nums};
1490 collectGlobalThreadIdArguments(GTIdArgs);
1492 <<
" global thread ID arguments\n");
1495 for (
auto DeduplicableRuntimeCallID : DeduplicableRuntimeCallIDs)
1496 Changed |= deduplicateRuntimeCalls(
1497 *
F, OMPInfoCache.RFIs[DeduplicableRuntimeCallID]);
1501 Value *GTIdArg =
nullptr;
1503 if (GTIdArgs.
count(&Arg)) {
1507 Changed |= deduplicateRuntimeCalls(
1508 *
F, OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num], GTIdArg);
1515 bool removeRuntimeSymbols() {
1521 if (!
GV->getType()->isPointerTy())
1529 GlobalVariable *Client = dyn_cast<GlobalVariable>(
C->stripPointerCasts());
1538 GV->eraseFromParent();
1551 bool hideMemTransfersLatency() {
1552 auto &RFI = OMPInfoCache.RFIs[OMPRTL___tgt_target_data_begin_mapper];
1553 bool Changed =
false;
1555 auto *RTCall = getCallIfRegularCall(U, &RFI);
1559 OffloadArray OffloadArrays[3];
1560 if (!getValuesInOffloadArrays(*RTCall, OffloadArrays))
1563 LLVM_DEBUG(dumpValuesInOffloadArrays(OffloadArrays));
1566 bool WasSplit =
false;
1567 Instruction *WaitMovementPoint = canBeMovedDownwards(*RTCall);
1568 if (WaitMovementPoint)
1569 WasSplit = splitTargetDataBeginRTC(*RTCall, *WaitMovementPoint);
1571 Changed |= WasSplit;
1574 if (OMPInfoCache.runtimeFnsAvailable(
1575 {OMPRTL___tgt_target_data_begin_mapper_issue,
1576 OMPRTL___tgt_target_data_begin_mapper_wait}))
1577 RFI.foreachUse(SCC, SplitMemTransfers);
1582 void analysisGlobalization() {
1583 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
1585 auto CheckGlobalization = [&](
Use &
U,
Function &Decl) {
1586 if (
CallInst *CI = getCallIfRegularCall(U, &RFI)) {
1589 <<
"Found thread data sharing on the GPU. "
1590 <<
"Expect degraded performance due to data globalization.";
1592 emitRemark<OptimizationRemarkMissed>(CI,
"OMP112",
Remark);
1598 RFI.foreachUse(SCC, CheckGlobalization);
1603 bool getValuesInOffloadArrays(
CallInst &RuntimeCall,
1605 assert(OAs.
size() == 3 &&
"Need space for three offload arrays!");
1615 Value *BasePtrsArg =
1624 if (!isa<AllocaInst>(V))
1626 auto *BasePtrsArray = cast<AllocaInst>(V);
1627 if (!OAs[0].
initialize(*BasePtrsArray, RuntimeCall))
1632 if (!isa<AllocaInst>(V))
1634 auto *PtrsArray = cast<AllocaInst>(V);
1635 if (!OAs[1].
initialize(*PtrsArray, RuntimeCall))
1641 if (isa<GlobalValue>(V))
1642 return isa<Constant>(V);
1643 if (!isa<AllocaInst>(V))
1646 auto *SizesArray = cast<AllocaInst>(V);
1647 if (!OAs[2].
initialize(*SizesArray, RuntimeCall))
1658 assert(OAs.
size() == 3 &&
"There are three offload arrays to debug!");
1661 std::string ValuesStr;
1663 std::string Separator =
" --- ";
1665 for (
auto *BP : OAs[0].StoredValues) {
1672 for (
auto *
P : OAs[1].StoredValues) {
1679 for (
auto *S : OAs[2].StoredValues) {
1693 bool IsWorthIt =
false;
1716 bool splitTargetDataBeginRTC(
CallInst &RuntimeCall,
1721 auto &
IRBuilder = OMPInfoCache.OMPBuilder;
1725 Entry.getFirstNonPHIOrDbgOrAlloca());
1727 IRBuilder.AsyncInfo,
nullptr,
"handle");
1735 M, OMPRTL___tgt_target_data_begin_mapper_issue);
1739 for (
auto &Arg : RuntimeCall.
args())
1740 Args.push_back(Arg.get());
1741 Args.push_back(Handle);
1745 OMPInfoCache.setCallingConvention(IssueDecl, IssueCallsite);
1751 M, OMPRTL___tgt_target_data_begin_mapper_wait);
1753 Value *WaitParams[2] = {
1755 OffloadArray::DeviceIDArgNum),
1759 WaitDecl, WaitParams,
"", &WaitMovementPoint);
1760 OMPInfoCache.setCallingConvention(WaitDecl, WaitCallsite);
1765 static Value *combinedIdentStruct(
Value *CurrentIdent,
Value *NextIdent,
1766 bool GlobalOnly,
bool &SingleChoice) {
1767 if (CurrentIdent == NextIdent)
1768 return CurrentIdent;
1772 if (!GlobalOnly || isa<GlobalValue>(NextIdent)) {
1773 SingleChoice = !CurrentIdent;
1785 getCombinedIdentFromCallUsesIn(OMPInformationCache::RuntimeFunctionInfo &RFI,
1787 bool SingleChoice =
true;
1788 Value *Ident =
nullptr;
1790 CallInst *CI = getCallIfRegularCall(U, &RFI);
1791 if (!CI || &
F != &Caller)
1794 true, SingleChoice);
1797 RFI.foreachUse(SCC, CombineIdentStruct);
1799 if (!Ident || !SingleChoice) {
1802 if (!OMPInfoCache.OMPBuilder.getInsertionPoint().getBlock())
1804 &
F.getEntryBlock(),
F.getEntryBlock().begin()));
1809 OMPInfoCache.OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
1810 Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(Loc, SrcLocStrSize);
1817 bool deduplicateRuntimeCalls(
Function &
F,
1818 OMPInformationCache::RuntimeFunctionInfo &RFI,
1819 Value *ReplVal =
nullptr) {
1820 auto *UV = RFI.getUseVector(
F);
1821 if (!UV || UV->size() + (ReplVal !=
nullptr) < 2)
1825 dbgs() <<
TAG <<
"Deduplicate " << UV->size() <<
" uses of " << RFI.Name
1826 << (ReplVal ?
" with an existing value\n" :
"\n") <<
"\n");
1828 assert((!ReplVal || (isa<Argument>(ReplVal) &&
1829 cast<Argument>(ReplVal)->
getParent() == &
F)) &&
1830 "Unexpected replacement value!");
1833 auto CanBeMoved = [
this](
CallBase &CB) {
1834 unsigned NumArgs = CB.arg_size();
1837 if (CB.getArgOperand(0)->getType() != OMPInfoCache.OMPBuilder.IdentPtr)
1839 for (
unsigned U = 1;
U < NumArgs; ++
U)
1840 if (isa<Instruction>(CB.getArgOperand(U)))
1851 for (
Use *U : *UV) {
1852 if (
CallInst *CI = getCallIfRegularCall(*U, &RFI)) {
1857 if (!CanBeMoved(*CI))
1865 assert(IP &&
"Expected insertion point!");
1866 cast<Instruction>(ReplVal)->moveBefore(IP);
1872 if (
CallBase *CI = dyn_cast<CallBase>(ReplVal)) {
1875 Value *Ident = getCombinedIdentFromCallUsesIn(RFI,
F,
1881 bool Changed =
false;
1883 CallInst *CI = getCallIfRegularCall(U, &RFI);
1884 if (!CI || CI == ReplVal || &
F != &Caller)
1889 return OR <<
"OpenMP runtime call "
1890 <<
ore::NV(
"OpenMPOptRuntime", RFI.Name) <<
" deduplicated.";
1893 emitRemark<OptimizationRemark>(CI,
"OMP170",
Remark);
1895 emitRemark<OptimizationRemark>(&
F,
"OMP170",
Remark);
1897 CGUpdater.removeCallSite(*CI);
1900 ++NumOpenMPRuntimeCallsDeduplicated;
1904 RFI.foreachUse(SCC, ReplaceAndDeleteCB);
1918 if (!
F.hasLocalLinkage())
1920 for (
Use &U :
F.uses()) {
1921 if (
CallInst *CI = getCallIfRegularCall(U)) {
1923 if (CI == &RefCI || GTIdArgs.
count(ArgOp) ||
1924 getCallIfRegularCall(
1925 *ArgOp, &OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num]))
1934 auto AddUserArgs = [&](
Value >Id) {
1935 for (
Use &U : GTId.uses())
1936 if (
CallInst *CI = dyn_cast<CallInst>(
U.getUser()))
1939 if (CallArgOpIsGTId(*Callee,
U.getOperandNo(), *CI))
1944 OMPInformationCache::RuntimeFunctionInfo &GlobThreadNumRFI =
1945 OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num];
1947 GlobThreadNumRFI.foreachUse(SCC, [&](
Use &U,
Function &
F) {
1948 if (
CallInst *CI = getCallIfRegularCall(U, &GlobThreadNumRFI))
1956 for (
unsigned U = 0;
U < GTIdArgs.
size(); ++
U)
1957 AddUserArgs(*GTIdArgs[U]);
1972 return getUniqueKernelFor(*
I.getFunction());
1977 bool rewriteDeviceCodeStateMachine();
1993 template <
typename RemarkKind,
typename RemarkCallBack>
1995 RemarkCallBack &&RemarkCB)
const {
1997 auto &ORE = OREGetter(
F);
2001 return RemarkCB(RemarkKind(
DEBUG_TYPE, RemarkName,
I))
2002 <<
" [" << RemarkName <<
"]";
2006 [&]() {
return RemarkCB(RemarkKind(
DEBUG_TYPE, RemarkName,
I)); });
2010 template <
typename RemarkKind,
typename RemarkCallBack>
2012 RemarkCallBack &&RemarkCB)
const {
2013 auto &ORE = OREGetter(
F);
2017 return RemarkCB(RemarkKind(
DEBUG_TYPE, RemarkName,
F))
2018 <<
" [" << RemarkName <<
"]";
2022 [&]() {
return RemarkCB(RemarkKind(
DEBUG_TYPE, RemarkName,
F)); });
2036 OptimizationRemarkGetter OREGetter;
2039 OMPInformationCache &OMPInfoCache;
2045 bool runAttributor(
bool IsModulePass) {
2049 registerAAs(IsModulePass);
2054 <<
" functions, result: " << Changed <<
".\n");
2056 return Changed == ChangeStatus::CHANGED;
2063 void registerAAs(
bool IsModulePass);
2072 if (OMPInfoCache.CGSCC && !OMPInfoCache.CGSCC->empty() &&
2073 !OMPInfoCache.CGSCC->contains(&
F))
2078 std::optional<Kernel> &CachedKernel = UniqueKernelMap[&
F];
2080 return *CachedKernel;
2087 return *CachedKernel;
2090 CachedKernel =
nullptr;
2091 if (!
F.hasLocalLinkage()) {
2095 return ORA <<
"Potentially unknown OpenMP target region caller.";
2097 emitRemark<OptimizationRemarkAnalysis>(&
F,
"OMP100",
Remark);
2103 auto GetUniqueKernelForUse = [&](
const Use &
U) ->
Kernel {
2104 if (
auto *Cmp = dyn_cast<ICmpInst>(
U.getUser())) {
2106 if (
Cmp->isEquality())
2107 return getUniqueKernelFor(*Cmp);
2110 if (
auto *CB = dyn_cast<CallBase>(
U.getUser())) {
2112 if (CB->isCallee(&U))
2113 return getUniqueKernelFor(*CB);
2115 OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
2116 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
2118 if (OpenMPOpt::getCallIfRegularCall(*
U.getUser(), &KernelParallelRFI))
2119 return getUniqueKernelFor(*CB);
2128 OMPInformationCache::foreachUse(
F, [&](
const Use &U) {
2129 PotentialKernels.
insert(GetUniqueKernelForUse(U));
2133 if (PotentialKernels.
size() == 1)
2134 K = *PotentialKernels.
begin();
2137 UniqueKernelMap[&
F] =
K;
2142bool OpenMPOpt::rewriteDeviceCodeStateMachine() {
2143 OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
2144 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
2146 bool Changed =
false;
2147 if (!KernelParallelRFI)
2158 bool UnknownUse =
false;
2159 bool KernelParallelUse =
false;
2160 unsigned NumDirectCalls = 0;
2163 OMPInformationCache::foreachUse(*
F, [&](
Use &U) {
2164 if (
auto *CB = dyn_cast<CallBase>(
U.getUser()))
2165 if (CB->isCallee(&U)) {
2170 if (isa<ICmpInst>(
U.getUser())) {
2171 ToBeReplacedStateMachineUses.push_back(&U);
2177 OpenMPOpt::getCallIfRegularCall(*
U.getUser(), &KernelParallelRFI);
2178 const unsigned int WrapperFunctionArgNo = 6;
2179 if (!KernelParallelUse && CI &&
2181 KernelParallelUse = true;
2182 ToBeReplacedStateMachineUses.push_back(&U);
2190 if (!KernelParallelUse)
2196 if (UnknownUse || NumDirectCalls != 1 ||
2197 ToBeReplacedStateMachineUses.
size() > 2) {
2199 return ORA <<
"Parallel region is used in "
2200 << (UnknownUse ?
"unknown" :
"unexpected")
2201 <<
" ways. Will not attempt to rewrite the state machine.";
2203 emitRemark<OptimizationRemarkAnalysis>(
F,
"OMP101",
Remark);
2212 return ORA <<
"Parallel region is not called from a unique kernel. "
2213 "Will not attempt to rewrite the state machine.";
2215 emitRemark<OptimizationRemarkAnalysis>(
F,
"OMP102",
Remark);
2231 for (
Use *U : ToBeReplacedStateMachineUses)
2233 ID,
U->get()->getType()));
2235 ++NumOpenMPParallelRegionsReplacedInGPUStateMachine;
2244struct AAICVTracker :
public StateWrapper<BooleanState, AbstractAttribute> {
2249 bool isAssumedTracked()
const {
return getAssumed(); }
2252 bool isKnownTracked()
const {
return getAssumed(); }
2261 return std::nullopt;
2267 virtual std::optional<Value *>
2275 const std::string
getName()
const override {
return "AAICVTracker"; }
2278 const char *getIdAddr()
const override {
return &
ID; }
2285 static const char ID;
2288struct AAICVTrackerFunction :
public AAICVTracker {
2290 : AAICVTracker(IRP,
A) {}
2293 const std::string getAsStr(
Attributor *)
const override {
2294 return "ICVTrackerFunction";
2298 void trackStatistics()
const override {}
2302 return ChangeStatus::UNCHANGED;
2307 InternalControlVar::ICV___last>
2308 ICVReplacementValuesMap;
2315 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
2318 auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2320 auto &ValuesMap = ICVReplacementValuesMap[ICV];
2322 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U);
2328 if (ValuesMap.insert(std::make_pair(CI, CI->
getArgOperand(0))).second)
2329 HasChanged = ChangeStatus::CHANGED;
2335 std::optional<Value *> ReplVal = getValueForCall(
A,
I, ICV);
2336 if (ReplVal && ValuesMap.insert(std::make_pair(&
I, *ReplVal)).second)
2337 HasChanged = ChangeStatus::CHANGED;
2343 SetterRFI.foreachUse(TrackValues,
F);
2345 bool UsedAssumedInformation =
false;
2346 A.checkForAllInstructions(CallCheck, *
this, {Instruction::Call},
2347 UsedAssumedInformation,
2353 if (HasChanged == ChangeStatus::CHANGED && !ValuesMap.count(Entry))
2354 ValuesMap.insert(std::make_pair(Entry,
nullptr));
2365 const auto *CB = dyn_cast<CallBase>(&
I);
2366 if (!CB || CB->hasFnAttr(
"no_openmp") ||
2367 CB->hasFnAttr(
"no_openmp_routines"))
2368 return std::nullopt;
2370 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
2371 auto &GetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Getter];
2372 auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2373 Function *CalledFunction = CB->getCalledFunction();
2376 if (CalledFunction ==
nullptr)
2378 if (CalledFunction == GetterRFI.Declaration)
2379 return std::nullopt;
2380 if (CalledFunction == SetterRFI.Declaration) {
2381 if (ICVReplacementValuesMap[ICV].
count(&
I))
2382 return ICVReplacementValuesMap[ICV].
lookup(&
I);
2391 const auto *ICVTrackingAA =
A.getAAFor<AAICVTracker>(
2394 if (ICVTrackingAA->isAssumedTracked()) {
2395 std::optional<Value *> URV =
2396 ICVTrackingAA->getUniqueReplacementValue(ICV);
2407 std::optional<Value *>
2409 return std::nullopt;
2416 const auto &ValuesMap = ICVReplacementValuesMap[ICV];
2417 if (ValuesMap.count(
I))
2418 return ValuesMap.lookup(
I);
2424 std::optional<Value *> ReplVal;
2426 while (!Worklist.
empty()) {
2428 if (!Visited.
insert(CurrInst).second)
2436 if (ValuesMap.count(CurrInst)) {
2437 std::optional<Value *> NewReplVal = ValuesMap.lookup(CurrInst);
2440 ReplVal = NewReplVal;
2446 if (ReplVal != NewReplVal)
2452 std::optional<Value *> NewReplVal = getValueForCall(
A, *CurrInst, ICV);
2458 ReplVal = NewReplVal;
2464 if (ReplVal != NewReplVal)
2469 if (CurrBB ==
I->getParent() && ReplVal)
2474 if (
const Instruction *Terminator = Pred->getTerminator())
2482struct AAICVTrackerFunctionReturned : AAICVTracker {
2484 : AAICVTracker(IRP,
A) {}
2487 const std::string getAsStr(
Attributor *)
const override {
2488 return "ICVTrackerFunctionReturned";
2492 void trackStatistics()
const override {}
2496 return ChangeStatus::UNCHANGED;
2501 InternalControlVar::ICV___last>
2502 ICVReplacementValuesMap;
2505 std::optional<Value *>
2507 return ICVReplacementValuesMap[ICV];
2512 const auto *ICVTrackingAA =
A.getAAFor<AAICVTracker>(
2515 if (!ICVTrackingAA->isAssumedTracked())
2516 return indicatePessimisticFixpoint();
2519 std::optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2520 std::optional<Value *> UniqueICVValue;
2523 std::optional<Value *> NewReplVal =
2524 ICVTrackingAA->getReplacementValue(ICV, &
I,
A);
2527 if (UniqueICVValue && UniqueICVValue != NewReplVal)
2530 UniqueICVValue = NewReplVal;
2535 bool UsedAssumedInformation =
false;
2536 if (!
A.checkForAllInstructions(CheckReturnInst, *
this, {Instruction::Ret},
2537 UsedAssumedInformation,
2539 UniqueICVValue =
nullptr;
2541 if (UniqueICVValue == ReplVal)
2544 ReplVal = UniqueICVValue;
2545 Changed = ChangeStatus::CHANGED;
2552struct AAICVTrackerCallSite : AAICVTracker {
2554 : AAICVTracker(IRP,
A) {}
2557 assert(getAnchorScope() &&
"Expected anchor function");
2561 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
2563 auto ICVInfo = OMPInfoCache.ICVs[ICV];
2564 auto &Getter = OMPInfoCache.RFIs[ICVInfo.Getter];
2565 if (Getter.Declaration == getAssociatedFunction()) {
2566 AssociatedICV = ICVInfo.Kind;
2572 indicatePessimisticFixpoint();
2576 if (!ReplVal || !*ReplVal)
2577 return ChangeStatus::UNCHANGED;
2580 A.deleteAfterManifest(*getCtxI());
2582 return ChangeStatus::CHANGED;
2586 const std::string getAsStr(
Attributor *)
const override {
2587 return "ICVTrackerCallSite";
2591 void trackStatistics()
const override {}
2594 std::optional<Value *> ReplVal;
2597 const auto *ICVTrackingAA =
A.getAAFor<AAICVTracker>(
2601 if (!ICVTrackingAA->isAssumedTracked())
2602 return indicatePessimisticFixpoint();
2604 std::optional<Value *> NewReplVal =
2605 ICVTrackingAA->getReplacementValue(AssociatedICV, getCtxI(),
A);
2607 if (ReplVal == NewReplVal)
2608 return ChangeStatus::UNCHANGED;
2610 ReplVal = NewReplVal;
2611 return ChangeStatus::CHANGED;
2616 std::optional<Value *>
2622struct AAICVTrackerCallSiteReturned : AAICVTracker {
2624 : AAICVTracker(IRP,
A) {}
2627 const std::string getAsStr(
Attributor *)
const override {
2628 return "ICVTrackerCallSiteReturned";
2632 void trackStatistics()
const override {}
2636 return ChangeStatus::UNCHANGED;
2641 InternalControlVar::ICV___last>
2642 ICVReplacementValuesMap;
2646 std::optional<Value *>
2648 return ICVReplacementValuesMap[ICV];
2653 const auto *ICVTrackingAA =
A.getAAFor<AAICVTracker>(
2655 DepClassTy::REQUIRED);
2658 if (!ICVTrackingAA->isAssumedTracked())
2659 return indicatePessimisticFixpoint();
2662 std::optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2663 std::optional<Value *> NewReplVal =
2664 ICVTrackingAA->getUniqueReplacementValue(ICV);
2666 if (ReplVal == NewReplVal)
2669 ReplVal = NewReplVal;
2670 Changed = ChangeStatus::CHANGED;
2678static bool hasFunctionEndAsUniqueSuccessor(
const BasicBlock *BB) {
2684 return hasFunctionEndAsUniqueSuccessor(
Successor);
2691 ~AAExecutionDomainFunction() {
delete RPOT; }
2695 assert(
F &&
"Expected anchor function");
2700 unsigned TotalBlocks = 0, InitialThreadBlocks = 0, AlignedBlocks = 0;
2701 for (
auto &It : BEDMap) {
2705 InitialThreadBlocks += It.getSecond().IsExecutedByInitialThreadOnly;
2706 AlignedBlocks += It.getSecond().IsReachedFromAlignedBarrierOnly &&
2707 It.getSecond().IsReachingAlignedBarrierOnly;
2709 return "[AAExecutionDomain] " + std::to_string(InitialThreadBlocks) +
"/" +
2710 std::to_string(AlignedBlocks) +
" of " +
2711 std::to_string(TotalBlocks) +
2712 " executed by initial thread / aligned";
2724 << BB.
getName() <<
" is executed by a single thread.\n";
2734 auto HandleAlignedBarrier = [&](
CallBase *CB) {
2735 const ExecutionDomainTy &ED = CB ? CEDMap[{CB, PRE}] : BEDMap[
nullptr];
2736 if (!ED.IsReachedFromAlignedBarrierOnly ||
2737 ED.EncounteredNonLocalSideEffect)
2739 if (!ED.EncounteredAssumes.empty() && !
A.isModulePass())
2750 DeletedBarriers.
insert(CB);
2751 A.deleteAfterManifest(*CB);
2752 ++NumBarriersEliminated;
2754 }
else if (!ED.AlignedBarriers.empty()) {
2757 ED.AlignedBarriers.end());
2759 while (!Worklist.
empty()) {
2761 if (!Visited.
insert(LastCB))
2765 if (!hasFunctionEndAsUniqueSuccessor(LastCB->
getParent()))
2767 if (!DeletedBarriers.
count(LastCB)) {
2768 ++NumBarriersEliminated;
2769 A.deleteAfterManifest(*LastCB);
2775 const ExecutionDomainTy &LastED = CEDMap[{LastCB, PRE}];
2776 Worklist.
append(LastED.AlignedBarriers.begin(),
2777 LastED.AlignedBarriers.end());
2783 if (!ED.EncounteredAssumes.empty() && (CB || !ED.AlignedBarriers.empty()))
2784 for (
auto *AssumeCB : ED.EncounteredAssumes)
2785 A.deleteAfterManifest(*AssumeCB);
2788 for (
auto *CB : AlignedBarriers)
2789 HandleAlignedBarrier(CB);
2793 HandleAlignedBarrier(
nullptr);
2805 mergeInPredecessorBarriersAndAssumptions(
Attributor &
A, ExecutionDomainTy &ED,
2806 const ExecutionDomainTy &PredED);
2811 bool mergeInPredecessor(
Attributor &
A, ExecutionDomainTy &ED,
2812 const ExecutionDomainTy &PredED,
2813 bool InitialEdgeOnly =
false);
2816 bool handleCallees(
Attributor &
A, ExecutionDomainTy &EntryBBED);
2826 assert(BB.
getParent() == getAnchorScope() &&
"Block is out of scope!");
2827 return BEDMap.lookup(&BB).IsExecutedByInitialThreadOnly;
2832 assert(
I.getFunction() == getAnchorScope() &&
2833 "Instruction is out of scope!");
2837 bool ForwardIsOk =
true;
2843 auto *CB = dyn_cast<CallBase>(CurI);
2846 if (CB != &
I && AlignedBarriers.contains(
const_cast<CallBase *
>(CB)))
2848 const auto &It = CEDMap.find({CB, PRE});
2849 if (It == CEDMap.end())
2851 if (!It->getSecond().IsReachingAlignedBarrierOnly)
2852 ForwardIsOk =
false;
2856 if (!CurI && !BEDMap.lookup(
I.getParent()).IsReachingAlignedBarrierOnly)
2857 ForwardIsOk =
false;
2862 auto *CB = dyn_cast<CallBase>(CurI);
2865 if (CB != &
I && AlignedBarriers.contains(
const_cast<CallBase *
>(CB)))
2867 const auto &It = CEDMap.find({CB, POST});
2868 if (It == CEDMap.end())
2870 if (It->getSecond().IsReachedFromAlignedBarrierOnly)
2883 return BEDMap.lookup(
nullptr).IsReachedFromAlignedBarrierOnly;
2885 return BEDMap.lookup(PredBB).IsReachedFromAlignedBarrierOnly;
2897 "No request should be made against an invalid state!");
2898 return BEDMap.lookup(&BB);
2900 std::pair<ExecutionDomainTy, ExecutionDomainTy>
2903 "No request should be made against an invalid state!");
2904 return {CEDMap.lookup({&CB, PRE}), CEDMap.lookup({&CB, POST})};
2908 "No request should be made against an invalid state!");
2909 return InterProceduralED;
2923 if (!Cmp || !
Cmp->isTrueWhenEqual() || !
Cmp->isEquality())
2931 if (
C->isAllOnesValue()) {
2932 auto *CB = dyn_cast<CallBase>(
Cmp->getOperand(0));
2933 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
2934 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
2935 CB = CB ? OpenMPOpt::getCallIfRegularCall(*CB, &RFI) : nullptr;
2941 KernelInfo::getExecModeFromKernelEnvironment(KernelEnvC);
2947 if (
auto *II = dyn_cast<IntrinsicInst>(
Cmp->getOperand(0)))
2948 if (II->getIntrinsicID() == Intrinsic::nvvm_read_ptx_sreg_tid_x)
2952 if (
auto *II = dyn_cast<IntrinsicInst>(
Cmp->getOperand(0)))
2953 if (II->getIntrinsicID() == Intrinsic::amdgcn_workitem_id_x)
2961 ExecutionDomainTy InterProceduralED;
2973 static bool setAndRecord(
bool &R,
bool V) {
2984void AAExecutionDomainFunction::mergeInPredecessorBarriersAndAssumptions(
2985 Attributor &
A, ExecutionDomainTy &ED,
const ExecutionDomainTy &PredED) {
2986 for (
auto *EA : PredED.EncounteredAssumes)
2987 ED.addAssumeInst(
A, *EA);
2989 for (
auto *AB : PredED.AlignedBarriers)
2990 ED.addAlignedBarrier(
A, *AB);
2993bool AAExecutionDomainFunction::mergeInPredecessor(
2994 Attributor &
A, ExecutionDomainTy &ED,
const ExecutionDomainTy &PredED,
2995 bool InitialEdgeOnly) {
2997 bool Changed =
false;
2999 setAndRecord(ED.IsExecutedByInitialThreadOnly,
3000 InitialEdgeOnly || (PredED.IsExecutedByInitialThreadOnly &&
3001 ED.IsExecutedByInitialThreadOnly));
3003 Changed |= setAndRecord(ED.IsReachedFromAlignedBarrierOnly,
3004 ED.IsReachedFromAlignedBarrierOnly &&
3005 PredED.IsReachedFromAlignedBarrierOnly);
3006 Changed |= setAndRecord(ED.EncounteredNonLocalSideEffect,
3007 ED.EncounteredNonLocalSideEffect |
3008 PredED.EncounteredNonLocalSideEffect);
3010 if (ED.IsReachedFromAlignedBarrierOnly)
3011 mergeInPredecessorBarriersAndAssumptions(
A, ED, PredED);
3013 ED.clearAssumeInstAndAlignedBarriers();
3017bool AAExecutionDomainFunction::handleCallees(
Attributor &
A,
3018 ExecutionDomainTy &EntryBBED) {
3023 DepClassTy::OPTIONAL);
3024 if (!EDAA || !EDAA->getState().isValidState())
3027 EDAA->getExecutionDomain(*cast<CallBase>(ACS.getInstruction())));
3031 ExecutionDomainTy ExitED;
3032 bool AllCallSitesKnown;
3033 if (
A.checkForAllCallSites(PredForCallSite, *
this,
3035 AllCallSitesKnown)) {
3036 for (
const auto &[CSInED, CSOutED] : CallSiteEDs) {
3037 mergeInPredecessor(
A, EntryBBED, CSInED);
3038 ExitED.IsReachingAlignedBarrierOnly &=
3039 CSOutED.IsReachingAlignedBarrierOnly;
3046 EntryBBED.IsExecutedByInitialThreadOnly =
false;
3047 EntryBBED.IsReachedFromAlignedBarrierOnly =
true;
3048 EntryBBED.EncounteredNonLocalSideEffect =
false;
3049 ExitED.IsReachingAlignedBarrierOnly =
false;
3051 EntryBBED.IsExecutedByInitialThreadOnly =
false;
3052 EntryBBED.IsReachedFromAlignedBarrierOnly =
false;
3053 EntryBBED.EncounteredNonLocalSideEffect =
true;
3054 ExitED.IsReachingAlignedBarrierOnly =
false;
3058 bool Changed =
false;
3059 auto &FnED = BEDMap[
nullptr];
3060 Changed |= setAndRecord(FnED.IsReachedFromAlignedBarrierOnly,
3061 FnED.IsReachedFromAlignedBarrierOnly &
3062 EntryBBED.IsReachedFromAlignedBarrierOnly);
3063 Changed |= setAndRecord(FnED.IsReachingAlignedBarrierOnly,
3064 FnED.IsReachingAlignedBarrierOnly &
3065 ExitED.IsReachingAlignedBarrierOnly);
3066 Changed |= setAndRecord(FnED.IsExecutedByInitialThreadOnly,
3067 EntryBBED.IsExecutedByInitialThreadOnly);
3073 bool Changed =
false;
3078 auto HandleAlignedBarrier = [&](
CallBase &CB, ExecutionDomainTy &ED) {
3079 Changed |= AlignedBarriers.insert(&CB);
3081 auto &CallInED = CEDMap[{&CB, PRE}];
3082 Changed |= mergeInPredecessor(
A, CallInED, ED);
3083 CallInED.IsReachingAlignedBarrierOnly =
true;
3085 ED.EncounteredNonLocalSideEffect =
false;
3086 ED.IsReachedFromAlignedBarrierOnly =
true;
3088 ED.clearAssumeInstAndAlignedBarriers();
3089 ED.addAlignedBarrier(
A, CB);
3090 auto &CallOutED = CEDMap[{&CB, POST}];
3091 Changed |= mergeInPredecessor(
A, CallOutED, ED);
3095 A.getAAFor<
AAIsDead>(*
this, getIRPosition(), DepClassTy::OPTIONAL);
3102 for (
auto &RIt : *RPOT) {
3105 bool IsEntryBB = &BB == &EntryBB;
3108 bool AlignedBarrierLastInBlock = IsEntryBB && IsKernel;
3109 bool IsExplicitlyAligned = IsEntryBB && IsKernel;
3110 ExecutionDomainTy ED;
3113 Changed |= handleCallees(
A, ED);
3117 if (LivenessAA && LivenessAA->isAssumedDead(&BB))
3121 if (LivenessAA && LivenessAA->isEdgeDead(PredBB, &BB))
3123 bool InitialEdgeOnly = isInitialThreadOnlyEdge(
3124 A, dyn_cast<BranchInst>(PredBB->getTerminator()), BB);
3125 mergeInPredecessor(
A, ED, BEDMap[PredBB], InitialEdgeOnly);
3132 bool UsedAssumedInformation;
3133 if (
A.isAssumedDead(
I, *
this, LivenessAA, UsedAssumedInformation,
3134 false, DepClassTy::OPTIONAL,
3140 if (
auto *II = dyn_cast<IntrinsicInst>(&
I)) {
3141 if (
auto *AI = dyn_cast_or_null<AssumeInst>(II)) {
3142 ED.addAssumeInst(
A, *AI);
3146 if (II->isAssumeLikeIntrinsic())
3150 if (
auto *FI = dyn_cast<FenceInst>(&
I)) {
3151 if (!ED.EncounteredNonLocalSideEffect) {
3153 if (ED.IsReachedFromAlignedBarrierOnly)
3158 case AtomicOrdering::NotAtomic:
3160 case AtomicOrdering::Unordered:
3162 case AtomicOrdering::Monotonic:
3164 case AtomicOrdering::Acquire:
3166 case AtomicOrdering::Release:
3168 case AtomicOrdering::AcquireRelease:
3170 case AtomicOrdering::SequentiallyConsistent:
3174 NonNoOpFences.insert(FI);
3177 auto *CB = dyn_cast<CallBase>(&
I);
3179 bool IsAlignedBarrier =
3183 AlignedBarrierLastInBlock &= IsNoSync;
3184 IsExplicitlyAligned &= IsNoSync;
3190 if (IsAlignedBarrier) {
3191 HandleAlignedBarrier(*CB, ED);
3192 AlignedBarrierLastInBlock =
true;
3193 IsExplicitlyAligned =
true;
3198 if (isa<MemIntrinsic>(&
I)) {
3199 if (!ED.EncounteredNonLocalSideEffect &&
3201 ED.EncounteredNonLocalSideEffect =
true;
3203 ED.IsReachedFromAlignedBarrierOnly =
false;
3211 auto &CallInED = CEDMap[{CB, PRE}];
3212 Changed |= mergeInPredecessor(
A, CallInED, ED);
3218 if (!IsNoSync && Callee && !
Callee->isDeclaration()) {
3221 if (EDAA && EDAA->getState().isValidState()) {
3224 CalleeED.IsReachedFromAlignedBarrierOnly;
3225 AlignedBarrierLastInBlock = ED.IsReachedFromAlignedBarrierOnly;
3226 if (IsNoSync || !CalleeED.IsReachedFromAlignedBarrierOnly)
3227 ED.EncounteredNonLocalSideEffect |=
3228 CalleeED.EncounteredNonLocalSideEffect;
3230 ED.EncounteredNonLocalSideEffect =
3231 CalleeED.EncounteredNonLocalSideEffect;
3232 if (!CalleeED.IsReachingAlignedBarrierOnly) {
3234 setAndRecord(CallInED.IsReachingAlignedBarrierOnly,
false);
3237 if (CalleeED.IsReachedFromAlignedBarrierOnly)
3238 mergeInPredecessorBarriersAndAssumptions(
A, ED, CalleeED);
3239 auto &CallOutED = CEDMap[{CB, POST}];
3240 Changed |= mergeInPredecessor(
A, CallOutED, ED);
3245 ED.IsReachedFromAlignedBarrierOnly =
false;
3246 Changed |= setAndRecord(CallInED.IsReachingAlignedBarrierOnly,
false);
3249 AlignedBarrierLastInBlock &= ED.IsReachedFromAlignedBarrierOnly;
3251 auto &CallOutED = CEDMap[{CB, POST}];
3252 Changed |= mergeInPredecessor(
A, CallOutED, ED);
3255 if (!
I.mayHaveSideEffects() && !
I.mayReadFromMemory())
3269 if (MemAA && MemAA->getState().isValidState() &&
3270 MemAA->checkForAllAccessesToMemoryKind(
3275 auto &InfoCache =
A.getInfoCache();
3276 if (!
I.mayHaveSideEffects() && InfoCache.isOnlyUsedByAssume(
I))
3279 if (
auto *LI = dyn_cast<LoadInst>(&
I))
3280 if (LI->hasMetadata(LLVMContext::MD_invariant_load))
3283 if (!ED.EncounteredNonLocalSideEffect &&
3285 ED.EncounteredNonLocalSideEffect =
true;
3288 bool IsEndAndNotReachingAlignedBarriersOnly =
false;
3289 if (!isa<UnreachableInst>(BB.getTerminator()) &&
3290 !BB.getTerminator()->getNumSuccessors()) {
3292 Changed |= mergeInPredecessor(
A, InterProceduralED, ED);
3294 auto &FnED = BEDMap[
nullptr];
3295 if (IsKernel && !IsExplicitlyAligned)
3296 FnED.IsReachingAlignedBarrierOnly =
false;
3297 Changed |= mergeInPredecessor(
A, FnED, ED);
3299 if (!FnED.IsReachingAlignedBarrierOnly) {
3300 IsEndAndNotReachingAlignedBarriersOnly =
true;
3301 SyncInstWorklist.
push_back(BB.getTerminator());
3302 auto &BBED = BEDMap[&BB];
3303 Changed |= setAndRecord(BBED.IsReachingAlignedBarrierOnly,
false);
3307 ExecutionDomainTy &StoredED = BEDMap[&BB];
3308 ED.IsReachingAlignedBarrierOnly = StoredED.IsReachingAlignedBarrierOnly &
3309 !IsEndAndNotReachingAlignedBarriersOnly;
3315 if (ED.IsExecutedByInitialThreadOnly !=
3316 StoredED.IsExecutedByInitialThreadOnly ||
3317 ED.IsReachedFromAlignedBarrierOnly !=
3318 StoredED.IsReachedFromAlignedBarrierOnly ||
3319 ED.EncounteredNonLocalSideEffect !=
3320 StoredED.EncounteredNonLocalSideEffect)
3324 StoredED = std::move(ED);
3330 while (!SyncInstWorklist.
empty()) {
3333 bool HitAlignedBarrierOrKnownEnd =
false;
3335 auto *CB = dyn_cast<CallBase>(CurInst);
3338 auto &CallOutED = CEDMap[{CB, POST}];
3339 Changed |= setAndRecord(CallOutED.IsReachingAlignedBarrierOnly,
false);
3340 auto &CallInED = CEDMap[{CB, PRE}];
3341 HitAlignedBarrierOrKnownEnd =
3342 AlignedBarriers.count(CB) || !CallInED.IsReachingAlignedBarrierOnly;
3343 if (HitAlignedBarrierOrKnownEnd)
3345 Changed |= setAndRecord(CallInED.IsReachingAlignedBarrierOnly,
false);
3347 if (HitAlignedBarrierOrKnownEnd)
3351 if (LivenessAA && LivenessAA->isEdgeDead(PredBB, SyncBB))
3353 if (!Visited.
insert(PredBB))
3355 auto &PredED = BEDMap[PredBB];
3356 if (setAndRecord(PredED.IsReachingAlignedBarrierOnly,
false)) {
3358 SyncInstWorklist.
push_back(PredBB->getTerminator());
3361 if (SyncBB != &EntryBB)
3364 setAndRecord(InterProceduralED.IsReachingAlignedBarrierOnly,
false);
3367 return Changed ? ChangeStatus::CHANGED : ChangeStatus::UNCHANGED;
3372struct AAHeapToShared :
public StateWrapper<BooleanState, AbstractAttribute> {
3377 static AAHeapToShared &createForPosition(
const IRPosition &IRP,
3381 virtual bool isAssumedHeapToShared(
CallBase &CB)
const = 0;
3385 virtual bool isAssumedHeapToSharedRemovedFree(
CallBase &CB)
const = 0;
3388 const std::string
getName()
const override {
return "AAHeapToShared"; }
3391 const char *getIdAddr()
const override {
return &
ID; }
3400 static const char ID;
3403struct AAHeapToSharedFunction :
public AAHeapToShared {
3405 : AAHeapToShared(IRP,
A) {}
3407 const std::string getAsStr(
Attributor *)
const override {
3408 return "[AAHeapToShared] " + std::to_string(MallocCalls.size()) +
3409 " malloc calls eligible.";
3413 void trackStatistics()
const override {}
3417 void findPotentialRemovedFreeCalls(
Attributor &
A) {
3418 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
3419 auto &FreeRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
3421 PotentialRemovedFreeCalls.clear();
3425 for (
auto *U : CB->
users()) {
3427 if (
C &&
C->getCalledFunction() == FreeRFI.Declaration)
3431 if (FreeCalls.
size() != 1)
3434 PotentialRemovedFreeCalls.insert(FreeCalls.
front());
3440 indicatePessimisticFixpoint();
3444 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
3445 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3446 if (!RFI.Declaration)
3451 bool &) -> std::optional<Value *> {
return nullptr; };
3454 for (
User *U : RFI.Declaration->
users())
3455 if (
CallBase *CB = dyn_cast<CallBase>(U)) {
3458 MallocCalls.insert(CB);
3463 findPotentialRemovedFreeCalls(
A);
3466 bool isAssumedHeapToShared(
CallBase &CB)
const override {
3467 return isValidState() && MallocCalls.count(&CB);
3470 bool isAssumedHeapToSharedRemovedFree(
CallBase &CB)
const override {
3471 return isValidState() && PotentialRemovedFreeCalls.count(&CB);
3475 if (MallocCalls.empty())
3476 return ChangeStatus::UNCHANGED;
3478 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
3479 auto &FreeCall = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
3483 DepClassTy::OPTIONAL);
3488 if (HS &&
HS->isAssumedHeapToStack(*CB))
3493 for (
auto *U : CB->
users()) {
3495 if (
C &&
C->getCalledFunction() == FreeCall.Declaration)
3498 if (FreeCalls.
size() != 1)
3505 <<
" with shared memory."
3506 <<
" Shared memory usage is limited to "
3512 <<
" with " << AllocSize->getZExtValue()
3513 <<
" bytes of shared memory\n");
3519 Type *Int8ArrTy = ArrayType::get(Int8Ty, AllocSize->getZExtValue());
3524 static_cast<unsigned>(AddressSpace::Shared));
3529 return OR <<
"Replaced globalized variable with "
3530 <<
ore::NV(
"SharedMemory", AllocSize->getZExtValue())
3531 << (AllocSize->isOne() ?
" byte " :
" bytes ")
3532 <<
"of shared memory.";
3538 "HeapToShared on allocation without alignment attribute");
3539 SharedMem->setAlignment(*Alignment);
3542 A.deleteAfterManifest(*CB);
3543 A.deleteAfterManifest(*FreeCalls.
front());
3545 SharedMemoryUsed += AllocSize->getZExtValue();
3546 NumBytesMovedToSharedMemory = SharedMemoryUsed;
3547 Changed = ChangeStatus::CHANGED;
3554 if (MallocCalls.empty())
3555 return indicatePessimisticFixpoint();
3556 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
3557 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3558 if (!RFI.Declaration)
3559 return ChangeStatus::UNCHANGED;
3563 auto NumMallocCalls = MallocCalls.size();
3566 for (
User *U : RFI.Declaration->
users()) {
3567 if (
CallBase *CB = dyn_cast<CallBase>(U)) {
3568 if (CB->getCaller() !=
F)
3570 if (!MallocCalls.count(CB))
3572 if (!isa<ConstantInt>(CB->getArgOperand(0))) {
3573 MallocCalls.remove(CB);
3578 if (!ED || !ED->isExecutedByInitialThreadOnly(*CB))
3579 MallocCalls.remove(CB);
3583 findPotentialRemovedFreeCalls(
A);
3585 if (NumMallocCalls != MallocCalls.size())
3586 return ChangeStatus::CHANGED;
3588 return ChangeStatus::UNCHANGED;
3596 unsigned SharedMemoryUsed = 0;
3599struct AAKernelInfo :
public StateWrapper<KernelInfoState, AbstractAttribute> {
3605 static bool requiresCalleeForCallBase() {
return false; }
3608 void trackStatistics()
const override {}
3611 const std::string getAsStr(
Attributor *)
const override {
3612 if (!isValidState())
3614 return std::string(SPMDCompatibilityTracker.isAssumed() ?
"SPMD"
3616 std::string(SPMDCompatibilityTracker.isAtFixpoint() ?
" [FIX]"
3618 std::string(
" #PRs: ") +
3619 (ReachedKnownParallelRegions.isValidState()
3620 ? std::to_string(ReachedKnownParallelRegions.size())
3622 ", #Unknown PRs: " +
3623 (ReachedUnknownParallelRegions.isValidState()
3626 ", #Reaching Kernels: " +
3627 (ReachingKernelEntries.isValidState()
3631 (ParallelLevels.isValidState()
3634 ", NestedPar: " + (NestedParallelism ?
"yes" :
"no");
3641 const std::string
getName()
const override {
return "AAKernelInfo"; }
3644 const char *getIdAddr()
const override {
return &
ID; }
3651 static const char ID;
3656struct AAKernelInfoFunction : AAKernelInfo {
3658 : AAKernelInfo(IRP,
A) {}
3663 return GuardedInstructions;
3666 void setConfigurationOfKernelEnvironment(
ConstantStruct *ConfigC) {
3668 KernelEnvC, ConfigC, {KernelInfo::ConfigurationIdx});
3669 assert(NewKernelEnvC &&
"Failed to create new kernel environment");
3670 KernelEnvC = cast<ConstantStruct>(NewKernelEnvC);
3673#define KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MEMBER) \
3674 void set##MEMBER##OfKernelEnvironment(ConstantInt *NewVal) { \
3675 ConstantStruct *ConfigC = \
3676 KernelInfo::getConfigurationFromKernelEnvironment(KernelEnvC); \
3677 Constant *NewConfigC = ConstantFoldInsertValueInstruction( \
3678 ConfigC, NewVal, {KernelInfo::MEMBER##Idx}); \
3679 assert(NewConfigC && "Failed to create new configuration environment"); \
3680 setConfigurationOfKernelEnvironment(cast<ConstantStruct>(NewConfigC)); \
3691#undef KERNEL_ENVIRONMENT_CONFIGURATION_SETTER
3698 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
3702 OMPInformationCache::RuntimeFunctionInfo &InitRFI =
3703 OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
3704 OMPInformationCache::RuntimeFunctionInfo &DeinitRFI =
3705 OMPInfoCache.RFIs[OMPRTL___kmpc_target_deinit];
3709 auto StoreCallBase = [](
Use &U,
3710 OMPInformationCache::RuntimeFunctionInfo &RFI,
3712 CallBase *CB = OpenMPOpt::getCallIfRegularCall(U, &RFI);
3714 "Unexpected use of __kmpc_target_init or __kmpc_target_deinit!");
3716 "Multiple uses of __kmpc_target_init or __kmpc_target_deinit!");
3722 StoreCallBase(U, InitRFI, KernelInitCB);
3726 DeinitRFI.foreachUse(
3728 StoreCallBase(U, DeinitRFI, KernelDeinitCB);
3734 if (!KernelInitCB || !KernelDeinitCB)
3738 ReachingKernelEntries.insert(Fn);
3739 IsKernelEntry =
true;
3747 KernelConfigurationSimplifyCB =
3749 bool &UsedAssumedInformation) -> std::optional<Constant *> {
3750 if (!isAtFixpoint()) {
3753 UsedAssumedInformation =
true;
3754 A.recordDependence(*
this, *AA, DepClassTy::OPTIONAL);
3759 A.registerGlobalVariableSimplificationCallback(
3760 *KernelEnvGV, KernelConfigurationSimplifyCB);
3764 KernelInfo::getExecModeFromKernelEnvironment(KernelEnvC);
3769 SPMDCompatibilityTracker.indicateOptimisticFixpoint();
3773 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3775 setExecModeOfKernelEnvironment(AssumedExecModeC);
3785 auto [MinTeams, MaxTeams] =
3793 KernelInfo::getMayUseNestedParallelismFromKernelEnvironment(KernelEnvC);
3795 MayUseNestedParallelismC->
getType(), NestedParallelism);
3796 setMayUseNestedParallelismOfKernelEnvironment(
3797 AssumedMayUseNestedParallelismC);
3801 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
3805 setUseGenericStateMachineOfKernelEnvironment(
3806 AssumedUseGenericStateMachineC);
3812 if (!OMPInfoCache.RFIs[RFKind].Declaration)
3814 A.registerVirtualUseCallback(*OMPInfoCache.RFIs[RFKind].Declaration, CB);
3818 auto AddDependence = [](
Attributor &
A,
const AAKernelInfo *KI,
3821 A.recordDependence(*KI, *QueryingAA, DepClassTy::OPTIONAL);
3835 if (SPMDCompatibilityTracker.isValidState())
3836 return AddDependence(
A,
this, QueryingAA);
3838 if (!ReachedKnownParallelRegions.isValidState())
3839 return AddDependence(
A,
this, QueryingAA);
3845 RegisterVirtualUse(OMPRTL___kmpc_get_hardware_num_threads_in_block,
3846 CustomStateMachineUseCB);
3847 RegisterVirtualUse(OMPRTL___kmpc_get_warp_size, CustomStateMachineUseCB);
3848 RegisterVirtualUse(OMPRTL___kmpc_barrier_simple_generic,
3849 CustomStateMachineUseCB);
3850 RegisterVirtualUse(OMPRTL___kmpc_kernel_parallel,
3851 CustomStateMachineUseCB);
3852 RegisterVirtualUse(OMPRTL___kmpc_kernel_end_parallel,
3853 CustomStateMachineUseCB);
3857 if (SPMDCompatibilityTracker.isAtFixpoint())
3864 if (!SPMDCompatibilityTracker.isValidState())
3865 return AddDependence(
A,
this, QueryingAA);
3868 RegisterVirtualUse(OMPRTL___kmpc_get_hardware_thread_id_in_block,
3877 if (!SPMDCompatibilityTracker.isValidState())
3878 return AddDependence(
A,
this, QueryingAA);
3879 if (SPMDCompatibilityTracker.empty())
3880 return AddDependence(
A,
this, QueryingAA);
3881 if (!mayContainParallelRegion())
3882 return AddDependence(
A,
this, QueryingAA);
3885 RegisterVirtualUse(OMPRTL___kmpc_barrier_simple_spmd, SPMDBarrierUseCB);
3889 static std::string sanitizeForGlobalName(std::string S) {
3893 return !((C >=
'a' && C <=
'z') || (C >=
'A' && C <=
'Z') ||
3894 (C >=
'0' && C <=
'9') || C ==
'_');
3905 if (!KernelInitCB || !KernelDeinitCB)
3906 return ChangeStatus::UNCHANGED;
3910 bool HasBuiltStateMachine =
true;
3911 if (!changeToSPMDMode(
A, Changed)) {
3913 HasBuiltStateMachine = buildCustomStateMachine(
A, Changed);
3915 HasBuiltStateMachine =
false;
3922 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
3923 ExistingKernelEnvC);
3924 if (!HasBuiltStateMachine)
3925 setUseGenericStateMachineOfKernelEnvironment(
3926 OldUseGenericStateMachineVal);
3933 Changed = ChangeStatus::CHANGED;
3939 void insertInstructionGuardsHelper(
Attributor &
A) {
3940 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
3942 auto CreateGuardedRegion = [&](
Instruction *RegionStartI,
3976 DT, LI, MSU,
"region.guarded.end");
3979 MSU,
"region.barrier");
3982 DT, LI, MSU,
"region.exit");
3984 SplitBlock(ParentBB, RegionStartI, DT, LI, MSU,
"region.guarded");
3987 "Expected a different CFG");
3990 ParentBB, ParentBB->
getTerminator(), DT, LI, MSU,
"region.check.tid");
3993 A.registerManifestAddedBasicBlock(*RegionEndBB);
3994 A.registerManifestAddedBasicBlock(*RegionBarrierBB);
3995 A.registerManifestAddedBasicBlock(*RegionExitBB);
3996 A.registerManifestAddedBasicBlock(*RegionStartBB);
3997 A.registerManifestAddedBasicBlock(*RegionCheckTidBB);
3999 bool HasBroadcastValues =
false;
4004 for (
Use &U :
I.uses()) {
4010 if (OutsideUses.
empty())
4013 HasBroadcastValues =
true;
4018 M,
I.getType(),
false,
4020 sanitizeForGlobalName(
4021 (
I.getName() +
".guarded.output.alloc").str()),
4023 static_cast<unsigned>(AddressSpace::Shared));
4029 I.getName() +
".guarded.output.load",
4033 for (
Use *U : OutsideUses)
4034 A.changeUseAfterManifest(*U, *LoadI);
4037 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
4043 InsertPointTy(ParentBB, ParentBB->
end()),
DL);
4044 OMPInfoCache.OMPBuilder.updateToLocation(Loc);
4047 OMPInfoCache.OMPBuilder.getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4049 OMPInfoCache.OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4055 InsertPointTy(RegionCheckTidBB, RegionCheckTidBB->
end()),
DL);
4056 OMPInfoCache.OMPBuilder.updateToLocation(LocRegionCheckTid);
4058 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4059 M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
4061 OMPInfoCache.OMPBuilder.Builder.CreateCall(HardwareTidFn, {});
4063 OMPInfoCache.setCallingConvention(HardwareTidFn, Tid);
4064 Value *TidCheck = OMPInfoCache.OMPBuilder.Builder.CreateIsNull(Tid);
4065 OMPInfoCache.OMPBuilder.Builder
4066 .CreateCondBr(TidCheck, RegionStartBB, RegionBarrierBB)
4072 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4073 M, OMPRTL___kmpc_barrier_simple_spmd);
4074 OMPInfoCache.OMPBuilder.updateToLocation(InsertPointTy(
4077 OMPInfoCache.OMPBuilder.Builder.CreateCall(BarrierFn, {Ident, Tid});
4079 OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
4082 if (HasBroadcastValues) {
4086 OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
4090 auto &AllocSharedRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
4092 for (
Instruction *GuardedI : SPMDCompatibilityTracker) {
4094 if (!Visited.
insert(BB).second)
4100 while (++IP != IPEnd) {
4101 if (!IP->mayHaveSideEffects() && !IP->mayReadFromMemory())
4104 if (OpenMPOpt::getCallIfRegularCall(*
I, &AllocSharedRFI))
4106 if (!
I->user_empty() || !SPMDCompatibilityTracker.contains(
I)) {
4107 LastEffect =
nullptr;
4114 for (
auto &Reorder : Reorders)
4120 for (
Instruction *GuardedI : SPMDCompatibilityTracker) {
4122 auto *CalleeAA =
A.lookupAAFor<AAKernelInfo>(
4125 assert(CalleeAA !=
nullptr &&
"Expected Callee AAKernelInfo");
4126 auto &CalleeAAFunction = *cast<AAKernelInfoFunction>(CalleeAA);
4128 if (CalleeAAFunction.getGuardedInstructions().contains(GuardedI))
4131 Instruction *GuardedRegionStart =
nullptr, *GuardedRegionEnd =
nullptr;
4135 if (SPMDCompatibilityTracker.contains(&
I)) {
4136 CalleeAAFunction.getGuardedInstructions().insert(&
I);
4137 if (GuardedRegionStart)
4138 GuardedRegionEnd = &
I;
4140 GuardedRegionStart = GuardedRegionEnd = &
I;
4147 if (GuardedRegionStart) {
4149 std::make_pair(GuardedRegionStart, GuardedRegionEnd));
4150 GuardedRegionStart =
nullptr;
4151 GuardedRegionEnd =
nullptr;
4156 for (
auto &GR : GuardedRegions)
4157 CreateGuardedRegion(GR.first, GR.second);
4160 void forceSingleThreadPerWorkgroupHelper(
Attributor &
A) {
4169 auto &Ctx = getAnchorValue().getContext();
4176 KernelInitCB->
getNextNode(),
"main.thread.user_code");
4181 A.registerManifestAddedBasicBlock(*InitBB);
4182 A.registerManifestAddedBasicBlock(*UserCodeBB);
4183 A.registerManifestAddedBasicBlock(*ReturnBB);
4192 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
4194 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4195 M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
4200 OMPInfoCache.setCallingConvention(ThreadIdInBlockFn, ThreadIdInBlock);
4207 "thread.is_main", InitBB);
4213 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
4216 if (!OMPInfoCache.runtimeFnsAvailable(
4217 {OMPRTL___kmpc_get_hardware_thread_id_in_block,
4218 OMPRTL___kmpc_barrier_simple_spmd}))
4221 if (!SPMDCompatibilityTracker.isAssumed()) {
4222 for (
Instruction *NonCompatibleI : SPMDCompatibilityTracker) {
4223 if (!NonCompatibleI)
4227 if (
auto *CB = dyn_cast<CallBase>(NonCompatibleI))
4228 if (OMPInfoCache.RTLFunctions.contains(CB->getCalledFunction()))
4232 ORA <<
"Value has potential side effects preventing SPMD-mode "
4234 if (isa<CallBase>(NonCompatibleI)) {
4235 ORA <<
". Add `__attribute__((assume(\"ompx_spmd_amenable\")))` to "
4236 "the called function to override";
4244 << *NonCompatibleI <<
"\n");
4256 Kernel = CB->getCaller();
4264 KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC);
4270 Changed = ChangeStatus::CHANGED;
4274 if (mayContainParallelRegion())
4275 insertInstructionGuardsHelper(
A);
4277 forceSingleThreadPerWorkgroupHelper(
A);
4282 "Initially non-SPMD kernel has SPMD exec mode!");
4286 ++NumOpenMPTargetRegionKernelsSPMD;
4289 return OR <<
"Transformed generic-mode kernel to SPMD-mode.";
4301 if (!ReachedKnownParallelRegions.isValidState())
4304 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
4305 if (!OMPInfoCache.runtimeFnsAvailable(
4306 {OMPRTL___kmpc_get_hardware_num_threads_in_block,
4307 OMPRTL___kmpc_get_warp_size, OMPRTL___kmpc_barrier_simple_generic,
4308 OMPRTL___kmpc_kernel_parallel, OMPRTL___kmpc_kernel_end_parallel}))
4319 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
4320 ExistingKernelEnvC);
4322 KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC);
4327 if (UseStateMachineC->
isZero() ||
4331 Changed = ChangeStatus::CHANGED;
4334 setUseGenericStateMachineOfKernelEnvironment(
4341 if (!mayContainParallelRegion()) {
4342 ++NumOpenMPTargetRegionKernelsWithoutStateMachine;
4345 return OR <<
"Removing unused state machine from generic-mode kernel.";
4353 if (ReachedUnknownParallelRegions.empty()) {
4354 ++NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback;
4357 return OR <<
"Rewriting generic-mode kernel with a customized state "
4362 ++NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback;
4365 return OR <<
"Generic-mode kernel is executed with a customized state "
4366 "machine that requires a fallback.";
4371 for (
CallBase *UnknownParallelRegionCB : ReachedUnknownParallelRegions) {
4372 if (!UnknownParallelRegionCB)
4375 return ORA <<
"Call may contain unknown parallel regions. Use "
4376 <<
"`__attribute__((assume(\"omp_no_parallelism\")))` to "
4414 auto &Ctx = getAnchorValue().getContext();
4418 BasicBlock *InitBB = KernelInitCB->getParent();
4420 KernelInitCB->getNextNode(),
"thread.user_code.check");
4424 Ctx,
"worker_state_machine.begin",
Kernel, UserCodeEntryBB);
4426 Ctx,
"worker_state_machine.finished",
Kernel, UserCodeEntryBB);
4428 Ctx,
"worker_state_machine.is_active.check",
Kernel, UserCodeEntryBB);
4431 Kernel, UserCodeEntryBB);
4434 Kernel, UserCodeEntryBB);
4436 Ctx,
"worker_state_machine.done.barrier",
Kernel, UserCodeEntryBB);
4437 A.registerManifestAddedBasicBlock(*InitBB);
4438 A.registerManifestAddedBasicBlock(*UserCodeEntryBB);
4439 A.registerManifestAddedBasicBlock(*IsWorkerCheckBB);
4440 A.registerManifestAddedBasicBlock(*StateMachineBeginBB);
4441 A.registerManifestAddedBasicBlock(*StateMachineFinishedBB);
4442 A.registerManifestAddedBasicBlock(*StateMachineIsActiveCheckBB);
4443 A.registerManifestAddedBasicBlock(*StateMachineIfCascadeCurrentBB);
4444 A.registerManifestAddedBasicBlock(*StateMachineEndParallelBB);
4445 A.registerManifestAddedBasicBlock(*StateMachineDoneBarrierBB);
4447 const DebugLoc &DLoc = KernelInitCB->getDebugLoc();
4454 "thread.is_worker", InitBB);
4460 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4461 M, OMPRTL___kmpc_get_hardware_num_threads_in_block);
4463 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4464 M, OMPRTL___kmpc_get_warp_size);
4467 OMPInfoCache.setCallingConvention(BlockHwSizeFn, BlockHwSize);
4471 OMPInfoCache.setCallingConvention(WarpSizeFn, WarpSize);
4474 BlockHwSize, WarpSize,
"block.size", IsWorkerCheckBB);
4478 "thread.is_main_or_worker", IsWorkerCheckBB);
4481 IsMainOrWorker, IsWorkerCheckBB);
4485 Type *VoidPtrTy = PointerType::getUnqual(Ctx);
4487 new AllocaInst(VoidPtrTy,
DL.getAllocaAddrSpace(),
nullptr,
4491 OMPInfoCache.OMPBuilder.updateToLocation(
4494 StateMachineBeginBB->
end()),
4497 Value *Ident = KernelInfo::getIdentFromKernelEnvironment(KernelEnvC);
4498 Value *GTid = KernelInitCB;
4501 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4502 M, OMPRTL___kmpc_barrier_simple_generic);
4505 OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
4509 (
unsigned int)AddressSpace::Generic) {
4511 WorkFnAI, PointerType::get(Ctx, (
unsigned int)AddressSpace::Generic),
4512 WorkFnAI->
getName() +
".generic", StateMachineBeginBB);
4517 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4518 M, OMPRTL___kmpc_kernel_parallel);
4520 KernelParallelFn, {WorkFnAI},
"worker.is_active", StateMachineBeginBB);
4521 OMPInfoCache.setCallingConvention(KernelParallelFn, IsActiveWorker);
4524 StateMachineBeginBB);
4534 StateMachineBeginBB);
4535 IsDone->setDebugLoc(DLoc);
4537 IsDone, StateMachineBeginBB)
4541 StateMachineDoneBarrierBB, IsActiveWorker,
4542 StateMachineIsActiveCheckBB)
4548 const unsigned int WrapperFunctionArgNo = 6;
4553 for (
int I = 0,
E = ReachedKnownParallelRegions.size();
I <
E; ++
I) {
4554 auto *CB = ReachedKnownParallelRegions[
I];
4555 auto *ParallelRegion = dyn_cast<Function>(
4556 CB->getArgOperand(WrapperFunctionArgNo)->stripPointerCasts());
4558 Ctx,
"worker_state_machine.parallel_region.execute",
Kernel,
4559 StateMachineEndParallelBB);
4561 ->setDebugLoc(DLoc);
4567 Kernel, StateMachineEndParallelBB);
4568 A.registerManifestAddedBasicBlock(*PRExecuteBB);
4569 A.registerManifestAddedBasicBlock(*PRNextBB);
4574 if (
I + 1 <
E || !ReachedUnknownParallelRegions.empty()) {
4577 "worker.check_parallel_region", StateMachineIfCascadeCurrentBB);
4585 StateMachineIfCascadeCurrentBB)
4587 StateMachineIfCascadeCurrentBB = PRNextBB;
4593 if (!ReachedUnknownParallelRegions.empty()) {
4594 StateMachineIfCascadeCurrentBB->
setName(
4595 "worker_state_machine.parallel_region.fallback.execute");
4597 StateMachineIfCascadeCurrentBB)
4598 ->setDebugLoc(DLoc);
4601 StateMachineIfCascadeCurrentBB)
4605 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4606 M, OMPRTL___kmpc_kernel_end_parallel);
4609 OMPInfoCache.setCallingConvention(EndParallelFn, EndParallel);
4615 ->setDebugLoc(DLoc);
4625 KernelInfoState StateBefore = getState();
4631 struct UpdateKernelEnvCRAII {
4632 AAKernelInfoFunction &AA;
4634 UpdateKernelEnvCRAII(AAKernelInfoFunction &AA) : AA(AA) {}
4636 ~UpdateKernelEnvCRAII() {
4643 if (!AA.isValidState()) {
4644 AA.KernelEnvC = ExistingKernelEnvC;
4648 if (!AA.ReachedKnownParallelRegions.isValidState())
4649 AA.setUseGenericStateMachineOfKernelEnvironment(
4650 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
4651 ExistingKernelEnvC));
4653 if (!AA.SPMDCompatibilityTracker.isValidState())
4654 AA.setExecModeOfKernelEnvironment(
4655 KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC));
4658 KernelInfo::getMayUseNestedParallelismFromKernelEnvironment(
4661 MayUseNestedParallelismC->
getType(), AA.NestedParallelism);
4662 AA.setMayUseNestedParallelismOfKernelEnvironment(
4663 NewMayUseNestedParallelismC);
4670 if (isa<CallBase>(
I))
4673 if (!
I.mayWriteToMemory())
4675 if (
auto *SI = dyn_cast<StoreInst>(&
I)) {
4678 DepClassTy::OPTIONAL);
4681 DepClassTy::OPTIONAL);
4682 if (UnderlyingObjsAA &&
4683 UnderlyingObjsAA->forallUnderlyingObjects([&](
Value &Obj) {
4684 if (AA::isAssumedThreadLocalObject(A, Obj, *this))
4688 auto *CB = dyn_cast<CallBase>(&Obj);
4689 return CB && HS && HS->isAssumedHeapToStack(*CB);
4695 SPMDCompatibilityTracker.insert(&
I);
4699 bool UsedAssumedInformationInCheckRWInst =
false;
4700 if (!SPMDCompatibilityTracker.isAtFixpoint())
4701 if (!
A.checkForAllReadWriteInstructions(
4702 CheckRWInst, *
this, UsedAssumedInformationInCheckRWInst))
4705 bool UsedAssumedInformationFromReachingKernels =
false;
4706 if (!IsKernelEntry) {
4707 updateParallelLevels(
A);
4709 bool AllReachingKernelsKnown =
true;
4710 updateReachingKernelEntries(
A, AllReachingKernelsKnown);
4711 UsedAssumedInformationFromReachingKernels = !AllReachingKernelsKnown;
4713 if (!SPMDCompatibilityTracker.empty()) {
4714 if (!ParallelLevels.isValidState())
4715 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4716 else if (!ReachingKernelEntries.isValidState())
4717 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4723 for (
auto *
Kernel : ReachingKernelEntries) {
4724 auto *CBAA =
A.getAAFor<AAKernelInfo>(
4726 if (CBAA && CBAA->SPMDCompatibilityTracker.isValidState() &&
4727 CBAA->SPMDCompatibilityTracker.isAssumed())
4731 if (!CBAA || !CBAA->SPMDCompatibilityTracker.isAtFixpoint())
4732 UsedAssumedInformationFromReachingKernels =
true;
4734 if (SPMD != 0 &&
Generic != 0)
4735 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4741 bool AllParallelRegionStatesWereFixed =
true;
4742 bool AllSPMDStatesWereFixed =
true;
4744 auto &CB = cast<CallBase>(
I);
4745 auto *CBAA =
A.getAAFor<AAKernelInfo>(
4749 getState() ^= CBAA->getState();
4750 AllSPMDStatesWereFixed &= CBAA->SPMDCompatibilityTracker.isAtFixpoint();
4751 AllParallelRegionStatesWereFixed &=
4752 CBAA->ReachedKnownParallelRegions.isAtFixpoint();
4753 AllParallelRegionStatesWereFixed &=
4754 CBAA->ReachedUnknownParallelRegions.isAtFixpoint();
4758 bool UsedAssumedInformationInCheckCallInst =
false;
4759 if (!
A.checkForAllCallLikeInstructions(
4760 CheckCallInst, *
this, UsedAssumedInformationInCheckCallInst)) {
4762 <<
"Failed to visit all call-like instructions!\n";);
4763 return indicatePessimisticFixpoint();
4768 if (!UsedAssumedInformationInCheckCallInst &&
4769 AllParallelRegionStatesWereFixed) {
4770 ReachedKnownParallelRegions.indicateOptimisticFixpoint();
4771 ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
4776 if (!UsedAssumedInformationInCheckRWInst &&
4777 !UsedAssumedInformationInCheckCallInst &&
4778 !UsedAssumedInformationFromReachingKernels && AllSPMDStatesWereFixed)
4779 SPMDCompatibilityTracker.indicateOptimisticFixpoint();
4781 return StateBefore == getState() ? ChangeStatus::UNCHANGED
4782 : ChangeStatus::CHANGED;
4788 bool &AllReachingKernelsKnown) {
4792 assert(Caller &&
"Caller is nullptr");
4794 auto *CAA =
A.getOrCreateAAFor<AAKernelInfo>(
4796 if (CAA && CAA->ReachingKernelEntries.isValidState()) {
4797 ReachingKernelEntries ^= CAA->ReachingKernelEntries;
4803 ReachingKernelEntries.indicatePessimisticFixpoint();
4808 if (!
A.checkForAllCallSites(PredCallSite, *
this,
4810 AllReachingKernelsKnown))
4811 ReachingKernelEntries.indicatePessimisticFixpoint();
4816 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
4817 OMPInformationCache::RuntimeFunctionInfo &Parallel51RFI =
4818 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
4823 assert(Caller &&
"Caller is nullptr");
4827 if (CAA && CAA->ParallelLevels.isValidState()) {
4833 if (Caller == Parallel51RFI.Declaration) {
4834 ParallelLevels.indicatePessimisticFixpoint();
4838 ParallelLevels ^= CAA->ParallelLevels;
4845 ParallelLevels.indicatePessimisticFixpoint();
4850 bool AllCallSitesKnown =
true;
4851 if (!
A.checkForAllCallSites(PredCallSite, *
this,
4854 ParallelLevels.indicatePessimisticFixpoint();
4861struct AAKernelInfoCallSite : AAKernelInfo {
4863 : AAKernelInfo(IRP,
A) {}
4867 AAKernelInfo::initialize(
A);
4869 CallBase &CB = cast<CallBase>(getAssociatedValue());
4874 if (AssumptionAA && AssumptionAA->hasAssumption(
"ompx_spmd_amenable")) {
4875 indicateOptimisticFixpoint();
4883 indicateOptimisticFixpoint();
4892 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
4893 const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
4894 if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
4896 if (!Callee || !
A.isFunctionIPOAmendable(*Callee)) {
4900 if (!AssumptionAA ||
4901 !(AssumptionAA->hasAssumption(
"omp_no_openmp") ||
4902 AssumptionAA->hasAssumption(
"omp_no_parallelism")))
4903 ReachedUnknownParallelRegions.insert(&CB);
4907 if (!SPMDCompatibilityTracker.isAtFixpoint()) {
4908 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4909 SPMDCompatibilityTracker.insert(&CB);
4914 indicateOptimisticFixpoint();
4920 if (NumCallees > 1) {
4921 indicatePessimisticFixpoint();
4928 case OMPRTL___kmpc_is_spmd_exec_mode:
4929 case OMPRTL___kmpc_distribute_static_fini:
4930 case OMPRTL___kmpc_for_static_fini:
4931 case OMPRTL___kmpc_global_thread_num:
4932 case OMPRTL___kmpc_get_hardware_num_threads_in_block:
4933 case OMPRTL___kmpc_get_hardware_num_blocks:
4934 case OMPRTL___kmpc_single:
4935 case OMPRTL___kmpc_end_single:
4936 case OMPRTL___kmpc_master:
4937 case OMPRTL___kmpc_end_master:
4938 case OMPRTL___kmpc_barrier:
4939 case OMPRTL___kmpc_nvptx_parallel_reduce_nowait_v2:
4940 case OMPRTL___kmpc_nvptx_teams_reduce_nowait_v2:
4941 case OMPRTL___kmpc_error:
4942 case OMPRTL___kmpc_flush:
4943 case OMPRTL___kmpc_get_hardware_thread_id_in_block:
4944 case OMPRTL___kmpc_get_warp_size:
4945 case OMPRTL_omp_get_thread_num:
4946 case OMPRTL_omp_get_num_threads:
4947 case OMPRTL_omp_get_max_threads:
4948 case OMPRTL_omp_in_parallel:
4949 case OMPRTL_omp_get_dynamic:
4950 case OMPRTL_omp_get_cancellation:
4951 case OMPRTL_omp_get_nested:
4952 case OMPRTL_omp_get_schedule:
4953 case OMPRTL_omp_get_thread_limit:
4954 case OMPRTL_omp_get_supported_active_levels:
4955 case OMPRTL_omp_get_max_active_levels:
4956 case OMPRTL_omp_get_level:
4957 case OMPRTL_omp_get_ancestor_thread_num:
4958 case OMPRTL_omp_get_team_size:
4959 case OMPRTL_omp_get_active_level:
4960 case OMPRTL_omp_in_final:
4961 case OMPRTL_omp_get_proc_bind:
4962 case OMPRTL_omp_get_num_places:
4963 case OMPRTL_omp_get_num_procs:
4964 case OMPRTL_omp_get_place_proc_ids:
4965 case OMPRTL_omp_get_place_num:
4966 case OMPRTL_omp_get_partition_num_places:
4967 case OMPRTL_omp_get_partition_place_nums:
4968 case OMPRTL_omp_get_wtime:
4970 case OMPRTL___kmpc_distribute_static_init_4:
4971 case OMPRTL___kmpc_distribute_static_init_4u:
4972 case OMPRTL___kmpc_distribute_static_init_8:
4973 case OMPRTL___kmpc_distribute_static_init_8u:
4974 case OMPRTL___kmpc_for_static_init_4:
4975 case OMPRTL___kmpc_for_static_init_4u:
4976 case OMPRTL___kmpc_for_static_init_8:
4977 case OMPRTL___kmpc_for_static_init_8u: {
4979 unsigned ScheduleArgOpNo = 2;
4980 auto *ScheduleTypeCI =
4982 unsigned ScheduleTypeVal =
4983 ScheduleTypeCI ? ScheduleTypeCI->getZExtValue() : 0;
4985 case OMPScheduleType::UnorderedStatic:
4986 case OMPScheduleType::UnorderedStaticChunked:
4987 case OMPScheduleType::OrderedDistribute:
4988 case OMPScheduleType::OrderedDistributeChunked:
4991 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4992 SPMDCompatibilityTracker.insert(&CB);
4996 case OMPRTL___kmpc_target_init:
4999 case OMPRTL___kmpc_target_deinit:
5000 KernelDeinitCB = &CB;
5002 case OMPRTL___kmpc_parallel_51:
5003 if (!handleParallel51(
A, CB))
5004 indicatePessimisticFixpoint();
5006 case OMPRTL___kmpc_omp_task:
5008 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5009 SPMDCompatibilityTracker.insert(&CB);
5010 ReachedUnknownParallelRegions.insert(&CB);
5012 case OMPRTL___kmpc_alloc_shared:
5013 case OMPRTL___kmpc_free_shared:
5019 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5020 SPMDCompatibilityTracker.insert(&CB);
5026 indicateOptimisticFixpoint();
5030 A.getAAFor<
AACallEdges>(*
this, getIRPosition(), DepClassTy::OPTIONAL);
5031 if (!AACE || !AACE->getState().isValidState() || AACE->hasUnknownCallee()) {
5032 CheckCallee(getAssociatedFunction(), 1);
5035 const auto &OptimisticEdges = AACE->getOptimisticEdges();
5036 for (
auto *Callee : OptimisticEdges) {
5037 CheckCallee(Callee, OptimisticEdges.size());
5048 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
5049 KernelInfoState StateBefore = getState();
5051 auto CheckCallee = [&](
Function *
F,
int NumCallees) {
5052 const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(
F);
5056 if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
5059 A.getAAFor<AAKernelInfo>(*
this, FnPos, DepClassTy::REQUIRED);
5061 return indicatePessimisticFixpoint();
5062 if (getState() == FnAA->getState())
5063 return ChangeStatus::UNCHANGED;
5064 getState() = FnAA->getState();
5065 return ChangeStatus::CHANGED;
5068 return indicatePessimisticFixpoint();
5070 CallBase &CB = cast<CallBase>(getAssociatedValue());
5071 if (It->getSecond() == OMPRTL___kmpc_parallel_51) {
5072 if (!handleParallel51(
A, CB))
5073 return indicatePessimisticFixpoint();
5074 return StateBefore == getState() ? ChangeStatus::UNCHANGED
5075 : ChangeStatus::CHANGED;
5081 (It->getSecond() == OMPRTL___kmpc_alloc_shared ||
5082 It->getSecond() == OMPRTL___kmpc_free_shared) &&
5083 "Expected a __kmpc_alloc_shared or __kmpc_free_shared runtime call");
5087 auto *HeapToSharedAA =
A.getAAFor<AAHeapToShared>(
5095 case OMPRTL___kmpc_alloc_shared:
5096 if ((!HeapToStackAA || !HeapToStackAA->isAssumedHeapToStack(CB)) &&
5097 (!HeapToSharedAA || !HeapToSharedAA->isAssumedHeapToShared(CB)))
5098 SPMDCompatibilityTracker.insert(&CB);
5100 case OMPRTL___kmpc_free_shared:
5101 if ((!HeapToStackAA ||
5102 !HeapToStackAA->isAssumedHeapToStackRemovedFree(CB)) &&
5104 !HeapToSharedAA->isAssumedHeapToSharedRemovedFree(CB)))
5105 SPMDCompatibilityTracker.insert(&CB);
5108 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5109 SPMDCompatibilityTracker.insert(&CB);
5111 return ChangeStatus::CHANGED;
5115 A.getAAFor<
AACallEdges>(*
this, getIRPosition(), DepClassTy::OPTIONAL);
5116 if (!AACE || !AACE->getState().isValidState() || AACE->hasUnknownCallee()) {
5117 if (
Function *
F = getAssociatedFunction())
5121 for (
auto *Callee : OptimisticEdges) {
5122 CheckCallee(Callee, OptimisticEdges.size());
5128 return StateBefore == getState() ? ChangeStatus::UNCHANGED
5129 : ChangeStatus::CHANGED;
5135 const unsigned int NonWrapperFunctionArgNo = 5;
5136 const unsigned int WrapperFunctionArgNo = 6;
5137 auto ParallelRegionOpArgNo = SPMDCompatibilityTracker.isAssumed()
5138 ? NonWrapperFunctionArgNo
5139 : WrapperFunctionArgNo;
5141 auto *ParallelRegion = dyn_cast<Function>(
5143 if (!ParallelRegion)
5146 ReachedKnownParallelRegions.insert(&CB);
5148 auto *FnAA =
A.getAAFor<AAKernelInfo>(
5150 NestedParallelism |= !FnAA || !FnAA->getState().isValidState() ||
5151 !FnAA->ReachedKnownParallelRegions.empty() ||
5152 !FnAA->ReachedKnownParallelRegions.isValidState() ||
5153 !FnAA->ReachedUnknownParallelRegions.isValidState() ||
5154 !FnAA->ReachedUnknownParallelRegions.empty();
5159struct AAFoldRuntimeCall
5160 :
public StateWrapper<BooleanState, AbstractAttribute> {
5166 void trackStatistics()
const override {}
5169 static AAFoldRuntimeCall &createForPosition(
const IRPosition &IRP,
5173 const std::string
getName()
const override {
return "AAFoldRuntimeCall"; }
5176 const char *getIdAddr()
const override {
return &
ID; }
5184 static const char ID;
5187struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
5189 : AAFoldRuntimeCall(IRP,
A) {}
5192 const std::string getAsStr(
Attributor *)
const override {
5193 if (!isValidState())
5196 std::string Str(
"simplified value: ");
5198 if (!SimplifiedValue)
5199 return Str + std::string(
"none");
5201 if (!*SimplifiedValue)
5202 return Str + std::string(
"nullptr");
5204 if (
ConstantInt *CI = dyn_cast<ConstantInt>(*SimplifiedValue))
5205 return Str + std::to_string(CI->getSExtValue());
5207 return Str + std::string(
"unknown");
5212 indicatePessimisticFixpoint();
5216 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
5217 const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
5218 assert(It != OMPInfoCache.RuntimeFunctionIDMap.end() &&
5219 "Expected a known OpenMP runtime function");
5221 RFKind = It->getSecond();
5223 CallBase &CB = cast<CallBase>(getAssociatedValue());
5224 A.registerSimplificationCallback(
5227 bool &UsedAssumedInformation) -> std::optional<Value *> {
5228 assert((isValidState() ||
5229 (SimplifiedValue && *SimplifiedValue ==
nullptr)) &&
5230 "Unexpected invalid state!");
5232 if (!isAtFixpoint()) {
5233 UsedAssumedInformation =
true;
5235 A.recordDependence(*
this, *AA, DepClassTy::OPTIONAL);
5237 return SimplifiedValue;
5244 case OMPRTL___kmpc_is_spmd_exec_mode:
5245 Changed |= foldIsSPMDExecMode(
A);
5247 case OMPRTL___kmpc_parallel_level:
5248 Changed |= foldParallelLevel(
A);
5250 case OMPRTL___kmpc_get_hardware_num_threads_in_block:
5251 Changed = Changed | foldKernelFnAttribute(
A,
"omp_target_thread_limit");
5253 case OMPRTL___kmpc_get_hardware_num_blocks:
5254 Changed = Changed | foldKernelFnAttribute(
A,
"omp_target_num_teams");
5266 if (SimplifiedValue && *SimplifiedValue) {
5269 A.deleteAfterManifest(
I);
5273 if (
auto *
C = dyn_cast<ConstantInt>(*SimplifiedValue))
5274 return OR <<
"Replacing OpenMP runtime call "
5276 <<
ore::NV(
"FoldedValue",
C->getZExtValue()) <<
".";
5277 return OR <<
"Replacing OpenMP runtime call "
5285 << **SimplifiedValue <<
"\n");
5287 Changed = ChangeStatus::CHANGED;
5294 SimplifiedValue =
nullptr;
5295 return AAFoldRuntimeCall::indicatePessimisticFixpoint();
5301 std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
5303 unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
5304 unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
5305 auto *CallerKernelInfoAA =
A.getAAFor<AAKernelInfo>(
5308 if (!CallerKernelInfoAA ||
5309 !CallerKernelInfoAA->ReachingKernelEntries.isValidState())
5310 return indicatePessimisticFixpoint();
5312 for (
Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
5314 DepClassTy::REQUIRED);
5316 if (!AA || !AA->isValidState()) {
5317 SimplifiedValue =
nullptr;
5318 return indicatePessimisticFixpoint();
5321 if (AA->SPMDCompatibilityTracker.isAssumed()) {
5322 if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5327 if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5328 ++KnownNonSPMDCount;
5330 ++AssumedNonSPMDCount;
5334 if ((AssumedSPMDCount + KnownSPMDCount) &&
5335 (AssumedNonSPMDCount + KnownNonSPMDCount))
5336 return indicatePessimisticFixpoint();
5338 auto &Ctx = getAnchorValue().getContext();
5339 if (KnownSPMDCount || AssumedSPMDCount) {
5340 assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
5341 "Expected only SPMD kernels!");
5345 }
else if (KnownNonSPMDCount || AssumedNonSPMDCount) {
5346 assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
5347 "Expected only non-SPMD kernels!");
5355 assert(!SimplifiedValue &&
"SimplifiedValue should be none");
5358 return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
5359 : ChangeStatus::CHANGED;
5364 std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
5366 auto *CallerKernelInfoAA =
A.getAAFor<AAKernelInfo>(
5369 if (!CallerKernelInfoAA ||
5370 !CallerKernelInfoAA->ParallelLevels.isValidState())
5371 return indicatePessimisticFixpoint();
5373 if (!CallerKernelInfoAA->ReachingKernelEntries.isValidState())
5374 return indicatePessimisticFixpoint();
5376 if (CallerKernelInfoAA->ReachingKernelEntries.empty()) {
5377 assert(!SimplifiedValue &&
5378 "SimplifiedValue should keep none at this point");
5379 return ChangeStatus::UNCHANGED;
5382 unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
5383 unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
5384 for (
Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
5386 DepClassTy::REQUIRED);
5387 if (!AA || !AA->SPMDCompatibilityTracker.isValidState())
5388 return indicatePessimisticFixpoint();