LLVM 19.0.0git
OpenMPOpt.cpp
Go to the documentation of this file.
1//===-- IPO/OpenMPOpt.cpp - Collection of OpenMP specific optimizations ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// OpenMP specific optimizations:
10//
11// - Deduplication of runtime calls, e.g., omp_get_thread_num.
12// - Replacing globalized device memory with stack memory.
13// - Replacing globalized device memory with shared memory.
14// - Parallel region merging.
15// - Transforming generic-mode device kernels to SPMD mode.
16// - Specializing the state machine for generic-mode device kernels.
17//
18//===----------------------------------------------------------------------===//
19
21
24#include "llvm/ADT/SetVector.h"
27#include "llvm/ADT/Statistic.h"
29#include "llvm/ADT/StringRef.h"
38#include "llvm/IR/Assumptions.h"
39#include "llvm/IR/BasicBlock.h"
40#include "llvm/IR/Constants.h"
42#include "llvm/IR/Dominators.h"
43#include "llvm/IR/Function.h"
44#include "llvm/IR/GlobalValue.h"
46#include "llvm/IR/InstrTypes.h"
47#include "llvm/IR/Instruction.h"
50#include "llvm/IR/IntrinsicsAMDGPU.h"
51#include "llvm/IR/IntrinsicsNVPTX.h"
52#include "llvm/IR/LLVMContext.h"
55#include "llvm/Support/Debug.h"
59
60#include <algorithm>
61#include <optional>
62#include <string>
63
64using namespace llvm;
65using namespace omp;
66
67#define DEBUG_TYPE "openmp-opt"
68
70 "openmp-opt-disable", cl::desc("Disable OpenMP specific optimizations."),
71 cl::Hidden, cl::init(false));
72
74 "openmp-opt-enable-merging",
75 cl::desc("Enable the OpenMP region merging optimization."), cl::Hidden,
76 cl::init(false));
77
78static cl::opt<bool>
79 DisableInternalization("openmp-opt-disable-internalization",
80 cl::desc("Disable function internalization."),
81 cl::Hidden, cl::init(false));
82
83static cl::opt<bool> DeduceICVValues("openmp-deduce-icv-values",
84 cl::init(false), cl::Hidden);
85static cl::opt<bool> PrintICVValues("openmp-print-icv-values", cl::init(false),
87static cl::opt<bool> PrintOpenMPKernels("openmp-print-gpu-kernels",
88 cl::init(false), cl::Hidden);
89
91 "openmp-hide-memory-transfer-latency",
92 cl::desc("[WIP] Tries to hide the latency of host to device memory"
93 " transfers"),
94 cl::Hidden, cl::init(false));
95
97 "openmp-opt-disable-deglobalization",
98 cl::desc("Disable OpenMP optimizations involving deglobalization."),
99 cl::Hidden, cl::init(false));
100
102 "openmp-opt-disable-spmdization",
103 cl::desc("Disable OpenMP optimizations involving SPMD-ization."),
104 cl::Hidden, cl::init(false));
105
107 "openmp-opt-disable-folding",
108 cl::desc("Disable OpenMP optimizations involving folding."), cl::Hidden,
109 cl::init(false));
110
112 "openmp-opt-disable-state-machine-rewrite",
113 cl::desc("Disable OpenMP optimizations that replace the state machine."),
114 cl::Hidden, cl::init(false));
115
117 "openmp-opt-disable-barrier-elimination",
118 cl::desc("Disable OpenMP optimizations that eliminate barriers."),
119 cl::Hidden, cl::init(false));
120
122 "openmp-opt-print-module-after",
123 cl::desc("Print the current module after OpenMP optimizations."),
124 cl::Hidden, cl::init(false));
125
127 "openmp-opt-print-module-before",
128 cl::desc("Print the current module before OpenMP optimizations."),
129 cl::Hidden, cl::init(false));
130
132 "openmp-opt-inline-device",
133 cl::desc("Inline all applicible functions on the device."), cl::Hidden,
134 cl::init(false));
135
136static cl::opt<bool>
137 EnableVerboseRemarks("openmp-opt-verbose-remarks",
138 cl::desc("Enables more verbose remarks."), cl::Hidden,
139 cl::init(false));
140
142 SetFixpointIterations("openmp-opt-max-iterations", cl::Hidden,
143 cl::desc("Maximal number of attributor iterations."),
144 cl::init(256));
145
147 SharedMemoryLimit("openmp-opt-shared-limit", cl::Hidden,
148 cl::desc("Maximum amount of shared memory to use."),
149 cl::init(std::numeric_limits<unsigned>::max()));
150
151STATISTIC(NumOpenMPRuntimeCallsDeduplicated,
152 "Number of OpenMP runtime calls deduplicated");
153STATISTIC(NumOpenMPParallelRegionsDeleted,
154 "Number of OpenMP parallel regions deleted");
155STATISTIC(NumOpenMPRuntimeFunctionsIdentified,
156 "Number of OpenMP runtime functions identified");
157STATISTIC(NumOpenMPRuntimeFunctionUsesIdentified,
158 "Number of OpenMP runtime function uses identified");
159STATISTIC(NumOpenMPTargetRegionKernels,
160 "Number of OpenMP target region entry points (=kernels) identified");
161STATISTIC(NumNonOpenMPTargetRegionKernels,
162 "Number of non-OpenMP target region kernels identified");
163STATISTIC(NumOpenMPTargetRegionKernelsSPMD,
164 "Number of OpenMP target region entry points (=kernels) executed in "
165 "SPMD-mode instead of generic-mode");
166STATISTIC(NumOpenMPTargetRegionKernelsWithoutStateMachine,
167 "Number of OpenMP target region entry points (=kernels) executed in "
168 "generic-mode without a state machines");
169STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback,
170 "Number of OpenMP target region entry points (=kernels) executed in "
171 "generic-mode with customized state machines with fallback");
172STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback,
173 "Number of OpenMP target region entry points (=kernels) executed in "
174 "generic-mode with customized state machines without fallback");
176 NumOpenMPParallelRegionsReplacedInGPUStateMachine,
177 "Number of OpenMP parallel regions replaced with ID in GPU state machines");
178STATISTIC(NumOpenMPParallelRegionsMerged,
179 "Number of OpenMP parallel regions merged");
180STATISTIC(NumBytesMovedToSharedMemory,
181 "Amount of memory pushed to shared memory");
182STATISTIC(NumBarriersEliminated, "Number of redundant barriers eliminated");
183
184#if !defined(NDEBUG)
185static constexpr auto TAG = "[" DEBUG_TYPE "]";
186#endif
187
188namespace KernelInfo {
189
190// struct ConfigurationEnvironmentTy {
191// uint8_t UseGenericStateMachine;
192// uint8_t MayUseNestedParallelism;
193// llvm::omp::OMPTgtExecModeFlags ExecMode;
194// int32_t MinThreads;
195// int32_t MaxThreads;
196// int32_t MinTeams;
197// int32_t MaxTeams;
198// };
199
200// struct DynamicEnvironmentTy {
201// uint16_t DebugIndentionLevel;
202// };
203
204// struct KernelEnvironmentTy {
205// ConfigurationEnvironmentTy Configuration;
206// IdentTy *Ident;
207// DynamicEnvironmentTy *DynamicEnv;
208// };
209
210#define KERNEL_ENVIRONMENT_IDX(MEMBER, IDX) \
211 constexpr const unsigned MEMBER##Idx = IDX;
212
213KERNEL_ENVIRONMENT_IDX(Configuration, 0)
215
216#undef KERNEL_ENVIRONMENT_IDX
217
218#define KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MEMBER, IDX) \
219 constexpr const unsigned MEMBER##Idx = IDX;
220
221KERNEL_ENVIRONMENT_CONFIGURATION_IDX(UseGenericStateMachine, 0)
222KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MayUseNestedParallelism, 1)
228
229#undef KERNEL_ENVIRONMENT_CONFIGURATION_IDX
230
231#define KERNEL_ENVIRONMENT_GETTER(MEMBER, RETURNTYPE) \
232 RETURNTYPE *get##MEMBER##FromKernelEnvironment(ConstantStruct *KernelEnvC) { \
233 return cast<RETURNTYPE>(KernelEnvC->getAggregateElement(MEMBER##Idx)); \
234 }
235
238
239#undef KERNEL_ENVIRONMENT_GETTER
240
241#define KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MEMBER) \
242 ConstantInt *get##MEMBER##FromKernelEnvironment( \
243 ConstantStruct *KernelEnvC) { \
244 ConstantStruct *ConfigC = \
245 getConfigurationFromKernelEnvironment(KernelEnvC); \
246 return dyn_cast<ConstantInt>(ConfigC->getAggregateElement(MEMBER##Idx)); \
247 }
248
249KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(UseGenericStateMachine)
250KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MayUseNestedParallelism)
256
257#undef KERNEL_ENVIRONMENT_CONFIGURATION_GETTER
258
261 constexpr const int InitKernelEnvironmentArgNo = 0;
262 return cast<GlobalVariable>(
263 KernelInitCB->getArgOperand(InitKernelEnvironmentArgNo)
265}
266
268 GlobalVariable *KernelEnvGV =
270 return cast<ConstantStruct>(KernelEnvGV->getInitializer());
271}
272} // namespace KernelInfo
273
274namespace {
275
276struct AAHeapToShared;
277
278struct AAICVTracker;
279
280/// OpenMP specific information. For now, stores RFIs and ICVs also needed for
281/// Attributor runs.
282struct OMPInformationCache : public InformationCache {
283 OMPInformationCache(Module &M, AnalysisGetter &AG,
285 bool OpenMPPostLink)
286 : InformationCache(M, AG, Allocator, CGSCC), OMPBuilder(M),
287 OpenMPPostLink(OpenMPPostLink) {
288
289 OMPBuilder.Config.IsTargetDevice = isOpenMPDevice(OMPBuilder.M);
290 OMPBuilder.initialize();
291 initializeRuntimeFunctions(M);
292 initializeInternalControlVars();
293 }
294
295 /// Generic information that describes an internal control variable.
296 struct InternalControlVarInfo {
297 /// The kind, as described by InternalControlVar enum.
299
300 /// The name of the ICV.
302
303 /// Environment variable associated with this ICV.
304 StringRef EnvVarName;
305
306 /// Initial value kind.
307 ICVInitValue InitKind;
308
309 /// Initial value.
310 ConstantInt *InitValue;
311
312 /// Setter RTL function associated with this ICV.
313 RuntimeFunction Setter;
314
315 /// Getter RTL function associated with this ICV.
316 RuntimeFunction Getter;
317
318 /// RTL Function corresponding to the override clause of this ICV
320 };
321
322 /// Generic information that describes a runtime function
323 struct RuntimeFunctionInfo {
324
325 /// The kind, as described by the RuntimeFunction enum.
327
328 /// The name of the function.
330
331 /// Flag to indicate a variadic function.
332 bool IsVarArg;
333
334 /// The return type of the function.
336
337 /// The argument types of the function.
338 SmallVector<Type *, 8> ArgumentTypes;
339
340 /// The declaration if available.
341 Function *Declaration = nullptr;
342
343 /// Uses of this runtime function per function containing the use.
344 using UseVector = SmallVector<Use *, 16>;
345
346 /// Clear UsesMap for runtime function.
347 void clearUsesMap() { UsesMap.clear(); }
348
349 /// Boolean conversion that is true if the runtime function was found.
350 operator bool() const { return Declaration; }
351
352 /// Return the vector of uses in function \p F.
353 UseVector &getOrCreateUseVector(Function *F) {
354 std::shared_ptr<UseVector> &UV = UsesMap[F];
355 if (!UV)
356 UV = std::make_shared<UseVector>();
357 return *UV;
358 }
359
360 /// Return the vector of uses in function \p F or `nullptr` if there are
361 /// none.
362 const UseVector *getUseVector(Function &F) const {
363 auto I = UsesMap.find(&F);
364 if (I != UsesMap.end())
365 return I->second.get();
366 return nullptr;
367 }
368
369 /// Return how many functions contain uses of this runtime function.
370 size_t getNumFunctionsWithUses() const { return UsesMap.size(); }
371
372 /// Return the number of arguments (or the minimal number for variadic
373 /// functions).
374 size_t getNumArgs() const { return ArgumentTypes.size(); }
375
376 /// Run the callback \p CB on each use and forget the use if the result is
377 /// true. The callback will be fed the function in which the use was
378 /// encountered as second argument.
379 void foreachUse(SmallVectorImpl<Function *> &SCC,
380 function_ref<bool(Use &, Function &)> CB) {
381 for (Function *F : SCC)
382 foreachUse(CB, F);
383 }
384
385 /// Run the callback \p CB on each use within the function \p F and forget
386 /// the use if the result is true.
387 void foreachUse(function_ref<bool(Use &, Function &)> CB, Function *F) {
388 SmallVector<unsigned, 8> ToBeDeleted;
389 ToBeDeleted.clear();
390
391 unsigned Idx = 0;
392 UseVector &UV = getOrCreateUseVector(F);
393
394 for (Use *U : UV) {
395 if (CB(*U, *F))
396 ToBeDeleted.push_back(Idx);
397 ++Idx;
398 }
399
400 // Remove the to-be-deleted indices in reverse order as prior
401 // modifications will not modify the smaller indices.
402 while (!ToBeDeleted.empty()) {
403 unsigned Idx = ToBeDeleted.pop_back_val();
404 UV[Idx] = UV.back();
405 UV.pop_back();
406 }
407 }
408
409 private:
410 /// Map from functions to all uses of this runtime function contained in
411 /// them.
413
414 public:
415 /// Iterators for the uses of this runtime function.
416 decltype(UsesMap)::iterator begin() { return UsesMap.begin(); }
417 decltype(UsesMap)::iterator end() { return UsesMap.end(); }
418 };
419
420 /// An OpenMP-IR-Builder instance
421 OpenMPIRBuilder OMPBuilder;
422
423 /// Map from runtime function kind to the runtime function description.
424 EnumeratedArray<RuntimeFunctionInfo, RuntimeFunction,
425 RuntimeFunction::OMPRTL___last>
426 RFIs;
427
428 /// Map from function declarations/definitions to their runtime enum type.
429 DenseMap<Function *, RuntimeFunction> RuntimeFunctionIDMap;
430
431 /// Map from ICV kind to the ICV description.
432 EnumeratedArray<InternalControlVarInfo, InternalControlVar,
433 InternalControlVar::ICV___last>
434 ICVs;
435
436 /// Helper to initialize all internal control variable information for those
437 /// defined in OMPKinds.def.
438 void initializeInternalControlVars() {
439#define ICV_RT_SET(_Name, RTL) \
440 { \
441 auto &ICV = ICVs[_Name]; \
442 ICV.Setter = RTL; \
443 }
444#define ICV_RT_GET(Name, RTL) \
445 { \
446 auto &ICV = ICVs[Name]; \
447 ICV.Getter = RTL; \
448 }
449#define ICV_DATA_ENV(Enum, _Name, _EnvVarName, Init) \
450 { \
451 auto &ICV = ICVs[Enum]; \
452 ICV.Name = _Name; \
453 ICV.Kind = Enum; \
454 ICV.InitKind = Init; \
455 ICV.EnvVarName = _EnvVarName; \
456 switch (ICV.InitKind) { \
457 case ICV_IMPLEMENTATION_DEFINED: \
458 ICV.InitValue = nullptr; \
459 break; \
460 case ICV_ZERO: \
461 ICV.InitValue = ConstantInt::get( \
462 Type::getInt32Ty(OMPBuilder.Int32->getContext()), 0); \
463 break; \
464 case ICV_FALSE: \
465 ICV.InitValue = ConstantInt::getFalse(OMPBuilder.Int1->getContext()); \
466 break; \
467 case ICV_LAST: \
468 break; \
469 } \
470 }
471#include "llvm/Frontend/OpenMP/OMPKinds.def"
472 }
473
474 /// Returns true if the function declaration \p F matches the runtime
475 /// function types, that is, return type \p RTFRetType, and argument types
476 /// \p RTFArgTypes.
477 static bool declMatchesRTFTypes(Function *F, Type *RTFRetType,
478 SmallVector<Type *, 8> &RTFArgTypes) {
479 // TODO: We should output information to the user (under debug output
480 // and via remarks).
481
482 if (!F)
483 return false;
484 if (F->getReturnType() != RTFRetType)
485 return false;
486 if (F->arg_size() != RTFArgTypes.size())
487 return false;
488
489 auto *RTFTyIt = RTFArgTypes.begin();
490 for (Argument &Arg : F->args()) {
491 if (Arg.getType() != *RTFTyIt)
492 return false;
493
494 ++RTFTyIt;
495 }
496
497 return true;
498 }
499
500 // Helper to collect all uses of the declaration in the UsesMap.
501 unsigned collectUses(RuntimeFunctionInfo &RFI, bool CollectStats = true) {
502 unsigned NumUses = 0;
503 if (!RFI.Declaration)
504 return NumUses;
505 OMPBuilder.addAttributes(RFI.Kind, *RFI.Declaration);
506
507 if (CollectStats) {
508 NumOpenMPRuntimeFunctionsIdentified += 1;
509 NumOpenMPRuntimeFunctionUsesIdentified += RFI.Declaration->getNumUses();
510 }
511
512 // TODO: We directly convert uses into proper calls and unknown uses.
513 for (Use &U : RFI.Declaration->uses()) {
514 if (Instruction *UserI = dyn_cast<Instruction>(U.getUser())) {
515 if (!CGSCC || CGSCC->empty() || CGSCC->contains(UserI->getFunction())) {
516 RFI.getOrCreateUseVector(UserI->getFunction()).push_back(&U);
517 ++NumUses;
518 }
519 } else {
520 RFI.getOrCreateUseVector(nullptr).push_back(&U);
521 ++NumUses;
522 }
523 }
524 return NumUses;
525 }
526
527 // Helper function to recollect uses of a runtime function.
528 void recollectUsesForFunction(RuntimeFunction RTF) {
529 auto &RFI = RFIs[RTF];
530 RFI.clearUsesMap();
531 collectUses(RFI, /*CollectStats*/ false);
532 }
533
534 // Helper function to recollect uses of all runtime functions.
535 void recollectUses() {
536 for (int Idx = 0; Idx < RFIs.size(); ++Idx)
537 recollectUsesForFunction(static_cast<RuntimeFunction>(Idx));
538 }
539
540 // Helper function to inherit the calling convention of the function callee.
541 void setCallingConvention(FunctionCallee Callee, CallInst *CI) {
542 if (Function *Fn = dyn_cast<Function>(Callee.getCallee()))
543 CI->setCallingConv(Fn->getCallingConv());
544 }
545
546 // Helper function to determine if it's legal to create a call to the runtime
547 // functions.
548 bool runtimeFnsAvailable(ArrayRef<RuntimeFunction> Fns) {
549 // We can always emit calls if we haven't yet linked in the runtime.
550 if (!OpenMPPostLink)
551 return true;
552
553 // Once the runtime has been already been linked in we cannot emit calls to
554 // any undefined functions.
555 for (RuntimeFunction Fn : Fns) {
556 RuntimeFunctionInfo &RFI = RFIs[Fn];
557
558 if (RFI.Declaration && RFI.Declaration->isDeclaration())
559 return false;
560 }
561 return true;
562 }
563
564 /// Helper to initialize all runtime function information for those defined
565 /// in OpenMPKinds.def.
566 void initializeRuntimeFunctions(Module &M) {
567
568 // Helper macros for handling __VA_ARGS__ in OMP_RTL
569#define OMP_TYPE(VarName, ...) \
570 Type *VarName = OMPBuilder.VarName; \
571 (void)VarName;
572
573#define OMP_ARRAY_TYPE(VarName, ...) \
574 ArrayType *VarName##Ty = OMPBuilder.VarName##Ty; \
575 (void)VarName##Ty; \
576 PointerType *VarName##PtrTy = OMPBuilder.VarName##PtrTy; \
577 (void)VarName##PtrTy;
578
579#define OMP_FUNCTION_TYPE(VarName, ...) \
580 FunctionType *VarName = OMPBuilder.VarName; \
581 (void)VarName; \
582 PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \
583 (void)VarName##Ptr;
584
585#define OMP_STRUCT_TYPE(VarName, ...) \
586 StructType *VarName = OMPBuilder.VarName; \
587 (void)VarName; \
588 PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \
589 (void)VarName##Ptr;
590
591#define OMP_RTL(_Enum, _Name, _IsVarArg, _ReturnType, ...) \
592 { \
593 SmallVector<Type *, 8> ArgsTypes({__VA_ARGS__}); \
594 Function *F = M.getFunction(_Name); \
595 RTLFunctions.insert(F); \
596 if (declMatchesRTFTypes(F, OMPBuilder._ReturnType, ArgsTypes)) { \
597 RuntimeFunctionIDMap[F] = _Enum; \
598 auto &RFI = RFIs[_Enum]; \
599 RFI.Kind = _Enum; \
600 RFI.Name = _Name; \
601 RFI.IsVarArg = _IsVarArg; \
602 RFI.ReturnType = OMPBuilder._ReturnType; \
603 RFI.ArgumentTypes = std::move(ArgsTypes); \
604 RFI.Declaration = F; \
605 unsigned NumUses = collectUses(RFI); \
606 (void)NumUses; \
607 LLVM_DEBUG({ \
608 dbgs() << TAG << RFI.Name << (RFI.Declaration ? "" : " not") \
609 << " found\n"; \
610 if (RFI.Declaration) \
611 dbgs() << TAG << "-> got " << NumUses << " uses in " \
612 << RFI.getNumFunctionsWithUses() \
613 << " different functions.\n"; \
614 }); \
615 } \
616 }
617#include "llvm/Frontend/OpenMP/OMPKinds.def"
618
619 // Remove the `noinline` attribute from `__kmpc`, `ompx::` and `omp_`
620 // functions, except if `optnone` is present.
621 if (isOpenMPDevice(M)) {
622 for (Function &F : M) {
623 for (StringRef Prefix : {"__kmpc", "_ZN4ompx", "omp_"})
624 if (F.hasFnAttribute(Attribute::NoInline) &&
625 F.getName().starts_with(Prefix) &&
626 !F.hasFnAttribute(Attribute::OptimizeNone))
627 F.removeFnAttr(Attribute::NoInline);
628 }
629 }
630
631 // TODO: We should attach the attributes defined in OMPKinds.def.
632 }
633
634 /// Collection of known OpenMP runtime functions..
635 DenseSet<const Function *> RTLFunctions;
636
637 /// Indicates if we have already linked in the OpenMP device library.
638 bool OpenMPPostLink = false;
639};
640
641template <typename Ty, bool InsertInvalidates = true>
642struct BooleanStateWithSetVector : public BooleanState {
643 bool contains(const Ty &Elem) const { return Set.contains(Elem); }
644 bool insert(const Ty &Elem) {
645 if (InsertInvalidates)
647 return Set.insert(Elem);
648 }
649
650 const Ty &operator[](int Idx) const { return Set[Idx]; }
651 bool operator==(const BooleanStateWithSetVector &RHS) const {
652 return BooleanState::operator==(RHS) && Set == RHS.Set;
653 }
654 bool operator!=(const BooleanStateWithSetVector &RHS) const {
655 return !(*this == RHS);
656 }
657
658 bool empty() const { return Set.empty(); }
659 size_t size() const { return Set.size(); }
660
661 /// "Clamp" this state with \p RHS.
662 BooleanStateWithSetVector &operator^=(const BooleanStateWithSetVector &RHS) {
663 BooleanState::operator^=(RHS);
664 Set.insert(RHS.Set.begin(), RHS.Set.end());
665 return *this;
666 }
667
668private:
669 /// A set to keep track of elements.
670 SetVector<Ty> Set;
671
672public:
673 typename decltype(Set)::iterator begin() { return Set.begin(); }
674 typename decltype(Set)::iterator end() { return Set.end(); }
675 typename decltype(Set)::const_iterator begin() const { return Set.begin(); }
676 typename decltype(Set)::const_iterator end() const { return Set.end(); }
677};
678
679template <typename Ty, bool InsertInvalidates = true>
680using BooleanStateWithPtrSetVector =
681 BooleanStateWithSetVector<Ty *, InsertInvalidates>;
682
683struct KernelInfoState : AbstractState {
684 /// Flag to track if we reached a fixpoint.
685 bool IsAtFixpoint = false;
686
687 /// The parallel regions (identified by the outlined parallel functions) that
688 /// can be reached from the associated function.
689 BooleanStateWithPtrSetVector<CallBase, /* InsertInvalidates */ false>
690 ReachedKnownParallelRegions;
691
692 /// State to track what parallel region we might reach.
693 BooleanStateWithPtrSetVector<CallBase> ReachedUnknownParallelRegions;
694
695 /// State to track if we are in SPMD-mode, assumed or know, and why we decided
696 /// we cannot be. If it is assumed, then RequiresFullRuntime should also be
697 /// false.
698 BooleanStateWithPtrSetVector<Instruction, false> SPMDCompatibilityTracker;
699
700 /// The __kmpc_target_init call in this kernel, if any. If we find more than
701 /// one we abort as the kernel is malformed.
702 CallBase *KernelInitCB = nullptr;
703
704 /// The constant kernel environement as taken from and passed to
705 /// __kmpc_target_init.
706 ConstantStruct *KernelEnvC = nullptr;
707
708 /// The __kmpc_target_deinit call in this kernel, if any. If we find more than
709 /// one we abort as the kernel is malformed.
710 CallBase *KernelDeinitCB = nullptr;
711
712 /// Flag to indicate if the associated function is a kernel entry.
713 bool IsKernelEntry = false;
714
715 /// State to track what kernel entries can reach the associated function.
716 BooleanStateWithPtrSetVector<Function, false> ReachingKernelEntries;
717
718 /// State to indicate if we can track parallel level of the associated
719 /// function. We will give up tracking if we encounter unknown caller or the
720 /// caller is __kmpc_parallel_51.
721 BooleanStateWithSetVector<uint8_t> ParallelLevels;
722
723 /// Flag that indicates if the kernel has nested Parallelism
724 bool NestedParallelism = false;
725
726 /// Abstract State interface
727 ///{
728
729 KernelInfoState() = default;
730 KernelInfoState(bool BestState) {
731 if (!BestState)
733 }
734
735 /// See AbstractState::isValidState(...)
736 bool isValidState() const override { return true; }
737
738 /// See AbstractState::isAtFixpoint(...)
739 bool isAtFixpoint() const override { return IsAtFixpoint; }
740
741 /// See AbstractState::indicatePessimisticFixpoint(...)
743 IsAtFixpoint = true;
744 ParallelLevels.indicatePessimisticFixpoint();
745 ReachingKernelEntries.indicatePessimisticFixpoint();
746 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
747 ReachedKnownParallelRegions.indicatePessimisticFixpoint();
748 ReachedUnknownParallelRegions.indicatePessimisticFixpoint();
749 NestedParallelism = true;
751 }
752
753 /// See AbstractState::indicateOptimisticFixpoint(...)
755 IsAtFixpoint = true;
756 ParallelLevels.indicateOptimisticFixpoint();
757 ReachingKernelEntries.indicateOptimisticFixpoint();
758 SPMDCompatibilityTracker.indicateOptimisticFixpoint();
759 ReachedKnownParallelRegions.indicateOptimisticFixpoint();
760 ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
762 }
763
764 /// Return the assumed state
765 KernelInfoState &getAssumed() { return *this; }
766 const KernelInfoState &getAssumed() const { return *this; }
767
768 bool operator==(const KernelInfoState &RHS) const {
769 if (SPMDCompatibilityTracker != RHS.SPMDCompatibilityTracker)
770 return false;
771 if (ReachedKnownParallelRegions != RHS.ReachedKnownParallelRegions)
772 return false;
773 if (ReachedUnknownParallelRegions != RHS.ReachedUnknownParallelRegions)
774 return false;
775 if (ReachingKernelEntries != RHS.ReachingKernelEntries)
776 return false;
777 if (ParallelLevels != RHS.ParallelLevels)
778 return false;
779 if (NestedParallelism != RHS.NestedParallelism)
780 return false;
781 return true;
782 }
783
784 /// Returns true if this kernel contains any OpenMP parallel regions.
785 bool mayContainParallelRegion() {
786 return !ReachedKnownParallelRegions.empty() ||
787 !ReachedUnknownParallelRegions.empty();
788 }
789
790 /// Return empty set as the best state of potential values.
791 static KernelInfoState getBestState() { return KernelInfoState(true); }
792
793 static KernelInfoState getBestState(KernelInfoState &KIS) {
794 return getBestState();
795 }
796
797 /// Return full set as the worst state of potential values.
798 static KernelInfoState getWorstState() { return KernelInfoState(false); }
799
800 /// "Clamp" this state with \p KIS.
801 KernelInfoState operator^=(const KernelInfoState &KIS) {
802 // Do not merge two different _init and _deinit call sites.
803 if (KIS.KernelInitCB) {
804 if (KernelInitCB && KernelInitCB != KIS.KernelInitCB)
805 llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
806 "assumptions.");
807 KernelInitCB = KIS.KernelInitCB;
808 }
809 if (KIS.KernelDeinitCB) {
810 if (KernelDeinitCB && KernelDeinitCB != KIS.KernelDeinitCB)
811 llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
812 "assumptions.");
813 KernelDeinitCB = KIS.KernelDeinitCB;
814 }
815 if (KIS.KernelEnvC) {
816 if (KernelEnvC && KernelEnvC != KIS.KernelEnvC)
817 llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
818 "assumptions.");
819 KernelEnvC = KIS.KernelEnvC;
820 }
821 SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker;
822 ReachedKnownParallelRegions ^= KIS.ReachedKnownParallelRegions;
823 ReachedUnknownParallelRegions ^= KIS.ReachedUnknownParallelRegions;
824 NestedParallelism |= KIS.NestedParallelism;
825 return *this;
826 }
827
828 KernelInfoState operator&=(const KernelInfoState &KIS) {
829 return (*this ^= KIS);
830 }
831
832 ///}
833};
834
835/// Used to map the values physically (in the IR) stored in an offload
836/// array, to a vector in memory.
837struct OffloadArray {
838 /// Physical array (in the IR).
839 AllocaInst *Array = nullptr;
840 /// Mapped values.
841 SmallVector<Value *, 8> StoredValues;
842 /// Last stores made in the offload array.
843 SmallVector<StoreInst *, 8> LastAccesses;
844
845 OffloadArray() = default;
846
847 /// Initializes the OffloadArray with the values stored in \p Array before
848 /// instruction \p Before is reached. Returns false if the initialization
849 /// fails.
850 /// This MUST be used immediately after the construction of the object.
851 bool initialize(AllocaInst &Array, Instruction &Before) {
852 if (!Array.getAllocatedType()->isArrayTy())
853 return false;
854
855 if (!getValues(Array, Before))
856 return false;
857
858 this->Array = &Array;
859 return true;
860 }
861
862 static const unsigned DeviceIDArgNum = 1;
863 static const unsigned BasePtrsArgNum = 3;
864 static const unsigned PtrsArgNum = 4;
865 static const unsigned SizesArgNum = 5;
866
867private:
868 /// Traverses the BasicBlock where \p Array is, collecting the stores made to
869 /// \p Array, leaving StoredValues with the values stored before the
870 /// instruction \p Before is reached.
871 bool getValues(AllocaInst &Array, Instruction &Before) {
872 // Initialize container.
873 const uint64_t NumValues = Array.getAllocatedType()->getArrayNumElements();
874 StoredValues.assign(NumValues, nullptr);
875 LastAccesses.assign(NumValues, nullptr);
876
877 // TODO: This assumes the instruction \p Before is in the same
878 // BasicBlock as Array. Make it general, for any control flow graph.
879 BasicBlock *BB = Array.getParent();
880 if (BB != Before.getParent())
881 return false;
882
883 const DataLayout &DL = Array.getModule()->getDataLayout();
884 const unsigned int PointerSize = DL.getPointerSize();
885
886 for (Instruction &I : *BB) {
887 if (&I == &Before)
888 break;
889
890 if (!isa<StoreInst>(&I))
891 continue;
892
893 auto *S = cast<StoreInst>(&I);
894 int64_t Offset = -1;
895 auto *Dst =
896 GetPointerBaseWithConstantOffset(S->getPointerOperand(), Offset, DL);
897 if (Dst == &Array) {
898 int64_t Idx = Offset / PointerSize;
899 StoredValues[Idx] = getUnderlyingObject(S->getValueOperand());
900 LastAccesses[Idx] = S;
901 }
902 }
903
904 return isFilled();
905 }
906
907 /// Returns true if all values in StoredValues and
908 /// LastAccesses are not nullptrs.
909 bool isFilled() {
910 const unsigned NumValues = StoredValues.size();
911 for (unsigned I = 0; I < NumValues; ++I) {
912 if (!StoredValues[I] || !LastAccesses[I])
913 return false;
914 }
915
916 return true;
917 }
918};
919
920struct OpenMPOpt {
921
922 using OptimizationRemarkGetter =
924
925 OpenMPOpt(SmallVectorImpl<Function *> &SCC, CallGraphUpdater &CGUpdater,
926 OptimizationRemarkGetter OREGetter,
927 OMPInformationCache &OMPInfoCache, Attributor &A)
928 : M(*(*SCC.begin())->getParent()), SCC(SCC), CGUpdater(CGUpdater),
929 OREGetter(OREGetter), OMPInfoCache(OMPInfoCache), A(A) {}
930
931 /// Check if any remarks are enabled for openmp-opt
932 bool remarksEnabled() {
933 auto &Ctx = M.getContext();
934 return Ctx.getDiagHandlerPtr()->isAnyRemarkEnabled(DEBUG_TYPE);
935 }
936
937 /// Run all OpenMP optimizations on the underlying SCC.
938 bool run(bool IsModulePass) {
939 if (SCC.empty())
940 return false;
941
942 bool Changed = false;
943
944 LLVM_DEBUG(dbgs() << TAG << "Run on SCC with " << SCC.size()
945 << " functions\n");
946
947 if (IsModulePass) {
948 Changed |= runAttributor(IsModulePass);
949
950 // Recollect uses, in case Attributor deleted any.
951 OMPInfoCache.recollectUses();
952
953 // TODO: This should be folded into buildCustomStateMachine.
954 Changed |= rewriteDeviceCodeStateMachine();
955
956 if (remarksEnabled())
957 analysisGlobalization();
958 } else {
959 if (PrintICVValues)
960 printICVs();
962 printKernels();
963
964 Changed |= runAttributor(IsModulePass);
965
966 // Recollect uses, in case Attributor deleted any.
967 OMPInfoCache.recollectUses();
968
969 Changed |= deleteParallelRegions();
970
972 Changed |= hideMemTransfersLatency();
973 Changed |= deduplicateRuntimeCalls();
975 if (mergeParallelRegions()) {
976 deduplicateRuntimeCalls();
977 Changed = true;
978 }
979 }
980 }
981
982 if (OMPInfoCache.OpenMPPostLink)
983 Changed |= removeRuntimeSymbols();
984
985 return Changed;
986 }
987
988 /// Print initial ICV values for testing.
989 /// FIXME: This should be done from the Attributor once it is added.
990 void printICVs() const {
991 InternalControlVar ICVs[] = {ICV_nthreads, ICV_active_levels, ICV_cancel,
992 ICV_proc_bind};
993
994 for (Function *F : SCC) {
995 for (auto ICV : ICVs) {
996 auto ICVInfo = OMPInfoCache.ICVs[ICV];
997 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
998 return ORA << "OpenMP ICV " << ore::NV("OpenMPICV", ICVInfo.Name)
999 << " Value: "
1000 << (ICVInfo.InitValue
1001 ? toString(ICVInfo.InitValue->getValue(), 10, true)
1002 : "IMPLEMENTATION_DEFINED");
1003 };
1004
1005 emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPICVTracker", Remark);
1006 }
1007 }
1008 }
1009
1010 /// Print OpenMP GPU kernels for testing.
1011 void printKernels() const {
1012 for (Function *F : SCC) {
1013 if (!omp::isOpenMPKernel(*F))
1014 continue;
1015
1016 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
1017 return ORA << "OpenMP GPU kernel "
1018 << ore::NV("OpenMPGPUKernel", F->getName()) << "\n";
1019 };
1020
1021 emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPGPU", Remark);
1022 }
1023 }
1024
1025 /// Return the call if \p U is a callee use in a regular call. If \p RFI is
1026 /// given it has to be the callee or a nullptr is returned.
1027 static CallInst *getCallIfRegularCall(
1028 Use &U, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
1029 CallInst *CI = dyn_cast<CallInst>(U.getUser());
1030 if (CI && CI->isCallee(&U) && !CI->hasOperandBundles() &&
1031 (!RFI ||
1032 (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
1033 return CI;
1034 return nullptr;
1035 }
1036
1037 /// Return the call if \p V is a regular call. If \p RFI is given it has to be
1038 /// the callee or a nullptr is returned.
1039 static CallInst *getCallIfRegularCall(
1040 Value &V, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
1041 CallInst *CI = dyn_cast<CallInst>(&V);
1042 if (CI && !CI->hasOperandBundles() &&
1043 (!RFI ||
1044 (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
1045 return CI;
1046 return nullptr;
1047 }
1048
1049private:
1050 /// Merge parallel regions when it is safe.
1051 bool mergeParallelRegions() {
1052 const unsigned CallbackCalleeOperand = 2;
1053 const unsigned CallbackFirstArgOperand = 3;
1054 using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
1055
1056 // Check if there are any __kmpc_fork_call calls to merge.
1057 OMPInformationCache::RuntimeFunctionInfo &RFI =
1058 OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
1059
1060 if (!RFI.Declaration)
1061 return false;
1062
1063 // Unmergable calls that prevent merging a parallel region.
1064 OMPInformationCache::RuntimeFunctionInfo UnmergableCallsInfo[] = {
1065 OMPInfoCache.RFIs[OMPRTL___kmpc_push_proc_bind],
1066 OMPInfoCache.RFIs[OMPRTL___kmpc_push_num_threads],
1067 };
1068
1069 bool Changed = false;
1070 LoopInfo *LI = nullptr;
1071 DominatorTree *DT = nullptr;
1072
1074
1075 BasicBlock *StartBB = nullptr, *EndBB = nullptr;
1076 auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
1077 BasicBlock *CGStartBB = CodeGenIP.getBlock();
1078 BasicBlock *CGEndBB =
1079 SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
1080 assert(StartBB != nullptr && "StartBB should not be null");
1081 CGStartBB->getTerminator()->setSuccessor(0, StartBB);
1082 assert(EndBB != nullptr && "EndBB should not be null");
1083 EndBB->getTerminator()->setSuccessor(0, CGEndBB);
1084 };
1085
1086 auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, Value &,
1087 Value &Inner, Value *&ReplacementValue) -> InsertPointTy {
1088 ReplacementValue = &Inner;
1089 return CodeGenIP;
1090 };
1091
1092 auto FiniCB = [&](InsertPointTy CodeGenIP) {};
1093
1094 /// Create a sequential execution region within a merged parallel region,
1095 /// encapsulated in a master construct with a barrier for synchronization.
1096 auto CreateSequentialRegion = [&](Function *OuterFn,
1097 BasicBlock *OuterPredBB,
1098 Instruction *SeqStartI,
1099 Instruction *SeqEndI) {
1100 // Isolate the instructions of the sequential region to a separate
1101 // block.
1102 BasicBlock *ParentBB = SeqStartI->getParent();
1103 BasicBlock *SeqEndBB =
1104 SplitBlock(ParentBB, SeqEndI->getNextNode(), DT, LI);
1105 BasicBlock *SeqAfterBB =
1106 SplitBlock(SeqEndBB, &*SeqEndBB->getFirstInsertionPt(), DT, LI);
1107 BasicBlock *SeqStartBB =
1108 SplitBlock(ParentBB, SeqStartI, DT, LI, nullptr, "seq.par.merged");
1109
1110 assert(ParentBB->getUniqueSuccessor() == SeqStartBB &&
1111 "Expected a different CFG");
1112 const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
1113 ParentBB->getTerminator()->eraseFromParent();
1114
1115 auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
1116 BasicBlock *CGStartBB = CodeGenIP.getBlock();
1117 BasicBlock *CGEndBB =
1118 SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
1119 assert(SeqStartBB != nullptr && "SeqStartBB should not be null");
1120 CGStartBB->getTerminator()->setSuccessor(0, SeqStartBB);
1121 assert(SeqEndBB != nullptr && "SeqEndBB should not be null");
1122 SeqEndBB->getTerminator()->setSuccessor(0, CGEndBB);
1123 };
1124 auto FiniCB = [&](InsertPointTy CodeGenIP) {};
1125
1126 // Find outputs from the sequential region to outside users and
1127 // broadcast their values to them.
1128 for (Instruction &I : *SeqStartBB) {
1129 SmallPtrSet<Instruction *, 4> OutsideUsers;
1130 for (User *Usr : I.users()) {
1131 Instruction &UsrI = *cast<Instruction>(Usr);
1132 // Ignore outputs to LT intrinsics, code extraction for the merged
1133 // parallel region will fix them.
1134 if (UsrI.isLifetimeStartOrEnd())
1135 continue;
1136
1137 if (UsrI.getParent() != SeqStartBB)
1138 OutsideUsers.insert(&UsrI);
1139 }
1140
1141 if (OutsideUsers.empty())
1142 continue;
1143
1144 // Emit an alloca in the outer region to store the broadcasted
1145 // value.
1146 const DataLayout &DL = M.getDataLayout();
1147 AllocaInst *AllocaI = new AllocaInst(
1148 I.getType(), DL.getAllocaAddrSpace(), nullptr,
1149 I.getName() + ".seq.output.alloc", OuterFn->front().begin());
1150
1151 // Emit a store instruction in the sequential BB to update the
1152 // value.
1153 new StoreInst(&I, AllocaI, SeqStartBB->getTerminator()->getIterator());
1154
1155 // Emit a load instruction and replace the use of the output value
1156 // with it.
1157 for (Instruction *UsrI : OutsideUsers) {
1158 LoadInst *LoadI = new LoadInst(I.getType(), AllocaI,
1159 I.getName() + ".seq.output.load",
1160 UsrI->getIterator());
1161 UsrI->replaceUsesOfWith(&I, LoadI);
1162 }
1163 }
1164
1166 InsertPointTy(ParentBB, ParentBB->end()), DL);
1167 InsertPointTy SeqAfterIP =
1168 OMPInfoCache.OMPBuilder.createMaster(Loc, BodyGenCB, FiniCB);
1169
1170 OMPInfoCache.OMPBuilder.createBarrier(SeqAfterIP, OMPD_parallel);
1171
1172 BranchInst::Create(SeqAfterBB, SeqAfterIP.getBlock());
1173
1174 LLVM_DEBUG(dbgs() << TAG << "After sequential inlining " << *OuterFn
1175 << "\n");
1176 };
1177
1178 // Helper to merge the __kmpc_fork_call calls in MergableCIs. They are all
1179 // contained in BB and only separated by instructions that can be
1180 // redundantly executed in parallel. The block BB is split before the first
1181 // call (in MergableCIs) and after the last so the entire region we merge
1182 // into a single parallel region is contained in a single basic block
1183 // without any other instructions. We use the OpenMPIRBuilder to outline
1184 // that block and call the resulting function via __kmpc_fork_call.
1185 auto Merge = [&](const SmallVectorImpl<CallInst *> &MergableCIs,
1186 BasicBlock *BB) {
1187 // TODO: Change the interface to allow single CIs expanded, e.g, to
1188 // include an outer loop.
1189 assert(MergableCIs.size() > 1 && "Assumed multiple mergable CIs");
1190
1191 auto Remark = [&](OptimizationRemark OR) {
1192 OR << "Parallel region merged with parallel region"
1193 << (MergableCIs.size() > 2 ? "s" : "") << " at ";
1194 for (auto *CI : llvm::drop_begin(MergableCIs)) {
1195 OR << ore::NV("OpenMPParallelMerge", CI->getDebugLoc());
1196 if (CI != MergableCIs.back())
1197 OR << ", ";
1198 }
1199 return OR << ".";
1200 };
1201
1202 emitRemark<OptimizationRemark>(MergableCIs.front(), "OMP150", Remark);
1203
1204 Function *OriginalFn = BB->getParent();
1205 LLVM_DEBUG(dbgs() << TAG << "Merge " << MergableCIs.size()
1206 << " parallel regions in " << OriginalFn->getName()
1207 << "\n");
1208
1209 // Isolate the calls to merge in a separate block.
1210 EndBB = SplitBlock(BB, MergableCIs.back()->getNextNode(), DT, LI);
1211 BasicBlock *AfterBB =
1212 SplitBlock(EndBB, &*EndBB->getFirstInsertionPt(), DT, LI);
1213 StartBB = SplitBlock(BB, MergableCIs.front(), DT, LI, nullptr,
1214 "omp.par.merged");
1215
1216 assert(BB->getUniqueSuccessor() == StartBB && "Expected a different CFG");
1217 const DebugLoc DL = BB->getTerminator()->getDebugLoc();
1218 BB->getTerminator()->eraseFromParent();
1219
1220 // Create sequential regions for sequential instructions that are
1221 // in-between mergable parallel regions.
1222 for (auto *It = MergableCIs.begin(), *End = MergableCIs.end() - 1;
1223 It != End; ++It) {
1224 Instruction *ForkCI = *It;
1225 Instruction *NextForkCI = *(It + 1);
1226
1227 // Continue if there are not in-between instructions.
1228 if (ForkCI->getNextNode() == NextForkCI)
1229 continue;
1230
1231 CreateSequentialRegion(OriginalFn, BB, ForkCI->getNextNode(),
1232 NextForkCI->getPrevNode());
1233 }
1234
1235 OpenMPIRBuilder::LocationDescription Loc(InsertPointTy(BB, BB->end()),
1236 DL);
1237 IRBuilder<>::InsertPoint AllocaIP(
1238 &OriginalFn->getEntryBlock(),
1239 OriginalFn->getEntryBlock().getFirstInsertionPt());
1240 // Create the merged parallel region with default proc binding, to
1241 // avoid overriding binding settings, and without explicit cancellation.
1242 InsertPointTy AfterIP = OMPInfoCache.OMPBuilder.createParallel(
1243 Loc, AllocaIP, BodyGenCB, PrivCB, FiniCB, nullptr, nullptr,
1244 OMP_PROC_BIND_default, /* IsCancellable */ false);
1245 BranchInst::Create(AfterBB, AfterIP.getBlock());
1246
1247 // Perform the actual outlining.
1248 OMPInfoCache.OMPBuilder.finalize(OriginalFn);
1249
1250 Function *OutlinedFn = MergableCIs.front()->getCaller();
1251
1252 // Replace the __kmpc_fork_call calls with direct calls to the outlined
1253 // callbacks.
1255 for (auto *CI : MergableCIs) {
1256 Value *Callee = CI->getArgOperand(CallbackCalleeOperand);
1257 FunctionType *FT = OMPInfoCache.OMPBuilder.ParallelTask;
1258 Args.clear();
1259 Args.push_back(OutlinedFn->getArg(0));
1260 Args.push_back(OutlinedFn->getArg(1));
1261 for (unsigned U = CallbackFirstArgOperand, E = CI->arg_size(); U < E;
1262 ++U)
1263 Args.push_back(CI->getArgOperand(U));
1264
1265 CallInst *NewCI =
1266 CallInst::Create(FT, Callee, Args, "", CI->getIterator());
1267 if (CI->getDebugLoc())
1268 NewCI->setDebugLoc(CI->getDebugLoc());
1269
1270 // Forward parameter attributes from the callback to the callee.
1271 for (unsigned U = CallbackFirstArgOperand, E = CI->arg_size(); U < E;
1272 ++U)
1273 for (const Attribute &A : CI->getAttributes().getParamAttrs(U))
1274 NewCI->addParamAttr(
1275 U - (CallbackFirstArgOperand - CallbackCalleeOperand), A);
1276
1277 // Emit an explicit barrier to replace the implicit fork-join barrier.
1278 if (CI != MergableCIs.back()) {
1279 // TODO: Remove barrier if the merged parallel region includes the
1280 // 'nowait' clause.
1281 OMPInfoCache.OMPBuilder.createBarrier(
1282 InsertPointTy(NewCI->getParent(),
1283 NewCI->getNextNode()->getIterator()),
1284 OMPD_parallel);
1285 }
1286
1287 CI->eraseFromParent();
1288 }
1289
1290 assert(OutlinedFn != OriginalFn && "Outlining failed");
1291 CGUpdater.registerOutlinedFunction(*OriginalFn, *OutlinedFn);
1292 CGUpdater.reanalyzeFunction(*OriginalFn);
1293
1294 NumOpenMPParallelRegionsMerged += MergableCIs.size();
1295
1296 return true;
1297 };
1298
1299 // Helper function that identifes sequences of
1300 // __kmpc_fork_call uses in a basic block.
1301 auto DetectPRsCB = [&](Use &U, Function &F) {
1302 CallInst *CI = getCallIfRegularCall(U, &RFI);
1303 BB2PRMap[CI->getParent()].insert(CI);
1304
1305 return false;
1306 };
1307
1308 BB2PRMap.clear();
1309 RFI.foreachUse(SCC, DetectPRsCB);
1310 SmallVector<SmallVector<CallInst *, 4>, 4> MergableCIsVector;
1311 // Find mergable parallel regions within a basic block that are
1312 // safe to merge, that is any in-between instructions can safely
1313 // execute in parallel after merging.
1314 // TODO: support merging across basic-blocks.
1315 for (auto &It : BB2PRMap) {
1316 auto &CIs = It.getSecond();
1317 if (CIs.size() < 2)
1318 continue;
1319
1320 BasicBlock *BB = It.getFirst();
1321 SmallVector<CallInst *, 4> MergableCIs;
1322
1323 /// Returns true if the instruction is mergable, false otherwise.
1324 /// A terminator instruction is unmergable by definition since merging
1325 /// works within a BB. Instructions before the mergable region are
1326 /// mergable if they are not calls to OpenMP runtime functions that may
1327 /// set different execution parameters for subsequent parallel regions.
1328 /// Instructions in-between parallel regions are mergable if they are not
1329 /// calls to any non-intrinsic function since that may call a non-mergable
1330 /// OpenMP runtime function.
1331 auto IsMergable = [&](Instruction &I, bool IsBeforeMergableRegion) {
1332 // We do not merge across BBs, hence return false (unmergable) if the
1333 // instruction is a terminator.
1334 if (I.isTerminator())
1335 return false;
1336
1337 if (!isa<CallInst>(&I))
1338 return true;
1339
1340 CallInst *CI = cast<CallInst>(&I);
1341 if (IsBeforeMergableRegion) {
1342 Function *CalledFunction = CI->getCalledFunction();
1343 if (!CalledFunction)
1344 return false;
1345 // Return false (unmergable) if the call before the parallel
1346 // region calls an explicit affinity (proc_bind) or number of
1347 // threads (num_threads) compiler-generated function. Those settings
1348 // may be incompatible with following parallel regions.
1349 // TODO: ICV tracking to detect compatibility.
1350 for (const auto &RFI : UnmergableCallsInfo) {
1351 if (CalledFunction == RFI.Declaration)
1352 return false;
1353 }
1354 } else {
1355 // Return false (unmergable) if there is a call instruction
1356 // in-between parallel regions when it is not an intrinsic. It
1357 // may call an unmergable OpenMP runtime function in its callpath.
1358 // TODO: Keep track of possible OpenMP calls in the callpath.
1359 if (!isa<IntrinsicInst>(CI))
1360 return false;
1361 }
1362
1363 return true;
1364 };
1365 // Find maximal number of parallel region CIs that are safe to merge.
1366 for (auto It = BB->begin(), End = BB->end(); It != End;) {
1367 Instruction &I = *It;
1368 ++It;
1369
1370 if (CIs.count(&I)) {
1371 MergableCIs.push_back(cast<CallInst>(&I));
1372 continue;
1373 }
1374
1375 // Continue expanding if the instruction is mergable.
1376 if (IsMergable(I, MergableCIs.empty()))
1377 continue;
1378
1379 // Forward the instruction iterator to skip the next parallel region
1380 // since there is an unmergable instruction which can affect it.
1381 for (; It != End; ++It) {
1382 Instruction &SkipI = *It;
1383 if (CIs.count(&SkipI)) {
1384 LLVM_DEBUG(dbgs() << TAG << "Skip parallel region " << SkipI
1385 << " due to " << I << "\n");
1386 ++It;
1387 break;
1388 }
1389 }
1390
1391 // Store mergable regions found.
1392 if (MergableCIs.size() > 1) {
1393 MergableCIsVector.push_back(MergableCIs);
1394 LLVM_DEBUG(dbgs() << TAG << "Found " << MergableCIs.size()
1395 << " parallel regions in block " << BB->getName()
1396 << " of function " << BB->getParent()->getName()
1397 << "\n";);
1398 }
1399
1400 MergableCIs.clear();
1401 }
1402
1403 if (!MergableCIsVector.empty()) {
1404 Changed = true;
1405
1406 for (auto &MergableCIs : MergableCIsVector)
1407 Merge(MergableCIs, BB);
1408 MergableCIsVector.clear();
1409 }
1410 }
1411
1412 if (Changed) {
1413 /// Re-collect use for fork calls, emitted barrier calls, and
1414 /// any emitted master/end_master calls.
1415 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_fork_call);
1416 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_barrier);
1417 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_master);
1418 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_end_master);
1419 }
1420
1421 return Changed;
1422 }
1423
1424 /// Try to delete parallel regions if possible.
1425 bool deleteParallelRegions() {
1426 const unsigned CallbackCalleeOperand = 2;
1427
1428 OMPInformationCache::RuntimeFunctionInfo &RFI =
1429 OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
1430
1431 if (!RFI.Declaration)
1432 return false;
1433
1434 bool Changed = false;
1435 auto DeleteCallCB = [&](Use &U, Function &) {
1436 CallInst *CI = getCallIfRegularCall(U);
1437 if (!CI)
1438 return false;
1439 auto *Fn = dyn_cast<Function>(
1440 CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts());
1441 if (!Fn)
1442 return false;
1443 if (!Fn->onlyReadsMemory())
1444 return false;
1445 if (!Fn->hasFnAttribute(Attribute::WillReturn))
1446 return false;
1447
1448 LLVM_DEBUG(dbgs() << TAG << "Delete read-only parallel region in "
1449 << CI->getCaller()->getName() << "\n");
1450
1451 auto Remark = [&](OptimizationRemark OR) {
1452 return OR << "Removing parallel region with no side-effects.";
1453 };
1454 emitRemark<OptimizationRemark>(CI, "OMP160", Remark);
1455
1456 CGUpdater.removeCallSite(*CI);
1457 CI->eraseFromParent();
1458 Changed = true;
1459 ++NumOpenMPParallelRegionsDeleted;
1460 return true;
1461 };
1462
1463 RFI.foreachUse(SCC, DeleteCallCB);
1464
1465 return Changed;
1466 }
1467
1468 /// Try to eliminate runtime calls by reusing existing ones.
1469 bool deduplicateRuntimeCalls() {
1470 bool Changed = false;
1471
1472 RuntimeFunction DeduplicableRuntimeCallIDs[] = {
1473 OMPRTL_omp_get_num_threads,
1474 OMPRTL_omp_in_parallel,
1475 OMPRTL_omp_get_cancellation,
1476 OMPRTL_omp_get_supported_active_levels,
1477 OMPRTL_omp_get_level,
1478 OMPRTL_omp_get_ancestor_thread_num,
1479 OMPRTL_omp_get_team_size,
1480 OMPRTL_omp_get_active_level,
1481 OMPRTL_omp_in_final,
1482 OMPRTL_omp_get_proc_bind,
1483 OMPRTL_omp_get_num_places,
1484 OMPRTL_omp_get_num_procs,
1485 OMPRTL_omp_get_place_num,
1486 OMPRTL_omp_get_partition_num_places,
1487 OMPRTL_omp_get_partition_place_nums};
1488
1489 // Global-tid is handled separately.
1491 collectGlobalThreadIdArguments(GTIdArgs);
1492 LLVM_DEBUG(dbgs() << TAG << "Found " << GTIdArgs.size()
1493 << " global thread ID arguments\n");
1494
1495 for (Function *F : SCC) {
1496 for (auto DeduplicableRuntimeCallID : DeduplicableRuntimeCallIDs)
1497 Changed |= deduplicateRuntimeCalls(
1498 *F, OMPInfoCache.RFIs[DeduplicableRuntimeCallID]);
1499
1500 // __kmpc_global_thread_num is special as we can replace it with an
1501 // argument in enough cases to make it worth trying.
1502 Value *GTIdArg = nullptr;
1503 for (Argument &Arg : F->args())
1504 if (GTIdArgs.count(&Arg)) {
1505 GTIdArg = &Arg;
1506 break;
1507 }
1508 Changed |= deduplicateRuntimeCalls(
1509 *F, OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num], GTIdArg);
1510 }
1511
1512 return Changed;
1513 }
1514
1515 /// Tries to remove known runtime symbols that are optional from the module.
1516 bool removeRuntimeSymbols() {
1517 // The RPC client symbol is defined in `libc` and indicates that something
1518 // required an RPC server. If its users were all optimized out then we can
1519 // safely remove it.
1520 // TODO: This should be somewhere more common in the future.
1521 if (GlobalVariable *GV = M.getNamedGlobal("__llvm_libc_rpc_client")) {
1522 if (!GV->getType()->isPointerTy())
1523 return false;
1524
1525 Constant *C = GV->getInitializer();
1526 if (!C)
1527 return false;
1528
1529 // Check to see if the only user of the RPC client is the external handle.
1530 GlobalVariable *Client = dyn_cast<GlobalVariable>(C->stripPointerCasts());
1531 if (!Client || Client->getNumUses() > 1 ||
1532 Client->user_back() != GV->getInitializer())
1533 return false;
1534
1535 Client->replaceAllUsesWith(PoisonValue::get(Client->getType()));
1536 Client->eraseFromParent();
1537
1538 GV->replaceAllUsesWith(PoisonValue::get(GV->getType()));
1539 GV->eraseFromParent();
1540
1541 return true;
1542 }
1543 return false;
1544 }
1545
1546 /// Tries to hide the latency of runtime calls that involve host to
1547 /// device memory transfers by splitting them into their "issue" and "wait"
1548 /// versions. The "issue" is moved upwards as much as possible. The "wait" is
1549 /// moved downards as much as possible. The "issue" issues the memory transfer
1550 /// asynchronously, returning a handle. The "wait" waits in the returned
1551 /// handle for the memory transfer to finish.
1552 bool hideMemTransfersLatency() {
1553 auto &RFI = OMPInfoCache.RFIs[OMPRTL___tgt_target_data_begin_mapper];
1554 bool Changed = false;
1555 auto SplitMemTransfers = [&](Use &U, Function &Decl) {
1556 auto *RTCall = getCallIfRegularCall(U, &RFI);
1557 if (!RTCall)
1558 return false;
1559
1560 OffloadArray OffloadArrays[3];
1561 if (!getValuesInOffloadArrays(*RTCall, OffloadArrays))
1562 return false;
1563
1564 LLVM_DEBUG(dumpValuesInOffloadArrays(OffloadArrays));
1565
1566 // TODO: Check if can be moved upwards.
1567 bool WasSplit = false;
1568 Instruction *WaitMovementPoint = canBeMovedDownwards(*RTCall);
1569 if (WaitMovementPoint)
1570 WasSplit = splitTargetDataBeginRTC(*RTCall, *WaitMovementPoint);
1571
1572 Changed |= WasSplit;
1573 return WasSplit;
1574 };
1575 if (OMPInfoCache.runtimeFnsAvailable(
1576 {OMPRTL___tgt_target_data_begin_mapper_issue,
1577 OMPRTL___tgt_target_data_begin_mapper_wait}))
1578 RFI.foreachUse(SCC, SplitMemTransfers);
1579
1580 return Changed;
1581 }
1582
1583 void analysisGlobalization() {
1584 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
1585
1586 auto CheckGlobalization = [&](Use &U, Function &Decl) {
1587 if (CallInst *CI = getCallIfRegularCall(U, &RFI)) {
1588 auto Remark = [&](OptimizationRemarkMissed ORM) {
1589 return ORM
1590 << "Found thread data sharing on the GPU. "
1591 << "Expect degraded performance due to data globalization.";
1592 };
1593 emitRemark<OptimizationRemarkMissed>(CI, "OMP112", Remark);
1594 }
1595
1596 return false;
1597 };
1598
1599 RFI.foreachUse(SCC, CheckGlobalization);
1600 }
1601
1602 /// Maps the values stored in the offload arrays passed as arguments to
1603 /// \p RuntimeCall into the offload arrays in \p OAs.
1604 bool getValuesInOffloadArrays(CallInst &RuntimeCall,
1606 assert(OAs.size() == 3 && "Need space for three offload arrays!");
1607
1608 // A runtime call that involves memory offloading looks something like:
1609 // call void @__tgt_target_data_begin_mapper(arg0, arg1,
1610 // i8** %offload_baseptrs, i8** %offload_ptrs, i64* %offload_sizes,
1611 // ...)
1612 // So, the idea is to access the allocas that allocate space for these
1613 // offload arrays, offload_baseptrs, offload_ptrs, offload_sizes.
1614 // Therefore:
1615 // i8** %offload_baseptrs.
1616 Value *BasePtrsArg =
1617 RuntimeCall.getArgOperand(OffloadArray::BasePtrsArgNum);
1618 // i8** %offload_ptrs.
1619 Value *PtrsArg = RuntimeCall.getArgOperand(OffloadArray::PtrsArgNum);
1620 // i8** %offload_sizes.
1621 Value *SizesArg = RuntimeCall.getArgOperand(OffloadArray::SizesArgNum);
1622
1623 // Get values stored in **offload_baseptrs.
1624 auto *V = getUnderlyingObject(BasePtrsArg);
1625 if (!isa<AllocaInst>(V))
1626 return false;
1627 auto *BasePtrsArray = cast<AllocaInst>(V);
1628 if (!OAs[0].initialize(*BasePtrsArray, RuntimeCall))
1629 return false;
1630
1631 // Get values stored in **offload_baseptrs.
1632 V = getUnderlyingObject(PtrsArg);
1633 if (!isa<AllocaInst>(V))
1634 return false;
1635 auto *PtrsArray = cast<AllocaInst>(V);
1636 if (!OAs[1].initialize(*PtrsArray, RuntimeCall))
1637 return false;
1638
1639 // Get values stored in **offload_sizes.
1640 V = getUnderlyingObject(SizesArg);
1641 // If it's a [constant] global array don't analyze it.
1642 if (isa<GlobalValue>(V))
1643 return isa<Constant>(V);
1644 if (!isa<AllocaInst>(V))
1645 return false;
1646
1647 auto *SizesArray = cast<AllocaInst>(V);
1648 if (!OAs[2].initialize(*SizesArray, RuntimeCall))
1649 return false;
1650
1651 return true;
1652 }
1653
1654 /// Prints the values in the OffloadArrays \p OAs using LLVM_DEBUG.
1655 /// For now this is a way to test that the function getValuesInOffloadArrays
1656 /// is working properly.
1657 /// TODO: Move this to a unittest when unittests are available for OpenMPOpt.
1658 void dumpValuesInOffloadArrays(ArrayRef<OffloadArray> OAs) {
1659 assert(OAs.size() == 3 && "There are three offload arrays to debug!");
1660
1661 LLVM_DEBUG(dbgs() << TAG << " Successfully got offload values:\n");
1662 std::string ValuesStr;
1663 raw_string_ostream Printer(ValuesStr);
1664 std::string Separator = " --- ";
1665
1666 for (auto *BP : OAs[0].StoredValues) {
1667 BP->print(Printer);
1668 Printer << Separator;
1669 }
1670 LLVM_DEBUG(dbgs() << "\t\toffload_baseptrs: " << Printer.str() << "\n");
1671 ValuesStr.clear();
1672
1673 for (auto *P : OAs[1].StoredValues) {
1674 P->print(Printer);
1675 Printer << Separator;
1676 }
1677 LLVM_DEBUG(dbgs() << "\t\toffload_ptrs: " << Printer.str() << "\n");
1678 ValuesStr.clear();
1679
1680 for (auto *S : OAs[2].StoredValues) {
1681 S->print(Printer);
1682 Printer << Separator;
1683 }
1684 LLVM_DEBUG(dbgs() << "\t\toffload_sizes: " << Printer.str() << "\n");
1685 }
1686
1687 /// Returns the instruction where the "wait" counterpart \p RuntimeCall can be
1688 /// moved. Returns nullptr if the movement is not possible, or not worth it.
1689 Instruction *canBeMovedDownwards(CallInst &RuntimeCall) {
1690 // FIXME: This traverses only the BasicBlock where RuntimeCall is.
1691 // Make it traverse the CFG.
1692
1693 Instruction *CurrentI = &RuntimeCall;
1694 bool IsWorthIt = false;
1695 while ((CurrentI = CurrentI->getNextNode())) {
1696
1697 // TODO: Once we detect the regions to be offloaded we should use the
1698 // alias analysis manager to check if CurrentI may modify one of
1699 // the offloaded regions.
1700 if (CurrentI->mayHaveSideEffects() || CurrentI->mayReadFromMemory()) {
1701 if (IsWorthIt)
1702 return CurrentI;
1703
1704 return nullptr;
1705 }
1706
1707 // FIXME: For now if we move it over anything without side effect
1708 // is worth it.
1709 IsWorthIt = true;
1710 }
1711
1712 // Return end of BasicBlock.
1713 return RuntimeCall.getParent()->getTerminator();
1714 }
1715
1716 /// Splits \p RuntimeCall into its "issue" and "wait" counterparts.
1717 bool splitTargetDataBeginRTC(CallInst &RuntimeCall,
1718 Instruction &WaitMovementPoint) {
1719 // Create stack allocated handle (__tgt_async_info) at the beginning of the
1720 // function. Used for storing information of the async transfer, allowing to
1721 // wait on it later.
1722 auto &IRBuilder = OMPInfoCache.OMPBuilder;
1723 Function *F = RuntimeCall.getCaller();
1724 BasicBlock &Entry = F->getEntryBlock();
1725 IRBuilder.Builder.SetInsertPoint(&Entry,
1726 Entry.getFirstNonPHIOrDbgOrAlloca());
1727 Value *Handle = IRBuilder.Builder.CreateAlloca(
1728 IRBuilder.AsyncInfo, /*ArraySize=*/nullptr, "handle");
1729 Handle =
1730 IRBuilder.Builder.CreateAddrSpaceCast(Handle, IRBuilder.AsyncInfoPtr);
1731
1732 // Add "issue" runtime call declaration:
1733 // declare %struct.tgt_async_info @__tgt_target_data_begin_issue(i64, i32,
1734 // i8**, i8**, i64*, i64*)
1735 FunctionCallee IssueDecl = IRBuilder.getOrCreateRuntimeFunction(
1736 M, OMPRTL___tgt_target_data_begin_mapper_issue);
1737
1738 // Change RuntimeCall call site for its asynchronous version.
1740 for (auto &Arg : RuntimeCall.args())
1741 Args.push_back(Arg.get());
1742 Args.push_back(Handle);
1743
1744 CallInst *IssueCallsite = CallInst::Create(IssueDecl, Args, /*NameStr=*/"",
1745 RuntimeCall.getIterator());
1746 OMPInfoCache.setCallingConvention(IssueDecl, IssueCallsite);
1747 RuntimeCall.eraseFromParent();
1748
1749 // Add "wait" runtime call declaration:
1750 // declare void @__tgt_target_data_begin_wait(i64, %struct.__tgt_async_info)
1751 FunctionCallee WaitDecl = IRBuilder.getOrCreateRuntimeFunction(
1752 M, OMPRTL___tgt_target_data_begin_mapper_wait);
1753
1754 Value *WaitParams[2] = {
1755 IssueCallsite->getArgOperand(
1756 OffloadArray::DeviceIDArgNum), // device_id.
1757 Handle // handle to wait on.
1758 };
1759 CallInst *WaitCallsite = CallInst::Create(
1760 WaitDecl, WaitParams, /*NameStr=*/"", WaitMovementPoint.getIterator());
1761 OMPInfoCache.setCallingConvention(WaitDecl, WaitCallsite);
1762
1763 return true;
1764 }
1765
1766 static Value *combinedIdentStruct(Value *CurrentIdent, Value *NextIdent,
1767 bool GlobalOnly, bool &SingleChoice) {
1768 if (CurrentIdent == NextIdent)
1769 return CurrentIdent;
1770
1771 // TODO: Figure out how to actually combine multiple debug locations. For
1772 // now we just keep an existing one if there is a single choice.
1773 if (!GlobalOnly || isa<GlobalValue>(NextIdent)) {
1774 SingleChoice = !CurrentIdent;
1775 return NextIdent;
1776 }
1777 return nullptr;
1778 }
1779
1780 /// Return an `struct ident_t*` value that represents the ones used in the
1781 /// calls of \p RFI inside of \p F. If \p GlobalOnly is true, we will not
1782 /// return a local `struct ident_t*`. For now, if we cannot find a suitable
1783 /// return value we create one from scratch. We also do not yet combine
1784 /// information, e.g., the source locations, see combinedIdentStruct.
1785 Value *
1786 getCombinedIdentFromCallUsesIn(OMPInformationCache::RuntimeFunctionInfo &RFI,
1787 Function &F, bool GlobalOnly) {
1788 bool SingleChoice = true;
1789 Value *Ident = nullptr;
1790 auto CombineIdentStruct = [&](Use &U, Function &Caller) {
1791 CallInst *CI = getCallIfRegularCall(U, &RFI);
1792 if (!CI || &F != &Caller)
1793 return false;
1794 Ident = combinedIdentStruct(Ident, CI->getArgOperand(0),
1795 /* GlobalOnly */ true, SingleChoice);
1796 return false;
1797 };
1798 RFI.foreachUse(SCC, CombineIdentStruct);
1799
1800 if (!Ident || !SingleChoice) {
1801 // The IRBuilder uses the insertion block to get to the module, this is
1802 // unfortunate but we work around it for now.
1803 if (!OMPInfoCache.OMPBuilder.getInsertionPoint().getBlock())
1804 OMPInfoCache.OMPBuilder.updateToLocation(OpenMPIRBuilder::InsertPointTy(
1805 &F.getEntryBlock(), F.getEntryBlock().begin()));
1806 // Create a fallback location if non was found.
1807 // TODO: Use the debug locations of the calls instead.
1808 uint32_t SrcLocStrSize;
1809 Constant *Loc =
1810 OMPInfoCache.OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
1811 Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(Loc, SrcLocStrSize);
1812 }
1813 return Ident;
1814 }
1815
1816 /// Try to eliminate calls of \p RFI in \p F by reusing an existing one or
1817 /// \p ReplVal if given.
1818 bool deduplicateRuntimeCalls(Function &F,
1819 OMPInformationCache::RuntimeFunctionInfo &RFI,
1820 Value *ReplVal = nullptr) {
1821 auto *UV = RFI.getUseVector(F);
1822 if (!UV || UV->size() + (ReplVal != nullptr) < 2)
1823 return false;
1824
1825 LLVM_DEBUG(
1826 dbgs() << TAG << "Deduplicate " << UV->size() << " uses of " << RFI.Name
1827 << (ReplVal ? " with an existing value\n" : "\n") << "\n");
1828
1829 assert((!ReplVal || (isa<Argument>(ReplVal) &&
1830 cast<Argument>(ReplVal)->getParent() == &F)) &&
1831 "Unexpected replacement value!");
1832
1833 // TODO: Use dominance to find a good position instead.
1834 auto CanBeMoved = [this](CallBase &CB) {
1835 unsigned NumArgs = CB.arg_size();
1836 if (NumArgs == 0)
1837 return true;
1838 if (CB.getArgOperand(0)->getType() != OMPInfoCache.OMPBuilder.IdentPtr)
1839 return false;
1840 for (unsigned U = 1; U < NumArgs; ++U)
1841 if (isa<Instruction>(CB.getArgOperand(U)))
1842 return false;
1843 return true;
1844 };
1845
1846 if (!ReplVal) {
1847 auto *DT =
1848 OMPInfoCache.getAnalysisResultForFunction<DominatorTreeAnalysis>(F);
1849 if (!DT)
1850 return false;
1851 Instruction *IP = nullptr;
1852 for (Use *U : *UV) {
1853 if (CallInst *CI = getCallIfRegularCall(*U, &RFI)) {
1854 if (IP)
1855 IP = DT->findNearestCommonDominator(IP, CI);
1856 else
1857 IP = CI;
1858 if (!CanBeMoved(*CI))
1859 continue;
1860 if (!ReplVal)
1861 ReplVal = CI;
1862 }
1863 }
1864 if (!ReplVal)
1865 return false;
1866 assert(IP && "Expected insertion point!");
1867 cast<Instruction>(ReplVal)->moveBefore(IP);
1868 }
1869
1870 // If we use a call as a replacement value we need to make sure the ident is
1871 // valid at the new location. For now we just pick a global one, either
1872 // existing and used by one of the calls, or created from scratch.
1873 if (CallBase *CI = dyn_cast<CallBase>(ReplVal)) {
1874 if (!CI->arg_empty() &&
1875 CI->getArgOperand(0)->getType() == OMPInfoCache.OMPBuilder.IdentPtr) {
1876 Value *Ident = getCombinedIdentFromCallUsesIn(RFI, F,
1877 /* GlobalOnly */ true);
1878 CI->setArgOperand(0, Ident);
1879 }
1880 }
1881
1882 bool Changed = false;
1883 auto ReplaceAndDeleteCB = [&](Use &U, Function &Caller) {
1884 CallInst *CI = getCallIfRegularCall(U, &RFI);
1885 if (!CI || CI == ReplVal || &F != &Caller)
1886 return false;
1887 assert(CI->getCaller() == &F && "Unexpected call!");
1888
1889 auto Remark = [&](OptimizationRemark OR) {
1890 return OR << "OpenMP runtime call "
1891 << ore::NV("OpenMPOptRuntime", RFI.Name) << " deduplicated.";
1892 };
1893 if (CI->getDebugLoc())
1894 emitRemark<OptimizationRemark>(CI, "OMP170", Remark);
1895 else
1896 emitRemark<OptimizationRemark>(&F, "OMP170", Remark);
1897
1898 CGUpdater.removeCallSite(*CI);
1899 CI->replaceAllUsesWith(ReplVal);
1900 CI->eraseFromParent();
1901 ++NumOpenMPRuntimeCallsDeduplicated;
1902 Changed = true;
1903 return true;
1904 };
1905 RFI.foreachUse(SCC, ReplaceAndDeleteCB);
1906
1907 return Changed;
1908 }
1909
1910 /// Collect arguments that represent the global thread id in \p GTIdArgs.
1911 void collectGlobalThreadIdArguments(SmallSetVector<Value *, 16> &GTIdArgs) {
1912 // TODO: Below we basically perform a fixpoint iteration with a pessimistic
1913 // initialization. We could define an AbstractAttribute instead and
1914 // run the Attributor here once it can be run as an SCC pass.
1915
1916 // Helper to check the argument \p ArgNo at all call sites of \p F for
1917 // a GTId.
1918 auto CallArgOpIsGTId = [&](Function &F, unsigned ArgNo, CallInst &RefCI) {
1919 if (!F.hasLocalLinkage())
1920 return false;
1921 for (Use &U : F.uses()) {
1922 if (CallInst *CI = getCallIfRegularCall(U)) {
1923 Value *ArgOp = CI->getArgOperand(ArgNo);
1924 if (CI == &RefCI || GTIdArgs.count(ArgOp) ||
1925 getCallIfRegularCall(
1926 *ArgOp, &OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num]))
1927 continue;
1928 }
1929 return false;
1930 }
1931 return true;
1932 };
1933
1934 // Helper to identify uses of a GTId as GTId arguments.
1935 auto AddUserArgs = [&](Value &GTId) {
1936 for (Use &U : GTId.uses())
1937 if (CallInst *CI = dyn_cast<CallInst>(U.getUser()))
1938 if (CI->isArgOperand(&U))
1939 if (Function *Callee = CI->getCalledFunction())
1940 if (CallArgOpIsGTId(*Callee, U.getOperandNo(), *CI))
1941 GTIdArgs.insert(Callee->getArg(U.getOperandNo()));
1942 };
1943
1944 // The argument users of __kmpc_global_thread_num calls are GTIds.
1945 OMPInformationCache::RuntimeFunctionInfo &GlobThreadNumRFI =
1946 OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num];
1947
1948 GlobThreadNumRFI.foreachUse(SCC, [&](Use &U, Function &F) {
1949 if (CallInst *CI = getCallIfRegularCall(U, &GlobThreadNumRFI))
1950 AddUserArgs(*CI);
1951 return false;
1952 });
1953
1954 // Transitively search for more arguments by looking at the users of the
1955 // ones we know already. During the search the GTIdArgs vector is extended
1956 // so we cannot cache the size nor can we use a range based for.
1957 for (unsigned U = 0; U < GTIdArgs.size(); ++U)
1958 AddUserArgs(*GTIdArgs[U]);
1959 }
1960
1961 /// Kernel (=GPU) optimizations and utility functions
1962 ///
1963 ///{{
1964
1965 /// Cache to remember the unique kernel for a function.
1967
1968 /// Find the unique kernel that will execute \p F, if any.
1969 Kernel getUniqueKernelFor(Function &F);
1970
1971 /// Find the unique kernel that will execute \p I, if any.
1972 Kernel getUniqueKernelFor(Instruction &I) {
1973 return getUniqueKernelFor(*I.getFunction());
1974 }
1975
1976 /// Rewrite the device (=GPU) code state machine create in non-SPMD mode in
1977 /// the cases we can avoid taking the address of a function.
1978 bool rewriteDeviceCodeStateMachine();
1979
1980 ///
1981 ///}}
1982
1983 /// Emit a remark generically
1984 ///
1985 /// This template function can be used to generically emit a remark. The
1986 /// RemarkKind should be one of the following:
1987 /// - OptimizationRemark to indicate a successful optimization attempt
1988 /// - OptimizationRemarkMissed to report a failed optimization attempt
1989 /// - OptimizationRemarkAnalysis to provide additional information about an
1990 /// optimization attempt
1991 ///
1992 /// The remark is built using a callback function provided by the caller that
1993 /// takes a RemarkKind as input and returns a RemarkKind.
1994 template <typename RemarkKind, typename RemarkCallBack>
1995 void emitRemark(Instruction *I, StringRef RemarkName,
1996 RemarkCallBack &&RemarkCB) const {
1997 Function *F = I->getParent()->getParent();
1998 auto &ORE = OREGetter(F);
1999
2000 if (RemarkName.starts_with("OMP"))
2001 ORE.emit([&]() {
2002 return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I))
2003 << " [" << RemarkName << "]";
2004 });
2005 else
2006 ORE.emit(
2007 [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I)); });
2008 }
2009
2010 /// Emit a remark on a function.
2011 template <typename RemarkKind, typename RemarkCallBack>
2012 void emitRemark(Function *F, StringRef RemarkName,
2013 RemarkCallBack &&RemarkCB) const {
2014 auto &ORE = OREGetter(F);
2015
2016 if (RemarkName.starts_with("OMP"))
2017 ORE.emit([&]() {
2018 return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F))
2019 << " [" << RemarkName << "]";
2020 });
2021 else
2022 ORE.emit(
2023 [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F)); });
2024 }
2025
2026 /// The underlying module.
2027 Module &M;
2028
2029 /// The SCC we are operating on.
2031
2032 /// Callback to update the call graph, the first argument is a removed call,
2033 /// the second an optional replacement call.
2034 CallGraphUpdater &CGUpdater;
2035
2036 /// Callback to get an OptimizationRemarkEmitter from a Function *
2037 OptimizationRemarkGetter OREGetter;
2038
2039 /// OpenMP-specific information cache. Also Used for Attributor runs.
2040 OMPInformationCache &OMPInfoCache;
2041
2042 /// Attributor instance.
2043 Attributor &A;
2044
2045 /// Helper function to run Attributor on SCC.
2046 bool runAttributor(bool IsModulePass) {
2047 if (SCC.empty())
2048 return false;
2049
2050 registerAAs(IsModulePass);
2051
2052 ChangeStatus Changed = A.run();
2053
2054 LLVM_DEBUG(dbgs() << "[Attributor] Done with " << SCC.size()
2055 << " functions, result: " << Changed << ".\n");
2056
2057 if (Changed == ChangeStatus::CHANGED)
2058 OMPInfoCache.invalidateAnalyses();
2059
2060 return Changed == ChangeStatus::CHANGED;
2061 }
2062
2063 void registerFoldRuntimeCall(RuntimeFunction RF);
2064
2065 /// Populate the Attributor with abstract attribute opportunities in the
2066 /// functions.
2067 void registerAAs(bool IsModulePass);
2068
2069public:
2070 /// Callback to register AAs for live functions, including internal functions
2071 /// marked live during the traversal.
2072 static void registerAAsForFunction(Attributor &A, const Function &F);
2073};
2074
2075Kernel OpenMPOpt::getUniqueKernelFor(Function &F) {
2076 if (OMPInfoCache.CGSCC && !OMPInfoCache.CGSCC->empty() &&
2077 !OMPInfoCache.CGSCC->contains(&F))
2078 return nullptr;
2079
2080 // Use a scope to keep the lifetime of the CachedKernel short.
2081 {
2082 std::optional<Kernel> &CachedKernel = UniqueKernelMap[&F];
2083 if (CachedKernel)
2084 return *CachedKernel;
2085
2086 // TODO: We should use an AA to create an (optimistic and callback
2087 // call-aware) call graph. For now we stick to simple patterns that
2088 // are less powerful, basically the worst fixpoint.
2089 if (isOpenMPKernel(F)) {
2090 CachedKernel = Kernel(&F);
2091 return *CachedKernel;
2092 }
2093
2094 CachedKernel = nullptr;
2095 if (!F.hasLocalLinkage()) {
2096
2097 // See https://openmp.llvm.org/remarks/OptimizationRemarks.html
2098 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2099 return ORA << "Potentially unknown OpenMP target region caller.";
2100 };
2101 emitRemark<OptimizationRemarkAnalysis>(&F, "OMP100", Remark);
2102
2103 return nullptr;
2104 }
2105 }
2106
2107 auto GetUniqueKernelForUse = [&](const Use &U) -> Kernel {
2108 if (auto *Cmp = dyn_cast<ICmpInst>(U.getUser())) {
2109 // Allow use in equality comparisons.
2110 if (Cmp->isEquality())
2111 return getUniqueKernelFor(*Cmp);
2112 return nullptr;
2113 }
2114 if (auto *CB = dyn_cast<CallBase>(U.getUser())) {
2115 // Allow direct calls.
2116 if (CB->isCallee(&U))
2117 return getUniqueKernelFor(*CB);
2118
2119 OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
2120 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
2121 // Allow the use in __kmpc_parallel_51 calls.
2122 if (OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI))
2123 return getUniqueKernelFor(*CB);
2124 return nullptr;
2125 }
2126 // Disallow every other use.
2127 return nullptr;
2128 };
2129
2130 // TODO: In the future we want to track more than just a unique kernel.
2131 SmallPtrSet<Kernel, 2> PotentialKernels;
2132 OMPInformationCache::foreachUse(F, [&](const Use &U) {
2133 PotentialKernels.insert(GetUniqueKernelForUse(U));
2134 });
2135
2136 Kernel K = nullptr;
2137 if (PotentialKernels.size() == 1)
2138 K = *PotentialKernels.begin();
2139
2140 // Cache the result.
2141 UniqueKernelMap[&F] = K;
2142
2143 return K;
2144}
2145
2146bool OpenMPOpt::rewriteDeviceCodeStateMachine() {
2147 OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
2148 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
2149
2150 bool Changed = false;
2151 if (!KernelParallelRFI)
2152 return Changed;
2153
2154 // If we have disabled state machine changes, exit
2156 return Changed;
2157
2158 for (Function *F : SCC) {
2159
2160 // Check if the function is a use in a __kmpc_parallel_51 call at
2161 // all.
2162 bool UnknownUse = false;
2163 bool KernelParallelUse = false;
2164 unsigned NumDirectCalls = 0;
2165
2166 SmallVector<Use *, 2> ToBeReplacedStateMachineUses;
2167 OMPInformationCache::foreachUse(*F, [&](Use &U) {
2168 if (auto *CB = dyn_cast<CallBase>(U.getUser()))
2169 if (CB->isCallee(&U)) {
2170 ++NumDirectCalls;
2171 return;
2172 }
2173
2174 if (isa<ICmpInst>(U.getUser())) {
2175 ToBeReplacedStateMachineUses.push_back(&U);
2176 return;
2177 }
2178
2179 // Find wrapper functions that represent parallel kernels.
2180 CallInst *CI =
2181 OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI);
2182 const unsigned int WrapperFunctionArgNo = 6;
2183 if (!KernelParallelUse && CI &&
2184 CI->getArgOperandNo(&U) == WrapperFunctionArgNo) {
2185 KernelParallelUse = true;
2186 ToBeReplacedStateMachineUses.push_back(&U);
2187 return;
2188 }
2189 UnknownUse = true;
2190 });
2191
2192 // Do not emit a remark if we haven't seen a __kmpc_parallel_51
2193 // use.
2194 if (!KernelParallelUse)
2195 continue;
2196
2197 // If this ever hits, we should investigate.
2198 // TODO: Checking the number of uses is not a necessary restriction and
2199 // should be lifted.
2200 if (UnknownUse || NumDirectCalls != 1 ||
2201 ToBeReplacedStateMachineUses.size() > 2) {
2202 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2203 return ORA << "Parallel region is used in "
2204 << (UnknownUse ? "unknown" : "unexpected")
2205 << " ways. Will not attempt to rewrite the state machine.";
2206 };
2207 emitRemark<OptimizationRemarkAnalysis>(F, "OMP101", Remark);
2208 continue;
2209 }
2210
2211 // Even if we have __kmpc_parallel_51 calls, we (for now) give
2212 // up if the function is not called from a unique kernel.
2213 Kernel K = getUniqueKernelFor(*F);
2214 if (!K) {
2215 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2216 return ORA << "Parallel region is not called from a unique kernel. "
2217 "Will not attempt to rewrite the state machine.";
2218 };
2219 emitRemark<OptimizationRemarkAnalysis>(F, "OMP102", Remark);
2220 continue;
2221 }
2222
2223 // We now know F is a parallel body function called only from the kernel K.
2224 // We also identified the state machine uses in which we replace the
2225 // function pointer by a new global symbol for identification purposes. This
2226 // ensures only direct calls to the function are left.
2227
2228 Module &M = *F->getParent();
2229 Type *Int8Ty = Type::getInt8Ty(M.getContext());
2230
2231 auto *ID = new GlobalVariable(
2232 M, Int8Ty, /* isConstant */ true, GlobalValue::PrivateLinkage,
2233 UndefValue::get(Int8Ty), F->getName() + ".ID");
2234
2235 for (Use *U : ToBeReplacedStateMachineUses)
2237 ID, U->get()->getType()));
2238
2239 ++NumOpenMPParallelRegionsReplacedInGPUStateMachine;
2240
2241 Changed = true;
2242 }
2243
2244 return Changed;
2245}
2246
2247/// Abstract Attribute for tracking ICV values.
2248struct AAICVTracker : public StateWrapper<BooleanState, AbstractAttribute> {
2250 AAICVTracker(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
2251
2252 /// Returns true if value is assumed to be tracked.
2253 bool isAssumedTracked() const { return getAssumed(); }
2254
2255 /// Returns true if value is known to be tracked.
2256 bool isKnownTracked() const { return getAssumed(); }
2257
2258 /// Create an abstract attribute biew for the position \p IRP.
2259 static AAICVTracker &createForPosition(const IRPosition &IRP, Attributor &A);
2260
2261 /// Return the value with which \p I can be replaced for specific \p ICV.
2262 virtual std::optional<Value *> getReplacementValue(InternalControlVar ICV,
2263 const Instruction *I,
2264 Attributor &A) const {
2265 return std::nullopt;
2266 }
2267
2268 /// Return an assumed unique ICV value if a single candidate is found. If
2269 /// there cannot be one, return a nullptr. If it is not clear yet, return
2270 /// std::nullopt.
2271 virtual std::optional<Value *>
2272 getUniqueReplacementValue(InternalControlVar ICV) const = 0;
2273
2274 // Currently only nthreads is being tracked.
2275 // this array will only grow with time.
2276 InternalControlVar TrackableICVs[1] = {ICV_nthreads};
2277
2278 /// See AbstractAttribute::getName()
2279 const std::string getName() const override { return "AAICVTracker"; }
2280
2281 /// See AbstractAttribute::getIdAddr()
2282 const char *getIdAddr() const override { return &ID; }
2283
2284 /// This function should return true if the type of the \p AA is AAICVTracker
2285 static bool classof(const AbstractAttribute *AA) {
2286 return (AA->getIdAddr() == &ID);
2287 }
2288
2289 static const char ID;
2290};
2291
2292struct AAICVTrackerFunction : public AAICVTracker {
2293 AAICVTrackerFunction(const IRPosition &IRP, Attributor &A)
2294 : AAICVTracker(IRP, A) {}
2295
2296 // FIXME: come up with better string.
2297 const std::string getAsStr(Attributor *) const override {
2298 return "ICVTrackerFunction";
2299 }
2300
2301 // FIXME: come up with some stats.
2302 void trackStatistics() const override {}
2303
2304 /// We don't manifest anything for this AA.
2305 ChangeStatus manifest(Attributor &A) override {
2306 return ChangeStatus::UNCHANGED;
2307 }
2308
2309 // Map of ICV to their values at specific program point.
2311 InternalControlVar::ICV___last>
2312 ICVReplacementValuesMap;
2313
2314 ChangeStatus updateImpl(Attributor &A) override {
2315 ChangeStatus HasChanged = ChangeStatus::UNCHANGED;
2316
2317 Function *F = getAnchorScope();
2318
2319 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2320
2321 for (InternalControlVar ICV : TrackableICVs) {
2322 auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2323
2324 auto &ValuesMap = ICVReplacementValuesMap[ICV];
2325 auto TrackValues = [&](Use &U, Function &) {
2326 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U);
2327 if (!CI)
2328 return false;
2329
2330 // FIXME: handle setters with more that 1 arguments.
2331 /// Track new value.
2332 if (ValuesMap.insert(std::make_pair(CI, CI->getArgOperand(0))).second)
2333 HasChanged = ChangeStatus::CHANGED;
2334
2335 return false;
2336 };
2337
2338 auto CallCheck = [&](Instruction &I) {
2339 std::optional<Value *> ReplVal = getValueForCall(A, I, ICV);
2340 if (ReplVal && ValuesMap.insert(std::make_pair(&I, *ReplVal)).second)
2341 HasChanged = ChangeStatus::CHANGED;
2342
2343 return true;
2344 };
2345
2346 // Track all changes of an ICV.
2347 SetterRFI.foreachUse(TrackValues, F);
2348
2349 bool UsedAssumedInformation = false;
2350 A.checkForAllInstructions(CallCheck, *this, {Instruction::Call},
2351 UsedAssumedInformation,
2352 /* CheckBBLivenessOnly */ true);
2353
2354 /// TODO: Figure out a way to avoid adding entry in
2355 /// ICVReplacementValuesMap
2356 Instruction *Entry = &F->getEntryBlock().front();
2357 if (HasChanged == ChangeStatus::CHANGED && !ValuesMap.count(Entry))
2358 ValuesMap.insert(std::make_pair(Entry, nullptr));
2359 }
2360
2361 return HasChanged;
2362 }
2363
2364 /// Helper to check if \p I is a call and get the value for it if it is
2365 /// unique.
2366 std::optional<Value *> getValueForCall(Attributor &A, const Instruction &I,
2367 InternalControlVar &ICV) const {
2368
2369 const auto *CB = dyn_cast<CallBase>(&I);
2370 if (!CB || CB->hasFnAttr("no_openmp") ||
2371 CB->hasFnAttr("no_openmp_routines"))
2372 return std::nullopt;
2373
2374 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2375 auto &GetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Getter];
2376 auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2377 Function *CalledFunction = CB->getCalledFunction();
2378
2379 // Indirect call, assume ICV changes.
2380 if (CalledFunction == nullptr)
2381 return nullptr;
2382 if (CalledFunction == GetterRFI.Declaration)
2383 return std::nullopt;
2384 if (CalledFunction == SetterRFI.Declaration) {
2385 if (ICVReplacementValuesMap[ICV].count(&I))
2386 return ICVReplacementValuesMap[ICV].lookup(&I);
2387
2388 return nullptr;
2389 }
2390
2391 // Since we don't know, assume it changes the ICV.
2392 if (CalledFunction->isDeclaration())
2393 return nullptr;
2394
2395 const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2396 *this, IRPosition::callsite_returned(*CB), DepClassTy::REQUIRED);
2397
2398 if (ICVTrackingAA->isAssumedTracked()) {
2399 std::optional<Value *> URV =
2400 ICVTrackingAA->getUniqueReplacementValue(ICV);
2401 if (!URV || (*URV && AA::isValidAtPosition(AA::ValueAndContext(**URV, I),
2402 OMPInfoCache)))
2403 return URV;
2404 }
2405
2406 // If we don't know, assume it changes.
2407 return nullptr;
2408 }
2409
2410 // We don't check unique value for a function, so return std::nullopt.
2411 std::optional<Value *>
2412 getUniqueReplacementValue(InternalControlVar ICV) const override {
2413 return std::nullopt;
2414 }
2415
2416 /// Return the value with which \p I can be replaced for specific \p ICV.
2417 std::optional<Value *> getReplacementValue(InternalControlVar ICV,
2418 const Instruction *I,
2419 Attributor &A) const override {
2420 const auto &ValuesMap = ICVReplacementValuesMap[ICV];
2421 if (ValuesMap.count(I))
2422 return ValuesMap.lookup(I);
2423
2426 Worklist.push_back(I);
2427
2428 std::optional<Value *> ReplVal;
2429
2430 while (!Worklist.empty()) {
2431 const Instruction *CurrInst = Worklist.pop_back_val();
2432 if (!Visited.insert(CurrInst).second)
2433 continue;
2434
2435 const BasicBlock *CurrBB = CurrInst->getParent();
2436
2437 // Go up and look for all potential setters/calls that might change the
2438 // ICV.
2439 while ((CurrInst = CurrInst->getPrevNode())) {
2440 if (ValuesMap.count(CurrInst)) {
2441 std::optional<Value *> NewReplVal = ValuesMap.lookup(CurrInst);
2442 // Unknown value, track new.
2443 if (!ReplVal) {
2444 ReplVal = NewReplVal;
2445 break;
2446 }
2447
2448 // If we found a new value, we can't know the icv value anymore.
2449 if (NewReplVal)
2450 if (ReplVal != NewReplVal)
2451 return nullptr;
2452
2453 break;
2454 }
2455
2456 std::optional<Value *> NewReplVal = getValueForCall(A, *CurrInst, ICV);
2457 if (!NewReplVal)
2458 continue;
2459
2460 // Unknown value, track new.
2461 if (!ReplVal) {
2462 ReplVal = NewReplVal;
2463 break;
2464 }
2465
2466 // if (NewReplVal.hasValue())
2467 // We found a new value, we can't know the icv value anymore.
2468 if (ReplVal != NewReplVal)
2469 return nullptr;
2470 }
2471
2472 // If we are in the same BB and we have a value, we are done.
2473 if (CurrBB == I->getParent() && ReplVal)
2474 return ReplVal;
2475
2476 // Go through all predecessors and add terminators for analysis.
2477 for (const BasicBlock *Pred : predecessors(CurrBB))
2478 if (const Instruction *Terminator = Pred->getTerminator())
2479 Worklist.push_back(Terminator);
2480 }
2481
2482 return ReplVal;
2483 }
2484};
2485
2486struct AAICVTrackerFunctionReturned : AAICVTracker {
2487 AAICVTrackerFunctionReturned(const IRPosition &IRP, Attributor &A)
2488 : AAICVTracker(IRP, A) {}
2489
2490 // FIXME: come up with better string.
2491 const std::string getAsStr(Attributor *) const override {
2492 return "ICVTrackerFunctionReturned";
2493 }
2494
2495 // FIXME: come up with some stats.
2496 void trackStatistics() const override {}
2497
2498 /// We don't manifest anything for this AA.
2499 ChangeStatus manifest(Attributor &A) override {
2500 return ChangeStatus::UNCHANGED;
2501 }
2502
2503 // Map of ICV to their values at specific program point.
2505 InternalControlVar::ICV___last>
2506 ICVReplacementValuesMap;
2507
2508 /// Return the value with which \p I can be replaced for specific \p ICV.
2509 std::optional<Value *>
2510 getUniqueReplacementValue(InternalControlVar ICV) const override {
2511 return ICVReplacementValuesMap[ICV];
2512 }
2513
2514 ChangeStatus updateImpl(Attributor &A) override {
2515 ChangeStatus Changed = ChangeStatus::UNCHANGED;
2516 const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2517 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2518
2519 if (!ICVTrackingAA->isAssumedTracked())
2520 return indicatePessimisticFixpoint();
2521
2522 for (InternalControlVar ICV : TrackableICVs) {
2523 std::optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2524 std::optional<Value *> UniqueICVValue;
2525
2526 auto CheckReturnInst = [&](Instruction &I) {
2527 std::optional<Value *> NewReplVal =
2528 ICVTrackingAA->getReplacementValue(ICV, &I, A);
2529
2530 // If we found a second ICV value there is no unique returned value.
2531 if (UniqueICVValue && UniqueICVValue != NewReplVal)
2532 return false;
2533
2534 UniqueICVValue = NewReplVal;
2535
2536 return true;
2537 };
2538
2539 bool UsedAssumedInformation = false;
2540 if (!A.checkForAllInstructions(CheckReturnInst, *this, {Instruction::Ret},
2541 UsedAssumedInformation,
2542 /* CheckBBLivenessOnly */ true))
2543 UniqueICVValue = nullptr;
2544
2545 if (UniqueICVValue == ReplVal)
2546 continue;
2547
2548 ReplVal = UniqueICVValue;
2549 Changed = ChangeStatus::CHANGED;
2550 }
2551
2552 return Changed;
2553 }
2554};
2555
2556struct AAICVTrackerCallSite : AAICVTracker {
2557 AAICVTrackerCallSite(const IRPosition &IRP, Attributor &A)
2558 : AAICVTracker(IRP, A) {}
2559
2560 void initialize(Attributor &A) override {
2561 assert(getAnchorScope() && "Expected anchor function");
2562
2563 // We only initialize this AA for getters, so we need to know which ICV it
2564 // gets.
2565 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2566 for (InternalControlVar ICV : TrackableICVs) {
2567 auto ICVInfo = OMPInfoCache.ICVs[ICV];
2568 auto &Getter = OMPInfoCache.RFIs[ICVInfo.Getter];
2569 if (Getter.Declaration == getAssociatedFunction()) {
2570 AssociatedICV = ICVInfo.Kind;
2571 return;
2572 }
2573 }
2574
2575 /// Unknown ICV.
2576 indicatePessimisticFixpoint();
2577 }
2578
2579 ChangeStatus manifest(Attributor &A) override {
2580 if (!ReplVal || !*ReplVal)
2581 return ChangeStatus::UNCHANGED;
2582
2583 A.changeAfterManifest(IRPosition::inst(*getCtxI()), **ReplVal);
2584 A.deleteAfterManifest(*getCtxI());
2585
2586 return ChangeStatus::CHANGED;
2587 }
2588
2589 // FIXME: come up with better string.
2590 const std::string getAsStr(Attributor *) const override {
2591 return "ICVTrackerCallSite";
2592 }
2593
2594 // FIXME: come up with some stats.
2595 void trackStatistics() const override {}
2596
2597 InternalControlVar AssociatedICV;
2598 std::optional<Value *> ReplVal;
2599
2600 ChangeStatus updateImpl(Attributor &A) override {
2601 const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2602 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2603
2604 // We don't have any information, so we assume it changes the ICV.
2605 if (!ICVTrackingAA->isAssumedTracked())
2606 return indicatePessimisticFixpoint();
2607
2608 std::optional<Value *> NewReplVal =
2609 ICVTrackingAA->getReplacementValue(AssociatedICV, getCtxI(), A);
2610
2611 if (ReplVal == NewReplVal)
2612 return ChangeStatus::UNCHANGED;
2613
2614 ReplVal = NewReplVal;
2615 return ChangeStatus::CHANGED;
2616 }
2617
2618 // Return the value with which associated value can be replaced for specific
2619 // \p ICV.
2620 std::optional<Value *>
2621 getUniqueReplacementValue(InternalControlVar ICV) const override {
2622 return ReplVal;
2623 }
2624};
2625
2626struct AAICVTrackerCallSiteReturned : AAICVTracker {
2627 AAICVTrackerCallSiteReturned(const IRPosition &IRP, Attributor &A)
2628 : AAICVTracker(IRP, A) {}
2629
2630 // FIXME: come up with better string.
2631 const std::string getAsStr(Attributor *) const override {
2632 return "ICVTrackerCallSiteReturned";
2633 }
2634
2635 // FIXME: come up with some stats.
2636 void trackStatistics() const override {}
2637
2638 /// We don't manifest anything for this AA.
2639 ChangeStatus manifest(Attributor &A) override {
2640 return ChangeStatus::UNCHANGED;
2641 }
2642
2643 // Map of ICV to their values at specific program point.
2645 InternalControlVar::ICV___last>
2646 ICVReplacementValuesMap;
2647
2648 /// Return the value with which associated value can be replaced for specific
2649 /// \p ICV.
2650 std::optional<Value *>
2651 getUniqueReplacementValue(InternalControlVar ICV) const override {
2652 return ICVReplacementValuesMap[ICV];
2653 }
2654
2655 ChangeStatus updateImpl(Attributor &A) override {
2656 ChangeStatus Changed = ChangeStatus::UNCHANGED;
2657 const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2658 *this, IRPosition::returned(*getAssociatedFunction()),
2659 DepClassTy::REQUIRED);
2660
2661 // We don't have any information, so we assume it changes the ICV.
2662 if (!ICVTrackingAA->isAssumedTracked())
2663 return indicatePessimisticFixpoint();
2664
2665 for (InternalControlVar ICV : TrackableICVs) {
2666 std::optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2667 std::optional<Value *> NewReplVal =
2668 ICVTrackingAA->getUniqueReplacementValue(ICV);
2669
2670 if (ReplVal == NewReplVal)
2671 continue;
2672
2673 ReplVal = NewReplVal;
2674 Changed = ChangeStatus::CHANGED;
2675 }
2676 return Changed;
2677 }
2678};
2679
2680/// Determines if \p BB exits the function unconditionally itself or reaches a
2681/// block that does through only unique successors.
2682static bool hasFunctionEndAsUniqueSuccessor(const BasicBlock *BB) {
2683 if (succ_empty(BB))
2684 return true;
2685 const BasicBlock *const Successor = BB->getUniqueSuccessor();
2686 if (!Successor)
2687 return false;
2688 return hasFunctionEndAsUniqueSuccessor(Successor);
2689}
2690
2691struct AAExecutionDomainFunction : public AAExecutionDomain {
2692 AAExecutionDomainFunction(const IRPosition &IRP, Attributor &A)
2693 : AAExecutionDomain(IRP, A) {}
2694
2695 ~AAExecutionDomainFunction() { delete RPOT; }
2696
2697 void initialize(Attributor &A) override {
2699 assert(F && "Expected anchor function");
2701 }
2702
2703 const std::string getAsStr(Attributor *) const override {
2704 unsigned TotalBlocks = 0, InitialThreadBlocks = 0, AlignedBlocks = 0;
2705 for (auto &It : BEDMap) {
2706 if (!It.getFirst())
2707 continue;
2708 TotalBlocks++;
2709 InitialThreadBlocks += It.getSecond().IsExecutedByInitialThreadOnly;
2710 AlignedBlocks += It.getSecond().IsReachedFromAlignedBarrierOnly &&
2711 It.getSecond().IsReachingAlignedBarrierOnly;
2712 }
2713 return "[AAExecutionDomain] " + std::to_string(InitialThreadBlocks) + "/" +
2714 std::to_string(AlignedBlocks) + " of " +
2715 std::to_string(TotalBlocks) +
2716 " executed by initial thread / aligned";
2717 }
2718
2719 /// See AbstractAttribute::trackStatistics().
2720 void trackStatistics() const override {}
2721
2722 ChangeStatus manifest(Attributor &A) override {
2723 LLVM_DEBUG({
2724 for (const BasicBlock &BB : *getAnchorScope()) {
2726 continue;
2727 dbgs() << TAG << " Basic block @" << getAnchorScope()->getName() << " "
2728 << BB.getName() << " is executed by a single thread.\n";
2729 }
2730 });
2731
2733
2735 return Changed;
2736
2737 SmallPtrSet<CallBase *, 16> DeletedBarriers;
2738 auto HandleAlignedBarrier = [&](CallBase *CB) {
2739 const ExecutionDomainTy &ED = CB ? CEDMap[{CB, PRE}] : BEDMap[nullptr];
2740 if (!ED.IsReachedFromAlignedBarrierOnly ||
2741 ED.EncounteredNonLocalSideEffect)
2742 return;
2743 if (!ED.EncounteredAssumes.empty() && !A.isModulePass())
2744 return;
2745
2746 // We can remove this barrier, if it is one, or aligned barriers reaching
2747 // the kernel end (if CB is nullptr). Aligned barriers reaching the kernel
2748 // end should only be removed if the kernel end is their unique successor;
2749 // otherwise, they may have side-effects that aren't accounted for in the
2750 // kernel end in their other successors. If those barriers have other
2751 // barriers reaching them, those can be transitively removed as well as
2752 // long as the kernel end is also their unique successor.
2753 if (CB) {
2754 DeletedBarriers.insert(CB);
2755 A.deleteAfterManifest(*CB);
2756 ++NumBarriersEliminated;
2757 Changed = ChangeStatus::CHANGED;
2758 } else if (!ED.AlignedBarriers.empty()) {
2759 Changed = ChangeStatus::CHANGED;
2760 SmallVector<CallBase *> Worklist(ED.AlignedBarriers.begin(),
2761 ED.AlignedBarriers.end());
2763 while (!Worklist.empty()) {
2764 CallBase *LastCB = Worklist.pop_back_val();
2765 if (!Visited.insert(LastCB))
2766 continue;
2767 if (LastCB->getFunction() != getAnchorScope())
2768 continue;
2769 if (!hasFunctionEndAsUniqueSuccessor(LastCB->getParent()))
2770 continue;
2771 if (!DeletedBarriers.count(LastCB)) {
2772 ++NumBarriersEliminated;
2773 A.deleteAfterManifest(*LastCB);
2774 continue;
2775 }
2776 // The final aligned barrier (LastCB) reaching the kernel end was
2777 // removed already. This means we can go one step further and remove
2778 // the barriers encoutered last before (LastCB).
2779 const ExecutionDomainTy &LastED = CEDMap[{LastCB, PRE}];
2780 Worklist.append(LastED.AlignedBarriers.begin(),
2781 LastED.AlignedBarriers.end());
2782 }
2783 }
2784
2785 // If we actually eliminated a barrier we need to eliminate the associated
2786 // llvm.assumes as well to avoid creating UB.
2787 if (!ED.EncounteredAssumes.empty() && (CB || !ED.AlignedBarriers.empty()))
2788 for (auto *AssumeCB : ED.EncounteredAssumes)
2789 A.deleteAfterManifest(*AssumeCB);
2790 };
2791
2792 for (auto *CB : AlignedBarriers)
2793 HandleAlignedBarrier(CB);
2794
2795 // Handle the "kernel end barrier" for kernels too.
2797 HandleAlignedBarrier(nullptr);
2798
2799 return Changed;
2800 }
2801
2802 bool isNoOpFence(const FenceInst &FI) const override {
2803 return getState().isValidState() && !NonNoOpFences.count(&FI);
2804 }
2805
2806 /// Merge barrier and assumption information from \p PredED into the successor
2807 /// \p ED.
2808 void
2809 mergeInPredecessorBarriersAndAssumptions(Attributor &A, ExecutionDomainTy &ED,
2810 const ExecutionDomainTy &PredED);
2811
2812 /// Merge all information from \p PredED into the successor \p ED. If
2813 /// \p InitialEdgeOnly is set, only the initial edge will enter the block
2814 /// represented by \p ED from this predecessor.
2815 bool mergeInPredecessor(Attributor &A, ExecutionDomainTy &ED,
2816 const ExecutionDomainTy &PredED,
2817 bool InitialEdgeOnly = false);
2818
2819 /// Accumulate information for the entry block in \p EntryBBED.
2820 bool handleCallees(Attributor &A, ExecutionDomainTy &EntryBBED);
2821
2822 /// See AbstractAttribute::updateImpl.
2824
2825 /// Query interface, see AAExecutionDomain
2826 ///{
2827 bool isExecutedByInitialThreadOnly(const BasicBlock &BB) const override {
2828 if (!isValidState())
2829 return false;
2830 assert(BB.getParent() == getAnchorScope() && "Block is out of scope!");
2831 return BEDMap.lookup(&BB).IsExecutedByInitialThreadOnly;
2832 }
2833
2835 const Instruction &I) const override {
2836 assert(I.getFunction() == getAnchorScope() &&
2837 "Instruction is out of scope!");
2838 if (!isValidState())
2839 return false;
2840
2841 bool ForwardIsOk = true;
2842 const Instruction *CurI;
2843
2844 // Check forward until a call or the block end is reached.
2845 CurI = &I;
2846 do {
2847 auto *CB = dyn_cast<CallBase>(CurI);
2848 if (!CB)
2849 continue;
2850 if (CB != &I && AlignedBarriers.contains(const_cast<CallBase *>(CB)))
2851 return true;
2852 const auto &It = CEDMap.find({CB, PRE});
2853 if (It == CEDMap.end())
2854 continue;
2855 if (!It->getSecond().IsReachingAlignedBarrierOnly)
2856 ForwardIsOk = false;
2857 break;
2858 } while ((CurI = CurI->getNextNonDebugInstruction()));
2859
2860 if (!CurI && !BEDMap.lookup(I.getParent()).IsReachingAlignedBarrierOnly)
2861 ForwardIsOk = false;
2862
2863 // Check backward until a call or the block beginning is reached.
2864 CurI = &I;
2865 do {
2866 auto *CB = dyn_cast<CallBase>(CurI);
2867 if (!CB)
2868 continue;
2869 if (CB != &I && AlignedBarriers.contains(const_cast<CallBase *>(CB)))
2870 return true;
2871 const auto &It = CEDMap.find({CB, POST});
2872 if (It == CEDMap.end())
2873 continue;
2874 if (It->getSecond().IsReachedFromAlignedBarrierOnly)
2875 break;
2876 return false;
2877 } while ((CurI = CurI->getPrevNonDebugInstruction()));
2878
2879 // Delayed decision on the forward pass to allow aligned barrier detection
2880 // in the backwards traversal.
2881 if (!ForwardIsOk)
2882 return false;
2883
2884 if (!CurI) {
2885 const BasicBlock *BB = I.getParent();
2886 if (BB == &BB->getParent()->getEntryBlock())
2887 return BEDMap.lookup(nullptr).IsReachedFromAlignedBarrierOnly;
2888 if (!llvm::all_of(predecessors(BB), [&](const BasicBlock *PredBB) {
2889 return BEDMap.lookup(PredBB).IsReachedFromAlignedBarrierOnly;
2890 })) {
2891 return false;
2892 }
2893 }
2894
2895 // On neither traversal we found a anything but aligned barriers.
2896 return true;
2897 }
2898
2899 ExecutionDomainTy getExecutionDomain(const BasicBlock &BB) const override {
2900 assert(isValidState() &&
2901 "No request should be made against an invalid state!");
2902 return BEDMap.lookup(&BB);
2903 }
2904 std::pair<ExecutionDomainTy, ExecutionDomainTy>
2905 getExecutionDomain(const CallBase &CB) const override {
2906 assert(isValidState() &&
2907 "No request should be made against an invalid state!");
2908 return {CEDMap.lookup({&CB, PRE}), CEDMap.lookup({&CB, POST})};
2909 }
2910 ExecutionDomainTy getFunctionExecutionDomain() const override {
2911 assert(isValidState() &&
2912 "No request should be made against an invalid state!");
2913 return InterProceduralED;
2914 }
2915 ///}
2916
2917 // Check if the edge into the successor block contains a condition that only
2918 // lets the main thread execute it.
2919 static bool isInitialThreadOnlyEdge(Attributor &A, BranchInst *Edge,
2920 BasicBlock &SuccessorBB) {
2921 if (!Edge || !Edge->isConditional())
2922 return false;
2923 if (Edge->getSuccessor(0) != &SuccessorBB)
2924 return false;
2925
2926 auto *Cmp = dyn_cast<CmpInst>(Edge->getCondition());
2927 if (!Cmp || !Cmp->isTrueWhenEqual() || !Cmp->isEquality())
2928 return false;
2929
2930 ConstantInt *C = dyn_cast<ConstantInt>(Cmp->getOperand(1));
2931 if (!C)
2932 return false;
2933
2934 // Match: -1 == __kmpc_target_init (for non-SPMD kernels only!)
2935 if (C->isAllOnesValue()) {
2936 auto *CB = dyn_cast<CallBase>(Cmp->getOperand(0));
2937 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2938 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
2939 CB = CB ? OpenMPOpt::getCallIfRegularCall(*CB, &RFI) : nullptr;
2940 if (!CB)
2941 return false;
2942 ConstantStruct *KernelEnvC =
2944 ConstantInt *ExecModeC =
2945 KernelInfo::getExecModeFromKernelEnvironment(KernelEnvC);
2946 return ExecModeC->getSExtValue() & OMP_TGT_EXEC_MODE_GENERIC;
2947 }
2948
2949 if (C->isZero()) {
2950 // Match: 0 == llvm.nvvm.read.ptx.sreg.tid.x()
2951 if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
2952 if (II->getIntrinsicID() == Intrinsic::nvvm_read_ptx_sreg_tid_x)
2953 return true;
2954
2955 // Match: 0 == llvm.amdgcn.workitem.id.x()
2956 if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
2957 if (II->getIntrinsicID() == Intrinsic::amdgcn_workitem_id_x)
2958 return true;
2959 }
2960
2961 return false;
2962 };
2963
2964 /// Mapping containing information about the function for other AAs.
2965 ExecutionDomainTy InterProceduralED;
2966
2967 enum Direction { PRE = 0, POST = 1 };
2968 /// Mapping containing information per block.
2971 CEDMap;
2972 SmallSetVector<CallBase *, 16> AlignedBarriers;
2973
2975
2976 /// Set \p R to \V and report true if that changed \p R.
2977 static bool setAndRecord(bool &R, bool V) {
2978 bool Eq = (R == V);
2979 R = V;
2980 return !Eq;
2981 }
2982
2983 /// Collection of fences known to be non-no-opt. All fences not in this set
2984 /// can be assumed no-opt.
2986};
2987
2988void AAExecutionDomainFunction::mergeInPredecessorBarriersAndAssumptions(
2989 Attributor &A, ExecutionDomainTy &ED, const ExecutionDomainTy &PredED) {
2990 for (auto *EA : PredED.EncounteredAssumes)
2991 ED.addAssumeInst(A, *EA);
2992
2993 for (auto *AB : PredED.AlignedBarriers)
2994 ED.addAlignedBarrier(A, *AB);
2995}
2996
2997bool AAExecutionDomainFunction::mergeInPredecessor(
2998 Attributor &A, ExecutionDomainTy &ED, const ExecutionDomainTy &PredED,
2999 bool InitialEdgeOnly) {
3000
3001 bool Changed = false;
3002 Changed |=
3003 setAndRecord(ED.IsExecutedByInitialThreadOnly,
3004 InitialEdgeOnly || (PredED.IsExecutedByInitialThreadOnly &&
3005 ED.IsExecutedByInitialThreadOnly));
3006
3007 Changed |= setAndRecord(ED.IsReachedFromAlignedBarrierOnly,
3008 ED.IsReachedFromAlignedBarrierOnly &&
3009 PredED.IsReachedFromAlignedBarrierOnly);
3010 Changed |= setAndRecord(ED.EncounteredNonLocalSideEffect,
3011 ED.EncounteredNonLocalSideEffect |
3012 PredED.EncounteredNonLocalSideEffect);
3013 // Do not track assumptions and barriers as part of Changed.
3014 if (ED.IsReachedFromAlignedBarrierOnly)
3015 mergeInPredecessorBarriersAndAssumptions(A, ED, PredED);
3016 else
3017 ED.clearAssumeInstAndAlignedBarriers();
3018 return Changed;
3019}
3020
3021bool AAExecutionDomainFunction::handleCallees(Attributor &A,
3022 ExecutionDomainTy &EntryBBED) {
3024 auto PredForCallSite = [&](AbstractCallSite ACS) {
3025 const auto *EDAA = A.getAAFor<AAExecutionDomain>(
3026 *this, IRPosition::function(*ACS.getInstruction()->getFunction()),
3027 DepClassTy::OPTIONAL);
3028 if (!EDAA || !EDAA->getState().isValidState())
3029 return false;
3030 CallSiteEDs.emplace_back(
3031 EDAA->getExecutionDomain(*cast<CallBase>(ACS.getInstruction())));
3032 return true;
3033 };
3034
3035 ExecutionDomainTy ExitED;
3036 bool AllCallSitesKnown;
3037 if (A.checkForAllCallSites(PredForCallSite, *this,
3038 /* RequiresAllCallSites */ true,
3039 AllCallSitesKnown)) {
3040 for (const auto &[CSInED, CSOutED] : CallSiteEDs) {
3041 mergeInPredecessor(A, EntryBBED, CSInED);
3042 ExitED.IsReachingAlignedBarrierOnly &=
3043 CSOutED.IsReachingAlignedBarrierOnly;
3044 }
3045
3046 } else {
3047 // We could not find all predecessors, so this is either a kernel or a
3048 // function with external linkage (or with some other weird uses).
3049 if (omp::isOpenMPKernel(*getAnchorScope())) {
3050 EntryBBED.IsExecutedByInitialThreadOnly = false;
3051 EntryBBED.IsReachedFromAlignedBarrierOnly = true;
3052 EntryBBED.EncounteredNonLocalSideEffect = false;
3053 ExitED.IsReachingAlignedBarrierOnly = false;
3054 } else {
3055 EntryBBED.IsExecutedByInitialThreadOnly = false;
3056 EntryBBED.IsReachedFromAlignedBarrierOnly = false;
3057 EntryBBED.EncounteredNonLocalSideEffect = true;
3058 ExitED.IsReachingAlignedBarrierOnly = false;
3059 }
3060 }
3061
3062 bool Changed = false;
3063 auto &FnED = BEDMap[nullptr];
3064 Changed |= setAndRecord(FnED.IsReachedFromAlignedBarrierOnly,
3065 FnED.IsReachedFromAlignedBarrierOnly &
3066 EntryBBED.IsReachedFromAlignedBarrierOnly);
3067 Changed |= setAndRecord(FnED.IsReachingAlignedBarrierOnly,
3068 FnED.IsReachingAlignedBarrierOnly &
3069 ExitED.IsReachingAlignedBarrierOnly);
3070 Changed |= setAndRecord(FnED.IsExecutedByInitialThreadOnly,
3071 EntryBBED.IsExecutedByInitialThreadOnly);
3072 return Changed;
3073}
3074
3075ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
3076
3077 bool Changed = false;
3078
3079 // Helper to deal with an aligned barrier encountered during the forward
3080 // traversal. \p CB is the aligned barrier, \p ED is the execution domain when
3081 // it was encountered.
3082 auto HandleAlignedBarrier = [&](CallBase &CB, ExecutionDomainTy &ED) {
3083 Changed |= AlignedBarriers.insert(&CB);
3084 // First, update the barrier ED kept in the separate CEDMap.
3085 auto &CallInED = CEDMap[{&CB, PRE}];
3086 Changed |= mergeInPredecessor(A, CallInED, ED);
3087 CallInED.IsReachingAlignedBarrierOnly = true;
3088 // Next adjust the ED we use for the traversal.
3089 ED.EncounteredNonLocalSideEffect = false;
3090 ED.IsReachedFromAlignedBarrierOnly = true;
3091 // Aligned barrier collection has to come last.
3092 ED.clearAssumeInstAndAlignedBarriers();
3093 ED.addAlignedBarrier(A, CB);
3094 auto &CallOutED = CEDMap[{&CB, POST}];
3095 Changed |= mergeInPredecessor(A, CallOutED, ED);
3096 };
3097
3098 auto *LivenessAA =
3099 A.getAAFor<AAIsDead>(*this, getIRPosition(), DepClassTy::OPTIONAL);
3100
3101 Function *F = getAnchorScope();
3102 BasicBlock &EntryBB = F->getEntryBlock();
3103 bool IsKernel = omp::isOpenMPKernel(*F);
3104
3105 SmallVector<Instruction *> SyncInstWorklist;
3106 for (auto &RIt : *RPOT) {
3107 BasicBlock &BB = *RIt;
3108
3109 bool IsEntryBB = &BB == &EntryBB;
3110 // TODO: We use local reasoning since we don't have a divergence analysis
3111 // running as well. We could basically allow uniform branches here.
3112 bool AlignedBarrierLastInBlock = IsEntryBB && IsKernel;
3113 bool IsExplicitlyAligned = IsEntryBB && IsKernel;
3114 ExecutionDomainTy ED;
3115 // Propagate "incoming edges" into information about this block.
3116 if (IsEntryBB) {
3117 Changed |= handleCallees(A, ED);
3118 } else {
3119 // For live non-entry blocks we only propagate
3120 // information via live edges.
3121 if (LivenessAA && LivenessAA->isAssumedDead(&BB))
3122 continue;
3123
3124 for (auto *PredBB : predecessors(&BB)) {
3125 if (LivenessAA && LivenessAA->isEdgeDead(PredBB, &BB))
3126 continue;
3127 bool InitialEdgeOnly = isInitialThreadOnlyEdge(
3128 A, dyn_cast<BranchInst>(PredBB->getTerminator()), BB);
3129 mergeInPredecessor(A, ED, BEDMap[PredBB], InitialEdgeOnly);
3130 }
3131 }
3132
3133 // Now we traverse the block, accumulate effects in ED and attach
3134 // information to calls.
3135 for (Instruction &I : BB) {
3136 bool UsedAssumedInformation;
3137 if (A.isAssumedDead(I, *this, LivenessAA, UsedAssumedInformation,
3138 /* CheckBBLivenessOnly */ false, DepClassTy::OPTIONAL,
3139 /* CheckForDeadStore */ true))
3140 continue;
3141
3142 // Asummes and "assume-like" (dbg, lifetime, ...) are handled first, the
3143 // former is collected the latter is ignored.
3144 if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
3145 if (auto *AI = dyn_cast_or_null<AssumeInst>(II)) {
3146 ED.addAssumeInst(A, *AI);
3147 continue;
3148 }
3149 // TODO: Should we also collect and delete lifetime markers?
3150 if (II->isAssumeLikeIntrinsic())
3151 continue;
3152 }
3153
3154 if (auto *FI = dyn_cast<FenceInst>(&I)) {
3155 if (!ED.EncounteredNonLocalSideEffect) {
3156 // An aligned fence without non-local side-effects is a no-op.
3157 if (ED.IsReachedFromAlignedBarrierOnly)
3158 continue;
3159 // A non-aligned fence without non-local side-effects is a no-op
3160 // if the ordering only publishes non-local side-effects (or less).
3161 switch (FI->getOrdering()) {
3162 case AtomicOrdering::NotAtomic:
3163 continue;
3164 case AtomicOrdering::Unordered:
3165 continue;
3166 case AtomicOrdering::Monotonic:
3167 continue;
3168 case AtomicOrdering::Acquire:
3169 break;
3170 case AtomicOrdering::Release:
3171 continue;
3172 case AtomicOrdering::AcquireRelease:
3173 break;
3174 case AtomicOrdering::SequentiallyConsistent:
3175 break;
3176 };
3177 }
3178 NonNoOpFences.insert(FI);
3179 }
3180
3181 auto *CB = dyn_cast<CallBase>(&I);
3182 bool IsNoSync = AA::isNoSyncInst(A, I, *this);
3183 bool IsAlignedBarrier =
3184 !IsNoSync && CB &&
3185 AANoSync::isAlignedBarrier(*CB, AlignedBarrierLastInBlock);
3186
3187 AlignedBarrierLastInBlock &= IsNoSync;
3188 IsExplicitlyAligned &= IsNoSync;
3189
3190 // Next we check for calls. Aligned barriers are handled
3191 // explicitly, everything else is kept for the backward traversal and will
3192 // also affect our state.
3193 if (CB) {
3194 if (IsAlignedBarrier) {
3195 HandleAlignedBarrier(*CB, ED);
3196 AlignedBarrierLastInBlock = true;
3197 IsExplicitlyAligned = true;
3198 continue;
3199 }
3200
3201 // Check the pointer(s) of a memory intrinsic explicitly.
3202 if (isa<MemIntrinsic>(&I)) {
3203 if (!ED.EncounteredNonLocalSideEffect &&
3205 ED.EncounteredNonLocalSideEffect = true;
3206 if (!IsNoSync) {
3207 ED.IsReachedFromAlignedBarrierOnly = false;
3208 SyncInstWorklist.push_back(&I);
3209 }
3210 continue;
3211 }
3212
3213 // Record how we entered the call, then accumulate the effect of the
3214 // call in ED for potential use by the callee.
3215 auto &CallInED = CEDMap[{CB, PRE}];
3216 Changed |= mergeInPredecessor(A, CallInED, ED);
3217
3218 // If we have a sync-definition we can check if it starts/ends in an
3219 // aligned barrier. If we are unsure we assume any sync breaks
3220 // alignment.
3222 if (!IsNoSync && Callee && !Callee->isDeclaration()) {
3223 const auto *EDAA = A.getAAFor<AAExecutionDomain>(
3224 *this, IRPosition::function(*Callee), DepClassTy::OPTIONAL);
3225 if (EDAA && EDAA->getState().isValidState()) {
3226 const auto &CalleeED = EDAA->getFunctionExecutionDomain();
3228 CalleeED.IsReachedFromAlignedBarrierOnly;
3229 AlignedBarrierLastInBlock = ED.IsReachedFromAlignedBarrierOnly;
3230 if (IsNoSync || !CalleeED.IsReachedFromAlignedBarrierOnly)
3231 ED.EncounteredNonLocalSideEffect |=
3232 CalleeED.EncounteredNonLocalSideEffect;
3233 else
3234 ED.EncounteredNonLocalSideEffect =
3235 CalleeED.EncounteredNonLocalSideEffect;
3236 if (!CalleeED.IsReachingAlignedBarrierOnly) {
3237 Changed |=
3238 setAndRecord(CallInED.IsReachingAlignedBarrierOnly, false);
3239 SyncInstWorklist.push_back(&I);
3240 }
3241 if (CalleeED.IsReachedFromAlignedBarrierOnly)
3242 mergeInPredecessorBarriersAndAssumptions(A, ED, CalleeED);
3243 auto &CallOutED = CEDMap[{CB, POST}];
3244 Changed |= mergeInPredecessor(A, CallOutED, ED);
3245 continue;
3246 }
3247 }
3248 if (!IsNoSync) {
3249 ED.IsReachedFromAlignedBarrierOnly = false;
3250 Changed |= setAndRecord(CallInED.IsReachingAlignedBarrierOnly, false);
3251 SyncInstWorklist.push_back(&I);
3252 }
3253 AlignedBarrierLastInBlock &= ED.IsReachedFromAlignedBarrierOnly;
3254 ED.EncounteredNonLocalSideEffect |= !CB->doesNotAccessMemory();
3255 auto &CallOutED = CEDMap[{CB, POST}];
3256 Changed |= mergeInPredecessor(A, CallOutED, ED);
3257 }
3258
3259 if (!I.mayHaveSideEffects() && !I.mayReadFromMemory())
3260 continue;
3261
3262 // If we have a callee we try to use fine-grained information to
3263 // determine local side-effects.
3264 if (CB) {
3265 const auto *MemAA = A.getAAFor<AAMemoryLocation>(
3266 *this, IRPosition::callsite_function(*CB), DepClassTy::OPTIONAL);
3267
3268 auto AccessPred = [&](const Instruction *I, const Value *Ptr,
3271 return !AA::isPotentiallyAffectedByBarrier(A, {Ptr}, *this, I);
3272 };
3273 if (MemAA && MemAA->getState().isValidState() &&
3274 MemAA->checkForAllAccessesToMemoryKind(
3276 continue;
3277 }
3278
3279 auto &InfoCache = A.getInfoCache();
3280 if (!I.mayHaveSideEffects() && InfoCache.isOnlyUsedByAssume(I))
3281 continue;
3282
3283 if (auto *LI = dyn_cast<LoadInst>(&I))
3284 if (LI->hasMetadata(LLVMContext::MD_invariant_load))
3285 continue;
3286
3287 if (!ED.EncounteredNonLocalSideEffect &&
3289 ED.EncounteredNonLocalSideEffect = true;
3290 }
3291
3292 bool IsEndAndNotReachingAlignedBarriersOnly = false;
3293 if (!isa<UnreachableInst>(BB.getTerminator()) &&
3294 !BB.getTerminator()->getNumSuccessors()) {
3295
3296 Changed |= mergeInPredecessor(A, InterProceduralED, ED);
3297
3298 auto &FnED = BEDMap[nullptr];
3299 if (IsKernel && !IsExplicitlyAligned)
3300 FnED.IsReachingAlignedBarrierOnly = false;
3301 Changed |= mergeInPredecessor(A, FnED, ED);
3302
3303 if (!FnED.IsReachingAlignedBarrierOnly) {
3304 IsEndAndNotReachingAlignedBarriersOnly = true;
3305 SyncInstWorklist.push_back(BB.getTerminator());
3306 auto &BBED = BEDMap[&BB];
3307 Changed |= setAndRecord(BBED.IsReachingAlignedBarrierOnly, false);
3308 }
3309 }
3310
3311 ExecutionDomainTy &StoredED = BEDMap[&BB];
3312 ED.IsReachingAlignedBarrierOnly = StoredED.IsReachingAlignedBarrierOnly &
3313 !IsEndAndNotReachingAlignedBarriersOnly;
3314
3315 // Check if we computed anything different as part of the forward
3316 // traversal. We do not take assumptions and aligned barriers into account
3317 // as they do not influence the state we iterate. Backward traversal values
3318 // are handled later on.
3319 if (ED.IsExecutedByInitialThreadOnly !=
3320 StoredED.IsExecutedByInitialThreadOnly ||
3321 ED.IsReachedFromAlignedBarrierOnly !=
3322 StoredED.IsReachedFromAlignedBarrierOnly ||
3323 ED.EncounteredNonLocalSideEffect !=
3324 StoredED.EncounteredNonLocalSideEffect)
3325 Changed = true;
3326
3327 // Update the state with the new value.
3328 StoredED = std::move(ED);
3329 }
3330
3331 // Propagate (non-aligned) sync instruction effects backwards until the
3332 // entry is hit or an aligned barrier.
3334 while (!SyncInstWorklist.empty()) {
3335 Instruction *SyncInst = SyncInstWorklist.pop_back_val();
3336 Instruction *CurInst = SyncInst;
3337 bool HitAlignedBarrierOrKnownEnd = false;
3338 while ((CurInst = CurInst->getPrevNode())) {
3339 auto *CB = dyn_cast<CallBase>(CurInst);
3340 if (!CB)
3341 continue;
3342 auto &CallOutED = CEDMap[{CB, POST}];
3343 Changed |= setAndRecord(CallOutED.IsReachingAlignedBarrierOnly, false);
3344 auto &CallInED = CEDMap[{CB, PRE}];
3345 HitAlignedBarrierOrKnownEnd =
3346 AlignedBarriers.count(CB) || !CallInED.IsReachingAlignedBarrierOnly;
3347 if (HitAlignedBarrierOrKnownEnd)
3348 break;
3349 Changed |= setAndRecord(CallInED.IsReachingAlignedBarrierOnly, false);
3350 }
3351 if (HitAlignedBarrierOrKnownEnd)
3352 continue;
3353 BasicBlock *SyncBB = SyncInst->getParent();
3354 for (auto *PredBB : predecessors(SyncBB)) {
3355 if (LivenessAA && LivenessAA->isEdgeDead(PredBB, SyncBB))
3356 continue;
3357 if (!Visited.insert(PredBB))
3358 continue;
3359 auto &PredED = BEDMap[PredBB];
3360 if (setAndRecord(PredED.IsReachingAlignedBarrierOnly, false)) {
3361 Changed = true;
3362 SyncInstWorklist.push_back(PredBB->getTerminator());
3363 }
3364 }
3365 if (SyncBB != &EntryBB)
3366 continue;
3367 Changed |=
3368 setAndRecord(InterProceduralED.IsReachingAlignedBarrierOnly, false);
3369 }
3370
3371 return Changed ? ChangeStatus::CHANGED : ChangeStatus::UNCHANGED;
3372}
3373
3374/// Try to replace memory allocation calls called by a single thread with a
3375/// static buffer of shared memory.
3376struct AAHeapToShared : public StateWrapper<BooleanState, AbstractAttribute> {
3378 AAHeapToShared(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
3379
3380 /// Create an abstract attribute view for the position \p IRP.
3381 static AAHeapToShared &createForPosition(const IRPosition &IRP,
3382 Attributor &A);
3383
3384 /// Returns true if HeapToShared conversion is assumed to be possible.
3385 virtual bool isAssumedHeapToShared(CallBase &CB) const = 0;
3386
3387 /// Returns true if HeapToShared conversion is assumed and the CB is a
3388 /// callsite to a free operation to be removed.
3389 virtual bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const = 0;
3390
3391 /// See AbstractAttribute::getName().
3392 const std::string getName() const override { return "AAHeapToShared"; }
3393
3394 /// See AbstractAttribute::getIdAddr().
3395 const char *getIdAddr() const override { return &ID; }
3396
3397 /// This function should return true if the type of the \p AA is
3398 /// AAHeapToShared.
3399 static bool classof(const AbstractAttribute *AA) {
3400 return (AA->getIdAddr() == &ID);
3401 }
3402
3403 /// Unique ID (due to the unique address)
3404 static const char ID;
3405};
3406
3407struct AAHeapToSharedFunction : public AAHeapToShared {
3408 AAHeapToSharedFunction(const IRPosition &IRP, Attributor &A)
3409 : AAHeapToShared(IRP, A) {}
3410
3411 const std::string getAsStr(Attributor *) const override {
3412 return "[AAHeapToShared] " + std::to_string(MallocCalls.size()) +
3413 " malloc calls eligible.";
3414 }
3415
3416 /// See AbstractAttribute::trackStatistics().
3417 void trackStatistics() const override {}
3418
3419 /// This functions finds free calls that will be removed by the
3420 /// HeapToShared transformation.
3421 void findPotentialRemovedFreeCalls(Attributor &A) {
3422 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3423 auto &FreeRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
3424
3425 PotentialRemovedFreeCalls.clear();
3426 // Update free call users of found malloc calls.
3427 for (CallBase *CB : MallocCalls) {
3429 for (auto *U : CB->users()) {
3430 CallBase *C = dyn_cast<CallBase>(U);
3431 if (C && C->getCalledFunction() == FreeRFI.Declaration)
3432 FreeCalls.push_back(C);
3433 }
3434
3435 if (FreeCalls.size() != 1)
3436 continue;
3437
3438 PotentialRemovedFreeCalls.insert(FreeCalls.front());
3439 }
3440 }
3441
3442 void initialize(Attributor &A) override {
3444 indicatePessimisticFixpoint();
3445 return;
3446 }
3447
3448 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3449 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3450 if (!RFI.Declaration)
3451 return;
3452
3454 [](const IRPosition &, const AbstractAttribute *,
3455 bool &) -> std::optional<Value *> { return nullptr; };
3456
3457 Function *F = getAnchorScope();
3458 for (User *U : RFI.Declaration->users())
3459 if (CallBase *CB = dyn_cast<CallBase>(U)) {
3460 if (CB->getFunction() != F)
3461 continue;
3462 MallocCalls.insert(CB);
3463 A.registerSimplificationCallback(IRPosition::callsite_returned(*CB),
3464 SCB);
3465 }
3466
3467 findPotentialRemovedFreeCalls(A);
3468 }
3469
3470 bool isAssumedHeapToShared(CallBase &CB) const override {
3471 return isValidState() && MallocCalls.count(&CB);
3472 }
3473
3474 bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const override {
3475 return isValidState() && PotentialRemovedFreeCalls.count(&CB);
3476 }
3477
3478 ChangeStatus manifest(Attributor &A) override {
3479 if (MallocCalls.empty())
3480 return ChangeStatus::UNCHANGED;
3481
3482 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3483 auto &FreeCall = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
3484
3485 Function *F = getAnchorScope();
3486 auto *HS = A.lookupAAFor<AAHeapToStack>(IRPosition::function(*F), this,
3487 DepClassTy::OPTIONAL);
3488
3489 ChangeStatus Changed = ChangeStatus::UNCHANGED;
3490 for (CallBase *CB : MallocCalls) {
3491 // Skip replacing this if HeapToStack has already claimed it.
3492 if (HS && HS->isAssumedHeapToStack(*CB))
3493 continue;
3494
3495 // Find the unique free call to remove it.
3497 for (auto *U : CB->users()) {
3498 CallBase *C = dyn_cast<CallBase>(U);
3499 if (C && C->getCalledFunction() == FreeCall.Declaration)
3500 FreeCalls.push_back(C);
3501 }
3502 if (FreeCalls.size() != 1)
3503 continue;
3504
3505 auto *AllocSize = cast<ConstantInt>(CB->getArgOperand(0));
3506
3507 if (AllocSize->getZExtValue() + SharedMemoryUsed > SharedMemoryLimit) {
3508 LLVM_DEBUG(dbgs() << TAG << "Cannot replace call " << *CB
3509 << " with shared memory."
3510 << " Shared memory usage is limited to "
3511 << SharedMemoryLimit << " bytes\n");
3512 continue;
3513 }
3514
3515 LLVM_DEBUG(dbgs() << TAG << "Replace globalization call " << *CB
3516 << " with " << AllocSize->getZExtValue()
3517 << " bytes of shared memory\n");
3518
3519 // Create a new shared memory buffer of the same size as the allocation
3520 // and replace all the uses of the original allocation with it.
3521 Module *M = CB->getModule();
3522 Type *Int8Ty = Type::getInt8Ty(M->getContext());
3523 Type *Int8ArrTy = ArrayType::get(Int8Ty, AllocSize->getZExtValue());
3524 auto *SharedMem = new GlobalVariable(
3525 *M, Int8ArrTy, /* IsConstant */ false, GlobalValue::InternalLinkage,
3526 PoisonValue::get(Int8ArrTy), CB->getName() + "_shared", nullptr,
3528 static_cast<unsigned>(AddressSpace::Shared));
3529 auto *NewBuffer =
3530 ConstantExpr::getPointerCast(SharedMem, Int8Ty->getPointerTo());
3531
3532 auto Remark = [&](OptimizationRemark OR) {
3533 return OR << "Replaced globalized variable with "
3534 << ore::NV("SharedMemory", AllocSize->getZExtValue())
3535 << (AllocSize->isOne() ? " byte " : " bytes ")
3536 << "of shared memory.";
3537 };
3538 A.emitRemark<OptimizationRemark>(CB, "OMP111", Remark);
3539
3540 MaybeAlign Alignment = CB->getRetAlign();
3541 assert(Alignment &&
3542 "HeapToShared on allocation without alignment attribute");
3543 SharedMem->setAlignment(*Alignment);
3544
3545 A.changeAfterManifest(IRPosition::callsite_returned(*CB), *NewBuffer);
3546 A.deleteAfterManifest(*CB);
3547 A.deleteAfterManifest(*FreeCalls.front());
3548
3549 SharedMemoryUsed += AllocSize->getZExtValue();
3550 NumBytesMovedToSharedMemory = SharedMemoryUsed;
3551 Changed = ChangeStatus::CHANGED;
3552 }
3553
3554 return Changed;
3555 }
3556
3557 ChangeStatus updateImpl(Attributor &A) override {
3558 if (MallocCalls.empty())
3559 return indicatePessimisticFixpoint();
3560 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3561 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3562 if (!RFI.Declaration)
3563 return ChangeStatus::UNCHANGED;
3564
3565 Function *F = getAnchorScope();
3566
3567 auto NumMallocCalls = MallocCalls.size();
3568
3569 // Only consider malloc calls executed by a single thread with a constant.
3570 for (User *U : RFI.Declaration->users()) {
3571 if (CallBase *CB = dyn_cast<CallBase>(U)) {
3572 if (CB->getCaller() != F)
3573 continue;
3574 if (!MallocCalls.count(CB))
3575 continue;
3576 if (!isa<ConstantInt>(CB->getArgOperand(0))) {
3577 MallocCalls.remove(CB);
3578 continue;
3579 }
3580 const auto *ED = A.getAAFor<AAExecutionDomain>(
3581 *this, IRPosition::function(*F), DepClassTy::REQUIRED);
3582 if (!ED || !ED->isExecutedByInitialThreadOnly(*CB))
3583 MallocCalls.remove(CB);
3584 }
3585 }
3586
3587 findPotentialRemovedFreeCalls(A);
3588
3589 if (NumMallocCalls != MallocCalls.size())
3590 return ChangeStatus::CHANGED;
3591
3592 return ChangeStatus::UNCHANGED;
3593 }
3594
3595 /// Collection of all malloc calls in a function.
3597 /// Collection of potentially removed free calls in a function.
3598 SmallPtrSet<CallBase *, 4> PotentialRemovedFreeCalls;
3599 /// The total amount of shared memory that has been used for HeapToShared.
3600 unsigned SharedMemoryUsed = 0;
3601};
3602
3603struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
3605 AAKernelInfo(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
3606
3607 /// The callee value is tracked beyond a simple stripPointerCasts, so we allow
3608 /// unknown callees.
3609 static bool requiresCalleeForCallBase() { return false; }
3610
3611 /// Statistics are tracked as part of manifest for now.
3612 void trackStatistics() const override {}
3613
3614 /// See AbstractAttribute::getAsStr()
3615 const std::string getAsStr(Attributor *) const override {
3616 if (!isValidState())
3617 return "<invalid>";
3618 return std::string(SPMDCompatibilityTracker.isAssumed() ? "SPMD"
3619 : "generic") +
3620 std::string(SPMDCompatibilityTracker.isAtFixpoint() ? " [FIX]"
3621 : "") +
3622 std::string(" #PRs: ") +
3623 (ReachedKnownParallelRegions.isValidState()
3624 ? std::to_string(ReachedKnownParallelRegions.size())
3625 : "<invalid>") +
3626 ", #Unknown PRs: " +
3627 (ReachedUnknownParallelRegions.isValidState()
3628 ? std::to_string(ReachedUnknownParallelRegions.size())
3629 : "<invalid>") +
3630 ", #Reaching Kernels: " +
3631 (ReachingKernelEntries.isValidState()
3632 ? std::to_string(ReachingKernelEntries.size())
3633 : "<invalid>") +
3634 ", #ParLevels: " +
3635 (ParallelLevels.isValidState()
3636 ? std::to_string(ParallelLevels.size())
3637 : "<invalid>") +
3638 ", NestedPar: " + (NestedParallelism ? "yes" : "no");
3639 }
3640
3641 /// Create an abstract attribute biew for the position \p IRP.
3642 static AAKernelInfo &createForPosition(const IRPosition &IRP, Attributor &A);
3643
3644 /// See AbstractAttribute::getName()
3645 const std::string getName() const override { return "AAKernelInfo"; }
3646
3647 /// See AbstractAttribute::getIdAddr()
3648 const char *getIdAddr() const override { return &ID; }
3649
3650 /// This function should return true if the type of the \p AA is AAKernelInfo
3651 static bool classof(const AbstractAttribute *AA) {
3652 return (AA->getIdAddr() == &ID);
3653 }
3654
3655 static const char ID;
3656};
3657
3658/// The function kernel info abstract attribute, basically, what can we say
3659/// about a function with regards to the KernelInfoState.
3660struct AAKernelInfoFunction : AAKernelInfo {
3661 AAKernelInfoFunction(const IRPosition &IRP, Attributor &A)
3662 : AAKernelInfo(IRP, A) {}
3663
3664 SmallPtrSet<Instruction *, 4> GuardedInstructions;
3665
3666 SmallPtrSetImpl<Instruction *> &getGuardedInstructions() {
3667 return GuardedInstructions;
3668 }
3669
3670 void setConfigurationOfKernelEnvironment(ConstantStruct *ConfigC) {
3672 KernelEnvC, ConfigC, {KernelInfo::ConfigurationIdx});
3673 assert(NewKernelEnvC && "Failed to create new kernel environment");
3674 KernelEnvC = cast<ConstantStruct>(NewKernelEnvC);
3675 }
3676
3677#define KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MEMBER) \
3678 void set##MEMBER##OfKernelEnvironment(ConstantInt *NewVal) { \
3679 ConstantStruct *ConfigC = \
3680 KernelInfo::getConfigurationFromKernelEnvironment(KernelEnvC); \
3681 Constant *NewConfigC = ConstantFoldInsertValueInstruction( \
3682 ConfigC, NewVal, {KernelInfo::MEMBER##Idx}); \
3683 assert(NewConfigC && "Failed to create new configuration environment"); \
3684 setConfigurationOfKernelEnvironment(cast<ConstantStruct>(NewConfigC)); \
3685 }
3686
3687 KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(UseGenericStateMachine)
3688 KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MayUseNestedParallelism)
3694
3695#undef KERNEL_ENVIRONMENT_CONFIGURATION_SETTER
3696
3697 /// See AbstractAttribute::initialize(...).
3698 void initialize(Attributor &A) override {
3699 // This is a high-level transform that might change the constant arguments
3700 // of the init and dinit calls. We need to tell the Attributor about this
3701 // to avoid other parts using the current constant value for simpliication.
3702 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3703
3704 Function *Fn = getAnchorScope();
3705
3706 OMPInformationCache::RuntimeFunctionInfo &InitRFI =
3707 OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
3708 OMPInformationCache::RuntimeFunctionInfo &DeinitRFI =
3709 OMPInfoCache.RFIs[OMPRTL___kmpc_target_deinit];
3710
3711 // For kernels we perform more initialization work, first we find the init
3712 // and deinit calls.
3713 auto StoreCallBase = [](Use &U,
3714 OMPInformationCache::RuntimeFunctionInfo &RFI,
3715 CallBase *&Storage) {
3716 CallBase *CB = OpenMPOpt::getCallIfRegularCall(U, &RFI);
3717 assert(CB &&
3718 "Unexpected use of __kmpc_target_init or __kmpc_target_deinit!");
3719 assert(!Storage &&
3720 "Multiple uses of __kmpc_target_init or __kmpc_target_deinit!");
3721 Storage = CB;
3722 return false;
3723 };
3724 InitRFI.foreachUse(
3725 [&](Use &U, Function &) {
3726 StoreCallBase(U, InitRFI, KernelInitCB);
3727 return false;
3728 },
3729 Fn);
3730 DeinitRFI.foreachUse(
3731 [&](Use &U, Function &) {
3732 StoreCallBase(U, DeinitRFI, KernelDeinitCB);
3733 return false;
3734 },
3735 Fn);
3736
3737 // Ignore kernels without initializers such as global constructors.
3738 if (!KernelInitCB || !KernelDeinitCB)
3739 return;
3740
3741 // Add itself to the reaching kernel and set IsKernelEntry.
3742 ReachingKernelEntries.insert(Fn);
3743 IsKernelEntry = true;
3744
3745 KernelEnvC =
3747 GlobalVariable *KernelEnvGV =
3749
3751 KernelConfigurationSimplifyCB =
3752 [&](const GlobalVariable &GV, const AbstractAttribute *AA,
3753 bool &UsedAssumedInformation) -> std::optional<Constant *> {
3754 if (!isAtFixpoint()) {
3755 if (!AA)
3756 return nullptr;
3757 UsedAssumedInformation = true;
3758 A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
3759 }
3760 return KernelEnvC;
3761 };
3762
3763 A.registerGlobalVariableSimplificationCallback(
3764 *KernelEnvGV, KernelConfigurationSimplifyCB);
3765
3766 // Check if we know we are in SPMD-mode already.
3767 ConstantInt *ExecModeC =
3768 KernelInfo::getExecModeFromKernelEnvironment(KernelEnvC);
3769 ConstantInt *AssumedExecModeC = ConstantInt::get(
3770 ExecModeC->getIntegerType(),
3772 if (ExecModeC->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD)
3773 SPMDCompatibilityTracker.indicateOptimisticFixpoint();
3775 // This is a generic region but SPMDization is disabled so stop
3776 // tracking.
3777 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3778 else
3779 setExecModeOfKernelEnvironment(AssumedExecModeC);
3780
3781 const Triple T(Fn->getParent()->getTargetTriple());
3782 auto *Int32Ty = Type::getInt32Ty(Fn->getContext());
3783 auto [MinThreads, MaxThreads] =
3785 if (MinThreads)
3786 setMinThreadsOfKernelEnvironment(ConstantInt::get(Int32Ty, MinThreads));
3787 if (MaxThreads)
3788 setMaxThreadsOfKernelEnvironment(ConstantInt::get(Int32Ty, MaxThreads));
3789 auto [MinTeams, MaxTeams] =
3791 if (MinTeams)
3792 setMinTeamsOfKernelEnvironment(ConstantInt::get(Int32Ty, MinTeams));
3793 if (MaxTeams)
3794 setMaxTeamsOfKernelEnvironment(ConstantInt::get(Int32Ty, MaxTeams));
3795
3796 ConstantInt *MayUseNestedParallelismC =
3797 KernelInfo::getMayUseNestedParallelismFromKernelEnvironment(KernelEnvC);
3798 ConstantInt *AssumedMayUseNestedParallelismC = ConstantInt::get(
3799 MayUseNestedParallelismC->getIntegerType(), NestedParallelism);
3800 setMayUseNestedParallelismOfKernelEnvironment(
3801 AssumedMayUseNestedParallelismC);
3802
3804 ConstantInt *UseGenericStateMachineC =
3805 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
3806 KernelEnvC);
3807 ConstantInt *AssumedUseGenericStateMachineC =
3808 ConstantInt::get(UseGenericStateMachineC->getIntegerType(), false);
3809 setUseGenericStateMachineOfKernelEnvironment(
3810 AssumedUseGenericStateMachineC);
3811 }
3812
3813 // Register virtual uses of functions we might need to preserve.
3814 auto RegisterVirtualUse = [&](RuntimeFunction RFKind,
3816 if (!OMPInfoCache.RFIs[RFKind].Declaration)
3817 return;
3818 A.registerVirtualUseCallback(*OMPInfoCache.RFIs[RFKind].Declaration, CB);
3819 };
3820
3821 // Add a dependence to ensure updates if the state changes.
3822 auto AddDependence = [](Attributor &A, const AAKernelInfo *KI,
3823 const AbstractAttribute *QueryingAA) {
3824 if (QueryingAA) {
3825 A.recordDependence(*KI, *QueryingAA, DepClassTy::OPTIONAL);
3826 }
3827 return true;
3828 };
3829
3830 Attributor::VirtualUseCallbackTy CustomStateMachineUseCB =
3831 [&](Attributor &A, const AbstractAttribute *QueryingAA) {
3832 // Whenever we create a custom state machine we will insert calls to
3833 // __kmpc_get_hardware_num_threads_in_block,
3834 // __kmpc_get_warp_size,
3835 // __kmpc_barrier_simple_generic,
3836 // __kmpc_kernel_parallel, and
3837 // __kmpc_kernel_end_parallel.
3838 // Not needed if we are on track for SPMDzation.
3839 if (SPMDCompatibilityTracker.isValidState())
3840 return AddDependence(A, this, QueryingAA);
3841 // Not needed if we can't rewrite due to an invalid state.
3842 if (!ReachedKnownParallelRegions.isValidState())
3843 return AddDependence(A, this, QueryingAA);
3844 return false;
3845 };
3846
3847 // Not needed if we are pre-runtime merge.
3848 if (!KernelInitCB->getCalledFunction()->isDeclaration()) {
3849 RegisterVirtualUse(OMPRTL___kmpc_get_hardware_num_threads_in_block,
3850 CustomStateMachineUseCB);
3851 RegisterVirtualUse(OMPRTL___kmpc_get_warp_size, CustomStateMachineUseCB);
3852 RegisterVirtualUse(OMPRTL___kmpc_barrier_simple_generic,
3853 CustomStateMachineUseCB);
3854 RegisterVirtualUse(OMPRTL___kmpc_kernel_parallel,
3855 CustomStateMachineUseCB);
3856 RegisterVirtualUse(OMPRTL___kmpc_kernel_end_parallel,
3857 CustomStateMachineUseCB);
3858 }
3859
3860 // If we do not perform SPMDzation we do not need the virtual uses below.
3861 if (SPMDCompatibilityTracker.isAtFixpoint())
3862 return;
3863
3864 Attributor::VirtualUseCallbackTy HWThreadIdUseCB =
3865 [&](Attributor &A, const AbstractAttribute *QueryingAA) {
3866 // Whenever we perform SPMDzation we will insert
3867 // __kmpc_get_hardware_thread_id_in_block calls.
3868 if (!SPMDCompatibilityTracker.isValidState())
3869 return AddDependence(A, this, QueryingAA);
3870 return false;
3871 };
3872 RegisterVirtualUse(OMPRTL___kmpc_get_hardware_thread_id_in_block,
3873 HWThreadIdUseCB);
3874
3875 Attributor::VirtualUseCallbackTy SPMDBarrierUseCB =
3876 [&](Attributor &A, const AbstractAttribute *QueryingAA) {
3877 // Whenever we perform SPMDzation with guarding we will insert
3878 // __kmpc_simple_barrier_spmd calls. If SPMDzation failed, there is
3879 // nothing to guard, or there are no parallel regions, we don't need
3880 // the calls.
3881 if (!SPMDCompatibilityTracker.isValidState())
3882 return AddDependence(A, this, QueryingAA);
3883 if (SPMDCompatibilityTracker.empty())
3884 return AddDependence(A, this, QueryingAA);
3885 if (!mayContainParallelRegion())
3886 return AddDependence(A, this, QueryingAA);
3887 return false;
3888 };
3889 RegisterVirtualUse(OMPRTL___kmpc_barrier_simple_spmd, SPMDBarrierUseCB);
3890 }
3891
3892 /// Sanitize the string \p S such that it is a suitable global symbol name.
3893 static std::string sanitizeForGlobalName(std::string S) {
3894 std::replace_if(
3895 S.begin(), S.end(),
3896 [](const char C) {
3897 return !((C >= 'a' && C <= 'z') || (C >= 'A' && C <= 'Z') ||
3898 (C >= '0' && C <= '9') || C == '_');
3899 },
3900 '.');
3901 return S;
3902 }
3903
3904 /// Modify the IR based on the KernelInfoState as the fixpoint iteration is
3905 /// finished now.
3906 ChangeStatus manifest(Attributor &A) override {
3907 // If we are not looking at a kernel with __kmpc_target_init and
3908 // __kmpc_target_deinit call we cannot actually manifest the information.
3909 if (!KernelInitCB || !KernelDeinitCB)
3910 return ChangeStatus::UNCHANGED;
3911
3912 ChangeStatus Changed = ChangeStatus::UNCHANGED;
3913
3914 bool HasBuiltStateMachine = true;
3915 if (!changeToSPMDMode(A, Changed)) {
3916 if (!KernelInitCB->getCalledFunction()->isDeclaration())
3917 HasBuiltStateMachine = buildCustomStateMachine(A, Changed);
3918 else
3919 HasBuiltStateMachine = false;
3920 }
3921
3922 // We need to reset KernelEnvC if specific rewriting is not done.
3923 ConstantStruct *ExistingKernelEnvC =
3925 ConstantInt *OldUseGenericStateMachineVal =
3926 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
3927 ExistingKernelEnvC);
3928 if (!HasBuiltStateMachine)
3929 setUseGenericStateMachineOfKernelEnvironment(
3930 OldUseGenericStateMachineVal);
3931
3932 // At last, update the KernelEnvc
3933 GlobalVariable *KernelEnvGV =
3935 if (KernelEnvGV->getInitializer() != KernelEnvC) {
3936 KernelEnvGV->setInitializer(KernelEnvC);
3937 Changed = ChangeStatus::CHANGED;
3938 }
3939
3940 return Changed;
3941 }
3942
3943 void insertInstructionGuardsHelper(Attributor &A) {
3944 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3945
3946 auto CreateGuardedRegion = [&](Instruction *RegionStartI,
3947 Instruction *RegionEndI) {
3948 LoopInfo *LI = nullptr;
3949 DominatorTree *DT = nullptr;
3950 MemorySSAUpdater *MSU = nullptr;
3951 using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
3952
3953 BasicBlock *ParentBB = RegionStartI->getParent();
3954 Function *Fn = ParentBB->getParent();
3955 Module &M = *Fn->getParent();
3956
3957 // Create all the blocks and logic.
3958 // ParentBB:
3959 // goto RegionCheckTidBB
3960 // RegionCheckTidBB:
3961 // Tid = __kmpc_hardware_thread_id()
3962 // if (Tid != 0)
3963 // goto RegionBarrierBB
3964 // RegionStartBB:
3965 // <execute instructions guarded>
3966 // goto RegionEndBB
3967 // RegionEndBB:
3968 // <store escaping values to shared mem>
3969 // goto RegionBarrierBB
3970 // RegionBarrierBB:
3971 // __kmpc_simple_barrier_spmd()
3972 // // second barrier is omitted if lacking escaping values.
3973 // <load escaping values from shared mem>
3974 // __kmpc_simple_barrier_spmd()
3975 // goto RegionExitBB
3976 // RegionExitBB:
3977 // <execute rest of instructions>
3978
3979 BasicBlock *RegionEndBB = SplitBlock(ParentBB, RegionEndI->getNextNode(),
3980 DT, LI, MSU, "region.guarded.end");
3981 BasicBlock *RegionBarrierBB =
3982 SplitBlock(RegionEndBB, &*RegionEndBB->getFirstInsertionPt(), DT, LI,
3983 MSU, "region.barrier");
3984 BasicBlock *RegionExitBB =
3985 SplitBlock(RegionBarrierBB, &*RegionBarrierBB->getFirstInsertionPt(),
3986 DT, LI, MSU, "region.exit");
3987 BasicBlock *RegionStartBB =
3988 SplitBlock(ParentBB, RegionStartI, DT, LI, MSU, "region.guarded");
3989
3990 assert(ParentBB->getUniqueSuccessor() == RegionStartBB &&
3991 "Expected a different CFG");
3992
3993 BasicBlock *RegionCheckTidBB = SplitBlock(
3994 ParentBB, ParentBB->getTerminator(), DT, LI, MSU, "region.check.tid");
3995
3996 // Register basic blocks with the Attributor.
3997 A.registerManifestAddedBasicBlock(*RegionEndBB);
3998 A.registerManifestAddedBasicBlock(*RegionBarrierBB);
3999 A.registerManifestAddedBasicBlock(*RegionExitBB);
4000 A.registerManifestAddedBasicBlock(*RegionStartBB);
4001 A.registerManifestAddedBasicBlock(*RegionCheckTidBB);
4002
4003 bool HasBroadcastValues = false;
4004 // Find escaping outputs from the guarded region to outside users and
4005 // broadcast their values to them.
4006 for (Instruction &I : *RegionStartBB) {
4007 SmallVector<Use *, 4> OutsideUses;
4008 for (Use &U : I.uses()) {
4009 Instruction &UsrI = *cast<Instruction>(U.getUser());
4010 if (UsrI.getParent() != RegionStartBB)
4011 OutsideUses.push_back(&U);
4012 }
4013
4014 if (OutsideUses.empty())
4015 continue;
4016
4017 HasBroadcastValues = true;
4018
4019 // Emit a global variable in shared memory to store the broadcasted
4020 // value.
4021 auto *SharedMem = new GlobalVariable(
4022 M, I.getType(), /* IsConstant */ false,
4024 sanitizeForGlobalName(
4025 (I.getName() + ".guarded.output.alloc").str()),
4027 static_cast<unsigned>(AddressSpace::Shared));
4028
4029 // Emit a store instruction to update the value.
4030 new StoreInst(&I, SharedMem,
4031 RegionEndBB->getTerminator()->getIterator());
4032
4033 LoadInst *LoadI = new LoadInst(
4034 I.getType(), SharedMem, I.getName() + ".guarded.output.load",
4035 RegionBarrierBB->getTerminator()->getIterator());
4036
4037 // Emit a load instruction and replace uses of the output value.
4038 for (Use *U : OutsideUses)
4039 A.changeUseAfterManifest(*U, *LoadI);
4040 }
4041
4042 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4043
4044 // Go to tid check BB in ParentBB.
4045 const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
4046 ParentBB->getTerminator()->eraseFromParent();
4048 InsertPointTy(ParentBB, ParentBB->end()), DL);
4049 OMPInfoCache.OMPBuilder.updateToLocation(Loc);
4050 uint32_t SrcLocStrSize;
4051 auto *SrcLocStr =
4052 OMPInfoCache.OMPBuilder.getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4053 Value *Ident =
4054 OMPInfoCache.OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4055 BranchInst::Create(RegionCheckTidBB, ParentBB)->setDebugLoc(DL);
4056
4057 // Add check for Tid in RegionCheckTidBB
4058 RegionCheckTidBB->getTerminator()->eraseFromParent();
4059 OpenMPIRBuilder::LocationDescription LocRegionCheckTid(
4060 InsertPointTy(RegionCheckTidBB, RegionCheckTidBB->end()), DL);
4061 OMPInfoCache.OMPBuilder.updateToLocation(LocRegionCheckTid);
4062 FunctionCallee HardwareTidFn =
4063 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4064 M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
4065 CallInst *Tid =
4066 OMPInfoCache.OMPBuilder.Builder.CreateCall(HardwareTidFn, {});
4067 Tid->setDebugLoc(DL);
4068 OMPInfoCache.setCallingConvention(HardwareTidFn, Tid);
4069 Value *TidCheck = OMPInfoCache.OMPBuilder.Builder.CreateIsNull(Tid);
4070 OMPInfoCache.OMPBuilder.Builder
4071 .CreateCondBr(TidCheck, RegionStartBB, RegionBarrierBB)
4072 ->setDebugLoc(DL);
4073
4074 // First barrier for synchronization, ensures main thread has updated
4075 // values.
4076 FunctionCallee BarrierFn =
4077 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4078 M, OMPRTL___kmpc_barrier_simple_spmd);
4079 OMPInfoCache.OMPBuilder.updateToLocation(InsertPointTy(
4080 RegionBarrierBB, RegionBarrierBB->getFirstInsertionPt()));
4081 CallInst *Barrier =
4082 OMPInfoCache.OMPBuilder.Builder.CreateCall(BarrierFn, {Ident, Tid});
4083 Barrier->setDebugLoc(DL);
4084 OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
4085
4086 // Second barrier ensures workers have read broadcast values.
4087 if (HasBroadcastValues) {
4088 CallInst *Barrier =
4089 CallInst::Create(BarrierFn, {Ident, Tid}, "",
4090 RegionBarrierBB->getTerminator()->getIterator());
4091 Barrier->setDebugLoc(DL);
4092 OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
4093 }
4094 };
4095
4096 auto &AllocSharedRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
4098 for (Instruction *GuardedI : SPMDCompatibilityTracker) {
4099 BasicBlock *BB = GuardedI->getParent();
4100 if (!Visited.insert(BB).second)
4101 continue;
4102
4104 Instruction *LastEffect = nullptr;
4105 BasicBlock::reverse_iterator IP = BB->rbegin(), IPEnd = BB->rend();
4106 while (++IP != IPEnd) {
4107 if (!IP->mayHaveSideEffects() && !IP->mayReadFromMemory())
4108 continue;
4109 Instruction *I = &*IP;
4110 if (OpenMPOpt::getCallIfRegularCall(*I, &AllocSharedRFI))
4111 continue;
4112 if (!I->user_empty() || !SPMDCompatibilityTracker.contains(I)) {
4113 LastEffect = nullptr;
4114 continue;
4115 }
4116 if (LastEffect)
4117 Reorders.push_back({I, LastEffect});
4118 LastEffect = &*IP;
4119 }
4120 for (auto &Reorder : Reorders)
4121 Reorder.first->moveBefore(Reorder.second);
4122 }
4123
4125
4126 for (Instruction *GuardedI : SPMDCompatibilityTracker) {
4127 BasicBlock *BB = GuardedI->getParent();
4128 auto *CalleeAA = A.lookupAAFor<AAKernelInfo>(
4129 IRPosition::function(*GuardedI->getFunction()), nullptr,
4130 DepClassTy::NONE);
4131 assert(CalleeAA != nullptr && "Expected Callee AAKernelInfo");
4132 auto &CalleeAAFunction = *cast<AAKernelInfoFunction>(CalleeAA);
4133 // Continue if instruction is already guarded.
4134 if (CalleeAAFunction.getGuardedInstructions().contains(GuardedI))
4135 continue;
4136
4137 Instruction *GuardedRegionStart = nullptr, *GuardedRegionEnd = nullptr;
4138 for (Instruction &I : *BB) {
4139 // If instruction I needs to be guarded update the guarded region
4140 // bounds.
4141 if (SPMDCompatibilityTracker.contains(&I)) {
4142 CalleeAAFunction.getGuardedInstructions().insert(&I);
4143 if (GuardedRegionStart)
4144 GuardedRegionEnd = &I;
4145 else
4146 GuardedRegionStart = GuardedRegionEnd = &I;
4147
4148 continue;
4149 }
4150
4151 // Instruction I does not need guarding, store
4152 // any region found and reset bounds.
4153 if (GuardedRegionStart) {
4154 GuardedRegions.push_back(
4155 std::make_pair(GuardedRegionStart, GuardedRegionEnd));
4156 GuardedRegionStart = nullptr;
4157 GuardedRegionEnd = nullptr;
4158 }
4159 }
4160 }
4161
4162 for (auto &GR : GuardedRegions)
4163 CreateGuardedRegion(GR.first, GR.second);
4164 }
4165
4166 void forceSingleThreadPerWorkgroupHelper(Attributor &A) {
4167 // Only allow 1 thread per workgroup to continue executing the user code.
4168 //
4169 // InitCB = __kmpc_target_init(...)
4170 // ThreadIdInBlock = __kmpc_get_hardware_thread_id_in_block();
4171 // if (ThreadIdInBlock != 0) return;
4172 // UserCode:
4173 // // user code
4174 //
4175 auto &Ctx = getAnchorValue().getContext();
4176 Function *Kernel = getAssociatedFunction();
4177 assert(Kernel && "Expected an associated function!");
4178
4179 // Create block for user code to branch to from initial block.
4180 BasicBlock *InitBB = KernelInitCB->getParent();
4181 BasicBlock *UserCodeBB = InitBB->splitBasicBlock(
4182 KernelInitCB->getNextNode(), "main.thread.user_code");
4183 BasicBlock *ReturnBB =
4184 BasicBlock::Create(Ctx, "exit.threads", Kernel, UserCodeBB);
4185
4186 // Register blocks with attributor:
4187 A.registerManifestAddedBasicBlock(*InitBB);
4188 A.registerManifestAddedBasicBlock(*UserCodeBB);
4189 A.registerManifestAddedBasicBlock(*ReturnBB);
4190
4191 // Debug location:
4192 const DebugLoc &DLoc = KernelInitCB->getDebugLoc();
4193 ReturnInst::Create(Ctx, ReturnBB)->setDebugLoc(DLoc);
4194 InitBB->getTerminator()->eraseFromParent();
4195
4196 // Prepare call to OMPRTL___kmpc_get_hardware_thread_id_in_block.
4197 Module &M = *Kernel->getParent();
4198 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4199 FunctionCallee ThreadIdInBlockFn =
4200 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4201 M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
4202
4203 // Get thread ID in block.
4204 CallInst *ThreadIdInBlock =
4205 CallInst::Create(ThreadIdInBlockFn, "thread_id.in.block", InitBB);
4206 OMPInfoCache.setCallingConvention(ThreadIdInBlockFn, ThreadIdInBlock);
4207 ThreadIdInBlock->setDebugLoc(DLoc);
4208
4209 // Eliminate all threads in the block with ID not equal to 0:
4210 Instruction *IsMainThread =
4211 ICmpInst::Create(ICmpInst::ICmp, CmpInst::ICMP_NE, ThreadIdInBlock,
4212 ConstantInt::get(ThreadIdInBlock->getType(), 0),
4213 "thread.is_main", InitBB);
4214 IsMainThread->setDebugLoc(DLoc);
4215 BranchInst::Create(ReturnBB, UserCodeBB, IsMainThread, InitBB);
4216 }
4217
4218 bool changeToSPMDMode(Attributor &A, ChangeStatus &Changed) {
4219 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4220
4221 // We cannot change to SPMD mode if the runtime functions aren't availible.
4222 if (!OMPInfoCache.runtimeFnsAvailable(
4223 {OMPRTL___kmpc_get_hardware_thread_id_in_block,
4224 OMPRTL___kmpc_barrier_simple_spmd}))
4225 return false;
4226
4227 if (!SPMDCompatibilityTracker.isAssumed()) {
4228 for (Instruction *NonCompatibleI : SPMDCompatibilityTracker) {
4229 if (!NonCompatibleI)
4230 continue;
4231
4232 // Skip diagnostics on calls to known OpenMP runtime functions for now.
4233 if (auto *CB = dyn_cast<CallBase>(NonCompatibleI))
4234 if (OMPInfoCache.RTLFunctions.contains(CB->getCalledFunction()))
4235 continue;
4236
4237 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
4238 ORA << "Value has potential side effects preventing SPMD-mode "
4239 "execution";
4240 if (isa<CallBase>(NonCompatibleI)) {
4241 ORA << ". Add `__attribute__((assume(\"ompx_spmd_amenable\")))` to "
4242 "the called function to override";
4243 }
4244 return ORA << ".";
4245 };
4246 A.emitRemark<OptimizationRemarkAnalysis>(NonCompatibleI, "OMP121",
4247 Remark);
4248
4249 LLVM_DEBUG(dbgs() << TAG << "SPMD-incompatible side-effect: "
4250 << *NonCompatibleI << "\n");
4251 }
4252
4253 return false;
4254 }
4255
4256 // Get the actual kernel, could be the caller of the anchor scope if we have
4257 // a debug wrapper.
4258 Function *Kernel = getAnchorScope();
4259 if (Kernel->hasLocalLinkage()) {
4260 assert(Kernel->hasOneUse() && "Unexpected use of debug kernel wrapper.");
4261 auto *CB = cast<CallBase>(Kernel->user_back());
4262 Kernel = CB->getCaller();
4263 }
4264 assert(omp::isOpenMPKernel(*Kernel) && "Expected kernel function!");
4265
4266 // Check if the kernel is already in SPMD mode, if so, return success.
4267 ConstantStruct *ExistingKernelEnvC =
4269 auto *ExecModeC =
4270 KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC);
4271 const int8_t ExecModeVal = ExecModeC->getSExtValue();
4272 if (ExecModeVal != OMP_TGT_EXEC_MODE_GENERIC)
4273 return true;
4274
4275 // We will now unconditionally modify the IR, indicate a change.
4276 Changed = ChangeStatus::CHANGED;
4277
4278 // Do not use instruction guards when no parallel is present inside
4279 // the target region.
4280 if (mayContainParallelRegion())
4281 insertInstructionGuardsHelper(A);
4282 else
4283 forceSingleThreadPerWorkgroupHelper(A);
4284
4285 // Adjust the global exec mode flag that tells the runtime what mode this
4286 // kernel is executed in.
4287 assert(ExecModeVal == OMP_TGT_EXEC_MODE_GENERIC &&
4288 "Initially non-SPMD kernel has SPMD exec mode!");
4289 setExecModeOfKernelEnvironment(
4290 ConstantInt::get(ExecModeC->getIntegerType(),
4291 ExecModeVal | OMP_TGT_EXEC_MODE_GENERIC_SPMD));
4292
4293 ++NumOpenMPTargetRegionKernelsSPMD;
4294
4295 auto Remark = [&](OptimizationRemark OR) {
4296 return OR << "Transformed generic-mode kernel to SPMD-mode.";
4297 };
4298 A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP120", Remark);
4299 return true;
4300 };
4301
4302 bool buildCustomStateMachine(Attributor &A, ChangeStatus &Changed) {
4303 // If we have disabled state machine rewrites, don't make a custom one
4305 return false;
4306
4307 // Don't rewrite the state machine if we are not in a valid state.
4308 if (!ReachedKnownParallelRegions.isValidState())
4309 return false;
4310
4311 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4312 if (!OMPInfoCache.runtimeFnsAvailable(
4313 {OMPRTL___kmpc_get_hardware_num_threads_in_block,
4314 OMPRTL___kmpc_get_warp_size, OMPRTL___kmpc_barrier_simple_generic,
4315 OMPRTL___kmpc_kernel_parallel, OMPRTL___kmpc_kernel_end_parallel}))
4316 return false;
4317
4318 ConstantStruct *ExistingKernelEnvC =
4320
4321 // Check if the current configuration is non-SPMD and generic state machine.
4322 // If we already have SPMD mode or a custom state machine we do not need to
4323 // go any further. If it is anything but a constant something is weird and
4324 // we give up.
4325 ConstantInt *UseStateMachineC =
4326 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
4327 ExistingKernelEnvC);
4328 ConstantInt *ModeC =
4329 KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC);
4330
4331 // If we are stuck with generic mode, try to create a custom device (=GPU)
4332 // state machine which is specialized for the parallel regions that are
4333 // reachable by the kernel.
4334 if (UseStateMachineC->isZero() ||
4336 return false;
4337
4338 Changed = ChangeStatus::CHANGED;
4339
4340 // If not SPMD mode, indicate we use a custom state machine now.
4341 setUseGenericStateMachineOfKernelEnvironment(
4342 ConstantInt::get(UseStateMachineC->getIntegerType(), false));
4343
4344 // If we don't actually need a state machine we are done here. This can
4345 // happen if there simply are no parallel regions. In the resulting kernel
4346 // all worker threads will simply exit right away, leaving the main thread
4347 // to do the work alone.
4348 if (!mayContainParallelRegion()) {
4349 ++NumOpenMPTargetRegionKernelsWithoutStateMachine;
4350
4351 auto Remark = [&](OptimizationRemark OR) {
4352 return OR << "Removing unused state machine from generic-mode kernel.";
4353 };
4354 A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP130", Remark);
4355
4356 return true;
4357 }
4358
4359 // Keep track in the statistics of our new shiny custom state machine.
4360 if (ReachedUnknownParallelRegions.empty()) {
4361 ++NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback;
4362
4363 auto Remark = [&](OptimizationRemark OR) {
4364 return OR << "Rewriting generic-mode kernel with a customized state "
4365 "machine.";
4366 };
4367 A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP131", Remark);
4368 } else {
4369 ++NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback;
4370
4372 return OR << "Generic-mode kernel is executed with a customized state "
4373 "machine that requires a fallback.";
4374 };
4375 A.emitRemark<OptimizationRemarkAnalysis>(KernelInitCB, "OMP132", Remark);
4376
4377 // Tell the user why we ended up with a fallback.
4378 for (CallBase *UnknownParallelRegionCB : ReachedUnknownParallelRegions) {
4379 if (!UnknownParallelRegionCB)
4380 continue;
4381 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
4382 return ORA << "Call may contain unknown parallel regions. Use "
4383 << "`__attribute__((assume(\"omp_no_parallelism\")))` to "
4384 "override.";
4385 };
4386 A.emitRemark<OptimizationRemarkAnalysis>(UnknownParallelRegionCB,
4387 "OMP133", Remark);
4388 }
4389 }
4390
4391 // Create all the blocks:
4392 //
4393 // InitCB = __kmpc_target_init(...)
4394 // BlockHwSize =
4395 // __kmpc_get_hardware_num_threads_in_block();
4396 // WarpSize = __kmpc_get_warp_size();
4397 // BlockSize = BlockHwSize - WarpSize;
4398 // IsWorkerCheckBB: bool IsWorker = InitCB != -1;
4399 // if (IsWorker) {
4400 // if (InitCB >= BlockSize) return;
4401 // SMBeginBB: __kmpc_barrier_simple_generic(...);
4402 // void *WorkFn;
4403 // bool Active = __kmpc_kernel_parallel(&WorkFn);
4404 // if (!WorkFn) return;
4405 // SMIsActiveCheckBB: if (Active) {
4406 // SMIfCascadeCurrentBB: if (WorkFn == <ParFn0>)
4407 // ParFn0(...);
4408 // SMIfCascadeCurrentBB: else if (WorkFn == <ParFn1>)
4409 // ParFn1(...);
4410 // ...
4411 // SMIfCascadeCurrentBB: else
4412 // ((WorkFnTy*)WorkFn)(...);
4413 // SMEndParallelBB: __kmpc_kernel_end_parallel(...);
4414 // }
4415 // SMDoneBB: __kmpc_barrier_simple_generic(...);
4416 // goto SMBeginBB;
4417 // }
4418 // UserCodeEntryBB: // user code
4419 // __kmpc_target_deinit(...)
4420 //
4421 auto &Ctx = getAnchorValue().getContext();
4422 Function *Kernel = getAssociatedFunction();
4423 assert(Kernel && "Expected an associated function!");
4424
4425 BasicBlock *InitBB = KernelInitCB->getParent();
4426 BasicBlock *UserCodeEntryBB = InitBB->splitBasicBlock(
4427 KernelInitCB->getNextNode(), "thread.user_code.check");
4428 BasicBlock *IsWorkerCheckBB =
4429 BasicBlock::Create(Ctx, "is_worker_check", Kernel, UserCodeEntryBB);
4430 BasicBlock *StateMachineBeginBB = BasicBlock::Create(
4431 Ctx, "worker_state_machine.begin", Kernel, UserCodeEntryBB);
4432 BasicBlock *StateMachineFinishedBB = BasicBlock::Create(
4433 Ctx, "worker_state_machine.finished", Kernel, UserCodeEntryBB);
4434 BasicBlock *StateMachineIsActiveCheckBB = BasicBlock::Create(
4435 Ctx, "worker_state_machine.is_active.check", Kernel, UserCodeEntryBB);
4436 BasicBlock *StateMachineIfCascadeCurrentBB =
4437 BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
4438 Kernel, UserCodeEntryBB);
4439 BasicBlock *StateMachineEndParallelBB =
4440 BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.end",
4441 Kernel, UserCodeEntryBB);
4442 BasicBlock *StateMachineDoneBarrierBB = BasicBlock::Create(
4443 Ctx, "worker_state_machine.done.barrier", Kernel, UserCodeEntryBB);
4444 A.registerManifestAddedBasicBlock(*InitBB);
4445 A.registerManifestAddedBasicBlock(*UserCodeEntryBB);
4446 A.registerManifestAddedBasicBlock(*IsWorkerCheckBB);
4447 A.registerManifestAddedBasicBlock(*StateMachineBeginBB);
4448 A.registerManifestAddedBasicBlock(*StateMachineFinishedBB);
4449 A.registerManifestAddedBasicBlock(*StateMachineIsActiveCheckBB);
4450 A.registerManifestAddedBasicBlock(*StateMachineIfCascadeCurrentBB);
4451 A.registerManifestAddedBasicBlock(*StateMachineEndParallelBB);
4452 A.registerManifestAddedBasicBlock(*StateMachineDoneBarrierBB);
4453
4454 const DebugLoc &DLoc = KernelInitCB->getDebugLoc();
4455 ReturnInst::Create(Ctx, StateMachineFinishedBB)->setDebugLoc(DLoc);
4456 InitBB->getTerminator()->eraseFromParent();
4457
4458 Instruction *IsWorker =
4459 ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_NE, KernelInitCB,
4460 ConstantInt::get(KernelInitCB->getType(), -1),
4461 "thread.is_worker", InitBB);
4462 IsWorker->setDebugLoc(DLoc);
4463 BranchInst::Create(IsWorkerCheckBB, UserCodeEntryBB, IsWorker, InitBB);
4464
4465 Module &M = *Kernel->getParent();
4466 FunctionCallee BlockHwSizeFn =
4467 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4468 M, OMPRTL___kmpc_get_hardware_num_threads_in_block);
4469 FunctionCallee WarpSizeFn =
4470 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4471 M, OMPRTL___kmpc_get_warp_size);
4472 CallInst *BlockHwSize =
4473 CallInst::Create(BlockHwSizeFn, "block.hw_size", IsWorkerCheckBB);
4474 OMPInfoCache.setCallingConvention(BlockHwSizeFn, BlockHwSize);
4475 BlockHwSize->setDebugLoc(DLoc);
4476 CallInst *WarpSize =
4477 CallInst::Create(WarpSizeFn, "warp.size", IsWorkerCheckBB);
4478 OMPInfoCache.setCallingConvention(WarpSizeFn, WarpSize);
4479 WarpSize->setDebugLoc(DLoc);
4480 Instruction *BlockSize = BinaryOperator::CreateSub(
4481 BlockHwSize, WarpSize, "block.size", IsWorkerCheckBB);
4482 BlockSize->setDebugLoc(DLoc);
4483 Instruction *IsMainOrWorker = ICmpInst::Create(
4484 ICmpInst::ICmp, llvm::CmpInst::ICMP_SLT, KernelInitCB, BlockSize,
4485 "thread.is_main_or_worker", IsWorkerCheckBB);
4486 IsMainOrWorker->setDebugLoc(DLoc);
4487 BranchInst::Create(StateMachineBeginBB, StateMachineFinishedBB,
4488 IsMainOrWorker, IsWorkerCheckBB);
4489
4490 // Create local storage for the work function pointer.
4491 const DataLayout &DL = M.getDataLayout();
4492 Type *VoidPtrTy = PointerType::getUnqual(Ctx);
4493 Instruction *WorkFnAI =
4494 new AllocaInst(VoidPtrTy, DL.getAllocaAddrSpace(), nullptr,
4495 "worker.work_fn.addr", Kernel->getEntryBlock().begin());
4496 WorkFnAI->setDebugLoc(DLoc);
4497
4498 OMPInfoCache.OMPBuilder.updateToLocation(
4500 IRBuilder<>::InsertPoint(StateMachineBeginBB,
4501 StateMachineBeginBB->end()),
4502 DLoc));
4503
4504 Value *Ident = KernelInfo::getIdentFromKernelEnvironment(KernelEnvC);
4505 Value *GTid = KernelInitCB;
4506
4507 FunctionCallee BarrierFn =
4508 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4509 M, OMPRTL___kmpc_barrier_simple_generic);
4510 CallInst *Barrier =
4511 CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineBeginBB);
4512 OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
4513 Barrier->setDebugLoc(DLoc);
4514
4515 if (WorkFnAI->getType()->getPointerAddressSpace() !=
4516 (unsigned int)AddressSpace::Generic) {
4517 WorkFnAI = new AddrSpaceCastInst(
4518 WorkFnAI, PointerType::get(Ctx, (unsigned int)AddressSpace::Generic),
4519 WorkFnAI->getName() + ".generic", StateMachineBeginBB);
4520 WorkFnAI->setDebugLoc(DLoc);
4521 }
4522
4523 FunctionCallee KernelParallelFn =
4524 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4525 M, OMPRTL___kmpc_kernel_parallel);
4526 CallInst *IsActiveWorker = CallInst::Create(
4527 KernelParallelFn, {WorkFnAI}, "worker.is_active", StateMachineBeginBB);
4528 OMPInfoCache.setCallingConvention(KernelParallelFn, IsActiveWorker);
4529 IsActiveWorker->setDebugLoc(DLoc);
4530 Instruction *WorkFn = new LoadInst(VoidPtrTy, WorkFnAI, "worker.work_fn",
4531 StateMachineBeginBB);
4532 WorkFn->setDebugLoc(DLoc);
4533
4534 FunctionType *ParallelRegionFnTy = FunctionType::get(
4536 false);
4537
4538 Instruction *IsDone =
4539 ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFn,
4540 Constant::getNullValue(VoidPtrTy), "worker.is_done",
4541 StateMachineBeginBB);
4542 IsDone->setDebugLoc(DLoc);
4543 BranchInst::Create(StateMachineFinishedBB, StateMachineIsActiveCheckBB,
4544 IsDone, StateMachineBeginBB)
4545 ->setDebugLoc(DLoc);
4546
4547 BranchInst::Create(StateMachineIfCascadeCurrentBB,
4548 StateMachineDoneBarrierBB, IsActiveWorker,
4549 StateMachineIsActiveCheckBB)
4550 ->setDebugLoc(DLoc);
4551
4552 Value *ZeroArg =
4553 Constant::getNullValue(ParallelRegionFnTy->getParamType(0));
4554
4555 const unsigned int WrapperFunctionArgNo = 6;
4556
4557 // Now that we have most of the CFG skeleton it is time for the if-cascade
4558 // that checks the function pointer we got from the runtime against the
4559 // parallel regions we expect, if there are any.
4560 for (int I = 0, E = ReachedKnownParallelRegions.size(); I < E; ++I) {
4561 auto *CB = ReachedKnownParallelRegions[I];
4562 auto *ParallelRegion = dyn_cast<Function>(
4563 CB->getArgOperand(WrapperFunctionArgNo)->stripPointerCasts());
4564 BasicBlock *PRExecuteBB = BasicBlock::Create(
4565 Ctx, "worker_state_machine.parallel_region.execute", Kernel,
4566 StateMachineEndParallelBB);
4567 CallInst::Create(ParallelRegion, {ZeroArg, GTid}, "", PRExecuteBB)
4568 ->setDebugLoc(DLoc);
4569 BranchInst::Create(StateMachineEndParallelBB, PRExecuteBB)
4570 ->setDebugLoc(DLoc);
4571
4572 BasicBlock *PRNextBB =
4573 BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
4574 Kernel, StateMachineEndParallelBB);
4575 A.registerManifestAddedBasicBlock(*PRExecuteBB);
4576 A.registerManifestAddedBasicBlock(*PRNextBB);
4577
4578 // Check if we need to compare the pointer at all or if we can just
4579 // call the parallel region function.
4580 Value *IsPR;
4581 if (I + 1 < E || !ReachedUnknownParallelRegions.empty()) {
4582 Instruction *CmpI = ICmpInst::Create(
4583 ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFn, ParallelRegion,
4584 "worker.check_parallel_region", StateMachineIfCascadeCurrentBB);
4585 CmpI->setDebugLoc(DLoc);
4586 IsPR = CmpI;
4587 } else {
4588 IsPR = ConstantInt::getTrue(Ctx);
4589 }
4590
4591 BranchInst::Create(PRExecuteBB, PRNextBB, IsPR,
4592 StateMachineIfCascadeCurrentBB)
4593 ->setDebugLoc(DLoc);
4594 StateMachineIfCascadeCurrentBB = PRNextBB;
4595 }
4596
4597 // At the end of the if-cascade we place the indirect function pointer call
4598 // in case we might need it, that is if there can be parallel regions we
4599 // have not handled in the if-cascade above.
4600 if (!ReachedUnknownParallelRegions.empty()) {
4601 StateMachineIfCascadeCurrentBB->setName(
4602 "worker_state_machine.parallel_region.fallback.execute");
4603 CallInst::Create(ParallelRegionFnTy, WorkFn, {ZeroArg, GTid}, "",
4604 StateMachineIfCascadeCurrentBB)
4605 ->setDebugLoc(DLoc);
4606 }
4607 BranchInst::Create(StateMachineEndParallelBB,
4608 StateMachineIfCascadeCurrentBB)
4609 ->setDebugLoc(DLoc);
4610
4611 FunctionCallee EndParallelFn =
4612 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4613 M, OMPRTL___kmpc_kernel_end_parallel);
4614 CallInst *EndParallel =
4615 CallInst::Create(EndParallelFn, {}, "", StateMachineEndParallelBB);
4616 OMPInfoCache.setCallingConvention(EndParallelFn, EndParallel);
4617 EndParallel->setDebugLoc(DLoc);
4618 BranchInst::Create(StateMachineDoneBarrierBB, StateMachineEndParallelBB)
4619 ->setDebugLoc(DLoc);
4620
4621 CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineDoneBarrierBB)
4622 ->setDebugLoc(DLoc);
4623 BranchInst::Create(StateMachineBeginBB, StateMachineDoneBarrierBB)
4624 ->setDebugLoc(DLoc);
4625
4626 return true;
4627 }
4628
4629 /// Fixpoint iteration update function. Will be called every time a dependence
4630 /// changed its state (and in the beginning).
4631 ChangeStatus updateImpl(Attributor &A) override {
4632 KernelInfoState StateBefore = getState();
4633
4634 // When we leave this function this RAII will make sure the member
4635 // KernelEnvC is updated properly depending on the state. That member is
4636 // used for simplification of values and needs to be up to date at all
4637 // times.
4638 struct UpdateKernelEnvCRAII {
4639 AAKernelInfoFunction &AA;
4640
4641 UpdateKernelEnvCRAII(AAKernelInfoFunction &AA) : AA(AA) {}
4642
4643 ~UpdateKernelEnvCRAII() {
4644 if (!AA.KernelEnvC)
4645 return;
4646
4647 ConstantStruct *ExistingKernelEnvC =
4649
4650 if (!AA.isValidState()) {
4651 AA.KernelEnvC = ExistingKernelEnvC;
4652 return;
4653 }
4654
4655 if (!AA.ReachedKnownParallelRegions.isValidState())
4656 AA.setUseGenericStateMachineOfKernelEnvironment(
4657 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
4658 ExistingKernelEnvC));
4659
4660 if (!AA.SPMDCompatibilityTracker.isValidState())
4661 AA.setExecModeOfKernelEnvironment(
4662 KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC));
4663
4664 ConstantInt *MayUseNestedParallelismC =
4665 KernelInfo::getMayUseNestedParallelismFromKernelEnvironment(
4666 AA.KernelEnvC);
4667 ConstantInt *NewMayUseNestedParallelismC = ConstantInt::get(
4668 MayUseNestedParallelismC->getIntegerType(), AA.NestedParallelism);
4669 AA.setMayUseNestedParallelismOfKernelEnvironment(
4670 NewMayUseNestedParallelismC);
4671 }
4672 } RAII(*this);
4673
4674 // Callback to check a read/write instruction.
4675 auto CheckRWInst = [&](Instruction &I) {
4676 // We handle calls later.
4677 if (isa<CallBase>(I))
4678 return true;
4679 // We only care about write effects.
4680 if (!I.mayWriteToMemory())
4681 return true;
4682 if (auto *SI = dyn_cast<StoreInst>(&I)) {
4683 const auto *UnderlyingObjsAA = A.getAAFor<AAUnderlyingObjects>(
4684 *this, IRPosition::value(*SI->getPointerOperand()),
4685 DepClassTy::OPTIONAL);
4686 auto *HS = A.getAAFor<AAHeapToStack>(
4687 *this, IRPosition::function(*I.getFunction()),
4688 DepClassTy::OPTIONAL);
4689 if (UnderlyingObjsAA &&
4690 UnderlyingObjsAA->forallUnderlyingObjects([&](Value &Obj) {
4691 if (AA::isAssumedThreadLocalObject(A, Obj, *this))
4692 return true;
4693 // Check for AAHeapToStack moved objects which must not be
4694 // guarded.
4695 auto *CB = dyn_cast<CallBase>(&Obj);
4696 return CB && HS && HS->isAssumedHeapToStack(*CB);
4697 }))
4698 return true;
4699 }
4700
4701 // Insert instruction that needs guarding.
4702 SPMDCompatibilityTracker.insert(&I);
4703 return true;
4704 };
4705
4706 bool UsedAssumedInformationInCheckRWInst = false;
4707 if (!SPMDCompatibilityTracker.isAtFixpoint())
4708 if (!A.checkForAllReadWriteInstructions(
4709 CheckRWInst, *this, UsedAssumedInformationInCheckRWInst))
4710 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4711
4712 bool UsedAssumedInformationFromReachingKernels = false;
4713 if (!IsKernelEntry) {
4714 updateParallelLevels(A);
4715
4716 bool AllReachingKernelsKnown = true;
4717 updateReachingKernelEntries(A, AllReachingKernelsKnown);
4718 UsedAssumedInformationFromReachingKernels = !AllReachingKernelsKnown;
4719
4720 if (!SPMDCompatibilityTracker.empty()) {
4721 if (!ParallelLevels.isValidState())
4722 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4723 else if (!ReachingKernelEntries.isValidState())
4724 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4725 else {
4726 // Check if all reaching kernels agree on the mode as we can otherwise
4727 // not guard instructions. We might not be sure about the mode so we
4728 // we cannot fix the internal spmd-zation state either.
4729 int SPMD = 0, Generic = 0;
4730 for (auto *Kernel : ReachingKernelEntries) {
4731 auto *CBAA = A.getAAFor<AAKernelInfo>(
4732 *this, IRPosition::function(*Kernel), DepClassTy::OPTIONAL);
4733 if (CBAA && CBAA->SPMDCompatibilityTracker.isValidState() &&
4734 CBAA->SPMDCompatibilityTracker.isAssumed())
4735 ++SPMD;
4736 else
4737 ++Generic;
4738 if (!CBAA || !CBAA->SPMDCompatibilityTracker.isAtFixpoint())
4739 UsedAssumedInformationFromReachingKernels = true;
4740 }
4741 if (SPMD != 0 && Generic != 0)
4742 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4743 }
4744 }
4745 }
4746
4747 // Callback to check a call instruction.
4748 bool AllParallelRegionStatesWereFixed = true;
4749 bool AllSPMDStatesWereFixed = true;
4750 auto CheckCallInst = [&](Instruction &I) {
4751 auto &CB = cast<CallBase>(I);
4752 auto *CBAA = A.getAAFor<AAKernelInfo>(
4753 *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);
4754 if (!CBAA)
4755 return false;
4756 getState() ^= CBAA->getState();
4757 AllSPMDStatesWereFixed &= CBAA->SPMDCompatibilityTracker.isAtFixpoint();
4758 AllParallelRegionStatesWereFixed &=
4759 CBAA->ReachedKnownParallelRegions.isAtFixpoint();
4760 AllParallelRegionStatesWereFixed &=
4761 CBAA->ReachedUnknownParallelRegions.isAtFixpoint();
4762 return true;
4763 };
4764
4765 bool UsedAssumedInformationInCheckCallInst = false;
4766 if (!A.checkForAllCallLikeInstructions(
4767 CheckCallInst, *this, UsedAssumedInformationInCheckCallInst)) {
4768 LLVM_DEBUG(dbgs() << TAG
4769 << "Failed to visit all call-like instructions!\n";);
4770 return indicatePessimisticFixpoint();
4771 }
4772
4773 // If we haven't used any assumed information for the reached parallel
4774 // region states we can fix it.
4775 if (!UsedAssumedInformationInCheckCallInst &&
4776 AllParallelRegionStatesWereFixed) {
4777 ReachedKnownParallelRegions.indicateOptimisticFixpoint();
4778 ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
4779 }
4780
4781 // If we haven't used any assumed information for the SPMD state we can fix
4782 // it.
4783 if (!UsedAssumedInformationInCheckRWInst &&
4784 !UsedAssumedInformationInCheckCallInst &&
4785 !UsedAssumedInformationFromReachingKernels && AllSPMDStatesWereFixed)
4786 SPMDCompatibilityTracker.indicateOptimisticFixpoint();
4787
4788 return StateBefore == getState() ? ChangeStatus::UNCHANGED
4789 : ChangeStatus::CHANGED;
4790 }
4791
4792private:
4793 /// Update info regarding reaching kernels.
4794 void updateReachingKernelEntries(Attributor &A,
4795 bool &AllReachingKernelsKnown) {
4796 auto PredCallSite = [&](AbstractCallSite ACS) {
4797 Function *Caller = ACS.getInstruction()->getFunction();
4798
4799 assert(Caller && "Caller is nullptr");
4800
4801 auto *CAA = A.getOrCreateAAFor<AAKernelInfo>(
4802 IRPosition::function(*Caller), this, DepClassTy::REQUIRED);
4803 if (CAA && CAA->ReachingKernelEntries.isValidState()) {
4804 ReachingKernelEntries ^= CAA->ReachingKernelEntries;
4805 return true;
4806 }
4807
4808 // We lost track of the caller of the associated function, any kernel
4809 // could reach now.
4810 ReachingKernelEntries.indicatePessimisticFixpoint();
4811
4812 return true;
4813 };
4814
4815 if (!A.checkForAllCallSites(PredCallSite, *this,
4816 true /* RequireAllCallSites */,
4817 AllReachingKernelsKnown))
4818 ReachingKernelEntries.indicatePessimisticFixpoint();
4819 }
4820
4821 /// Update info regarding parallel levels.
4822 void updateParallelLevels(Attributor &A) {
4823 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4824 OMPInformationCache::RuntimeFunctionInfo &Parallel51RFI =
4825 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
4826
4827 auto PredCallSite = [&](AbstractCallSite ACS) {
4828 Function *Caller = ACS.getInstruction()->getFunction();
4829
4830 assert(Caller && "Caller is nullptr");
4831
4832 auto *CAA =
4833 A.getOrCreateAAFor<AAKernelInfo>(IRPosition::function(*Caller));
4834 if (CAA && CAA->ParallelLevels.isValidState()) {
4835 // Any function that is called by `__kmpc_parallel_51` will not be
4836 // folded as the parallel level in the function is updated. In order to
4837 // get it right, all the analysis would depend on the implentation. That
4838 // said, if in the future any change to the implementation, the analysis
4839 // could be wrong. As a consequence, we are just conservative here.
4840 if (Caller == Parallel51RFI.Declaration) {
4841 ParallelLevels.indicatePessimisticFixpoint();
4842 return true;
4843 }
4844
4845 ParallelLevels ^= CAA->ParallelLevels;
4846
4847 return true;
4848 }
4849
4850 // We lost track of the caller of the associated function, any kernel
4851 // could reach now.
4852 ParallelLevels.indicatePessimisticFixpoint();
4853
4854 return true;
4855 };
4856
4857 bool AllCallSitesKnown = true;
4858 if (!A.checkForAllCallSites(PredCallSite, *this,
4859 true /* RequireAllCallSites */,
4860 AllCallSitesKnown))
4861 ParallelLevels.indicatePessimisticFixpoint();
4862 }
4863};
4864
4865/// The call site kernel info abstract attribute, basically, what can we say
4866/// about a call site with regards to the KernelInfoState. For now this simply
4867/// forwards the information from the callee.
4868struct AAKernelInfoCallSite : AAKernelInfo {
4869 AAKernelInfoCallSite(const IRPosition &IRP, Attributor &A)
4870 : AAKernelInfo(IRP, A) {}
4871
4872 /// See AbstractAttribute::initialize(...).
4873 void initialize(Attributor &A) override {
4874 AAKernelInfo::initialize(A);
4875
4876 CallBase &CB = cast<CallBase>(getAssociatedValue());
4877 auto *AssumptionAA = A.getAAFor<AAAssumptionInfo>(
4878 *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);
4879
4880 // Check for SPMD-mode assumptions.
4881 if (AssumptionAA && AssumptionAA->hasAssumption("ompx_spmd_amenable")) {
4882 indicateOptimisticFixpoint();
4883 return;
4884 }
4885
4886 // First weed out calls we do not care about, that is readonly/readnone
4887 // calls, intrinsics, and "no_openmp" calls. Neither of these can reach a
4888 // parallel region or anything else we are looking for.
4889 if (!CB.mayWriteToMemory() || isa<IntrinsicInst>(CB)) {
4890 indicateOptimisticFixpoint();
4891 return;
4892 }
4893
4894 // Next we check if we know the callee. If it is a known OpenMP function
4895 // we will handle them explicitly in the switch below. If it is not, we
4896 // will use an AAKernelInfo object on the callee to gather information and
4897 // merge that into the current state. The latter happens in the updateImpl.
4898 auto CheckCallee = [&](Function *Callee, unsigned NumCallees) {
4899 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4900 const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
4901 if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
4902 // Unknown caller or declarations are not analyzable, we give up.
4903 if (!Callee || !A.isFunctionIPOAmendable(*Callee)) {
4904
4905 // Unknown callees might contain parallel regions, except if they have
4906 // an appropriate assumption attached.
4907 if (!AssumptionAA ||
4908 !(AssumptionAA->hasAssumption("omp_no_openmp") ||
4909 AssumptionAA->hasAssumption("omp_no_parallelism")))
4910 ReachedUnknownParallelRegions.insert(&CB);
4911
4912 // If SPMDCompatibilityTracker is not fixed, we need to give up on the
4913 // idea we can run something unknown in SPMD-mode.
4914 if (!SPMDCompatibilityTracker.isAtFixpoint()) {
4915 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4916 SPMDCompatibilityTracker.insert(&CB);
4917 }
4918
4919 // We have updated the state for this unknown call properly, there
4920 // won't be any change so we indicate a fixpoint.
4921 indicateOptimisticFixpoint();
4922 }
4923 // If the callee is known and can be used in IPO, we will update the
4924 // state based on the callee state in updateImpl.
4925 return;
4926 }
4927 if (NumCallees > 1) {
4928 indicatePessimisticFixpoint();
4929 return;
4930 }
4931
4932 RuntimeFunction RF = It->getSecond();
4933 switch (RF) {
4934 // All the functions we know are compatible with SPMD mode.
4935 case OMPRTL___kmpc_is_spmd_exec_mode:
4936 case OMPRTL___kmpc_distribute_static_fini:
4937 case OMPRTL___kmpc_for_static_fini:
4938 case OMPRTL___kmpc_global_thread_num:
4939 case OMPRTL___kmpc_get_hardware_num_threads_in_block:
4940 case OMPRTL___kmpc_get_hardware_num_blocks:
4941 case OMPRTL___kmpc_single:
4942 case OMPRTL___kmpc_end_single:
4943 case OMPRTL___kmpc_master:
4944 case OMPRTL___kmpc_end_master:
4945 case OMPRTL___kmpc_barrier:
4946 case OMPRTL___kmpc_nvptx_parallel_reduce_nowait_v2:
4947 case OMPRTL___kmpc_nvptx_teams_reduce_nowait_v2:
4948 case OMPRTL___kmpc_error:
4949 case OMPRTL___kmpc_flush:
4950 case OMPRTL___kmpc_get_hardware_thread_id_in_block:
4951 case OMPRTL___kmpc_get_warp_size:
4952 case OMPRTL_omp_get_thread_num:
4953 case OMPRTL_omp_get_num_threads:
4954 case OMPRTL_omp_get_max_threads:
4955 case OMPRTL_omp_in_parallel:
4956 case OMPRTL_omp_get_dynamic:
4957 case OMPRTL_omp_get_cancellation:
4958 case OMPRTL_omp_get_nested:
4959 case OMPRTL_omp_get_schedule:
4960 case OMPRTL_omp_get_thread_limit:
4961 case OMPRTL_omp_get_supported_active_levels:
4962 case OMPRTL_omp_get_max_active_levels:
4963 case OMPRTL_omp_get_level:
4964 case OMPRTL_omp_get_ancestor_thread_num:
4965 case OMPRTL_omp_get_team_size:
4966 case OMPRTL_omp_get_active_level:
4967 case OMPRTL_omp_in_final:
4968 case OMPRTL_omp_get_proc_bind:
4969 case OMPRTL_omp_get_num_places:
4970 case OMPRTL_omp_get_num_procs:
4971 case OMPRTL_omp_get_place_proc_ids:
4972 case OMPRTL_omp_get_place_num:
4973 case OMPRTL_omp_get_partition_num_places:
4974 case OMPRTL_omp_get_partition_place_nums:
4975 case OMPRTL_omp_get_wtime:
4976 break;
4977 case OMPRTL___kmpc_distribute_static_init_4:
4978 case OMPRTL___kmpc_distribute_static_init_4u:
4979 case OMPRTL___kmpc_distribute_static_init_8:
4980 case OMPRTL___kmpc_distribute_static_init_8u:
4981 case OMPRTL___kmpc_for_static_init_4:
4982 case OMPRTL___kmpc_for_static_init_4u:
4983 case OMPRTL___kmpc_for_static_init_8:
4984 case OMPRTL___kmpc_for_static_init_8u: {
4985 // Check the schedule and allow static schedule in SPMD mode.
4986 unsigned ScheduleArgOpNo = 2;
4987 auto *ScheduleTypeCI =
4988 dyn_cast<ConstantInt>(CB.getArgOperand(ScheduleArgOpNo));
4989 unsigned ScheduleTypeVal =
4990 ScheduleTypeCI ? ScheduleTypeCI->getZExtValue() : 0;
4991 switch (OMPScheduleType(ScheduleTypeVal)) {
4992 case OMPScheduleType::UnorderedStatic:
4993 case OMPScheduleType::UnorderedStaticChunked:
4994 case OMPScheduleType::OrderedDistribute:
4995 case OMPScheduleType::OrderedDistributeChunked:
4996 break;
4997 default:
4998 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4999 SPMDCompatibilityTracker.insert(&CB);
5000 break;
5001 };
5002 } break;
5003 case OMPRTL___kmpc_target_init:
5004 KernelInitCB = &CB;
5005 break;
5006 case OMPRTL___kmpc_target_deinit:
5007 KernelDeinitCB = &CB;
5008 break;
5009 case OMPRTL___kmpc_parallel_51:
5010 if (!handleParallel51(A, CB))
5011 indicatePessimisticFixpoint();
5012 return;
5013 case OMPRTL___kmpc_omp_task:
5014 // We do not look into tasks right now, just give up.
5015 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5016 SPMDCompatibilityTracker.insert(&CB);
5017 ReachedUnknownParallelRegions.insert(&CB);
5018 break;
5019 case OMPRTL___kmpc_alloc_shared:
5020 case OMPRTL___kmpc_free_shared:
5021 // Return without setting a fixpoint, to be resolved in updateImpl.
5022 return;
5023 default:
5024 // Unknown OpenMP runtime calls cannot be executed in SPMD-mode,
5025 // generally. However, they do not hide parallel regions.
5026 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5027 SPMDCompatibilityTracker.insert(&CB);
5028 break;
5029 }
5030 // All other OpenMP runtime calls will not reach parallel regions so they
5031 // can be safely ignored for now. Since it is a known OpenMP runtime call
5032 // we have now modeled all effects and there is no need for any update.
5033 indicateOptimisticFixpoint();
5034 };
5035
5036 const auto *AACE =
5037 A.getAAFor<AACallEdges>(*this, getIRPosition(), DepClassTy::OPTIONAL);
5038 if (!AACE || !AACE->getState().isValidState() || AACE->hasUnknownCallee()) {
5039 CheckCallee(getAssociatedFunction(), 1);
5040 return;
5041 }
5042 const auto &OptimisticEdges = AACE->getOptimisticEdges();
5043 for (auto *Callee : OptimisticEdges) {
5044 CheckCallee(Callee, OptimisticEdges.size());
5045 if (isAtFixpoint())
5046 break;
5047 }
5048 }
5049
5050 ChangeStatus updateImpl(Attributor &A) override {
5051 // TODO: Once we have call site specific value information we can provide
5052 // call site specific liveness information and then it makes
5053 // sense to specialize attributes for call sites arguments instead of
5054 // redirecting requests to the callee argument.
5055 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
5056 KernelInfoState StateBefore = getState();
5057
5058 auto CheckCallee = [&](Function *F, int NumCallees) {
5059 const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(F);
5060
5061 // If F is not a runtime function, propagate the AAKernelInfo of the
5062 // callee.
5063 if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
5064 const IRPosition &FnPos = IRPosition::function(*F);
5065 auto *FnAA =
5066 A.getAAFor<AAKernelInfo>(*this, FnPos, DepClassTy::REQUIRED);
5067 if (!FnAA)
5068 return indicatePessimisticFixpoint();
5069 if (getState() == FnAA->getState())
5070 return ChangeStatus::UNCHANGED;
5071 getState() = FnAA->getState();
5072 return ChangeStatus::CHANGED;
5073 }
5074 if (NumCallees > 1)
5075 return indicatePessimisticFixpoint();
5076
5077 CallBase &CB = cast<CallBase>(getAssociatedValue());
5078 if (It->getSecond() == OMPRTL___kmpc_parallel_51) {
5079 if (!handleParallel51(A, CB))
5080 return indicatePessimisticFixpoint();
5081 return StateBefore == getState() ? ChangeStatus::UNCHANGED
5082 : ChangeStatus::CHANGED;
5083 }
5084
5085 // F is a runtime function that allocates or frees memory, check
5086 // AAHeapToStack and AAHeapToShared.
5087 assert(
5088 (It->getSecond() == OMPRTL___kmpc_alloc_shared ||
5089 It->getSecond() == OMPRTL___kmpc_free_shared) &&
5090 "Expected a __kmpc_alloc_shared or __kmpc_free_shared runtime call");
5091
5092 auto *HeapToStackAA = A.getAAFor<AAHeapToStack>(
5093 *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
5094 auto *HeapToSharedAA = A.getAAFor<AAHeapToShared>(
5095 *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
5096
5097 RuntimeFunction RF = It->getSecond();
5098
5099 switch (RF) {
5100 // If neither HeapToStack nor HeapToShared assume the call is removed,
5101 // assume SPMD incompatibility.
5102 case OMPRTL___kmpc_alloc_shared:
5103 if ((!HeapToStackAA || !HeapToStackAA->isAssumedHeapToStack(CB)) &&
5104 (!HeapToSharedAA || !HeapToSharedAA->isAssumedHeapToShared(CB)))
5105 SPMDCompatibilityTracker.insert(&CB);
5106 break;
5107 case OMPRTL___kmpc_free_shared:
5108 if ((!HeapToStackAA ||
5109 !HeapToStackAA->isAssumedHeapToStackRemovedFree(CB)) &&
5110 (!HeapToSharedAA ||
5111 !HeapToSharedAA->isAssumedHeapToSharedRemovedFree(CB)))
5112 SPMDCompatibilityTracker.insert(&CB);
5113 break;
5114 default:
5115 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5116 SPMDCompatibilityTracker.insert(&CB);
5117 }
5118 return ChangeStatus::CHANGED;
5119 };
5120
5121 const auto *AACE =
5122 A.getAAFor<AACallEdges>(*this, getIRPosition(), DepClassTy::OPTIONAL);
5123 if (!AACE || !AACE->getState().isValidState() || AACE->hasUnknownCallee()) {
5124 if (Function *F = getAssociatedFunction())
5125 CheckCallee(F, /*NumCallees=*/1);
5126 } else {
5127 const auto &OptimisticEdges = AACE->getOptimisticEdges();
5128 for (auto *Callee : OptimisticEdges) {
5129 CheckCallee(Callee, OptimisticEdges.size());
5130 if (isAtFixpoint())
5131 break;
5132 }
5133 }
5134
5135 return StateBefore == getState() ? ChangeStatus::UNCHANGED
5136 : ChangeStatus::CHANGED;
5137 }
5138
5139 /// Deal with a __kmpc_parallel_51 call (\p CB). Returns true if the call was
5140 /// handled, if a problem occurred, false is returned.
5141 bool handleParallel51(Attributor &A, CallBase &CB) {
5142 const unsigned int NonWrapperFunctionArgNo = 5;
5143 const unsigned int WrapperFunctionArgNo = 6;
5144 auto ParallelRegionOpArgNo = SPMDCompatibilityTracker.isAssumed()
5145 ? NonWrapperFunctionArgNo
5146 : WrapperFunctionArgNo;
5147
5148 auto *ParallelRegion = dyn_cast<Function>(
5149 CB.getArgOperand(ParallelRegionOpArgNo)->stripPointerCasts());
5150 if (!ParallelRegion)
5151 return false;
5152
5153 ReachedKnownParallelRegions.insert(&CB);
5154 /// Check nested parallelism
5155 auto *FnAA = A.getAAFor<AAKernelInfo>(
5156 *this, IRPosition::function(*ParallelRegion), DepClassTy::OPTIONAL);
5157 NestedParallelism |= !FnAA || !FnAA->getState().isValidState() ||
5158 !FnAA->ReachedKnownParallelRegions.empty() ||
5159 !FnAA->ReachedKnownParallelRegions.isValidState() ||
5160 !FnAA->ReachedUnknownParallelRegions.isValidState() ||
5161 !FnAA->ReachedUnknownParallelRegions.empty();
5162 return true;
5163 }
5164};
5165
5166struct AAFoldRuntimeCall
5167 : public StateWrapper<BooleanState, AbstractAttribute> {
5169
5170 AAFoldRuntimeCall(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
5171
5172 /// Statistics are tracked as part of manifest for now.
5173 void trackStatistics() const override {}
5174
5175 /// Create an abstract attribute biew for the position \p IRP.
5176 static AAFoldRuntimeCall &createForPosition(const IRPosition &IRP,
5177 Attributor &A);
5178
5179 /// See AbstractAttribute::getName()
5180 const std::string getName() const override { return "AAFoldRuntimeCall"; }
5181
5182 /// See AbstractAttribute::getIdAddr()
5183 const char *getIdAddr() const override { return &ID; }
5184
5185 /// This function should return true if the type of the \p AA is
5186 /// AAFoldRuntimeCall
5187 static bool classof(const AbstractAttribute *AA) {
5188 return (AA->getIdAddr() == &ID);
5189 }
5190
5191 static const char ID;
5192};
5193
5194struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
5195 AAFoldRuntimeCallCallSiteReturned(const IRPosition &IRP, Attributor &A)
5196 : AAFoldRuntimeCall(IRP, A) {}
5197
5198 /// See AbstractAttribute::getAsStr()
5199 const std::string getAsStr(Attributor *) const override {
5200 if (!isValidState())
5201 return "<invalid>";
5202
5203 std::string Str("simplified value: ");
5204
5205 if (!SimplifiedValue)
5206 return Str + std::string("none");
5207
5208 if (!*SimplifiedValue)
5209 return Str + std::string("nullptr");
5210
5211 if (ConstantInt *CI = dyn_cast<ConstantInt>(*SimplifiedValue))
5212 return Str + std::to_string(CI->getSExtValue());
5213
5214 return Str + std::string("unknown");
5215 }
5216
5217 void initialize(Attributor &A) override {
5219 indicatePessimisticFixpoint();
5220
5221 Function *Callee = getAssociatedFunction();
5222
5223 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
5224 const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
5225 assert(It != OMPInfoCache.RuntimeFunctionIDMap.end() &&
5226 "Expected a known OpenMP runtime function");
5227
5228 RFKind = It->getSecond();
5229
5230 CallBase &CB = cast<CallBase>(getAssociatedValue());
5231 A.registerSimplificationCallback(
5233 [&](const IRPosition &IRP, const AbstractAttribute *AA,
5234 bool &UsedAssumedInformation) -> std::optional<Value *> {
5235 assert((isValidState() ||
5236 (SimplifiedValue && *SimplifiedValue == nullptr)) &&
5237 "Unexpected invalid state!");
5238
5239 if (!isAtFixpoint()) {
5240 UsedAssumedInformation = true;
5241 if (AA)
5242 A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
5243 }
5244 return SimplifiedValue;
5245 });
5246 }
5247
5248 ChangeStatus updateImpl(Attributor &A) override {
5249 ChangeStatus Changed = ChangeStatus::UNCHANGED;
5250 switch (RFKind) {
5251 case OMPRTL___kmpc_is_spmd_exec_mode:
5252 Changed |= foldIsSPMDExecMode(A);
5253 break;
5254 case OMPRTL___kmpc_parallel_level:
5255 Changed |= foldParallelLevel(A);
5256 break;
5257 case OMPRTL___kmpc_get_hardware_num_threads_in_block:
5258 Changed = Changed | foldKernelFnAttribute(A, "omp_target_thread_limit");
5259 break;
5260 case OMPRTL___kmpc_get_hardware_num_blocks:
5261 Changed = Changed | foldKernelFnAttribute(A, "omp_target_num_teams");
5262 break;
5263 default:
5264 llvm_unreachable("Unhandled OpenMP runtime function!");
5265 }
5266
5267 return Changed;
5268 }
5269
5270 ChangeStatus manifest(Attributor &A) override {
5271 ChangeStatus Changed = ChangeStatus::UNCHANGED;
5272
5273 if (SimplifiedValue && *SimplifiedValue) {
5274 Instruction &I = *getCtxI();
5275 A.changeAfterManifest(IRPosition::inst(I), **SimplifiedValue);
5276 A.deleteAfterManifest(I);
5277
5278 CallBase *CB = dyn_cast<CallBase>(&I);
5279 auto Remark = [&](OptimizationRemark OR) {
5280 if (auto *C = dyn_cast<ConstantInt>(*SimplifiedValue))
5281 return OR << "Replacing OpenMP runtime call "
5282 << CB->getCalledFunction()->getName() << " with "
5283 << ore::NV("FoldedValue", C->getZExtValue()) << ".";
5284 return OR << "Replacing OpenMP runtime call "
5285 << CB->getCalledFunction()->getName() << ".";
5286 };
5287
5288 if (CB && EnableVerboseRemarks)
5289 A.emitRemark<OptimizationRemark>(CB, "OMP180", Remark);
5290
5291 LLVM_DEBUG(dbgs() << TAG << "Replacing runtime call: " << I << " with "
5292 << **SimplifiedValue << "\n");
5293
5294 Changed = ChangeStatus::CHANGED;
5295 }
5296
5297 return Changed;
5298 }
5299
5300 ChangeStatus indicatePessimisticFixpoint() override {
5301 SimplifiedValue = nullptr;
5302 return AAFoldRuntimeCall::indicatePessimisticFixpoint();
5303 }
5304
5305private:
5306 /// Fold __kmpc_is_spmd_exec_mode into a constant if possible.
5307 ChangeStatus foldIsSPMDExecMode(Attributor &A) {
5308 std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
5309
5310 unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
5311 unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
5312 auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
5313 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
5314
5315 if (!CallerKernelInfoAA ||
5316 !CallerKernelInfoAA->ReachingKernelEntries.isValidState())
5317 return indicatePessimisticFixpoint();
5318
5319 for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
5320 auto *AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
5321 DepClassTy::REQUIRED);
5322
5323 if (!AA || !AA->isValidState()) {
5324 SimplifiedValue = nullptr;
5325 return indicatePessimisticFixpoint();
5326 }
5327
5328 if (AA->SPMDCompatibilityTracker.isAssumed()) {
5329 if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5330 ++KnownSPMDCount;
5331 else
5332 ++AssumedSPMDCount;
5333 } else {
5334 if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5335 ++KnownNonSPMDCount;
5336 else
5337 ++AssumedNonSPMDCount;
5338 }
5339 }
5340
5341 if ((AssumedSPMDCount + KnownSPMDCount) &&
5342 (AssumedNonSPMDCount + KnownNonSPMDCount))
5343 return indicatePessimisticFixpoint();
5344
5345 auto &Ctx = getAnchorValue().getContext();
5346 if (KnownSPMDCount || AssumedSPMDCount) {
5347 assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
5348 "Expected only SPMD kernels!");
5349 // All reaching kernels are in SPMD mode. Update all function calls to
5350 // __kmpc_is_spmd_exec_mode to 1.
5351 SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), true);
5352 } else if (KnownNonSPMDCount || AssumedNonSPMDCount) {
5353 assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
5354 "Expected only non-SPMD kernels!");
5355 // All reaching kernels are in non-SPMD mode. Update all function
5356 // calls to __kmpc_is_spmd_exec_mode to 0.
5357 SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), false);
5358 } else {
5359 // We have empty reaching kernels, therefore we cannot tell if the
5360 // associated call site can be folded. At this moment, SimplifiedValue
5361 // must be none.
5362 assert(!SimplifiedValue && "SimplifiedValue should be none");
5363 }
5364
5365 return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
5366 : ChangeStatus::CHANGED;
5367 }
5368
5369 /// Fold __kmpc_parallel_level into a constant if possible.
5370 ChangeStatus foldParallelLevel(Attributor &A) {
5371 std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
5372
5373 auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
5374 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
5375
5376 if (!CallerKernelInfoAA ||
5377 !CallerKernelInfoAA->ParallelLevels.isValidState())
5378 return indicatePessimisticFixpoint();
5379
5380 if (!CallerKernelInfoAA->ReachingKernelEntries.isValidState())
5381 return indicatePessimisticFixpoint();
5382
5383 if (CallerKernelInfoAA->ReachingKernelEntries.empty()) {
5384 assert(!SimplifiedValue &&
5385 "SimplifiedValue should keep none at this point");
5386 return ChangeStatus::UNCHANGED;
5387 }
5388
5389 unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
5390 unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
5391 for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
5392 auto *AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
5393 DepClassTy::REQUIRED);
5394 if (!AA || !AA->SPMDCompatibilityTracker.isValidState())
5395 return indicatePessimisticFixpoint();
5396