49#include "llvm/IR/IntrinsicsAMDGPU.h"
50#include "llvm/IR/IntrinsicsNVPTX.h"
66#define DEBUG_TYPE "openmp-opt"
69 "openmp-opt-disable",
cl::desc(
"Disable OpenMP specific optimizations."),
73 "openmp-opt-enable-merging",
79 cl::desc(
"Disable function internalization."),
90 "openmp-hide-memory-transfer-latency",
91 cl::desc(
"[WIP] Tries to hide the latency of host to device memory"
96 "openmp-opt-disable-deglobalization",
97 cl::desc(
"Disable OpenMP optimizations involving deglobalization."),
101 "openmp-opt-disable-spmdization",
102 cl::desc(
"Disable OpenMP optimizations involving SPMD-ization."),
106 "openmp-opt-disable-folding",
111 "openmp-opt-disable-state-machine-rewrite",
112 cl::desc(
"Disable OpenMP optimizations that replace the state machine."),
116 "openmp-opt-disable-barrier-elimination",
117 cl::desc(
"Disable OpenMP optimizations that eliminate barriers."),
121 "openmp-opt-print-module-after",
122 cl::desc(
"Print the current module after OpenMP optimizations."),
126 "openmp-opt-print-module-before",
127 cl::desc(
"Print the current module before OpenMP optimizations."),
131 "openmp-opt-inline-device",
142 cl::desc(
"Maximal number of attributor iterations."),
147 cl::desc(
"Maximum amount of shared memory to use."),
148 cl::init(std::numeric_limits<unsigned>::max()));
151 "Number of OpenMP runtime calls deduplicated");
153 "Number of OpenMP parallel regions deleted");
155 "Number of OpenMP runtime functions identified");
157 "Number of OpenMP runtime function uses identified");
159 "Number of OpenMP target region entry points (=kernels) identified");
161 "Number of non-OpenMP target region kernels identified");
163 "Number of OpenMP target region entry points (=kernels) executed in "
164 "SPMD-mode instead of generic-mode");
165STATISTIC(NumOpenMPTargetRegionKernelsWithoutStateMachine,
166 "Number of OpenMP target region entry points (=kernels) executed in "
167 "generic-mode without a state machines");
168STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback,
169 "Number of OpenMP target region entry points (=kernels) executed in "
170 "generic-mode with customized state machines with fallback");
171STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback,
172 "Number of OpenMP target region entry points (=kernels) executed in "
173 "generic-mode with customized state machines without fallback");
175 NumOpenMPParallelRegionsReplacedInGPUStateMachine,
176 "Number of OpenMP parallel regions replaced with ID in GPU state machines");
178 "Number of OpenMP parallel regions merged");
180 "Amount of memory pushed to shared memory");
181STATISTIC(NumBarriersEliminated,
"Number of redundant barriers eliminated");
209#define KERNEL_ENVIRONMENT_IDX(MEMBER, IDX) \
210 constexpr const unsigned MEMBER##Idx = IDX;
215#undef KERNEL_ENVIRONMENT_IDX
217#define KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MEMBER, IDX) \
218 constexpr const unsigned MEMBER##Idx = IDX;
228#undef KERNEL_ENVIRONMENT_CONFIGURATION_IDX
230#define KERNEL_ENVIRONMENT_GETTER(MEMBER, RETURNTYPE) \
231 RETURNTYPE *get##MEMBER##FromKernelEnvironment(ConstantStruct *KernelEnvC) { \
232 return cast<RETURNTYPE>(KernelEnvC->getAggregateElement(MEMBER##Idx)); \
238#undef KERNEL_ENVIRONMENT_GETTER
240#define KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MEMBER) \
241 ConstantInt *get##MEMBER##FromKernelEnvironment( \
242 ConstantStruct *KernelEnvC) { \
243 ConstantStruct *ConfigC = \
244 getConfigurationFromKernelEnvironment(KernelEnvC); \
245 return dyn_cast<ConstantInt>(ConfigC->getAggregateElement(MEMBER##Idx)); \
256#undef KERNEL_ENVIRONMENT_CONFIGURATION_GETTER
260 constexpr const int InitKernelEnvironmentArgNo = 0;
261 return cast<GlobalVariable>(
275struct AAHeapToShared;
286 OpenMPPostLink(OpenMPPostLink) {
289 const Triple T(OMPBuilder.M.getTargetTriple());
290 switch (
T.getArch()) {
294 assert(OMPBuilder.Config.IsTargetDevice &&
295 "OpenMP AMDGPU/NVPTX is only prepared to deal with device code.");
296 OMPBuilder.Config.IsGPU =
true;
299 OMPBuilder.Config.IsGPU =
false;
302 OMPBuilder.initialize();
303 initializeRuntimeFunctions(M);
304 initializeInternalControlVars();
308 struct InternalControlVarInfo {
335 struct RuntimeFunctionInfo {
359 void clearUsesMap() { UsesMap.
clear(); }
362 operator bool()
const {
return Declaration; }
365 UseVector &getOrCreateUseVector(
Function *
F) {
366 std::shared_ptr<UseVector> &UV = UsesMap[
F];
368 UV = std::make_shared<UseVector>();
374 const UseVector *getUseVector(
Function &
F)
const {
375 auto I = UsesMap.find(&
F);
376 if (
I != UsesMap.end())
377 return I->second.get();
382 size_t getNumFunctionsWithUses()
const {
return UsesMap.size(); }
386 size_t getNumArgs()
const {
return ArgumentTypes.
size(); }
404 UseVector &UV = getOrCreateUseVector(
F);
414 while (!ToBeDeleted.
empty()) {
428 decltype(UsesMap)::iterator
begin() {
return UsesMap.
begin(); }
429 decltype(UsesMap)::iterator
end() {
return UsesMap.
end(); }
437 RuntimeFunction::OMPRTL___last>
445 InternalControlVar::ICV___last>
450 void initializeInternalControlVars() {
451#define ICV_RT_SET(_Name, RTL) \
453 auto &ICV = ICVs[_Name]; \
456#define ICV_RT_GET(Name, RTL) \
458 auto &ICV = ICVs[Name]; \
461#define ICV_DATA_ENV(Enum, _Name, _EnvVarName, Init) \
463 auto &ICV = ICVs[Enum]; \
466 ICV.InitKind = Init; \
467 ICV.EnvVarName = _EnvVarName; \
468 switch (ICV.InitKind) { \
469 case ICV_IMPLEMENTATION_DEFINED: \
470 ICV.InitValue = nullptr; \
473 ICV.InitValue = ConstantInt::get( \
474 Type::getInt32Ty(OMPBuilder.Int32->getContext()), 0); \
477 ICV.InitValue = ConstantInt::getFalse(OMPBuilder.Int1->getContext()); \
483#include "llvm/Frontend/OpenMP/OMPKinds.def"
489 static bool declMatchesRTFTypes(
Function *
F,
Type *RTFRetType,
496 if (
F->getReturnType() != RTFRetType)
498 if (
F->arg_size() != RTFArgTypes.
size())
501 auto *RTFTyIt = RTFArgTypes.
begin();
503 if (Arg.getType() != *RTFTyIt)
513 unsigned collectUses(RuntimeFunctionInfo &RFI,
bool CollectStats =
true) {
514 unsigned NumUses = 0;
515 if (!RFI.Declaration)
520 NumOpenMPRuntimeFunctionsIdentified += 1;
521 NumOpenMPRuntimeFunctionUsesIdentified += RFI.Declaration->getNumUses();
525 for (
Use &U : RFI.Declaration->uses()) {
526 if (
Instruction *UserI = dyn_cast<Instruction>(
U.getUser())) {
527 if (!
CGSCC ||
CGSCC->empty() ||
CGSCC->contains(UserI->getFunction())) {
528 RFI.getOrCreateUseVector(UserI->getFunction()).push_back(&U);
532 RFI.getOrCreateUseVector(
nullptr).push_back(&U);
541 auto &RFI = RFIs[RTF];
543 collectUses(RFI,
false);
547 void recollectUses() {
548 for (
int Idx = 0;
Idx < RFIs.size(); ++
Idx)
568 RuntimeFunctionInfo &RFI = RFIs[Fn];
570 if (RFI.Declaration && RFI.Declaration->isDeclaration())
578 void initializeRuntimeFunctions(
Module &M) {
581#define OMP_TYPE(VarName, ...) \
582 Type *VarName = OMPBuilder.VarName; \
585#define OMP_ARRAY_TYPE(VarName, ...) \
586 ArrayType *VarName##Ty = OMPBuilder.VarName##Ty; \
588 PointerType *VarName##PtrTy = OMPBuilder.VarName##PtrTy; \
589 (void)VarName##PtrTy;
591#define OMP_FUNCTION_TYPE(VarName, ...) \
592 FunctionType *VarName = OMPBuilder.VarName; \
594 PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \
597#define OMP_STRUCT_TYPE(VarName, ...) \
598 StructType *VarName = OMPBuilder.VarName; \
600 PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \
603#define OMP_RTL(_Enum, _Name, _IsVarArg, _ReturnType, ...) \
605 SmallVector<Type *, 8> ArgsTypes({__VA_ARGS__}); \
606 Function *F = M.getFunction(_Name); \
607 RTLFunctions.insert(F); \
608 if (declMatchesRTFTypes(F, OMPBuilder._ReturnType, ArgsTypes)) { \
609 RuntimeFunctionIDMap[F] = _Enum; \
610 auto &RFI = RFIs[_Enum]; \
613 RFI.IsVarArg = _IsVarArg; \
614 RFI.ReturnType = OMPBuilder._ReturnType; \
615 RFI.ArgumentTypes = std::move(ArgsTypes); \
616 RFI.Declaration = F; \
617 unsigned NumUses = collectUses(RFI); \
620 dbgs() << TAG << RFI.Name << (RFI.Declaration ? "" : " not") \
622 if (RFI.Declaration) \
623 dbgs() << TAG << "-> got " << NumUses << " uses in " \
624 << RFI.getNumFunctionsWithUses() \
625 << " different functions.\n"; \
629#include "llvm/Frontend/OpenMP/OMPKinds.def"
635 for (
StringRef Prefix : {
"__kmpc",
"_ZN4ompx",
"omp_"})
636 if (
F.hasFnAttribute(Attribute::NoInline) &&
637 F.getName().starts_with(Prefix) &&
638 !
F.hasFnAttribute(Attribute::OptimizeNone))
639 F.removeFnAttr(Attribute::NoInline);
650 bool OpenMPPostLink =
false;
653template <
typename Ty,
bool InsertInval
idates = true>
655 bool contains(
const Ty &Elem)
const {
return Set.contains(Elem); }
656 bool insert(
const Ty &Elem) {
657 if (InsertInvalidates)
659 return Set.insert(Elem);
662 const Ty &operator[](
int Idx)
const {
return Set[
Idx]; }
663 bool operator==(
const BooleanStateWithSetVector &RHS)
const {
664 return BooleanState::operator==(RHS) &&
Set ==
RHS.Set;
666 bool operator!=(
const BooleanStateWithSetVector &RHS)
const {
667 return !(*
this ==
RHS);
670 bool empty()
const {
return Set.empty(); }
671 size_t size()
const {
return Set.size(); }
674 BooleanStateWithSetVector &
operator^=(
const BooleanStateWithSetVector &RHS) {
675 BooleanState::operator^=(RHS);
676 Set.insert(
RHS.Set.begin(),
RHS.Set.end());
685 typename decltype(
Set)::iterator
begin() {
return Set.begin(); }
686 typename decltype(
Set)::iterator
end() {
return Set.end(); }
691template <
typename Ty,
bool InsertInval
idates = true>
692using BooleanStateWithPtrSetVector =
693 BooleanStateWithSetVector<Ty *, InsertInvalidates>;
697 bool IsAtFixpoint =
false;
701 BooleanStateWithPtrSetVector<
CallBase,
false>
702 ReachedKnownParallelRegions;
705 BooleanStateWithPtrSetVector<CallBase> ReachedUnknownParallelRegions;
710 BooleanStateWithPtrSetVector<Instruction, false> SPMDCompatibilityTracker;
725 bool IsKernelEntry =
false;
728 BooleanStateWithPtrSetVector<Function, false> ReachingKernelEntries;
733 BooleanStateWithSetVector<uint8_t> ParallelLevels;
736 bool NestedParallelism =
false;
741 KernelInfoState() =
default;
742 KernelInfoState(
bool BestState) {
751 bool isAtFixpoint()
const override {
return IsAtFixpoint; }
756 ParallelLevels.indicatePessimisticFixpoint();
757 ReachingKernelEntries.indicatePessimisticFixpoint();
758 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
759 ReachedKnownParallelRegions.indicatePessimisticFixpoint();
760 ReachedUnknownParallelRegions.indicatePessimisticFixpoint();
761 NestedParallelism =
true;
768 ParallelLevels.indicateOptimisticFixpoint();
769 ReachingKernelEntries.indicateOptimisticFixpoint();
770 SPMDCompatibilityTracker.indicateOptimisticFixpoint();
771 ReachedKnownParallelRegions.indicateOptimisticFixpoint();
772 ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
777 KernelInfoState &getAssumed() {
return *
this; }
778 const KernelInfoState &getAssumed()
const {
return *
this; }
780 bool operator==(
const KernelInfoState &RHS)
const {
781 if (SPMDCompatibilityTracker !=
RHS.SPMDCompatibilityTracker)
783 if (ReachedKnownParallelRegions !=
RHS.ReachedKnownParallelRegions)
785 if (ReachedUnknownParallelRegions !=
RHS.ReachedUnknownParallelRegions)
787 if (ReachingKernelEntries !=
RHS.ReachingKernelEntries)
789 if (ParallelLevels !=
RHS.ParallelLevels)
791 if (NestedParallelism !=
RHS.NestedParallelism)
797 bool mayContainParallelRegion() {
798 return !ReachedKnownParallelRegions.empty() ||
799 !ReachedUnknownParallelRegions.empty();
803 static KernelInfoState getBestState() {
return KernelInfoState(
true); }
805 static KernelInfoState getBestState(KernelInfoState &KIS) {
806 return getBestState();
810 static KernelInfoState getWorstState() {
return KernelInfoState(
false); }
813 KernelInfoState
operator^=(
const KernelInfoState &KIS) {
815 if (KIS.KernelInitCB) {
816 if (KernelInitCB && KernelInitCB != KIS.KernelInitCB)
819 KernelInitCB = KIS.KernelInitCB;
821 if (KIS.KernelDeinitCB) {
822 if (KernelDeinitCB && KernelDeinitCB != KIS.KernelDeinitCB)
825 KernelDeinitCB = KIS.KernelDeinitCB;
827 if (KIS.KernelEnvC) {
828 if (KernelEnvC && KernelEnvC != KIS.KernelEnvC)
831 KernelEnvC = KIS.KernelEnvC;
833 SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker;
834 ReachedKnownParallelRegions ^= KIS.ReachedKnownParallelRegions;
835 ReachedUnknownParallelRegions ^= KIS.ReachedUnknownParallelRegions;
836 NestedParallelism |= KIS.NestedParallelism;
840 KernelInfoState
operator&=(
const KernelInfoState &KIS) {
841 return (*
this ^= KIS);
857 OffloadArray() =
default;
864 if (!
Array.getAllocatedType()->isArrayTy())
867 if (!getValues(Array,
Before))
870 this->Array = &
Array;
874 static const unsigned DeviceIDArgNum = 1;
875 static const unsigned BasePtrsArgNum = 3;
876 static const unsigned PtrsArgNum = 4;
877 static const unsigned SizesArgNum = 5;
885 const uint64_t NumValues =
Array.getAllocatedType()->getArrayNumElements();
886 StoredValues.
assign(NumValues,
nullptr);
887 LastAccesses.
assign(NumValues,
nullptr);
892 if (BB !=
Before.getParent())
902 if (!isa<StoreInst>(&
I))
905 auto *S = cast<StoreInst>(&
I);
912 LastAccesses[
Idx] = S;
922 const unsigned NumValues = StoredValues.
size();
923 for (
unsigned I = 0;
I < NumValues; ++
I) {
924 if (!StoredValues[
I] || !LastAccesses[
I])
934 using OptimizationRemarkGetter =
938 OptimizationRemarkGetter OREGetter,
939 OMPInformationCache &OMPInfoCache,
Attributor &A)
941 OREGetter(OREGetter), OMPInfoCache(OMPInfoCache),
A(
A) {}
944 bool remarksEnabled() {
945 auto &Ctx =
M.getContext();
946 return Ctx.getDiagHandlerPtr()->isAnyRemarkEnabled(
DEBUG_TYPE);
950 bool run(
bool IsModulePass) {
954 bool Changed =
false;
960 Changed |= runAttributor(IsModulePass);
963 OMPInfoCache.recollectUses();
966 Changed |= rewriteDeviceCodeStateMachine();
968 if (remarksEnabled())
969 analysisGlobalization();
976 Changed |= runAttributor(IsModulePass);
979 OMPInfoCache.recollectUses();
981 Changed |= deleteParallelRegions();
984 Changed |= hideMemTransfersLatency();
985 Changed |= deduplicateRuntimeCalls();
987 if (mergeParallelRegions()) {
988 deduplicateRuntimeCalls();
994 if (OMPInfoCache.OpenMPPostLink)
995 Changed |= removeRuntimeSymbols();
1002 void printICVs()
const {
1007 for (
auto ICV : ICVs) {
1008 auto ICVInfo = OMPInfoCache.ICVs[ICV];
1010 return ORA <<
"OpenMP ICV " <<
ore::NV(
"OpenMPICV", ICVInfo.Name)
1012 << (ICVInfo.InitValue
1013 ?
toString(ICVInfo.InitValue->getValue(), 10,
true)
1014 :
"IMPLEMENTATION_DEFINED");
1017 emitRemark<OptimizationRemarkAnalysis>(
F,
"OpenMPICVTracker",
Remark);
1023 void printKernels()
const {
1029 return ORA <<
"OpenMP GPU kernel "
1030 <<
ore::NV(
"OpenMPGPUKernel",
F->getName()) <<
"\n";
1033 emitRemark<OptimizationRemarkAnalysis>(
F,
"OpenMPGPU",
Remark);
1039 static CallInst *getCallIfRegularCall(
1040 Use &U, OMPInformationCache::RuntimeFunctionInfo *RFI =
nullptr) {
1041 CallInst *CI = dyn_cast<CallInst>(
U.getUser());
1051 static CallInst *getCallIfRegularCall(
1052 Value &V, OMPInformationCache::RuntimeFunctionInfo *RFI =
nullptr) {
1053 CallInst *CI = dyn_cast<CallInst>(&V);
1063 bool mergeParallelRegions() {
1064 const unsigned CallbackCalleeOperand = 2;
1065 const unsigned CallbackFirstArgOperand = 3;
1069 OMPInformationCache::RuntimeFunctionInfo &RFI =
1070 OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
1072 if (!RFI.Declaration)
1076 OMPInformationCache::RuntimeFunctionInfo UnmergableCallsInfo[] = {
1077 OMPInfoCache.RFIs[OMPRTL___kmpc_push_proc_bind],
1078 OMPInfoCache.RFIs[OMPRTL___kmpc_push_num_threads],
1081 bool Changed =
false;
1087 BasicBlock *StartBB =
nullptr, *EndBB =
nullptr;
1088 auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
1089 BasicBlock *CGStartBB = CodeGenIP.getBlock();
1091 SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
1092 assert(StartBB !=
nullptr &&
"StartBB should not be null");
1094 assert(EndBB !=
nullptr &&
"EndBB should not be null");
1095 EndBB->getTerminator()->setSuccessor(0, CGEndBB);
1099 auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
Value &,
1100 Value &Inner,
Value *&ReplacementValue) -> InsertPointTy {
1101 ReplacementValue = &Inner;
1105 auto FiniCB = [&](InsertPointTy CodeGenIP) {
return Error::success(); };
1109 auto CreateSequentialRegion = [&](
Function *OuterFn,
1117 SplitBlock(ParentBB, SeqEndI->getNextNode(), DT, LI);
1121 SplitBlock(ParentBB, SeqStartI, DT, LI,
nullptr,
"seq.par.merged");
1124 "Expected a different CFG");
1128 auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
1129 BasicBlock *CGStartBB = CodeGenIP.getBlock();
1131 SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
1132 assert(SeqStartBB !=
nullptr &&
"SeqStartBB should not be null");
1134 assert(SeqEndBB !=
nullptr &&
"SeqEndBB should not be null");
1138 auto FiniCB = [&](InsertPointTy CodeGenIP) {
return Error::success(); };
1144 for (
User *Usr :
I.users()) {
1152 OutsideUsers.
insert(&UsrI);
1155 if (OutsideUsers.
empty())
1162 I.getType(),
DL.getAllocaAddrSpace(),
nullptr,
1163 I.getName() +
".seq.output.alloc", OuterFn->
front().
begin());
1167 new StoreInst(&
I, AllocaI, SeqStartBB->getTerminator()->getIterator());
1173 I.getName() +
".seq.output.load",
1180 InsertPointTy(ParentBB, ParentBB->
end()),
DL);
1182 OMPInfoCache.OMPBuilder.createMaster(Loc, BodyGenCB, FiniCB);
1183 assert(SeqAfterIP &&
"Unexpected error creating master");
1186 OMPInfoCache.OMPBuilder.createBarrier(*SeqAfterIP, OMPD_parallel);
1187 assert(BarrierAfterIP &&
"Unexpected error creating barrier");
1206 assert(MergableCIs.
size() > 1 &&
"Assumed multiple mergable CIs");
1209 OR <<
"Parallel region merged with parallel region"
1210 << (MergableCIs.
size() > 2 ?
"s" :
"") <<
" at ";
1213 if (CI != MergableCIs.
back())
1219 emitRemark<OptimizationRemark>(MergableCIs.
front(),
"OMP150",
Remark);
1223 <<
" parallel regions in " << OriginalFn->
getName()
1227 EndBB =
SplitBlock(BB, MergableCIs.
back()->getNextNode(), DT, LI);
1229 SplitBlock(EndBB, &*EndBB->getFirstInsertionPt(), DT, LI);
1233 assert(BB->getUniqueSuccessor() == StartBB &&
"Expected a different CFG");
1234 const DebugLoc DL = BB->getTerminator()->getDebugLoc();
1239 for (
auto *It = MergableCIs.
begin(), *
End = MergableCIs.
end() - 1;
1248 CreateSequentialRegion(OriginalFn, BB, ForkCI->
getNextNode(),
1260 OMPInfoCache.OMPBuilder.createParallel(
1261 Loc, AllocaIP, BodyGenCB, PrivCB, FiniCB,
nullptr,
nullptr,
1262 OMP_PROC_BIND_default,
false);
1263 assert(AfterIP &&
"Unexpected error creating parallel");
1267 OMPInfoCache.OMPBuilder.finalize(OriginalFn);
1274 for (
auto *CI : MergableCIs) {
1276 FunctionType *FT = OMPInfoCache.OMPBuilder.ParallelTask;
1280 for (
unsigned U = CallbackFirstArgOperand, E = CI->
arg_size(); U < E;
1290 for (
unsigned U = CallbackFirstArgOperand, E = CI->
arg_size(); U < E;
1294 U - (CallbackFirstArgOperand - CallbackCalleeOperand), A);
1297 if (CI != MergableCIs.back()) {
1301 OMPInfoCache.OMPBuilder.createBarrier(
1305 assert(AfterIP &&
"Unexpected error creating barrier");
1311 assert(OutlinedFn != OriginalFn &&
"Outlining failed");
1312 CGUpdater.registerOutlinedFunction(*OriginalFn, *OutlinedFn);
1313 CGUpdater.reanalyzeFunction(*OriginalFn);
1315 NumOpenMPParallelRegionsMerged += MergableCIs.size();
1323 CallInst *CI = getCallIfRegularCall(U, &RFI);
1330 RFI.foreachUse(SCC, DetectPRsCB);
1336 for (
auto &It : BB2PRMap) {
1337 auto &CIs = It.getSecond();
1352 auto IsMergable = [&](
Instruction &
I,
bool IsBeforeMergableRegion) {
1355 if (
I.isTerminator())
1358 if (!isa<CallInst>(&
I))
1362 if (IsBeforeMergableRegion) {
1364 if (!CalledFunction)
1371 for (
const auto &RFI : UnmergableCallsInfo) {
1372 if (CalledFunction == RFI.Declaration)
1380 if (!isa<IntrinsicInst>(CI))
1391 if (CIs.count(&
I)) {
1397 if (IsMergable(
I, MergableCIs.
empty()))
1402 for (; It !=
End; ++It) {
1404 if (CIs.count(&SkipI)) {
1406 <<
" due to " <<
I <<
"\n");
1413 if (MergableCIs.
size() > 1) {
1414 MergableCIsVector.
push_back(MergableCIs);
1416 <<
" parallel regions in block " << BB->
getName()
1421 MergableCIs.
clear();
1424 if (!MergableCIsVector.
empty()) {
1427 for (
auto &MergableCIs : MergableCIsVector)
1428 Merge(MergableCIs, BB);
1429 MergableCIsVector.clear();
1436 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_fork_call);
1437 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_barrier);
1438 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_master);
1439 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_end_master);
1446 bool deleteParallelRegions() {
1447 const unsigned CallbackCalleeOperand = 2;
1449 OMPInformationCache::RuntimeFunctionInfo &RFI =
1450 OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
1452 if (!RFI.Declaration)
1455 bool Changed =
false;
1457 CallInst *CI = getCallIfRegularCall(U);
1460 auto *Fn = dyn_cast<Function>(
1464 if (!Fn->onlyReadsMemory())
1466 if (!Fn->hasFnAttribute(Attribute::WillReturn))
1473 return OR <<
"Removing parallel region with no side-effects.";
1475 emitRemark<OptimizationRemark>(CI,
"OMP160",
Remark);
1479 ++NumOpenMPParallelRegionsDeleted;
1483 RFI.foreachUse(SCC, DeleteCallCB);
1489 bool deduplicateRuntimeCalls() {
1490 bool Changed =
false;
1493 OMPRTL_omp_get_num_threads,
1494 OMPRTL_omp_in_parallel,
1495 OMPRTL_omp_get_cancellation,
1496 OMPRTL_omp_get_supported_active_levels,
1497 OMPRTL_omp_get_level,
1498 OMPRTL_omp_get_ancestor_thread_num,
1499 OMPRTL_omp_get_team_size,
1500 OMPRTL_omp_get_active_level,
1501 OMPRTL_omp_in_final,
1502 OMPRTL_omp_get_proc_bind,
1503 OMPRTL_omp_get_num_places,
1504 OMPRTL_omp_get_num_procs,
1505 OMPRTL_omp_get_place_num,
1506 OMPRTL_omp_get_partition_num_places,
1507 OMPRTL_omp_get_partition_place_nums};
1511 collectGlobalThreadIdArguments(GTIdArgs);
1513 <<
" global thread ID arguments\n");
1516 for (
auto DeduplicableRuntimeCallID : DeduplicableRuntimeCallIDs)
1517 Changed |= deduplicateRuntimeCalls(
1518 *
F, OMPInfoCache.RFIs[DeduplicableRuntimeCallID]);
1522 Value *GTIdArg =
nullptr;
1524 if (GTIdArgs.
count(&Arg)) {
1528 Changed |= deduplicateRuntimeCalls(
1529 *
F, OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num], GTIdArg);
1536 bool removeRuntimeSymbols() {
1542 if (
GV->getNumUses() >= 1)
1546 GV->eraseFromParent();
1558 bool hideMemTransfersLatency() {
1559 auto &RFI = OMPInfoCache.RFIs[OMPRTL___tgt_target_data_begin_mapper];
1560 bool Changed =
false;
1562 auto *RTCall = getCallIfRegularCall(U, &RFI);
1566 OffloadArray OffloadArrays[3];
1567 if (!getValuesInOffloadArrays(*RTCall, OffloadArrays))
1570 LLVM_DEBUG(dumpValuesInOffloadArrays(OffloadArrays));
1573 bool WasSplit =
false;
1574 Instruction *WaitMovementPoint = canBeMovedDownwards(*RTCall);
1575 if (WaitMovementPoint)
1576 WasSplit = splitTargetDataBeginRTC(*RTCall, *WaitMovementPoint);
1578 Changed |= WasSplit;
1581 if (OMPInfoCache.runtimeFnsAvailable(
1582 {OMPRTL___tgt_target_data_begin_mapper_issue,
1583 OMPRTL___tgt_target_data_begin_mapper_wait}))
1584 RFI.foreachUse(SCC, SplitMemTransfers);
1589 void analysisGlobalization() {
1590 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
1592 auto CheckGlobalization = [&](
Use &
U,
Function &Decl) {
1593 if (
CallInst *CI = getCallIfRegularCall(U, &RFI)) {
1596 <<
"Found thread data sharing on the GPU. "
1597 <<
"Expect degraded performance due to data globalization.";
1599 emitRemark<OptimizationRemarkMissed>(CI,
"OMP112",
Remark);
1605 RFI.foreachUse(SCC, CheckGlobalization);
1610 bool getValuesInOffloadArrays(
CallInst &RuntimeCall,
1612 assert(OAs.
size() == 3 &&
"Need space for three offload arrays!");
1622 Value *BasePtrsArg =
1631 if (!isa<AllocaInst>(V))
1633 auto *BasePtrsArray = cast<AllocaInst>(V);
1634 if (!OAs[0].
initialize(*BasePtrsArray, RuntimeCall))
1639 if (!isa<AllocaInst>(V))
1641 auto *PtrsArray = cast<AllocaInst>(V);
1642 if (!OAs[1].
initialize(*PtrsArray, RuntimeCall))
1648 if (isa<GlobalValue>(V))
1649 return isa<Constant>(V);
1650 if (!isa<AllocaInst>(V))
1653 auto *SizesArray = cast<AllocaInst>(V);
1654 if (!OAs[2].
initialize(*SizesArray, RuntimeCall))
1665 assert(OAs.
size() == 3 &&
"There are three offload arrays to debug!");
1668 std::string ValuesStr;
1670 std::string Separator =
" --- ";
1672 for (
auto *BP : OAs[0].StoredValues) {
1676 LLVM_DEBUG(
dbgs() <<
"\t\toffload_baseptrs: " << ValuesStr <<
"\n");
1679 for (
auto *
P : OAs[1].StoredValues) {
1686 for (
auto *S : OAs[2].StoredValues) {
1690 LLVM_DEBUG(
dbgs() <<
"\t\toffload_sizes: " << ValuesStr <<
"\n");
1700 bool IsWorthIt =
false;
1719 return RuntimeCall.
getParent()->getTerminator();
1723 bool splitTargetDataBeginRTC(
CallInst &RuntimeCall,
1728 auto &
IRBuilder = OMPInfoCache.OMPBuilder;
1732 Entry.getFirstNonPHIOrDbgOrAlloca());
1734 IRBuilder.AsyncInfo,
nullptr,
"handle");
1742 M, OMPRTL___tgt_target_data_begin_mapper_issue);
1746 for (
auto &Arg : RuntimeCall.
args())
1747 Args.push_back(Arg.get());
1748 Args.push_back(Handle);
1752 OMPInfoCache.setCallingConvention(IssueDecl, IssueCallsite);
1758 M, OMPRTL___tgt_target_data_begin_mapper_wait);
1760 Value *WaitParams[2] = {
1762 OffloadArray::DeviceIDArgNum),
1766 WaitDecl, WaitParams,
"", WaitMovementPoint.
getIterator());
1767 OMPInfoCache.setCallingConvention(WaitDecl, WaitCallsite);
1772 static Value *combinedIdentStruct(
Value *CurrentIdent,
Value *NextIdent,
1773 bool GlobalOnly,
bool &SingleChoice) {
1774 if (CurrentIdent == NextIdent)
1775 return CurrentIdent;
1779 if (!GlobalOnly || isa<GlobalValue>(NextIdent)) {
1780 SingleChoice = !CurrentIdent;
1792 getCombinedIdentFromCallUsesIn(OMPInformationCache::RuntimeFunctionInfo &RFI,
1794 bool SingleChoice =
true;
1795 Value *Ident =
nullptr;
1797 CallInst *CI = getCallIfRegularCall(U, &RFI);
1798 if (!CI || &
F != &Caller)
1801 true, SingleChoice);
1804 RFI.foreachUse(SCC, CombineIdentStruct);
1806 if (!Ident || !SingleChoice) {
1809 if (!OMPInfoCache.OMPBuilder.getInsertionPoint().getBlock())
1811 &
F.getEntryBlock(),
F.getEntryBlock().begin()));
1816 OMPInfoCache.OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
1817 Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(Loc, SrcLocStrSize);
1824 bool deduplicateRuntimeCalls(
Function &
F,
1825 OMPInformationCache::RuntimeFunctionInfo &RFI,
1826 Value *ReplVal =
nullptr) {
1827 auto *UV = RFI.getUseVector(
F);
1828 if (!UV || UV->size() + (ReplVal !=
nullptr) < 2)
1832 dbgs() <<
TAG <<
"Deduplicate " << UV->size() <<
" uses of " << RFI.Name
1833 << (ReplVal ?
" with an existing value\n" :
"\n") <<
"\n");
1835 assert((!ReplVal || (isa<Argument>(ReplVal) &&
1836 cast<Argument>(ReplVal)->
getParent() == &
F)) &&
1837 "Unexpected replacement value!");
1840 auto CanBeMoved = [
this](
CallBase &CB) {
1841 unsigned NumArgs = CB.arg_size();
1844 if (CB.getArgOperand(0)->getType() != OMPInfoCache.OMPBuilder.IdentPtr)
1846 for (
unsigned U = 1;
U < NumArgs; ++
U)
1847 if (isa<Instruction>(CB.getArgOperand(U)))
1858 for (
Use *U : *UV) {
1859 if (
CallInst *CI = getCallIfRegularCall(*U, &RFI)) {
1864 if (!CanBeMoved(*CI))
1872 assert(IP &&
"Expected insertion point!");
1873 cast<Instruction>(ReplVal)->moveBefore(IP);
1879 if (
CallBase *CI = dyn_cast<CallBase>(ReplVal)) {
1882 Value *Ident = getCombinedIdentFromCallUsesIn(RFI,
F,
1888 bool Changed =
false;
1890 CallInst *CI = getCallIfRegularCall(U, &RFI);
1891 if (!CI || CI == ReplVal || &
F != &Caller)
1896 return OR <<
"OpenMP runtime call "
1897 <<
ore::NV(
"OpenMPOptRuntime", RFI.Name) <<
" deduplicated.";
1900 emitRemark<OptimizationRemark>(CI,
"OMP170",
Remark);
1902 emitRemark<OptimizationRemark>(&
F,
"OMP170",
Remark);
1906 ++NumOpenMPRuntimeCallsDeduplicated;
1910 RFI.foreachUse(SCC, ReplaceAndDeleteCB);
1924 if (!
F.hasLocalLinkage())
1926 for (
Use &U :
F.uses()) {
1927 if (
CallInst *CI = getCallIfRegularCall(U)) {
1929 if (CI == &RefCI || GTIdArgs.
count(ArgOp) ||
1930 getCallIfRegularCall(
1931 *ArgOp, &OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num]))
1940 auto AddUserArgs = [&](
Value >Id) {
1941 for (
Use &U : GTId.uses())
1942 if (
CallInst *CI = dyn_cast<CallInst>(
U.getUser()))
1945 if (CallArgOpIsGTId(*Callee,
U.getOperandNo(), *CI))
1950 OMPInformationCache::RuntimeFunctionInfo &GlobThreadNumRFI =
1951 OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num];
1953 GlobThreadNumRFI.foreachUse(SCC, [&](
Use &U,
Function &
F) {
1954 if (
CallInst *CI = getCallIfRegularCall(U, &GlobThreadNumRFI))
1962 for (
unsigned U = 0;
U < GTIdArgs.
size(); ++
U)
1963 AddUserArgs(*GTIdArgs[U]);
1978 return getUniqueKernelFor(*
I.getFunction());
1983 bool rewriteDeviceCodeStateMachine();
1999 template <
typename RemarkKind,
typename RemarkCallBack>
2001 RemarkCallBack &&RemarkCB)
const {
2003 auto &ORE = OREGetter(
F);
2007 return RemarkCB(RemarkKind(
DEBUG_TYPE, RemarkName,
I))
2008 <<
" [" << RemarkName <<
"]";
2012 [&]() {
return RemarkCB(RemarkKind(
DEBUG_TYPE, RemarkName,
I)); });
2016 template <
typename RemarkKind,
typename RemarkCallBack>
2018 RemarkCallBack &&RemarkCB)
const {
2019 auto &ORE = OREGetter(
F);
2023 return RemarkCB(RemarkKind(
DEBUG_TYPE, RemarkName,
F))
2024 <<
" [" << RemarkName <<
"]";
2028 [&]() {
return RemarkCB(RemarkKind(
DEBUG_TYPE, RemarkName,
F)); });
2042 OptimizationRemarkGetter OREGetter;
2045 OMPInformationCache &OMPInfoCache;
2051 bool runAttributor(
bool IsModulePass) {
2055 registerAAs(IsModulePass);
2060 <<
" functions, result: " << Changed <<
".\n");
2062 if (Changed == ChangeStatus::CHANGED)
2063 OMPInfoCache.invalidateAnalyses();
2065 return Changed == ChangeStatus::CHANGED;
2072 void registerAAs(
bool IsModulePass);
2081 if (OMPInfoCache.CGSCC && !OMPInfoCache.CGSCC->empty() &&
2082 !OMPInfoCache.CGSCC->contains(&
F))
2087 std::optional<Kernel> &CachedKernel = UniqueKernelMap[&
F];
2089 return *CachedKernel;
2096 return *CachedKernel;
2099 CachedKernel =
nullptr;
2100 if (!
F.hasLocalLinkage()) {
2104 return ORA <<
"Potentially unknown OpenMP target region caller.";
2106 emitRemark<OptimizationRemarkAnalysis>(&
F,
"OMP100",
Remark);
2112 auto GetUniqueKernelForUse = [&](
const Use &
U) ->
Kernel {
2113 if (
auto *Cmp = dyn_cast<ICmpInst>(
U.getUser())) {
2115 if (
Cmp->isEquality())
2116 return getUniqueKernelFor(*Cmp);
2119 if (
auto *CB = dyn_cast<CallBase>(
U.getUser())) {
2121 if (CB->isCallee(&U))
2122 return getUniqueKernelFor(*CB);
2124 OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
2125 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
2127 if (OpenMPOpt::getCallIfRegularCall(*
U.getUser(), &KernelParallelRFI))
2128 return getUniqueKernelFor(*CB);
2137 OMPInformationCache::foreachUse(
F, [&](
const Use &U) {
2138 PotentialKernels.
insert(GetUniqueKernelForUse(U));
2142 if (PotentialKernels.
size() == 1)
2143 K = *PotentialKernels.
begin();
2146 UniqueKernelMap[&
F] =
K;
2151bool OpenMPOpt::rewriteDeviceCodeStateMachine() {
2152 OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
2153 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
2155 bool Changed =
false;
2156 if (!KernelParallelRFI)
2167 bool UnknownUse =
false;
2168 bool KernelParallelUse =
false;
2169 unsigned NumDirectCalls = 0;
2172 OMPInformationCache::foreachUse(*
F, [&](
Use &U) {
2173 if (
auto *CB = dyn_cast<CallBase>(
U.getUser()))
2174 if (CB->isCallee(&U)) {
2179 if (isa<ICmpInst>(
U.getUser())) {
2180 ToBeReplacedStateMachineUses.push_back(&U);
2186 OpenMPOpt::getCallIfRegularCall(*
U.getUser(), &KernelParallelRFI);
2187 const unsigned int WrapperFunctionArgNo = 6;
2188 if (!KernelParallelUse && CI &&
2190 KernelParallelUse = true;
2191 ToBeReplacedStateMachineUses.push_back(&U);
2199 if (!KernelParallelUse)
2205 if (UnknownUse || NumDirectCalls != 1 ||
2206 ToBeReplacedStateMachineUses.
size() > 2) {
2208 return ORA <<
"Parallel region is used in "
2209 << (UnknownUse ?
"unknown" :
"unexpected")
2210 <<
" ways. Will not attempt to rewrite the state machine.";
2212 emitRemark<OptimizationRemarkAnalysis>(
F,
"OMP101",
Remark);
2221 return ORA <<
"Parallel region is not called from a unique kernel. "
2222 "Will not attempt to rewrite the state machine.";
2224 emitRemark<OptimizationRemarkAnalysis>(
F,
"OMP102",
Remark);
2240 for (
Use *U : ToBeReplacedStateMachineUses)
2242 ID,
U->get()->getType()));
2244 ++NumOpenMPParallelRegionsReplacedInGPUStateMachine;
2253struct AAICVTracker :
public StateWrapper<BooleanState, AbstractAttribute> {
2258 bool isAssumedTracked()
const {
return getAssumed(); }
2261 bool isKnownTracked()
const {
return getAssumed(); }
2270 return std::nullopt;
2276 virtual std::optional<Value *>
2284 const std::string
getName()
const override {
return "AAICVTracker"; }
2287 const char *getIdAddr()
const override {
return &
ID; }
2294 static const char ID;
2297struct AAICVTrackerFunction :
public AAICVTracker {
2299 : AAICVTracker(IRP,
A) {}
2302 const std::string getAsStr(
Attributor *)
const override {
2303 return "ICVTrackerFunction";
2307 void trackStatistics()
const override {}
2311 return ChangeStatus::UNCHANGED;
2316 InternalControlVar::ICV___last>
2317 ICVReplacementValuesMap;
2324 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
2327 auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2329 auto &ValuesMap = ICVReplacementValuesMap[ICV];
2331 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U);
2337 if (ValuesMap.insert(std::make_pair(CI, CI->
getArgOperand(0))).second)
2338 HasChanged = ChangeStatus::CHANGED;
2344 std::optional<Value *> ReplVal = getValueForCall(
A,
I, ICV);
2345 if (ReplVal && ValuesMap.insert(std::make_pair(&
I, *ReplVal)).second)
2346 HasChanged = ChangeStatus::CHANGED;
2352 SetterRFI.foreachUse(TrackValues,
F);
2354 bool UsedAssumedInformation =
false;
2355 A.checkForAllInstructions(CallCheck, *
this, {Instruction::Call},
2356 UsedAssumedInformation,
2362 if (HasChanged == ChangeStatus::CHANGED)
2363 ValuesMap.try_emplace(Entry);
2374 const auto *CB = dyn_cast<CallBase>(&
I);
2375 if (!CB || CB->hasFnAttr(
"no_openmp") ||
2376 CB->hasFnAttr(
"no_openmp_routines"))
2377 return std::nullopt;
2379 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
2380 auto &GetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Getter];
2381 auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2382 Function *CalledFunction = CB->getCalledFunction();
2385 if (CalledFunction ==
nullptr)
2387 if (CalledFunction == GetterRFI.Declaration)
2388 return std::nullopt;
2389 if (CalledFunction == SetterRFI.Declaration) {
2390 if (ICVReplacementValuesMap[ICV].
count(&
I))
2391 return ICVReplacementValuesMap[ICV].
lookup(&
I);
2400 const auto *ICVTrackingAA =
A.getAAFor<AAICVTracker>(
2403 if (ICVTrackingAA->isAssumedTracked()) {
2404 std::optional<Value *> URV =
2405 ICVTrackingAA->getUniqueReplacementValue(ICV);
2416 std::optional<Value *>
2418 return std::nullopt;
2425 const auto &ValuesMap = ICVReplacementValuesMap[ICV];
2426 if (ValuesMap.count(
I))
2427 return ValuesMap.lookup(
I);
2433 std::optional<Value *> ReplVal;
2435 while (!Worklist.
empty()) {
2437 if (!Visited.
insert(CurrInst).second)
2445 if (ValuesMap.count(CurrInst)) {
2446 std::optional<Value *> NewReplVal = ValuesMap.lookup(CurrInst);
2449 ReplVal = NewReplVal;
2455 if (ReplVal != NewReplVal)
2461 std::optional<Value *> NewReplVal = getValueForCall(
A, *CurrInst, ICV);
2467 ReplVal = NewReplVal;
2473 if (ReplVal != NewReplVal)
2478 if (CurrBB ==
I->getParent() && ReplVal)
2483 if (
const Instruction *Terminator = Pred->getTerminator())
2491struct AAICVTrackerFunctionReturned : AAICVTracker {
2493 : AAICVTracker(IRP,
A) {}
2496 const std::string getAsStr(
Attributor *)
const override {
2497 return "ICVTrackerFunctionReturned";
2501 void trackStatistics()
const override {}
2505 return ChangeStatus::UNCHANGED;
2510 InternalControlVar::ICV___last>
2511 ICVReplacementValuesMap;
2514 std::optional<Value *>
2516 return ICVReplacementValuesMap[ICV];
2521 const auto *ICVTrackingAA =
A.getAAFor<AAICVTracker>(
2524 if (!ICVTrackingAA->isAssumedTracked())
2525 return indicatePessimisticFixpoint();
2528 std::optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2529 std::optional<Value *> UniqueICVValue;
2532 std::optional<Value *> NewReplVal =
2533 ICVTrackingAA->getReplacementValue(ICV, &
I,
A);
2536 if (UniqueICVValue && UniqueICVValue != NewReplVal)
2539 UniqueICVValue = NewReplVal;
2544 bool UsedAssumedInformation =
false;
2545 if (!
A.checkForAllInstructions(CheckReturnInst, *
this, {Instruction::Ret},
2546 UsedAssumedInformation,
2548 UniqueICVValue =
nullptr;
2550 if (UniqueICVValue == ReplVal)
2553 ReplVal = UniqueICVValue;
2554 Changed = ChangeStatus::CHANGED;
2561struct AAICVTrackerCallSite : AAICVTracker {
2563 : AAICVTracker(IRP,
A) {}
2566 assert(getAnchorScope() &&
"Expected anchor function");
2570 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
2572 auto ICVInfo = OMPInfoCache.ICVs[ICV];
2573 auto &Getter = OMPInfoCache.RFIs[ICVInfo.Getter];
2574 if (Getter.Declaration == getAssociatedFunction()) {
2575 AssociatedICV = ICVInfo.Kind;
2581 indicatePessimisticFixpoint();
2585 if (!ReplVal || !*ReplVal)
2586 return ChangeStatus::UNCHANGED;
2589 A.deleteAfterManifest(*getCtxI());
2591 return ChangeStatus::CHANGED;
2595 const std::string getAsStr(
Attributor *)
const override {
2596 return "ICVTrackerCallSite";
2600 void trackStatistics()
const override {}
2603 std::optional<Value *> ReplVal;
2606 const auto *ICVTrackingAA =
A.getAAFor<AAICVTracker>(
2610 if (!ICVTrackingAA->isAssumedTracked())
2611 return indicatePessimisticFixpoint();
2613 std::optional<Value *> NewReplVal =
2614 ICVTrackingAA->getReplacementValue(AssociatedICV, getCtxI(),
A);
2616 if (ReplVal == NewReplVal)
2617 return ChangeStatus::UNCHANGED;
2619 ReplVal = NewReplVal;
2620 return ChangeStatus::CHANGED;
2625 std::optional<Value *>
2631struct AAICVTrackerCallSiteReturned : AAICVTracker {
2633 : AAICVTracker(IRP,
A) {}
2636 const std::string getAsStr(
Attributor *)
const override {
2637 return "ICVTrackerCallSiteReturned";
2641 void trackStatistics()
const override {}
2645 return ChangeStatus::UNCHANGED;
2650 InternalControlVar::ICV___last>
2651 ICVReplacementValuesMap;
2655 std::optional<Value *>
2657 return ICVReplacementValuesMap[ICV];
2662 const auto *ICVTrackingAA =
A.getAAFor<AAICVTracker>(
2664 DepClassTy::REQUIRED);
2667 if (!ICVTrackingAA->isAssumedTracked())
2668 return indicatePessimisticFixpoint();
2671 std::optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2672 std::optional<Value *> NewReplVal =
2673 ICVTrackingAA->getUniqueReplacementValue(ICV);
2675 if (ReplVal == NewReplVal)
2678 ReplVal = NewReplVal;
2679 Changed = ChangeStatus::CHANGED;
2687static bool hasFunctionEndAsUniqueSuccessor(
const BasicBlock *BB) {
2693 return hasFunctionEndAsUniqueSuccessor(
Successor);
2700 ~AAExecutionDomainFunction() {
delete RPOT; }
2704 assert(
F &&
"Expected anchor function");
2709 unsigned TotalBlocks = 0, InitialThreadBlocks = 0, AlignedBlocks = 0;
2710 for (
auto &It : BEDMap) {
2714 InitialThreadBlocks += It.getSecond().IsExecutedByInitialThreadOnly;
2715 AlignedBlocks += It.getSecond().IsReachedFromAlignedBarrierOnly &&
2716 It.getSecond().IsReachingAlignedBarrierOnly;
2718 return "[AAExecutionDomain] " + std::to_string(InitialThreadBlocks) +
"/" +
2719 std::to_string(AlignedBlocks) +
" of " +
2720 std::to_string(TotalBlocks) +
2721 " executed by initial thread / aligned";
2733 << BB.
getName() <<
" is executed by a single thread.\n";
2743 auto HandleAlignedBarrier = [&](
CallBase *CB) {
2744 const ExecutionDomainTy &ED = CB ? CEDMap[{CB, PRE}] : BEDMap[
nullptr];
2745 if (!ED.IsReachedFromAlignedBarrierOnly ||
2746 ED.EncounteredNonLocalSideEffect)
2748 if (!ED.EncounteredAssumes.empty() && !
A.isModulePass())
2759 DeletedBarriers.
insert(CB);
2760 A.deleteAfterManifest(*CB);
2761 ++NumBarriersEliminated;
2763 }
else if (!ED.AlignedBarriers.empty()) {
2766 ED.AlignedBarriers.end());
2768 while (!Worklist.
empty()) {
2770 if (!Visited.
insert(LastCB))
2774 if (!hasFunctionEndAsUniqueSuccessor(LastCB->
getParent()))
2776 if (!DeletedBarriers.
count(LastCB)) {
2777 ++NumBarriersEliminated;
2778 A.deleteAfterManifest(*LastCB);
2784 const ExecutionDomainTy &LastED = CEDMap[{LastCB, PRE}];
2785 Worklist.
append(LastED.AlignedBarriers.begin(),
2786 LastED.AlignedBarriers.end());
2792 if (!ED.EncounteredAssumes.empty() && (CB || !ED.AlignedBarriers.empty()))
2793 for (
auto *AssumeCB : ED.EncounteredAssumes)
2794 A.deleteAfterManifest(*AssumeCB);
2797 for (
auto *CB : AlignedBarriers)
2798 HandleAlignedBarrier(CB);
2802 HandleAlignedBarrier(
nullptr);
2814 mergeInPredecessorBarriersAndAssumptions(
Attributor &
A, ExecutionDomainTy &ED,
2815 const ExecutionDomainTy &PredED);
2820 bool mergeInPredecessor(
Attributor &
A, ExecutionDomainTy &ED,
2821 const ExecutionDomainTy &PredED,
2822 bool InitialEdgeOnly =
false);
2825 bool handleCallees(
Attributor &
A, ExecutionDomainTy &EntryBBED);
2835 assert(BB.
getParent() == getAnchorScope() &&
"Block is out of scope!");
2836 return BEDMap.lookup(&BB).IsExecutedByInitialThreadOnly;
2841 assert(
I.getFunction() == getAnchorScope() &&
2842 "Instruction is out of scope!");
2846 bool ForwardIsOk =
true;
2852 auto *CB = dyn_cast<CallBase>(CurI);
2855 if (CB != &
I && AlignedBarriers.contains(
const_cast<CallBase *
>(CB)))
2857 const auto &It = CEDMap.find({CB, PRE});
2858 if (It == CEDMap.end())
2860 if (!It->getSecond().IsReachingAlignedBarrierOnly)
2861 ForwardIsOk =
false;
2865 if (!CurI && !BEDMap.lookup(
I.getParent()).IsReachingAlignedBarrierOnly)
2866 ForwardIsOk =
false;
2871 auto *CB = dyn_cast<CallBase>(CurI);
2874 if (CB != &
I && AlignedBarriers.contains(
const_cast<CallBase *
>(CB)))
2876 const auto &It = CEDMap.find({CB, POST});
2877 if (It == CEDMap.end())
2879 if (It->getSecond().IsReachedFromAlignedBarrierOnly)
2892 return BEDMap.lookup(
nullptr).IsReachedFromAlignedBarrierOnly;
2894 return BEDMap.lookup(PredBB).IsReachedFromAlignedBarrierOnly;
2906 "No request should be made against an invalid state!");
2907 return BEDMap.lookup(&BB);
2909 std::pair<ExecutionDomainTy, ExecutionDomainTy>
2912 "No request should be made against an invalid state!");
2913 return {CEDMap.lookup({&CB, PRE}), CEDMap.lookup({&CB, POST})};
2917 "No request should be made against an invalid state!");
2918 return InterProceduralED;
2932 if (!Cmp || !
Cmp->isTrueWhenEqual() || !
Cmp->isEquality())
2940 if (
C->isAllOnesValue()) {
2941 auto *CB = dyn_cast<CallBase>(
Cmp->getOperand(0));
2942 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
2943 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
2944 CB = CB ? OpenMPOpt::getCallIfRegularCall(*CB, &RFI) : nullptr;
2950 KernelInfo::getExecModeFromKernelEnvironment(KernelEnvC);
2956 if (
auto *
II = dyn_cast<IntrinsicInst>(
Cmp->getOperand(0)))
2957 if (
II->getIntrinsicID() == Intrinsic::nvvm_read_ptx_sreg_tid_x)
2961 if (
auto *
II = dyn_cast<IntrinsicInst>(
Cmp->getOperand(0)))
2962 if (
II->getIntrinsicID() == Intrinsic::amdgcn_workitem_id_x)
2970 ExecutionDomainTy InterProceduralED;
2982 static bool setAndRecord(
bool &R,
bool V) {
2993void AAExecutionDomainFunction::mergeInPredecessorBarriersAndAssumptions(
2994 Attributor &
A, ExecutionDomainTy &ED,
const ExecutionDomainTy &PredED) {
2995 for (
auto *EA : PredED.EncounteredAssumes)
2996 ED.addAssumeInst(
A, *EA);
2998 for (
auto *AB : PredED.AlignedBarriers)
2999 ED.addAlignedBarrier(
A, *AB);
3002bool AAExecutionDomainFunction::mergeInPredecessor(
3003 Attributor &
A, ExecutionDomainTy &ED,
const ExecutionDomainTy &PredED,
3004 bool InitialEdgeOnly) {
3006 bool Changed =
false;
3008 setAndRecord(ED.IsExecutedByInitialThreadOnly,
3009 InitialEdgeOnly || (PredED.IsExecutedByInitialThreadOnly &&
3010 ED.IsExecutedByInitialThreadOnly));
3012 Changed |= setAndRecord(ED.IsReachedFromAlignedBarrierOnly,
3013 ED.IsReachedFromAlignedBarrierOnly &&
3014 PredED.IsReachedFromAlignedBarrierOnly);
3015 Changed |= setAndRecord(ED.EncounteredNonLocalSideEffect,
3016 ED.EncounteredNonLocalSideEffect |
3017 PredED.EncounteredNonLocalSideEffect);
3019 if (ED.IsReachedFromAlignedBarrierOnly)
3020 mergeInPredecessorBarriersAndAssumptions(
A, ED, PredED);
3022 ED.clearAssumeInstAndAlignedBarriers();
3026bool AAExecutionDomainFunction::handleCallees(
Attributor &
A,
3027 ExecutionDomainTy &EntryBBED) {
3032 DepClassTy::OPTIONAL);
3033 if (!EDAA || !EDAA->getState().isValidState())
3036 EDAA->getExecutionDomain(*cast<CallBase>(ACS.getInstruction())));
3040 ExecutionDomainTy ExitED;
3041 bool AllCallSitesKnown;
3042 if (
A.checkForAllCallSites(PredForCallSite, *
this,
3044 AllCallSitesKnown)) {
3045 for (
const auto &[CSInED, CSOutED] : CallSiteEDs) {
3046 mergeInPredecessor(
A, EntryBBED, CSInED);
3047 ExitED.IsReachingAlignedBarrierOnly &=
3048 CSOutED.IsReachingAlignedBarrierOnly;
3055 EntryBBED.IsExecutedByInitialThreadOnly =
false;
3056 EntryBBED.IsReachedFromAlignedBarrierOnly =
true;
3057 EntryBBED.EncounteredNonLocalSideEffect =
false;
3058 ExitED.IsReachingAlignedBarrierOnly =
false;
3060 EntryBBED.IsExecutedByInitialThreadOnly =
false;
3061 EntryBBED.IsReachedFromAlignedBarrierOnly =
false;
3062 EntryBBED.EncounteredNonLocalSideEffect =
true;
3063 ExitED.IsReachingAlignedBarrierOnly =
false;
3067 bool Changed =
false;
3068 auto &FnED = BEDMap[
nullptr];
3069 Changed |= setAndRecord(FnED.IsReachedFromAlignedBarrierOnly,
3070 FnED.IsReachedFromAlignedBarrierOnly &
3071 EntryBBED.IsReachedFromAlignedBarrierOnly);
3072 Changed |= setAndRecord(FnED.IsReachingAlignedBarrierOnly,
3073 FnED.IsReachingAlignedBarrierOnly &
3074 ExitED.IsReachingAlignedBarrierOnly);
3075 Changed |= setAndRecord(FnED.IsExecutedByInitialThreadOnly,
3076 EntryBBED.IsExecutedByInitialThreadOnly);
3082 bool Changed =
false;
3087 auto HandleAlignedBarrier = [&](
CallBase &CB, ExecutionDomainTy &ED) {
3088 Changed |= AlignedBarriers.insert(&CB);
3090 auto &CallInED = CEDMap[{&CB, PRE}];
3091 Changed |= mergeInPredecessor(
A, CallInED, ED);
3092 CallInED.IsReachingAlignedBarrierOnly =
true;
3094 ED.EncounteredNonLocalSideEffect =
false;
3095 ED.IsReachedFromAlignedBarrierOnly =
true;
3097 ED.clearAssumeInstAndAlignedBarriers();
3098 ED.addAlignedBarrier(
A, CB);
3099 auto &CallOutED = CEDMap[{&CB, POST}];
3100 Changed |= mergeInPredecessor(
A, CallOutED, ED);
3104 A.getAAFor<
AAIsDead>(*
this, getIRPosition(), DepClassTy::OPTIONAL);
3111 for (
auto &RIt : *RPOT) {
3114 bool IsEntryBB = &BB == &EntryBB;
3117 bool AlignedBarrierLastInBlock = IsEntryBB && IsKernel;
3118 bool IsExplicitlyAligned = IsEntryBB && IsKernel;
3119 ExecutionDomainTy ED;
3122 Changed |= handleCallees(
A, ED);
3126 if (LivenessAA && LivenessAA->isAssumedDead(&BB))
3130 if (LivenessAA && LivenessAA->isEdgeDead(PredBB, &BB))
3132 bool InitialEdgeOnly = isInitialThreadOnlyEdge(
3133 A, dyn_cast<BranchInst>(PredBB->getTerminator()), BB);
3134 mergeInPredecessor(
A, ED, BEDMap[PredBB], InitialEdgeOnly);
3141 bool UsedAssumedInformation;
3142 if (
A.isAssumedDead(
I, *
this, LivenessAA, UsedAssumedInformation,
3143 false, DepClassTy::OPTIONAL,
3149 if (
auto *
II = dyn_cast<IntrinsicInst>(&
I)) {
3150 if (
auto *AI = dyn_cast_or_null<AssumeInst>(
II)) {
3151 ED.addAssumeInst(
A, *AI);
3155 if (
II->isAssumeLikeIntrinsic())
3159 if (
auto *FI = dyn_cast<FenceInst>(&
I)) {
3160 if (!ED.EncounteredNonLocalSideEffect) {
3162 if (ED.IsReachedFromAlignedBarrierOnly)
3167 case AtomicOrdering::NotAtomic:
3169 case AtomicOrdering::Unordered:
3171 case AtomicOrdering::Monotonic:
3173 case AtomicOrdering::Acquire:
3175 case AtomicOrdering::Release:
3177 case AtomicOrdering::AcquireRelease:
3179 case AtomicOrdering::SequentiallyConsistent:
3183 NonNoOpFences.insert(FI);
3186 auto *CB = dyn_cast<CallBase>(&
I);
3188 bool IsAlignedBarrier =
3192 AlignedBarrierLastInBlock &= IsNoSync;
3193 IsExplicitlyAligned &= IsNoSync;
3199 if (IsAlignedBarrier) {
3200 HandleAlignedBarrier(*CB, ED);
3201 AlignedBarrierLastInBlock =
true;
3202 IsExplicitlyAligned =
true;
3207 if (isa<MemIntrinsic>(&
I)) {
3208 if (!ED.EncounteredNonLocalSideEffect &&
3210 ED.EncounteredNonLocalSideEffect =
true;
3212 ED.IsReachedFromAlignedBarrierOnly =
false;
3220 auto &CallInED = CEDMap[{CB, PRE}];
3221 Changed |= mergeInPredecessor(
A, CallInED, ED);
3227 if (!IsNoSync && Callee && !
Callee->isDeclaration()) {
3230 if (EDAA && EDAA->getState().isValidState()) {
3233 CalleeED.IsReachedFromAlignedBarrierOnly;
3234 AlignedBarrierLastInBlock = ED.IsReachedFromAlignedBarrierOnly;
3235 if (IsNoSync || !CalleeED.IsReachedFromAlignedBarrierOnly)
3236 ED.EncounteredNonLocalSideEffect |=
3237 CalleeED.EncounteredNonLocalSideEffect;
3239 ED.EncounteredNonLocalSideEffect =
3240 CalleeED.EncounteredNonLocalSideEffect;
3241 if (!CalleeED.IsReachingAlignedBarrierOnly) {
3243 setAndRecord(CallInED.IsReachingAlignedBarrierOnly,
false);
3246 if (CalleeED.IsReachedFromAlignedBarrierOnly)
3247 mergeInPredecessorBarriersAndAssumptions(
A, ED, CalleeED);
3248 auto &CallOutED = CEDMap[{CB, POST}];
3249 Changed |= mergeInPredecessor(
A, CallOutED, ED);
3254 ED.IsReachedFromAlignedBarrierOnly =
false;
3255 Changed |= setAndRecord(CallInED.IsReachingAlignedBarrierOnly,
false);
3258 AlignedBarrierLastInBlock &= ED.IsReachedFromAlignedBarrierOnly;
3260 auto &CallOutED = CEDMap[{CB, POST}];
3261 Changed |= mergeInPredecessor(
A, CallOutED, ED);
3264 if (!
I.mayHaveSideEffects() && !
I.mayReadFromMemory())
3278 if (MemAA && MemAA->getState().isValidState() &&
3279 MemAA->checkForAllAccessesToMemoryKind(
3284 auto &InfoCache =
A.getInfoCache();
3285 if (!
I.mayHaveSideEffects() && InfoCache.isOnlyUsedByAssume(
I))
3288 if (
auto *LI = dyn_cast<LoadInst>(&
I))
3289 if (LI->hasMetadata(LLVMContext::MD_invariant_load))
3292 if (!ED.EncounteredNonLocalSideEffect &&
3294 ED.EncounteredNonLocalSideEffect =
true;
3297 bool IsEndAndNotReachingAlignedBarriersOnly =
false;
3298 if (!isa<UnreachableInst>(BB.getTerminator()) &&
3299 !BB.getTerminator()->getNumSuccessors()) {
3301 Changed |= mergeInPredecessor(
A, InterProceduralED, ED);
3303 auto &FnED = BEDMap[
nullptr];
3304 if (IsKernel && !IsExplicitlyAligned)
3305 FnED.IsReachingAlignedBarrierOnly =
false;
3306 Changed |= mergeInPredecessor(
A, FnED, ED);
3308 if (!FnED.IsReachingAlignedBarrierOnly) {
3309 IsEndAndNotReachingAlignedBarriersOnly =
true;
3310 SyncInstWorklist.
push_back(BB.getTerminator());
3311 auto &BBED = BEDMap[&BB];
3312 Changed |= setAndRecord(BBED.IsReachingAlignedBarrierOnly,
false);
3316 ExecutionDomainTy &StoredED = BEDMap[&BB];
3317 ED.IsReachingAlignedBarrierOnly = StoredED.IsReachingAlignedBarrierOnly &
3318 !IsEndAndNotReachingAlignedBarriersOnly;
3324 if (ED.IsExecutedByInitialThreadOnly !=
3325 StoredED.IsExecutedByInitialThreadOnly ||
3326 ED.IsReachedFromAlignedBarrierOnly !=
3327 StoredED.IsReachedFromAlignedBarrierOnly ||
3328 ED.EncounteredNonLocalSideEffect !=
3329 StoredED.EncounteredNonLocalSideEffect)
3333 StoredED = std::move(ED);
3339 while (!SyncInstWorklist.
empty()) {
3342 bool HitAlignedBarrierOrKnownEnd =
false;
3344 auto *CB = dyn_cast<CallBase>(CurInst);
3347 auto &CallOutED = CEDMap[{CB, POST}];
3348 Changed |= setAndRecord(CallOutED.IsReachingAlignedBarrierOnly,
false);
3349 auto &CallInED = CEDMap[{CB, PRE}];
3350 HitAlignedBarrierOrKnownEnd =
3351 AlignedBarriers.count(CB) || !CallInED.IsReachingAlignedBarrierOnly;
3352 if (HitAlignedBarrierOrKnownEnd)
3354 Changed |= setAndRecord(CallInED.IsReachingAlignedBarrierOnly,
false);
3356 if (HitAlignedBarrierOrKnownEnd)
3360 if (LivenessAA && LivenessAA->isEdgeDead(PredBB, SyncBB))
3362 if (!Visited.
insert(PredBB))
3364 auto &PredED = BEDMap[PredBB];
3365 if (setAndRecord(PredED.IsReachingAlignedBarrierOnly,
false)) {
3367 SyncInstWorklist.
push_back(PredBB->getTerminator());
3370 if (SyncBB != &EntryBB)
3373 setAndRecord(InterProceduralED.IsReachingAlignedBarrierOnly,
false);
3376 return Changed ? ChangeStatus::CHANGED : ChangeStatus::UNCHANGED;
3381struct AAHeapToShared :
public StateWrapper<BooleanState, AbstractAttribute> {
3386 static AAHeapToShared &createForPosition(
const IRPosition &IRP,
3390 virtual bool isAssumedHeapToShared(
CallBase &CB)
const = 0;
3394 virtual bool isAssumedHeapToSharedRemovedFree(
CallBase &CB)
const = 0;
3397 const std::string
getName()
const override {
return "AAHeapToShared"; }
3400 const char *getIdAddr()
const override {
return &
ID; }
3409 static const char ID;
3412struct AAHeapToSharedFunction :
public AAHeapToShared {
3414 : AAHeapToShared(IRP,
A) {}
3416 const std::string getAsStr(
Attributor *)
const override {
3417 return "[AAHeapToShared] " + std::to_string(MallocCalls.size()) +
3418 " malloc calls eligible.";
3422 void trackStatistics()
const override {}
3426 void findPotentialRemovedFreeCalls(
Attributor &
A) {
3427 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
3428 auto &FreeRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
3430 PotentialRemovedFreeCalls.clear();
3434 for (
auto *U : CB->
users()) {
3436 if (
C &&
C->getCalledFunction() == FreeRFI.Declaration)
3440 if (FreeCalls.
size() != 1)
3443 PotentialRemovedFreeCalls.insert(FreeCalls.
front());
3449 indicatePessimisticFixpoint();
3453 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
3454 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3455 if (!RFI.Declaration)
3460 bool &) -> std::optional<Value *> {
return nullptr; };
3463 for (
User *U : RFI.Declaration->
users())
3464 if (
CallBase *CB = dyn_cast<CallBase>(U)) {
3467 MallocCalls.insert(CB);
3472 findPotentialRemovedFreeCalls(
A);
3475 bool isAssumedHeapToShared(
CallBase &CB)
const override {
3476 return isValidState() && MallocCalls.count(&CB);
3479 bool isAssumedHeapToSharedRemovedFree(
CallBase &CB)
const override {
3480 return isValidState() && PotentialRemovedFreeCalls.count(&CB);
3484 if (MallocCalls.empty())
3485 return ChangeStatus::UNCHANGED;
3487 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
3488 auto &FreeCall = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
3492 DepClassTy::OPTIONAL);
3497 if (HS &&
HS->isAssumedHeapToStack(*CB))
3502 for (
auto *U : CB->
users()) {
3504 if (
C &&
C->getCalledFunction() == FreeCall.Declaration)
3507 if (FreeCalls.
size() != 1)
3514 <<
" with shared memory."
3515 <<
" Shared memory usage is limited to "
3521 <<
" with " << AllocSize->getZExtValue()
3522 <<
" bytes of shared memory\n");
3528 Type *Int8ArrTy = ArrayType::get(Int8Ty, AllocSize->getZExtValue());
3533 static_cast<unsigned>(AddressSpace::Shared));
3535 SharedMem, PointerType::getUnqual(
M->getContext()));
3538 return OR <<
"Replaced globalized variable with "
3539 <<
ore::NV(
"SharedMemory", AllocSize->getZExtValue())
3540 << (AllocSize->isOne() ?
" byte " :
" bytes ")
3541 <<
"of shared memory.";
3547 "HeapToShared on allocation without alignment attribute");
3548 SharedMem->setAlignment(*Alignment);
3551 A.deleteAfterManifest(*CB);
3552 A.deleteAfterManifest(*FreeCalls.
front());
3554 SharedMemoryUsed += AllocSize->getZExtValue();
3555 NumBytesMovedToSharedMemory = SharedMemoryUsed;
3556 Changed = ChangeStatus::CHANGED;
3563 if (MallocCalls.empty())
3564 return indicatePessimisticFixpoint();
3565 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
3566 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3567 if (!RFI.Declaration)
3568 return ChangeStatus::UNCHANGED;
3572 auto NumMallocCalls = MallocCalls.size();
3575 for (
User *U : RFI.Declaration->
users()) {
3576 if (
CallBase *CB = dyn_cast<CallBase>(U)) {
3577 if (CB->getCaller() !=
F)
3579 if (!MallocCalls.count(CB))
3581 if (!isa<ConstantInt>(CB->getArgOperand(0))) {
3582 MallocCalls.remove(CB);
3587 if (!ED || !ED->isExecutedByInitialThreadOnly(*CB))
3588 MallocCalls.remove(CB);
3592 findPotentialRemovedFreeCalls(
A);
3594 if (NumMallocCalls != MallocCalls.size())
3595 return ChangeStatus::CHANGED;
3597 return ChangeStatus::UNCHANGED;
3605 unsigned SharedMemoryUsed = 0;
3608struct AAKernelInfo :
public StateWrapper<KernelInfoState, AbstractAttribute> {
3614 static bool requiresCalleeForCallBase() {
return false; }
3617 void trackStatistics()
const override {}
3620 const std::string getAsStr(
Attributor *)
const override {
3621 if (!isValidState())
3623 return std::string(SPMDCompatibilityTracker.isAssumed() ?
"SPMD"
3625 std::string(SPMDCompatibilityTracker.isAtFixpoint() ?
" [FIX]"
3627 std::string(
" #PRs: ") +
3628 (ReachedKnownParallelRegions.isValidState()
3629 ? std::to_string(ReachedKnownParallelRegions.size())
3631 ", #Unknown PRs: " +
3632 (ReachedUnknownParallelRegions.isValidState()
3635 ", #Reaching Kernels: " +
3636 (ReachingKernelEntries.isValidState()
3640 (ParallelLevels.isValidState()
3643 ", NestedPar: " + (NestedParallelism ?
"yes" :
"no");
3650 const std::string
getName()
const override {
return "AAKernelInfo"; }
3653 const char *getIdAddr()
const override {
return &
ID; }
3660 static const char ID;
3665struct AAKernelInfoFunction : AAKernelInfo {
3667 : AAKernelInfo(IRP,
A) {}
3672 return GuardedInstructions;
3675 void setConfigurationOfKernelEnvironment(
ConstantStruct *ConfigC) {
3677 KernelEnvC, ConfigC, {KernelInfo::ConfigurationIdx});
3678 assert(NewKernelEnvC &&
"Failed to create new kernel environment");
3679 KernelEnvC = cast<ConstantStruct>(NewKernelEnvC);
3682#define KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MEMBER) \
3683 void set##MEMBER##OfKernelEnvironment(ConstantInt *NewVal) { \
3684 ConstantStruct *ConfigC = \
3685 KernelInfo::getConfigurationFromKernelEnvironment(KernelEnvC); \
3686 Constant *NewConfigC = ConstantFoldInsertValueInstruction( \
3687 ConfigC, NewVal, {KernelInfo::MEMBER##Idx}); \
3688 assert(NewConfigC && "Failed to create new configuration environment"); \
3689 setConfigurationOfKernelEnvironment(cast<ConstantStruct>(NewConfigC)); \
3700#undef KERNEL_ENVIRONMENT_CONFIGURATION_SETTER
3707 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
3711 OMPInformationCache::RuntimeFunctionInfo &InitRFI =
3712 OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
3713 OMPInformationCache::RuntimeFunctionInfo &DeinitRFI =
3714 OMPInfoCache.RFIs[OMPRTL___kmpc_target_deinit];
3718 auto StoreCallBase = [](
Use &U,
3719 OMPInformationCache::RuntimeFunctionInfo &RFI,
3721 CallBase *CB = OpenMPOpt::getCallIfRegularCall(U, &RFI);
3723 "Unexpected use of __kmpc_target_init or __kmpc_target_deinit!");
3725 "Multiple uses of __kmpc_target_init or __kmpc_target_deinit!");
3731 StoreCallBase(U, InitRFI, KernelInitCB);
3735 DeinitRFI.foreachUse(
3737 StoreCallBase(U, DeinitRFI, KernelDeinitCB);
3743 if (!KernelInitCB || !KernelDeinitCB)
3747 ReachingKernelEntries.insert(Fn);
3748 IsKernelEntry =
true;
3756 KernelConfigurationSimplifyCB =
3758 bool &UsedAssumedInformation) -> std::optional<Constant *> {
3759 if (!isAtFixpoint()) {
3762 UsedAssumedInformation =
true;
3763 A.recordDependence(*
this, *AA, DepClassTy::OPTIONAL);
3768 A.registerGlobalVariableSimplificationCallback(
3769 *KernelEnvGV, KernelConfigurationSimplifyCB);
3772 bool CanChangeToSPMD = OMPInfoCache.runtimeFnsAvailable(
3773 {OMPRTL___kmpc_get_hardware_thread_id_in_block,
3774 OMPRTL___kmpc_barrier_simple_spmd});
3778 KernelInfo::getExecModeFromKernelEnvironment(KernelEnvC);
3783 SPMDCompatibilityTracker.indicateOptimisticFixpoint();
3787 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3789 setExecModeOfKernelEnvironment(AssumedExecModeC);
3796 setMinThreadsOfKernelEnvironment(ConstantInt::get(Int32Ty, MinThreads));
3798 setMaxThreadsOfKernelEnvironment(ConstantInt::get(Int32Ty,
MaxThreads));
3799 auto [MinTeams, MaxTeams] =
3802 setMinTeamsOfKernelEnvironment(ConstantInt::get(Int32Ty, MinTeams));
3804 setMaxTeamsOfKernelEnvironment(ConstantInt::get(Int32Ty, MaxTeams));
3807 KernelInfo::getMayUseNestedParallelismFromKernelEnvironment(KernelEnvC);
3808 ConstantInt *AssumedMayUseNestedParallelismC = ConstantInt::get(
3810 setMayUseNestedParallelismOfKernelEnvironment(
3811 AssumedMayUseNestedParallelismC);
3815 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
3818 ConstantInt::get(UseGenericStateMachineC->
getIntegerType(),
false);
3819 setUseGenericStateMachineOfKernelEnvironment(
3820 AssumedUseGenericStateMachineC);
3826 if (!OMPInfoCache.RFIs[RFKind].Declaration)
3828 A.registerVirtualUseCallback(*OMPInfoCache.RFIs[RFKind].Declaration, CB);
3832 auto AddDependence = [](
Attributor &
A,
const AAKernelInfo *KI,
3835 A.recordDependence(*KI, *QueryingAA, DepClassTy::OPTIONAL);
3849 if (SPMDCompatibilityTracker.isValidState())
3850 return AddDependence(
A,
this, QueryingAA);
3852 if (!ReachedKnownParallelRegions.isValidState())
3853 return AddDependence(
A,
this, QueryingAA);
3859 RegisterVirtualUse(OMPRTL___kmpc_get_hardware_num_threads_in_block,
3860 CustomStateMachineUseCB);
3861 RegisterVirtualUse(OMPRTL___kmpc_get_warp_size, CustomStateMachineUseCB);
3862 RegisterVirtualUse(OMPRTL___kmpc_barrier_simple_generic,
3863 CustomStateMachineUseCB);
3864 RegisterVirtualUse(OMPRTL___kmpc_kernel_parallel,
3865 CustomStateMachineUseCB);
3866 RegisterVirtualUse(OMPRTL___kmpc_kernel_end_parallel,
3867 CustomStateMachineUseCB);
3871 if (SPMDCompatibilityTracker.isAtFixpoint())
3878 if (!SPMDCompatibilityTracker.isValidState())
3879 return AddDependence(
A,
this, QueryingAA);
3882 RegisterVirtualUse(OMPRTL___kmpc_get_hardware_thread_id_in_block,
3891 if (!SPMDCompatibilityTracker.isValidState())
3892 return AddDependence(
A,
this, QueryingAA);
3893 if (SPMDCompatibilityTracker.empty())
3894 return AddDependence(
A,
this, QueryingAA);
3895 if (!mayContainParallelRegion())
3896 return AddDependence(
A,
this, QueryingAA);
3899 RegisterVirtualUse(OMPRTL___kmpc_barrier_simple_spmd, SPMDBarrierUseCB);
3903 static std::string sanitizeForGlobalName(std::string S) {
3907 return !((C >=
'a' && C <=
'z') || (C >=
'A' && C <=
'Z') ||
3908 (C >=
'0' && C <=
'9') || C ==
'_');
3919 if (!KernelInitCB || !KernelDeinitCB)
3920 return ChangeStatus::UNCHANGED;
3924 bool HasBuiltStateMachine =
true;
3925 if (!changeToSPMDMode(
A, Changed)) {
3927 HasBuiltStateMachine = buildCustomStateMachine(
A, Changed);
3929 HasBuiltStateMachine =
false;
3936 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
3937 ExistingKernelEnvC);
3938 if (!HasBuiltStateMachine)
3939 setUseGenericStateMachineOfKernelEnvironment(
3940 OldUseGenericStateMachineVal);
3947 Changed = ChangeStatus::CHANGED;
3953 void insertInstructionGuardsHelper(
Attributor &
A) {
3954 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
3956 auto CreateGuardedRegion = [&](
Instruction *RegionStartI,
3990 DT, LI, MSU,
"region.guarded.end");
3993 MSU,
"region.barrier");
3996 DT, LI, MSU,
"region.exit");
3998 SplitBlock(ParentBB, RegionStartI, DT, LI, MSU,
"region.guarded");
4001 "Expected a different CFG");
4004 ParentBB, ParentBB->
getTerminator(), DT, LI, MSU,
"region.check.tid");
4007 A.registerManifestAddedBasicBlock(*RegionEndBB);
4008 A.registerManifestAddedBasicBlock(*RegionBarrierBB);
4009 A.registerManifestAddedBasicBlock(*RegionExitBB);
4010 A.registerManifestAddedBasicBlock(*RegionStartBB);
4011 A.registerManifestAddedBasicBlock(*RegionCheckTidBB);
4013 bool HasBroadcastValues =
false;
4018 for (
Use &U :
I.uses()) {
4024 if (OutsideUses.
empty())
4027 HasBroadcastValues =
true;
4032 M,
I.getType(),
false,
4034 sanitizeForGlobalName(
4035 (
I.getName() +
".guarded.output.alloc").str()),
4037 static_cast<unsigned>(AddressSpace::Shared));
4044 I.getType(), SharedMem,
I.getName() +
".guarded.output.load",
4048 for (
Use *U : OutsideUses)
4049 A.changeUseAfterManifest(*U, *LoadI);
4052 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
4058 InsertPointTy(ParentBB, ParentBB->
end()),
DL);
4059 OMPInfoCache.OMPBuilder.updateToLocation(Loc);
4062 OMPInfoCache.OMPBuilder.getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4064 OMPInfoCache.OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4070 InsertPointTy(RegionCheckTidBB, RegionCheckTidBB->
end()),
DL);
4071 OMPInfoCache.OMPBuilder.updateToLocation(LocRegionCheckTid);
4073 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4074 M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
4076 OMPInfoCache.OMPBuilder.Builder.CreateCall(HardwareTidFn, {});
4078 OMPInfoCache.setCallingConvention(HardwareTidFn, Tid);
4079 Value *TidCheck = OMPInfoCache.OMPBuilder.Builder.CreateIsNull(Tid);
4080 OMPInfoCache.OMPBuilder.Builder
4081 .CreateCondBr(TidCheck, RegionStartBB, RegionBarrierBB)
4087 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4088 M, OMPRTL___kmpc_barrier_simple_spmd);
4089 OMPInfoCache.OMPBuilder.updateToLocation(InsertPointTy(
4092 OMPInfoCache.OMPBuilder.Builder.CreateCall(BarrierFn, {Ident, Tid});
4094 OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
4097 if (HasBroadcastValues) {
4102 OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
4106 auto &AllocSharedRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
4108 for (
Instruction *GuardedI : SPMDCompatibilityTracker) {
4110 if (!Visited.
insert(BB).second)
4116 while (++IP != IPEnd) {
4117 if (!IP->mayHaveSideEffects() && !IP->mayReadFromMemory())
4120 if (OpenMPOpt::getCallIfRegularCall(*
I, &AllocSharedRFI))
4122 if (!
I->user_empty() || !SPMDCompatibilityTracker.contains(
I)) {
4123 LastEffect =
nullptr;
4130 for (
auto &Reorder : Reorders)
4136 for (
Instruction *GuardedI : SPMDCompatibilityTracker) {
4138 auto *CalleeAA =
A.lookupAAFor<AAKernelInfo>(
4141 assert(CalleeAA !=
nullptr &&
"Expected Callee AAKernelInfo");
4142 auto &CalleeAAFunction = *cast<AAKernelInfoFunction>(CalleeAA);
4144 if (CalleeAAFunction.getGuardedInstructions().contains(GuardedI))
4147 Instruction *GuardedRegionStart =
nullptr, *GuardedRegionEnd =
nullptr;
4151 if (SPMDCompatibilityTracker.contains(&
I)) {
4152 CalleeAAFunction.getGuardedInstructions().insert(&
I);
4153 if (GuardedRegionStart)
4154 GuardedRegionEnd = &
I;
4156 GuardedRegionStart = GuardedRegionEnd = &
I;
4163 if (GuardedRegionStart) {
4165 std::make_pair(GuardedRegionStart, GuardedRegionEnd));
4166 GuardedRegionStart =
nullptr;
4167 GuardedRegionEnd =
nullptr;
4172 for (
auto &GR : GuardedRegions)
4173 CreateGuardedRegion(GR.first, GR.second);
4176 void forceSingleThreadPerWorkgroupHelper(
Attributor &
A) {
4185 auto &Ctx = getAnchorValue().getContext();
4192 KernelInitCB->
getNextNode(),
"main.thread.user_code");
4197 A.registerManifestAddedBasicBlock(*InitBB);
4198 A.registerManifestAddedBasicBlock(*UserCodeBB);
4199 A.registerManifestAddedBasicBlock(*ReturnBB);
4208 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
4210 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4211 M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
4216 OMPInfoCache.setCallingConvention(ThreadIdInBlockFn, ThreadIdInBlock);
4222 ConstantInt::get(ThreadIdInBlock->
getType(), 0),
4223 "thread.is_main", InitBB);
4229 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
4231 if (!SPMDCompatibilityTracker.isAssumed()) {
4232 for (
Instruction *NonCompatibleI : SPMDCompatibilityTracker) {
4233 if (!NonCompatibleI)
4237 if (
auto *CB = dyn_cast<CallBase>(NonCompatibleI))
4238 if (OMPInfoCache.RTLFunctions.contains(CB->getCalledFunction()))
4242 ORA <<
"Value has potential side effects preventing SPMD-mode "
4244 if (isa<CallBase>(NonCompatibleI)) {
4245 ORA <<
". Add `[[omp::assume(\"ompx_spmd_amenable\")]]` to "
4246 "the called function to override";
4254 << *NonCompatibleI <<
"\n");
4266 Kernel = CB->getCaller();
4274 KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC);
4280 Changed = ChangeStatus::CHANGED;
4284 if (mayContainParallelRegion())
4285 insertInstructionGuardsHelper(
A);
4287 forceSingleThreadPerWorkgroupHelper(
A);
4292 "Initially non-SPMD kernel has SPMD exec mode!");
4293 setExecModeOfKernelEnvironment(
4297 ++NumOpenMPTargetRegionKernelsSPMD;
4300 return OR <<
"Transformed generic-mode kernel to SPMD-mode.";
4312 if (!ReachedKnownParallelRegions.isValidState())
4315 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
4316 if (!OMPInfoCache.runtimeFnsAvailable(
4317 {OMPRTL___kmpc_get_hardware_num_threads_in_block,
4318 OMPRTL___kmpc_get_warp_size, OMPRTL___kmpc_barrier_simple_generic,
4319 OMPRTL___kmpc_kernel_parallel, OMPRTL___kmpc_kernel_end_parallel}))
4330 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
4331 ExistingKernelEnvC);
4333 KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC);
4338 if (UseStateMachineC->
isZero() ||
4342 Changed = ChangeStatus::CHANGED;
4345 setUseGenericStateMachineOfKernelEnvironment(
4352 if (!mayContainParallelRegion()) {
4353 ++NumOpenMPTargetRegionKernelsWithoutStateMachine;
4356 return OR <<
"Removing unused state machine from generic-mode kernel.";
4364 if (ReachedUnknownParallelRegions.empty()) {
4365 ++NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback;
4368 return OR <<
"Rewriting generic-mode kernel with a customized state "
4373 ++NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback;
4376 return OR <<
"Generic-mode kernel is executed with a customized state "
4377 "machine that requires a fallback.";
4382 for (
CallBase *UnknownParallelRegionCB : ReachedUnknownParallelRegions) {
4383 if (!UnknownParallelRegionCB)
4386 return ORA <<
"Call may contain unknown parallel regions. Use "
4387 <<
"`[[omp::assume(\"omp_no_parallelism\")]]` to "
4425 auto &Ctx = getAnchorValue().getContext();
4429 BasicBlock *InitBB = KernelInitCB->getParent();
4431 KernelInitCB->getNextNode(),
"thread.user_code.check");
4435 Ctx,
"worker_state_machine.begin",
Kernel, UserCodeEntryBB);
4437 Ctx,
"worker_state_machine.finished",
Kernel, UserCodeEntryBB);
4439 Ctx,
"worker_state_machine.is_active.check",
Kernel, UserCodeEntryBB);
4442 Kernel, UserCodeEntryBB);
4445 Kernel, UserCodeEntryBB);
4447 Ctx,
"worker_state_machine.done.barrier",
Kernel, UserCodeEntryBB);
4448 A.registerManifestAddedBasicBlock(*InitBB);
4449 A.registerManifestAddedBasicBlock(*UserCodeEntryBB);
4450 A.registerManifestAddedBasicBlock(*IsWorkerCheckBB);
4451 A.registerManifestAddedBasicBlock(*StateMachineBeginBB);
4452 A.registerManifestAddedBasicBlock(*StateMachineFinishedBB);
4453 A.registerManifestAddedBasicBlock(*StateMachineIsActiveCheckBB);
4454 A.registerManifestAddedBasicBlock(*StateMachineIfCascadeCurrentBB);
4455 A.registerManifestAddedBasicBlock(*StateMachineEndParallelBB);
4456 A.registerManifestAddedBasicBlock(*StateMachineDoneBarrierBB);
4458 const DebugLoc &DLoc = KernelInitCB->getDebugLoc();
4464 ConstantInt::get(KernelInitCB->getType(), -1),
4465 "thread.is_worker", InitBB);
4471 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4472 M, OMPRTL___kmpc_get_hardware_num_threads_in_block);
4474 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4475 M, OMPRTL___kmpc_get_warp_size);
4478 OMPInfoCache.setCallingConvention(BlockHwSizeFn, BlockHwSize);
4482 OMPInfoCache.setCallingConvention(WarpSizeFn, WarpSize);
4485 BlockHwSize, WarpSize,
"block.size", IsWorkerCheckBB);
4489 "thread.is_main_or_worker", IsWorkerCheckBB);
4492 IsMainOrWorker, IsWorkerCheckBB);
4496 Type *VoidPtrTy = PointerType::getUnqual(Ctx);
4498 new AllocaInst(VoidPtrTy,
DL.getAllocaAddrSpace(),
nullptr,
4502 OMPInfoCache.OMPBuilder.updateToLocation(
4505 StateMachineBeginBB->
end()),
4508 Value *Ident = KernelInfo::getIdentFromKernelEnvironment(KernelEnvC);
4509 Value *GTid = KernelInitCB;
4512 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4513 M, OMPRTL___kmpc_barrier_simple_generic);
4516 OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
4520 (
unsigned int)AddressSpace::Generic) {
4522 WorkFnAI, PointerType::get(Ctx, (
unsigned int)AddressSpace::Generic),
4523 WorkFnAI->
getName() +
".generic", StateMachineBeginBB);
4528 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4529 M, OMPRTL___kmpc_kernel_parallel);
4531 KernelParallelFn, {WorkFnAI},
"worker.is_active", StateMachineBeginBB);
4532 OMPInfoCache.setCallingConvention(KernelParallelFn, IsActiveWorker);
4535 StateMachineBeginBB);
4545 StateMachineBeginBB);
4546 IsDone->setDebugLoc(DLoc);
4548 IsDone, StateMachineBeginBB)
4552 StateMachineDoneBarrierBB, IsActiveWorker,
4553 StateMachineIsActiveCheckBB)
4559 const unsigned int WrapperFunctionArgNo = 6;
4564 for (
int I = 0, E = ReachedKnownParallelRegions.size();
I < E; ++
I) {
4565 auto *CB = ReachedKnownParallelRegions[
I];
4566 auto *ParallelRegion = dyn_cast<Function>(
4567 CB->getArgOperand(WrapperFunctionArgNo)->stripPointerCasts());
4569 Ctx,
"worker_state_machine.parallel_region.execute",
Kernel,
4570 StateMachineEndParallelBB);
4572 ->setDebugLoc(DLoc);
4578 Kernel, StateMachineEndParallelBB);
4579 A.registerManifestAddedBasicBlock(*PRExecuteBB);
4580 A.registerManifestAddedBasicBlock(*PRNextBB);
4585 if (
I + 1 < E || !ReachedUnknownParallelRegions.empty()) {
4588 "worker.check_parallel_region", StateMachineIfCascadeCurrentBB);
4596 StateMachineIfCascadeCurrentBB)
4598 StateMachineIfCascadeCurrentBB = PRNextBB;
4604 if (!ReachedUnknownParallelRegions.empty()) {
4605 StateMachineIfCascadeCurrentBB->
setName(
4606 "worker_state_machine.parallel_region.fallback.execute");
4608 StateMachineIfCascadeCurrentBB)
4609 ->setDebugLoc(DLoc);
4612 StateMachineIfCascadeCurrentBB)
4616 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4617 M, OMPRTL___kmpc_kernel_end_parallel);
4620 OMPInfoCache.setCallingConvention(EndParallelFn, EndParallel);
4626 ->setDebugLoc(DLoc);
4636 KernelInfoState StateBefore = getState();
4642 struct UpdateKernelEnvCRAII {
4643 AAKernelInfoFunction &AA;
4645 UpdateKernelEnvCRAII(AAKernelInfoFunction &AA) : AA(AA) {}
4647 ~UpdateKernelEnvCRAII() {
4654 if (!AA.isValidState()) {
4655 AA.KernelEnvC = ExistingKernelEnvC;
4659 if (!AA.ReachedKnownParallelRegions.isValidState())
4660 AA.setUseGenericStateMachineOfKernelEnvironment(
4661 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
4662 ExistingKernelEnvC));
4664 if (!AA.SPMDCompatibilityTracker.isValidState())
4665 AA.setExecModeOfKernelEnvironment(
4666 KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC));
4669 KernelInfo::getMayUseNestedParallelismFromKernelEnvironment(
4671 ConstantInt *NewMayUseNestedParallelismC = ConstantInt::get(
4672 MayUseNestedParallelismC->
getIntegerType(), AA.NestedParallelism);
4673 AA.setMayUseNestedParallelismOfKernelEnvironment(
4674 NewMayUseNestedParallelismC);
4681 if (isa<CallBase>(
I))
4684 if (!
I.mayWriteToMemory())
4686 if (
auto *SI = dyn_cast<StoreInst>(&
I)) {
4689 DepClassTy::OPTIONAL);
4692 DepClassTy::OPTIONAL);
4693 if (UnderlyingObjsAA &&
4694 UnderlyingObjsAA->forallUnderlyingObjects([&](
Value &Obj) {
4695 if (AA::isAssumedThreadLocalObject(A, Obj, *this))
4699 auto *CB = dyn_cast<CallBase>(&Obj);
4700 return CB && HS && HS->isAssumedHeapToStack(*CB);
4706 SPMDCompatibilityTracker.insert(&
I);
4710 bool UsedAssumedInformationInCheckRWInst =
false;
4711 if (!SPMDCompatibilityTracker.isAtFixpoint())
4712 if (!
A.checkForAllReadWriteInstructions(
4713 CheckRWInst, *
this, UsedAssumedInformationInCheckRWInst))
4716 bool UsedAssumedInformationFromReachingKernels =
false;
4717 if (!IsKernelEntry) {
4718 updateParallelLevels(
A);
4720 bool AllReachingKernelsKnown =
true;
4721 updateReachingKernelEntries(
A, AllReachingKernelsKnown);
4722 UsedAssumedInformationFromReachingKernels = !AllReachingKernelsKnown;
4724 if (!SPMDCompatibilityTracker.empty()) {
4725 if (!ParallelLevels.isValidState())
4726 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4727 else if (!ReachingKernelEntries.isValidState())
4728 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4734 for (
auto *
Kernel : ReachingKernelEntries) {
4735 auto *CBAA =
A.getAAFor<AAKernelInfo>(
4737 if (CBAA && CBAA->SPMDCompatibilityTracker.isValidState() &&
4738 CBAA->SPMDCompatibilityTracker.isAssumed())
4742 if (!CBAA || !CBAA->SPMDCompatibilityTracker.isAtFixpoint())
4743 UsedAssumedInformationFromReachingKernels =
true;
4745 if (SPMD != 0 &&
Generic != 0)
4746 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4752 bool AllParallelRegionStatesWereFixed =
true;
4753 bool AllSPMDStatesWereFixed =
true;
4755 auto &CB = cast<CallBase>(
I);
4756 auto *CBAA =
A.getAAFor<AAKernelInfo>(
4760 getState() ^= CBAA->getState();
4761 AllSPMDStatesWereFixed &= CBAA->SPMDCompatibilityTracker.isAtFixpoint();
4762 AllParallelRegionStatesWereFixed &=
4763 CBAA->ReachedKnownParallelRegions.isAtFixpoint();
4764 AllParallelRegionStatesWereFixed &=
4765 CBAA->ReachedUnknownParallelRegions.isAtFixpoint();
4769 bool UsedAssumedInformationInCheckCallInst =
false;
4770 if (!
A.checkForAllCallLikeInstructions(
4771 CheckCallInst, *
this, UsedAssumedInformationInCheckCallInst)) {
4773 <<
"Failed to visit all call-like instructions!\n";);
4774 return indicatePessimisticFixpoint();
4779 if (!UsedAssumedInformationInCheckCallInst &&
4780 AllParallelRegionStatesWereFixed) {
4781 ReachedKnownParallelRegions.indicateOptimisticFixpoint();
4782 ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
4787 if (!UsedAssumedInformationInCheckRWInst &&
4788 !UsedAssumedInformationInCheckCallInst &&
4789 !UsedAssumedInformationFromReachingKernels && AllSPMDStatesWereFixed)
4790 SPMDCompatibilityTracker.indicateOptimisticFixpoint();
4792 return StateBefore == getState() ? ChangeStatus::UNCHANGED
4793 : ChangeStatus::CHANGED;
4799 bool &AllReachingKernelsKnown) {
4803 assert(Caller &&
"Caller is nullptr");
4805 auto *CAA =
A.getOrCreateAAFor<AAKernelInfo>(
4807 if (CAA && CAA->ReachingKernelEntries.isValidState()) {
4808 ReachingKernelEntries ^= CAA->ReachingKernelEntries;
4814 ReachingKernelEntries.indicatePessimisticFixpoint();
4819 if (!
A.checkForAllCallSites(PredCallSite, *
this,
4821 AllReachingKernelsKnown))
4822 ReachingKernelEntries.indicatePessimisticFixpoint();
4827 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
4828 OMPInformationCache::RuntimeFunctionInfo &Parallel51RFI =
4829 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
4834 assert(Caller &&
"Caller is nullptr");
4838 if (CAA && CAA->ParallelLevels.isValidState()) {
4844 if (Caller == Parallel51RFI.Declaration) {
4845 ParallelLevels.indicatePessimisticFixpoint();
4849 ParallelLevels ^= CAA->ParallelLevels;
4856 ParallelLevels.indicatePessimisticFixpoint();
4861 bool AllCallSitesKnown =
true;
4862 if (!
A.checkForAllCallSites(PredCallSite, *
this,
4865 ParallelLevels.indicatePessimisticFixpoint();
4872struct AAKernelInfoCallSite : AAKernelInfo {
4874 : AAKernelInfo(IRP,
A) {}
4878 AAKernelInfo::initialize(
A);
4880 CallBase &CB = cast<CallBase>(getAssociatedValue());
4885 if (AssumptionAA && AssumptionAA->hasAssumption(
"ompx_spmd_amenable")) {
4886 indicateOptimisticFixpoint();
4894 indicateOptimisticFixpoint();
4903 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
4904 const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
4905 if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
4907 if (!Callee || !
A.isFunctionIPOAmendable(*Callee)) {
4911 if (!AssumptionAA ||
4912 !(AssumptionAA->hasAssumption(
"omp_no_openmp") ||
4913 AssumptionAA->hasAssumption(
"omp_no_parallelism")))
4914 ReachedUnknownParallelRegions.insert(&CB);
4918 if (!SPMDCompatibilityTracker.isAtFixpoint()) {
4919 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4920 SPMDCompatibilityTracker.insert(&CB);
4925 indicateOptimisticFixpoint();
4931 if (NumCallees > 1) {
4932 indicatePessimisticFixpoint();
4939 case OMPRTL___kmpc_is_spmd_exec_mode:
4940 case OMPRTL___kmpc_distribute_static_fini:
4941 case OMPRTL___kmpc_for_static_fini:
4942 case OMPRTL___kmpc_global_thread_num:
4943 case OMPRTL___kmpc_get_hardware_num_threads_in_block:
4944 case OMPRTL___kmpc_get_hardware_num_blocks:
4945 case OMPRTL___kmpc_single:
4946 case OMPRTL___kmpc_end_single:
4947 case OMPRTL___kmpc_master:
4948 case OMPRTL___kmpc_end_master:
4949 case OMPRTL___kmpc_barrier:
4950 case OMPRTL___kmpc_nvptx_parallel_reduce_nowait_v2:
4951 case OMPRTL___kmpc_nvptx_teams_reduce_nowait_v2:
4952 case OMPRTL___kmpc_error:
4953 case OMPRTL___kmpc_flush:
4954 case OMPRTL___kmpc_get_hardware_thread_id_in_block:
4955 case OMPRTL___kmpc_get_warp_size:
4956 case OMPRTL_omp_get_thread_num:
4957 case OMPRTL_omp_get_num_threads:
4958 case OMPRTL_omp_get_max_threads:
4959 case OMPRTL_omp_in_parallel:
4960 case OMPRTL_omp_get_dynamic:
4961 case OMPRTL_omp_get_cancellation:
4962 case OMPRTL_omp_get_nested:
4963 case OMPRTL_omp_get_schedule:
4964 case OMPRTL_omp_get_thread_limit:
4965 case OMPRTL_omp_get_supported_active_levels:
4966 case OMPRTL_omp_get_max_active_levels:
4967 case OMPRTL_omp_get_level:
4968 case OMPRTL_omp_get_ancestor_thread_num:
4969 case OMPRTL_omp_get_team_size:
4970 case OMPRTL_omp_get_active_level:
4971 case OMPRTL_omp_in_final:
4972 case OMPRTL_omp_get_proc_bind:
4973 case OMPRTL_omp_get_num_places:
4974 case OMPRTL_omp_get_num_procs:
4975 case OMPRTL_omp_get_place_proc_ids:
4976 case OMPRTL_omp_get_place_num:
4977 case OMPRTL_omp_get_partition_num_places:
4978 case OMPRTL_omp_get_partition_place_nums:
4979 case OMPRTL_omp_get_wtime:
4981 case OMPRTL___kmpc_distribute_static_init_4:
4982 case OMPRTL___kmpc_distribute_static_init_4u:
4983 case OMPRTL___kmpc_distribute_static_init_8:
4984 case OMPRTL___kmpc_distribute_static_init_8u:
4985 case OMPRTL___kmpc_for_static_init_4:
4986 case OMPRTL___kmpc_for_static_init_4u:
4987 case OMPRTL___kmpc_for_static_init_8:
4988 case OMPRTL___kmpc_for_static_init_8u: {
4990 unsigned ScheduleArgOpNo = 2;
4991 auto *ScheduleTypeCI =
4993 unsigned ScheduleTypeVal =
4994 ScheduleTypeCI ? ScheduleTypeCI->getZExtValue() : 0;
4996 case OMPScheduleType::UnorderedStatic:
4997 case OMPScheduleType::UnorderedStaticChunked:
4998 case OMPScheduleType::OrderedDistribute:
4999 case OMPScheduleType::OrderedDistributeChunked:
5002 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5003 SPMDCompatibilityTracker.insert(&CB);
5007 case OMPRTL___kmpc_target_init:
5010 case OMPRTL___kmpc_target_deinit:
5011 KernelDeinitCB = &CB;
5013 case OMPRTL___kmpc_parallel_51:
5014 if (!handleParallel51(
A, CB))
5015 indicatePessimisticFixpoint();
5017 case OMPRTL___kmpc_omp_task:
5019 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5020 SPMDCompatibilityTracker.insert(&CB);
5021 ReachedUnknownParallelRegions.insert(&CB);
5023 case OMPRTL___kmpc_alloc_shared:
5024 case OMPRTL___kmpc_free_shared:
5030 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5031 SPMDCompatibilityTracker.insert(&CB);
5037 indicateOptimisticFixpoint();
5041 A.getAAFor<
AACallEdges>(*
this, getIRPosition(), DepClassTy::OPTIONAL);
5042 if (!AACE || !AACE->getState().isValidState() || AACE->hasUnknownCallee()) {
5043 CheckCallee(getAssociatedFunction(), 1);
5046 const auto &OptimisticEdges = AACE->getOptimisticEdges();
5047 for (
auto *Callee : OptimisticEdges) {
5048 CheckCallee(Callee, OptimisticEdges.size());
5059 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
5060 KernelInfoState StateBefore = getState();
5062 auto CheckCallee = [&](
Function *
F,
int NumCallees) {
5063 const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(
F);
5067 if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
5070 A.getAAFor<AAKernelInfo>(*
this, FnPos, DepClassTy::REQUIRED);
5072 return indicatePessimisticFixpoint();
5073 if (getState() == FnAA->getState())
5074 return ChangeStatus::UNCHANGED;
5075 getState() = FnAA->getState();
5076 return ChangeStatus::CHANGED;
5079 return indicatePessimisticFixpoint();
5081 CallBase &CB = cast<CallBase>(getAssociatedValue());
5082 if (It->getSecond() == OMPRTL___kmpc_parallel_51) {
5083 if (!handleParallel51(
A, CB))
5084 return indicatePessimisticFixpoint();
5085 return StateBefore == getState() ? ChangeStatus::UNCHANGED
5086 : ChangeStatus::CHANGED;
5092 (It->getSecond() == OMPRTL___kmpc_alloc_shared ||
5093 It->getSecond() == OMPRTL___kmpc_free_shared) &&
5094 "Expected a __kmpc_alloc_shared or __kmpc_free_shared runtime call");
5098 auto *HeapToSharedAA =
A.getAAFor<AAHeapToShared>(
5106 case OMPRTL___kmpc_alloc_shared:
5107 if ((!HeapToStackAA || !HeapToStackAA->isAssumedHeapToStack(CB)) &&
5108 (!HeapToSharedAA || !HeapToSharedAA->isAssumedHeapToShared(CB)))
5109 SPMDCompatibilityTracker.insert(&CB);
5111 case OMPRTL___kmpc_free_shared:
5112 if ((!HeapToStackAA ||
5113 !HeapToStackAA->isAssumedHeapToStackRemovedFree(CB)) &&
5115 !HeapToSharedAA->isAssumedHeapToSharedRemovedFree(CB)))
5116 SPMDCompatibilityTracker.insert(&CB);
5119 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5120 SPMDCompatibilityTracker.insert(&CB);
5122 return ChangeStatus::CHANGED;
5126 A.getAAFor<
AACallEdges>(*
this, getIRPosition(), DepClassTy::OPTIONAL);
5127 if (!AACE || !AACE->getState().isValidState() || AACE->hasUnknownCallee()) {
5128 if (
Function *
F = getAssociatedFunction())
5132 for (
auto *Callee : OptimisticEdges) {
5133 CheckCallee(Callee, OptimisticEdges.size());
5139 return StateBefore == getState() ? ChangeStatus::UNCHANGED
5140 : ChangeStatus::CHANGED;
5146 const unsigned int NonWrapperFunctionArgNo = 5;
5147 const unsigned int WrapperFunctionArgNo = 6;
5148 auto ParallelRegionOpArgNo = SPMDCompatibilityTracker.isAssumed()
5149 ? NonWrapperFunctionArgNo
5150 : WrapperFunctionArgNo;
5152 auto *ParallelRegion = dyn_cast<Function>(
5154 if (!ParallelRegion)
5157 ReachedKnownParallelRegions.insert(&CB);
5159 auto *FnAA =
A.getAAFor<AAKernelInfo>(
5161 NestedParallelism |= !FnAA || !FnAA->getState().isValidState() ||
5162 !FnAA->ReachedKnownParallelRegions.empty() ||
5163 !FnAA->ReachedKnownParallelRegions.isValidState() ||
5164 !FnAA->ReachedUnknownParallelRegions.isValidState() ||
5165 !FnAA->ReachedUnknownParallelRegions.empty();
5170struct AAFoldRuntimeCall
5171 :
public StateWrapper<BooleanState, AbstractAttribute> {
5177 void trackStatistics()
const override {}
5180 static AAFoldRuntimeCall &createForPosition(
const IRPosition &IRP,
5184 const std::string
getName()
const override {
return "AAFoldRuntimeCall"; }
5187 const char *getIdAddr()
const override {
return &
ID; }
5195 static const char ID;
5198struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
5200 : AAFoldRuntimeCall(IRP,
A) {}
5203 const std::string getAsStr(
Attributor *)
const override {
5204 if (!isValidState())
5207 std::string Str(
"simplified value: ");
5209 if (!SimplifiedValue)
5210 return Str + std::string(
"none");
5212 if (!*SimplifiedValue)
5213 return Str + std::string(
"nullptr");
5215 if (
ConstantInt *CI = dyn_cast<ConstantInt>(*SimplifiedValue))
5216 return Str + std::to_string(CI->getSExtValue());
5218 return Str + std::string(
"unknown");
5223 indicatePessimisticFixpoint();
5227 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
5228 const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
5229 assert(It != OMPInfoCache.RuntimeFunctionIDMap.end() &&
5230 "Expected a known OpenMP runtime function");
5232 RFKind = It->getSecond();
5234 CallBase &CB = cast<CallBase>(getAssociatedValue());
5235 A.registerSimplificationCallback(
5238 bool &UsedAssumedInformation) -> std::optional<Value *> {
5239 assert((isValidState() ||
5240 (SimplifiedValue && *SimplifiedValue ==
nullptr)) &&
5241 "Unexpected invalid state!");
5243 if (!isAtFixpoint()) {
5244 UsedAssumedInformation =
true;
5246 A.recordDependence(*
this, *AA, DepClassTy::OPTIONAL);
5248 return SimplifiedValue;
5255 case OMPRTL___kmpc_is_spmd_exec_mode:
5256 Changed |= foldIsSPMDExecMode(
A);
5258 case OMPRTL___kmpc_parallel_level:
5259 Changed |= foldParallelLevel(
A);
5261 case OMPRTL___kmpc_get_hardware_num_threads_in_block:
5262 Changed = Changed | foldKernelFnAttribute(
A,
"omp_target_thread_limit");
5264 case OMPRTL___kmpc_get_hardware_num_blocks:
5265 Changed = Changed | foldKernelFnAttribute(
A,
"omp_target_num_teams");
5277 if (SimplifiedValue && *SimplifiedValue) {
5280 A.deleteAfterManifest(
I);
5284 if (
auto *
C = dyn_cast<ConstantInt>(*SimplifiedValue))
5285 return OR <<
"Replacing OpenMP runtime call "
5287 <<
ore::NV(
"FoldedValue",
C->getZExtValue()) <<
".";
5288 return OR <<
"Replacing OpenMP runtime call "
5296 << **SimplifiedValue <<
"\n");
5298 Changed = ChangeStatus::CHANGED;
5305 SimplifiedValue =
nullptr;
5306 return AAFoldRuntimeCall::indicatePessimisticFixpoint();
5312 std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
5314 unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
5315 unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
5316 auto *CallerKernelInfoAA =
A.getAAFor<AAKernelInfo>(
5319 if (!CallerKernelInfoAA ||
5320 !CallerKernelInfoAA->ReachingKernelEntries.isValidState())
5321 return indicatePessimisticFixpoint();
5323 for (
Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
5325 DepClassTy::REQUIRED);
5327 if (!AA || !AA->isValidState()) {
5328 SimplifiedValue =
nullptr;
5329 return indicatePessimisticFixpoint();
5332 if (AA->SPMDCompatibilityTracker.isAssumed()) {
5333 if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5338 if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5339 ++KnownNonSPMDCount;
5341 ++AssumedNonSPMDCount;
5345 if ((AssumedSPMDCount + KnownSPMDCount) &&
5346 (AssumedNonSPMDCount + KnownNonSPMDCount))
5347 return indicatePessimisticFixpoint();
5349 auto &Ctx = getAnchorValue().getContext();
5350 if (KnownSPMDCount || AssumedSPMDCount) {
5351 assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
5352 "Expected only SPMD kernels!");
5356 }
else if (KnownNonSPMDCount || AssumedNonSPMDCount) {
5357 assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
5358 "Expected only non-SPMD kernels!");
5366 assert(!SimplifiedValue &&
"SimplifiedValue should be none");
5369 return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
5370 : ChangeStatus::CHANGED;
5375 std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
5377 auto *CallerKernelInfoAA =
A.getAAFor<AAKernelInfo>(