49#include "llvm/IR/IntrinsicsAMDGPU.h"
50#include "llvm/IR/IntrinsicsNVPTX.h"
66#define DEBUG_TYPE "openmp-opt"
69 "openmp-opt-disable",
cl::desc(
"Disable OpenMP specific optimizations."),
73 "openmp-opt-enable-merging",
79 cl::desc(
"Disable function internalization."),
90 "openmp-hide-memory-transfer-latency",
91 cl::desc(
"[WIP] Tries to hide the latency of host to device memory"
96 "openmp-opt-disable-deglobalization",
97 cl::desc(
"Disable OpenMP optimizations involving deglobalization."),
101 "openmp-opt-disable-spmdization",
102 cl::desc(
"Disable OpenMP optimizations involving SPMD-ization."),
106 "openmp-opt-disable-folding",
111 "openmp-opt-disable-state-machine-rewrite",
112 cl::desc(
"Disable OpenMP optimizations that replace the state machine."),
116 "openmp-opt-disable-barrier-elimination",
117 cl::desc(
"Disable OpenMP optimizations that eliminate barriers."),
121 "openmp-opt-print-module-after",
122 cl::desc(
"Print the current module after OpenMP optimizations."),
126 "openmp-opt-print-module-before",
127 cl::desc(
"Print the current module before OpenMP optimizations."),
131 "openmp-opt-inline-device",
142 cl::desc(
"Maximal number of attributor iterations."),
147 cl::desc(
"Maximum amount of shared memory to use."),
148 cl::init(std::numeric_limits<unsigned>::max()));
151 "Number of OpenMP runtime calls deduplicated");
153 "Number of OpenMP parallel regions deleted");
155 "Number of OpenMP runtime functions identified");
157 "Number of OpenMP runtime function uses identified");
159 "Number of OpenMP target region entry points (=kernels) identified");
161 "Number of non-OpenMP target region kernels identified");
163 "Number of OpenMP target region entry points (=kernels) executed in "
164 "SPMD-mode instead of generic-mode");
165STATISTIC(NumOpenMPTargetRegionKernelsWithoutStateMachine,
166 "Number of OpenMP target region entry points (=kernels) executed in "
167 "generic-mode without a state machines");
168STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback,
169 "Number of OpenMP target region entry points (=kernels) executed in "
170 "generic-mode with customized state machines with fallback");
171STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback,
172 "Number of OpenMP target region entry points (=kernels) executed in "
173 "generic-mode with customized state machines without fallback");
175 NumOpenMPParallelRegionsReplacedInGPUStateMachine,
176 "Number of OpenMP parallel regions replaced with ID in GPU state machines");
178 "Number of OpenMP parallel regions merged");
180 "Amount of memory pushed to shared memory");
181STATISTIC(NumBarriersEliminated,
"Number of redundant barriers eliminated");
209#define KERNEL_ENVIRONMENT_IDX(MEMBER, IDX) \
210 constexpr const unsigned MEMBER##Idx = IDX;
215#undef KERNEL_ENVIRONMENT_IDX
217#define KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MEMBER, IDX) \
218 constexpr const unsigned MEMBER##Idx = IDX;
228#undef KERNEL_ENVIRONMENT_CONFIGURATION_IDX
230#define KERNEL_ENVIRONMENT_GETTER(MEMBER, RETURNTYPE) \
231 RETURNTYPE *get##MEMBER##FromKernelEnvironment(ConstantStruct *KernelEnvC) { \
232 return cast<RETURNTYPE>(KernelEnvC->getAggregateElement(MEMBER##Idx)); \
238#undef KERNEL_ENVIRONMENT_GETTER
240#define KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MEMBER) \
241 ConstantInt *get##MEMBER##FromKernelEnvironment( \
242 ConstantStruct *KernelEnvC) { \
243 ConstantStruct *ConfigC = \
244 getConfigurationFromKernelEnvironment(KernelEnvC); \
245 return dyn_cast<ConstantInt>(ConfigC->getAggregateElement(MEMBER##Idx)); \
256#undef KERNEL_ENVIRONMENT_CONFIGURATION_GETTER
260 constexpr const int InitKernelEnvironmentArgNo = 0;
261 return cast<GlobalVariable>(
275struct AAHeapToShared;
286 OpenMPPostLink(OpenMPPostLink) {
289 OMPBuilder.initialize();
290 initializeRuntimeFunctions(M);
291 initializeInternalControlVars();
295 struct InternalControlVarInfo {
322 struct RuntimeFunctionInfo {
346 void clearUsesMap() { UsesMap.
clear(); }
349 operator bool()
const {
return Declaration; }
352 UseVector &getOrCreateUseVector(
Function *
F) {
353 std::shared_ptr<UseVector> &UV = UsesMap[
F];
355 UV = std::make_shared<UseVector>();
361 const UseVector *getUseVector(
Function &
F)
const {
362 auto I = UsesMap.find(&
F);
363 if (
I != UsesMap.end())
364 return I->second.get();
369 size_t getNumFunctionsWithUses()
const {
return UsesMap.size(); }
373 size_t getNumArgs()
const {
return ArgumentTypes.
size(); }
391 UseVector &UV = getOrCreateUseVector(
F);
401 while (!ToBeDeleted.
empty()) {
415 decltype(UsesMap)::iterator
begin() {
return UsesMap.
begin(); }
416 decltype(UsesMap)::iterator
end() {
return UsesMap.
end(); }
424 RuntimeFunction::OMPRTL___last>
432 InternalControlVar::ICV___last>
437 void initializeInternalControlVars() {
438#define ICV_RT_SET(_Name, RTL) \
440 auto &ICV = ICVs[_Name]; \
443#define ICV_RT_GET(Name, RTL) \
445 auto &ICV = ICVs[Name]; \
448#define ICV_DATA_ENV(Enum, _Name, _EnvVarName, Init) \
450 auto &ICV = ICVs[Enum]; \
453 ICV.InitKind = Init; \
454 ICV.EnvVarName = _EnvVarName; \
455 switch (ICV.InitKind) { \
456 case ICV_IMPLEMENTATION_DEFINED: \
457 ICV.InitValue = nullptr; \
460 ICV.InitValue = ConstantInt::get( \
461 Type::getInt32Ty(OMPBuilder.Int32->getContext()), 0); \
464 ICV.InitValue = ConstantInt::getFalse(OMPBuilder.Int1->getContext()); \
470#include "llvm/Frontend/OpenMP/OMPKinds.def"
476 static bool declMatchesRTFTypes(
Function *
F,
Type *RTFRetType,
483 if (
F->getReturnType() != RTFRetType)
485 if (
F->arg_size() != RTFArgTypes.
size())
488 auto *RTFTyIt = RTFArgTypes.
begin();
490 if (Arg.getType() != *RTFTyIt)
500 unsigned collectUses(RuntimeFunctionInfo &RFI,
bool CollectStats =
true) {
501 unsigned NumUses = 0;
502 if (!RFI.Declaration)
507 NumOpenMPRuntimeFunctionsIdentified += 1;
508 NumOpenMPRuntimeFunctionUsesIdentified += RFI.Declaration->getNumUses();
512 for (
Use &U : RFI.Declaration->uses()) {
513 if (
Instruction *UserI = dyn_cast<Instruction>(
U.getUser())) {
514 if (!
CGSCC ||
CGSCC->empty() ||
CGSCC->contains(UserI->getFunction())) {
515 RFI.getOrCreateUseVector(UserI->getFunction()).push_back(&U);
519 RFI.getOrCreateUseVector(
nullptr).push_back(&U);
528 auto &RFI = RFIs[RTF];
530 collectUses(RFI,
false);
534 void recollectUses() {
535 for (
int Idx = 0;
Idx < RFIs.size(); ++
Idx)
555 RuntimeFunctionInfo &RFI = RFIs[Fn];
557 if (RFI.Declaration && RFI.Declaration->isDeclaration())
565 void initializeRuntimeFunctions(
Module &M) {
568#define OMP_TYPE(VarName, ...) \
569 Type *VarName = OMPBuilder.VarName; \
572#define OMP_ARRAY_TYPE(VarName, ...) \
573 ArrayType *VarName##Ty = OMPBuilder.VarName##Ty; \
575 PointerType *VarName##PtrTy = OMPBuilder.VarName##PtrTy; \
576 (void)VarName##PtrTy;
578#define OMP_FUNCTION_TYPE(VarName, ...) \
579 FunctionType *VarName = OMPBuilder.VarName; \
581 PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \
584#define OMP_STRUCT_TYPE(VarName, ...) \
585 StructType *VarName = OMPBuilder.VarName; \
587 PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \
590#define OMP_RTL(_Enum, _Name, _IsVarArg, _ReturnType, ...) \
592 SmallVector<Type *, 8> ArgsTypes({__VA_ARGS__}); \
593 Function *F = M.getFunction(_Name); \
594 RTLFunctions.insert(F); \
595 if (declMatchesRTFTypes(F, OMPBuilder._ReturnType, ArgsTypes)) { \
596 RuntimeFunctionIDMap[F] = _Enum; \
597 auto &RFI = RFIs[_Enum]; \
600 RFI.IsVarArg = _IsVarArg; \
601 RFI.ReturnType = OMPBuilder._ReturnType; \
602 RFI.ArgumentTypes = std::move(ArgsTypes); \
603 RFI.Declaration = F; \
604 unsigned NumUses = collectUses(RFI); \
607 dbgs() << TAG << RFI.Name << (RFI.Declaration ? "" : " not") \
609 if (RFI.Declaration) \
610 dbgs() << TAG << "-> got " << NumUses << " uses in " \
611 << RFI.getNumFunctionsWithUses() \
612 << " different functions.\n"; \
616#include "llvm/Frontend/OpenMP/OMPKinds.def"
622 for (
StringRef Prefix : {
"__kmpc",
"_ZN4ompx",
"omp_"})
623 if (
F.hasFnAttribute(Attribute::NoInline) &&
624 F.getName().starts_with(Prefix) &&
625 !
F.hasFnAttribute(Attribute::OptimizeNone))
626 F.removeFnAttr(Attribute::NoInline);
637 bool OpenMPPostLink =
false;
640template <
typename Ty,
bool InsertInval
idates = true>
642 bool contains(
const Ty &Elem)
const {
return Set.contains(Elem); }
643 bool insert(
const Ty &Elem) {
644 if (InsertInvalidates)
646 return Set.insert(Elem);
649 const Ty &operator[](
int Idx)
const {
return Set[
Idx]; }
650 bool operator==(
const BooleanStateWithSetVector &RHS)
const {
651 return BooleanState::operator==(RHS) &&
Set ==
RHS.Set;
653 bool operator!=(
const BooleanStateWithSetVector &RHS)
const {
654 return !(*
this ==
RHS);
657 bool empty()
const {
return Set.empty(); }
658 size_t size()
const {
return Set.size(); }
661 BooleanStateWithSetVector &
operator^=(
const BooleanStateWithSetVector &RHS) {
662 BooleanState::operator^=(RHS);
663 Set.insert(
RHS.Set.begin(),
RHS.Set.end());
672 typename decltype(
Set)::iterator
begin() {
return Set.begin(); }
673 typename decltype(
Set)::iterator
end() {
return Set.end(); }
678template <
typename Ty,
bool InsertInval
idates = true>
679using BooleanStateWithPtrSetVector =
680 BooleanStateWithSetVector<Ty *, InsertInvalidates>;
684 bool IsAtFixpoint =
false;
688 BooleanStateWithPtrSetVector<
CallBase,
false>
689 ReachedKnownParallelRegions;
692 BooleanStateWithPtrSetVector<CallBase> ReachedUnknownParallelRegions;
697 BooleanStateWithPtrSetVector<Instruction, false> SPMDCompatibilityTracker;
712 bool IsKernelEntry =
false;
715 BooleanStateWithPtrSetVector<Function, false> ReachingKernelEntries;
720 BooleanStateWithSetVector<uint8_t> ParallelLevels;
723 bool NestedParallelism =
false;
728 KernelInfoState() =
default;
729 KernelInfoState(
bool BestState) {
738 bool isAtFixpoint()
const override {
return IsAtFixpoint; }
743 ParallelLevels.indicatePessimisticFixpoint();
744 ReachingKernelEntries.indicatePessimisticFixpoint();
745 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
746 ReachedKnownParallelRegions.indicatePessimisticFixpoint();
747 ReachedUnknownParallelRegions.indicatePessimisticFixpoint();
748 NestedParallelism =
true;
755 ParallelLevels.indicateOptimisticFixpoint();
756 ReachingKernelEntries.indicateOptimisticFixpoint();
757 SPMDCompatibilityTracker.indicateOptimisticFixpoint();
758 ReachedKnownParallelRegions.indicateOptimisticFixpoint();
759 ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
764 KernelInfoState &getAssumed() {
return *
this; }
765 const KernelInfoState &getAssumed()
const {
return *
this; }
767 bool operator==(
const KernelInfoState &RHS)
const {
768 if (SPMDCompatibilityTracker !=
RHS.SPMDCompatibilityTracker)
770 if (ReachedKnownParallelRegions !=
RHS.ReachedKnownParallelRegions)
772 if (ReachedUnknownParallelRegions !=
RHS.ReachedUnknownParallelRegions)
774 if (ReachingKernelEntries !=
RHS.ReachingKernelEntries)
776 if (ParallelLevels !=
RHS.ParallelLevels)
778 if (NestedParallelism !=
RHS.NestedParallelism)
784 bool mayContainParallelRegion() {
785 return !ReachedKnownParallelRegions.empty() ||
786 !ReachedUnknownParallelRegions.empty();
790 static KernelInfoState getBestState() {
return KernelInfoState(
true); }
792 static KernelInfoState getBestState(KernelInfoState &KIS) {
793 return getBestState();
797 static KernelInfoState getWorstState() {
return KernelInfoState(
false); }
800 KernelInfoState
operator^=(
const KernelInfoState &KIS) {
802 if (KIS.KernelInitCB) {
803 if (KernelInitCB && KernelInitCB != KIS.KernelInitCB)
806 KernelInitCB = KIS.KernelInitCB;
808 if (KIS.KernelDeinitCB) {
809 if (KernelDeinitCB && KernelDeinitCB != KIS.KernelDeinitCB)
812 KernelDeinitCB = KIS.KernelDeinitCB;
814 if (KIS.KernelEnvC) {
815 if (KernelEnvC && KernelEnvC != KIS.KernelEnvC)
818 KernelEnvC = KIS.KernelEnvC;
820 SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker;
821 ReachedKnownParallelRegions ^= KIS.ReachedKnownParallelRegions;
822 ReachedUnknownParallelRegions ^= KIS.ReachedUnknownParallelRegions;
823 NestedParallelism |= KIS.NestedParallelism;
827 KernelInfoState
operator&=(
const KernelInfoState &KIS) {
828 return (*
this ^= KIS);
844 OffloadArray() =
default;
851 if (!
Array.getAllocatedType()->isArrayTy())
854 if (!getValues(Array,
Before))
857 this->Array = &
Array;
861 static const unsigned DeviceIDArgNum = 1;
862 static const unsigned BasePtrsArgNum = 3;
863 static const unsigned PtrsArgNum = 4;
864 static const unsigned SizesArgNum = 5;
872 const uint64_t NumValues =
Array.getAllocatedType()->getArrayNumElements();
873 StoredValues.
assign(NumValues,
nullptr);
874 LastAccesses.
assign(NumValues,
nullptr);
879 if (BB !=
Before.getParent())
889 if (!isa<StoreInst>(&
I))
892 auto *S = cast<StoreInst>(&
I);
899 LastAccesses[
Idx] = S;
909 const unsigned NumValues = StoredValues.
size();
910 for (
unsigned I = 0;
I < NumValues; ++
I) {
911 if (!StoredValues[
I] || !LastAccesses[
I])
921 using OptimizationRemarkGetter =
925 OptimizationRemarkGetter OREGetter,
926 OMPInformationCache &OMPInfoCache,
Attributor &A)
928 OREGetter(OREGetter), OMPInfoCache(OMPInfoCache),
A(
A) {}
931 bool remarksEnabled() {
932 auto &Ctx =
M.getContext();
933 return Ctx.getDiagHandlerPtr()->isAnyRemarkEnabled(
DEBUG_TYPE);
937 bool run(
bool IsModulePass) {
941 bool Changed =
false;
947 Changed |= runAttributor(IsModulePass);
950 OMPInfoCache.recollectUses();
953 Changed |= rewriteDeviceCodeStateMachine();
955 if (remarksEnabled())
956 analysisGlobalization();
963 Changed |= runAttributor(IsModulePass);
966 OMPInfoCache.recollectUses();
968 Changed |= deleteParallelRegions();
971 Changed |= hideMemTransfersLatency();
972 Changed |= deduplicateRuntimeCalls();
974 if (mergeParallelRegions()) {
975 deduplicateRuntimeCalls();
981 if (OMPInfoCache.OpenMPPostLink)
982 Changed |= removeRuntimeSymbols();
989 void printICVs()
const {
994 for (
auto ICV : ICVs) {
995 auto ICVInfo = OMPInfoCache.ICVs[ICV];
997 return ORA <<
"OpenMP ICV " <<
ore::NV(
"OpenMPICV", ICVInfo.Name)
999 << (ICVInfo.InitValue
1000 ?
toString(ICVInfo.InitValue->getValue(), 10,
true)
1001 :
"IMPLEMENTATION_DEFINED");
1004 emitRemark<OptimizationRemarkAnalysis>(
F,
"OpenMPICVTracker",
Remark);
1010 void printKernels()
const {
1016 return ORA <<
"OpenMP GPU kernel "
1017 <<
ore::NV(
"OpenMPGPUKernel",
F->getName()) <<
"\n";
1020 emitRemark<OptimizationRemarkAnalysis>(
F,
"OpenMPGPU",
Remark);
1026 static CallInst *getCallIfRegularCall(
1027 Use &U, OMPInformationCache::RuntimeFunctionInfo *RFI =
nullptr) {
1028 CallInst *CI = dyn_cast<CallInst>(
U.getUser());
1038 static CallInst *getCallIfRegularCall(
1039 Value &V, OMPInformationCache::RuntimeFunctionInfo *RFI =
nullptr) {
1040 CallInst *CI = dyn_cast<CallInst>(&V);
1050 bool mergeParallelRegions() {
1051 const unsigned CallbackCalleeOperand = 2;
1052 const unsigned CallbackFirstArgOperand = 3;
1056 OMPInformationCache::RuntimeFunctionInfo &RFI =
1057 OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
1059 if (!RFI.Declaration)
1063 OMPInformationCache::RuntimeFunctionInfo UnmergableCallsInfo[] = {
1064 OMPInfoCache.RFIs[OMPRTL___kmpc_push_proc_bind],
1065 OMPInfoCache.RFIs[OMPRTL___kmpc_push_num_threads],
1068 bool Changed =
false;
1074 BasicBlock *StartBB =
nullptr, *EndBB =
nullptr;
1075 auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
1076 BasicBlock *CGStartBB = CodeGenIP.getBlock();
1078 SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
1079 assert(StartBB !=
nullptr &&
"StartBB should not be null");
1081 assert(EndBB !=
nullptr &&
"EndBB should not be null");
1082 EndBB->getTerminator()->setSuccessor(0, CGEndBB);
1085 auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
Value &,
1086 Value &Inner,
Value *&ReplacementValue) -> InsertPointTy {
1087 ReplacementValue = &Inner;
1091 auto FiniCB = [&](InsertPointTy CodeGenIP) {};
1095 auto CreateSequentialRegion = [&](
Function *OuterFn,
1103 SplitBlock(ParentBB, SeqEndI->getNextNode(), DT, LI);
1107 SplitBlock(ParentBB, SeqStartI, DT, LI,
nullptr,
"seq.par.merged");
1110 "Expected a different CFG");
1114 auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
1115 BasicBlock *CGStartBB = CodeGenIP.getBlock();
1117 SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
1118 assert(SeqStartBB !=
nullptr &&
"SeqStartBB should not be null");
1120 assert(SeqEndBB !=
nullptr &&
"SeqEndBB should not be null");
1123 auto FiniCB = [&](InsertPointTy CodeGenIP) {};
1129 for (
User *Usr :
I.users()) {
1137 OutsideUsers.
insert(&UsrI);
1140 if (OutsideUsers.
empty())
1147 I.getType(),
DL.getAllocaAddrSpace(),
nullptr,
1148 I.getName() +
".seq.output.alloc", OuterFn->
front().
begin());
1152 new StoreInst(&
I, AllocaI, SeqStartBB->getTerminator()->getIterator());
1158 I.getName() +
".seq.output.load",
1165 InsertPointTy(ParentBB, ParentBB->
end()),
DL);
1166 InsertPointTy SeqAfterIP =
1167 OMPInfoCache.OMPBuilder.createMaster(Loc, BodyGenCB, FiniCB);
1169 OMPInfoCache.OMPBuilder.createBarrier(SeqAfterIP, OMPD_parallel);
1188 assert(MergableCIs.
size() > 1 &&
"Assumed multiple mergable CIs");
1191 OR <<
"Parallel region merged with parallel region"
1192 << (MergableCIs.
size() > 2 ?
"s" :
"") <<
" at ";
1195 if (CI != MergableCIs.
back())
1201 emitRemark<OptimizationRemark>(MergableCIs.
front(),
"OMP150",
Remark);
1205 <<
" parallel regions in " << OriginalFn->
getName()
1209 EndBB =
SplitBlock(BB, MergableCIs.
back()->getNextNode(), DT, LI);
1211 SplitBlock(EndBB, &*EndBB->getFirstInsertionPt(), DT, LI);
1215 assert(BB->getUniqueSuccessor() == StartBB &&
"Expected a different CFG");
1216 const DebugLoc DL = BB->getTerminator()->getDebugLoc();
1221 for (
auto *It = MergableCIs.
begin(), *
End = MergableCIs.
end() - 1;
1230 CreateSequentialRegion(OriginalFn, BB, ForkCI->
getNextNode(),
1241 InsertPointTy AfterIP = OMPInfoCache.OMPBuilder.createParallel(
1242 Loc, AllocaIP, BodyGenCB, PrivCB, FiniCB,
nullptr,
nullptr,
1243 OMP_PROC_BIND_default,
false);
1247 OMPInfoCache.OMPBuilder.finalize(OriginalFn);
1254 for (
auto *CI : MergableCIs) {
1256 FunctionType *FT = OMPInfoCache.OMPBuilder.ParallelTask;
1260 for (
unsigned U = CallbackFirstArgOperand, E = CI->
arg_size(); U < E;
1270 for (
unsigned U = CallbackFirstArgOperand, E = CI->
arg_size(); U < E;
1274 U - (CallbackFirstArgOperand - CallbackCalleeOperand), A);
1277 if (CI != MergableCIs.back()) {
1280 OMPInfoCache.OMPBuilder.createBarrier(
1289 assert(OutlinedFn != OriginalFn &&
"Outlining failed");
1290 CGUpdater.registerOutlinedFunction(*OriginalFn, *OutlinedFn);
1291 CGUpdater.reanalyzeFunction(*OriginalFn);
1293 NumOpenMPParallelRegionsMerged += MergableCIs.size();
1301 CallInst *CI = getCallIfRegularCall(U, &RFI);
1308 RFI.foreachUse(SCC, DetectPRsCB);
1314 for (
auto &It : BB2PRMap) {
1315 auto &CIs = It.getSecond();
1330 auto IsMergable = [&](
Instruction &
I,
bool IsBeforeMergableRegion) {
1333 if (
I.isTerminator())
1336 if (!isa<CallInst>(&
I))
1340 if (IsBeforeMergableRegion) {
1342 if (!CalledFunction)
1349 for (
const auto &RFI : UnmergableCallsInfo) {
1350 if (CalledFunction == RFI.Declaration)
1358 if (!isa<IntrinsicInst>(CI))
1369 if (CIs.count(&
I)) {
1375 if (IsMergable(
I, MergableCIs.
empty()))
1380 for (; It !=
End; ++It) {
1382 if (CIs.count(&SkipI)) {
1384 <<
" due to " <<
I <<
"\n");
1391 if (MergableCIs.
size() > 1) {
1392 MergableCIsVector.
push_back(MergableCIs);
1394 <<
" parallel regions in block " << BB->
getName()
1399 MergableCIs.
clear();
1402 if (!MergableCIsVector.
empty()) {
1405 for (
auto &MergableCIs : MergableCIsVector)
1406 Merge(MergableCIs, BB);
1407 MergableCIsVector.clear();
1414 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_fork_call);
1415 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_barrier);
1416 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_master);
1417 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_end_master);
1424 bool deleteParallelRegions() {
1425 const unsigned CallbackCalleeOperand = 2;
1427 OMPInformationCache::RuntimeFunctionInfo &RFI =
1428 OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
1430 if (!RFI.Declaration)
1433 bool Changed =
false;
1435 CallInst *CI = getCallIfRegularCall(U);
1438 auto *Fn = dyn_cast<Function>(
1442 if (!Fn->onlyReadsMemory())
1444 if (!Fn->hasFnAttribute(Attribute::WillReturn))
1451 return OR <<
"Removing parallel region with no side-effects.";
1453 emitRemark<OptimizationRemark>(CI,
"OMP160",
Remark);
1457 ++NumOpenMPParallelRegionsDeleted;
1461 RFI.foreachUse(SCC, DeleteCallCB);
1467 bool deduplicateRuntimeCalls() {
1468 bool Changed =
false;
1471 OMPRTL_omp_get_num_threads,
1472 OMPRTL_omp_in_parallel,
1473 OMPRTL_omp_get_cancellation,
1474 OMPRTL_omp_get_supported_active_levels,
1475 OMPRTL_omp_get_level,
1476 OMPRTL_omp_get_ancestor_thread_num,
1477 OMPRTL_omp_get_team_size,
1478 OMPRTL_omp_get_active_level,
1479 OMPRTL_omp_in_final,
1480 OMPRTL_omp_get_proc_bind,
1481 OMPRTL_omp_get_num_places,
1482 OMPRTL_omp_get_num_procs,
1483 OMPRTL_omp_get_place_num,
1484 OMPRTL_omp_get_partition_num_places,
1485 OMPRTL_omp_get_partition_place_nums};
1489 collectGlobalThreadIdArguments(GTIdArgs);
1491 <<
" global thread ID arguments\n");
1494 for (
auto DeduplicableRuntimeCallID : DeduplicableRuntimeCallIDs)
1495 Changed |= deduplicateRuntimeCalls(
1496 *
F, OMPInfoCache.RFIs[DeduplicableRuntimeCallID]);
1500 Value *GTIdArg =
nullptr;
1502 if (GTIdArgs.
count(&Arg)) {
1506 Changed |= deduplicateRuntimeCalls(
1507 *
F, OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num], GTIdArg);
1514 bool removeRuntimeSymbols() {
1520 if (!
GV->getType()->isPointerTy())
1528 GlobalVariable *Client = dyn_cast<GlobalVariable>(
C->stripPointerCasts());
1537 GV->eraseFromParent();
1550 bool hideMemTransfersLatency() {
1551 auto &RFI = OMPInfoCache.RFIs[OMPRTL___tgt_target_data_begin_mapper];
1552 bool Changed =
false;
1554 auto *RTCall = getCallIfRegularCall(U, &RFI);
1558 OffloadArray OffloadArrays[3];
1559 if (!getValuesInOffloadArrays(*RTCall, OffloadArrays))
1562 LLVM_DEBUG(dumpValuesInOffloadArrays(OffloadArrays));
1565 bool WasSplit =
false;
1566 Instruction *WaitMovementPoint = canBeMovedDownwards(*RTCall);
1567 if (WaitMovementPoint)
1568 WasSplit = splitTargetDataBeginRTC(*RTCall, *WaitMovementPoint);
1570 Changed |= WasSplit;
1573 if (OMPInfoCache.runtimeFnsAvailable(
1574 {OMPRTL___tgt_target_data_begin_mapper_issue,
1575 OMPRTL___tgt_target_data_begin_mapper_wait}))
1576 RFI.foreachUse(SCC, SplitMemTransfers);
1581 void analysisGlobalization() {
1582 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
1584 auto CheckGlobalization = [&](
Use &
U,
Function &Decl) {
1585 if (
CallInst *CI = getCallIfRegularCall(U, &RFI)) {
1588 <<
"Found thread data sharing on the GPU. "
1589 <<
"Expect degraded performance due to data globalization.";
1591 emitRemark<OptimizationRemarkMissed>(CI,
"OMP112",
Remark);
1597 RFI.foreachUse(SCC, CheckGlobalization);
1602 bool getValuesInOffloadArrays(
CallInst &RuntimeCall,
1604 assert(OAs.
size() == 3 &&
"Need space for three offload arrays!");
1614 Value *BasePtrsArg =
1623 if (!isa<AllocaInst>(V))
1625 auto *BasePtrsArray = cast<AllocaInst>(V);
1626 if (!OAs[0].
initialize(*BasePtrsArray, RuntimeCall))
1631 if (!isa<AllocaInst>(V))
1633 auto *PtrsArray = cast<AllocaInst>(V);
1634 if (!OAs[1].
initialize(*PtrsArray, RuntimeCall))
1640 if (isa<GlobalValue>(V))
1641 return isa<Constant>(V);
1642 if (!isa<AllocaInst>(V))
1645 auto *SizesArray = cast<AllocaInst>(V);
1646 if (!OAs[2].
initialize(*SizesArray, RuntimeCall))
1657 assert(OAs.
size() == 3 &&
"There are three offload arrays to debug!");
1660 std::string ValuesStr;
1662 std::string Separator =
" --- ";
1664 for (
auto *BP : OAs[0].StoredValues) {
1668 LLVM_DEBUG(
dbgs() <<
"\t\toffload_baseptrs: " << ValuesStr <<
"\n");
1671 for (
auto *
P : OAs[1].StoredValues) {
1678 for (
auto *S : OAs[2].StoredValues) {
1682 LLVM_DEBUG(
dbgs() <<
"\t\toffload_sizes: " << ValuesStr <<
"\n");
1692 bool IsWorthIt =
false;
1711 return RuntimeCall.
getParent()->getTerminator();
1715 bool splitTargetDataBeginRTC(
CallInst &RuntimeCall,
1720 auto &
IRBuilder = OMPInfoCache.OMPBuilder;
1724 Entry.getFirstNonPHIOrDbgOrAlloca());
1726 IRBuilder.AsyncInfo,
nullptr,
"handle");
1734 M, OMPRTL___tgt_target_data_begin_mapper_issue);
1738 for (
auto &Arg : RuntimeCall.
args())
1739 Args.push_back(Arg.get());
1740 Args.push_back(Handle);
1744 OMPInfoCache.setCallingConvention(IssueDecl, IssueCallsite);
1750 M, OMPRTL___tgt_target_data_begin_mapper_wait);
1752 Value *WaitParams[2] = {
1754 OffloadArray::DeviceIDArgNum),
1758 WaitDecl, WaitParams,
"", WaitMovementPoint.
getIterator());
1759 OMPInfoCache.setCallingConvention(WaitDecl, WaitCallsite);
1764 static Value *combinedIdentStruct(
Value *CurrentIdent,
Value *NextIdent,
1765 bool GlobalOnly,
bool &SingleChoice) {
1766 if (CurrentIdent == NextIdent)
1767 return CurrentIdent;
1771 if (!GlobalOnly || isa<GlobalValue>(NextIdent)) {
1772 SingleChoice = !CurrentIdent;
1784 getCombinedIdentFromCallUsesIn(OMPInformationCache::RuntimeFunctionInfo &RFI,
1786 bool SingleChoice =
true;
1787 Value *Ident =
nullptr;
1789 CallInst *CI = getCallIfRegularCall(U, &RFI);
1790 if (!CI || &
F != &Caller)
1793 true, SingleChoice);
1796 RFI.foreachUse(SCC, CombineIdentStruct);
1798 if (!Ident || !SingleChoice) {
1801 if (!OMPInfoCache.OMPBuilder.getInsertionPoint().getBlock())
1803 &
F.getEntryBlock(),
F.getEntryBlock().begin()));
1808 OMPInfoCache.OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
1809 Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(Loc, SrcLocStrSize);
1816 bool deduplicateRuntimeCalls(
Function &
F,
1817 OMPInformationCache::RuntimeFunctionInfo &RFI,
1818 Value *ReplVal =
nullptr) {
1819 auto *UV = RFI.getUseVector(
F);
1820 if (!UV || UV->size() + (ReplVal !=
nullptr) < 2)
1824 dbgs() <<
TAG <<
"Deduplicate " << UV->size() <<
" uses of " << RFI.Name
1825 << (ReplVal ?
" with an existing value\n" :
"\n") <<
"\n");
1827 assert((!ReplVal || (isa<Argument>(ReplVal) &&
1828 cast<Argument>(ReplVal)->
getParent() == &
F)) &&
1829 "Unexpected replacement value!");
1832 auto CanBeMoved = [
this](
CallBase &CB) {
1833 unsigned NumArgs = CB.arg_size();
1836 if (CB.getArgOperand(0)->getType() != OMPInfoCache.OMPBuilder.IdentPtr)
1838 for (
unsigned U = 1;
U < NumArgs; ++
U)
1839 if (isa<Instruction>(CB.getArgOperand(U)))
1850 for (
Use *U : *UV) {
1851 if (
CallInst *CI = getCallIfRegularCall(*U, &RFI)) {
1856 if (!CanBeMoved(*CI))
1864 assert(IP &&
"Expected insertion point!");
1865 cast<Instruction>(ReplVal)->moveBefore(IP);
1871 if (
CallBase *CI = dyn_cast<CallBase>(ReplVal)) {
1874 Value *Ident = getCombinedIdentFromCallUsesIn(RFI,
F,
1880 bool Changed =
false;
1882 CallInst *CI = getCallIfRegularCall(U, &RFI);
1883 if (!CI || CI == ReplVal || &
F != &Caller)
1888 return OR <<
"OpenMP runtime call "
1889 <<
ore::NV(
"OpenMPOptRuntime", RFI.Name) <<
" deduplicated.";
1892 emitRemark<OptimizationRemark>(CI,
"OMP170",
Remark);
1894 emitRemark<OptimizationRemark>(&
F,
"OMP170",
Remark);
1898 ++NumOpenMPRuntimeCallsDeduplicated;
1902 RFI.foreachUse(SCC, ReplaceAndDeleteCB);
1916 if (!
F.hasLocalLinkage())
1918 for (
Use &U :
F.uses()) {
1919 if (
CallInst *CI = getCallIfRegularCall(U)) {
1921 if (CI == &RefCI || GTIdArgs.
count(ArgOp) ||
1922 getCallIfRegularCall(
1923 *ArgOp, &OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num]))
1932 auto AddUserArgs = [&](
Value >Id) {
1933 for (
Use &U : GTId.uses())
1934 if (
CallInst *CI = dyn_cast<CallInst>(
U.getUser()))
1937 if (CallArgOpIsGTId(*Callee,
U.getOperandNo(), *CI))
1942 OMPInformationCache::RuntimeFunctionInfo &GlobThreadNumRFI =
1943 OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num];
1945 GlobThreadNumRFI.foreachUse(SCC, [&](
Use &U,
Function &
F) {
1946 if (
CallInst *CI = getCallIfRegularCall(U, &GlobThreadNumRFI))
1954 for (
unsigned U = 0;
U < GTIdArgs.
size(); ++
U)
1955 AddUserArgs(*GTIdArgs[U]);
1970 return getUniqueKernelFor(*
I.getFunction());
1975 bool rewriteDeviceCodeStateMachine();
1991 template <
typename RemarkKind,
typename RemarkCallBack>
1993 RemarkCallBack &&RemarkCB)
const {
1995 auto &ORE = OREGetter(
F);
1999 return RemarkCB(RemarkKind(
DEBUG_TYPE, RemarkName,
I))
2000 <<
" [" << RemarkName <<
"]";
2004 [&]() {
return RemarkCB(RemarkKind(
DEBUG_TYPE, RemarkName,
I)); });
2008 template <
typename RemarkKind,
typename RemarkCallBack>
2010 RemarkCallBack &&RemarkCB)
const {
2011 auto &ORE = OREGetter(
F);
2015 return RemarkCB(RemarkKind(
DEBUG_TYPE, RemarkName,
F))
2016 <<
" [" << RemarkName <<
"]";
2020 [&]() {
return RemarkCB(RemarkKind(
DEBUG_TYPE, RemarkName,
F)); });
2034 OptimizationRemarkGetter OREGetter;
2037 OMPInformationCache &OMPInfoCache;
2043 bool runAttributor(
bool IsModulePass) {
2047 registerAAs(IsModulePass);
2052 <<
" functions, result: " << Changed <<
".\n");
2054 if (Changed == ChangeStatus::CHANGED)
2055 OMPInfoCache.invalidateAnalyses();
2057 return Changed == ChangeStatus::CHANGED;
2064 void registerAAs(
bool IsModulePass);
2073 if (OMPInfoCache.CGSCC && !OMPInfoCache.CGSCC->empty() &&
2074 !OMPInfoCache.CGSCC->contains(&
F))
2079 std::optional<Kernel> &CachedKernel = UniqueKernelMap[&
F];
2081 return *CachedKernel;
2088 return *CachedKernel;
2091 CachedKernel =
nullptr;
2092 if (!
F.hasLocalLinkage()) {
2096 return ORA <<
"Potentially unknown OpenMP target region caller.";
2098 emitRemark<OptimizationRemarkAnalysis>(&
F,
"OMP100",
Remark);
2104 auto GetUniqueKernelForUse = [&](
const Use &
U) ->
Kernel {
2105 if (
auto *Cmp = dyn_cast<ICmpInst>(
U.getUser())) {
2107 if (
Cmp->isEquality())
2108 return getUniqueKernelFor(*Cmp);
2111 if (
auto *CB = dyn_cast<CallBase>(
U.getUser())) {
2113 if (CB->isCallee(&U))
2114 return getUniqueKernelFor(*CB);
2116 OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
2117 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
2119 if (OpenMPOpt::getCallIfRegularCall(*
U.getUser(), &KernelParallelRFI))
2120 return getUniqueKernelFor(*CB);
2129 OMPInformationCache::foreachUse(
F, [&](
const Use &U) {
2130 PotentialKernels.
insert(GetUniqueKernelForUse(U));
2134 if (PotentialKernels.
size() == 1)
2135 K = *PotentialKernels.
begin();
2138 UniqueKernelMap[&
F] =
K;
2143bool OpenMPOpt::rewriteDeviceCodeStateMachine() {
2144 OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
2145 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
2147 bool Changed =
false;
2148 if (!KernelParallelRFI)
2159 bool UnknownUse =
false;
2160 bool KernelParallelUse =
false;
2161 unsigned NumDirectCalls = 0;
2164 OMPInformationCache::foreachUse(*
F, [&](
Use &U) {
2165 if (
auto *CB = dyn_cast<CallBase>(
U.getUser()))
2166 if (CB->isCallee(&U)) {
2171 if (isa<ICmpInst>(
U.getUser())) {
2172 ToBeReplacedStateMachineUses.push_back(&U);
2178 OpenMPOpt::getCallIfRegularCall(*
U.getUser(), &KernelParallelRFI);
2179 const unsigned int WrapperFunctionArgNo = 6;
2180 if (!KernelParallelUse && CI &&
2182 KernelParallelUse = true;
2183 ToBeReplacedStateMachineUses.push_back(&U);
2191 if (!KernelParallelUse)
2197 if (UnknownUse || NumDirectCalls != 1 ||
2198 ToBeReplacedStateMachineUses.
size() > 2) {
2200 return ORA <<
"Parallel region is used in "
2201 << (UnknownUse ?
"unknown" :
"unexpected")
2202 <<
" ways. Will not attempt to rewrite the state machine.";
2204 emitRemark<OptimizationRemarkAnalysis>(
F,
"OMP101",
Remark);
2213 return ORA <<
"Parallel region is not called from a unique kernel. "
2214 "Will not attempt to rewrite the state machine.";
2216 emitRemark<OptimizationRemarkAnalysis>(
F,
"OMP102",
Remark);
2232 for (
Use *U : ToBeReplacedStateMachineUses)
2234 ID,
U->get()->getType()));
2236 ++NumOpenMPParallelRegionsReplacedInGPUStateMachine;
2245struct AAICVTracker :
public StateWrapper<BooleanState, AbstractAttribute> {
2250 bool isAssumedTracked()
const {
return getAssumed(); }
2253 bool isKnownTracked()
const {
return getAssumed(); }
2262 return std::nullopt;
2268 virtual std::optional<Value *>
2276 const std::string
getName()
const override {
return "AAICVTracker"; }
2279 const char *getIdAddr()
const override {
return &
ID; }
2286 static const char ID;
2289struct AAICVTrackerFunction :
public AAICVTracker {
2291 : AAICVTracker(IRP,
A) {}
2294 const std::string getAsStr(
Attributor *)
const override {
2295 return "ICVTrackerFunction";
2299 void trackStatistics()
const override {}
2303 return ChangeStatus::UNCHANGED;
2308 InternalControlVar::ICV___last>
2309 ICVReplacementValuesMap;
2316 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
2319 auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2321 auto &ValuesMap = ICVReplacementValuesMap[ICV];
2323 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U);
2329 if (ValuesMap.insert(std::make_pair(CI, CI->
getArgOperand(0))).second)
2330 HasChanged = ChangeStatus::CHANGED;
2336 std::optional<Value *> ReplVal = getValueForCall(
A,
I, ICV);
2337 if (ReplVal && ValuesMap.insert(std::make_pair(&
I, *ReplVal)).second)
2338 HasChanged = ChangeStatus::CHANGED;
2344 SetterRFI.foreachUse(TrackValues,
F);
2346 bool UsedAssumedInformation =
false;
2347 A.checkForAllInstructions(CallCheck, *
this, {Instruction::Call},
2348 UsedAssumedInformation,
2354 if (HasChanged == ChangeStatus::CHANGED && !ValuesMap.count(Entry))
2355 ValuesMap.insert(std::make_pair(Entry,
nullptr));
2366 const auto *CB = dyn_cast<CallBase>(&
I);
2367 if (!CB || CB->hasFnAttr(
"no_openmp") ||
2368 CB->hasFnAttr(
"no_openmp_routines"))
2369 return std::nullopt;
2371 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
2372 auto &GetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Getter];
2373 auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2374 Function *CalledFunction = CB->getCalledFunction();
2377 if (CalledFunction ==
nullptr)
2379 if (CalledFunction == GetterRFI.Declaration)
2380 return std::nullopt;
2381 if (CalledFunction == SetterRFI.Declaration) {
2382 if (ICVReplacementValuesMap[ICV].
count(&
I))
2383 return ICVReplacementValuesMap[ICV].
lookup(&
I);
2392 const auto *ICVTrackingAA =
A.getAAFor<AAICVTracker>(
2395 if (ICVTrackingAA->isAssumedTracked()) {
2396 std::optional<Value *> URV =
2397 ICVTrackingAA->getUniqueReplacementValue(ICV);
2408 std::optional<Value *>
2410 return std::nullopt;
2417 const auto &ValuesMap = ICVReplacementValuesMap[ICV];
2418 if (ValuesMap.count(
I))
2419 return ValuesMap.lookup(
I);
2425 std::optional<Value *> ReplVal;
2427 while (!Worklist.
empty()) {
2429 if (!Visited.
insert(CurrInst).second)
2437 if (ValuesMap.count(CurrInst)) {
2438 std::optional<Value *> NewReplVal = ValuesMap.lookup(CurrInst);
2441 ReplVal = NewReplVal;
2447 if (ReplVal != NewReplVal)
2453 std::optional<Value *> NewReplVal = getValueForCall(
A, *CurrInst, ICV);
2459 ReplVal = NewReplVal;
2465 if (ReplVal != NewReplVal)
2470 if (CurrBB ==
I->getParent() && ReplVal)
2475 if (
const Instruction *Terminator = Pred->getTerminator())
2483struct AAICVTrackerFunctionReturned : AAICVTracker {
2485 : AAICVTracker(IRP,
A) {}
2488 const std::string getAsStr(
Attributor *)
const override {
2489 return "ICVTrackerFunctionReturned";
2493 void trackStatistics()
const override {}
2497 return ChangeStatus::UNCHANGED;
2502 InternalControlVar::ICV___last>
2503 ICVReplacementValuesMap;
2506 std::optional<Value *>
2508 return ICVReplacementValuesMap[ICV];
2513 const auto *ICVTrackingAA =
A.getAAFor<AAICVTracker>(
2516 if (!ICVTrackingAA->isAssumedTracked())
2517 return indicatePessimisticFixpoint();
2520 std::optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2521 std::optional<Value *> UniqueICVValue;
2524 std::optional<Value *> NewReplVal =
2525 ICVTrackingAA->getReplacementValue(ICV, &
I,
A);
2528 if (UniqueICVValue && UniqueICVValue != NewReplVal)
2531 UniqueICVValue = NewReplVal;
2536 bool UsedAssumedInformation =
false;
2537 if (!
A.checkForAllInstructions(CheckReturnInst, *
this, {Instruction::Ret},
2538 UsedAssumedInformation,
2540 UniqueICVValue =
nullptr;
2542 if (UniqueICVValue == ReplVal)
2545 ReplVal = UniqueICVValue;
2546 Changed = ChangeStatus::CHANGED;
2553struct AAICVTrackerCallSite : AAICVTracker {
2555 : AAICVTracker(IRP,
A) {}
2558 assert(getAnchorScope() &&
"Expected anchor function");
2562 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
2564 auto ICVInfo = OMPInfoCache.ICVs[ICV];
2565 auto &Getter = OMPInfoCache.RFIs[ICVInfo.Getter];
2566 if (Getter.Declaration == getAssociatedFunction()) {
2567 AssociatedICV = ICVInfo.Kind;
2573 indicatePessimisticFixpoint();
2577 if (!ReplVal || !*ReplVal)
2578 return ChangeStatus::UNCHANGED;
2581 A.deleteAfterManifest(*getCtxI());
2583 return ChangeStatus::CHANGED;
2587 const std::string getAsStr(
Attributor *)
const override {
2588 return "ICVTrackerCallSite";
2592 void trackStatistics()
const override {}
2595 std::optional<Value *> ReplVal;
2598 const auto *ICVTrackingAA =
A.getAAFor<AAICVTracker>(
2602 if (!ICVTrackingAA->isAssumedTracked())
2603 return indicatePessimisticFixpoint();
2605 std::optional<Value *> NewReplVal =
2606 ICVTrackingAA->getReplacementValue(AssociatedICV, getCtxI(),
A);
2608 if (ReplVal == NewReplVal)
2609 return ChangeStatus::UNCHANGED;
2611 ReplVal = NewReplVal;
2612 return ChangeStatus::CHANGED;
2617 std::optional<Value *>
2623struct AAICVTrackerCallSiteReturned : AAICVTracker {
2625 : AAICVTracker(IRP,
A) {}
2628 const std::string getAsStr(
Attributor *)
const override {
2629 return "ICVTrackerCallSiteReturned";
2633 void trackStatistics()
const override {}
2637 return ChangeStatus::UNCHANGED;
2642 InternalControlVar::ICV___last>
2643 ICVReplacementValuesMap;
2647 std::optional<Value *>
2649 return ICVReplacementValuesMap[ICV];
2654 const auto *ICVTrackingAA =
A.getAAFor<AAICVTracker>(
2656 DepClassTy::REQUIRED);
2659 if (!ICVTrackingAA->isAssumedTracked())
2660 return indicatePessimisticFixpoint();
2663 std::optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2664 std::optional<Value *> NewReplVal =
2665 ICVTrackingAA->getUniqueReplacementValue(ICV);
2667 if (ReplVal == NewReplVal)
2670 ReplVal = NewReplVal;
2671 Changed = ChangeStatus::CHANGED;
2679static bool hasFunctionEndAsUniqueSuccessor(
const BasicBlock *BB) {
2685 return hasFunctionEndAsUniqueSuccessor(
Successor);
2692 ~AAExecutionDomainFunction() {
delete RPOT; }
2696 assert(
F &&
"Expected anchor function");
2701 unsigned TotalBlocks = 0, InitialThreadBlocks = 0, AlignedBlocks = 0;
2702 for (
auto &It : BEDMap) {
2706 InitialThreadBlocks += It.getSecond().IsExecutedByInitialThreadOnly;
2707 AlignedBlocks += It.getSecond().IsReachedFromAlignedBarrierOnly &&
2708 It.getSecond().IsReachingAlignedBarrierOnly;
2710 return "[AAExecutionDomain] " + std::to_string(InitialThreadBlocks) +
"/" +
2711 std::to_string(AlignedBlocks) +
" of " +
2712 std::to_string(TotalBlocks) +
2713 " executed by initial thread / aligned";
2725 << BB.
getName() <<
" is executed by a single thread.\n";
2735 auto HandleAlignedBarrier = [&](
CallBase *CB) {
2736 const ExecutionDomainTy &ED = CB ? CEDMap[{CB, PRE}] : BEDMap[
nullptr];
2737 if (!ED.IsReachedFromAlignedBarrierOnly ||
2738 ED.EncounteredNonLocalSideEffect)
2740 if (!ED.EncounteredAssumes.empty() && !
A.isModulePass())
2751 DeletedBarriers.
insert(CB);
2752 A.deleteAfterManifest(*CB);
2753 ++NumBarriersEliminated;
2755 }
else if (!ED.AlignedBarriers.empty()) {
2758 ED.AlignedBarriers.end());
2760 while (!Worklist.
empty()) {
2762 if (!Visited.
insert(LastCB))
2766 if (!hasFunctionEndAsUniqueSuccessor(LastCB->
getParent()))
2768 if (!DeletedBarriers.
count(LastCB)) {
2769 ++NumBarriersEliminated;
2770 A.deleteAfterManifest(*LastCB);
2776 const ExecutionDomainTy &LastED = CEDMap[{LastCB, PRE}];
2777 Worklist.
append(LastED.AlignedBarriers.begin(),
2778 LastED.AlignedBarriers.end());
2784 if (!ED.EncounteredAssumes.empty() && (CB || !ED.AlignedBarriers.empty()))
2785 for (
auto *AssumeCB : ED.EncounteredAssumes)
2786 A.deleteAfterManifest(*AssumeCB);
2789 for (
auto *CB : AlignedBarriers)
2790 HandleAlignedBarrier(CB);
2794 HandleAlignedBarrier(
nullptr);
2806 mergeInPredecessorBarriersAndAssumptions(
Attributor &
A, ExecutionDomainTy &ED,
2807 const ExecutionDomainTy &PredED);
2812 bool mergeInPredecessor(
Attributor &
A, ExecutionDomainTy &ED,
2813 const ExecutionDomainTy &PredED,
2814 bool InitialEdgeOnly =
false);
2817 bool handleCallees(
Attributor &
A, ExecutionDomainTy &EntryBBED);
2827 assert(BB.
getParent() == getAnchorScope() &&
"Block is out of scope!");
2828 return BEDMap.lookup(&BB).IsExecutedByInitialThreadOnly;
2833 assert(
I.getFunction() == getAnchorScope() &&
2834 "Instruction is out of scope!");
2838 bool ForwardIsOk =
true;
2844 auto *CB = dyn_cast<CallBase>(CurI);
2847 if (CB != &
I && AlignedBarriers.contains(
const_cast<CallBase *
>(CB)))
2849 const auto &It = CEDMap.find({CB, PRE});
2850 if (It == CEDMap.end())
2852 if (!It->getSecond().IsReachingAlignedBarrierOnly)
2853 ForwardIsOk =
false;
2857 if (!CurI && !BEDMap.lookup(
I.getParent()).IsReachingAlignedBarrierOnly)
2858 ForwardIsOk =
false;
2863 auto *CB = dyn_cast<CallBase>(CurI);
2866 if (CB != &
I && AlignedBarriers.contains(
const_cast<CallBase *
>(CB)))
2868 const auto &It = CEDMap.find({CB, POST});
2869 if (It == CEDMap.end())
2871 if (It->getSecond().IsReachedFromAlignedBarrierOnly)
2884 return BEDMap.lookup(
nullptr).IsReachedFromAlignedBarrierOnly;
2886 return BEDMap.lookup(PredBB).IsReachedFromAlignedBarrierOnly;
2898 "No request should be made against an invalid state!");
2899 return BEDMap.lookup(&BB);
2901 std::pair<ExecutionDomainTy, ExecutionDomainTy>
2904 "No request should be made against an invalid state!");
2905 return {CEDMap.lookup({&CB, PRE}), CEDMap.lookup({&CB, POST})};
2909 "No request should be made against an invalid state!");
2910 return InterProceduralED;
2924 if (!Cmp || !
Cmp->isTrueWhenEqual() || !
Cmp->isEquality())
2932 if (
C->isAllOnesValue()) {
2933 auto *CB = dyn_cast<CallBase>(
Cmp->getOperand(0));
2934 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
2935 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
2936 CB = CB ? OpenMPOpt::getCallIfRegularCall(*CB, &RFI) : nullptr;
2942 KernelInfo::getExecModeFromKernelEnvironment(KernelEnvC);
2948 if (
auto *
II = dyn_cast<IntrinsicInst>(
Cmp->getOperand(0)))
2949 if (
II->getIntrinsicID() == Intrinsic::nvvm_read_ptx_sreg_tid_x)
2953 if (
auto *
II = dyn_cast<IntrinsicInst>(
Cmp->getOperand(0)))
2954 if (
II->getIntrinsicID() == Intrinsic::amdgcn_workitem_id_x)
2962 ExecutionDomainTy InterProceduralED;
2974 static bool setAndRecord(
bool &R,
bool V) {
2985void AAExecutionDomainFunction::mergeInPredecessorBarriersAndAssumptions(
2986 Attributor &
A, ExecutionDomainTy &ED,
const ExecutionDomainTy &PredED) {
2987 for (
auto *EA : PredED.EncounteredAssumes)
2988 ED.addAssumeInst(
A, *EA);
2990 for (
auto *AB : PredED.AlignedBarriers)
2991 ED.addAlignedBarrier(
A, *AB);
2994bool AAExecutionDomainFunction::mergeInPredecessor(
2995 Attributor &
A, ExecutionDomainTy &ED,
const ExecutionDomainTy &PredED,
2996 bool InitialEdgeOnly) {
2998 bool Changed =
false;
3000 setAndRecord(ED.IsExecutedByInitialThreadOnly,
3001 InitialEdgeOnly || (PredED.IsExecutedByInitialThreadOnly &&
3002 ED.IsExecutedByInitialThreadOnly));
3004 Changed |= setAndRecord(ED.IsReachedFromAlignedBarrierOnly,
3005 ED.IsReachedFromAlignedBarrierOnly &&
3006 PredED.IsReachedFromAlignedBarrierOnly);
3007 Changed |= setAndRecord(ED.EncounteredNonLocalSideEffect,
3008 ED.EncounteredNonLocalSideEffect |
3009 PredED.EncounteredNonLocalSideEffect);
3011 if (ED.IsReachedFromAlignedBarrierOnly)
3012 mergeInPredecessorBarriersAndAssumptions(
A, ED, PredED);
3014 ED.clearAssumeInstAndAlignedBarriers();
3018bool AAExecutionDomainFunction::handleCallees(
Attributor &
A,
3019 ExecutionDomainTy &EntryBBED) {
3024 DepClassTy::OPTIONAL);
3025 if (!EDAA || !EDAA->getState().isValidState())
3028 EDAA->getExecutionDomain(*cast<CallBase>(ACS.getInstruction())));
3032 ExecutionDomainTy ExitED;
3033 bool AllCallSitesKnown;
3034 if (
A.checkForAllCallSites(PredForCallSite, *
this,
3036 AllCallSitesKnown)) {
3037 for (
const auto &[CSInED, CSOutED] : CallSiteEDs) {
3038 mergeInPredecessor(
A, EntryBBED, CSInED);
3039 ExitED.IsReachingAlignedBarrierOnly &=
3040 CSOutED.IsReachingAlignedBarrierOnly;
3047 EntryBBED.IsExecutedByInitialThreadOnly =
false;
3048 EntryBBED.IsReachedFromAlignedBarrierOnly =
true;
3049 EntryBBED.EncounteredNonLocalSideEffect =
false;
3050 ExitED.IsReachingAlignedBarrierOnly =
false;
3052 EntryBBED.IsExecutedByInitialThreadOnly =
false;
3053 EntryBBED.IsReachedFromAlignedBarrierOnly =
false;
3054 EntryBBED.EncounteredNonLocalSideEffect =
true;
3055 ExitED.IsReachingAlignedBarrierOnly =
false;
3059 bool Changed =
false;
3060 auto &FnED = BEDMap[
nullptr];
3061 Changed |= setAndRecord(FnED.IsReachedFromAlignedBarrierOnly,
3062 FnED.IsReachedFromAlignedBarrierOnly &
3063 EntryBBED.IsReachedFromAlignedBarrierOnly);
3064 Changed |= setAndRecord(FnED.IsReachingAlignedBarrierOnly,
3065 FnED.IsReachingAlignedBarrierOnly &
3066 ExitED.IsReachingAlignedBarrierOnly);
3067 Changed |= setAndRecord(FnED.IsExecutedByInitialThreadOnly,
3068 EntryBBED.IsExecutedByInitialThreadOnly);
3074 bool Changed =
false;
3079 auto HandleAlignedBarrier = [&](
CallBase &CB, ExecutionDomainTy &ED) {
3080 Changed |= AlignedBarriers.insert(&CB);
3082 auto &CallInED = CEDMap[{&CB, PRE}];
3083 Changed |= mergeInPredecessor(
A, CallInED, ED);
3084 CallInED.IsReachingAlignedBarrierOnly =
true;
3086 ED.EncounteredNonLocalSideEffect =
false;
3087 ED.IsReachedFromAlignedBarrierOnly =
true;
3089 ED.clearAssumeInstAndAlignedBarriers();
3090 ED.addAlignedBarrier(
A, CB);
3091 auto &CallOutED = CEDMap[{&CB, POST}];
3092 Changed |= mergeInPredecessor(
A, CallOutED, ED);
3096 A.getAAFor<
AAIsDead>(*
this, getIRPosition(), DepClassTy::OPTIONAL);
3103 for (
auto &RIt : *RPOT) {
3106 bool IsEntryBB = &BB == &EntryBB;
3109 bool AlignedBarrierLastInBlock = IsEntryBB && IsKernel;
3110 bool IsExplicitlyAligned = IsEntryBB && IsKernel;
3111 ExecutionDomainTy ED;
3114 Changed |= handleCallees(
A, ED);
3118 if (LivenessAA && LivenessAA->isAssumedDead(&BB))
3122 if (LivenessAA && LivenessAA->isEdgeDead(PredBB, &BB))
3124 bool InitialEdgeOnly = isInitialThreadOnlyEdge(
3125 A, dyn_cast<BranchInst>(PredBB->getTerminator()), BB);
3126 mergeInPredecessor(
A, ED, BEDMap[PredBB], InitialEdgeOnly);
3133 bool UsedAssumedInformation;
3134 if (
A.isAssumedDead(
I, *
this, LivenessAA, UsedAssumedInformation,
3135 false, DepClassTy::OPTIONAL,
3141 if (
auto *
II = dyn_cast<IntrinsicInst>(&
I)) {
3142 if (
auto *AI = dyn_cast_or_null<AssumeInst>(
II)) {
3143 ED.addAssumeInst(
A, *AI);
3147 if (
II->isAssumeLikeIntrinsic())
3151 if (
auto *FI = dyn_cast<FenceInst>(&
I)) {
3152 if (!ED.EncounteredNonLocalSideEffect) {
3154 if (ED.IsReachedFromAlignedBarrierOnly)
3159 case AtomicOrdering::NotAtomic:
3161 case AtomicOrdering::Unordered:
3163 case AtomicOrdering::Monotonic:
3165 case AtomicOrdering::Acquire:
3167 case AtomicOrdering::Release:
3169 case AtomicOrdering::AcquireRelease:
3171 case AtomicOrdering::SequentiallyConsistent:
3175 NonNoOpFences.insert(FI);
3178 auto *CB = dyn_cast<CallBase>(&
I);
3180 bool IsAlignedBarrier =
3184 AlignedBarrierLastInBlock &= IsNoSync;
3185 IsExplicitlyAligned &= IsNoSync;
3191 if (IsAlignedBarrier) {
3192 HandleAlignedBarrier(*CB, ED);
3193 AlignedBarrierLastInBlock =
true;
3194 IsExplicitlyAligned =
true;
3199 if (isa<MemIntrinsic>(&
I)) {
3200 if (!ED.EncounteredNonLocalSideEffect &&
3202 ED.EncounteredNonLocalSideEffect =
true;
3204 ED.IsReachedFromAlignedBarrierOnly =
false;
3212 auto &CallInED = CEDMap[{CB, PRE}];
3213 Changed |= mergeInPredecessor(
A, CallInED, ED);
3219 if (!IsNoSync && Callee && !
Callee->isDeclaration()) {
3222 if (EDAA && EDAA->getState().isValidState()) {
3225 CalleeED.IsReachedFromAlignedBarrierOnly;
3226 AlignedBarrierLastInBlock = ED.IsReachedFromAlignedBarrierOnly;
3227 if (IsNoSync || !CalleeED.IsReachedFromAlignedBarrierOnly)
3228 ED.EncounteredNonLocalSideEffect |=
3229 CalleeED.EncounteredNonLocalSideEffect;
3231 ED.EncounteredNonLocalSideEffect =
3232 CalleeED.EncounteredNonLocalSideEffect;
3233 if (!CalleeED.IsReachingAlignedBarrierOnly) {
3235 setAndRecord(CallInED.IsReachingAlignedBarrierOnly,
false);
3238 if (CalleeED.IsReachedFromAlignedBarrierOnly)
3239 mergeInPredecessorBarriersAndAssumptions(
A, ED, CalleeED);
3240 auto &CallOutED = CEDMap[{CB, POST}];
3241 Changed |= mergeInPredecessor(
A, CallOutED, ED);
3246 ED.IsReachedFromAlignedBarrierOnly =
false;
3247 Changed |= setAndRecord(CallInED.IsReachingAlignedBarrierOnly,
false);
3250 AlignedBarrierLastInBlock &= ED.IsReachedFromAlignedBarrierOnly;
3252 auto &CallOutED = CEDMap[{CB, POST}];
3253 Changed |= mergeInPredecessor(
A, CallOutED, ED);
3256 if (!
I.mayHaveSideEffects() && !
I.mayReadFromMemory())
3270 if (MemAA && MemAA->getState().isValidState() &&
3271 MemAA->checkForAllAccessesToMemoryKind(
3276 auto &InfoCache =
A.getInfoCache();
3277 if (!
I.mayHaveSideEffects() && InfoCache.isOnlyUsedByAssume(
I))
3280 if (
auto *LI = dyn_cast<LoadInst>(&
I))
3281 if (LI->hasMetadata(LLVMContext::MD_invariant_load))
3284 if (!ED.EncounteredNonLocalSideEffect &&
3286 ED.EncounteredNonLocalSideEffect =
true;
3289 bool IsEndAndNotReachingAlignedBarriersOnly =
false;
3290 if (!isa<UnreachableInst>(BB.getTerminator()) &&
3291 !BB.getTerminator()->getNumSuccessors()) {
3293 Changed |= mergeInPredecessor(
A, InterProceduralED, ED);
3295 auto &FnED = BEDMap[
nullptr];
3296 if (IsKernel && !IsExplicitlyAligned)
3297 FnED.IsReachingAlignedBarrierOnly =
false;
3298 Changed |= mergeInPredecessor(
A, FnED, ED);
3300 if (!FnED.IsReachingAlignedBarrierOnly) {
3301 IsEndAndNotReachingAlignedBarriersOnly =
true;
3302 SyncInstWorklist.
push_back(BB.getTerminator());
3303 auto &BBED = BEDMap[&BB];
3304 Changed |= setAndRecord(BBED.IsReachingAlignedBarrierOnly,
false);
3308 ExecutionDomainTy &StoredED = BEDMap[&BB];
3309 ED.IsReachingAlignedBarrierOnly = StoredED.IsReachingAlignedBarrierOnly &
3310 !IsEndAndNotReachingAlignedBarriersOnly;
3316 if (ED.IsExecutedByInitialThreadOnly !=
3317 StoredED.IsExecutedByInitialThreadOnly ||
3318 ED.IsReachedFromAlignedBarrierOnly !=
3319 StoredED.IsReachedFromAlignedBarrierOnly ||
3320 ED.EncounteredNonLocalSideEffect !=
3321 StoredED.EncounteredNonLocalSideEffect)
3325 StoredED = std::move(ED);
3331 while (!SyncInstWorklist.
empty()) {
3334 bool HitAlignedBarrierOrKnownEnd =
false;
3336 auto *CB = dyn_cast<CallBase>(CurInst);
3339 auto &CallOutED = CEDMap[{CB, POST}];
3340 Changed |= setAndRecord(CallOutED.IsReachingAlignedBarrierOnly,
false);
3341 auto &CallInED = CEDMap[{CB, PRE}];
3342 HitAlignedBarrierOrKnownEnd =
3343 AlignedBarriers.count(CB) || !CallInED.IsReachingAlignedBarrierOnly;
3344 if (HitAlignedBarrierOrKnownEnd)
3346 Changed |= setAndRecord(CallInED.IsReachingAlignedBarrierOnly,
false);
3348 if (HitAlignedBarrierOrKnownEnd)
3352 if (LivenessAA && LivenessAA->isEdgeDead(PredBB, SyncBB))
3354 if (!Visited.
insert(PredBB))
3356 auto &PredED = BEDMap[PredBB];
3357 if (setAndRecord(PredED.IsReachingAlignedBarrierOnly,
false)) {
3359 SyncInstWorklist.
push_back(PredBB->getTerminator());
3362 if (SyncBB != &EntryBB)
3365 setAndRecord(InterProceduralED.IsReachingAlignedBarrierOnly,
false);
3368 return Changed ? ChangeStatus::CHANGED : ChangeStatus::UNCHANGED;
3373struct AAHeapToShared :
public StateWrapper<BooleanState, AbstractAttribute> {
3378 static AAHeapToShared &createForPosition(
const IRPosition &IRP,
3382 virtual bool isAssumedHeapToShared(
CallBase &CB)
const = 0;
3386 virtual bool isAssumedHeapToSharedRemovedFree(
CallBase &CB)
const = 0;
3389 const std::string
getName()
const override {
return "AAHeapToShared"; }
3392 const char *getIdAddr()
const override {
return &
ID; }
3401 static const char ID;
3404struct AAHeapToSharedFunction :
public AAHeapToShared {
3406 : AAHeapToShared(IRP,
A) {}
3408 const std::string getAsStr(
Attributor *)
const override {
3409 return "[AAHeapToShared] " + std::to_string(MallocCalls.size()) +
3410 " malloc calls eligible.";
3414 void trackStatistics()
const override {}
3418 void findPotentialRemovedFreeCalls(
Attributor &
A) {
3419 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
3420 auto &FreeRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
3422 PotentialRemovedFreeCalls.clear();
3426 for (
auto *U : CB->
users()) {
3428 if (
C &&
C->getCalledFunction() == FreeRFI.Declaration)
3432 if (FreeCalls.
size() != 1)
3435 PotentialRemovedFreeCalls.insert(FreeCalls.
front());
3441 indicatePessimisticFixpoint();
3445 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
3446 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3447 if (!RFI.Declaration)
3452 bool &) -> std::optional<Value *> {
return nullptr; };
3455 for (
User *U : RFI.Declaration->
users())
3456 if (
CallBase *CB = dyn_cast<CallBase>(U)) {
3459 MallocCalls.insert(CB);
3464 findPotentialRemovedFreeCalls(
A);
3467 bool isAssumedHeapToShared(
CallBase &CB)
const override {
3468 return isValidState() && MallocCalls.count(&CB);
3471 bool isAssumedHeapToSharedRemovedFree(
CallBase &CB)
const override {
3472 return isValidState() && PotentialRemovedFreeCalls.count(&CB);
3476 if (MallocCalls.empty())
3477 return ChangeStatus::UNCHANGED;
3479 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
3480 auto &FreeCall = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
3484 DepClassTy::OPTIONAL);
3489 if (HS &&
HS->isAssumedHeapToStack(*CB))
3494 for (
auto *U : CB->
users()) {
3496 if (
C &&
C->getCalledFunction() == FreeCall.Declaration)
3499 if (FreeCalls.
size() != 1)
3506 <<
" with shared memory."
3507 <<
" Shared memory usage is limited to "
3513 <<
" with " << AllocSize->getZExtValue()
3514 <<
" bytes of shared memory\n");
3520 Type *Int8ArrTy = ArrayType::get(Int8Ty, AllocSize->getZExtValue());
3525 static_cast<unsigned>(AddressSpace::Shared));
3530 return OR <<
"Replaced globalized variable with "
3531 <<
ore::NV(
"SharedMemory", AllocSize->getZExtValue())
3532 << (AllocSize->isOne() ?
" byte " :
" bytes ")
3533 <<
"of shared memory.";
3539 "HeapToShared on allocation without alignment attribute");
3540 SharedMem->setAlignment(*Alignment);
3543 A.deleteAfterManifest(*CB);
3544 A.deleteAfterManifest(*FreeCalls.
front());
3546 SharedMemoryUsed += AllocSize->getZExtValue();
3547 NumBytesMovedToSharedMemory = SharedMemoryUsed;
3548 Changed = ChangeStatus::CHANGED;
3555 if (MallocCalls.empty())
3556 return indicatePessimisticFixpoint();
3557 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
3558 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3559 if (!RFI.Declaration)
3560 return ChangeStatus::UNCHANGED;
3564 auto NumMallocCalls = MallocCalls.size();
3567 for (
User *U : RFI.Declaration->
users()) {
3568 if (
CallBase *CB = dyn_cast<CallBase>(U)) {
3569 if (CB->getCaller() !=
F)
3571 if (!MallocCalls.count(CB))
3573 if (!isa<ConstantInt>(CB->getArgOperand(0))) {
3574 MallocCalls.remove(CB);
3579 if (!ED || !ED->isExecutedByInitialThreadOnly(*CB))
3580 MallocCalls.remove(CB);
3584 findPotentialRemovedFreeCalls(
A);
3586 if (NumMallocCalls != MallocCalls.size())
3587 return ChangeStatus::CHANGED;
3589 return ChangeStatus::UNCHANGED;
3597 unsigned SharedMemoryUsed = 0;
3600struct AAKernelInfo :
public StateWrapper<KernelInfoState, AbstractAttribute> {
3606 static bool requiresCalleeForCallBase() {
return false; }
3609 void trackStatistics()
const override {}
3612 const std::string getAsStr(
Attributor *)
const override {
3613 if (!isValidState())
3615 return std::string(SPMDCompatibilityTracker.isAssumed() ?
"SPMD"
3617 std::string(SPMDCompatibilityTracker.isAtFixpoint() ?
" [FIX]"
3619 std::string(
" #PRs: ") +
3620 (ReachedKnownParallelRegions.isValidState()
3621 ? std::to_string(ReachedKnownParallelRegions.size())
3623 ", #Unknown PRs: " +
3624 (ReachedUnknownParallelRegions.isValidState()
3627 ", #Reaching Kernels: " +
3628 (ReachingKernelEntries.isValidState()
3632 (ParallelLevels.isValidState()
3635 ", NestedPar: " + (NestedParallelism ?
"yes" :
"no");
3642 const std::string
getName()
const override {
return "AAKernelInfo"; }
3645 const char *getIdAddr()
const override {
return &
ID; }
3652 static const char ID;
3657struct AAKernelInfoFunction : AAKernelInfo {
3659 : AAKernelInfo(IRP,
A) {}
3664 return GuardedInstructions;
3667 void setConfigurationOfKernelEnvironment(
ConstantStruct *ConfigC) {
3669 KernelEnvC, ConfigC, {KernelInfo::ConfigurationIdx});
3670 assert(NewKernelEnvC &&
"Failed to create new kernel environment");
3671 KernelEnvC = cast<ConstantStruct>(NewKernelEnvC);
3674#define KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MEMBER) \
3675 void set##MEMBER##OfKernelEnvironment(ConstantInt *NewVal) { \
3676 ConstantStruct *ConfigC = \
3677 KernelInfo::getConfigurationFromKernelEnvironment(KernelEnvC); \
3678 Constant *NewConfigC = ConstantFoldInsertValueInstruction( \
3679 ConfigC, NewVal, {KernelInfo::MEMBER##Idx}); \
3680 assert(NewConfigC && "Failed to create new configuration environment"); \
3681 setConfigurationOfKernelEnvironment(cast<ConstantStruct>(NewConfigC)); \
3692#undef KERNEL_ENVIRONMENT_CONFIGURATION_SETTER
3699 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
3703 OMPInformationCache::RuntimeFunctionInfo &InitRFI =
3704 OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
3705 OMPInformationCache::RuntimeFunctionInfo &DeinitRFI =
3706 OMPInfoCache.RFIs[OMPRTL___kmpc_target_deinit];
3710 auto StoreCallBase = [](
Use &U,
3711 OMPInformationCache::RuntimeFunctionInfo &RFI,
3713 CallBase *CB = OpenMPOpt::getCallIfRegularCall(U, &RFI);
3715 "Unexpected use of __kmpc_target_init or __kmpc_target_deinit!");
3717 "Multiple uses of __kmpc_target_init or __kmpc_target_deinit!");
3723 StoreCallBase(U, InitRFI, KernelInitCB);
3727 DeinitRFI.foreachUse(
3729 StoreCallBase(U, DeinitRFI, KernelDeinitCB);
3735 if (!KernelInitCB || !KernelDeinitCB)
3739 ReachingKernelEntries.insert(Fn);
3740 IsKernelEntry =
true;
3748 KernelConfigurationSimplifyCB =
3750 bool &UsedAssumedInformation) -> std::optional<Constant *> {
3751 if (!isAtFixpoint()) {
3754 UsedAssumedInformation =
true;
3755 A.recordDependence(*
this, *AA, DepClassTy::OPTIONAL);
3760 A.registerGlobalVariableSimplificationCallback(
3761 *KernelEnvGV, KernelConfigurationSimplifyCB);
3765 KernelInfo::getExecModeFromKernelEnvironment(KernelEnvC);
3770 SPMDCompatibilityTracker.indicateOptimisticFixpoint();
3774 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3776 setExecModeOfKernelEnvironment(AssumedExecModeC);
3783 setMinThreadsOfKernelEnvironment(ConstantInt::get(Int32Ty, MinThreads));
3785 setMaxThreadsOfKernelEnvironment(ConstantInt::get(Int32Ty,
MaxThreads));
3786 auto [MinTeams, MaxTeams] =
3789 setMinTeamsOfKernelEnvironment(ConstantInt::get(Int32Ty, MinTeams));
3791 setMaxTeamsOfKernelEnvironment(ConstantInt::get(Int32Ty, MaxTeams));
3794 KernelInfo::getMayUseNestedParallelismFromKernelEnvironment(KernelEnvC);
3795 ConstantInt *AssumedMayUseNestedParallelismC = ConstantInt::get(
3797 setMayUseNestedParallelismOfKernelEnvironment(
3798 AssumedMayUseNestedParallelismC);
3802 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
3805 ConstantInt::get(UseGenericStateMachineC->
getIntegerType(),
false);
3806 setUseGenericStateMachineOfKernelEnvironment(
3807 AssumedUseGenericStateMachineC);
3813 if (!OMPInfoCache.RFIs[RFKind].Declaration)
3815 A.registerVirtualUseCallback(*OMPInfoCache.RFIs[RFKind].Declaration, CB);
3819 auto AddDependence = [](
Attributor &
A,
const AAKernelInfo *KI,
3822 A.recordDependence(*KI, *QueryingAA, DepClassTy::OPTIONAL);
3836 if (SPMDCompatibilityTracker.isValidState())
3837 return AddDependence(
A,
this, QueryingAA);
3839 if (!ReachedKnownParallelRegions.isValidState())
3840 return AddDependence(
A,
this, QueryingAA);
3846 RegisterVirtualUse(OMPRTL___kmpc_get_hardware_num_threads_in_block,
3847 CustomStateMachineUseCB);
3848 RegisterVirtualUse(OMPRTL___kmpc_get_warp_size, CustomStateMachineUseCB);
3849 RegisterVirtualUse(OMPRTL___kmpc_barrier_simple_generic,
3850 CustomStateMachineUseCB);
3851 RegisterVirtualUse(OMPRTL___kmpc_kernel_parallel,
3852 CustomStateMachineUseCB);
3853 RegisterVirtualUse(OMPRTL___kmpc_kernel_end_parallel,
3854 CustomStateMachineUseCB);
3858 if (SPMDCompatibilityTracker.isAtFixpoint())
3865 if (!SPMDCompatibilityTracker.isValidState())
3866 return AddDependence(
A,
this, QueryingAA);
3869 RegisterVirtualUse(OMPRTL___kmpc_get_hardware_thread_id_in_block,
3878 if (!SPMDCompatibilityTracker.isValidState())
3879 return AddDependence(
A,
this, QueryingAA);
3880 if (SPMDCompatibilityTracker.empty())
3881 return AddDependence(
A,
this, QueryingAA);
3882 if (!mayContainParallelRegion())
3883 return AddDependence(
A,
this, QueryingAA);
3886 RegisterVirtualUse(OMPRTL___kmpc_barrier_simple_spmd, SPMDBarrierUseCB);
3890 static std::string sanitizeForGlobalName(std::string S) {
3894 return !((C >=
'a' && C <=
'z') || (C >=
'A' && C <=
'Z') ||
3895 (C >=
'0' && C <=
'9') || C ==
'_');
3906 if (!KernelInitCB || !KernelDeinitCB)
3907 return ChangeStatus::UNCHANGED;
3911 bool HasBuiltStateMachine =
true;
3912 if (!changeToSPMDMode(
A, Changed)) {
3914 HasBuiltStateMachine = buildCustomStateMachine(
A, Changed);
3916 HasBuiltStateMachine =
false;
3923 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
3924 ExistingKernelEnvC);
3925 if (!HasBuiltStateMachine)
3926 setUseGenericStateMachineOfKernelEnvironment(
3927 OldUseGenericStateMachineVal);
3934 Changed = ChangeStatus::CHANGED;
3940 void insertInstructionGuardsHelper(
Attributor &
A) {
3941 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
3943 auto CreateGuardedRegion = [&](
Instruction *RegionStartI,
3977 DT, LI, MSU,
"region.guarded.end");
3980 MSU,
"region.barrier");
3983 DT, LI, MSU,
"region.exit");
3985 SplitBlock(ParentBB, RegionStartI, DT, LI, MSU,
"region.guarded");
3988 "Expected a different CFG");
3991 ParentBB, ParentBB->
getTerminator(), DT, LI, MSU,
"region.check.tid");
3994 A.registerManifestAddedBasicBlock(*RegionEndBB);
3995 A.registerManifestAddedBasicBlock(*RegionBarrierBB);
3996 A.registerManifestAddedBasicBlock(*RegionExitBB);
3997 A.registerManifestAddedBasicBlock(*RegionStartBB);
3998 A.registerManifestAddedBasicBlock(*RegionCheckTidBB);
4000 bool HasBroadcastValues =
false;
4005 for (
Use &U :
I.uses()) {
4011 if (OutsideUses.
empty())
4014 HasBroadcastValues =
true;
4019 M,
I.getType(),
false,
4021 sanitizeForGlobalName(
4022 (
I.getName() +
".guarded.output.alloc").str()),
4024 static_cast<unsigned>(AddressSpace::Shared));
4031 I.getType(), SharedMem,
I.getName() +
".guarded.output.load",
4035 for (
Use *U : OutsideUses)
4036 A.changeUseAfterManifest(*U, *LoadI);
4039 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
4045 InsertPointTy(ParentBB, ParentBB->
end()),
DL);
4046 OMPInfoCache.OMPBuilder.updateToLocation(Loc);
4049 OMPInfoCache.OMPBuilder.getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4051 OMPInfoCache.OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4057 InsertPointTy(RegionCheckTidBB, RegionCheckTidBB->
end()),
DL);
4058 OMPInfoCache.OMPBuilder.updateToLocation(LocRegionCheckTid);
4060 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4061 M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
4063 OMPInfoCache.OMPBuilder.Builder.CreateCall(HardwareTidFn, {});
4065 OMPInfoCache.setCallingConvention(HardwareTidFn, Tid);
4066 Value *TidCheck = OMPInfoCache.OMPBuilder.Builder.CreateIsNull(Tid);
4067 OMPInfoCache.OMPBuilder.Builder
4068 .CreateCondBr(TidCheck, RegionStartBB, RegionBarrierBB)
4074 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4075 M, OMPRTL___kmpc_barrier_simple_spmd);
4076 OMPInfoCache.OMPBuilder.updateToLocation(InsertPointTy(
4079 OMPInfoCache.OMPBuilder.Builder.CreateCall(BarrierFn, {Ident, Tid});
4081 OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
4084 if (HasBroadcastValues) {
4089 OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
4093 auto &AllocSharedRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
4095 for (
Instruction *GuardedI : SPMDCompatibilityTracker) {
4097 if (!Visited.
insert(BB).second)
4103 while (++IP != IPEnd) {
4104 if (!IP->mayHaveSideEffects() && !IP->mayReadFromMemory())
4107 if (OpenMPOpt::getCallIfRegularCall(*
I, &AllocSharedRFI))
4109 if (!
I->user_empty() || !SPMDCompatibilityTracker.contains(
I)) {
4110 LastEffect =
nullptr;
4117 for (
auto &Reorder : Reorders)
4123 for (
Instruction *GuardedI : SPMDCompatibilityTracker) {
4125 auto *CalleeAA =
A.lookupAAFor<AAKernelInfo>(
4128 assert(CalleeAA !=
nullptr &&
"Expected Callee AAKernelInfo");
4129 auto &CalleeAAFunction = *cast<AAKernelInfoFunction>(CalleeAA);
4131 if (CalleeAAFunction.getGuardedInstructions().contains(GuardedI))
4134 Instruction *GuardedRegionStart =
nullptr, *GuardedRegionEnd =
nullptr;
4138 if (SPMDCompatibilityTracker.contains(&
I)) {
4139 CalleeAAFunction.getGuardedInstructions().insert(&
I);
4140 if (GuardedRegionStart)
4141 GuardedRegionEnd = &
I;
4143 GuardedRegionStart = GuardedRegionEnd = &
I;
4150 if (GuardedRegionStart) {
4152 std::make_pair(GuardedRegionStart, GuardedRegionEnd));
4153 GuardedRegionStart =
nullptr;
4154 GuardedRegionEnd =
nullptr;
4159 for (
auto &GR : GuardedRegions)
4160 CreateGuardedRegion(GR.first, GR.second);
4163 void forceSingleThreadPerWorkgroupHelper(
Attributor &
A) {
4172 auto &Ctx = getAnchorValue().getContext();
4179 KernelInitCB->
getNextNode(),
"main.thread.user_code");
4184 A.registerManifestAddedBasicBlock(*InitBB);
4185 A.registerManifestAddedBasicBlock(*UserCodeBB);
4186 A.registerManifestAddedBasicBlock(*ReturnBB);
4195 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
4197 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4198 M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
4203 OMPInfoCache.setCallingConvention(ThreadIdInBlockFn, ThreadIdInBlock);
4209 ConstantInt::get(ThreadIdInBlock->
getType(), 0),
4210 "thread.is_main", InitBB);
4216 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
4219 if (!OMPInfoCache.runtimeFnsAvailable(
4220 {OMPRTL___kmpc_get_hardware_thread_id_in_block,
4221 OMPRTL___kmpc_barrier_simple_spmd}))
4224 if (!SPMDCompatibilityTracker.isAssumed()) {
4225 for (
Instruction *NonCompatibleI : SPMDCompatibilityTracker) {
4226 if (!NonCompatibleI)
4230 if (
auto *CB = dyn_cast<CallBase>(NonCompatibleI))
4231 if (OMPInfoCache.RTLFunctions.contains(CB->getCalledFunction()))
4235 ORA <<
"Value has potential side effects preventing SPMD-mode "
4237 if (isa<CallBase>(NonCompatibleI)) {
4238 ORA <<
". Add `[[omp::assume(\"ompx_spmd_amenable\")]]` to "
4239 "the called function to override";
4247 << *NonCompatibleI <<
"\n");
4259 Kernel = CB->getCaller();
4267 KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC);
4273 Changed = ChangeStatus::CHANGED;
4277 if (mayContainParallelRegion())
4278 insertInstructionGuardsHelper(
A);
4280 forceSingleThreadPerWorkgroupHelper(
A);
4285 "Initially non-SPMD kernel has SPMD exec mode!");
4286 setExecModeOfKernelEnvironment(
4290 ++NumOpenMPTargetRegionKernelsSPMD;
4293 return OR <<
"Transformed generic-mode kernel to SPMD-mode.";
4305 if (!ReachedKnownParallelRegions.isValidState())
4308 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
4309 if (!OMPInfoCache.runtimeFnsAvailable(
4310 {OMPRTL___kmpc_get_hardware_num_threads_in_block,
4311 OMPRTL___kmpc_get_warp_size, OMPRTL___kmpc_barrier_simple_generic,
4312 OMPRTL___kmpc_kernel_parallel, OMPRTL___kmpc_kernel_end_parallel}))
4323 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
4324 ExistingKernelEnvC);
4326 KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC);
4331 if (UseStateMachineC->
isZero() ||
4335 Changed = ChangeStatus::CHANGED;
4338 setUseGenericStateMachineOfKernelEnvironment(
4345 if (!mayContainParallelRegion()) {
4346 ++NumOpenMPTargetRegionKernelsWithoutStateMachine;
4349 return OR <<
"Removing unused state machine from generic-mode kernel.";
4357 if (ReachedUnknownParallelRegions.empty()) {
4358 ++NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback;
4361 return OR <<
"Rewriting generic-mode kernel with a customized state "
4366 ++NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback;
4369 return OR <<
"Generic-mode kernel is executed with a customized state "
4370 "machine that requires a fallback.";
4375 for (
CallBase *UnknownParallelRegionCB : ReachedUnknownParallelRegions) {
4376 if (!UnknownParallelRegionCB)
4379 return ORA <<
"Call may contain unknown parallel regions. Use "
4380 <<
"`[[omp::assume(\"omp_no_parallelism\")]]` to "
4418 auto &Ctx = getAnchorValue().getContext();
4422 BasicBlock *InitBB = KernelInitCB->getParent();
4424 KernelInitCB->getNextNode(),
"thread.user_code.check");
4428 Ctx,
"worker_state_machine.begin",
Kernel, UserCodeEntryBB);
4430 Ctx,
"worker_state_machine.finished",
Kernel, UserCodeEntryBB);
4432 Ctx,
"worker_state_machine.is_active.check",
Kernel, UserCodeEntryBB);
4435 Kernel, UserCodeEntryBB);
4438 Kernel, UserCodeEntryBB);
4440 Ctx,
"worker_state_machine.done.barrier",
Kernel, UserCodeEntryBB);
4441 A.registerManifestAddedBasicBlock(*InitBB);
4442 A.registerManifestAddedBasicBlock(*UserCodeEntryBB);
4443 A.registerManifestAddedBasicBlock(*IsWorkerCheckBB);
4444 A.registerManifestAddedBasicBlock(*StateMachineBeginBB);
4445 A.registerManifestAddedBasicBlock(*StateMachineFinishedBB);
4446 A.registerManifestAddedBasicBlock(*StateMachineIsActiveCheckBB);
4447 A.registerManifestAddedBasicBlock(*StateMachineIfCascadeCurrentBB);
4448 A.registerManifestAddedBasicBlock(*StateMachineEndParallelBB);
4449 A.registerManifestAddedBasicBlock(*StateMachineDoneBarrierBB);
4451 const DebugLoc &DLoc = KernelInitCB->getDebugLoc();
4457 ConstantInt::get(KernelInitCB->getType(), -1),
4458 "thread.is_worker", InitBB);
4464 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4465 M, OMPRTL___kmpc_get_hardware_num_threads_in_block);
4467 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4468 M, OMPRTL___kmpc_get_warp_size);
4471 OMPInfoCache.setCallingConvention(BlockHwSizeFn, BlockHwSize);
4475 OMPInfoCache.setCallingConvention(WarpSizeFn, WarpSize);
4478 BlockHwSize, WarpSize,
"block.size", IsWorkerCheckBB);
4482 "thread.is_main_or_worker", IsWorkerCheckBB);
4485 IsMainOrWorker, IsWorkerCheckBB);
4489 Type *VoidPtrTy = PointerType::getUnqual(Ctx);
4491 new AllocaInst(VoidPtrTy,
DL.getAllocaAddrSpace(),
nullptr,
4495 OMPInfoCache.OMPBuilder.updateToLocation(
4498 StateMachineBeginBB->
end()),
4501 Value *Ident = KernelInfo::getIdentFromKernelEnvironment(KernelEnvC);
4502 Value *GTid = KernelInitCB;
4505 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4506 M, OMPRTL___kmpc_barrier_simple_generic);
4509 OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
4513 (
unsigned int)AddressSpace::Generic) {
4515 WorkFnAI, PointerType::get(Ctx, (
unsigned int)AddressSpace::Generic),
4516 WorkFnAI->
getName() +
".generic", StateMachineBeginBB);
4521 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4522 M, OMPRTL___kmpc_kernel_parallel);
4524 KernelParallelFn, {WorkFnAI},
"worker.is_active", StateMachineBeginBB);
4525 OMPInfoCache.setCallingConvention(KernelParallelFn, IsActiveWorker);
4528 StateMachineBeginBB);
4538 StateMachineBeginBB);
4539 IsDone->setDebugLoc(DLoc);
4541 IsDone, StateMachineBeginBB)
4545 StateMachineDoneBarrierBB, IsActiveWorker,
4546 StateMachineIsActiveCheckBB)
4552 const unsigned int WrapperFunctionArgNo = 6;
4557 for (
int I = 0, E = ReachedKnownParallelRegions.size();
I < E; ++
I) {
4558 auto *CB = ReachedKnownParallelRegions[
I];
4559 auto *ParallelRegion = dyn_cast<Function>(
4560 CB->getArgOperand(WrapperFunctionArgNo)->stripPointerCasts());
4562 Ctx,
"worker_state_machine.parallel_region.execute",
Kernel,
4563 StateMachineEndParallelBB);
4565 ->setDebugLoc(DLoc);
4571 Kernel, StateMachineEndParallelBB);
4572 A.registerManifestAddedBasicBlock(*PRExecuteBB);
4573 A.registerManifestAddedBasicBlock(*PRNextBB);
4578 if (
I + 1 < E || !ReachedUnknownParallelRegions.empty()) {
4581 "worker.check_parallel_region", StateMachineIfCascadeCurrentBB);
4589 StateMachineIfCascadeCurrentBB)
4591 StateMachineIfCascadeCurrentBB = PRNextBB;
4597 if (!ReachedUnknownParallelRegions.empty()) {
4598 StateMachineIfCascadeCurrentBB->
setName(
4599 "worker_state_machine.parallel_region.fallback.execute");
4601 StateMachineIfCascadeCurrentBB)
4602 ->setDebugLoc(DLoc);
4605 StateMachineIfCascadeCurrentBB)
4609 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4610 M, OMPRTL___kmpc_kernel_end_parallel);
4613 OMPInfoCache.setCallingConvention(EndParallelFn, EndParallel);
4619 ->setDebugLoc(DLoc);
4629 KernelInfoState StateBefore = getState();
4635 struct UpdateKernelEnvCRAII {
4636 AAKernelInfoFunction &AA;
4638 UpdateKernelEnvCRAII(AAKernelInfoFunction &AA) : AA(AA) {}
4640 ~UpdateKernelEnvCRAII() {
4647 if (!AA.isValidState()) {
4648 AA.KernelEnvC = ExistingKernelEnvC;
4652 if (!AA.ReachedKnownParallelRegions.isValidState())
4653 AA.setUseGenericStateMachineOfKernelEnvironment(
4654 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
4655 ExistingKernelEnvC));
4657 if (!AA.SPMDCompatibilityTracker.isValidState())
4658 AA.setExecModeOfKernelEnvironment(
4659 KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC));
4662 KernelInfo::getMayUseNestedParallelismFromKernelEnvironment(
4664 ConstantInt *NewMayUseNestedParallelismC = ConstantInt::get(
4665 MayUseNestedParallelismC->
getIntegerType(), AA.NestedParallelism);
4666 AA.setMayUseNestedParallelismOfKernelEnvironment(
4667 NewMayUseNestedParallelismC);
4674 if (isa<CallBase>(
I))
4677 if (!
I.mayWriteToMemory())
4679 if (
auto *SI = dyn_cast<StoreInst>(&
I)) {
4682 DepClassTy::OPTIONAL);
4685 DepClassTy::OPTIONAL);
4686 if (UnderlyingObjsAA &&
4687 UnderlyingObjsAA->forallUnderlyingObjects([&](
Value &Obj) {
4688 if (AA::isAssumedThreadLocalObject(A, Obj, *this))
4692 auto *CB = dyn_cast<CallBase>(&Obj);
4693 return CB && HS && HS->isAssumedHeapToStack(*CB);
4699 SPMDCompatibilityTracker.insert(&
I);
4703 bool UsedAssumedInformationInCheckRWInst =
false;
4704 if (!SPMDCompatibilityTracker.isAtFixpoint())
4705 if (!
A.checkForAllReadWriteInstructions(
4706 CheckRWInst, *
this, UsedAssumedInformationInCheckRWInst))
4709 bool UsedAssumedInformationFromReachingKernels =
false;
4710 if (!IsKernelEntry) {
4711 updateParallelLevels(
A);
4713 bool AllReachingKernelsKnown =
true;
4714 updateReachingKernelEntries(
A, AllReachingKernelsKnown);
4715 UsedAssumedInformationFromReachingKernels = !AllReachingKernelsKnown;
4717 if (!SPMDCompatibilityTracker.empty()) {
4718 if (!ParallelLevels.isValidState())
4719 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4720 else if (!ReachingKernelEntries.isValidState())
4721 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4727 for (
auto *
Kernel : ReachingKernelEntries) {
4728 auto *CBAA =
A.getAAFor<AAKernelInfo>(
4730 if (CBAA && CBAA->SPMDCompatibilityTracker.isValidState() &&
4731 CBAA->SPMDCompatibilityTracker.isAssumed())
4735 if (!CBAA || !CBAA->SPMDCompatibilityTracker.isAtFixpoint())
4736 UsedAssumedInformationFromReachingKernels =
true;
4738 if (SPMD != 0 &&
Generic != 0)
4739 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4745 bool AllParallelRegionStatesWereFixed =
true;
4746 bool AllSPMDStatesWereFixed =
true;
4748 auto &CB = cast<CallBase>(
I);
4749 auto *CBAA =
A.getAAFor<AAKernelInfo>(
4753 getState() ^= CBAA->getState();
4754 AllSPMDStatesWereFixed &= CBAA->SPMDCompatibilityTracker.isAtFixpoint();
4755 AllParallelRegionStatesWereFixed &=
4756 CBAA->ReachedKnownParallelRegions.isAtFixpoint();
4757 AllParallelRegionStatesWereFixed &=
4758 CBAA->ReachedUnknownParallelRegions.isAtFixpoint();
4762 bool UsedAssumedInformationInCheckCallInst =
false;
4763 if (!
A.checkForAllCallLikeInstructions(
4764 CheckCallInst, *
this, UsedAssumedInformationInCheckCallInst)) {
4766 <<
"Failed to visit all call-like instructions!\n";);
4767 return indicatePessimisticFixpoint();
4772 if (!UsedAssumedInformationInCheckCallInst &&
4773 AllParallelRegionStatesWereFixed) {
4774 ReachedKnownParallelRegions.indicateOptimisticFixpoint();
4775 ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
4780 if (!UsedAssumedInformationInCheckRWInst &&
4781 !UsedAssumedInformationInCheckCallInst &&
4782 !UsedAssumedInformationFromReachingKernels && AllSPMDStatesWereFixed)
4783 SPMDCompatibilityTracker.indicateOptimisticFixpoint();
4785 return StateBefore == getState() ? ChangeStatus::UNCHANGED
4786 : ChangeStatus::CHANGED;
4792 bool &AllReachingKernelsKnown) {
4796 assert(Caller &&
"Caller is nullptr");
4798 auto *CAA =
A.getOrCreateAAFor<AAKernelInfo>(
4800 if (CAA && CAA->ReachingKernelEntries.isValidState()) {
4801 ReachingKernelEntries ^= CAA->ReachingKernelEntries;
4807 ReachingKernelEntries.indicatePessimisticFixpoint();
4812 if (!
A.checkForAllCallSites(PredCallSite, *
this,
4814 AllReachingKernelsKnown))
4815 ReachingKernelEntries.indicatePessimisticFixpoint();
4820 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
4821 OMPInformationCache::RuntimeFunctionInfo &Parallel51RFI =
4822 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
4827 assert(Caller &&
"Caller is nullptr");
4831 if (CAA && CAA->ParallelLevels.isValidState()) {
4837 if (Caller == Parallel51RFI.Declaration) {
4838 ParallelLevels.indicatePessimisticFixpoint();
4842 ParallelLevels ^= CAA->ParallelLevels;
4849 ParallelLevels.indicatePessimisticFixpoint();
4854 bool AllCallSitesKnown =
true;
4855 if (!
A.checkForAllCallSites(PredCallSite, *
this,
4858 ParallelLevels.indicatePessimisticFixpoint();
4865struct AAKernelInfoCallSite : AAKernelInfo {
4867 : AAKernelInfo(IRP,
A) {}
4871 AAKernelInfo::initialize(
A);
4873 CallBase &CB = cast<CallBase>(getAssociatedValue());
4878 if (AssumptionAA && AssumptionAA->hasAssumption(
"ompx_spmd_amenable")) {
4879 indicateOptimisticFixpoint();
4887 indicateOptimisticFixpoint();
4896 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
4897 const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
4898 if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
4900 if (!Callee || !
A.isFunctionIPOAmendable(*Callee)) {
4904 if (!AssumptionAA ||
4905 !(AssumptionAA->hasAssumption(
"omp_no_openmp") ||
4906 AssumptionAA->hasAssumption(
"omp_no_parallelism")))
4907 ReachedUnknownParallelRegions.insert(&CB);
4911 if (!SPMDCompatibilityTracker.isAtFixpoint()) {
4912 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4913 SPMDCompatibilityTracker.insert(&CB);
4918 indicateOptimisticFixpoint();
4924 if (NumCallees > 1) {
4925 indicatePessimisticFixpoint();
4932 case OMPRTL___kmpc_is_spmd_exec_mode:
4933 case OMPRTL___kmpc_distribute_static_fini:
4934 case OMPRTL___kmpc_for_static_fini:
4935 case OMPRTL___kmpc_global_thread_num:
4936 case OMPRTL___kmpc_get_hardware_num_threads_in_block:
4937 case OMPRTL___kmpc_get_hardware_num_blocks:
4938 case OMPRTL___kmpc_single:
4939 case OMPRTL___kmpc_end_single:
4940 case OMPRTL___kmpc_master:
4941 case OMPRTL___kmpc_end_master:
4942 case OMPRTL___kmpc_barrier:
4943 case OMPRTL___kmpc_nvptx_parallel_reduce_nowait_v2:
4944 case OMPRTL___kmpc_nvptx_teams_reduce_nowait_v2:
4945 case OMPRTL___kmpc_error:
4946 case OMPRTL___kmpc_flush:
4947 case OMPRTL___kmpc_get_hardware_thread_id_in_block:
4948 case OMPRTL___kmpc_get_warp_size:
4949 case OMPRTL_omp_get_thread_num:
4950 case OMPRTL_omp_get_num_threads:
4951 case OMPRTL_omp_get_max_threads:
4952 case OMPRTL_omp_in_parallel:
4953 case OMPRTL_omp_get_dynamic:
4954 case OMPRTL_omp_get_cancellation:
4955 case OMPRTL_omp_get_nested:
4956 case OMPRTL_omp_get_schedule:
4957 case OMPRTL_omp_get_thread_limit:
4958 case OMPRTL_omp_get_supported_active_levels:
4959 case OMPRTL_omp_get_max_active_levels:
4960 case OMPRTL_omp_get_level:
4961 case OMPRTL_omp_get_ancestor_thread_num:
4962 case OMPRTL_omp_get_team_size:
4963 case OMPRTL_omp_get_active_level:
4964 case OMPRTL_omp_in_final:
4965 case OMPRTL_omp_get_proc_bind:
4966 case OMPRTL_omp_get_num_places:
4967 case OMPRTL_omp_get_num_procs:
4968 case OMPRTL_omp_get_place_proc_ids:
4969 case OMPRTL_omp_get_place_num:
4970 case OMPRTL_omp_get_partition_num_places:
4971 case OMPRTL_omp_get_partition_place_nums:
4972 case OMPRTL_omp_get_wtime:
4974 case OMPRTL___kmpc_distribute_static_init_4:
4975 case OMPRTL___kmpc_distribute_static_init_4u:
4976 case OMPRTL___kmpc_distribute_static_init_8:
4977 case OMPRTL___kmpc_distribute_static_init_8u:
4978 case OMPRTL___kmpc_for_static_init_4:
4979 case OMPRTL___kmpc_for_static_init_4u:
4980 case OMPRTL___kmpc_for_static_init_8:
4981 case OMPRTL___kmpc_for_static_init_8u: {
4983 unsigned ScheduleArgOpNo = 2;
4984 auto *ScheduleTypeCI =
4986 unsigned ScheduleTypeVal =
4987 ScheduleTypeCI ? ScheduleTypeCI->getZExtValue() : 0;
4989 case OMPScheduleType::UnorderedStatic:
4990 case OMPScheduleType::UnorderedStaticChunked:
4991 case OMPScheduleType::OrderedDistribute:
4992 case OMPScheduleType::OrderedDistributeChunked:
4995 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4996 SPMDCompatibilityTracker.insert(&CB);
5000 case OMPRTL___kmpc_target_init:
5003 case OMPRTL___kmpc_target_deinit:
5004 KernelDeinitCB = &CB;
5006 case OMPRTL___kmpc_parallel_51:
5007 if (!handleParallel51(
A, CB))
5008 indicatePessimisticFixpoint();
5010 case OMPRTL___kmpc_omp_task:
5012 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5013 SPMDCompatibilityTracker.insert(&CB);
5014 ReachedUnknownParallelRegions.insert(&CB);
5016 case OMPRTL___kmpc_alloc_shared:
5017 case OMPRTL___kmpc_free_shared:
5023 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5024 SPMDCompatibilityTracker.insert(&CB);
5030 indicateOptimisticFixpoint();
5034 A.getAAFor<
AACallEdges>(*
this, getIRPosition(), DepClassTy::OPTIONAL);
5035 if (!AACE || !AACE->getState().isValidState() || AACE->hasUnknownCallee()) {
5036 CheckCallee(getAssociatedFunction(), 1);
5039 const auto &OptimisticEdges = AACE->getOptimisticEdges();
5040 for (
auto *Callee : OptimisticEdges) {
5041 CheckCallee(Callee, OptimisticEdges.size());
5052 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
5053 KernelInfoState StateBefore = getState();
5055 auto CheckCallee = [&](
Function *
F,
int NumCallees) {
5056 const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(
F);
5060 if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
5063 A.getAAFor<AAKernelInfo>(*
this, FnPos, DepClassTy::REQUIRED);
5065 return indicatePessimisticFixpoint();
5066 if (getState() == FnAA->getState())
5067 return ChangeStatus::UNCHANGED;
5068 getState() = FnAA->getState();
5069 return ChangeStatus::CHANGED;
5072 return indicatePessimisticFixpoint();
5074 CallBase &CB = cast<CallBase>(getAssociatedValue());
5075 if (It->getSecond() == OMPRTL___kmpc_parallel_51) {
5076 if (!handleParallel51(
A, CB))
5077 return indicatePessimisticFixpoint();
5078 return StateBefore == getState() ? ChangeStatus::UNCHANGED
5079 : ChangeStatus::CHANGED;
5085 (It->getSecond() == OMPRTL___kmpc_alloc_shared ||
5086 It->getSecond() == OMPRTL___kmpc_free_shared) &&
5087 "Expected a __kmpc_alloc_shared or __kmpc_free_shared runtime call");
5091 auto *HeapToSharedAA =
A.getAAFor<AAHeapToShared>(
5099 case OMPRTL___kmpc_alloc_shared:
5100 if ((!HeapToStackAA || !HeapToStackAA->isAssumedHeapToStack(CB)) &&
5101 (!HeapToSharedAA || !HeapToSharedAA->isAssumedHeapToShared(CB)))
5102 SPMDCompatibilityTracker.insert(&CB);
5104 case OMPRTL___kmpc_free_shared:
5105 if ((!HeapToStackAA ||
5106 !HeapToStackAA->isAssumedHeapToStackRemovedFree(CB)) &&
5108 !HeapToSharedAA->isAssumedHeapToSharedRemovedFree(CB)))
5109 SPMDCompatibilityTracker.insert(&CB);
5112 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5113 SPMDCompatibilityTracker.insert(&CB);
5115 return ChangeStatus::CHANGED;
5119 A.getAAFor<
AACallEdges>(*
this, getIRPosition(), DepClassTy::OPTIONAL);
5120 if (!AACE || !AACE->getState().isValidState() || AACE->hasUnknownCallee()) {
5121 if (
Function *
F = getAssociatedFunction())
5125 for (
auto *Callee : OptimisticEdges) {
5126 CheckCallee(Callee, OptimisticEdges.size());
5132 return StateBefore == getState() ? ChangeStatus::UNCHANGED
5133 : ChangeStatus::CHANGED;
5139 const unsigned int NonWrapperFunctionArgNo = 5;
5140 const unsigned int WrapperFunctionArgNo = 6;
5141 auto ParallelRegionOpArgNo = SPMDCompatibilityTracker.isAssumed()
5142 ? NonWrapperFunctionArgNo
5143 : WrapperFunctionArgNo;
5145 auto *ParallelRegion = dyn_cast<Function>(
5147 if (!ParallelRegion)
5150 ReachedKnownParallelRegions.insert(&CB);
5152 auto *FnAA =
A.getAAFor<AAKernelInfo>(
5154 NestedParallelism |= !FnAA || !FnAA->getState().isValidState() ||
5155 !FnAA->ReachedKnownParallelRegions.empty() ||
5156 !FnAA->ReachedKnownParallelRegions.isValidState() ||
5157 !FnAA->ReachedUnknownParallelRegions.isValidState() ||
5158 !FnAA->ReachedUnknownParallelRegions.empty();
5163struct AAFoldRuntimeCall
5164 :
public StateWrapper<BooleanState, AbstractAttribute> {
5170 void trackStatistics()
const override {}
5173 static AAFoldRuntimeCall &createForPosition(
const IRPosition &IRP,
5177 const std::string
getName()
const override {
return "AAFoldRuntimeCall"; }
5180 const char *getIdAddr()
const override {
return &
ID; }
5188 static const char ID;
5191struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
5193 : AAFoldRuntimeCall(IRP,
A) {}
5196 const std::string getAsStr(
Attributor *)
const override {
5197 if (!isValidState())
5200 std::string Str(
"simplified value: ");
5202 if (!SimplifiedValue)
5203 return Str + std::string(
"none");
5205 if (!*SimplifiedValue)
5206 return Str + std::string(
"nullptr");
5208 if (
ConstantInt *CI = dyn_cast<ConstantInt>(*SimplifiedValue))
5209 return Str + std::to_string(CI->getSExtValue());
5211 return Str + std::string(
"unknown");
5216 indicatePessimisticFixpoint();
5220 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
5221 const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
5222 assert(It != OMPInfoCache.RuntimeFunctionIDMap.end() &&
5223 "Expected a known OpenMP runtime function");
5225 RFKind = It->getSecond();
5227 CallBase &CB = cast<CallBase>(getAssociatedValue());
5228 A.registerSimplificationCallback(
5231 bool &UsedAssumedInformation) -> std::optional<Value *> {
5232 assert((isValidState() ||
5233 (SimplifiedValue && *SimplifiedValue ==
nullptr)) &&
5234 "Unexpected invalid state!");
5236 if (!isAtFixpoint()) {
5237 UsedAssumedInformation =
true;
5239 A.recordDependence(*
this, *AA, DepClassTy::OPTIONAL);
5241 return SimplifiedValue;
5248 case OMPRTL___kmpc_is_spmd_exec_mode:
5249 Changed |= foldIsSPMDExecMode(
A);
5251 case OMPRTL___kmpc_parallel_level:
5252 Changed |= foldParallelLevel(
A);
5254 case OMPRTL___kmpc_get_hardware_num_threads_in_block:
5255 Changed = Changed | foldKernelFnAttribute(
A,
"omp_target_thread_limit");
5257 case OMPRTL___kmpc_get_hardware_num_blocks:
5258 Changed = Changed | foldKernelFnAttribute(
A,
"omp_target_num_teams");
5270 if (SimplifiedValue && *SimplifiedValue) {
5273 A.deleteAfterManifest(
I);
5277 if (
auto *
C = dyn_cast<ConstantInt>(*SimplifiedValue))
5278 return OR <<
"Replacing OpenMP runtime call "
5280 <<
ore::NV(
"FoldedValue",
C->getZExtValue()) <<
".";
5281 return OR <<
"Replacing OpenMP runtime call "
5289 << **SimplifiedValue <<
"\n");
5291 Changed = ChangeStatus::CHANGED;
5298 SimplifiedValue =
nullptr;
5299 return AAFoldRuntimeCall::indicatePessimisticFixpoint();
5305 std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
5307 unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
5308 unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
5309 auto *CallerKernelInfoAA =
A.getAAFor<AAKernelInfo>(
5312 if (!CallerKernelInfoAA ||
5313 !CallerKernelInfoAA->ReachingKernelEntries.isValidState())
5314 return indicatePessimisticFixpoint();
5316 for (
Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
5318 DepClassTy::REQUIRED);
5320 if (!AA || !AA->isValidState()) {
5321 SimplifiedValue =
nullptr;
5322 return indicatePessimisticFixpoint();
5325 if (AA->SPMDCompatibilityTracker.isAssumed()) {
5326 if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5331 if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5332 ++KnownNonSPMDCount;
5334 ++AssumedNonSPMDCount;
5338 if ((AssumedSPMDCount + KnownSPMDCount) &&
5339 (AssumedNonSPMDCount + KnownNonSPMDCount))
5340 return indicatePessimisticFixpoint();
5342 auto &Ctx = getAnchorValue().getContext();
5343 if (KnownSPMDCount || AssumedSPMDCount) {
5344 assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
5345 "Expected only SPMD kernels!");
5349 }
else if (KnownNonSPMDCount || AssumedNonSPMDCount) {
5350 assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
5351 "Expected only non-SPMD kernels!");
5359 assert(!SimplifiedValue &&
"SimplifiedValue should be none");
5362 return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
5363 : ChangeStatus::CHANGED;
5368 std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
5370 auto *CallerKernelInfoAA =
A.getAAFor<AAKernelInfo>(
5373 if (!CallerKernelInfoAA ||
5374 !CallerKernelInfoAA->ParallelLevels.isValidState())
5375 return indicatePessimisticFixpoint();
5377 if (!CallerKernelInfoAA->ReachingKernelEntries.isValidState())
5378 return indicatePessimisticFixpoint();
5380 if (CallerKernelInfoAA->ReachingKernelEntries.empty()) {
5381 assert(!SimplifiedValue &&
5382 "SimplifiedValue should keep none at this point");
5383 return ChangeStatus::UNCHANGED;