44#include "llvm/IR/IntrinsicsAMDGPU.h"
45#include "llvm/IR/IntrinsicsNVPTX.h"
61#define DEBUG_TYPE "openmp-opt"
64 "openmp-opt-disable",
cl::desc(
"Disable OpenMP specific optimizations."),
68 "openmp-opt-enable-merging",
74 cl::desc(
"Disable function internalization."),
85 "openmp-hide-memory-transfer-latency",
86 cl::desc(
"[WIP] Tries to hide the latency of host to device memory"
91 "openmp-opt-disable-deglobalization",
92 cl::desc(
"Disable OpenMP optimizations involving deglobalization."),
96 "openmp-opt-disable-spmdization",
97 cl::desc(
"Disable OpenMP optimizations involving SPMD-ization."),
101 "openmp-opt-disable-folding",
106 "openmp-opt-disable-state-machine-rewrite",
107 cl::desc(
"Disable OpenMP optimizations that replace the state machine."),
111 "openmp-opt-disable-barrier-elimination",
112 cl::desc(
"Disable OpenMP optimizations that eliminate barriers."),
116 "openmp-opt-print-module-after",
117 cl::desc(
"Print the current module after OpenMP optimizations."),
121 "openmp-opt-print-module-before",
122 cl::desc(
"Print the current module before OpenMP optimizations."),
126 "openmp-opt-inline-device",
137 cl::desc(
"Maximal number of attributor iterations."),
142 cl::desc(
"Maximum amount of shared memory to use."),
143 cl::init(std::numeric_limits<unsigned>::max()));
146 "Number of OpenMP runtime calls deduplicated");
148 "Number of OpenMP parallel regions deleted");
150 "Number of OpenMP runtime functions identified");
152 "Number of OpenMP runtime function uses identified");
154 "Number of OpenMP target region entry points (=kernels) identified");
156 "Number of OpenMP target region entry points (=kernels) executed in "
157 "SPMD-mode instead of generic-mode");
158STATISTIC(NumOpenMPTargetRegionKernelsWithoutStateMachine,
159 "Number of OpenMP target region entry points (=kernels) executed in "
160 "generic-mode without a state machines");
161STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback,
162 "Number of OpenMP target region entry points (=kernels) executed in "
163 "generic-mode with customized state machines with fallback");
164STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback,
165 "Number of OpenMP target region entry points (=kernels) executed in "
166 "generic-mode with customized state machines without fallback");
168 NumOpenMPParallelRegionsReplacedInGPUStateMachine,
169 "Number of OpenMP parallel regions replaced with ID in GPU state machines");
171 "Number of OpenMP parallel regions merged");
173 "Amount of memory pushed to shared memory");
174STATISTIC(NumBarriersEliminated,
"Number of redundant barriers eliminated");
182struct AAHeapToShared;
193 Kernels(Kernels), OpenMPPostLink(OpenMPPostLink) {
195 OMPBuilder.initialize();
196 initializeRuntimeFunctions(M);
197 initializeInternalControlVars();
201 struct InternalControlVarInfo {
228 struct RuntimeFunctionInfo {
252 void clearUsesMap() { UsesMap.
clear(); }
255 operator bool()
const {
return Declaration; }
258 UseVector &getOrCreateUseVector(
Function *
F) {
259 std::shared_ptr<UseVector> &UV = UsesMap[
F];
261 UV = std::make_shared<UseVector>();
267 const UseVector *getUseVector(
Function &
F)
const {
268 auto I = UsesMap.find(&
F);
269 if (
I != UsesMap.end())
270 return I->second.get();
275 size_t getNumFunctionsWithUses()
const {
return UsesMap.size(); }
279 size_t getNumArgs()
const {
return ArgumentTypes.
size(); }
297 UseVector &UV = getOrCreateUseVector(
F);
307 while (!ToBeDeleted.
empty()) {
321 decltype(UsesMap)::iterator begin() {
return UsesMap.
begin(); }
322 decltype(UsesMap)::iterator end() {
return UsesMap.
end(); }
330 RuntimeFunction::OMPRTL___last>
338 InternalControlVar::ICV___last>
343 void initializeInternalControlVars() {
344#define ICV_RT_SET(_Name, RTL) \
346 auto &ICV = ICVs[_Name]; \
349#define ICV_RT_GET(Name, RTL) \
351 auto &ICV = ICVs[Name]; \
354#define ICV_DATA_ENV(Enum, _Name, _EnvVarName, Init) \
356 auto &ICV = ICVs[Enum]; \
359 ICV.InitKind = Init; \
360 ICV.EnvVarName = _EnvVarName; \
361 switch (ICV.InitKind) { \
362 case ICV_IMPLEMENTATION_DEFINED: \
363 ICV.InitValue = nullptr; \
366 ICV.InitValue = ConstantInt::get( \
367 Type::getInt32Ty(OMPBuilder.Int32->getContext()), 0); \
370 ICV.InitValue = ConstantInt::getFalse(OMPBuilder.Int1->getContext()); \
376#include "llvm/Frontend/OpenMP/OMPKinds.def"
382 static bool declMatchesRTFTypes(
Function *
F,
Type *RTFRetType,
389 if (
F->getReturnType() != RTFRetType)
391 if (
F->arg_size() != RTFArgTypes.
size())
394 auto *RTFTyIt = RTFArgTypes.
begin();
396 if (
Arg.getType() != *RTFTyIt)
406 unsigned collectUses(RuntimeFunctionInfo &RFI,
bool CollectStats =
true) {
407 unsigned NumUses = 0;
408 if (!RFI.Declaration)
413 NumOpenMPRuntimeFunctionsIdentified += 1;
414 NumOpenMPRuntimeFunctionUsesIdentified += RFI.Declaration->getNumUses();
418 for (
Use &U : RFI.Declaration->uses()) {
419 if (
Instruction *UserI = dyn_cast<Instruction>(U.getUser())) {
420 if (ModuleSlice.empty() || ModuleSlice.count(UserI->getFunction())) {
421 RFI.getOrCreateUseVector(UserI->getFunction()).push_back(&U);
425 RFI.getOrCreateUseVector(
nullptr).push_back(&U);
434 auto &RFI = RFIs[RTF];
436 collectUses(RFI,
false);
440 void recollectUses() {
441 for (
int Idx = 0;
Idx < RFIs.size(); ++
Idx)
461 RuntimeFunctionInfo &RFI = RFIs[Fn];
463 if (RFI.Declaration && RFI.Declaration->isDeclaration())
471 void initializeRuntimeFunctions(
Module &M) {
474#define OMP_TYPE(VarName, ...) \
475 Type *VarName = OMPBuilder.VarName; \
478#define OMP_ARRAY_TYPE(VarName, ...) \
479 ArrayType *VarName##Ty = OMPBuilder.VarName##Ty; \
481 PointerType *VarName##PtrTy = OMPBuilder.VarName##PtrTy; \
482 (void)VarName##PtrTy;
484#define OMP_FUNCTION_TYPE(VarName, ...) \
485 FunctionType *VarName = OMPBuilder.VarName; \
487 PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \
490#define OMP_STRUCT_TYPE(VarName, ...) \
491 StructType *VarName = OMPBuilder.VarName; \
493 PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \
496#define OMP_RTL(_Enum, _Name, _IsVarArg, _ReturnType, ...) \
498 SmallVector<Type *, 8> ArgsTypes({__VA_ARGS__}); \
499 Function *F = M.getFunction(_Name); \
500 RTLFunctions.insert(F); \
501 if (declMatchesRTFTypes(F, OMPBuilder._ReturnType, ArgsTypes)) { \
502 RuntimeFunctionIDMap[F] = _Enum; \
503 auto &RFI = RFIs[_Enum]; \
506 RFI.IsVarArg = _IsVarArg; \
507 RFI.ReturnType = OMPBuilder._ReturnType; \
508 RFI.ArgumentTypes = std::move(ArgsTypes); \
509 RFI.Declaration = F; \
510 unsigned NumUses = collectUses(RFI); \
513 dbgs() << TAG << RFI.Name << (RFI.Declaration ? "" : " not") \
515 if (RFI.Declaration) \
516 dbgs() << TAG << "-> got " << NumUses << " uses in " \
517 << RFI.getNumFunctionsWithUses() \
518 << " different functions.\n"; \
522#include "llvm/Frontend/OpenMP/OMPKinds.def"
528 for (
StringRef Prefix : {
"__kmpc",
"_ZN4ompx",
"omp_"})
529 if (
F.hasFnAttribute(Attribute::NoInline) &&
530 F.getName().startswith(Prefix) &&
531 !
F.hasFnAttribute(Attribute::OptimizeNone))
532 F.removeFnAttr(Attribute::NoInline);
546 bool OpenMPPostLink =
false;
549template <
typename Ty,
bool InsertInval
idates = true>
551 bool contains(
const Ty &Elem)
const {
return Set.contains(Elem); }
552 bool insert(
const Ty &Elem) {
553 if (InsertInvalidates)
555 return Set.insert(Elem);
558 const Ty &operator[](
int Idx)
const {
return Set[
Idx]; }
559 bool operator==(
const BooleanStateWithSetVector &
RHS)
const {
560 return BooleanState::operator==(
RHS) && Set ==
RHS.Set;
562 bool operator!=(
const BooleanStateWithSetVector &
RHS)
const {
563 return !(*
this ==
RHS);
566 bool empty()
const {
return Set.empty(); }
567 size_t size()
const {
return Set.size(); }
570 BooleanStateWithSetVector &operator^=(
const BooleanStateWithSetVector &
RHS) {
571 BooleanState::operator^=(
RHS);
572 Set.insert(
RHS.Set.begin(),
RHS.Set.end());
581 typename decltype(Set)::iterator begin() {
return Set.
begin(); }
582 typename decltype(Set)::iterator end() {
return Set.
end(); }
587template <
typename Ty,
bool InsertInval
idates = true>
588using BooleanStateWithPtrSetVector =
589 BooleanStateWithSetVector<Ty *, InsertInvalidates>;
593 bool IsAtFixpoint =
false;
597 BooleanStateWithPtrSetVector<
Function,
false>
598 ReachedKnownParallelRegions;
601 BooleanStateWithPtrSetVector<CallBase> ReachedUnknownParallelRegions;
606 BooleanStateWithPtrSetVector<Instruction, false> SPMDCompatibilityTracker;
617 bool IsKernelEntry =
false;
620 BooleanStateWithPtrSetVector<Function, false> ReachingKernelEntries;
625 BooleanStateWithSetVector<uint8_t> ParallelLevels;
628 bool NestedParallelism =
false;
633 KernelInfoState() =
default;
634 KernelInfoState(
bool BestState) {
643 bool isAtFixpoint()
const override {
return IsAtFixpoint; }
648 ParallelLevels.indicatePessimisticFixpoint();
649 ReachingKernelEntries.indicatePessimisticFixpoint();
650 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
651 ReachedKnownParallelRegions.indicatePessimisticFixpoint();
652 ReachedUnknownParallelRegions.indicatePessimisticFixpoint();
659 ParallelLevels.indicateOptimisticFixpoint();
660 ReachingKernelEntries.indicateOptimisticFixpoint();
661 SPMDCompatibilityTracker.indicateOptimisticFixpoint();
662 ReachedKnownParallelRegions.indicateOptimisticFixpoint();
663 ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
668 KernelInfoState &getAssumed() {
return *
this; }
669 const KernelInfoState &getAssumed()
const {
return *
this; }
672 if (SPMDCompatibilityTracker !=
RHS.SPMDCompatibilityTracker)
674 if (ReachedKnownParallelRegions !=
RHS.ReachedKnownParallelRegions)
676 if (ReachedUnknownParallelRegions !=
RHS.ReachedUnknownParallelRegions)
678 if (ReachingKernelEntries !=
RHS.ReachingKernelEntries)
680 if (ParallelLevels !=
RHS.ParallelLevels)
686 bool mayContainParallelRegion() {
687 return !ReachedKnownParallelRegions.empty() ||
688 !ReachedUnknownParallelRegions.empty();
692 static KernelInfoState getBestState() {
return KernelInfoState(
true); }
694 static KernelInfoState getBestState(KernelInfoState &KIS) {
695 return getBestState();
699 static KernelInfoState getWorstState() {
return KernelInfoState(
false); }
702 KernelInfoState operator^=(
const KernelInfoState &KIS) {
704 if (KIS.KernelInitCB) {
705 if (KernelInitCB && KernelInitCB != KIS.KernelInitCB)
708 KernelInitCB = KIS.KernelInitCB;
710 if (KIS.KernelDeinitCB) {
711 if (KernelDeinitCB && KernelDeinitCB != KIS.KernelDeinitCB)
714 KernelDeinitCB = KIS.KernelDeinitCB;
716 SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker;
717 ReachedKnownParallelRegions ^= KIS.ReachedKnownParallelRegions;
718 ReachedUnknownParallelRegions ^= KIS.ReachedUnknownParallelRegions;
719 NestedParallelism |= KIS.NestedParallelism;
723 KernelInfoState
operator&=(
const KernelInfoState &KIS) {
724 return (*
this ^= KIS);
740 OffloadArray() =
default;
747 if (!Array.getAllocatedType()->isArrayTy())
750 if (!getValues(Array, Before))
753 this->Array = &Array;
757 static const unsigned DeviceIDArgNum = 1;
758 static const unsigned BasePtrsArgNum = 3;
759 static const unsigned PtrsArgNum = 4;
760 static const unsigned SizesArgNum = 5;
768 const uint64_t NumValues = Array.getAllocatedType()->getArrayNumElements();
769 StoredValues.
assign(NumValues,
nullptr);
770 LastAccesses.
assign(NumValues,
nullptr);
778 const DataLayout &
DL = Array.getModule()->getDataLayout();
779 const unsigned int PointerSize =
DL.getPointerSize();
785 if (!isa<StoreInst>(&
I))
788 auto *S = cast<StoreInst>(&
I);
795 LastAccesses[
Idx] = S;
805 const unsigned NumValues = StoredValues.
size();
806 for (
unsigned I = 0;
I < NumValues; ++
I) {
807 if (!StoredValues[
I] || !LastAccesses[
I])
817 using OptimizationRemarkGetter =
821 OptimizationRemarkGetter OREGetter,
822 OMPInformationCache &OMPInfoCache,
Attributor &A)
823 : M(*(*SCC.begin())->getParent()), SCC(SCC), CGUpdater(CGUpdater),
824 OREGetter(OREGetter), OMPInfoCache(OMPInfoCache), A(A) {}
827 bool remarksEnabled() {
828 auto &Ctx = M.getContext();
829 return Ctx.getDiagHandlerPtr()->isAnyRemarkEnabled(
DEBUG_TYPE);
833 bool run(
bool IsModulePass) {
837 bool Changed =
false;
840 <<
" functions in a slice with "
841 << OMPInfoCache.ModuleSlice.size() <<
" functions\n");
844 Changed |= runAttributor(IsModulePass);
847 OMPInfoCache.recollectUses();
850 Changed |= rewriteDeviceCodeStateMachine();
852 if (remarksEnabled())
853 analysisGlobalization();
860 Changed |= runAttributor(IsModulePass);
863 OMPInfoCache.recollectUses();
865 Changed |= deleteParallelRegions();
868 Changed |= hideMemTransfersLatency();
869 Changed |= deduplicateRuntimeCalls();
871 if (mergeParallelRegions()) {
872 deduplicateRuntimeCalls();
883 void printICVs()
const {
888 for (
auto ICV : ICVs) {
889 auto ICVInfo = OMPInfoCache.ICVs[ICV];
891 return ORA <<
"OpenMP ICV " <<
ore::NV(
"OpenMPICV", ICVInfo.Name)
893 << (ICVInfo.InitValue
894 ?
toString(ICVInfo.InitValue->getValue(), 10,
true)
895 :
"IMPLEMENTATION_DEFINED");
898 emitRemark<OptimizationRemarkAnalysis>(
F,
"OpenMPICVTracker",
Remark);
904 void printKernels()
const {
906 if (!OMPInfoCache.Kernels.count(
F))
910 return ORA <<
"OpenMP GPU kernel "
911 <<
ore::NV(
"OpenMPGPUKernel",
F->getName()) <<
"\n";
914 emitRemark<OptimizationRemarkAnalysis>(
F,
"OpenMPGPU",
Remark);
920 static CallInst *getCallIfRegularCall(
921 Use &U, OMPInformationCache::RuntimeFunctionInfo *RFI =
nullptr) {
932 static CallInst *getCallIfRegularCall(
933 Value &V, OMPInformationCache::RuntimeFunctionInfo *RFI =
nullptr) {
934 CallInst *CI = dyn_cast<CallInst>(&V);
944 bool mergeParallelRegions() {
945 const unsigned CallbackCalleeOperand = 2;
946 const unsigned CallbackFirstArgOperand = 3;
950 OMPInformationCache::RuntimeFunctionInfo &RFI =
951 OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
953 if (!RFI.Declaration)
957 OMPInformationCache::RuntimeFunctionInfo UnmergableCallsInfo[] = {
958 OMPInfoCache.RFIs[OMPRTL___kmpc_push_proc_bind],
959 OMPInfoCache.RFIs[OMPRTL___kmpc_push_num_threads],
962 bool Changed =
false;
968 BasicBlock *StartBB =
nullptr, *EndBB =
nullptr;
969 auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
972 SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
973 assert(StartBB !=
nullptr &&
"StartBB should not be null");
975 assert(EndBB !=
nullptr &&
"EndBB should not be null");
976 EndBB->getTerminator()->setSuccessor(0, CGEndBB);
979 auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
Value &,
980 Value &Inner,
Value *&ReplacementValue) -> InsertPointTy {
981 ReplacementValue = &Inner;
985 auto FiniCB = [&](InsertPointTy CodeGenIP) {};
989 auto CreateSequentialRegion = [&](
Function *OuterFn,
1001 SplitBlock(ParentBB, SeqStartI, DT, LI,
nullptr,
"seq.par.merged");
1004 "Expected a different CFG");
1008 auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
1009 BasicBlock *CGStartBB = CodeGenIP.getBlock();
1011 SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
1012 assert(SeqStartBB !=
nullptr &&
"SeqStartBB should not be null");
1014 assert(SeqEndBB !=
nullptr &&
"SeqEndBB should not be null");
1017 auto FiniCB = [&](InsertPointTy CodeGenIP) {};
1023 for (
User *Usr :
I.users()) {
1031 OutsideUsers.
insert(&UsrI);
1034 if (OutsideUsers.
empty())
1041 I.getType(),
DL.getAllocaAddrSpace(),
nullptr,
1042 I.getName() +
".seq.output.alloc", &OuterFn->
front().
front());
1052 I.getType(), AllocaI,
I.getName() +
".seq.output.load", UsrI);
1053 UsrI->replaceUsesOfWith(&
I, LoadI);
1058 InsertPointTy(ParentBB, ParentBB->
end()),
DL);
1059 InsertPointTy SeqAfterIP =
1060 OMPInfoCache.OMPBuilder.createMaster(Loc, BodyGenCB, FiniCB);
1062 OMPInfoCache.OMPBuilder.createBarrier(SeqAfterIP, OMPD_parallel);
1081 assert(MergableCIs.
size() > 1 &&
"Assumed multiple mergable CIs");
1084 OR <<
"Parallel region merged with parallel region"
1085 << (MergableCIs.
size() > 2 ?
"s" :
"") <<
" at ";
1087 OR <<
ore::NV(
"OpenMPParallelMerge", CI->getDebugLoc());
1088 if (CI != MergableCIs.
back())
1094 emitRemark<OptimizationRemark>(MergableCIs.
front(),
"OMP150",
Remark);
1098 <<
" parallel regions in " << OriginalFn->
getName()
1102 EndBB =
SplitBlock(BB, MergableCIs.
back()->getNextNode(), DT, LI);
1104 SplitBlock(EndBB, &*EndBB->getFirstInsertionPt(), DT, LI);
1108 assert(BB->getUniqueSuccessor() == StartBB &&
"Expected a different CFG");
1109 const DebugLoc DL = BB->getTerminator()->getDebugLoc();
1114 for (
auto *It = MergableCIs.
begin(), *End = MergableCIs.
end() - 1;
1123 CreateSequentialRegion(OriginalFn, BB, ForkCI->
getNextNode(),
1134 InsertPointTy AfterIP = OMPInfoCache.OMPBuilder.createParallel(
1135 Loc, AllocaIP, BodyGenCB, PrivCB, FiniCB,
nullptr,
nullptr,
1136 OMP_PROC_BIND_default,
false);
1140 OMPInfoCache.OMPBuilder.finalize(OriginalFn);
1147 for (
auto *CI : MergableCIs) {
1148 Value *
Callee = CI->getArgOperand(CallbackCalleeOperand);
1149 FunctionType *FT = OMPInfoCache.OMPBuilder.ParallelTask;
1151 Args.push_back(OutlinedFn->
getArg(0));
1152 Args.push_back(OutlinedFn->
getArg(1));
1153 for (
unsigned U = CallbackFirstArgOperand,
E = CI->arg_size(); U <
E;
1155 Args.push_back(CI->getArgOperand(U));
1158 if (CI->getDebugLoc())
1162 for (
unsigned U = CallbackFirstArgOperand,
E = CI->arg_size(); U <
E;
1164 for (
const Attribute &A : CI->getAttributes().getParamAttrs(U))
1166 U - (CallbackFirstArgOperand - CallbackCalleeOperand), A);
1169 if (CI != MergableCIs.back()) {
1172 OMPInfoCache.OMPBuilder.createBarrier(
1178 CI->eraseFromParent();
1181 assert(OutlinedFn != OriginalFn &&
"Outlining failed");
1185 NumOpenMPParallelRegionsMerged += MergableCIs.size();
1193 CallInst *CI = getCallIfRegularCall(U, &RFI);
1200 RFI.foreachUse(SCC, DetectPRsCB);
1206 for (
auto &It : BB2PRMap) {
1207 auto &CIs = It.getSecond();
1222 auto IsMergable = [&](
Instruction &
I,
bool IsBeforeMergableRegion) {
1225 if (
I.isTerminator())
1228 if (!isa<CallInst>(&
I))
1232 if (IsBeforeMergableRegion) {
1234 if (!CalledFunction)
1241 for (
const auto &RFI : UnmergableCallsInfo) {
1242 if (CalledFunction == RFI.Declaration)
1250 if (!isa<IntrinsicInst>(CI))
1257 for (
auto It = BB->
begin(), End = BB->
end(); It != End;) {
1261 if (CIs.count(&
I)) {
1267 if (IsMergable(
I, MergableCIs.
empty()))
1272 for (; It != End; ++It) {
1274 if (CIs.count(&SkipI)) {
1276 <<
" due to " <<
I <<
"\n");
1283 if (MergableCIs.
size() > 1) {
1284 MergableCIsVector.
push_back(MergableCIs);
1286 <<
" parallel regions in block " << BB->
getName()
1291 MergableCIs.
clear();
1294 if (!MergableCIsVector.
empty()) {
1297 for (
auto &MergableCIs : MergableCIsVector)
1298 Merge(MergableCIs, BB);
1299 MergableCIsVector.clear();
1306 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_fork_call);
1307 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_barrier);
1308 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_master);
1309 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_end_master);
1316 bool deleteParallelRegions() {
1317 const unsigned CallbackCalleeOperand = 2;
1319 OMPInformationCache::RuntimeFunctionInfo &RFI =
1320 OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
1322 if (!RFI.Declaration)
1325 bool Changed =
false;
1327 CallInst *CI = getCallIfRegularCall(U);
1330 auto *Fn = dyn_cast<Function>(
1334 if (!Fn->onlyReadsMemory())
1336 if (!Fn->hasFnAttribute(Attribute::WillReturn))
1343 return OR <<
"Removing parallel region with no side-effects.";
1345 emitRemark<OptimizationRemark>(CI,
"OMP160",
Remark);
1350 ++NumOpenMPParallelRegionsDeleted;
1354 RFI.foreachUse(SCC, DeleteCallCB);
1360 bool deduplicateRuntimeCalls() {
1361 bool Changed =
false;
1364 OMPRTL_omp_get_num_threads,
1365 OMPRTL_omp_in_parallel,
1366 OMPRTL_omp_get_cancellation,
1367 OMPRTL_omp_get_thread_limit,
1368 OMPRTL_omp_get_supported_active_levels,
1369 OMPRTL_omp_get_level,
1370 OMPRTL_omp_get_ancestor_thread_num,
1371 OMPRTL_omp_get_team_size,
1372 OMPRTL_omp_get_active_level,
1373 OMPRTL_omp_in_final,
1374 OMPRTL_omp_get_proc_bind,
1375 OMPRTL_omp_get_num_places,
1376 OMPRTL_omp_get_num_procs,
1377 OMPRTL_omp_get_place_num,
1378 OMPRTL_omp_get_partition_num_places,
1379 OMPRTL_omp_get_partition_place_nums};
1383 collectGlobalThreadIdArguments(GTIdArgs);
1385 <<
" global thread ID arguments\n");
1388 for (
auto DeduplicableRuntimeCallID : DeduplicableRuntimeCallIDs)
1389 Changed |= deduplicateRuntimeCalls(
1390 *
F, OMPInfoCache.RFIs[DeduplicableRuntimeCallID]);
1394 Value *GTIdArg =
nullptr;
1400 Changed |= deduplicateRuntimeCalls(
1401 *
F, OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num], GTIdArg);
1413 bool hideMemTransfersLatency() {
1414 auto &RFI = OMPInfoCache.RFIs[OMPRTL___tgt_target_data_begin_mapper];
1415 bool Changed =
false;
1416 auto SplitMemTransfers = [&](
Use &U,
Function &Decl) {
1417 auto *RTCall = getCallIfRegularCall(U, &RFI);
1421 OffloadArray OffloadArrays[3];
1422 if (!getValuesInOffloadArrays(*RTCall, OffloadArrays))
1425 LLVM_DEBUG(dumpValuesInOffloadArrays(OffloadArrays));
1428 bool WasSplit =
false;
1429 Instruction *WaitMovementPoint = canBeMovedDownwards(*RTCall);
1430 if (WaitMovementPoint)
1431 WasSplit = splitTargetDataBeginRTC(*RTCall, *WaitMovementPoint);
1433 Changed |= WasSplit;
1436 if (OMPInfoCache.runtimeFnsAvailable(
1437 {OMPRTL___tgt_target_data_begin_mapper_issue,
1438 OMPRTL___tgt_target_data_begin_mapper_wait}))
1439 RFI.foreachUse(SCC, SplitMemTransfers);
1444 void analysisGlobalization() {
1445 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
1447 auto CheckGlobalization = [&](
Use &U,
Function &Decl) {
1448 if (
CallInst *CI = getCallIfRegularCall(U, &RFI)) {
1451 <<
"Found thread data sharing on the GPU. "
1452 <<
"Expect degraded performance due to data globalization.";
1454 emitRemark<OptimizationRemarkMissed>(CI,
"OMP112",
Remark);
1460 RFI.foreachUse(SCC, CheckGlobalization);
1465 bool getValuesInOffloadArrays(
CallInst &RuntimeCall,
1467 assert(OAs.
size() == 3 &&
"Need space for three offload arrays!");
1477 Value *BasePtrsArg =
1486 if (!isa<AllocaInst>(V))
1488 auto *BasePtrsArray = cast<AllocaInst>(V);
1489 if (!OAs[0].
initialize(*BasePtrsArray, RuntimeCall))
1494 if (!isa<AllocaInst>(V))
1496 auto *PtrsArray = cast<AllocaInst>(V);
1497 if (!OAs[1].
initialize(*PtrsArray, RuntimeCall))
1503 if (isa<GlobalValue>(V))
1504 return isa<Constant>(V);
1505 if (!isa<AllocaInst>(V))
1508 auto *SizesArray = cast<AllocaInst>(V);
1509 if (!OAs[2].
initialize(*SizesArray, RuntimeCall))
1520 assert(OAs.
size() == 3 &&
"There are three offload arrays to debug!");
1523 std::string ValuesStr;
1525 std::string Separator =
" --- ";
1527 for (
auto *BP : OAs[0].StoredValues) {
1534 for (
auto *
P : OAs[1].StoredValues) {
1541 for (
auto *S : OAs[2].StoredValues) {
1555 bool IsWorthIt =
false;
1578 bool splitTargetDataBeginRTC(
CallInst &RuntimeCall,
1583 auto &
IRBuilder = OMPInfoCache.OMPBuilder;
1587 Entry.getFirstNonPHIOrDbgOrAlloca());
1589 IRBuilder.AsyncInfo,
nullptr,
"handle");
1597 M, OMPRTL___tgt_target_data_begin_mapper_issue);
1601 for (
auto &
Arg : RuntimeCall.
args())
1602 Args.push_back(
Arg.get());
1603 Args.push_back(Handle);
1607 OMPInfoCache.setCallingConvention(IssueDecl, IssueCallsite);
1613 M, OMPRTL___tgt_target_data_begin_mapper_wait);
1615 Value *WaitParams[2] = {
1617 OffloadArray::DeviceIDArgNum),
1621 WaitDecl, WaitParams,
"", &WaitMovementPoint);
1622 OMPInfoCache.setCallingConvention(WaitDecl, WaitCallsite);
1627 static Value *combinedIdentStruct(
Value *CurrentIdent,
Value *NextIdent,
1628 bool GlobalOnly,
bool &SingleChoice) {
1629 if (CurrentIdent == NextIdent)
1630 return CurrentIdent;
1634 if (!GlobalOnly || isa<GlobalValue>(NextIdent)) {
1635 SingleChoice = !CurrentIdent;
1647 getCombinedIdentFromCallUsesIn(OMPInformationCache::RuntimeFunctionInfo &RFI,
1649 bool SingleChoice =
true;
1650 Value *Ident =
nullptr;
1651 auto CombineIdentStruct = [&](
Use &U,
Function &Caller) {
1652 CallInst *CI = getCallIfRegularCall(U, &RFI);
1653 if (!CI || &
F != &Caller)
1656 true, SingleChoice);
1659 RFI.foreachUse(SCC, CombineIdentStruct);
1661 if (!Ident || !SingleChoice) {
1664 if (!OMPInfoCache.OMPBuilder.getInsertionPoint().getBlock())
1666 &
F.getEntryBlock(),
F.getEntryBlock().begin()));
1671 OMPInfoCache.OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
1672 Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(Loc, SrcLocStrSize);
1679 bool deduplicateRuntimeCalls(
Function &
F,
1680 OMPInformationCache::RuntimeFunctionInfo &RFI,
1681 Value *ReplVal =
nullptr) {
1682 auto *UV = RFI.getUseVector(
F);
1683 if (!UV || UV->size() + (ReplVal !=
nullptr) < 2)
1687 dbgs() <<
TAG <<
"Deduplicate " << UV->size() <<
" uses of " << RFI.Name
1688 << (ReplVal ?
" with an existing value\n" :
"\n") <<
"\n");
1690 assert((!ReplVal || (isa<Argument>(ReplVal) &&
1691 cast<Argument>(ReplVal)->
getParent() == &
F)) &&
1692 "Unexpected replacement value!");
1695 auto CanBeMoved = [
this](
CallBase &CB) {
1696 unsigned NumArgs = CB.arg_size();
1699 if (CB.getArgOperand(0)->getType() != OMPInfoCache.OMPBuilder.IdentPtr)
1701 for (
unsigned U = 1; U < NumArgs; ++U)
1702 if (isa<Instruction>(CB.getArgOperand(U)))
1709 if (
CallInst *CI = getCallIfRegularCall(*U, &RFI)) {
1710 if (!CanBeMoved(*CI))
1717 auto &KernelInitRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
1718 auto *KernelInitUV = KernelInitRFI.getUseVector(
F);
1720 if (KernelInitUV->empty())
1723 assert(KernelInitUV->size() == 1 &&
1724 "Expected a single __kmpc_target_init in kernel\n");
1727 getCallIfRegularCall(*KernelInitUV->front(), &KernelInitRFI);
1729 "Expected a call to __kmpc_target_init in kernel\n");
1731 CI->moveAfter(KernelInitCI);
1733 CI->moveBefore(&*
F.getEntryBlock().getFirstInsertionPt());
1744 if (
CallBase *CI = dyn_cast<CallBase>(ReplVal)) {
1745 if (!CI->arg_empty() &&
1746 CI->getArgOperand(0)->getType() == OMPInfoCache.OMPBuilder.IdentPtr) {
1747 Value *Ident = getCombinedIdentFromCallUsesIn(RFI,
F,
1749 CI->setArgOperand(0, Ident);
1753 bool Changed =
false;
1754 auto ReplaceAndDeleteCB = [&](
Use &U,
Function &Caller) {
1755 CallInst *CI = getCallIfRegularCall(U, &RFI);
1756 if (!CI || CI == ReplVal || &
F != &Caller)
1761 return OR <<
"OpenMP runtime call "
1762 <<
ore::NV(
"OpenMPOptRuntime", RFI.Name) <<
" deduplicated.";
1765 emitRemark<OptimizationRemark>(CI,
"OMP170",
Remark);
1767 emitRemark<OptimizationRemark>(&
F,
"OMP170",
Remark);
1772 ++NumOpenMPRuntimeCallsDeduplicated;
1776 RFI.foreachUse(SCC, ReplaceAndDeleteCB);
1790 if (!
F.hasLocalLinkage())
1792 for (
Use &U :
F.uses()) {
1793 if (
CallInst *CI = getCallIfRegularCall(U)) {
1794 Value *ArgOp = CI->getArgOperand(ArgNo);
1795 if (CI == &RefCI || GTIdArgs.
count(ArgOp) ||
1796 getCallIfRegularCall(
1797 *ArgOp, &OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num]))
1806 auto AddUserArgs = [&](
Value >Id) {
1807 for (
Use &U : GTId.uses())
1809 if (CI->isArgOperand(&U))
1816 OMPInformationCache::RuntimeFunctionInfo &GlobThreadNumRFI =
1817 OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num];
1819 GlobThreadNumRFI.foreachUse(SCC, [&](
Use &U,
Function &
F) {
1820 if (
CallInst *CI = getCallIfRegularCall(U, &GlobThreadNumRFI))
1828 for (
unsigned U = 0; U < GTIdArgs.
size(); ++U)
1829 AddUserArgs(*GTIdArgs[U]);
1837 bool isKernel(
Function &
F) {
return OMPInfoCache.Kernels.count(&
F); }
1847 return getUniqueKernelFor(*
I.getFunction());
1852 bool rewriteDeviceCodeStateMachine();
1868 template <
typename RemarkKind,
typename RemarkCallBack>
1870 RemarkCallBack &&RemarkCB)
const {
1872 auto &ORE = OREGetter(
F);
1876 return RemarkCB(RemarkKind(
DEBUG_TYPE, RemarkName,
I))
1877 <<
" [" << RemarkName <<
"]";
1881 [&]() {
return RemarkCB(RemarkKind(
DEBUG_TYPE, RemarkName,
I)); });
1885 template <
typename RemarkKind,
typename RemarkCallBack>
1887 RemarkCallBack &&RemarkCB)
const {
1888 auto &ORE = OREGetter(
F);
1892 return RemarkCB(RemarkKind(
DEBUG_TYPE, RemarkName,
F))
1893 <<
" [" << RemarkName <<
"]";
1897 [&]() {
return RemarkCB(RemarkKind(
DEBUG_TYPE, RemarkName,
F)); });
1911 OptimizationRemarkGetter OREGetter;
1914 OMPInformationCache &OMPInfoCache;
1920 bool runAttributor(
bool IsModulePass) {
1924 registerAAs(IsModulePass);
1929 <<
" functions, result: " << Changed <<
".\n");
1931 return Changed == ChangeStatus::CHANGED;
1938 void registerAAs(
bool IsModulePass);
1947 if (!OMPInfoCache.ModuleSlice.empty() && !OMPInfoCache.ModuleSlice.count(&
F))
1952 std::optional<Kernel> &CachedKernel = UniqueKernelMap[&
F];
1954 return *CachedKernel;
1961 return *CachedKernel;
1964 CachedKernel =
nullptr;
1965 if (!
F.hasLocalLinkage()) {
1969 return ORA <<
"Potentially unknown OpenMP target region caller.";
1971 emitRemark<OptimizationRemarkAnalysis>(&
F,
"OMP100",
Remark);
1977 auto GetUniqueKernelForUse = [&](
const Use &U) ->
Kernel {
1978 if (
auto *Cmp = dyn_cast<ICmpInst>(U.getUser())) {
1980 if (Cmp->isEquality())
1981 return getUniqueKernelFor(*Cmp);
1984 if (
auto *CB = dyn_cast<CallBase>(U.getUser())) {
1986 if (CB->isCallee(&U))
1987 return getUniqueKernelFor(*CB);
1989 OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
1990 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
1992 if (OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI))
1993 return getUniqueKernelFor(*CB);
2002 OMPInformationCache::foreachUse(
F, [&](
const Use &U) {
2003 PotentialKernels.
insert(GetUniqueKernelForUse(U));
2007 if (PotentialKernels.
size() == 1)
2008 K = *PotentialKernels.
begin();
2011 UniqueKernelMap[&
F] = K;
2016bool OpenMPOpt::rewriteDeviceCodeStateMachine() {
2017 OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
2018 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
2020 bool Changed =
false;
2021 if (!KernelParallelRFI)
2032 bool UnknownUse =
false;
2033 bool KernelParallelUse =
false;
2034 unsigned NumDirectCalls = 0;
2037 OMPInformationCache::foreachUse(*
F, [&](
Use &U) {
2038 if (
auto *CB = dyn_cast<CallBase>(U.
getUser()))
2039 if (CB->isCallee(&U)) {
2045 ToBeReplacedStateMachineUses.push_back(&U);
2051 OpenMPOpt::getCallIfRegularCall(*U.
getUser(), &KernelParallelRFI);
2052 const unsigned int WrapperFunctionArgNo = 6;
2053 if (!KernelParallelUse && CI &&
2054 CI->getArgOperandNo(&U) == WrapperFunctionArgNo) {
2055 KernelParallelUse = true;
2056 ToBeReplacedStateMachineUses.push_back(&U);
2064 if (!KernelParallelUse)
2070 if (UnknownUse || NumDirectCalls != 1 ||
2071 ToBeReplacedStateMachineUses.
size() > 2) {
2073 return ORA <<
"Parallel region is used in "
2074 << (UnknownUse ?
"unknown" :
"unexpected")
2075 <<
" ways. Will not attempt to rewrite the state machine.";
2077 emitRemark<OptimizationRemarkAnalysis>(
F,
"OMP101",
Remark);
2083 Kernel K = getUniqueKernelFor(*
F);
2086 return ORA <<
"Parallel region is not called from a unique kernel. "
2087 "Will not attempt to rewrite the state machine.";
2089 emitRemark<OptimizationRemarkAnalysis>(
F,
"OMP102",
Remark);
2105 for (
Use *U : ToBeReplacedStateMachineUses)
2109 ++NumOpenMPParallelRegionsReplacedInGPUStateMachine;
2118struct AAICVTracker :
public StateWrapper<BooleanState, AbstractAttribute> {
2124 if (!
F || !
A.isFunctionIPOAmendable(*
F))
2125 indicatePessimisticFixpoint();
2129 bool isAssumedTracked()
const {
return getAssumed(); }
2132 bool isKnownTracked()
const {
return getAssumed(); }
2141 return std::nullopt;
2147 virtual std::optional<Value *>
2155 const std::string
getName()
const override {
return "AAICVTracker"; }
2158 const char *getIdAddr()
const override {
return &
ID; }
2165 static const char ID;
2168struct AAICVTrackerFunction :
public AAICVTracker {
2170 : AAICVTracker(IRP,
A) {}
2173 const std::string getAsStr()
const override {
return "ICVTrackerFunction"; }
2176 void trackStatistics()
const override {}
2180 return ChangeStatus::UNCHANGED;
2185 InternalControlVar::ICV___last>
2186 ICVReplacementValuesMap;
2193 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
2196 auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2198 auto &ValuesMap = ICVReplacementValuesMap[ICV];
2200 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U);
2206 if (ValuesMap.insert(std::make_pair(CI, CI->
getArgOperand(0))).second)
2207 HasChanged = ChangeStatus::CHANGED;
2213 std::optional<Value *> ReplVal = getValueForCall(
A,
I, ICV);
2214 if (ReplVal && ValuesMap.insert(std::make_pair(&
I, *ReplVal)).second)
2215 HasChanged = ChangeStatus::CHANGED;
2221 SetterRFI.foreachUse(TrackValues,
F);
2223 bool UsedAssumedInformation =
false;
2224 A.checkForAllInstructions(CallCheck, *
this, {Instruction::Call},
2225 UsedAssumedInformation,
2231 if (HasChanged == ChangeStatus::CHANGED && !ValuesMap.count(Entry))
2232 ValuesMap.insert(std::make_pair(Entry,
nullptr));
2243 const auto *CB = dyn_cast<CallBase>(&
I);
2244 if (!CB || CB->hasFnAttr(
"no_openmp") ||
2245 CB->hasFnAttr(
"no_openmp_routines"))
2246 return std::nullopt;
2248 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
2249 auto &GetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Getter];
2250 auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2251 Function *CalledFunction = CB->getCalledFunction();
2254 if (CalledFunction ==
nullptr)
2256 if (CalledFunction == GetterRFI.Declaration)
2257 return std::nullopt;
2258 if (CalledFunction == SetterRFI.Declaration) {
2259 if (ICVReplacementValuesMap[ICV].
count(&
I))
2260 return ICVReplacementValuesMap[ICV].
lookup(&
I);
2269 const auto &ICVTrackingAA =
A.getAAFor<AAICVTracker>(
2272 if (ICVTrackingAA.isAssumedTracked()) {
2273 std::optional<Value *> URV = ICVTrackingAA.getUniqueReplacementValue(ICV);
2284 std::optional<Value *>
2286 return std::nullopt;
2293 const auto &ValuesMap = ICVReplacementValuesMap[ICV];
2294 if (ValuesMap.count(
I))
2295 return ValuesMap.lookup(
I);
2301 std::optional<Value *> ReplVal;
2303 while (!Worklist.
empty()) {
2305 if (!Visited.
insert(CurrInst).second)
2313 if (ValuesMap.count(CurrInst)) {
2314 std::optional<Value *> NewReplVal = ValuesMap.lookup(CurrInst);
2317 ReplVal = NewReplVal;
2323 if (ReplVal != NewReplVal)
2329 std::optional<Value *> NewReplVal = getValueForCall(
A, *CurrInst, ICV);
2335 ReplVal = NewReplVal;
2341 if (ReplVal != NewReplVal)
2346 if (CurrBB ==
I->getParent() && ReplVal)
2351 if (
const Instruction *Terminator = Pred->getTerminator())
2359struct AAICVTrackerFunctionReturned : AAICVTracker {
2361 : AAICVTracker(IRP,
A) {}
2364 const std::string getAsStr()
const override {
2365 return "ICVTrackerFunctionReturned";
2369 void trackStatistics()
const override {}
2373 return ChangeStatus::UNCHANGED;
2378 InternalControlVar::ICV___last>
2379 ICVReplacementValuesMap;
2382 std::optional<Value *>
2384 return ICVReplacementValuesMap[ICV];
2389 const auto &ICVTrackingAA =
A.getAAFor<AAICVTracker>(
2392 if (!ICVTrackingAA.isAssumedTracked())
2393 return indicatePessimisticFixpoint();
2396 std::optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2397 std::optional<Value *> UniqueICVValue;
2400 std::optional<Value *> NewReplVal =
2401 ICVTrackingAA.getReplacementValue(ICV, &
I,
A);
2404 if (UniqueICVValue && UniqueICVValue != NewReplVal)
2407 UniqueICVValue = NewReplVal;
2412 bool UsedAssumedInformation =
false;
2413 if (!
A.checkForAllInstructions(CheckReturnInst, *
this, {Instruction::Ret},
2414 UsedAssumedInformation,
2416 UniqueICVValue =
nullptr;
2418 if (UniqueICVValue == ReplVal)
2421 ReplVal = UniqueICVValue;
2422 Changed = ChangeStatus::CHANGED;
2429struct AAICVTrackerCallSite : AAICVTracker {
2431 : AAICVTracker(IRP,
A) {}
2435 if (!
F || !
A.isFunctionIPOAmendable(*
F))
2436 indicatePessimisticFixpoint();
2440 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
2442 auto ICVInfo = OMPInfoCache.ICVs[ICV];
2443 auto &Getter = OMPInfoCache.RFIs[ICVInfo.Getter];
2444 if (Getter.Declaration == getAssociatedFunction()) {
2445 AssociatedICV = ICVInfo.Kind;
2451 indicatePessimisticFixpoint();
2455 if (!ReplVal || !*ReplVal)
2456 return ChangeStatus::UNCHANGED;
2459 A.deleteAfterManifest(*getCtxI());
2461 return ChangeStatus::CHANGED;
2465 const std::string getAsStr()
const override {
return "ICVTrackerCallSite"; }
2468 void trackStatistics()
const override {}
2471 std::optional<Value *> ReplVal;
2474 const auto &ICVTrackingAA =
A.getAAFor<AAICVTracker>(
2478 if (!ICVTrackingAA.isAssumedTracked())
2479 return indicatePessimisticFixpoint();
2481 std::optional<Value *> NewReplVal =
2482 ICVTrackingAA.getReplacementValue(AssociatedICV, getCtxI(),
A);
2484 if (ReplVal == NewReplVal)
2485 return ChangeStatus::UNCHANGED;
2487 ReplVal = NewReplVal;
2488 return ChangeStatus::CHANGED;
2493 std::optional<Value *>
2499struct AAICVTrackerCallSiteReturned : AAICVTracker {
2501 : AAICVTracker(IRP,
A) {}
2504 const std::string getAsStr()
const override {
2505 return "ICVTrackerCallSiteReturned";
2509 void trackStatistics()
const override {}
2513 return ChangeStatus::UNCHANGED;
2518 InternalControlVar::ICV___last>
2519 ICVReplacementValuesMap;
2523 std::optional<Value *>
2525 return ICVReplacementValuesMap[ICV];
2530 const auto &ICVTrackingAA =
A.getAAFor<AAICVTracker>(
2532 DepClassTy::REQUIRED);
2535 if (!ICVTrackingAA.isAssumedTracked())
2536 return indicatePessimisticFixpoint();
2539 std::optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2540 std::optional<Value *> NewReplVal =
2541 ICVTrackingAA.getUniqueReplacementValue(ICV);
2543 if (ReplVal == NewReplVal)
2546 ReplVal = NewReplVal;
2547 Changed = ChangeStatus::CHANGED;
2557 ~AAExecutionDomainFunction() {
delete RPOT; }
2567 const std::string
getAsStr()
const override {
2568 unsigned TotalBlocks = 0, InitialThreadBlocks = 0;
2569 for (
auto &It : BEDMap) {
2571 InitialThreadBlocks += It.getSecond().IsExecutedByInitialThreadOnly;
2573 return "[AAExecutionDomain] " + std::to_string(InitialThreadBlocks) +
"/" +
2574 std::to_string(TotalBlocks) +
" executed by initial thread only";
2586 << BB.getName() <<
" is executed by a single thread.\n";
2596 auto HandleAlignedBarrier = [&](
CallBase *CB) {
2597 const ExecutionDomainTy &ED = CEDMap[CB];
2598 if (!ED.IsReachedFromAlignedBarrierOnly ||
2599 ED.EncounteredNonLocalSideEffect)
2607 DeletedBarriers.
insert(CB);
2608 A.deleteAfterManifest(*CB);
2609 ++NumBarriersEliminated;
2611 }
else if (!ED.AlignedBarriers.empty()) {
2612 NumBarriersEliminated += ED.AlignedBarriers.size();
2615 ED.AlignedBarriers.end());
2617 while (!Worklist.empty()) {
2618 CallBase *LastCB = Worklist.pop_back_val();
2619 if (!Visited.
insert(LastCB))
2621 if (!DeletedBarriers.
count(LastCB)) {
2622 A.deleteAfterManifest(*LastCB);
2628 const ExecutionDomainTy &LastED = CEDMap[LastCB];
2629 Worklist.append(LastED.AlignedBarriers.begin(),
2630 LastED.AlignedBarriers.end());
2636 if (!ED.EncounteredAssumes.empty() && (CB || !ED.AlignedBarriers.empty()))
2637 for (
auto *AssumeCB : ED.EncounteredAssumes)
2638 A.deleteAfterManifest(*AssumeCB);
2641 for (
auto *CB : AlignedBarriers)
2642 HandleAlignedBarrier(CB);
2644 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
2646 if (OMPInfoCache.Kernels.count(getAnchorScope()))
2647 HandleAlignedBarrier(
nullptr);
2655 mergeInPredecessorBarriersAndAssumptions(
Attributor &
A, ExecutionDomainTy &ED,
2656 const ExecutionDomainTy &PredED);
2661 void mergeInPredecessor(
Attributor &
A, ExecutionDomainTy &ED,
2662 const ExecutionDomainTy &PredED,
2663 bool InitialEdgeOnly =
false);
2666 void handleEntryBB(
Attributor &
A, ExecutionDomainTy &EntryBBED);
2676 return BEDMap.lookup(&BB).IsExecutedByInitialThreadOnly;
2681 assert(
I.getFunction() == getAnchorScope() &&
2682 "Instruction is out of scope!");
2691 auto *CB = dyn_cast<CallBase>(CurI);
2694 if (CB != &
I && AlignedBarriers.contains(
const_cast<CallBase *
>(CB))) {
2697 const auto &It = CEDMap.find(CB);
2698 if (It == CEDMap.end())
2700 if (!It->getSecond().IsReachingAlignedBarrierOnly)
2705 if (!CurI && !BEDMap.lookup(
I.getParent()).IsReachingAlignedBarrierOnly)
2711 auto *CB = dyn_cast<CallBase>(CurI);
2714 if (CB != &
I && AlignedBarriers.contains(
const_cast<CallBase *
>(CB))) {
2717 const auto &It = CEDMap.find(CB);
2718 if (It == CEDMap.end())
2721 if (It->getSecond().IsReachedFromAlignedBarrierOnly) {
2732 if (!EDAA.getState().isValidState())
2734 if (!EDAA.getFunctionExecutionDomain().IsReachedFromAlignedBarrierOnly)
2742 return BEDMap.lookup(PredBB).IsReachedFromAlignedBarrierOnly;
2753 "No request should be made against an invalid state!");
2754 return BEDMap.lookup(&BB);
2758 "No request should be made against an invalid state!");
2759 return CEDMap.lookup(&CB);
2763 "No request should be made against an invalid state!");
2764 return BEDMap.lookup(
nullptr);
2778 if (!Cmp || !
Cmp->isTrueWhenEqual() || !
Cmp->isEquality())
2786 if (
C->isAllOnesValue()) {
2787 auto *CB = dyn_cast<CallBase>(
Cmp->getOperand(0));
2788 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
2789 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
2790 CB = CB ? OpenMPOpt::getCallIfRegularCall(*CB, &RFI) : nullptr;
2793 const int InitModeArgNo = 1;
2794 auto *ModeCI = dyn_cast<ConstantInt>(CB->
getOperand(InitModeArgNo));
2800 if (
auto *II = dyn_cast<IntrinsicInst>(
Cmp->getOperand(0)))
2801 if (II->getIntrinsicID() == Intrinsic::nvvm_read_ptx_sreg_tid_x)
2805 if (
auto *II = dyn_cast<IntrinsicInst>(
Cmp->getOperand(0)))
2806 if (II->getIntrinsicID() == Intrinsic::amdgcn_workitem_id_x)
2821void AAExecutionDomainFunction::mergeInPredecessorBarriersAndAssumptions(
2822 Attributor &
A, ExecutionDomainTy &ED,
const ExecutionDomainTy &PredED) {
2823 for (
auto *EA : PredED.EncounteredAssumes)
2824 ED.addAssumeInst(
A, *EA);
2826 for (
auto *AB : PredED.AlignedBarriers)
2827 ED.addAlignedBarrier(
A, *AB);
2830void AAExecutionDomainFunction::mergeInPredecessor(
2831 Attributor &
A, ExecutionDomainTy &ED,
const ExecutionDomainTy &PredED,
2832 bool InitialEdgeOnly) {
2833 ED.IsExecutedByInitialThreadOnly =
2834 InitialEdgeOnly || (PredED.IsExecutedByInitialThreadOnly &&
2835 ED.IsExecutedByInitialThreadOnly);
2837 ED.IsReachedFromAlignedBarrierOnly = ED.IsReachedFromAlignedBarrierOnly &&
2838 PredED.IsReachedFromAlignedBarrierOnly;
2839 ED.EncounteredNonLocalSideEffect =
2840 ED.EncounteredNonLocalSideEffect | PredED.EncounteredNonLocalSideEffect;
2841 if (ED.IsReachedFromAlignedBarrierOnly)
2842 mergeInPredecessorBarriersAndAssumptions(
A, ED, PredED);
2844 ED.clearAssumeInstAndAlignedBarriers();
2847void AAExecutionDomainFunction::handleEntryBB(
Attributor &
A,
2848 ExecutionDomainTy &EntryBBED) {
2853 DepClassTy::OPTIONAL);
2854 if (!EDAA.getState().isValidState())
2857 EDAA.getExecutionDomain(*cast<CallBase>(ACS.getInstruction())));
2861 bool AllCallSitesKnown;
2862 if (
A.checkForAllCallSites(PredForCallSite, *
this,
2864 AllCallSitesKnown)) {
2865 for (
const auto &PredED : PredExecDomains)
2866 mergeInPredecessor(
A, EntryBBED, PredED);
2871 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
2872 if (OMPInfoCache.Kernels.count(getAnchorScope())) {
2873 EntryBBED.IsExecutedByInitialThreadOnly =
false;
2874 EntryBBED.IsReachedFromAlignedBarrierOnly =
true;
2875 EntryBBED.EncounteredNonLocalSideEffect =
false;
2877 EntryBBED.IsExecutedByInitialThreadOnly =
false;
2878 EntryBBED.IsReachedFromAlignedBarrierOnly =
false;
2879 EntryBBED.EncounteredNonLocalSideEffect =
true;
2883 auto &FnED = BEDMap[
nullptr];
2884 FnED.IsReachingAlignedBarrierOnly &=
2885 EntryBBED.IsReachedFromAlignedBarrierOnly;
2890 bool Changed =
false;
2895 auto HandleAlignedBarrier = [&](
CallBase *CB, ExecutionDomainTy &ED) {
2897 Changed |= AlignedBarriers.insert(CB);
2899 auto &CallED = CEDMap[CB];
2900 mergeInPredecessor(
A, CallED, ED);
2902 ED.EncounteredNonLocalSideEffect =
false;
2903 ED.IsReachedFromAlignedBarrierOnly =
true;
2905 ED.clearAssumeInstAndAlignedBarriers();
2907 ED.addAlignedBarrier(
A, *CB);
2911 A.getAAFor<
AAIsDead>(*
this, getIRPosition(), DepClassTy::OPTIONAL);
2914 auto SetAndRecord = [&](
bool &
R,
bool V) {
2920 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
2924 bool IsKernel = OMPInfoCache.Kernels.count(
F);
2927 for (
auto &RIt : *RPOT) {
2930 bool IsEntryBB = &BB == &EntryBB;
2933 bool AlignedBarrierLastInBlock = IsEntryBB && IsKernel;
2934 ExecutionDomainTy ED;
2937 handleEntryBB(
A, ED);
2941 if (LivenessAA.isAssumedDead(&BB))
2945 if (LivenessAA.isEdgeDead(PredBB, &BB))
2947 bool InitialEdgeOnly = isInitialThreadOnlyEdge(
2948 A, dyn_cast<BranchInst>(PredBB->getTerminator()), BB);
2949 mergeInPredecessor(
A, ED, BEDMap[PredBB], InitialEdgeOnly);
2956 bool UsedAssumedInformation;
2957 if (
A.isAssumedDead(
I, *
this, &LivenessAA, UsedAssumedInformation,
2958 false, DepClassTy::OPTIONAL,
2964 if (
auto *II = dyn_cast<IntrinsicInst>(&
I)) {
2965 if (
auto *AI = dyn_cast_or_null<AssumeInst>(II)) {
2966 ED.addAssumeInst(
A, *AI);
2970 if (II->isAssumeLikeIntrinsic())
2974 auto *CB = dyn_cast<CallBase>(&
I);
2976 bool IsAlignedBarrier =
2980 AlignedBarrierLastInBlock &= IsNoSync;
2986 if (IsAlignedBarrier) {
2987 HandleAlignedBarrier(CB, ED);
2988 AlignedBarrierLastInBlock =
true;
2993 if (isa<MemIntrinsic>(&
I)) {
2994 if (!ED.EncounteredNonLocalSideEffect &&
2996 ED.EncounteredNonLocalSideEffect =
true;
2998 ED.IsReachedFromAlignedBarrierOnly =
false;
3006 auto &CallED = CEDMap[CB];
3007 mergeInPredecessor(
A, CallED, ED);
3016 if (EDAA.getState().isValidState()) {
3019 CallED.IsReachedFromAlignedBarrierOnly =
3020 CalleeED.IsReachedFromAlignedBarrierOnly;
3021 AlignedBarrierLastInBlock = ED.IsReachedFromAlignedBarrierOnly;
3022 if (IsNoSync || !CalleeED.IsReachedFromAlignedBarrierOnly)
3023 ED.EncounteredNonLocalSideEffect |=
3024 CalleeED.EncounteredNonLocalSideEffect;
3026 ED.EncounteredNonLocalSideEffect =
3027 CalleeED.EncounteredNonLocalSideEffect;
3028 if (!CalleeED.IsReachingAlignedBarrierOnly)
3030 if (CalleeED.IsReachedFromAlignedBarrierOnly)
3031 mergeInPredecessorBarriersAndAssumptions(
A, ED, CalleeED);
3036 ED.IsReachedFromAlignedBarrierOnly =
3037 CallED.IsReachedFromAlignedBarrierOnly =
false;
3038 AlignedBarrierLastInBlock &= ED.IsReachedFromAlignedBarrierOnly;
3044 if (!
I.mayHaveSideEffects() && !
I.mayReadFromMemory())
3058 if (MemAA.getState().isValidState() &&
3059 MemAA.checkForAllAccessesToMemoryKind(
3064 if (!
I.mayHaveSideEffects() && OMPInfoCache.isOnlyUsedByAssume(
I))
3067 if (
auto *LI = dyn_cast<LoadInst>(&
I))
3068 if (LI->hasMetadata(LLVMContext::MD_invariant_load))
3071 if (!ED.EncounteredNonLocalSideEffect &&
3073 ED.EncounteredNonLocalSideEffect =
true;
3076 if (!isa<UnreachableInst>(BB.getTerminator()) &&
3077 !BB.getTerminator()->getNumSuccessors()) {
3079 auto &FnED = BEDMap[
nullptr];
3080 mergeInPredecessor(
A, FnED, ED);
3083 HandleAlignedBarrier(
nullptr, ED);
3086 ExecutionDomainTy &StoredED = BEDMap[&BB];
3087 ED.IsReachingAlignedBarrierOnly = StoredED.IsReachingAlignedBarrierOnly;
3093 if (ED.IsExecutedByInitialThreadOnly !=
3094 StoredED.IsExecutedByInitialThreadOnly ||
3095 ED.IsReachedFromAlignedBarrierOnly !=
3096 StoredED.IsReachedFromAlignedBarrierOnly ||
3097 ED.EncounteredNonLocalSideEffect !=
3098 StoredED.EncounteredNonLocalSideEffect)
3102 StoredED = std::move(ED);
3108 while (!SyncInstWorklist.
empty()) {
3111 bool HitAlignedBarrier =
false;
3113 auto *CB = dyn_cast<CallBase>(CurInst);
3116 auto &CallED = CEDMap[CB];
3117 if (SetAndRecord(CallED.IsReachingAlignedBarrierOnly,
false))
3119 HitAlignedBarrier = AlignedBarriers.count(CB);
3120 if (HitAlignedBarrier)
3123 if (HitAlignedBarrier)
3127 if (LivenessAA.isEdgeDead(PredBB, SyncBB))
3129 if (!Visited.
insert(PredBB))
3131 SyncInstWorklist.
push_back(PredBB->getTerminator());
3132 auto &PredED = BEDMap[PredBB];
3133 if (SetAndRecord(PredED.IsReachingAlignedBarrierOnly,
false))
3136 if (SyncBB != &EntryBB)
3138 auto &FnED = BEDMap[
nullptr];
3139 if (SetAndRecord(FnED.IsReachingAlignedBarrierOnly,
false))
3143 return Changed ? ChangeStatus::CHANGED : ChangeStatus::UNCHANGED;
3148struct AAHeapToShared :
public StateWrapper<BooleanState, AbstractAttribute> {
3153 static AAHeapToShared &createForPosition(
const IRPosition &IRP,
3157 virtual bool isAssumedHeapToShared(
CallBase &CB)
const = 0;
3161 virtual bool isAssumedHeapToSharedRemovedFree(
CallBase &CB)
const = 0;
3164 const std::string
getName()
const override {
return "AAHeapToShared"; }
3167 const char *getIdAddr()
const override {
return &
ID; }
3176 static const char ID;
3179struct AAHeapToSharedFunction :
public AAHeapToShared {
3181 : AAHeapToShared(IRP,
A) {}
3183 const std::string getAsStr()
const override {
3184 return "[AAHeapToShared] " + std::to_string(MallocCalls.size()) +
3185 " malloc calls eligible.";
3189 void trackStatistics()
const override {}
3193 void findPotentialRemovedFreeCalls(
Attributor &
A) {
3194 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
3195 auto &FreeRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
3197 PotentialRemovedFreeCalls.clear();
3201 for (
auto *U : CB->
users()) {
3203 if (
C &&
C->getCalledFunction() == FreeRFI.Declaration)
3207 if (FreeCalls.
size() != 1)
3210 PotentialRemovedFreeCalls.insert(FreeCalls.
front());
3216 indicatePessimisticFixpoint();
3220 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
3221 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3222 if (!RFI.Declaration)
3227 bool &) -> std::optional<Value *> {
return nullptr; };
3230 for (
User *U : RFI.Declaration->
users())
3231 if (
CallBase *CB = dyn_cast<CallBase>(U)) {
3234 MallocCalls.insert(CB);
3239 findPotentialRemovedFreeCalls(
A);
3242 bool isAssumedHeapToShared(
CallBase &CB)
const override {
3243 return isValidState() && MallocCalls.count(&CB);
3246 bool isAssumedHeapToSharedRemovedFree(
CallBase &CB)
const override {
3247 return isValidState() && PotentialRemovedFreeCalls.count(&CB);
3251 if (MallocCalls.empty())
3252 return ChangeStatus::UNCHANGED;
3254 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
3255 auto &FreeCall = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
3259 DepClassTy::OPTIONAL);
3264 if (HS &&
HS->isAssumedHeapToStack(*CB))
3269 for (
auto *U : CB->
users()) {
3271 if (
C &&
C->getCalledFunction() == FreeCall.Declaration)
3274 if (FreeCalls.
size() != 1)
3281 <<
" with shared memory."
3282 <<
" Shared memory usage is limited to "
3288 <<
" with " << AllocSize->getZExtValue()
3289 <<
" bytes of shared memory\n");
3295 Type *Int8ArrTy = ArrayType::get(Int8Ty, AllocSize->getZExtValue());
3300 static_cast<unsigned>(AddressSpace::Shared));
3305 return OR <<
"Replaced globalized variable with "
3306 <<
ore::NV(
"SharedMemory", AllocSize->getZExtValue())
3307 << ((AllocSize->getZExtValue() != 1) ?
" bytes " :
" byte ")
3308 <<
"of shared memory.";
3314 "HeapToShared on allocation without alignment attribute");
3315 SharedMem->setAlignment(*Alignment);
3318 A.deleteAfterManifest(*CB);
3319 A.deleteAfterManifest(*FreeCalls.
front());
3321 SharedMemoryUsed += AllocSize->getZExtValue();
3322 NumBytesMovedToSharedMemory = SharedMemoryUsed;
3323 Changed = ChangeStatus::CHANGED;
3330 if (MallocCalls.empty())
3331 return indicatePessimisticFixpoint();
3332 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
3333 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3334 if (!RFI.Declaration)
3335 return ChangeStatus::UNCHANGED;
3339 auto NumMallocCalls = MallocCalls.size();
3342 for (
User *U : RFI.Declaration->
users()) {
3343 if (
CallBase *CB = dyn_cast<CallBase>(U)) {
3344 if (CB->getCaller() !=
F)
3346 if (!MallocCalls.count(CB))
3348 if (!isa<ConstantInt>(CB->getArgOperand(0))) {
3349 MallocCalls.remove(CB);
3354 if (!ED.isExecutedByInitialThreadOnly(*CB))
3355 MallocCalls.remove(CB);
3359 findPotentialRemovedFreeCalls(
A);
3361 if (NumMallocCalls != MallocCalls.size())
3362 return ChangeStatus::CHANGED;
3364 return ChangeStatus::UNCHANGED;
3372 unsigned SharedMemoryUsed = 0;
3375struct AAKernelInfo :
public StateWrapper<KernelInfoState, AbstractAttribute> {
3380 void trackStatistics()
const override {}
3383 const std::string getAsStr()
const override {
3384 if (!isValidState())
3386 return std::string(SPMDCompatibilityTracker.isAssumed() ?
"SPMD"
3388 std::string(SPMDCompatibilityTracker.isAtFixpoint() ?
" [FIX]"
3390 std::string(
" #PRs: ") +
3391 (ReachedKnownParallelRegions.isValidState()
3392 ? std::to_string(ReachedKnownParallelRegions.size())
3394 ", #Unknown PRs: " +
3395 (ReachedUnknownParallelRegions.isValidState()
3398 ", #Reaching Kernels: " +
3399 (ReachingKernelEntries.isValidState()
3403 (ParallelLevels.isValidState()
3412 const std::string
getName()
const override {
return "AAKernelInfo"; }
3415 const char *getIdAddr()
const override {
return &
ID; }
3422 static const char ID;
3427struct AAKernelInfoFunction : AAKernelInfo {
3429 : AAKernelInfo(IRP,
A) {}
3434 return GuardedInstructions;
3442 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
3446 OMPInformationCache::RuntimeFunctionInfo &InitRFI =
3447 OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
3448 OMPInformationCache::RuntimeFunctionInfo &DeinitRFI =
3449 OMPInfoCache.RFIs[OMPRTL___kmpc_target_deinit];
3453 auto StoreCallBase = [](
Use &U,
3454 OMPInformationCache::RuntimeFunctionInfo &RFI,
3456 CallBase *CB = OpenMPOpt::getCallIfRegularCall(U, &RFI);
3458 "Unexpected use of __kmpc_target_init or __kmpc_target_deinit!");
3460 "Multiple uses of __kmpc_target_init or __kmpc_target_deinit!");
3466 StoreCallBase(U, InitRFI, KernelInitCB);
3470 DeinitRFI.foreachUse(
3472 StoreCallBase(U, DeinitRFI, KernelDeinitCB);
3478 if (!KernelInitCB || !KernelDeinitCB)
3482 ReachingKernelEntries.insert(Fn);
3483 IsKernelEntry =
true;
3492 bool &UsedAssumedInformation) -> std::optional<Value *> {
3498 if (!ReachedKnownParallelRegions.isValidState())
3504 A.recordDependence(*
this, *AA, DepClassTy::OPTIONAL);
3505 UsedAssumedInformation = !isAtFixpoint();
3513 bool &UsedAssumedInformation) -> std::optional<Value *> {
3518 if (!SPMDCompatibilityTracker.isValidState())
3520 if (!SPMDCompatibilityTracker.isAtFixpoint()) {
3522 A.recordDependence(*
this, *AA, DepClassTy::OPTIONAL);
3523 UsedAssumedInformation =
true;
3525 UsedAssumedInformation =
false;
3534 constexpr const int InitModeArgNo = 1;
3535 constexpr const int DeinitModeArgNo = 1;
3536 constexpr const int InitUseStateMachineArgNo = 2;
3537 A.registerSimplificationCallback(
3539 StateMachineSimplifyCB);
3540 A.registerSimplificationCallback(
3543 A.registerSimplificationCallback(
3549 dyn_cast<ConstantInt>(KernelInitCB->getArgOperand(InitModeArgNo));
3551 SPMDCompatibilityTracker.indicateOptimisticFixpoint();
3554 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3559 if (!OMPInfoCache.RFIs[RFKind].Declaration)
3561 A.registerVirtualUseCallback(*OMPInfoCache.RFIs[RFKind].Declaration, CB);
3565 auto AddDependence = [](
Attributor &
A,
const AAKernelInfo *KI,
3568 A.recordDependence(*KI, *QueryingAA, DepClassTy::OPTIONAL);
3582 if (SPMDCompatibilityTracker.isValidState())
3583 return AddDependence(
A,
this, QueryingAA);
3585 if (!ReachedKnownParallelRegions.isValidState())
3586 return AddDependence(
A,
this, QueryingAA);
3591 if (!KernelInitCB->getCalledFunction()->isDeclaration()) {
3592 RegisterVirtualUse(OMPRTL___kmpc_get_hardware_num_threads_in_block,
3593 CustomStateMachineUseCB);
3594 RegisterVirtualUse(OMPRTL___kmpc_get_warp_size, CustomStateMachineUseCB);
3595 RegisterVirtualUse(OMPRTL___kmpc_barrier_simple_generic,
3596 CustomStateMachineUseCB);
3597 RegisterVirtualUse(OMPRTL___kmpc_kernel_parallel,
3598 CustomStateMachineUseCB);
3599 RegisterVirtualUse(OMPRTL___kmpc_kernel_end_parallel,
3600 CustomStateMachineUseCB);
3604 if (SPMDCompatibilityTracker.isAtFixpoint())
3611 if (!SPMDCompatibilityTracker.isValidState())
3612 return AddDependence(
A,
this, QueryingAA);
3615 RegisterVirtualUse(OMPRTL___kmpc_get_hardware_thread_id_in_block,
3624 if (!SPMDCompatibilityTracker.isValidState())
3625 return AddDependence(
A,
this, QueryingAA);
3626 if (SPMDCompatibilityTracker.empty())
3627 return AddDependence(
A,
this, QueryingAA);
3628 if (!mayContainParallelRegion())
3629 return AddDependence(
A,
this, QueryingAA);
3632 RegisterVirtualUse(OMPRTL___kmpc_barrier_simple_spmd, SPMDBarrierUseCB);
3636 static std::string sanitizeForGlobalName(std::string S) {
3640 return !((C >=
'a' && C <=
'z') || (C >=
'A' && C <=
'Z') ||
3641 (C >=
'0' && C <=
'9') || C ==
'_');
3652 if (!KernelInitCB || !KernelDeinitCB)
3653 return ChangeStatus::UNCHANGED;
3667 if (!changeToSPMDMode(
A, Changed)) {
3668 if (!KernelInitCB->getCalledFunction()->isDeclaration())
3669 return buildCustomStateMachine(
A);
3675 void insertInstructionGuardsHelper(
Attributor &
A) {
3676 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
3678 auto CreateGuardedRegion = [&](
Instruction *RegionStartI,
3712 DT, LI, MSU,
"region.guarded.end");
3715 MSU,
"region.barrier");
3718 DT, LI, MSU,
"region.exit");
3720 SplitBlock(ParentBB, RegionStartI, DT, LI, MSU,
"region.guarded");
3723 "Expected a different CFG");
3726 ParentBB, ParentBB->
getTerminator(), DT, LI, MSU,
"region.check.tid");
3729 A.registerManifestAddedBasicBlock(*RegionEndBB);
3730 A.registerManifestAddedBasicBlock(*RegionBarrierBB);
3731 A.registerManifestAddedBasicBlock(*RegionExitBB);
3732 A.registerManifestAddedBasicBlock(*RegionStartBB);
3733 A.registerManifestAddedBasicBlock(*RegionCheckTidBB);
3735 bool HasBroadcastValues =
false;
3740 for (
User *Usr :
I.users()) {
3743 OutsideUsers.
insert(&UsrI);
3746 if (OutsideUsers.
empty())
3749 HasBroadcastValues =
true;
3754 M,
I.getType(),
false,
3756 sanitizeForGlobalName(
3757 (
I.getName() +
".guarded.output.alloc").str()),
3759 static_cast<unsigned>(AddressSpace::Shared));
3765 I.getName() +
".guarded.output.load",
3773 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
3779 InsertPointTy(ParentBB, ParentBB->
end()),
DL);
3780 OMPInfoCache.OMPBuilder.updateToLocation(Loc);
3783 OMPInfoCache.OMPBuilder.getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3785 OMPInfoCache.OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3791 InsertPointTy(RegionCheckTidBB, RegionCheckTidBB->
end()),
DL);
3792 OMPInfoCache.OMPBuilder.updateToLocation(LocRegionCheckTid);
3794 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3795 M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
3797 OMPInfoCache.OMPBuilder.Builder.CreateCall(HardwareTidFn, {});
3799 OMPInfoCache.setCallingConvention(HardwareTidFn, Tid);
3800 Value *TidCheck = OMPInfoCache.OMPBuilder.Builder.CreateIsNull(Tid);
3801 OMPInfoCache.OMPBuilder.Builder
3802 .CreateCondBr(TidCheck, RegionStartBB, RegionBarrierBB)
3808 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3809 M, OMPRTL___kmpc_barrier_simple_spmd);
3810 OMPInfoCache.OMPBuilder.updateToLocation(InsertPointTy(
3813 OMPInfoCache.OMPBuilder.Builder.CreateCall(BarrierFn, {Ident, Tid});
3815 OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
3818 if (HasBroadcastValues) {
3822 OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
3826 auto &AllocSharedRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3828 for (
Instruction *GuardedI : SPMDCompatibilityTracker) {
3830 if (!Visited.
insert(BB).second)
3836 while (++IP != IPEnd) {
3837 if (!IP->mayHaveSideEffects() && !IP->mayReadFromMemory())
3840 if (OpenMPOpt::getCallIfRegularCall(*
I, &AllocSharedRFI))
3842 if (!
I->user_empty() || !SPMDCompatibilityTracker.contains(
I)) {
3843 LastEffect =
nullptr;
3850 for (
auto &Reorder : Reorders)
3856 for (
Instruction *GuardedI : SPMDCompatibilityTracker) {
3858 auto *CalleeAA =
A.lookupAAFor<AAKernelInfo>(
3861 assert(CalleeAA !=
nullptr &&
"Expected Callee AAKernelInfo");
3862 auto &CalleeAAFunction = *cast<AAKernelInfoFunction>(CalleeAA);
3864 if (CalleeAAFunction.getGuardedInstructions().contains(GuardedI))
3867 Instruction *GuardedRegionStart =
nullptr, *GuardedRegionEnd =
nullptr;
3871 if (SPMDCompatibilityTracker.contains(&
I)) {
3872 CalleeAAFunction.getGuardedInstructions().insert(&
I);
3873 if (GuardedRegionStart)
3874 GuardedRegionEnd = &
I;
3876 GuardedRegionStart = GuardedRegionEnd = &
I;
3883 if (GuardedRegionStart) {
3885 std::make_pair(GuardedRegionStart, GuardedRegionEnd));
3886 GuardedRegionStart =
nullptr;
3887 GuardedRegionEnd =
nullptr;
3892 for (
auto &GR : GuardedRegions)
3893 CreateGuardedRegion(GR.first, GR.second);
3896 void forceSingleThreadPerWorkgroupHelper(
Attributor &
A) {
3905 auto &Ctx = getAnchorValue().getContext();
3912 KernelInitCB->getNextNode(),
"main.thread.user_code");
3917 A.registerManifestAddedBasicBlock(*InitBB);
3918 A.registerManifestAddedBasicBlock(*UserCodeBB);
3919 A.registerManifestAddedBasicBlock(*ReturnBB);
3922 const DebugLoc &DLoc = KernelInitCB->getDebugLoc();
3928 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
3930 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3931 M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
3936 OMPInfoCache.setCallingConvention(ThreadIdInBlockFn, ThreadIdInBlock);
3943 "thread.is_main", InitBB);
3949 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
3952 if (!OMPInfoCache.runtimeFnsAvailable(
3953 {OMPRTL___kmpc_get_hardware_thread_id_in_block,
3954 OMPRTL___kmpc_barrier_simple_spmd}))
3957 if (!SPMDCompatibilityTracker.isAssumed()) {
3958 for (
Instruction *NonCompatibleI : SPMDCompatibilityTracker) {
3959 if (!NonCompatibleI)
3963 if (
auto *CB = dyn_cast<CallBase>(NonCompatibleI))
3964 if (OMPInfoCache.RTLFunctions.contains(CB->getCalledFunction()))
3968 ORA <<
"Value has potential side effects preventing SPMD-mode "
3970 if (isa<CallBase>(NonCompatibleI)) {
3971 ORA <<
". Add `__attribute__((assume(\"ompx_spmd_amenable\")))` to "
3972 "the called function to override";
3980 << *NonCompatibleI <<
"\n");
3992 Kernel = CB->getCaller();
3994 assert(OMPInfoCache.Kernels.count(
Kernel) &&
"Expected kernel function!");
3999 assert(ExecMode &&
"Kernel without exec mode?");
4004 "ExecMode is not an integer!");
4005 const int8_t ExecModeVal =
4011 Changed = ChangeStatus::CHANGED;
4015 if (mayContainParallelRegion())
4016 insertInstructionGuardsHelper(
A);
4018 forceSingleThreadPerWorkgroupHelper(
A);
4023 "Initially non-SPMD kernel has SPMD exec mode!");
4029 const int InitModeArgNo = 1;
4030 const int DeinitModeArgNo = 1;
4031 const int InitUseStateMachineArgNo = 2;
4033 auto &Ctx = getAnchorValue().getContext();
4034 A.changeUseAfterManifest(
4035 KernelInitCB->getArgOperandUse(InitModeArgNo),
4038 A.changeUseAfterManifest(
4039 KernelInitCB->getArgOperandUse(InitUseStateMachineArgNo),
4041 A.changeUseAfterManifest(
4042 KernelDeinitCB->getArgOperandUse(DeinitModeArgNo),
4046 ++NumOpenMPTargetRegionKernelsSPMD;
4049 return OR <<
"Transformed generic-mode kernel to SPMD-mode.";
4058 return ChangeStatus::UNCHANGED;
4061 if (!ReachedKnownParallelRegions.isValidState())
4062 return ChangeStatus::UNCHANGED;
4064 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
4065 if (!OMPInfoCache.runtimeFnsAvailable(
4066 {OMPRTL___kmpc_get_hardware_num_threads_in_block,
4067 OMPRTL___kmpc_get_warp_size, OMPRTL___kmpc_barrier_simple_generic,
4068 OMPRTL___kmpc_kernel_parallel, OMPRTL___kmpc_kernel_end_parallel}))
4069 return ChangeStatus::UNCHANGED;
4071 const int InitModeArgNo = 1;
4072 const int InitUseStateMachineArgNo = 2;
4078 ConstantInt *UseStateMachine = dyn_cast<ConstantInt>(
4079 KernelInitCB->getArgOperand(InitUseStateMachineArgNo));
4081 dyn_cast<ConstantInt>(KernelInitCB->getArgOperand(InitModeArgNo));
4086 if (!UseStateMachine || UseStateMachine->
isZero() || !Mode ||
4088 return ChangeStatus::UNCHANGED;
4091 auto &Ctx = getAnchorValue().getContext();
4093 A.changeUseAfterManifest(
4094 KernelInitCB->getArgOperandUse(InitUseStateMachineArgNo), *FalseVal);
4100 if (!mayContainParallelRegion()) {
4101 ++NumOpenMPTargetRegionKernelsWithoutStateMachine;
4104 return OR <<
"Removing unused state machine from generic-mode kernel.";
4108 return ChangeStatus::CHANGED;
4112 if (ReachedUnknownParallelRegions.empty()) {
4113 ++NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback;
4116 return OR <<
"Rewriting generic-mode kernel with a customized state "
4121 ++NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback;
4124 return OR <<
"Generic-mode kernel is executed with a customized state "
4125 "machine that requires a fallback.";
4130 for (
CallBase *UnknownParallelRegionCB : ReachedUnknownParallelRegions) {
4131 if (!UnknownParallelRegionCB)
4134 return ORA <<
"Call may contain unknown parallel regions. Use "
4135 <<
"`__attribute__((assume(\"omp_no_parallelism\")))` to "
4176 BasicBlock *InitBB = KernelInitCB->getParent();
4178 KernelInitCB->getNextNode(),
"thread.user_code.check");
4182 Ctx,
"worker_state_machine.begin",
Kernel, UserCodeEntryBB);
4184 Ctx,
"worker_state_machine.finished",
Kernel, UserCodeEntryBB);
4186 Ctx,
"worker_state_machine.is_active.check",
Kernel, UserCodeEntryBB);
4189 Kernel, UserCodeEntryBB);
4192 Kernel, UserCodeEntryBB);
4194 Ctx,
"worker_state_machine.done.barrier",
Kernel, UserCodeEntryBB);
4195 A.registerManifestAddedBasicBlock(*InitBB);
4196 A.registerManifestAddedBasicBlock(*UserCodeEntryBB);
4197 A.registerManifestAddedBasicBlock(*IsWorkerCheckBB);
4198 A.registerManifestAddedBasicBlock(*StateMachineBeginBB);
4199 A.registerManifestAddedBasicBlock(*StateMachineFinishedBB);
4200 A.registerManifestAddedBasicBlock(*StateMachineIsActiveCheckBB);
4201 A.registerManifestAddedBasicBlock(*StateMachineIfCascadeCurrentBB);
4202 A.registerManifestAddedBasicBlock(*StateMachineEndParallelBB);
4203 A.registerManifestAddedBasicBlock(*StateMachineDoneBarrierBB);
4205 const DebugLoc &DLoc = KernelInitCB->getDebugLoc();
4212 "thread.is_worker", InitBB);
4218 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4219 M, OMPRTL___kmpc_get_hardware_num_threads_in_block);
4221 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4222 M, OMPRTL___kmpc_get_warp_size);
4225 OMPInfoCache.setCallingConvention(BlockHwSizeFn, BlockHwSize);
4229 OMPInfoCache.setCallingConvention(WarpSizeFn, WarpSize);
4232 BlockHwSize, WarpSize,
"block.size", IsWorkerCheckBB);
4236 "thread.is_main_or_worker", IsWorkerCheckBB);
4239 IsMainOrWorker, IsWorkerCheckBB);
4245 new AllocaInst(VoidPtrTy,
DL.getAllocaAddrSpace(),
nullptr,
4249 OMPInfoCache.OMPBuilder.updateToLocation(
4252 StateMachineBeginBB->
end()),
4255 Value *Ident = KernelInitCB->getArgOperand(0);
4256 Value *GTid = KernelInitCB;
4259 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4260 M, OMPRTL___kmpc_barrier_simple_generic);
4263 OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
4267 (
unsigned int)AddressSpace::Generic) {
4270 PointerType::getWithSamePointeeType(
4271 cast<PointerType>(WorkFnAI->
getType()),
4272 (
unsigned int)AddressSpace::Generic),
4273 WorkFnAI->
getName() +
".generic", StateMachineBeginBB);
4278 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4279 M, OMPRTL___kmpc_kernel_parallel);
4281 KernelParallelFn, {WorkFnAI},
"worker.is_active", StateMachineBeginBB);
4282 OMPInfoCache.setCallingConvention(KernelParallelFn, IsActiveWorker);
4285 StateMachineBeginBB);
4291 Value *WorkFnCast = BitCastInst::CreatePointerBitCastOrAddrSpaceCast(
4292 WorkFn, ParallelRegionFnTy->getPointerTo(),
"worker.work_fn.addr_cast",
4293 StateMachineBeginBB);
4298 StateMachineBeginBB);
4301 IsDone, StateMachineBeginBB)
4305 StateMachineDoneBarrierBB, IsActiveWorker,
4306 StateMachineIsActiveCheckBB)
4315 for (
int I = 0,
E = ReachedKnownParallelRegions.size();
I <
E; ++
I) {
4316 auto *ParallelRegion = ReachedKnownParallelRegions[
I];
4318 Ctx,
"worker_state_machine.parallel_region.execute",
Kernel,
4319 StateMachineEndParallelBB);
4321 ->setDebugLoc(DLoc);
4327 Kernel, StateMachineEndParallelBB);
4332 if (
I + 1 <
E || !ReachedUnknownParallelRegions.empty()) {
4335 "worker.check_parallel_region", StateMachineIfCascadeCurrentBB);
4343 StateMachineIfCascadeCurrentBB)
4345 StateMachineIfCascadeCurrentBB = PRNextBB;
4351 if (!ReachedUnknownParallelRegions.empty()) {
4352 StateMachineIfCascadeCurrentBB->
setName(
4353 "worker_state_machine.parallel_region.fallback.execute");
4355 StateMachineIfCascadeCurrentBB)
4356 ->setDebugLoc(DLoc);
4359 StateMachineIfCascadeCurrentBB)
4363 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4364 M, OMPRTL___kmpc_kernel_end_parallel);
4367 OMPInfoCache.setCallingConvention(EndParallelFn, EndParallel);
4373 ->setDebugLoc(DLoc);
4377 return ChangeStatus::CHANGED;
4383 KernelInfoState StateBefore = getState();
4388 if (isa<CallBase>(
I))
4391 if (!
I.mayWriteToMemory())
4393 if (
auto *SI = dyn_cast<StoreInst>(&
I)) {
4396 DepClassTy::OPTIONAL);
4399 DepClassTy::OPTIONAL);
4400 if (UnderlyingObjsAA.forallUnderlyingObjects([&](
Value &Obj) {
4401 if (AA::isAssumedThreadLocalObject(A, Obj, *this))
4405 auto *CB = dyn_cast<CallBase>(&Obj);
4406 return CB && HS.isAssumedHeapToStack(*CB);
4412 SPMDCompatibilityTracker.insert(&
I);
4416 bool UsedAssumedInformationInCheckRWInst =
false;
4417 if (!SPMDCompatibilityTracker.isAtFixpoint())
4418 if (!
A.checkForAllReadWriteInstructions(
4419 CheckRWInst, *
this, UsedAssumedInformationInCheckRWInst))
4422 bool UsedAssumedInformationFromReachingKernels =
false;
4423 if (!IsKernelEntry) {
4424 updateParallelLevels(
A);
4426 bool AllReachingKernelsKnown =
true;
4427 updateReachingKernelEntries(
A, AllReachingKernelsKnown);
4428 UsedAssumedInformationFromReachingKernels = !AllReachingKernelsKnown;
4430 if (!SPMDCompatibilityTracker.empty()) {
4431 if (!ParallelLevels.isValidState())
4432 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4433 else if (!ReachingKernelEntries.isValidState())
4434 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4440 for (
auto *
Kernel : ReachingKernelEntries) {
4441 auto &CBAA =
A.getAAFor<AAKernelInfo>(
4443 if (CBAA.SPMDCompatibilityTracker.isValidState() &&
4444 CBAA.SPMDCompatibilityTracker.isAssumed())
4448 if (!CBAA.SPMDCompatibilityTracker.isAtFixpoint())
4449 UsedAssumedInformationFromReachingKernels =
true;
4451 if (SPMD != 0 &&
Generic != 0)
4452 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4458 bool AllParallelRegionStatesWereFixed =
true;
4459 bool AllSPMDStatesWereFixed =
true;
4461 auto &CB = cast<CallBase>(
I);
4462 auto &CBAA =
A.getAAFor<AAKernelInfo>(
4464 getState() ^= CBAA.getState();
4465 AllSPMDStatesWereFixed &= CBAA.SPMDCompatibilityTracker.isAtFixpoint();
4466 AllParallelRegionStatesWereFixed &=
4467 CBAA.ReachedKnownParallelRegions.isAtFixpoint();
4468 AllParallelRegionStatesWereFixed &=
4469 CBAA.ReachedUnknownParallelRegions.isAtFixpoint();
4473 bool UsedAssumedInformationInCheckCallInst =
false;
4474 if (!
A.checkForAllCallLikeInstructions(
4475 CheckCallInst, *
this, UsedAssumedInformationInCheckCallInst)) {
4477 <<
"Failed to visit all call-like instructions!\n";);
4478 return indicatePessimisticFixpoint();
4483 if (!UsedAssumedInformationInCheckCallInst &&
4484 AllParallelRegionStatesWereFixed) {
4485 ReachedKnownParallelRegions.indicateOptimisticFixpoint();
4486 ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
4491 if (!UsedAssumedInformationInCheckRWInst &&
4492 !UsedAssumedInformationInCheckCallInst &&
4493 !UsedAssumedInformationFromReachingKernels && AllSPMDStatesWereFixed)
4494 SPMDCompatibilityTracker.indicateOptimisticFixpoint();
4496 return StateBefore == getState() ? ChangeStatus::UNCHANGED
4497 : ChangeStatus::CHANGED;
4503 bool &AllReachingKernelsKnown) {
4507 assert(Caller &&
"Caller is nullptr");
4509 auto &CAA =
A.getOrCreateAAFor<AAKernelInfo>(
4511 if (CAA.ReachingKernelEntries.isValidState()) {
4512 ReachingKernelEntries ^= CAA.ReachingKernelEntries;
4518 ReachingKernelEntries.indicatePessimisticFixpoint();
4523 if (!
A.checkForAllCallSites(PredCallSite, *
this,
4525 AllReachingKernelsKnown))
4526 ReachingKernelEntries.indicatePessimisticFixpoint();
4531 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
4532 OMPInformationCache::RuntimeFunctionInfo &Parallel51RFI =
4533 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
4538 assert(Caller &&
"Caller is nullptr");
4542 if (CAA.ParallelLevels.isValidState()) {
4548 if (Caller == Parallel51RFI.Declaration) {
4549 ParallelLevels.indicatePessimisticFixpoint();
4553 ParallelLevels ^= CAA.ParallelLevels;
4560 ParallelLevels.indicatePessimisticFixpoint();
4565 bool AllCallSitesKnown =
true;
4566 if (!
A.checkForAllCallSites(PredCallSite, *
this,
4569 ParallelLevels.indicatePessimisticFixpoint();
4576struct AAKernelInfoCallSite : AAKernelInfo {
4578 : AAKernelInfo(IRP,
A) {}
4582 AAKernelInfo::initialize(
A);
4584 CallBase &CB = cast<CallBase>(getAssociatedValue());
4591 if (AssumptionAA.hasAssumption(
"ompx_spmd_amenable")) {
4593 indicateOptimisticFixpoint();
4600 indicateOptimisticFixpoint();
4608 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
4609 const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(
Callee);
4610 if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
4616 if (!(AssumptionAA.hasAssumption(
"omp_no_openmp") ||
4617 AssumptionAA.hasAssumption(
"omp_no_parallelism")))
4618 ReachedUnknownParallelRegions.insert(&CB);
4622 if (!SPMDCompatibilityTracker.isAtFixpoint()) {
4623 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4624 SPMDCompatibilityTracker.insert(&CB);
4629 indicateOptimisticFixpoint();
4636 const unsigned int WrapperFunctionArgNo = 6;
4640 case OMPRTL___kmpc_is_spmd_exec_mode:
4641 case OMPRTL___kmpc_distribute_static_fini:
4642 case OMPRTL___kmpc_for_static_fini:
4643 case OMPRTL___kmpc_global_thread_num:
4644 case OMPRTL___kmpc_get_hardware_num_threads_in_block:
4645 case OMPRTL___kmpc_get_hardware_num_blocks:
4646 case OMPRTL___kmpc_single:
4647 case OMPRTL___kmpc_end_single:
4648 case OMPRTL___kmpc_master:
4649 case OMPRTL___kmpc_end_master:
4650 case OMPRTL___kmpc_barrier:
4651 case OMPRTL___kmpc_nvptx_parallel_reduce_nowait_v2:
4652 case OMPRTL___kmpc_nvptx_teams_reduce_nowait_v2:
4653 case OMPRTL___kmpc_nvptx_end_reduce_nowait:
4655 case OMPRTL___kmpc_distribute_static_init_4:
4656 case OMPRTL___kmpc_distribute_static_init_4u:
4657 case OMPRTL___kmpc_distribute_static_init_8:
4658 case OMPRTL___kmpc_distribute_static_init_8u:
4659 case OMPRTL___kmpc_for_static_init_4:
4660 case OMPRTL___kmpc_for_static_init_4u:
4661 case OMPRTL___kmpc_for_static_init_8:
4662 case OMPRTL___kmpc_for_static_init_8u: {
4664 unsigned ScheduleArgOpNo = 2;
4665 auto *ScheduleTypeCI =
4667 unsigned ScheduleTypeVal =
4668 ScheduleTypeCI ? ScheduleTypeCI->getZExtValue() : 0;
4670 case OMPScheduleType::UnorderedStatic:
4671 case OMPScheduleType::UnorderedStaticChunked:
4672 case OMPScheduleType::OrderedDistribute:
4673 case OMPScheduleType::OrderedDistributeChunked:
4676 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4677 SPMDCompatibilityTracker.insert(&CB);
4681 case OMPRTL___kmpc_target_init:
4684 case OMPRTL___kmpc_target_deinit:
4685 KernelDeinitCB = &CB;
4687 case OMPRTL___kmpc_parallel_51:
4688 if (
auto *ParallelRegion = dyn_cast<Function>(
4690 ReachedKnownParallelRegions.insert(ParallelRegion);
4692 auto &FnAA =
A.getAAFor<AAKernelInfo>(
4694 NestedParallelism |= !FnAA.getState().isValidState() ||
4695 !FnAA.ReachedKnownParallelRegions.empty() ||
4696 !FnAA.ReachedUnknownParallelRegions.empty();
4702 ReachedUnknownParallelRegions.insert(&CB);
4704 case OMPRTL___kmpc_omp_task:
4706 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4707 SPMDCompatibilityTracker.insert(&CB);
4708 ReachedUnknownParallelRegions.insert(&CB);
4710 case OMPRTL___kmpc_alloc_shared:
4711 case OMPRTL___kmpc_free_shared:
4717 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4718 SPMDCompatibilityTracker.insert(&CB);
4724 indicateOptimisticFixpoint();
4734 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
4735 const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(
F);
4738 if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
4740 auto &FnAA =
A.getAAFor<AAKernelInfo>(*
this, FnPos, DepClassTy::REQUIRED);
4741 if (getState() == FnAA.getState())
4742 return ChangeStatus::UNCHANGED;
4743 getState() = FnAA.getState();
4744 return ChangeStatus::CHANGED;
4749 KernelInfoState StateBefore = getState();
4750 assert((It->getSecond() == OMPRTL___kmpc_alloc_shared ||
4751 It->getSecond() == OMPRTL___kmpc_free_shared) &&
4752 "Expected a __kmpc_alloc_shared or __kmpc_free_shared runtime call");
4754 CallBase &CB = cast<CallBase>(getAssociatedValue());
4758 auto &HeapToSharedAA =
A.getAAFor<AAHeapToShared>(
4766 case OMPRTL___kmpc_alloc_shared:
4767 if (!HeapToStackAA.isAssumedHeapToStack(CB) &&
4768 !HeapToSharedAA.isAssumedHeapToShared(CB))
4769 SPMDCompatibilityTracker.insert(&CB);
4771 case OMPRTL___kmpc_free_shared:
4772 if (!HeapToStackAA.isAssumedHeapToStackRemovedFree(CB) &&
4773 !HeapToSharedAA.isAssumedHeapToSharedRemovedFree(CB))
4774 SPMDCompatibilityTracker.insert(&CB);
4777 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4778 SPMDCompatibilityTracker.insert(&CB);
4781 return StateBefore == getState() ? ChangeStatus::UNCHANGED
4782 : ChangeStatus::CHANGED;
4786struct AAFoldRuntimeCall
4787 :
public StateWrapper<BooleanState, AbstractAttribute> {
4793 void trackStatistics()
const override {}
4796 static AAFoldRuntimeCall &createForPosition(
const IRPosition &IRP,
4800 const std::string
getName()
const override {
return "AAFoldRuntimeCall"; }
4803 const char *getIdAddr()
const override {
return &
ID; }
4811 static const char ID;
4814struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
4816 : AAFoldRuntimeCall(IRP,
A) {}
4819 const std::string getAsStr()
const override {
4820 if (!isValidState())
4823 std::string Str(
"simplified value: ");
4825 if (!SimplifiedValue)
4826 return Str + std::string(
"none");
4828 if (!*SimplifiedValue)
4829 return Str + std::string(
"nullptr");
4831 if (
ConstantInt *CI = dyn_cast<ConstantInt>(*SimplifiedValue))
4832 return Str + std::to_string(CI->getSExtValue());
4834 return Str + std::string(
"unknown");
4839 indicatePessimisticFixpoint();
4843 auto &OMPInfoCache =
static_cast<OMPInformationCache &
>(
A.getInfoCache());
4844 const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(
Callee);
4845 assert(It != OMPInfoCache.RuntimeFunctionIDMap.end() &&
4846 "Expected a known OpenMP runtime function");
4848 RFKind = It->getSecond();
4850 CallBase &CB = cast<CallBase>(getAssociatedValue());
4851 A.registerSimplificationCallback(
4854 bool &UsedAssumedInformation) -> std::optional<Value *> {
4855 assert((isValidState() ||
4856 (SimplifiedValue && *SimplifiedValue ==
nullptr)) &&
4857 "Unexpected invalid state!");
4859 if (!isAtFixpoint()) {
4860 UsedAssumedInformation =
true;
4862 A.recordDependence(*
this, *AA, DepClassTy::OPTIONAL);
4864 return SimplifiedValue;
4871 case OMPRTL___kmpc_is_spmd_exec_mode:
4872 Changed |= foldIsSPMDExecMode(
A);
4874 case OMPRTL___kmpc_parallel_level:
4875 Changed |= foldParallelLevel(
A);
4877 case OMPRTL___kmpc_get_hardware_num_threads_in_block:
4878 Changed = Changed | foldKernelFnAttribute(
A,
"omp_target_thread_limit");
4880 case OMPRTL___kmpc_get_hardware_num_blocks:
4881 Changed = Changed | foldKernelFnAttribute(
A,
"omp_target_num_teams");
4893 if (SimplifiedValue && *SimplifiedValue) {
4896 A.deleteAfterManifest(
I);
4900 if (
auto *
C = dyn_cast<ConstantInt>(*SimplifiedValue))
4901 return OR <<
"Replacing OpenMP runtime call "
4903 <<
ore::NV(
"FoldedValue",
C->getZExtValue()) <<
".";
4904 return OR <<
"Replacing OpenMP runtime call "
4912 << **SimplifiedValue <<
"\n");
4914 Changed = ChangeStatus::CHANGED;
4921 SimplifiedValue =
nullptr;
4922 return AAFoldRuntimeCall::indicatePessimisticFixpoint();
4928 std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
4930 unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
4931 unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
4932 auto &CallerKernelInfoAA =
A.getAAFor<AAKernelInfo>(
4935 if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
4936 return indicatePessimisticFixpoint();
4938 for (
Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
4940 DepClassTy::REQUIRED);
4942 if (!AA.isValidState()) {
4943 SimplifiedValue =
nullptr;
4944 return indicatePessimisticFixpoint();
4947 if (AA.SPMDCompatibilityTracker.isAssumed()) {
4948 if (AA.SPMDCompatibilityTracker.isAtFixpoint())
4953 if (AA.SPMDCompatibilityTracker.isAtFixpoint())
4954 ++KnownNonSPMDCount;
4956 ++AssumedNonSPMDCount;
4960 if ((AssumedSPMDCount + KnownSPMDCount) &&
4961 (AssumedNonSPMDCount + KnownNonSPMDCount))
4962 return indicatePessimisticFixpoint();
4964 auto &Ctx = getAnchorValue().getContext();
4965 if (KnownSPMDCount || AssumedSPMDCount) {
4966 assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
4967 "Expected only SPMD kernels!");
4971 }
else if (KnownNonSPMDCount || AssumedNonSPMDCount) {
4972 assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
4973 "Expected only non-SPMD kernels!");
4981 assert(!SimplifiedValue &&
"SimplifiedValue should be none");
4984 return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
4985 : ChangeStatus::CHANGED;
4990 std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
4992 auto &CallerKernelInfoAA =
A.getAAFor<AAKernelInfo>(
4995 if (!CallerKernelInfoAA.ParallelLevels.isValidState())
4996 return indicatePessimisticFixpoint();
4998 if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
4999 return indicatePessimisticFixpoint();
5001 if (CallerKernelInfoAA.ReachingKernelEntries.empty()) {
5002 assert(!SimplifiedValue &&
5003 "SimplifiedValue should keep none at this point");
5004 return ChangeStatus::UNCHANGED;
5007 unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
5008 unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
5009 for (
Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
5011 DepClassTy::REQUIRED);
5012 if (!AA.SPMDCompatibilityTracker.isValidState())
5013 return indicatePessimisticFixpoint();
5015 if (AA.SPMDCompatibilityTracker.isAssumed()) {
5016 if (AA.SPMDCompatibilityTracker.isAtFixpoint())
5021 if (AA.SPMDCompatibilityTracker.isAtFixpoint())
5022 ++KnownNonSPMDCount;
5024 ++AssumedNonSPMDCount;
5028 if ((AssumedSPMDCount + KnownSPMDCount) &&
5029 (AssumedNonSPMDCount + KnownNonSPMDCount))
5030 return indicatePessimisticFixpoint();
5032 auto &Ctx = getAnchorValue().getContext();
5036 if (AssumedSPMDCount || KnownSPMDCount) {
5037 assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
5038 "Expected only SPMD kernels!");
5041 assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
5042 "Expected only non-SPMD kernels!");
5045 return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
5046 : ChangeStatus::CHANGED;
5051 int32_t CurrentAttrValue = -1;
5052 std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
5054 auto &CallerKernelInfoAA =
A.getAAFor<AAKernelInfo>(
5057 if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
5058 return indicatePessimisticFixpoint();
5061 for (
Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
5064 if (NextAttrVal == -1 ||
5065 (CurrentAttrValue != -1 && CurrentAttrValue != NextAttrVal))
5066 return indicatePessimisticFixpoint();
5067 CurrentAttrValue = NextAttrVal;
5070 if (CurrentAttrValue != -1) {
5071 auto &Ctx = getAnchorValue().getContext();
5075 return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
5076 : ChangeStatus::CHANGED;
5082 std::optional<Value *> SimplifiedValue;
5092 auto &RFI = OMPInfoCache.RFIs[RF];
5094 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &RFI);
5097 A.getOrCreateAAFor<AAFoldRuntimeCall>(
5099 DepClassTy::NONE,
false,
5105void OpenMPOpt::registerAAs(
bool IsModulePass) {
5115 A.getOrCreateAAFor<AAKernelInfo>(
5117 DepClassTy::NONE,
false,
5121 OMPInformationCache::RuntimeFunctionInfo &InitRFI =
5122 OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
5123 InitRFI.foreachUse(SCC, CreateKernelInfoCB);
5125 registerFoldRuntimeCall(OMPRTL___kmpc_is_spmd_exec_mode);
5126 registerFoldRuntimeCall(OMPRTL___kmpc_parallel_level);
5127 registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_threads_in_block);
5128 registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_blocks);
5133 for (
int Idx = 0;
Idx < OMPInfoCache.ICVs.size() - 1; ++
Idx) {
5136 auto &GetterRFI = OMPInfoCache.RFIs[ICVInfo.Getter];
5139 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &GetterRFI);
5143 auto &CB = cast<CallBase>(*CI);
5146 A.getOrCreateAAFor<AAICVTracker>(CBPos);
5150 GetterRFI.foreachUse(SCC, CreateAA);
5159 for (
auto *
F : SCC) {
5160 if (
F->isDeclaration())
5166 if (
F->hasLocalLinkage()) {
5168 const auto *CB = dyn_cast<CallBase>(U.getUser());
5169 return CB && CB->isCallee(&U) &&
5170 A.isRunOn(const_cast<Function *>(CB->getCaller()));
5174 registerAAsForFunction(
A, *
F);
5186 if (
auto *LI = dyn_cast<LoadInst>(&
I)) {
5187 bool UsedAssumedInformation =
false;
5192 if (
auto *SI = dyn_cast<StoreInst>(&
I)) {
5196 if (
auto *II = dyn_cast<IntrinsicInst>(&
I)) {
5197 if (II->getIntrinsicID() == Intrinsic::assume) {
5206const char AAICVTracker::ID = 0;
5207const char AAKernelInfo::ID = 0;
5209const char AAHeapToShared::ID = 0;
5210const char AAFoldRuntimeCall::ID = 0;
5212AAICVTracker &AAICVTracker::createForPosition(
const IRPosition &IRP,
5214 AAICVTracker *AA =
nullptr;
5222 AA =
new (
A.Allocator) AAICVTrackerFunctionReturned(IRP,
A);
5225 AA =
new (
A.Allocator) AAICVTrackerCallSiteReturned(IRP,
A);
5228 AA =
new (
A.Allocator) AAICVTrackerCallSite(IRP,
A);
5231 AA =
new (
A.Allocator) AAICVTrackerFunction(IRP,
A);
5240 AAExecutionDomainFunction *AA =
nullptr;
5250 "AAExecutionDomain can only be created for function position!");
5252 AA =
new (
A.Allocator) AAExecutionDomainFunction(IRP,
A);
5259AAHeapToShared &AAHeapToShared::createForPosition(
const IRPosition &IRP,
5261 AAHeapToSharedFunction *AA =
nullptr;
5271 "AAHeapToShared can only be created for function position!");
5273 AA =
new (
A.Allocator) AAHeapToSharedFunction(IRP,
A);
5280AAKernelInfo &AAKernelInfo::createForPosition(
const IRPosition &IRP,
5282 AAKernelInfo *AA =
nullptr;
5292 AA =
new (
A.Allocator) AAKernelInfoCallSite(IRP,
A);
5295 AA =
new (
A.Allocator) AAKernelInfoFunction(IRP,
A);
5302AAFoldRuntimeCall &AAFoldRuntimeCall::createForPosition(
const IRPosition &IRP,
5304 AAFoldRuntimeCall *AA =
nullptr;
5313 llvm_unreachable(
"KernelInfo can only be created for call site position!");
5315 AA =
new (
A.Allocator) AAFoldRuntimeCallCallSiteReturned(IRP,
A);
5336 if (Kernels.contains(&
F))
5338 for (
const User *U :
F.users())
5339 if (!isa<BlockAddress>(U))
5348 return ORA <<
"Could not internalize function. "
5349 <<
"Some optimizations may not be possible. [OMP140]";
5359 if (!
F.isDeclaration() && !Kernels.contains(&
F) && IsCalled(
F) &&
5363 }
else if (!
F.hasLocalLinkage() && !
F.hasFnAttribute(Attribute::Cold)) {
5375 if (!
F.isDeclaration() && !InternalizedMap.
lookup(&
F)) {