LLVM 20.0.0git
OpenMPOpt.cpp
Go to the documentation of this file.
1//===-- IPO/OpenMPOpt.cpp - Collection of OpenMP specific optimizations ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// OpenMP specific optimizations:
10//
11// - Deduplication of runtime calls, e.g., omp_get_thread_num.
12// - Replacing globalized device memory with stack memory.
13// - Replacing globalized device memory with shared memory.
14// - Parallel region merging.
15// - Transforming generic-mode device kernels to SPMD mode.
16// - Specializing the state machine for generic-mode device kernels.
17//
18//===----------------------------------------------------------------------===//
19
21
24#include "llvm/ADT/SetVector.h"
27#include "llvm/ADT/Statistic.h"
29#include "llvm/ADT/StringRef.h"
38#include "llvm/IR/Assumptions.h"
39#include "llvm/IR/BasicBlock.h"
40#include "llvm/IR/Constants.h"
42#include "llvm/IR/Dominators.h"
43#include "llvm/IR/Function.h"
44#include "llvm/IR/GlobalValue.h"
46#include "llvm/IR/InstrTypes.h"
47#include "llvm/IR/Instruction.h"
50#include "llvm/IR/IntrinsicsAMDGPU.h"
51#include "llvm/IR/IntrinsicsNVPTX.h"
52#include "llvm/IR/LLVMContext.h"
55#include "llvm/Support/Debug.h"
59
60#include <algorithm>
61#include <optional>
62#include <string>
63
64using namespace llvm;
65using namespace omp;
66
67#define DEBUG_TYPE "openmp-opt"
68
70 "openmp-opt-disable", cl::desc("Disable OpenMP specific optimizations."),
71 cl::Hidden, cl::init(false));
72
74 "openmp-opt-enable-merging",
75 cl::desc("Enable the OpenMP region merging optimization."), cl::Hidden,
76 cl::init(false));
77
78static cl::opt<bool>
79 DisableInternalization("openmp-opt-disable-internalization",
80 cl::desc("Disable function internalization."),
81 cl::Hidden, cl::init(false));
82
83static cl::opt<bool> DeduceICVValues("openmp-deduce-icv-values",
84 cl::init(false), cl::Hidden);
85static cl::opt<bool> PrintICVValues("openmp-print-icv-values", cl::init(false),
87static cl::opt<bool> PrintOpenMPKernels("openmp-print-gpu-kernels",
88 cl::init(false), cl::Hidden);
89
91 "openmp-hide-memory-transfer-latency",
92 cl::desc("[WIP] Tries to hide the latency of host to device memory"
93 " transfers"),
94 cl::Hidden, cl::init(false));
95
97 "openmp-opt-disable-deglobalization",
98 cl::desc("Disable OpenMP optimizations involving deglobalization."),
99 cl::Hidden, cl::init(false));
100
102 "openmp-opt-disable-spmdization",
103 cl::desc("Disable OpenMP optimizations involving SPMD-ization."),
104 cl::Hidden, cl::init(false));
105
107 "openmp-opt-disable-folding",
108 cl::desc("Disable OpenMP optimizations involving folding."), cl::Hidden,
109 cl::init(false));
110
112 "openmp-opt-disable-state-machine-rewrite",
113 cl::desc("Disable OpenMP optimizations that replace the state machine."),
114 cl::Hidden, cl::init(false));
115
117 "openmp-opt-disable-barrier-elimination",
118 cl::desc("Disable OpenMP optimizations that eliminate barriers."),
119 cl::Hidden, cl::init(false));
120
122 "openmp-opt-print-module-after",
123 cl::desc("Print the current module after OpenMP optimizations."),
124 cl::Hidden, cl::init(false));
125
127 "openmp-opt-print-module-before",
128 cl::desc("Print the current module before OpenMP optimizations."),
129 cl::Hidden, cl::init(false));
130
132 "openmp-opt-inline-device",
133 cl::desc("Inline all applicible functions on the device."), cl::Hidden,
134 cl::init(false));
135
136static cl::opt<bool>
137 EnableVerboseRemarks("openmp-opt-verbose-remarks",
138 cl::desc("Enables more verbose remarks."), cl::Hidden,
139 cl::init(false));
140
142 SetFixpointIterations("openmp-opt-max-iterations", cl::Hidden,
143 cl::desc("Maximal number of attributor iterations."),
144 cl::init(256));
145
147 SharedMemoryLimit("openmp-opt-shared-limit", cl::Hidden,
148 cl::desc("Maximum amount of shared memory to use."),
149 cl::init(std::numeric_limits<unsigned>::max()));
150
151STATISTIC(NumOpenMPRuntimeCallsDeduplicated,
152 "Number of OpenMP runtime calls deduplicated");
153STATISTIC(NumOpenMPParallelRegionsDeleted,
154 "Number of OpenMP parallel regions deleted");
155STATISTIC(NumOpenMPRuntimeFunctionsIdentified,
156 "Number of OpenMP runtime functions identified");
157STATISTIC(NumOpenMPRuntimeFunctionUsesIdentified,
158 "Number of OpenMP runtime function uses identified");
159STATISTIC(NumOpenMPTargetRegionKernels,
160 "Number of OpenMP target region entry points (=kernels) identified");
161STATISTIC(NumNonOpenMPTargetRegionKernels,
162 "Number of non-OpenMP target region kernels identified");
163STATISTIC(NumOpenMPTargetRegionKernelsSPMD,
164 "Number of OpenMP target region entry points (=kernels) executed in "
165 "SPMD-mode instead of generic-mode");
166STATISTIC(NumOpenMPTargetRegionKernelsWithoutStateMachine,
167 "Number of OpenMP target region entry points (=kernels) executed in "
168 "generic-mode without a state machines");
169STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback,
170 "Number of OpenMP target region entry points (=kernels) executed in "
171 "generic-mode with customized state machines with fallback");
172STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback,
173 "Number of OpenMP target region entry points (=kernels) executed in "
174 "generic-mode with customized state machines without fallback");
176 NumOpenMPParallelRegionsReplacedInGPUStateMachine,
177 "Number of OpenMP parallel regions replaced with ID in GPU state machines");
178STATISTIC(NumOpenMPParallelRegionsMerged,
179 "Number of OpenMP parallel regions merged");
180STATISTIC(NumBytesMovedToSharedMemory,
181 "Amount of memory pushed to shared memory");
182STATISTIC(NumBarriersEliminated, "Number of redundant barriers eliminated");
183
184#if !defined(NDEBUG)
185static constexpr auto TAG = "[" DEBUG_TYPE "]";
186#endif
187
188namespace KernelInfo {
189
190// struct ConfigurationEnvironmentTy {
191// uint8_t UseGenericStateMachine;
192// uint8_t MayUseNestedParallelism;
193// llvm::omp::OMPTgtExecModeFlags ExecMode;
194// int32_t MinThreads;
195// int32_t MaxThreads;
196// int32_t MinTeams;
197// int32_t MaxTeams;
198// };
199
200// struct DynamicEnvironmentTy {
201// uint16_t DebugIndentionLevel;
202// };
203
204// struct KernelEnvironmentTy {
205// ConfigurationEnvironmentTy Configuration;
206// IdentTy *Ident;
207// DynamicEnvironmentTy *DynamicEnv;
208// };
209
210#define KERNEL_ENVIRONMENT_IDX(MEMBER, IDX) \
211 constexpr const unsigned MEMBER##Idx = IDX;
212
213KERNEL_ENVIRONMENT_IDX(Configuration, 0)
215
216#undef KERNEL_ENVIRONMENT_IDX
217
218#define KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MEMBER, IDX) \
219 constexpr const unsigned MEMBER##Idx = IDX;
220
221KERNEL_ENVIRONMENT_CONFIGURATION_IDX(UseGenericStateMachine, 0)
222KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MayUseNestedParallelism, 1)
228
229#undef KERNEL_ENVIRONMENT_CONFIGURATION_IDX
230
231#define KERNEL_ENVIRONMENT_GETTER(MEMBER, RETURNTYPE) \
232 RETURNTYPE *get##MEMBER##FromKernelEnvironment(ConstantStruct *KernelEnvC) { \
233 return cast<RETURNTYPE>(KernelEnvC->getAggregateElement(MEMBER##Idx)); \
234 }
235
238
239#undef KERNEL_ENVIRONMENT_GETTER
240
241#define KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MEMBER) \
242 ConstantInt *get##MEMBER##FromKernelEnvironment( \
243 ConstantStruct *KernelEnvC) { \
244 ConstantStruct *ConfigC = \
245 getConfigurationFromKernelEnvironment(KernelEnvC); \
246 return dyn_cast<ConstantInt>(ConfigC->getAggregateElement(MEMBER##Idx)); \
247 }
248
249KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(UseGenericStateMachine)
250KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MayUseNestedParallelism)
256
257#undef KERNEL_ENVIRONMENT_CONFIGURATION_GETTER
258
261 constexpr const int InitKernelEnvironmentArgNo = 0;
262 return cast<GlobalVariable>(
263 KernelInitCB->getArgOperand(InitKernelEnvironmentArgNo)
265}
266
268 GlobalVariable *KernelEnvGV =
270 return cast<ConstantStruct>(KernelEnvGV->getInitializer());
271}
272} // namespace KernelInfo
273
274namespace {
275
276struct AAHeapToShared;
277
278struct AAICVTracker;
279
280/// OpenMP specific information. For now, stores RFIs and ICVs also needed for
281/// Attributor runs.
282struct OMPInformationCache : public InformationCache {
283 OMPInformationCache(Module &M, AnalysisGetter &AG,
285 bool OpenMPPostLink)
286 : InformationCache(M, AG, Allocator, CGSCC), OMPBuilder(M),
287 OpenMPPostLink(OpenMPPostLink) {
288
289 OMPBuilder.Config.IsTargetDevice = isOpenMPDevice(OMPBuilder.M);
290 OMPBuilder.initialize();
291 initializeRuntimeFunctions(M);
292 initializeInternalControlVars();
293 }
294
295 /// Generic information that describes an internal control variable.
296 struct InternalControlVarInfo {
297 /// The kind, as described by InternalControlVar enum.
299
300 /// The name of the ICV.
302
303 /// Environment variable associated with this ICV.
304 StringRef EnvVarName;
305
306 /// Initial value kind.
307 ICVInitValue InitKind;
308
309 /// Initial value.
310 ConstantInt *InitValue;
311
312 /// Setter RTL function associated with this ICV.
313 RuntimeFunction Setter;
314
315 /// Getter RTL function associated with this ICV.
316 RuntimeFunction Getter;
317
318 /// RTL Function corresponding to the override clause of this ICV
320 };
321
322 /// Generic information that describes a runtime function
323 struct RuntimeFunctionInfo {
324
325 /// The kind, as described by the RuntimeFunction enum.
327
328 /// The name of the function.
330
331 /// Flag to indicate a variadic function.
332 bool IsVarArg;
333
334 /// The return type of the function.
336
337 /// The argument types of the function.
338 SmallVector<Type *, 8> ArgumentTypes;
339
340 /// The declaration if available.
341 Function *Declaration = nullptr;
342
343 /// Uses of this runtime function per function containing the use.
344 using UseVector = SmallVector<Use *, 16>;
345
346 /// Clear UsesMap for runtime function.
347 void clearUsesMap() { UsesMap.clear(); }
348
349 /// Boolean conversion that is true if the runtime function was found.
350 operator bool() const { return Declaration; }
351
352 /// Return the vector of uses in function \p F.
353 UseVector &getOrCreateUseVector(Function *F) {
354 std::shared_ptr<UseVector> &UV = UsesMap[F];
355 if (!UV)
356 UV = std::make_shared<UseVector>();
357 return *UV;
358 }
359
360 /// Return the vector of uses in function \p F or `nullptr` if there are
361 /// none.
362 const UseVector *getUseVector(Function &F) const {
363 auto I = UsesMap.find(&F);
364 if (I != UsesMap.end())
365 return I->second.get();
366 return nullptr;
367 }
368
369 /// Return how many functions contain uses of this runtime function.
370 size_t getNumFunctionsWithUses() const { return UsesMap.size(); }
371
372 /// Return the number of arguments (or the minimal number for variadic
373 /// functions).
374 size_t getNumArgs() const { return ArgumentTypes.size(); }
375
376 /// Run the callback \p CB on each use and forget the use if the result is
377 /// true. The callback will be fed the function in which the use was
378 /// encountered as second argument.
379 void foreachUse(SmallVectorImpl<Function *> &SCC,
380 function_ref<bool(Use &, Function &)> CB) {
381 for (Function *F : SCC)
382 foreachUse(CB, F);
383 }
384
385 /// Run the callback \p CB on each use within the function \p F and forget
386 /// the use if the result is true.
387 void foreachUse(function_ref<bool(Use &, Function &)> CB, Function *F) {
388 SmallVector<unsigned, 8> ToBeDeleted;
389 ToBeDeleted.clear();
390
391 unsigned Idx = 0;
392 UseVector &UV = getOrCreateUseVector(F);
393
394 for (Use *U : UV) {
395 if (CB(*U, *F))
396 ToBeDeleted.push_back(Idx);
397 ++Idx;
398 }
399
400 // Remove the to-be-deleted indices in reverse order as prior
401 // modifications will not modify the smaller indices.
402 while (!ToBeDeleted.empty()) {
403 unsigned Idx = ToBeDeleted.pop_back_val();
404 UV[Idx] = UV.back();
405 UV.pop_back();
406 }
407 }
408
409 private:
410 /// Map from functions to all uses of this runtime function contained in
411 /// them.
413
414 public:
415 /// Iterators for the uses of this runtime function.
416 decltype(UsesMap)::iterator begin() { return UsesMap.begin(); }
417 decltype(UsesMap)::iterator end() { return UsesMap.end(); }
418 };
419
420 /// An OpenMP-IR-Builder instance
421 OpenMPIRBuilder OMPBuilder;
422
423 /// Map from runtime function kind to the runtime function description.
424 EnumeratedArray<RuntimeFunctionInfo, RuntimeFunction,
425 RuntimeFunction::OMPRTL___last>
426 RFIs;
427
428 /// Map from function declarations/definitions to their runtime enum type.
429 DenseMap<Function *, RuntimeFunction> RuntimeFunctionIDMap;
430
431 /// Map from ICV kind to the ICV description.
432 EnumeratedArray<InternalControlVarInfo, InternalControlVar,
433 InternalControlVar::ICV___last>
434 ICVs;
435
436 /// Helper to initialize all internal control variable information for those
437 /// defined in OMPKinds.def.
438 void initializeInternalControlVars() {
439#define ICV_RT_SET(_Name, RTL) \
440 { \
441 auto &ICV = ICVs[_Name]; \
442 ICV.Setter = RTL; \
443 }
444#define ICV_RT_GET(Name, RTL) \
445 { \
446 auto &ICV = ICVs[Name]; \
447 ICV.Getter = RTL; \
448 }
449#define ICV_DATA_ENV(Enum, _Name, _EnvVarName, Init) \
450 { \
451 auto &ICV = ICVs[Enum]; \
452 ICV.Name = _Name; \
453 ICV.Kind = Enum; \
454 ICV.InitKind = Init; \
455 ICV.EnvVarName = _EnvVarName; \
456 switch (ICV.InitKind) { \
457 case ICV_IMPLEMENTATION_DEFINED: \
458 ICV.InitValue = nullptr; \
459 break; \
460 case ICV_ZERO: \
461 ICV.InitValue = ConstantInt::get( \
462 Type::getInt32Ty(OMPBuilder.Int32->getContext()), 0); \
463 break; \
464 case ICV_FALSE: \
465 ICV.InitValue = ConstantInt::getFalse(OMPBuilder.Int1->getContext()); \
466 break; \
467 case ICV_LAST: \
468 break; \
469 } \
470 }
471#include "llvm/Frontend/OpenMP/OMPKinds.def"
472 }
473
474 /// Returns true if the function declaration \p F matches the runtime
475 /// function types, that is, return type \p RTFRetType, and argument types
476 /// \p RTFArgTypes.
477 static bool declMatchesRTFTypes(Function *F, Type *RTFRetType,
478 SmallVector<Type *, 8> &RTFArgTypes) {
479 // TODO: We should output information to the user (under debug output
480 // and via remarks).
481
482 if (!F)
483 return false;
484 if (F->getReturnType() != RTFRetType)
485 return false;
486 if (F->arg_size() != RTFArgTypes.size())
487 return false;
488
489 auto *RTFTyIt = RTFArgTypes.begin();
490 for (Argument &Arg : F->args()) {
491 if (Arg.getType() != *RTFTyIt)
492 return false;
493
494 ++RTFTyIt;
495 }
496
497 return true;
498 }
499
500 // Helper to collect all uses of the declaration in the UsesMap.
501 unsigned collectUses(RuntimeFunctionInfo &RFI, bool CollectStats = true) {
502 unsigned NumUses = 0;
503 if (!RFI.Declaration)
504 return NumUses;
505 OMPBuilder.addAttributes(RFI.Kind, *RFI.Declaration);
506
507 if (CollectStats) {
508 NumOpenMPRuntimeFunctionsIdentified += 1;
509 NumOpenMPRuntimeFunctionUsesIdentified += RFI.Declaration->getNumUses();
510 }
511
512 // TODO: We directly convert uses into proper calls and unknown uses.
513 for (Use &U : RFI.Declaration->uses()) {
514 if (Instruction *UserI = dyn_cast<Instruction>(U.getUser())) {
515 if (!CGSCC || CGSCC->empty() || CGSCC->contains(UserI->getFunction())) {
516 RFI.getOrCreateUseVector(UserI->getFunction()).push_back(&U);
517 ++NumUses;
518 }
519 } else {
520 RFI.getOrCreateUseVector(nullptr).push_back(&U);
521 ++NumUses;
522 }
523 }
524 return NumUses;
525 }
526
527 // Helper function to recollect uses of a runtime function.
528 void recollectUsesForFunction(RuntimeFunction RTF) {
529 auto &RFI = RFIs[RTF];
530 RFI.clearUsesMap();
531 collectUses(RFI, /*CollectStats*/ false);
532 }
533
534 // Helper function to recollect uses of all runtime functions.
535 void recollectUses() {
536 for (int Idx = 0; Idx < RFIs.size(); ++Idx)
537 recollectUsesForFunction(static_cast<RuntimeFunction>(Idx));
538 }
539
540 // Helper function to inherit the calling convention of the function callee.
541 void setCallingConvention(FunctionCallee Callee, CallInst *CI) {
542 if (Function *Fn = dyn_cast<Function>(Callee.getCallee()))
543 CI->setCallingConv(Fn->getCallingConv());
544 }
545
546 // Helper function to determine if it's legal to create a call to the runtime
547 // functions.
548 bool runtimeFnsAvailable(ArrayRef<RuntimeFunction> Fns) {
549 // We can always emit calls if we haven't yet linked in the runtime.
550 if (!OpenMPPostLink)
551 return true;
552
553 // Once the runtime has been already been linked in we cannot emit calls to
554 // any undefined functions.
555 for (RuntimeFunction Fn : Fns) {
556 RuntimeFunctionInfo &RFI = RFIs[Fn];
557
558 if (RFI.Declaration && RFI.Declaration->isDeclaration())
559 return false;
560 }
561 return true;
562 }
563
564 /// Helper to initialize all runtime function information for those defined
565 /// in OpenMPKinds.def.
566 void initializeRuntimeFunctions(Module &M) {
567
568 // Helper macros for handling __VA_ARGS__ in OMP_RTL
569#define OMP_TYPE(VarName, ...) \
570 Type *VarName = OMPBuilder.VarName; \
571 (void)VarName;
572
573#define OMP_ARRAY_TYPE(VarName, ...) \
574 ArrayType *VarName##Ty = OMPBuilder.VarName##Ty; \
575 (void)VarName##Ty; \
576 PointerType *VarName##PtrTy = OMPBuilder.VarName##PtrTy; \
577 (void)VarName##PtrTy;
578
579#define OMP_FUNCTION_TYPE(VarName, ...) \
580 FunctionType *VarName = OMPBuilder.VarName; \
581 (void)VarName; \
582 PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \
583 (void)VarName##Ptr;
584
585#define OMP_STRUCT_TYPE(VarName, ...) \
586 StructType *VarName = OMPBuilder.VarName; \
587 (void)VarName; \
588 PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \
589 (void)VarName##Ptr;
590
591#define OMP_RTL(_Enum, _Name, _IsVarArg, _ReturnType, ...) \
592 { \
593 SmallVector<Type *, 8> ArgsTypes({__VA_ARGS__}); \
594 Function *F = M.getFunction(_Name); \
595 RTLFunctions.insert(F); \
596 if (declMatchesRTFTypes(F, OMPBuilder._ReturnType, ArgsTypes)) { \
597 RuntimeFunctionIDMap[F] = _Enum; \
598 auto &RFI = RFIs[_Enum]; \
599 RFI.Kind = _Enum; \
600 RFI.Name = _Name; \
601 RFI.IsVarArg = _IsVarArg; \
602 RFI.ReturnType = OMPBuilder._ReturnType; \
603 RFI.ArgumentTypes = std::move(ArgsTypes); \
604 RFI.Declaration = F; \
605 unsigned NumUses = collectUses(RFI); \
606 (void)NumUses; \
607 LLVM_DEBUG({ \
608 dbgs() << TAG << RFI.Name << (RFI.Declaration ? "" : " not") \
609 << " found\n"; \
610 if (RFI.Declaration) \
611 dbgs() << TAG << "-> got " << NumUses << " uses in " \
612 << RFI.getNumFunctionsWithUses() \
613 << " different functions.\n"; \
614 }); \
615 } \
616 }
617#include "llvm/Frontend/OpenMP/OMPKinds.def"
618
619 // Remove the `noinline` attribute from `__kmpc`, `ompx::` and `omp_`
620 // functions, except if `optnone` is present.
621 if (isOpenMPDevice(M)) {
622 for (Function &F : M) {
623 for (StringRef Prefix : {"__kmpc", "_ZN4ompx", "omp_"})
624 if (F.hasFnAttribute(Attribute::NoInline) &&
625 F.getName().starts_with(Prefix) &&
626 !F.hasFnAttribute(Attribute::OptimizeNone))
627 F.removeFnAttr(Attribute::NoInline);
628 }
629 }
630
631 // TODO: We should attach the attributes defined in OMPKinds.def.
632 }
633
634 /// Collection of known OpenMP runtime functions..
635 DenseSet<const Function *> RTLFunctions;
636
637 /// Indicates if we have already linked in the OpenMP device library.
638 bool OpenMPPostLink = false;
639};
640
641template <typename Ty, bool InsertInvalidates = true>
642struct BooleanStateWithSetVector : public BooleanState {
643 bool contains(const Ty &Elem) const { return Set.contains(Elem); }
644 bool insert(const Ty &Elem) {
645 if (InsertInvalidates)
647 return Set.insert(Elem);
648 }
649
650 const Ty &operator[](int Idx) const { return Set[Idx]; }
651 bool operator==(const BooleanStateWithSetVector &RHS) const {
652 return BooleanState::operator==(RHS) && Set == RHS.Set;
653 }
654 bool operator!=(const BooleanStateWithSetVector &RHS) const {
655 return !(*this == RHS);
656 }
657
658 bool empty() const { return Set.empty(); }
659 size_t size() const { return Set.size(); }
660
661 /// "Clamp" this state with \p RHS.
662 BooleanStateWithSetVector &operator^=(const BooleanStateWithSetVector &RHS) {
663 BooleanState::operator^=(RHS);
664 Set.insert(RHS.Set.begin(), RHS.Set.end());
665 return *this;
666 }
667
668private:
669 /// A set to keep track of elements.
670 SetVector<Ty> Set;
671
672public:
673 typename decltype(Set)::iterator begin() { return Set.begin(); }
674 typename decltype(Set)::iterator end() { return Set.end(); }
675 typename decltype(Set)::const_iterator begin() const { return Set.begin(); }
676 typename decltype(Set)::const_iterator end() const { return Set.end(); }
677};
678
679template <typename Ty, bool InsertInvalidates = true>
680using BooleanStateWithPtrSetVector =
681 BooleanStateWithSetVector<Ty *, InsertInvalidates>;
682
683struct KernelInfoState : AbstractState {
684 /// Flag to track if we reached a fixpoint.
685 bool IsAtFixpoint = false;
686
687 /// The parallel regions (identified by the outlined parallel functions) that
688 /// can be reached from the associated function.
689 BooleanStateWithPtrSetVector<CallBase, /* InsertInvalidates */ false>
690 ReachedKnownParallelRegions;
691
692 /// State to track what parallel region we might reach.
693 BooleanStateWithPtrSetVector<CallBase> ReachedUnknownParallelRegions;
694
695 /// State to track if we are in SPMD-mode, assumed or know, and why we decided
696 /// we cannot be. If it is assumed, then RequiresFullRuntime should also be
697 /// false.
698 BooleanStateWithPtrSetVector<Instruction, false> SPMDCompatibilityTracker;
699
700 /// The __kmpc_target_init call in this kernel, if any. If we find more than
701 /// one we abort as the kernel is malformed.
702 CallBase *KernelInitCB = nullptr;
703
704 /// The constant kernel environement as taken from and passed to
705 /// __kmpc_target_init.
706 ConstantStruct *KernelEnvC = nullptr;
707
708 /// The __kmpc_target_deinit call in this kernel, if any. If we find more than
709 /// one we abort as the kernel is malformed.
710 CallBase *KernelDeinitCB = nullptr;
711
712 /// Flag to indicate if the associated function is a kernel entry.
713 bool IsKernelEntry = false;
714
715 /// State to track what kernel entries can reach the associated function.
716 BooleanStateWithPtrSetVector<Function, false> ReachingKernelEntries;
717
718 /// State to indicate if we can track parallel level of the associated
719 /// function. We will give up tracking if we encounter unknown caller or the
720 /// caller is __kmpc_parallel_51.
721 BooleanStateWithSetVector<uint8_t> ParallelLevels;
722
723 /// Flag that indicates if the kernel has nested Parallelism
724 bool NestedParallelism = false;
725
726 /// Abstract State interface
727 ///{
728
729 KernelInfoState() = default;
730 KernelInfoState(bool BestState) {
731 if (!BestState)
733 }
734
735 /// See AbstractState::isValidState(...)
736 bool isValidState() const override { return true; }
737
738 /// See AbstractState::isAtFixpoint(...)
739 bool isAtFixpoint() const override { return IsAtFixpoint; }
740
741 /// See AbstractState::indicatePessimisticFixpoint(...)
743 IsAtFixpoint = true;
744 ParallelLevels.indicatePessimisticFixpoint();
745 ReachingKernelEntries.indicatePessimisticFixpoint();
746 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
747 ReachedKnownParallelRegions.indicatePessimisticFixpoint();
748 ReachedUnknownParallelRegions.indicatePessimisticFixpoint();
749 NestedParallelism = true;
751 }
752
753 /// See AbstractState::indicateOptimisticFixpoint(...)
755 IsAtFixpoint = true;
756 ParallelLevels.indicateOptimisticFixpoint();
757 ReachingKernelEntries.indicateOptimisticFixpoint();
758 SPMDCompatibilityTracker.indicateOptimisticFixpoint();
759 ReachedKnownParallelRegions.indicateOptimisticFixpoint();
760 ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
762 }
763
764 /// Return the assumed state
765 KernelInfoState &getAssumed() { return *this; }
766 const KernelInfoState &getAssumed() const { return *this; }
767
768 bool operator==(const KernelInfoState &RHS) const {
769 if (SPMDCompatibilityTracker != RHS.SPMDCompatibilityTracker)
770 return false;
771 if (ReachedKnownParallelRegions != RHS.ReachedKnownParallelRegions)
772 return false;
773 if (ReachedUnknownParallelRegions != RHS.ReachedUnknownParallelRegions)
774 return false;
775 if (ReachingKernelEntries != RHS.ReachingKernelEntries)
776 return false;
777 if (ParallelLevels != RHS.ParallelLevels)
778 return false;
779 if (NestedParallelism != RHS.NestedParallelism)
780 return false;
781 return true;
782 }
783
784 /// Returns true if this kernel contains any OpenMP parallel regions.
785 bool mayContainParallelRegion() {
786 return !ReachedKnownParallelRegions.empty() ||
787 !ReachedUnknownParallelRegions.empty();
788 }
789
790 /// Return empty set as the best state of potential values.
791 static KernelInfoState getBestState() { return KernelInfoState(true); }
792
793 static KernelInfoState getBestState(KernelInfoState &KIS) {
794 return getBestState();
795 }
796
797 /// Return full set as the worst state of potential values.
798 static KernelInfoState getWorstState() { return KernelInfoState(false); }
799
800 /// "Clamp" this state with \p KIS.
801 KernelInfoState operator^=(const KernelInfoState &KIS) {
802 // Do not merge two different _init and _deinit call sites.
803 if (KIS.KernelInitCB) {
804 if (KernelInitCB && KernelInitCB != KIS.KernelInitCB)
805 llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
806 "assumptions.");
807 KernelInitCB = KIS.KernelInitCB;
808 }
809 if (KIS.KernelDeinitCB) {
810 if (KernelDeinitCB && KernelDeinitCB != KIS.KernelDeinitCB)
811 llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
812 "assumptions.");
813 KernelDeinitCB = KIS.KernelDeinitCB;
814 }
815 if (KIS.KernelEnvC) {
816 if (KernelEnvC && KernelEnvC != KIS.KernelEnvC)
817 llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
818 "assumptions.");
819 KernelEnvC = KIS.KernelEnvC;
820 }
821 SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker;
822 ReachedKnownParallelRegions ^= KIS.ReachedKnownParallelRegions;
823 ReachedUnknownParallelRegions ^= KIS.ReachedUnknownParallelRegions;
824 NestedParallelism |= KIS.NestedParallelism;
825 return *this;
826 }
827
828 KernelInfoState operator&=(const KernelInfoState &KIS) {
829 return (*this ^= KIS);
830 }
831
832 ///}
833};
834
835/// Used to map the values physically (in the IR) stored in an offload
836/// array, to a vector in memory.
837struct OffloadArray {
838 /// Physical array (in the IR).
839 AllocaInst *Array = nullptr;
840 /// Mapped values.
841 SmallVector<Value *, 8> StoredValues;
842 /// Last stores made in the offload array.
843 SmallVector<StoreInst *, 8> LastAccesses;
844
845 OffloadArray() = default;
846
847 /// Initializes the OffloadArray with the values stored in \p Array before
848 /// instruction \p Before is reached. Returns false if the initialization
849 /// fails.
850 /// This MUST be used immediately after the construction of the object.
851 bool initialize(AllocaInst &Array, Instruction &Before) {
852 if (!Array.getAllocatedType()->isArrayTy())
853 return false;
854
855 if (!getValues(Array, Before))
856 return false;
857
858 this->Array = &Array;
859 return true;
860 }
861
862 static const unsigned DeviceIDArgNum = 1;
863 static const unsigned BasePtrsArgNum = 3;
864 static const unsigned PtrsArgNum = 4;
865 static const unsigned SizesArgNum = 5;
866
867private:
868 /// Traverses the BasicBlock where \p Array is, collecting the stores made to
869 /// \p Array, leaving StoredValues with the values stored before the
870 /// instruction \p Before is reached.
871 bool getValues(AllocaInst &Array, Instruction &Before) {
872 // Initialize container.
873 const uint64_t NumValues = Array.getAllocatedType()->getArrayNumElements();
874 StoredValues.assign(NumValues, nullptr);
875 LastAccesses.assign(NumValues, nullptr);
876
877 // TODO: This assumes the instruction \p Before is in the same
878 // BasicBlock as Array. Make it general, for any control flow graph.
879 BasicBlock *BB = Array.getParent();
880 if (BB != Before.getParent())
881 return false;
882
883 const DataLayout &DL = Array.getDataLayout();
884 const unsigned int PointerSize = DL.getPointerSize();
885
886 for (Instruction &I : *BB) {
887 if (&I == &Before)
888 break;
889
890 if (!isa<StoreInst>(&I))
891 continue;
892
893 auto *S = cast<StoreInst>(&I);
894 int64_t Offset = -1;
895 auto *Dst =
896 GetPointerBaseWithConstantOffset(S->getPointerOperand(), Offset, DL);
897 if (Dst == &Array) {
898 int64_t Idx = Offset / PointerSize;
899 StoredValues[Idx] = getUnderlyingObject(S->getValueOperand());
900 LastAccesses[Idx] = S;
901 }
902 }
903
904 return isFilled();
905 }
906
907 /// Returns true if all values in StoredValues and
908 /// LastAccesses are not nullptrs.
909 bool isFilled() {
910 const unsigned NumValues = StoredValues.size();
911 for (unsigned I = 0; I < NumValues; ++I) {
912 if (!StoredValues[I] || !LastAccesses[I])
913 return false;
914 }
915
916 return true;
917 }
918};
919
920struct OpenMPOpt {
921
922 using OptimizationRemarkGetter =
924
925 OpenMPOpt(SmallVectorImpl<Function *> &SCC, CallGraphUpdater &CGUpdater,
926 OptimizationRemarkGetter OREGetter,
927 OMPInformationCache &OMPInfoCache, Attributor &A)
928 : M(*(*SCC.begin())->getParent()), SCC(SCC), CGUpdater(CGUpdater),
929 OREGetter(OREGetter), OMPInfoCache(OMPInfoCache), A(A) {}
930
931 /// Check if any remarks are enabled for openmp-opt
932 bool remarksEnabled() {
933 auto &Ctx = M.getContext();
934 return Ctx.getDiagHandlerPtr()->isAnyRemarkEnabled(DEBUG_TYPE);
935 }
936
937 /// Run all OpenMP optimizations on the underlying SCC.
938 bool run(bool IsModulePass) {
939 if (SCC.empty())
940 return false;
941
942 bool Changed = false;
943
944 LLVM_DEBUG(dbgs() << TAG << "Run on SCC with " << SCC.size()
945 << " functions\n");
946
947 if (IsModulePass) {
948 Changed |= runAttributor(IsModulePass);
949
950 // Recollect uses, in case Attributor deleted any.
951 OMPInfoCache.recollectUses();
952
953 // TODO: This should be folded into buildCustomStateMachine.
954 Changed |= rewriteDeviceCodeStateMachine();
955
956 if (remarksEnabled())
957 analysisGlobalization();
958 } else {
959 if (PrintICVValues)
960 printICVs();
962 printKernels();
963
964 Changed |= runAttributor(IsModulePass);
965
966 // Recollect uses, in case Attributor deleted any.
967 OMPInfoCache.recollectUses();
968
969 Changed |= deleteParallelRegions();
970
972 Changed |= hideMemTransfersLatency();
973 Changed |= deduplicateRuntimeCalls();
975 if (mergeParallelRegions()) {
976 deduplicateRuntimeCalls();
977 Changed = true;
978 }
979 }
980 }
981
982 if (OMPInfoCache.OpenMPPostLink)
983 Changed |= removeRuntimeSymbols();
984
985 return Changed;
986 }
987
988 /// Print initial ICV values for testing.
989 /// FIXME: This should be done from the Attributor once it is added.
990 void printICVs() const {
991 InternalControlVar ICVs[] = {ICV_nthreads, ICV_active_levels, ICV_cancel,
992 ICV_proc_bind};
993
994 for (Function *F : SCC) {
995 for (auto ICV : ICVs) {
996 auto ICVInfo = OMPInfoCache.ICVs[ICV];
997 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
998 return ORA << "OpenMP ICV " << ore::NV("OpenMPICV", ICVInfo.Name)
999 << " Value: "
1000 << (ICVInfo.InitValue
1001 ? toString(ICVInfo.InitValue->getValue(), 10, true)
1002 : "IMPLEMENTATION_DEFINED");
1003 };
1004
1005 emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPICVTracker", Remark);
1006 }
1007 }
1008 }
1009
1010 /// Print OpenMP GPU kernels for testing.
1011 void printKernels() const {
1012 for (Function *F : SCC) {
1013 if (!omp::isOpenMPKernel(*F))
1014 continue;
1015
1016 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
1017 return ORA << "OpenMP GPU kernel "
1018 << ore::NV("OpenMPGPUKernel", F->getName()) << "\n";
1019 };
1020
1021 emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPGPU", Remark);
1022 }
1023 }
1024
1025 /// Return the call if \p U is a callee use in a regular call. If \p RFI is
1026 /// given it has to be the callee or a nullptr is returned.
1027 static CallInst *getCallIfRegularCall(
1028 Use &U, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
1029 CallInst *CI = dyn_cast<CallInst>(U.getUser());
1030 if (CI && CI->isCallee(&U) && !CI->hasOperandBundles() &&
1031 (!RFI ||
1032 (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
1033 return CI;
1034 return nullptr;
1035 }
1036
1037 /// Return the call if \p V is a regular call. If \p RFI is given it has to be
1038 /// the callee or a nullptr is returned.
1039 static CallInst *getCallIfRegularCall(
1040 Value &V, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
1041 CallInst *CI = dyn_cast<CallInst>(&V);
1042 if (CI && !CI->hasOperandBundles() &&
1043 (!RFI ||
1044 (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
1045 return CI;
1046 return nullptr;
1047 }
1048
1049private:
1050 /// Merge parallel regions when it is safe.
1051 bool mergeParallelRegions() {
1052 const unsigned CallbackCalleeOperand = 2;
1053 const unsigned CallbackFirstArgOperand = 3;
1054 using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
1055
1056 // Check if there are any __kmpc_fork_call calls to merge.
1057 OMPInformationCache::RuntimeFunctionInfo &RFI =
1058 OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
1059
1060 if (!RFI.Declaration)
1061 return false;
1062
1063 // Unmergable calls that prevent merging a parallel region.
1064 OMPInformationCache::RuntimeFunctionInfo UnmergableCallsInfo[] = {
1065 OMPInfoCache.RFIs[OMPRTL___kmpc_push_proc_bind],
1066 OMPInfoCache.RFIs[OMPRTL___kmpc_push_num_threads],
1067 };
1068
1069 bool Changed = false;
1070 LoopInfo *LI = nullptr;
1071 DominatorTree *DT = nullptr;
1072
1074
1075 BasicBlock *StartBB = nullptr, *EndBB = nullptr;
1076 auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
1077 BasicBlock *CGStartBB = CodeGenIP.getBlock();
1078 BasicBlock *CGEndBB =
1079 SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
1080 assert(StartBB != nullptr && "StartBB should not be null");
1081 CGStartBB->getTerminator()->setSuccessor(0, StartBB);
1082 assert(EndBB != nullptr && "EndBB should not be null");
1083 EndBB->getTerminator()->setSuccessor(0, CGEndBB);
1084 };
1085
1086 auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, Value &,
1087 Value &Inner, Value *&ReplacementValue) -> InsertPointTy {
1088 ReplacementValue = &Inner;
1089 return CodeGenIP;
1090 };
1091
1092 auto FiniCB = [&](InsertPointTy CodeGenIP) {};
1093
1094 /// Create a sequential execution region within a merged parallel region,
1095 /// encapsulated in a master construct with a barrier for synchronization.
1096 auto CreateSequentialRegion = [&](Function *OuterFn,
1097 BasicBlock *OuterPredBB,
1098 Instruction *SeqStartI,
1099 Instruction *SeqEndI) {
1100 // Isolate the instructions of the sequential region to a separate
1101 // block.
1102 BasicBlock *ParentBB = SeqStartI->getParent();
1103 BasicBlock *SeqEndBB =
1104 SplitBlock(ParentBB, SeqEndI->getNextNode(), DT, LI);
1105 BasicBlock *SeqAfterBB =
1106 SplitBlock(SeqEndBB, &*SeqEndBB->getFirstInsertionPt(), DT, LI);
1107 BasicBlock *SeqStartBB =
1108 SplitBlock(ParentBB, SeqStartI, DT, LI, nullptr, "seq.par.merged");
1109
1110 assert(ParentBB->getUniqueSuccessor() == SeqStartBB &&
1111 "Expected a different CFG");
1112 const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
1113 ParentBB->getTerminator()->eraseFromParent();
1114
1115 auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
1116 BasicBlock *CGStartBB = CodeGenIP.getBlock();
1117 BasicBlock *CGEndBB =
1118 SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
1119 assert(SeqStartBB != nullptr && "SeqStartBB should not be null");
1120 CGStartBB->getTerminator()->setSuccessor(0, SeqStartBB);
1121 assert(SeqEndBB != nullptr && "SeqEndBB should not be null");
1122 SeqEndBB->getTerminator()->setSuccessor(0, CGEndBB);
1123 };
1124 auto FiniCB = [&](InsertPointTy CodeGenIP) {};
1125
1126 // Find outputs from the sequential region to outside users and
1127 // broadcast their values to them.
1128 for (Instruction &I : *SeqStartBB) {
1129 SmallPtrSet<Instruction *, 4> OutsideUsers;
1130 for (User *Usr : I.users()) {
1131 Instruction &UsrI = *cast<Instruction>(Usr);
1132 // Ignore outputs to LT intrinsics, code extraction for the merged
1133 // parallel region will fix them.
1134 if (UsrI.isLifetimeStartOrEnd())
1135 continue;
1136
1137 if (UsrI.getParent() != SeqStartBB)
1138 OutsideUsers.insert(&UsrI);
1139 }
1140
1141 if (OutsideUsers.empty())
1142 continue;
1143
1144 // Emit an alloca in the outer region to store the broadcasted
1145 // value.
1146 const DataLayout &DL = M.getDataLayout();
1147 AllocaInst *AllocaI = new AllocaInst(
1148 I.getType(), DL.getAllocaAddrSpace(), nullptr,
1149 I.getName() + ".seq.output.alloc", OuterFn->front().begin());
1150
1151 // Emit a store instruction in the sequential BB to update the
1152 // value.
1153 new StoreInst(&I, AllocaI, SeqStartBB->getTerminator()->getIterator());
1154
1155 // Emit a load instruction and replace the use of the output value
1156 // with it.
1157 for (Instruction *UsrI : OutsideUsers) {
1158 LoadInst *LoadI = new LoadInst(I.getType(), AllocaI,
1159 I.getName() + ".seq.output.load",
1160 UsrI->getIterator());
1161 UsrI->replaceUsesOfWith(&I, LoadI);
1162 }
1163 }
1164
1166 InsertPointTy(ParentBB, ParentBB->end()), DL);
1167 InsertPointTy SeqAfterIP =
1168 OMPInfoCache.OMPBuilder.createMaster(Loc, BodyGenCB, FiniCB);
1169
1170 OMPInfoCache.OMPBuilder.createBarrier(SeqAfterIP, OMPD_parallel);
1171
1172 BranchInst::Create(SeqAfterBB, SeqAfterIP.getBlock());
1173
1174 LLVM_DEBUG(dbgs() << TAG << "After sequential inlining " << *OuterFn
1175 << "\n");
1176 };
1177
1178 // Helper to merge the __kmpc_fork_call calls in MergableCIs. They are all
1179 // contained in BB and only separated by instructions that can be
1180 // redundantly executed in parallel. The block BB is split before the first
1181 // call (in MergableCIs) and after the last so the entire region we merge
1182 // into a single parallel region is contained in a single basic block
1183 // without any other instructions. We use the OpenMPIRBuilder to outline
1184 // that block and call the resulting function via __kmpc_fork_call.
1185 auto Merge = [&](const SmallVectorImpl<CallInst *> &MergableCIs,
1186 BasicBlock *BB) {
1187 // TODO: Change the interface to allow single CIs expanded, e.g, to
1188 // include an outer loop.
1189 assert(MergableCIs.size() > 1 && "Assumed multiple mergable CIs");
1190
1191 auto Remark = [&](OptimizationRemark OR) {
1192 OR << "Parallel region merged with parallel region"
1193 << (MergableCIs.size() > 2 ? "s" : "") << " at ";
1194 for (auto *CI : llvm::drop_begin(MergableCIs)) {
1195 OR << ore::NV("OpenMPParallelMerge", CI->getDebugLoc());
1196 if (CI != MergableCIs.back())
1197 OR << ", ";
1198 }
1199 return OR << ".";
1200 };
1201
1202 emitRemark<OptimizationRemark>(MergableCIs.front(), "OMP150", Remark);
1203
1204 Function *OriginalFn = BB->getParent();
1205 LLVM_DEBUG(dbgs() << TAG << "Merge " << MergableCIs.size()
1206 << " parallel regions in " << OriginalFn->getName()
1207 << "\n");
1208
1209 // Isolate the calls to merge in a separate block.
1210 EndBB = SplitBlock(BB, MergableCIs.back()->getNextNode(), DT, LI);
1211 BasicBlock *AfterBB =
1212 SplitBlock(EndBB, &*EndBB->getFirstInsertionPt(), DT, LI);
1213 StartBB = SplitBlock(BB, MergableCIs.front(), DT, LI, nullptr,
1214 "omp.par.merged");
1215
1216 assert(BB->getUniqueSuccessor() == StartBB && "Expected a different CFG");
1217 const DebugLoc DL = BB->getTerminator()->getDebugLoc();
1218 BB->getTerminator()->eraseFromParent();
1219
1220 // Create sequential regions for sequential instructions that are
1221 // in-between mergable parallel regions.
1222 for (auto *It = MergableCIs.begin(), *End = MergableCIs.end() - 1;
1223 It != End; ++It) {
1224 Instruction *ForkCI = *It;
1225 Instruction *NextForkCI = *(It + 1);
1226
1227 // Continue if there are not in-between instructions.
1228 if (ForkCI->getNextNode() == NextForkCI)
1229 continue;
1230
1231 CreateSequentialRegion(OriginalFn, BB, ForkCI->getNextNode(),
1232 NextForkCI->getPrevNode());
1233 }
1234
1235 OpenMPIRBuilder::LocationDescription Loc(InsertPointTy(BB, BB->end()),
1236 DL);
1237 IRBuilder<>::InsertPoint AllocaIP(
1238 &OriginalFn->getEntryBlock(),
1239 OriginalFn->getEntryBlock().getFirstInsertionPt());
1240 // Create the merged parallel region with default proc binding, to
1241 // avoid overriding binding settings, and without explicit cancellation.
1242 InsertPointTy AfterIP = OMPInfoCache.OMPBuilder.createParallel(
1243 Loc, AllocaIP, BodyGenCB, PrivCB, FiniCB, nullptr, nullptr,
1244 OMP_PROC_BIND_default, /* IsCancellable */ false);
1245 BranchInst::Create(AfterBB, AfterIP.getBlock());
1246
1247 // Perform the actual outlining.
1248 OMPInfoCache.OMPBuilder.finalize(OriginalFn);
1249
1250 Function *OutlinedFn = MergableCIs.front()->getCaller();
1251
1252 // Replace the __kmpc_fork_call calls with direct calls to the outlined
1253 // callbacks.
1255 for (auto *CI : MergableCIs) {
1256 Value *Callee = CI->getArgOperand(CallbackCalleeOperand);
1257 FunctionType *FT = OMPInfoCache.OMPBuilder.ParallelTask;
1258 Args.clear();
1259 Args.push_back(OutlinedFn->getArg(0));
1260 Args.push_back(OutlinedFn->getArg(1));
1261 for (unsigned U = CallbackFirstArgOperand, E = CI->arg_size(); U < E;
1262 ++U)
1263 Args.push_back(CI->getArgOperand(U));
1264
1265 CallInst *NewCI =
1266 CallInst::Create(FT, Callee, Args, "", CI->getIterator());
1267 if (CI->getDebugLoc())
1268 NewCI->setDebugLoc(CI->getDebugLoc());
1269
1270 // Forward parameter attributes from the callback to the callee.
1271 for (unsigned U = CallbackFirstArgOperand, E = CI->arg_size(); U < E;
1272 ++U)
1273 for (const Attribute &A : CI->getAttributes().getParamAttrs(U))
1274 NewCI->addParamAttr(
1275 U - (CallbackFirstArgOperand - CallbackCalleeOperand), A);
1276
1277 // Emit an explicit barrier to replace the implicit fork-join barrier.
1278 if (CI != MergableCIs.back()) {
1279 // TODO: Remove barrier if the merged parallel region includes the
1280 // 'nowait' clause.
1281 OMPInfoCache.OMPBuilder.createBarrier(
1282 InsertPointTy(NewCI->getParent(),
1283 NewCI->getNextNode()->getIterator()),
1284 OMPD_parallel);
1285 }
1286
1287 CI->eraseFromParent();
1288 }
1289
1290 assert(OutlinedFn != OriginalFn && "Outlining failed");
1291 CGUpdater.registerOutlinedFunction(*OriginalFn, *OutlinedFn);
1292 CGUpdater.reanalyzeFunction(*OriginalFn);
1293
1294 NumOpenMPParallelRegionsMerged += MergableCIs.size();
1295
1296 return true;
1297 };
1298
1299 // Helper function that identifes sequences of
1300 // __kmpc_fork_call uses in a basic block.
1301 auto DetectPRsCB = [&](Use &U, Function &F) {
1302 CallInst *CI = getCallIfRegularCall(U, &RFI);
1303 BB2PRMap[CI->getParent()].insert(CI);
1304
1305 return false;
1306 };
1307
1308 BB2PRMap.clear();
1309 RFI.foreachUse(SCC, DetectPRsCB);
1310 SmallVector<SmallVector<CallInst *, 4>, 4> MergableCIsVector;
1311 // Find mergable parallel regions within a basic block that are
1312 // safe to merge, that is any in-between instructions can safely
1313 // execute in parallel after merging.
1314 // TODO: support merging across basic-blocks.
1315 for (auto &It : BB2PRMap) {
1316 auto &CIs = It.getSecond();
1317 if (CIs.size() < 2)
1318 continue;
1319
1320 BasicBlock *BB = It.getFirst();
1321 SmallVector<CallInst *, 4> MergableCIs;
1322
1323 /// Returns true if the instruction is mergable, false otherwise.
1324 /// A terminator instruction is unmergable by definition since merging
1325 /// works within a BB. Instructions before the mergable region are
1326 /// mergable if they are not calls to OpenMP runtime functions that may
1327 /// set different execution parameters for subsequent parallel regions.
1328 /// Instructions in-between parallel regions are mergable if they are not
1329 /// calls to any non-intrinsic function since that may call a non-mergable
1330 /// OpenMP runtime function.
1331 auto IsMergable = [&](Instruction &I, bool IsBeforeMergableRegion) {
1332 // We do not merge across BBs, hence return false (unmergable) if the
1333 // instruction is a terminator.
1334 if (I.isTerminator())
1335 return false;
1336
1337 if (!isa<CallInst>(&I))
1338 return true;
1339
1340 CallInst *CI = cast<CallInst>(&I);
1341 if (IsBeforeMergableRegion) {
1342 Function *CalledFunction = CI->getCalledFunction();
1343 if (!CalledFunction)
1344 return false;
1345 // Return false (unmergable) if the call before the parallel
1346 // region calls an explicit affinity (proc_bind) or number of
1347 // threads (num_threads) compiler-generated function. Those settings
1348 // may be incompatible with following parallel regions.
1349 // TODO: ICV tracking to detect compatibility.
1350 for (const auto &RFI : UnmergableCallsInfo) {
1351 if (CalledFunction == RFI.Declaration)
1352 return false;
1353 }
1354 } else {
1355 // Return false (unmergable) if there is a call instruction
1356 // in-between parallel regions when it is not an intrinsic. It
1357 // may call an unmergable OpenMP runtime function in its callpath.
1358 // TODO: Keep track of possible OpenMP calls in the callpath.
1359 if (!isa<IntrinsicInst>(CI))
1360 return false;
1361 }
1362
1363 return true;
1364 };
1365 // Find maximal number of parallel region CIs that are safe to merge.
1366 for (auto It = BB->begin(), End = BB->end(); It != End;) {
1367 Instruction &I = *It;
1368 ++It;
1369
1370 if (CIs.count(&I)) {
1371 MergableCIs.push_back(cast<CallInst>(&I));
1372 continue;
1373 }
1374
1375 // Continue expanding if the instruction is mergable.
1376 if (IsMergable(I, MergableCIs.empty()))
1377 continue;
1378
1379 // Forward the instruction iterator to skip the next parallel region
1380 // since there is an unmergable instruction which can affect it.
1381 for (; It != End; ++It) {
1382 Instruction &SkipI = *It;
1383 if (CIs.count(&SkipI)) {
1384 LLVM_DEBUG(dbgs() << TAG << "Skip parallel region " << SkipI
1385 << " due to " << I << "\n");
1386 ++It;
1387 break;
1388 }
1389 }
1390
1391 // Store mergable regions found.
1392 if (MergableCIs.size() > 1) {
1393 MergableCIsVector.push_back(MergableCIs);
1394 LLVM_DEBUG(dbgs() << TAG << "Found " << MergableCIs.size()
1395 << " parallel regions in block " << BB->getName()
1396 << " of function " << BB->getParent()->getName()
1397 << "\n";);
1398 }
1399
1400 MergableCIs.clear();
1401 }
1402
1403 if (!MergableCIsVector.empty()) {
1404 Changed = true;
1405
1406 for (auto &MergableCIs : MergableCIsVector)
1407 Merge(MergableCIs, BB);
1408 MergableCIsVector.clear();
1409 }
1410 }
1411
1412 if (Changed) {
1413 /// Re-collect use for fork calls, emitted barrier calls, and
1414 /// any emitted master/end_master calls.
1415 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_fork_call);
1416 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_barrier);
1417 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_master);
1418 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_end_master);
1419 }
1420
1421 return Changed;
1422 }
1423
1424 /// Try to delete parallel regions if possible.
1425 bool deleteParallelRegions() {
1426 const unsigned CallbackCalleeOperand = 2;
1427
1428 OMPInformationCache::RuntimeFunctionInfo &RFI =
1429 OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
1430
1431 if (!RFI.Declaration)
1432 return false;
1433
1434 bool Changed = false;
1435 auto DeleteCallCB = [&](Use &U, Function &) {
1436 CallInst *CI = getCallIfRegularCall(U);
1437 if (!CI)
1438 return false;
1439 auto *Fn = dyn_cast<Function>(
1440 CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts());
1441 if (!Fn)
1442 return false;
1443 if (!Fn->onlyReadsMemory())
1444 return false;
1445 if (!Fn->hasFnAttribute(Attribute::WillReturn))
1446 return false;
1447
1448 LLVM_DEBUG(dbgs() << TAG << "Delete read-only parallel region in "
1449 << CI->getCaller()->getName() << "\n");
1450
1451 auto Remark = [&](OptimizationRemark OR) {
1452 return OR << "Removing parallel region with no side-effects.";
1453 };
1454 emitRemark<OptimizationRemark>(CI, "OMP160", Remark);
1455
1456 CI->eraseFromParent();
1457 Changed = true;
1458 ++NumOpenMPParallelRegionsDeleted;
1459 return true;
1460 };
1461
1462 RFI.foreachUse(SCC, DeleteCallCB);
1463
1464 return Changed;
1465 }
1466
1467 /// Try to eliminate runtime calls by reusing existing ones.
1468 bool deduplicateRuntimeCalls() {
1469 bool Changed = false;
1470
1471 RuntimeFunction DeduplicableRuntimeCallIDs[] = {
1472 OMPRTL_omp_get_num_threads,
1473 OMPRTL_omp_in_parallel,
1474 OMPRTL_omp_get_cancellation,
1475 OMPRTL_omp_get_supported_active_levels,
1476 OMPRTL_omp_get_level,
1477 OMPRTL_omp_get_ancestor_thread_num,
1478 OMPRTL_omp_get_team_size,
1479 OMPRTL_omp_get_active_level,
1480 OMPRTL_omp_in_final,
1481 OMPRTL_omp_get_proc_bind,
1482 OMPRTL_omp_get_num_places,
1483 OMPRTL_omp_get_num_procs,
1484 OMPRTL_omp_get_place_num,
1485 OMPRTL_omp_get_partition_num_places,
1486 OMPRTL_omp_get_partition_place_nums};
1487
1488 // Global-tid is handled separately.
1490 collectGlobalThreadIdArguments(GTIdArgs);
1491 LLVM_DEBUG(dbgs() << TAG << "Found " << GTIdArgs.size()
1492 << " global thread ID arguments\n");
1493
1494 for (Function *F : SCC) {
1495 for (auto DeduplicableRuntimeCallID : DeduplicableRuntimeCallIDs)
1496 Changed |= deduplicateRuntimeCalls(
1497 *F, OMPInfoCache.RFIs[DeduplicableRuntimeCallID]);
1498
1499 // __kmpc_global_thread_num is special as we can replace it with an
1500 // argument in enough cases to make it worth trying.
1501 Value *GTIdArg = nullptr;
1502 for (Argument &Arg : F->args())
1503 if (GTIdArgs.count(&Arg)) {
1504 GTIdArg = &Arg;
1505 break;
1506 }
1507 Changed |= deduplicateRuntimeCalls(
1508 *F, OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num], GTIdArg);
1509 }
1510
1511 return Changed;
1512 }
1513
1514 /// Tries to remove known runtime symbols that are optional from the module.
1515 bool removeRuntimeSymbols() {
1516 // The RPC client symbol is defined in `libc` and indicates that something
1517 // required an RPC server. If its users were all optimized out then we can
1518 // safely remove it.
1519 // TODO: This should be somewhere more common in the future.
1520 if (GlobalVariable *GV = M.getNamedGlobal("__llvm_libc_rpc_client")) {
1521 if (!GV->getType()->isPointerTy())
1522 return false;
1523
1524 Constant *C = GV->getInitializer();
1525 if (!C)
1526 return false;
1527
1528 // Check to see if the only user of the RPC client is the external handle.
1529 GlobalVariable *Client = dyn_cast<GlobalVariable>(C->stripPointerCasts());
1530 if (!Client || Client->getNumUses() > 1 ||
1531 Client->user_back() != GV->getInitializer())
1532 return false;
1533
1534 Client->replaceAllUsesWith(PoisonValue::get(Client->getType()));
1535 Client->eraseFromParent();
1536
1537 GV->replaceAllUsesWith(PoisonValue::get(GV->getType()));
1538 GV->eraseFromParent();
1539
1540 return true;
1541 }
1542 return false;
1543 }
1544
1545 /// Tries to hide the latency of runtime calls that involve host to
1546 /// device memory transfers by splitting them into their "issue" and "wait"
1547 /// versions. The "issue" is moved upwards as much as possible. The "wait" is
1548 /// moved downards as much as possible. The "issue" issues the memory transfer
1549 /// asynchronously, returning a handle. The "wait" waits in the returned
1550 /// handle for the memory transfer to finish.
1551 bool hideMemTransfersLatency() {
1552 auto &RFI = OMPInfoCache.RFIs[OMPRTL___tgt_target_data_begin_mapper];
1553 bool Changed = false;
1554 auto SplitMemTransfers = [&](Use &U, Function &Decl) {
1555 auto *RTCall = getCallIfRegularCall(U, &RFI);
1556 if (!RTCall)
1557 return false;
1558
1559 OffloadArray OffloadArrays[3];
1560 if (!getValuesInOffloadArrays(*RTCall, OffloadArrays))
1561 return false;
1562
1563 LLVM_DEBUG(dumpValuesInOffloadArrays(OffloadArrays));
1564
1565 // TODO: Check if can be moved upwards.
1566 bool WasSplit = false;
1567 Instruction *WaitMovementPoint = canBeMovedDownwards(*RTCall);
1568 if (WaitMovementPoint)
1569 WasSplit = splitTargetDataBeginRTC(*RTCall, *WaitMovementPoint);
1570
1571 Changed |= WasSplit;
1572 return WasSplit;
1573 };
1574 if (OMPInfoCache.runtimeFnsAvailable(
1575 {OMPRTL___tgt_target_data_begin_mapper_issue,
1576 OMPRTL___tgt_target_data_begin_mapper_wait}))
1577 RFI.foreachUse(SCC, SplitMemTransfers);
1578
1579 return Changed;
1580 }
1581
1582 void analysisGlobalization() {
1583 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
1584
1585 auto CheckGlobalization = [&](Use &U, Function &Decl) {
1586 if (CallInst *CI = getCallIfRegularCall(U, &RFI)) {
1587 auto Remark = [&](OptimizationRemarkMissed ORM) {
1588 return ORM
1589 << "Found thread data sharing on the GPU. "
1590 << "Expect degraded performance due to data globalization.";
1591 };
1592 emitRemark<OptimizationRemarkMissed>(CI, "OMP112", Remark);
1593 }
1594
1595 return false;
1596 };
1597
1598 RFI.foreachUse(SCC, CheckGlobalization);
1599 }
1600
1601 /// Maps the values stored in the offload arrays passed as arguments to
1602 /// \p RuntimeCall into the offload arrays in \p OAs.
1603 bool getValuesInOffloadArrays(CallInst &RuntimeCall,
1605 assert(OAs.size() == 3 && "Need space for three offload arrays!");
1606
1607 // A runtime call that involves memory offloading looks something like:
1608 // call void @__tgt_target_data_begin_mapper(arg0, arg1,
1609 // i8** %offload_baseptrs, i8** %offload_ptrs, i64* %offload_sizes,
1610 // ...)
1611 // So, the idea is to access the allocas that allocate space for these
1612 // offload arrays, offload_baseptrs, offload_ptrs, offload_sizes.
1613 // Therefore:
1614 // i8** %offload_baseptrs.
1615 Value *BasePtrsArg =
1616 RuntimeCall.getArgOperand(OffloadArray::BasePtrsArgNum);
1617 // i8** %offload_ptrs.
1618 Value *PtrsArg = RuntimeCall.getArgOperand(OffloadArray::PtrsArgNum);
1619 // i8** %offload_sizes.
1620 Value *SizesArg = RuntimeCall.getArgOperand(OffloadArray::SizesArgNum);
1621
1622 // Get values stored in **offload_baseptrs.
1623 auto *V = getUnderlyingObject(BasePtrsArg);
1624 if (!isa<AllocaInst>(V))
1625 return false;
1626 auto *BasePtrsArray = cast<AllocaInst>(V);
1627 if (!OAs[0].initialize(*BasePtrsArray, RuntimeCall))
1628 return false;
1629
1630 // Get values stored in **offload_baseptrs.
1631 V = getUnderlyingObject(PtrsArg);
1632 if (!isa<AllocaInst>(V))
1633 return false;
1634 auto *PtrsArray = cast<AllocaInst>(V);
1635 if (!OAs[1].initialize(*PtrsArray, RuntimeCall))
1636 return false;
1637
1638 // Get values stored in **offload_sizes.
1639 V = getUnderlyingObject(SizesArg);
1640 // If it's a [constant] global array don't analyze it.
1641 if (isa<GlobalValue>(V))
1642 return isa<Constant>(V);
1643 if (!isa<AllocaInst>(V))
1644 return false;
1645
1646 auto *SizesArray = cast<AllocaInst>(V);
1647 if (!OAs[2].initialize(*SizesArray, RuntimeCall))
1648 return false;
1649
1650 return true;
1651 }
1652
1653 /// Prints the values in the OffloadArrays \p OAs using LLVM_DEBUG.
1654 /// For now this is a way to test that the function getValuesInOffloadArrays
1655 /// is working properly.
1656 /// TODO: Move this to a unittest when unittests are available for OpenMPOpt.
1657 void dumpValuesInOffloadArrays(ArrayRef<OffloadArray> OAs) {
1658 assert(OAs.size() == 3 && "There are three offload arrays to debug!");
1659
1660 LLVM_DEBUG(dbgs() << TAG << " Successfully got offload values:\n");
1661 std::string ValuesStr;
1662 raw_string_ostream Printer(ValuesStr);
1663 std::string Separator = " --- ";
1664
1665 for (auto *BP : OAs[0].StoredValues) {
1666 BP->print(Printer);
1667 Printer << Separator;
1668 }
1669 LLVM_DEBUG(dbgs() << "\t\toffload_baseptrs: " << ValuesStr << "\n");
1670 ValuesStr.clear();
1671
1672 for (auto *P : OAs[1].StoredValues) {
1673 P->print(Printer);
1674 Printer << Separator;
1675 }
1676 LLVM_DEBUG(dbgs() << "\t\toffload_ptrs: " << ValuesStr << "\n");
1677 ValuesStr.clear();
1678
1679 for (auto *S : OAs[2].StoredValues) {
1680 S->print(Printer);
1681 Printer << Separator;
1682 }
1683 LLVM_DEBUG(dbgs() << "\t\toffload_sizes: " << ValuesStr << "\n");
1684 }
1685
1686 /// Returns the instruction where the "wait" counterpart \p RuntimeCall can be
1687 /// moved. Returns nullptr if the movement is not possible, or not worth it.
1688 Instruction *canBeMovedDownwards(CallInst &RuntimeCall) {
1689 // FIXME: This traverses only the BasicBlock where RuntimeCall is.
1690 // Make it traverse the CFG.
1691
1692 Instruction *CurrentI = &RuntimeCall;
1693 bool IsWorthIt = false;
1694 while ((CurrentI = CurrentI->getNextNode())) {
1695
1696 // TODO: Once we detect the regions to be offloaded we should use the
1697 // alias analysis manager to check if CurrentI may modify one of
1698 // the offloaded regions.
1699 if (CurrentI->mayHaveSideEffects() || CurrentI->mayReadFromMemory()) {
1700 if (IsWorthIt)
1701 return CurrentI;
1702
1703 return nullptr;
1704 }
1705
1706 // FIXME: For now if we move it over anything without side effect
1707 // is worth it.
1708 IsWorthIt = true;
1709 }
1710
1711 // Return end of BasicBlock.
1712 return RuntimeCall.getParent()->getTerminator();
1713 }
1714
1715 /// Splits \p RuntimeCall into its "issue" and "wait" counterparts.
1716 bool splitTargetDataBeginRTC(CallInst &RuntimeCall,
1717 Instruction &WaitMovementPoint) {
1718 // Create stack allocated handle (__tgt_async_info) at the beginning of the
1719 // function. Used for storing information of the async transfer, allowing to
1720 // wait on it later.
1721 auto &IRBuilder = OMPInfoCache.OMPBuilder;
1722 Function *F = RuntimeCall.getCaller();
1723 BasicBlock &Entry = F->getEntryBlock();
1724 IRBuilder.Builder.SetInsertPoint(&Entry,
1725 Entry.getFirstNonPHIOrDbgOrAlloca());
1726 Value *Handle = IRBuilder.Builder.CreateAlloca(
1727 IRBuilder.AsyncInfo, /*ArraySize=*/nullptr, "handle");
1728 Handle =
1729 IRBuilder.Builder.CreateAddrSpaceCast(Handle, IRBuilder.AsyncInfoPtr);
1730
1731 // Add "issue" runtime call declaration:
1732 // declare %struct.tgt_async_info @__tgt_target_data_begin_issue(i64, i32,
1733 // i8**, i8**, i64*, i64*)
1734 FunctionCallee IssueDecl = IRBuilder.getOrCreateRuntimeFunction(
1735 M, OMPRTL___tgt_target_data_begin_mapper_issue);
1736
1737 // Change RuntimeCall call site for its asynchronous version.
1739 for (auto &Arg : RuntimeCall.args())
1740 Args.push_back(Arg.get());
1741 Args.push_back(Handle);
1742
1743 CallInst *IssueCallsite = CallInst::Create(IssueDecl, Args, /*NameStr=*/"",
1744 RuntimeCall.getIterator());
1745 OMPInfoCache.setCallingConvention(IssueDecl, IssueCallsite);
1746 RuntimeCall.eraseFromParent();
1747
1748 // Add "wait" runtime call declaration:
1749 // declare void @__tgt_target_data_begin_wait(i64, %struct.__tgt_async_info)
1750 FunctionCallee WaitDecl = IRBuilder.getOrCreateRuntimeFunction(
1751 M, OMPRTL___tgt_target_data_begin_mapper_wait);
1752
1753 Value *WaitParams[2] = {
1754 IssueCallsite->getArgOperand(
1755 OffloadArray::DeviceIDArgNum), // device_id.
1756 Handle // handle to wait on.
1757 };
1758 CallInst *WaitCallsite = CallInst::Create(
1759 WaitDecl, WaitParams, /*NameStr=*/"", WaitMovementPoint.getIterator());
1760 OMPInfoCache.setCallingConvention(WaitDecl, WaitCallsite);
1761
1762 return true;
1763 }
1764
1765 static Value *combinedIdentStruct(Value *CurrentIdent, Value *NextIdent,
1766 bool GlobalOnly, bool &SingleChoice) {
1767 if (CurrentIdent == NextIdent)
1768 return CurrentIdent;
1769
1770 // TODO: Figure out how to actually combine multiple debug locations. For
1771 // now we just keep an existing one if there is a single choice.
1772 if (!GlobalOnly || isa<GlobalValue>(NextIdent)) {
1773 SingleChoice = !CurrentIdent;
1774 return NextIdent;
1775 }
1776 return nullptr;
1777 }
1778
1779 /// Return an `struct ident_t*` value that represents the ones used in the
1780 /// calls of \p RFI inside of \p F. If \p GlobalOnly is true, we will not
1781 /// return a local `struct ident_t*`. For now, if we cannot find a suitable
1782 /// return value we create one from scratch. We also do not yet combine
1783 /// information, e.g., the source locations, see combinedIdentStruct.
1784 Value *
1785 getCombinedIdentFromCallUsesIn(OMPInformationCache::RuntimeFunctionInfo &RFI,
1786 Function &F, bool GlobalOnly) {
1787 bool SingleChoice = true;
1788 Value *Ident = nullptr;
1789 auto CombineIdentStruct = [&](Use &U, Function &Caller) {
1790 CallInst *CI = getCallIfRegularCall(U, &RFI);
1791 if (!CI || &F != &Caller)
1792 return false;
1793 Ident = combinedIdentStruct(Ident, CI->getArgOperand(0),
1794 /* GlobalOnly */ true, SingleChoice);
1795 return false;
1796 };
1797 RFI.foreachUse(SCC, CombineIdentStruct);
1798
1799 if (!Ident || !SingleChoice) {
1800 // The IRBuilder uses the insertion block to get to the module, this is
1801 // unfortunate but we work around it for now.
1802 if (!OMPInfoCache.OMPBuilder.getInsertionPoint().getBlock())
1803 OMPInfoCache.OMPBuilder.updateToLocation(OpenMPIRBuilder::InsertPointTy(
1804 &F.getEntryBlock(), F.getEntryBlock().begin()));
1805 // Create a fallback location if non was found.
1806 // TODO: Use the debug locations of the calls instead.
1807 uint32_t SrcLocStrSize;
1808 Constant *Loc =
1809 OMPInfoCache.OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
1810 Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(Loc, SrcLocStrSize);
1811 }
1812 return Ident;
1813 }
1814
1815 /// Try to eliminate calls of \p RFI in \p F by reusing an existing one or
1816 /// \p ReplVal if given.
1817 bool deduplicateRuntimeCalls(Function &F,
1818 OMPInformationCache::RuntimeFunctionInfo &RFI,
1819 Value *ReplVal = nullptr) {
1820 auto *UV = RFI.getUseVector(F);
1821 if (!UV || UV->size() + (ReplVal != nullptr) < 2)
1822 return false;
1823
1824 LLVM_DEBUG(
1825 dbgs() << TAG << "Deduplicate " << UV->size() << " uses of " << RFI.Name
1826 << (ReplVal ? " with an existing value\n" : "\n") << "\n");
1827
1828 assert((!ReplVal || (isa<Argument>(ReplVal) &&
1829 cast<Argument>(ReplVal)->getParent() == &F)) &&
1830 "Unexpected replacement value!");
1831
1832 // TODO: Use dominance to find a good position instead.
1833 auto CanBeMoved = [this](CallBase &CB) {
1834 unsigned NumArgs = CB.arg_size();
1835 if (NumArgs == 0)
1836 return true;
1837 if (CB.getArgOperand(0)->getType() != OMPInfoCache.OMPBuilder.IdentPtr)
1838 return false;
1839 for (unsigned U = 1; U < NumArgs; ++U)
1840 if (isa<Instruction>(CB.getArgOperand(U)))
1841 return false;
1842 return true;
1843 };
1844
1845 if (!ReplVal) {
1846 auto *DT =
1847 OMPInfoCache.getAnalysisResultForFunction<DominatorTreeAnalysis>(F);
1848 if (!DT)
1849 return false;
1850 Instruction *IP = nullptr;
1851 for (Use *U : *UV) {
1852 if (CallInst *CI = getCallIfRegularCall(*U, &RFI)) {
1853 if (IP)
1854 IP = DT->findNearestCommonDominator(IP, CI);
1855 else
1856 IP = CI;
1857 if (!CanBeMoved(*CI))
1858 continue;
1859 if (!ReplVal)
1860 ReplVal = CI;
1861 }
1862 }
1863 if (!ReplVal)
1864 return false;
1865 assert(IP && "Expected insertion point!");
1866 cast<Instruction>(ReplVal)->moveBefore(IP);
1867 }
1868
1869 // If we use a call as a replacement value we need to make sure the ident is
1870 // valid at the new location. For now we just pick a global one, either
1871 // existing and used by one of the calls, or created from scratch.
1872 if (CallBase *CI = dyn_cast<CallBase>(ReplVal)) {
1873 if (!CI->arg_empty() &&
1874 CI->getArgOperand(0)->getType() == OMPInfoCache.OMPBuilder.IdentPtr) {
1875 Value *Ident = getCombinedIdentFromCallUsesIn(RFI, F,
1876 /* GlobalOnly */ true);
1877 CI->setArgOperand(0, Ident);
1878 }
1879 }
1880
1881 bool Changed = false;
1882 auto ReplaceAndDeleteCB = [&](Use &U, Function &Caller) {
1883 CallInst *CI = getCallIfRegularCall(U, &RFI);
1884 if (!CI || CI == ReplVal || &F != &Caller)
1885 return false;
1886 assert(CI->getCaller() == &F && "Unexpected call!");
1887
1888 auto Remark = [&](OptimizationRemark OR) {
1889 return OR << "OpenMP runtime call "
1890 << ore::NV("OpenMPOptRuntime", RFI.Name) << " deduplicated.";
1891 };
1892 if (CI->getDebugLoc())
1893 emitRemark<OptimizationRemark>(CI, "OMP170", Remark);
1894 else
1895 emitRemark<OptimizationRemark>(&F, "OMP170", Remark);
1896
1897 CI->replaceAllUsesWith(ReplVal);
1898 CI->eraseFromParent();
1899 ++NumOpenMPRuntimeCallsDeduplicated;
1900 Changed = true;
1901 return true;
1902 };
1903 RFI.foreachUse(SCC, ReplaceAndDeleteCB);
1904
1905 return Changed;
1906 }
1907
1908 /// Collect arguments that represent the global thread id in \p GTIdArgs.
1909 void collectGlobalThreadIdArguments(SmallSetVector<Value *, 16> &GTIdArgs) {
1910 // TODO: Below we basically perform a fixpoint iteration with a pessimistic
1911 // initialization. We could define an AbstractAttribute instead and
1912 // run the Attributor here once it can be run as an SCC pass.
1913
1914 // Helper to check the argument \p ArgNo at all call sites of \p F for
1915 // a GTId.
1916 auto CallArgOpIsGTId = [&](Function &F, unsigned ArgNo, CallInst &RefCI) {
1917 if (!F.hasLocalLinkage())
1918 return false;
1919 for (Use &U : F.uses()) {
1920 if (CallInst *CI = getCallIfRegularCall(U)) {
1921 Value *ArgOp = CI->getArgOperand(ArgNo);
1922 if (CI == &RefCI || GTIdArgs.count(ArgOp) ||
1923 getCallIfRegularCall(
1924 *ArgOp, &OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num]))
1925 continue;
1926 }
1927 return false;
1928 }
1929 return true;
1930 };
1931
1932 // Helper to identify uses of a GTId as GTId arguments.
1933 auto AddUserArgs = [&](Value &GTId) {
1934 for (Use &U : GTId.uses())
1935 if (CallInst *CI = dyn_cast<CallInst>(U.getUser()))
1936 if (CI->isArgOperand(&U))
1937 if (Function *Callee = CI->getCalledFunction())
1938 if (CallArgOpIsGTId(*Callee, U.getOperandNo(), *CI))
1939 GTIdArgs.insert(Callee->getArg(U.getOperandNo()));
1940 };
1941
1942 // The argument users of __kmpc_global_thread_num calls are GTIds.
1943 OMPInformationCache::RuntimeFunctionInfo &GlobThreadNumRFI =
1944 OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num];
1945
1946 GlobThreadNumRFI.foreachUse(SCC, [&](Use &U, Function &F) {
1947 if (CallInst *CI = getCallIfRegularCall(U, &GlobThreadNumRFI))
1948 AddUserArgs(*CI);
1949 return false;
1950 });
1951
1952 // Transitively search for more arguments by looking at the users of the
1953 // ones we know already. During the search the GTIdArgs vector is extended
1954 // so we cannot cache the size nor can we use a range based for.
1955 for (unsigned U = 0; U < GTIdArgs.size(); ++U)
1956 AddUserArgs(*GTIdArgs[U]);
1957 }
1958
1959 /// Kernel (=GPU) optimizations and utility functions
1960 ///
1961 ///{{
1962
1963 /// Cache to remember the unique kernel for a function.
1965
1966 /// Find the unique kernel that will execute \p F, if any.
1967 Kernel getUniqueKernelFor(Function &F);
1968
1969 /// Find the unique kernel that will execute \p I, if any.
1970 Kernel getUniqueKernelFor(Instruction &I) {
1971 return getUniqueKernelFor(*I.getFunction());
1972 }
1973
1974 /// Rewrite the device (=GPU) code state machine create in non-SPMD mode in
1975 /// the cases we can avoid taking the address of a function.
1976 bool rewriteDeviceCodeStateMachine();
1977
1978 ///
1979 ///}}
1980
1981 /// Emit a remark generically
1982 ///
1983 /// This template function can be used to generically emit a remark. The
1984 /// RemarkKind should be one of the following:
1985 /// - OptimizationRemark to indicate a successful optimization attempt
1986 /// - OptimizationRemarkMissed to report a failed optimization attempt
1987 /// - OptimizationRemarkAnalysis to provide additional information about an
1988 /// optimization attempt
1989 ///
1990 /// The remark is built using a callback function provided by the caller that
1991 /// takes a RemarkKind as input and returns a RemarkKind.
1992 template <typename RemarkKind, typename RemarkCallBack>
1993 void emitRemark(Instruction *I, StringRef RemarkName,
1994 RemarkCallBack &&RemarkCB) const {
1995 Function *F = I->getParent()->getParent();
1996 auto &ORE = OREGetter(F);
1997
1998 if (RemarkName.starts_with("OMP"))
1999 ORE.emit([&]() {
2000 return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I))
2001 << " [" << RemarkName << "]";
2002 });
2003 else
2004 ORE.emit(
2005 [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I)); });
2006 }
2007
2008 /// Emit a remark on a function.
2009 template <typename RemarkKind, typename RemarkCallBack>
2010 void emitRemark(Function *F, StringRef RemarkName,
2011 RemarkCallBack &&RemarkCB) const {
2012 auto &ORE = OREGetter(F);
2013
2014 if (RemarkName.starts_with("OMP"))
2015 ORE.emit([&]() {
2016 return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F))
2017 << " [" << RemarkName << "]";
2018 });
2019 else
2020 ORE.emit(
2021 [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F)); });
2022 }
2023
2024 /// The underlying module.
2025 Module &M;
2026
2027 /// The SCC we are operating on.
2029
2030 /// Callback to update the call graph, the first argument is a removed call,
2031 /// the second an optional replacement call.
2032 CallGraphUpdater &CGUpdater;
2033
2034 /// Callback to get an OptimizationRemarkEmitter from a Function *
2035 OptimizationRemarkGetter OREGetter;
2036
2037 /// OpenMP-specific information cache. Also Used for Attributor runs.
2038 OMPInformationCache &OMPInfoCache;
2039
2040 /// Attributor instance.
2041 Attributor &A;
2042
2043 /// Helper function to run Attributor on SCC.
2044 bool runAttributor(bool IsModulePass) {
2045 if (SCC.empty())
2046 return false;
2047
2048 registerAAs(IsModulePass);
2049
2050 ChangeStatus Changed = A.run();
2051
2052 LLVM_DEBUG(dbgs() << "[Attributor] Done with " << SCC.size()
2053 << " functions, result: " << Changed << ".\n");
2054
2055 if (Changed == ChangeStatus::CHANGED)
2056 OMPInfoCache.invalidateAnalyses();
2057
2058 return Changed == ChangeStatus::CHANGED;
2059 }
2060
2061 void registerFoldRuntimeCall(RuntimeFunction RF);
2062
2063 /// Populate the Attributor with abstract attribute opportunities in the
2064 /// functions.
2065 void registerAAs(bool IsModulePass);
2066
2067public:
2068 /// Callback to register AAs for live functions, including internal functions
2069 /// marked live during the traversal.
2070 static void registerAAsForFunction(Attributor &A, const Function &F);
2071};
2072
2073Kernel OpenMPOpt::getUniqueKernelFor(Function &F) {
2074 if (OMPInfoCache.CGSCC && !OMPInfoCache.CGSCC->empty() &&
2075 !OMPInfoCache.CGSCC->contains(&F))
2076 return nullptr;
2077
2078 // Use a scope to keep the lifetime of the CachedKernel short.
2079 {
2080 std::optional<Kernel> &CachedKernel = UniqueKernelMap[&F];
2081 if (CachedKernel)
2082 return *CachedKernel;
2083
2084 // TODO: We should use an AA to create an (optimistic and callback
2085 // call-aware) call graph. For now we stick to simple patterns that
2086 // are less powerful, basically the worst fixpoint.
2087 if (isOpenMPKernel(F)) {
2088 CachedKernel = Kernel(&F);
2089 return *CachedKernel;
2090 }
2091
2092 CachedKernel = nullptr;
2093 if (!F.hasLocalLinkage()) {
2094
2095 // See https://openmp.llvm.org/remarks/OptimizationRemarks.html
2096 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2097 return ORA << "Potentially unknown OpenMP target region caller.";
2098 };
2099 emitRemark<OptimizationRemarkAnalysis>(&F, "OMP100", Remark);
2100
2101 return nullptr;
2102 }
2103 }
2104
2105 auto GetUniqueKernelForUse = [&](const Use &U) -> Kernel {
2106 if (auto *Cmp = dyn_cast<ICmpInst>(U.getUser())) {
2107 // Allow use in equality comparisons.
2108 if (Cmp->isEquality())
2109 return getUniqueKernelFor(*Cmp);
2110 return nullptr;
2111 }
2112 if (auto *CB = dyn_cast<CallBase>(U.getUser())) {
2113 // Allow direct calls.
2114 if (CB->isCallee(&U))
2115 return getUniqueKernelFor(*CB);
2116
2117 OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
2118 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
2119 // Allow the use in __kmpc_parallel_51 calls.
2120 if (OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI))
2121 return getUniqueKernelFor(*CB);
2122 return nullptr;
2123 }
2124 // Disallow every other use.
2125 return nullptr;
2126 };
2127
2128 // TODO: In the future we want to track more than just a unique kernel.
2129 SmallPtrSet<Kernel, 2> PotentialKernels;
2130 OMPInformationCache::foreachUse(F, [&](const Use &U) {
2131 PotentialKernels.insert(GetUniqueKernelForUse(U));
2132 });
2133
2134 Kernel K = nullptr;
2135 if (PotentialKernels.size() == 1)
2136 K = *PotentialKernels.begin();
2137
2138 // Cache the result.
2139 UniqueKernelMap[&F] = K;
2140
2141 return K;
2142}
2143
2144bool OpenMPOpt::rewriteDeviceCodeStateMachine() {
2145 OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
2146 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
2147
2148 bool Changed = false;
2149 if (!KernelParallelRFI)
2150 return Changed;
2151
2152 // If we have disabled state machine changes, exit
2154 return Changed;
2155
2156 for (Function *F : SCC) {
2157
2158 // Check if the function is a use in a __kmpc_parallel_51 call at
2159 // all.
2160 bool UnknownUse = false;
2161 bool KernelParallelUse = false;
2162 unsigned NumDirectCalls = 0;
2163
2164 SmallVector<Use *, 2> ToBeReplacedStateMachineUses;
2165 OMPInformationCache::foreachUse(*F, [&](Use &U) {
2166 if (auto *CB = dyn_cast<CallBase>(U.getUser()))
2167 if (CB->isCallee(&U)) {
2168 ++NumDirectCalls;
2169 return;
2170 }
2171
2172 if (isa<ICmpInst>(U.getUser())) {
2173 ToBeReplacedStateMachineUses.push_back(&U);
2174 return;
2175 }
2176
2177 // Find wrapper functions that represent parallel kernels.
2178 CallInst *CI =
2179 OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI);
2180 const unsigned int WrapperFunctionArgNo = 6;
2181 if (!KernelParallelUse && CI &&
2182 CI->getArgOperandNo(&U) == WrapperFunctionArgNo) {
2183 KernelParallelUse = true;
2184 ToBeReplacedStateMachineUses.push_back(&U);
2185 return;
2186 }
2187 UnknownUse = true;
2188 });
2189
2190 // Do not emit a remark if we haven't seen a __kmpc_parallel_51
2191 // use.
2192 if (!KernelParallelUse)
2193 continue;
2194
2195 // If this ever hits, we should investigate.
2196 // TODO: Checking the number of uses is not a necessary restriction and
2197 // should be lifted.
2198 if (UnknownUse || NumDirectCalls != 1 ||
2199 ToBeReplacedStateMachineUses.size() > 2) {
2200 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2201 return ORA << "Parallel region is used in "
2202 << (UnknownUse ? "unknown" : "unexpected")
2203 << " ways. Will not attempt to rewrite the state machine.";
2204 };
2205 emitRemark<OptimizationRemarkAnalysis>(F, "OMP101", Remark);
2206 continue;
2207 }
2208
2209 // Even if we have __kmpc_parallel_51 calls, we (for now) give
2210 // up if the function is not called from a unique kernel.
2211 Kernel K = getUniqueKernelFor(*F);
2212 if (!K) {
2213 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2214 return ORA << "Parallel region is not called from a unique kernel. "
2215 "Will not attempt to rewrite the state machine.";
2216 };
2217 emitRemark<OptimizationRemarkAnalysis>(F, "OMP102", Remark);
2218 continue;
2219 }
2220
2221 // We now know F is a parallel body function called only from the kernel K.
2222 // We also identified the state machine uses in which we replace the
2223 // function pointer by a new global symbol for identification purposes. This
2224 // ensures only direct calls to the function are left.
2225
2226 Module &M = *F->getParent();
2227 Type *Int8Ty = Type::getInt8Ty(M.getContext());
2228
2229 auto *ID = new GlobalVariable(
2230 M, Int8Ty, /* isConstant */ true, GlobalValue::PrivateLinkage,
2231 UndefValue::get(Int8Ty), F->getName() + ".ID");
2232
2233 for (Use *U : ToBeReplacedStateMachineUses)
2235 ID, U->get()->getType()));
2236
2237 ++NumOpenMPParallelRegionsReplacedInGPUStateMachine;
2238
2239 Changed = true;
2240 }
2241
2242 return Changed;
2243}
2244
2245/// Abstract Attribute for tracking ICV values.
2246struct AAICVTracker : public StateWrapper<BooleanState, AbstractAttribute> {
2248 AAICVTracker(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
2249
2250 /// Returns true if value is assumed to be tracked.
2251 bool isAssumedTracked() const { return getAssumed(); }
2252
2253 /// Returns true if value is known to be tracked.
2254 bool isKnownTracked() const { return getAssumed(); }
2255
2256 /// Create an abstract attribute biew for the position \p IRP.
2257 static AAICVTracker &createForPosition(const IRPosition &IRP, Attributor &A);
2258
2259 /// Return the value with which \p I can be replaced for specific \p ICV.
2260 virtual std::optional<Value *> getReplacementValue(InternalControlVar ICV,
2261 const Instruction *I,
2262 Attributor &A) const {
2263 return std::nullopt;
2264 }
2265
2266 /// Return an assumed unique ICV value if a single candidate is found. If
2267 /// there cannot be one, return a nullptr. If it is not clear yet, return
2268 /// std::nullopt.
2269 virtual std::optional<Value *>
2270 getUniqueReplacementValue(InternalControlVar ICV) const = 0;
2271
2272 // Currently only nthreads is being tracked.
2273 // this array will only grow with time.
2274 InternalControlVar TrackableICVs[1] = {ICV_nthreads};
2275
2276 /// See AbstractAttribute::getName()
2277 const std::string getName() const override { return "AAICVTracker"; }
2278
2279 /// See AbstractAttribute::getIdAddr()
2280 const char *getIdAddr() const override { return &ID; }
2281
2282 /// This function should return true if the type of the \p AA is AAICVTracker
2283 static bool classof(const AbstractAttribute *AA) {
2284 return (AA->getIdAddr() == &ID);
2285 }
2286
2287 static const char ID;
2288};
2289
2290struct AAICVTrackerFunction : public AAICVTracker {
2291 AAICVTrackerFunction(const IRPosition &IRP, Attributor &A)
2292 : AAICVTracker(IRP, A) {}
2293
2294 // FIXME: come up with better string.
2295 const std::string getAsStr(Attributor *) const override {
2296 return "ICVTrackerFunction";
2297 }
2298
2299 // FIXME: come up with some stats.
2300 void trackStatistics() const override {}
2301
2302 /// We don't manifest anything for this AA.
2303 ChangeStatus manifest(Attributor &A) override {
2304 return ChangeStatus::UNCHANGED;
2305 }
2306
2307 // Map of ICV to their values at specific program point.
2309 InternalControlVar::ICV___last>
2310 ICVReplacementValuesMap;
2311
2312 ChangeStatus updateImpl(Attributor &A) override {
2313 ChangeStatus HasChanged = ChangeStatus::UNCHANGED;
2314
2315 Function *F = getAnchorScope();
2316
2317 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2318
2319 for (InternalControlVar ICV : TrackableICVs) {
2320 auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2321
2322 auto &ValuesMap = ICVReplacementValuesMap[ICV];
2323 auto TrackValues = [&](Use &U, Function &) {
2324 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U);
2325 if (!CI)
2326 return false;
2327
2328 // FIXME: handle setters with more that 1 arguments.
2329 /// Track new value.
2330 if (ValuesMap.insert(std::make_pair(CI, CI->getArgOperand(0))).second)
2331 HasChanged = ChangeStatus::CHANGED;
2332
2333 return false;
2334 };
2335
2336 auto CallCheck = [&](Instruction &I) {
2337 std::optional<Value *> ReplVal = getValueForCall(A, I, ICV);
2338 if (ReplVal && ValuesMap.insert(std::make_pair(&I, *ReplVal)).second)
2339 HasChanged = ChangeStatus::CHANGED;
2340
2341 return true;
2342 };
2343
2344 // Track all changes of an ICV.
2345 SetterRFI.foreachUse(TrackValues, F);
2346
2347 bool UsedAssumedInformation = false;
2348 A.checkForAllInstructions(CallCheck, *this, {Instruction::Call},
2349 UsedAssumedInformation,
2350 /* CheckBBLivenessOnly */ true);
2351
2352 /// TODO: Figure out a way to avoid adding entry in
2353 /// ICVReplacementValuesMap
2354 Instruction *Entry = &F->getEntryBlock().front();
2355 if (HasChanged == ChangeStatus::CHANGED && !ValuesMap.count(Entry))
2356 ValuesMap.insert(std::make_pair(Entry, nullptr));
2357 }
2358
2359 return HasChanged;
2360 }
2361
2362 /// Helper to check if \p I is a call and get the value for it if it is
2363 /// unique.
2364 std::optional<Value *> getValueForCall(Attributor &A, const Instruction &I,
2365 InternalControlVar &ICV) const {
2366
2367 const auto *CB = dyn_cast<CallBase>(&I);
2368 if (!CB || CB->hasFnAttr("no_openmp") ||
2369 CB->hasFnAttr("no_openmp_routines"))
2370 return std::nullopt;
2371
2372 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2373 auto &GetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Getter];
2374 auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2375 Function *CalledFunction = CB->getCalledFunction();
2376
2377 // Indirect call, assume ICV changes.
2378 if (CalledFunction == nullptr)
2379 return nullptr;
2380 if (CalledFunction == GetterRFI.Declaration)
2381 return std::nullopt;
2382 if (CalledFunction == SetterRFI.Declaration) {
2383 if (ICVReplacementValuesMap[ICV].count(&I))
2384 return ICVReplacementValuesMap[ICV].lookup(&I);
2385
2386 return nullptr;
2387 }
2388
2389 // Since we don't know, assume it changes the ICV.
2390 if (CalledFunction->isDeclaration())
2391 return nullptr;
2392
2393 const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2394 *this, IRPosition::callsite_returned(*CB), DepClassTy::REQUIRED);
2395
2396 if (ICVTrackingAA->isAssumedTracked()) {
2397 std::optional<Value *> URV =
2398 ICVTrackingAA->getUniqueReplacementValue(ICV);
2399 if (!URV || (*URV && AA::isValidAtPosition(AA::ValueAndContext(**URV, I),
2400 OMPInfoCache)))
2401 return URV;
2402 }
2403
2404 // If we don't know, assume it changes.
2405 return nullptr;
2406 }
2407
2408 // We don't check unique value for a function, so return std::nullopt.
2409 std::optional<Value *>
2410 getUniqueReplacementValue(InternalControlVar ICV) const override {
2411 return std::nullopt;
2412 }
2413
2414 /// Return the value with which \p I can be replaced for specific \p ICV.
2415 std::optional<Value *> getReplacementValue(InternalControlVar ICV,
2416 const Instruction *I,
2417 Attributor &A) const override {
2418 const auto &ValuesMap = ICVReplacementValuesMap[ICV];
2419 if (ValuesMap.count(I))
2420 return ValuesMap.lookup(I);
2421
2424 Worklist.push_back(I);
2425
2426 std::optional<Value *> ReplVal;
2427
2428 while (!Worklist.empty()) {
2429 const Instruction *CurrInst = Worklist.pop_back_val();
2430 if (!Visited.insert(CurrInst).second)
2431 continue;
2432
2433 const BasicBlock *CurrBB = CurrInst->getParent();
2434
2435 // Go up and look for all potential setters/calls that might change the
2436 // ICV.
2437 while ((CurrInst = CurrInst->getPrevNode())) {
2438 if (ValuesMap.count(CurrInst)) {
2439 std::optional<Value *> NewReplVal = ValuesMap.lookup(CurrInst);
2440 // Unknown value, track new.
2441 if (!ReplVal) {
2442 ReplVal = NewReplVal;
2443 break;
2444 }
2445
2446 // If we found a new value, we can't know the icv value anymore.
2447 if (NewReplVal)
2448 if (ReplVal != NewReplVal)
2449 return nullptr;
2450
2451 break;
2452 }
2453
2454 std::optional<Value *> NewReplVal = getValueForCall(A, *CurrInst, ICV);
2455 if (!NewReplVal)
2456 continue;
2457
2458 // Unknown value, track new.
2459 if (!ReplVal) {
2460 ReplVal = NewReplVal;
2461 break;
2462 }
2463
2464 // if (NewReplVal.hasValue())
2465 // We found a new value, we can't know the icv value anymore.
2466 if (ReplVal != NewReplVal)
2467 return nullptr;
2468 }
2469
2470 // If we are in the same BB and we have a value, we are done.
2471 if (CurrBB == I->getParent() && ReplVal)
2472 return ReplVal;
2473
2474 // Go through all predecessors and add terminators for analysis.
2475 for (const BasicBlock *Pred : predecessors(CurrBB))
2476 if (const Instruction *Terminator = Pred->getTerminator())
2477 Worklist.push_back(Terminator);
2478 }
2479
2480 return ReplVal;
2481 }
2482};
2483
2484struct AAICVTrackerFunctionReturned : AAICVTracker {
2485 AAICVTrackerFunctionReturned(const IRPosition &IRP, Attributor &A)
2486 : AAICVTracker(IRP, A) {}
2487
2488 // FIXME: come up with better string.
2489 const std::string getAsStr(Attributor *) const override {
2490 return "ICVTrackerFunctionReturned";
2491 }
2492
2493 // FIXME: come up with some stats.
2494 void trackStatistics() const override {}
2495
2496 /// We don't manifest anything for this AA.
2497 ChangeStatus manifest(Attributor &A) override {
2498 return ChangeStatus::UNCHANGED;
2499 }
2500
2501 // Map of ICV to their values at specific program point.
2503 InternalControlVar::ICV___last>
2504 ICVReplacementValuesMap;
2505
2506 /// Return the value with which \p I can be replaced for specific \p ICV.
2507 std::optional<Value *>
2508 getUniqueReplacementValue(InternalControlVar ICV) const override {
2509 return ICVReplacementValuesMap[ICV];
2510 }
2511
2512 ChangeStatus updateImpl(Attributor &A) override {
2513 ChangeStatus Changed = ChangeStatus::UNCHANGED;
2514 const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2515 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2516
2517 if (!ICVTrackingAA->isAssumedTracked())
2518 return indicatePessimisticFixpoint();
2519
2520 for (InternalControlVar ICV : TrackableICVs) {
2521 std::optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2522 std::optional<Value *> UniqueICVValue;
2523
2524 auto CheckReturnInst = [&](Instruction &I) {
2525 std::optional<Value *> NewReplVal =
2526 ICVTrackingAA->getReplacementValue(ICV, &I, A);
2527
2528 // If we found a second ICV value there is no unique returned value.
2529 if (UniqueICVValue && UniqueICVValue != NewReplVal)
2530 return false;
2531
2532 UniqueICVValue = NewReplVal;
2533
2534 return true;
2535 };
2536
2537 bool UsedAssumedInformation = false;
2538 if (!A.checkForAllInstructions(CheckReturnInst, *this, {Instruction::Ret},
2539 UsedAssumedInformation,
2540 /* CheckBBLivenessOnly */ true))
2541 UniqueICVValue = nullptr;
2542
2543 if (UniqueICVValue == ReplVal)
2544 continue;
2545
2546 ReplVal = UniqueICVValue;
2547 Changed = ChangeStatus::CHANGED;
2548 }
2549
2550 return Changed;
2551 }
2552};
2553
2554struct AAICVTrackerCallSite : AAICVTracker {
2555 AAICVTrackerCallSite(const IRPosition &IRP, Attributor &A)
2556 : AAICVTracker(IRP, A) {}
2557
2558 void initialize(Attributor &A) override {
2559 assert(getAnchorScope() && "Expected anchor function");
2560
2561 // We only initialize this AA for getters, so we need to know which ICV it
2562 // gets.
2563 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2564 for (InternalControlVar ICV : TrackableICVs) {
2565 auto ICVInfo = OMPInfoCache.ICVs[ICV];
2566 auto &Getter = OMPInfoCache.RFIs[ICVInfo.Getter];
2567 if (Getter.Declaration == getAssociatedFunction()) {
2568 AssociatedICV = ICVInfo.Kind;
2569 return;
2570 }
2571 }
2572
2573 /// Unknown ICV.
2574 indicatePessimisticFixpoint();
2575 }
2576
2577 ChangeStatus manifest(Attributor &A) override {
2578 if (!ReplVal || !*ReplVal)
2579 return ChangeStatus::UNCHANGED;
2580
2581 A.changeAfterManifest(IRPosition::inst(*getCtxI()), **ReplVal);
2582 A.deleteAfterManifest(*getCtxI());
2583
2584 return ChangeStatus::CHANGED;
2585 }
2586
2587 // FIXME: come up with better string.
2588 const std::string getAsStr(Attributor *) const override {
2589 return "ICVTrackerCallSite";
2590 }
2591
2592 // FIXME: come up with some stats.
2593 void trackStatistics() const override {}
2594
2595 InternalControlVar AssociatedICV;
2596 std::optional<Value *> ReplVal;
2597
2598 ChangeStatus updateImpl(Attributor &A) override {
2599 const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2600 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2601
2602 // We don't have any information, so we assume it changes the ICV.
2603 if (!ICVTrackingAA->isAssumedTracked())
2604 return indicatePessimisticFixpoint();
2605
2606 std::optional<Value *> NewReplVal =
2607 ICVTrackingAA->getReplacementValue(AssociatedICV, getCtxI(), A);
2608
2609 if (ReplVal == NewReplVal)
2610 return ChangeStatus::UNCHANGED;
2611
2612 ReplVal = NewReplVal;
2613 return ChangeStatus::CHANGED;
2614 }
2615
2616 // Return the value with which associated value can be replaced for specific
2617 // \p ICV.
2618 std::optional<Value *>
2619 getUniqueReplacementValue(InternalControlVar ICV) const override {
2620 return ReplVal;
2621 }
2622};
2623
2624struct AAICVTrackerCallSiteReturned : AAICVTracker {
2625 AAICVTrackerCallSiteReturned(const IRPosition &IRP, Attributor &A)
2626 : AAICVTracker(IRP, A) {}
2627
2628 // FIXME: come up with better string.
2629 const std::string getAsStr(Attributor *) const override {
2630 return "ICVTrackerCallSiteReturned";
2631 }
2632
2633 // FIXME: come up with some stats.
2634 void trackStatistics() const override {}
2635
2636 /// We don't manifest anything for this AA.
2637 ChangeStatus manifest(Attributor &A) override {
2638 return ChangeStatus::UNCHANGED;
2639 }
2640
2641 // Map of ICV to their values at specific program point.
2643 InternalControlVar::ICV___last>
2644 ICVReplacementValuesMap;
2645
2646 /// Return the value with which associated value can be replaced for specific
2647 /// \p ICV.
2648 std::optional<Value *>
2649 getUniqueReplacementValue(InternalControlVar ICV) const override {
2650 return ICVReplacementValuesMap[ICV];
2651 }
2652
2653 ChangeStatus updateImpl(Attributor &A) override {
2654 ChangeStatus Changed = ChangeStatus::UNCHANGED;
2655 const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2656 *this, IRPosition::returned(*getAssociatedFunction()),
2657 DepClassTy::REQUIRED);
2658
2659 // We don't have any information, so we assume it changes the ICV.
2660 if (!ICVTrackingAA->isAssumedTracked())
2661 return indicatePessimisticFixpoint();
2662
2663 for (InternalControlVar ICV : TrackableICVs) {
2664 std::optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2665 std::optional<Value *> NewReplVal =
2666 ICVTrackingAA->getUniqueReplacementValue(ICV);
2667
2668 if (ReplVal == NewReplVal)
2669 continue;
2670
2671 ReplVal = NewReplVal;
2672 Changed = ChangeStatus::CHANGED;
2673 }
2674 return Changed;
2675 }
2676};
2677
2678/// Determines if \p BB exits the function unconditionally itself or reaches a
2679/// block that does through only unique successors.
2680static bool hasFunctionEndAsUniqueSuccessor(const BasicBlock *BB) {
2681 if (succ_empty(BB))
2682 return true;
2683 const BasicBlock *const Successor = BB->getUniqueSuccessor();
2684 if (!Successor)
2685 return false;
2686 return hasFunctionEndAsUniqueSuccessor(Successor);
2687}
2688
2689struct AAExecutionDomainFunction : public AAExecutionDomain {
2690 AAExecutionDomainFunction(const IRPosition &IRP, Attributor &A)
2691 : AAExecutionDomain(IRP, A) {}
2692
2693 ~AAExecutionDomainFunction() { delete RPOT; }
2694
2695 void initialize(Attributor &A) override {
2697 assert(F && "Expected anchor function");
2699 }
2700
2701 const std::string getAsStr(Attributor *) const override {
2702 unsigned TotalBlocks = 0, InitialThreadBlocks = 0, AlignedBlocks = 0;
2703 for (auto &It : BEDMap) {
2704 if (!It.getFirst())
2705 continue;
2706 TotalBlocks++;
2707 InitialThreadBlocks += It.getSecond().IsExecutedByInitialThreadOnly;
2708 AlignedBlocks += It.getSecond().IsReachedFromAlignedBarrierOnly &&
2709 It.getSecond().IsReachingAlignedBarrierOnly;
2710 }
2711 return "[AAExecutionDomain] " + std::to_string(InitialThreadBlocks) + "/" +
2712 std::to_string(AlignedBlocks) + " of " +
2713 std::to_string(TotalBlocks) +
2714 " executed by initial thread / aligned";
2715 }
2716
2717 /// See AbstractAttribute::trackStatistics().
2718 void trackStatistics() const override {}
2719
2720 ChangeStatus manifest(Attributor &A) override {
2721 LLVM_DEBUG({
2722 for (const BasicBlock &BB : *getAnchorScope()) {
2724 continue;
2725 dbgs() << TAG << " Basic block @" << getAnchorScope()->getName() << " "
2726 << BB.getName() << " is executed by a single thread.\n";
2727 }
2728 });
2729
2731
2733 return Changed;
2734
2735 SmallPtrSet<CallBase *, 16> DeletedBarriers;
2736 auto HandleAlignedBarrier = [&](CallBase *CB) {
2737 const ExecutionDomainTy &ED = CB ? CEDMap[{CB, PRE}] : BEDMap[nullptr];
2738 if (!ED.IsReachedFromAlignedBarrierOnly ||
2739 ED.EncounteredNonLocalSideEffect)
2740 return;
2741 if (!ED.EncounteredAssumes.empty() && !A.isModulePass())
2742 return;
2743
2744 // We can remove this barrier, if it is one, or aligned barriers reaching
2745 // the kernel end (if CB is nullptr). Aligned barriers reaching the kernel
2746 // end should only be removed if the kernel end is their unique successor;
2747 // otherwise, they may have side-effects that aren't accounted for in the
2748 // kernel end in their other successors. If those barriers have other
2749 // barriers reaching them, those can be transitively removed as well as
2750 // long as the kernel end is also their unique successor.
2751 if (CB) {
2752 DeletedBarriers.insert(CB);
2753 A.deleteAfterManifest(*CB);
2754 ++NumBarriersEliminated;
2755 Changed = ChangeStatus::CHANGED;
2756 } else if (!ED.AlignedBarriers.empty()) {
2757 Changed = ChangeStatus::CHANGED;
2758 SmallVector<CallBase *> Worklist(ED.AlignedBarriers.begin(),
2759 ED.AlignedBarriers.end());
2761 while (!Worklist.empty()) {
2762 CallBase *LastCB = Worklist.pop_back_val();
2763 if (!Visited.insert(LastCB))
2764 continue;
2765 if (LastCB->getFunction() != getAnchorScope())
2766 continue;
2767 if (!hasFunctionEndAsUniqueSuccessor(LastCB->getParent()))
2768 continue;
2769 if (!DeletedBarriers.count(LastCB)) {
2770 ++NumBarriersEliminated;
2771 A.deleteAfterManifest(*LastCB);
2772 continue;
2773 }
2774 // The final aligned barrier (LastCB) reaching the kernel end was
2775 // removed already. This means we can go one step further and remove
2776 // the barriers encoutered last before (LastCB).
2777 const ExecutionDomainTy &LastED = CEDMap[{LastCB, PRE}];
2778 Worklist.append(LastED.AlignedBarriers.begin(),
2779 LastED.AlignedBarriers.end());
2780 }
2781 }
2782
2783 // If we actually eliminated a barrier we need to eliminate the associated
2784 // llvm.assumes as well to avoid creating UB.
2785 if (!ED.EncounteredAssumes.empty() && (CB || !ED.AlignedBarriers.empty()))
2786 for (auto *AssumeCB : ED.EncounteredAssumes)
2787 A.deleteAfterManifest(*AssumeCB);
2788 };
2789
2790 for (auto *CB : AlignedBarriers)
2791 HandleAlignedBarrier(CB);
2792
2793 // Handle the "kernel end barrier" for kernels too.
2795 HandleAlignedBarrier(nullptr);
2796
2797 return Changed;
2798 }
2799
2800 bool isNoOpFence(const FenceInst &FI) const override {
2801 return getState().isValidState() && !NonNoOpFences.count(&FI);
2802 }
2803
2804 /// Merge barrier and assumption information from \p PredED into the successor
2805 /// \p ED.
2806 void
2807 mergeInPredecessorBarriersAndAssumptions(Attributor &A, ExecutionDomainTy &ED,
2808 const ExecutionDomainTy &PredED);
2809
2810 /// Merge all information from \p PredED into the successor \p ED. If
2811 /// \p InitialEdgeOnly is set, only the initial edge will enter the block
2812 /// represented by \p ED from this predecessor.
2813 bool mergeInPredecessor(Attributor &A, ExecutionDomainTy &ED,
2814 const ExecutionDomainTy &PredED,
2815 bool InitialEdgeOnly = false);
2816
2817 /// Accumulate information for the entry block in \p EntryBBED.
2818 bool handleCallees(Attributor &A, ExecutionDomainTy &EntryBBED);
2819
2820 /// See AbstractAttribute::updateImpl.
2822
2823 /// Query interface, see AAExecutionDomain
2824 ///{
2825 bool isExecutedByInitialThreadOnly(const BasicBlock &BB) const override {
2826 if (!isValidState())
2827 return false;
2828 assert(BB.getParent() == getAnchorScope() && "Block is out of scope!");
2829 return BEDMap.lookup(&BB).IsExecutedByInitialThreadOnly;
2830 }
2831
2833 const Instruction &I) const override {
2834 assert(I.getFunction() == getAnchorScope() &&
2835 "Instruction is out of scope!");
2836 if (!isValidState())
2837 return false;
2838
2839 bool ForwardIsOk = true;
2840 const Instruction *CurI;
2841
2842 // Check forward until a call or the block end is reached.
2843 CurI = &I;
2844 do {
2845 auto *CB = dyn_cast<CallBase>(CurI);
2846 if (!CB)
2847 continue;
2848 if (CB != &I && AlignedBarriers.contains(const_cast<CallBase *>(CB)))
2849 return true;
2850 const auto &It = CEDMap.find({CB, PRE});
2851 if (It == CEDMap.end())
2852 continue;
2853 if (!It->getSecond().IsReachingAlignedBarrierOnly)
2854 ForwardIsOk = false;
2855 break;
2856 } while ((CurI = CurI->getNextNonDebugInstruction()));
2857
2858 if (!CurI && !BEDMap.lookup(I.getParent()).IsReachingAlignedBarrierOnly)
2859 ForwardIsOk = false;
2860
2861 // Check backward until a call or the block beginning is reached.
2862 CurI = &I;
2863 do {
2864 auto *CB = dyn_cast<CallBase>(CurI);
2865 if (!CB)
2866 continue;
2867 if (CB != &I && AlignedBarriers.contains(const_cast<CallBase *>(CB)))
2868 return true;
2869 const auto &It = CEDMap.find({CB, POST});
2870 if (It == CEDMap.end())
2871 continue;
2872 if (It->getSecond().IsReachedFromAlignedBarrierOnly)
2873 break;
2874 return false;
2875 } while ((CurI = CurI->getPrevNonDebugInstruction()));
2876
2877 // Delayed decision on the forward pass to allow aligned barrier detection
2878 // in the backwards traversal.
2879 if (!ForwardIsOk)
2880 return false;
2881
2882 if (!CurI) {
2883 const BasicBlock *BB = I.getParent();
2884 if (BB == &BB->getParent()->getEntryBlock())
2885 return BEDMap.lookup(nullptr).IsReachedFromAlignedBarrierOnly;
2886 if (!llvm::all_of(predecessors(BB), [&](const BasicBlock *PredBB) {
2887 return BEDMap.lookup(PredBB).IsReachedFromAlignedBarrierOnly;
2888 })) {
2889 return false;
2890 }
2891 }
2892
2893 // On neither traversal we found a anything but aligned barriers.
2894 return true;
2895 }
2896
2897 ExecutionDomainTy getExecutionDomain(const BasicBlock &BB) const override {
2898 assert(isValidState() &&
2899 "No request should be made against an invalid state!");
2900 return BEDMap.lookup(&BB);
2901 }
2902 std::pair<ExecutionDomainTy, ExecutionDomainTy>
2903 getExecutionDomain(const CallBase &CB) const override {
2904 assert(isValidState() &&
2905 "No request should be made against an invalid state!");
2906 return {CEDMap.lookup({&CB, PRE}), CEDMap.lookup({&CB, POST})};
2907 }
2908 ExecutionDomainTy getFunctionExecutionDomain() const override {
2909 assert(isValidState() &&
2910 "No request should be made against an invalid state!");
2911 return InterProceduralED;
2912 }
2913 ///}
2914
2915 // Check if the edge into the successor block contains a condition that only
2916 // lets the main thread execute it.
2917 static bool isInitialThreadOnlyEdge(Attributor &A, BranchInst *Edge,
2918 BasicBlock &SuccessorBB) {
2919 if (!Edge || !Edge->isConditional())
2920 return false;
2921 if (Edge->getSuccessor(0) != &SuccessorBB)
2922 return false;
2923
2924 auto *Cmp = dyn_cast<CmpInst>(Edge->getCondition());
2925 if (!Cmp || !Cmp->isTrueWhenEqual() || !Cmp->isEquality())
2926 return false;
2927
2928 ConstantInt *C = dyn_cast<ConstantInt>(Cmp->getOperand(1));
2929 if (!C)
2930 return false;
2931
2932 // Match: -1 == __kmpc_target_init (for non-SPMD kernels only!)
2933 if (C->isAllOnesValue()) {
2934 auto *CB = dyn_cast<CallBase>(Cmp->getOperand(0));
2935 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2936 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
2937 CB = CB ? OpenMPOpt::getCallIfRegularCall(*CB, &RFI) : nullptr;
2938 if (!CB)
2939 return false;
2940 ConstantStruct *KernelEnvC =
2942 ConstantInt *ExecModeC =
2943 KernelInfo::getExecModeFromKernelEnvironment(KernelEnvC);
2944 return ExecModeC->getSExtValue() & OMP_TGT_EXEC_MODE_GENERIC;
2945 }
2946
2947 if (C->isZero()) {
2948 // Match: 0 == llvm.nvvm.read.ptx.sreg.tid.x()
2949 if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
2950 if (II->getIntrinsicID() == Intrinsic::nvvm_read_ptx_sreg_tid_x)
2951 return true;
2952
2953 // Match: 0 == llvm.amdgcn.workitem.id.x()
2954 if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
2955 if (II->getIntrinsicID() == Intrinsic::amdgcn_workitem_id_x)
2956 return true;
2957 }
2958
2959 return false;
2960 };
2961
2962 /// Mapping containing information about the function for other AAs.
2963 ExecutionDomainTy InterProceduralED;
2964
2965 enum Direction { PRE = 0, POST = 1 };
2966 /// Mapping containing information per block.
2969 CEDMap;
2970 SmallSetVector<CallBase *, 16> AlignedBarriers;
2971
2973
2974 /// Set \p R to \V and report true if that changed \p R.
2975 static bool setAndRecord(bool &R, bool V) {
2976 bool Eq = (R == V);
2977 R = V;
2978 return !Eq;
2979 }
2980
2981 /// Collection of fences known to be non-no-opt. All fences not in this set
2982 /// can be assumed no-opt.
2984};
2985
2986void AAExecutionDomainFunction::mergeInPredecessorBarriersAndAssumptions(
2987 Attributor &A, ExecutionDomainTy &ED, const ExecutionDomainTy &PredED) {
2988 for (auto *EA : PredED.EncounteredAssumes)
2989 ED.addAssumeInst(A, *EA);
2990
2991 for (auto *AB : PredED.AlignedBarriers)
2992 ED.addAlignedBarrier(A, *AB);
2993}
2994
2995bool AAExecutionDomainFunction::mergeInPredecessor(
2996 Attributor &A, ExecutionDomainTy &ED, const ExecutionDomainTy &PredED,
2997 bool InitialEdgeOnly) {
2998
2999 bool Changed = false;
3000 Changed |=
3001 setAndRecord(ED.IsExecutedByInitialThreadOnly,
3002 InitialEdgeOnly || (PredED.IsExecutedByInitialThreadOnly &&
3003 ED.IsExecutedByInitialThreadOnly));
3004
3005 Changed |= setAndRecord(ED.IsReachedFromAlignedBarrierOnly,
3006 ED.IsReachedFromAlignedBarrierOnly &&
3007 PredED.IsReachedFromAlignedBarrierOnly);
3008 Changed |= setAndRecord(ED.EncounteredNonLocalSideEffect,
3009 ED.EncounteredNonLocalSideEffect |
3010 PredED.EncounteredNonLocalSideEffect);
3011 // Do not track assumptions and barriers as part of Changed.
3012 if (ED.IsReachedFromAlignedBarrierOnly)
3013 mergeInPredecessorBarriersAndAssumptions(A, ED, PredED);
3014 else
3015 ED.clearAssumeInstAndAlignedBarriers();
3016 return Changed;
3017}
3018
3019bool AAExecutionDomainFunction::handleCallees(Attributor &A,
3020 ExecutionDomainTy &EntryBBED) {
3022 auto PredForCallSite = [&](AbstractCallSite ACS) {
3023 const auto *EDAA = A.getAAFor<AAExecutionDomain>(
3024 *this, IRPosition::function(*ACS.getInstruction()->getFunction()),
3025 DepClassTy::OPTIONAL);
3026 if (!EDAA || !EDAA->getState().isValidState())
3027 return false;
3028 CallSiteEDs.emplace_back(
3029 EDAA->getExecutionDomain(*cast<CallBase>(ACS.getInstruction())));
3030 return true;
3031 };
3032
3033 ExecutionDomainTy ExitED;
3034 bool AllCallSitesKnown;
3035 if (A.checkForAllCallSites(PredForCallSite, *this,
3036 /* RequiresAllCallSites */ true,
3037 AllCallSitesKnown)) {
3038 for (const auto &[CSInED, CSOutED] : CallSiteEDs) {
3039 mergeInPredecessor(A, EntryBBED, CSInED);
3040 ExitED.IsReachingAlignedBarrierOnly &=
3041 CSOutED.IsReachingAlignedBarrierOnly;
3042 }
3043
3044 } else {
3045 // We could not find all predecessors, so this is either a kernel or a
3046 // function with external linkage (or with some other weird uses).
3047 if (omp::isOpenMPKernel(*getAnchorScope())) {
3048 EntryBBED.IsExecutedByInitialThreadOnly = false;
3049 EntryBBED.IsReachedFromAlignedBarrierOnly = true;
3050 EntryBBED.EncounteredNonLocalSideEffect = false;
3051 ExitED.IsReachingAlignedBarrierOnly = false;
3052 } else {
3053 EntryBBED.IsExecutedByInitialThreadOnly = false;
3054 EntryBBED.IsReachedFromAlignedBarrierOnly = false;
3055 EntryBBED.EncounteredNonLocalSideEffect = true;
3056 ExitED.IsReachingAlignedBarrierOnly = false;
3057 }
3058 }
3059
3060 bool Changed = false;
3061 auto &FnED = BEDMap[nullptr];
3062 Changed |= setAndRecord(FnED.IsReachedFromAlignedBarrierOnly,
3063 FnED.IsReachedFromAlignedBarrierOnly &
3064 EntryBBED.IsReachedFromAlignedBarrierOnly);
3065 Changed |= setAndRecord(FnED.IsReachingAlignedBarrierOnly,
3066 FnED.IsReachingAlignedBarrierOnly &
3067 ExitED.IsReachingAlignedBarrierOnly);
3068 Changed |= setAndRecord(FnED.IsExecutedByInitialThreadOnly,
3069 EntryBBED.IsExecutedByInitialThreadOnly);
3070 return Changed;
3071}
3072
3073ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
3074
3075 bool Changed = false;
3076
3077 // Helper to deal with an aligned barrier encountered during the forward
3078 // traversal. \p CB is the aligned barrier, \p ED is the execution domain when
3079 // it was encountered.
3080 auto HandleAlignedBarrier = [&](CallBase &CB, ExecutionDomainTy &ED) {
3081 Changed |= AlignedBarriers.insert(&CB);
3082 // First, update the barrier ED kept in the separate CEDMap.
3083 auto &CallInED = CEDMap[{&CB, PRE}];
3084 Changed |= mergeInPredecessor(A, CallInED, ED);
3085 CallInED.IsReachingAlignedBarrierOnly = true;
3086 // Next adjust the ED we use for the traversal.
3087 ED.EncounteredNonLocalSideEffect = false;
3088 ED.IsReachedFromAlignedBarrierOnly = true;
3089 // Aligned barrier collection has to come last.
3090 ED.clearAssumeInstAndAlignedBarriers();
3091 ED.addAlignedBarrier(A, CB);
3092 auto &CallOutED = CEDMap[{&CB, POST}];
3093 Changed |= mergeInPredecessor(A, CallOutED, ED);
3094 };
3095
3096 auto *LivenessAA =
3097 A.getAAFor<AAIsDead>(*this, getIRPosition(), DepClassTy::OPTIONAL);
3098
3099 Function *F = getAnchorScope();
3100 BasicBlock &EntryBB = F->getEntryBlock();
3101 bool IsKernel = omp::isOpenMPKernel(*F);
3102
3103 SmallVector<Instruction *> SyncInstWorklist;
3104 for (auto &RIt : *RPOT) {
3105 BasicBlock &BB = *RIt;
3106
3107 bool IsEntryBB = &BB == &EntryBB;
3108 // TODO: We use local reasoning since we don't have a divergence analysis
3109 // running as well. We could basically allow uniform branches here.
3110 bool AlignedBarrierLastInBlock = IsEntryBB && IsKernel;
3111 bool IsExplicitlyAligned = IsEntryBB && IsKernel;
3112 ExecutionDomainTy ED;
3113 // Propagate "incoming edges" into information about this block.
3114 if (IsEntryBB) {
3115 Changed |= handleCallees(A, ED);
3116 } else {
3117 // For live non-entry blocks we only propagate
3118 // information via live edges.
3119 if (LivenessAA && LivenessAA->isAssumedDead(&BB))
3120 continue;
3121
3122 for (auto *PredBB : predecessors(&BB)) {
3123 if (LivenessAA && LivenessAA->isEdgeDead(PredBB, &BB))
3124 continue;
3125 bool InitialEdgeOnly = isInitialThreadOnlyEdge(
3126 A, dyn_cast<BranchInst>(PredBB->getTerminator()), BB);
3127 mergeInPredecessor(A, ED, BEDMap[PredBB], InitialEdgeOnly);
3128 }
3129 }
3130
3131 // Now we traverse the block, accumulate effects in ED and attach
3132 // information to calls.
3133 for (Instruction &I : BB) {
3134 bool UsedAssumedInformation;
3135 if (A.isAssumedDead(I, *this, LivenessAA, UsedAssumedInformation,
3136 /* CheckBBLivenessOnly */ false, DepClassTy::OPTIONAL,
3137 /* CheckForDeadStore */ true))
3138 continue;
3139
3140 // Asummes and "assume-like" (dbg, lifetime, ...) are handled first, the
3141 // former is collected the latter is ignored.
3142 if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
3143 if (auto *AI = dyn_cast_or_null<AssumeInst>(II)) {
3144 ED.addAssumeInst(A, *AI);
3145 continue;
3146 }
3147 // TODO: Should we also collect and delete lifetime markers?
3148 if (II->isAssumeLikeIntrinsic())
3149 continue;
3150 }
3151
3152 if (auto *FI = dyn_cast<FenceInst>(&I)) {
3153 if (!ED.EncounteredNonLocalSideEffect) {
3154 // An aligned fence without non-local side-effects is a no-op.
3155 if (ED.IsReachedFromAlignedBarrierOnly)
3156 continue;
3157 // A non-aligned fence without non-local side-effects is a no-op
3158 // if the ordering only publishes non-local side-effects (or less).
3159 switch (FI->getOrdering()) {
3160 case AtomicOrdering::NotAtomic:
3161 continue;
3162 case AtomicOrdering::Unordered:
3163 continue;
3164 case AtomicOrdering::Monotonic:
3165 continue;
3166 case AtomicOrdering::Acquire:
3167 break;
3168 case AtomicOrdering::Release:
3169 continue;
3170 case AtomicOrdering::AcquireRelease:
3171 break;
3172 case AtomicOrdering::SequentiallyConsistent:
3173 break;
3174 };
3175 }
3176 NonNoOpFences.insert(FI);
3177 }
3178
3179 auto *CB = dyn_cast<CallBase>(&I);
3180 bool IsNoSync = AA::isNoSyncInst(A, I, *this);
3181 bool IsAlignedBarrier =
3182 !IsNoSync && CB &&
3183 AANoSync::isAlignedBarrier(*CB, AlignedBarrierLastInBlock);
3184
3185 AlignedBarrierLastInBlock &= IsNoSync;
3186 IsExplicitlyAligned &= IsNoSync;
3187
3188 // Next we check for calls. Aligned barriers are handled
3189 // explicitly, everything else is kept for the backward traversal and will
3190 // also affect our state.
3191 if (CB) {
3192 if (IsAlignedBarrier) {
3193 HandleAlignedBarrier(*CB, ED);
3194 AlignedBarrierLastInBlock = true;
3195 IsExplicitlyAligned = true;
3196 continue;
3197 }
3198
3199 // Check the pointer(s) of a memory intrinsic explicitly.
3200 if (isa<MemIntrinsic>(&I)) {
3201 if (!ED.EncounteredNonLocalSideEffect &&
3203 ED.EncounteredNonLocalSideEffect = true;
3204 if (!IsNoSync) {
3205 ED.IsReachedFromAlignedBarrierOnly = false;
3206 SyncInstWorklist.push_back(&I);
3207 }
3208 continue;
3209 }
3210
3211 // Record how we entered the call, then accumulate the effect of the
3212 // call in ED for potential use by the callee.
3213 auto &CallInED = CEDMap[{CB, PRE}];
3214 Changed |= mergeInPredecessor(A, CallInED, ED);
3215
3216 // If we have a sync-definition we can check if it starts/ends in an
3217 // aligned barrier. If we are unsure we assume any sync breaks
3218 // alignment.
3220 if (!IsNoSync && Callee && !Callee->isDeclaration()) {
3221 const auto *EDAA = A.getAAFor<AAExecutionDomain>(
3222 *this, IRPosition::function(*Callee), DepClassTy::OPTIONAL);
3223 if (EDAA && EDAA->getState().isValidState()) {
3224 const auto &CalleeED = EDAA->getFunctionExecutionDomain();
3226 CalleeED.IsReachedFromAlignedBarrierOnly;
3227 AlignedBarrierLastInBlock = ED.IsReachedFromAlignedBarrierOnly;
3228 if (IsNoSync || !CalleeED.IsReachedFromAlignedBarrierOnly)
3229 ED.EncounteredNonLocalSideEffect |=
3230 CalleeED.EncounteredNonLocalSideEffect;
3231 else
3232 ED.EncounteredNonLocalSideEffect =
3233 CalleeED.EncounteredNonLocalSideEffect;
3234 if (!CalleeED.IsReachingAlignedBarrierOnly) {
3235 Changed |=
3236 setAndRecord(CallInED.IsReachingAlignedBarrierOnly, false);
3237 SyncInstWorklist.push_back(&I);
3238 }
3239 if (CalleeED.IsReachedFromAlignedBarrierOnly)
3240 mergeInPredecessorBarriersAndAssumptions(A, ED, CalleeED);
3241 auto &CallOutED = CEDMap[{CB, POST}];
3242 Changed |= mergeInPredecessor(A, CallOutED, ED);
3243 continue;
3244 }
3245 }
3246 if (!IsNoSync) {
3247 ED.IsReachedFromAlignedBarrierOnly = false;
3248 Changed |= setAndRecord(CallInED.IsReachingAlignedBarrierOnly, false);
3249 SyncInstWorklist.push_back(&I);
3250 }
3251 AlignedBarrierLastInBlock &= ED.IsReachedFromAlignedBarrierOnly;
3252 ED.EncounteredNonLocalSideEffect |= !CB->doesNotAccessMemory();
3253 auto &CallOutED = CEDMap[{CB, POST}];
3254 Changed |= mergeInPredecessor(A, CallOutED, ED);
3255 }
3256
3257 if (!I.mayHaveSideEffects() && !I.mayReadFromMemory())
3258 continue;
3259
3260 // If we have a callee we try to use fine-grained information to
3261 // determine local side-effects.
3262 if (CB) {
3263 const auto *MemAA = A.getAAFor<AAMemoryLocation>(
3264 *this, IRPosition::callsite_function(*CB), DepClassTy::OPTIONAL);
3265
3266 auto AccessPred = [&](const Instruction *I, const Value *Ptr,
3269 return !AA::isPotentiallyAffectedByBarrier(A, {Ptr}, *this, I);
3270 };
3271 if (MemAA && MemAA->getState().isValidState() &&
3272 MemAA->checkForAllAccessesToMemoryKind(
3274 continue;
3275 }
3276
3277 auto &InfoCache = A.getInfoCache();
3278 if (!I.mayHaveSideEffects() && InfoCache.isOnlyUsedByAssume(I))
3279 continue;
3280
3281 if (auto *LI = dyn_cast<LoadInst>(&I))
3282 if (LI->hasMetadata(LLVMContext::MD_invariant_load))
3283 continue;
3284
3285 if (!ED.EncounteredNonLocalSideEffect &&
3287 ED.EncounteredNonLocalSideEffect = true;
3288 }
3289
3290 bool IsEndAndNotReachingAlignedBarriersOnly = false;
3291 if (!isa<UnreachableInst>(BB.getTerminator()) &&
3292 !BB.getTerminator()->getNumSuccessors()) {
3293
3294 Changed |= mergeInPredecessor(A, InterProceduralED, ED);
3295
3296 auto &FnED = BEDMap[nullptr];
3297 if (IsKernel && !IsExplicitlyAligned)
3298 FnED.IsReachingAlignedBarrierOnly = false;
3299 Changed |= mergeInPredecessor(A, FnED, ED);
3300
3301 if (!FnED.IsReachingAlignedBarrierOnly) {
3302 IsEndAndNotReachingAlignedBarriersOnly = true;
3303 SyncInstWorklist.push_back(BB.getTerminator());
3304 auto &BBED = BEDMap[&BB];
3305 Changed |= setAndRecord(BBED.IsReachingAlignedBarrierOnly, false);
3306 }
3307 }
3308
3309 ExecutionDomainTy &StoredED = BEDMap[&BB];
3310 ED.IsReachingAlignedBarrierOnly = StoredED.IsReachingAlignedBarrierOnly &
3311 !IsEndAndNotReachingAlignedBarriersOnly;
3312
3313 // Check if we computed anything different as part of the forward
3314 // traversal. We do not take assumptions and aligned barriers into account
3315 // as they do not influence the state we iterate. Backward traversal values
3316 // are handled later on.
3317 if (ED.IsExecutedByInitialThreadOnly !=
3318 StoredED.IsExecutedByInitialThreadOnly ||
3319 ED.IsReachedFromAlignedBarrierOnly !=
3320 StoredED.IsReachedFromAlignedBarrierOnly ||
3321 ED.EncounteredNonLocalSideEffect !=
3322 StoredED.EncounteredNonLocalSideEffect)
3323 Changed = true;
3324
3325 // Update the state with the new value.
3326 StoredED = std::move(ED);
3327 }
3328
3329 // Propagate (non-aligned) sync instruction effects backwards until the
3330 // entry is hit or an aligned barrier.
3332 while (!SyncInstWorklist.empty()) {
3333 Instruction *SyncInst = SyncInstWorklist.pop_back_val();
3334 Instruction *CurInst = SyncInst;
3335 bool HitAlignedBarrierOrKnownEnd = false;
3336 while ((CurInst = CurInst->getPrevNode())) {
3337 auto *CB = dyn_cast<CallBase>(CurInst);
3338 if (!CB)
3339 continue;
3340 auto &CallOutED = CEDMap[{CB, POST}];
3341 Changed |= setAndRecord(CallOutED.IsReachingAlignedBarrierOnly, false);
3342 auto &CallInED = CEDMap[{CB, PRE}];
3343 HitAlignedBarrierOrKnownEnd =
3344 AlignedBarriers.count(CB) || !CallInED.IsReachingAlignedBarrierOnly;
3345 if (HitAlignedBarrierOrKnownEnd)
3346 break;
3347 Changed |= setAndRecord(CallInED.IsReachingAlignedBarrierOnly, false);
3348 }
3349 if (HitAlignedBarrierOrKnownEnd)
3350 continue;
3351 BasicBlock *SyncBB = SyncInst->getParent();
3352 for (auto *PredBB : predecessors(SyncBB)) {
3353 if (LivenessAA && LivenessAA->isEdgeDead(PredBB, SyncBB))
3354 continue;
3355 if (!Visited.insert(PredBB))
3356 continue;
3357 auto &PredED = BEDMap[PredBB];
3358 if (setAndRecord(PredED.IsReachingAlignedBarrierOnly, false)) {
3359 Changed = true;
3360 SyncInstWorklist.push_back(PredBB->getTerminator());
3361 }
3362 }
3363 if (SyncBB != &EntryBB)
3364 continue;
3365 Changed |=
3366 setAndRecord(InterProceduralED.IsReachingAlignedBarrierOnly, false);
3367 }
3368
3369 return Changed ? ChangeStatus::CHANGED : ChangeStatus::UNCHANGED;
3370}
3371
3372/// Try to replace memory allocation calls called by a single thread with a
3373/// static buffer of shared memory.
3374struct AAHeapToShared : public StateWrapper<BooleanState, AbstractAttribute> {
3376 AAHeapToShared(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
3377
3378 /// Create an abstract attribute view for the position \p IRP.
3379 static AAHeapToShared &createForPosition(const IRPosition &IRP,
3380 Attributor &A);
3381
3382 /// Returns true if HeapToShared conversion is assumed to be possible.
3383 virtual bool isAssumedHeapToShared(CallBase &CB) const = 0;
3384
3385 /// Returns true if HeapToShared conversion is assumed and the CB is a
3386 /// callsite to a free operation to be removed.
3387 virtual bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const = 0;
3388
3389 /// See AbstractAttribute::getName().
3390 const std::string getName() const override { return "AAHeapToShared"; }
3391
3392 /// See AbstractAttribute::getIdAddr().
3393 const char *getIdAddr() const override { return &ID; }
3394
3395 /// This function should return true if the type of the \p AA is
3396 /// AAHeapToShared.
3397 static bool classof(const AbstractAttribute *AA) {
3398 return (AA->getIdAddr() == &ID);
3399 }
3400
3401 /// Unique ID (due to the unique address)
3402 static const char ID;
3403};
3404
3405struct AAHeapToSharedFunction : public AAHeapToShared {
3406 AAHeapToSharedFunction(const IRPosition &IRP, Attributor &A)
3407 : AAHeapToShared(IRP, A) {}
3408
3409 const std::string getAsStr(Attributor *) const override {
3410 return "[AAHeapToShared] " + std::to_string(MallocCalls.size()) +
3411 " malloc calls eligible.";
3412 }
3413
3414 /// See AbstractAttribute::trackStatistics().
3415 void trackStatistics() const override {}
3416
3417 /// This functions finds free calls that will be removed by the
3418 /// HeapToShared transformation.
3419 void findPotentialRemovedFreeCalls(Attributor &A) {
3420 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3421 auto &FreeRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
3422
3423 PotentialRemovedFreeCalls.clear();
3424 // Update free call users of found malloc calls.
3425 for (CallBase *CB : MallocCalls) {
3427 for (auto *U : CB->users()) {
3428 CallBase *C = dyn_cast<CallBase>(U);
3429 if (C && C->getCalledFunction() == FreeRFI.Declaration)
3430 FreeCalls.push_back(C);
3431 }
3432
3433 if (FreeCalls.size() != 1)
3434 continue;
3435
3436 PotentialRemovedFreeCalls.insert(FreeCalls.front());
3437 }
3438 }
3439
3440 void initialize(Attributor &A) override {
3442 indicatePessimisticFixpoint();
3443 return;
3444 }
3445
3446 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3447 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3448 if (!RFI.Declaration)
3449 return;
3450
3452 [](const IRPosition &, const AbstractAttribute *,
3453 bool &) -> std::optional<Value *> { return nullptr; };
3454
3455 Function *F = getAnchorScope();
3456 for (User *U : RFI.Declaration->users())
3457 if (CallBase *CB = dyn_cast<CallBase>(U)) {
3458 if (CB->getFunction() != F)
3459 continue;
3460 MallocCalls.insert(CB);
3461 A.registerSimplificationCallback(IRPosition::callsite_returned(*CB),
3462 SCB);
3463 }
3464
3465 findPotentialRemovedFreeCalls(A);
3466 }
3467
3468 bool isAssumedHeapToShared(CallBase &CB) const override {
3469 return isValidState() && MallocCalls.count(&CB);
3470 }
3471
3472 bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const override {
3473 return isValidState() && PotentialRemovedFreeCalls.count(&CB);
3474 }
3475
3476 ChangeStatus manifest(Attributor &A) override {
3477 if (MallocCalls.empty())
3478 return ChangeStatus::UNCHANGED;
3479
3480 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3481 auto &FreeCall = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
3482
3483 Function *F = getAnchorScope();
3484 auto *HS = A.lookupAAFor<AAHeapToStack>(IRPosition::function(*F), this,
3485 DepClassTy::OPTIONAL);
3486
3487 ChangeStatus Changed = ChangeStatus::UNCHANGED;
3488 for (CallBase *CB : MallocCalls) {
3489 // Skip replacing this if HeapToStack has already claimed it.
3490 if (HS && HS->isAssumedHeapToStack(*CB))
3491 continue;
3492
3493 // Find the unique free call to remove it.
3495 for (auto *U : CB->users()) {
3496 CallBase *C = dyn_cast<CallBase>(U);
3497 if (C && C->getCalledFunction() == FreeCall.Declaration)
3498 FreeCalls.push_back(C);
3499 }
3500 if (FreeCalls.size() != 1)
3501 continue;
3502
3503 auto *AllocSize = cast<ConstantInt>(CB->getArgOperand(0));
3504
3505 if (AllocSize->getZExtValue() + SharedMemoryUsed > SharedMemoryLimit) {
3506 LLVM_DEBUG(dbgs() << TAG << "Cannot replace call " << *CB
3507 << " with shared memory."
3508 << " Shared memory usage is limited to "
3509 << SharedMemoryLimit << " bytes\n");
3510 continue;
3511 }
3512
3513 LLVM_DEBUG(dbgs() << TAG << "Replace globalization call " << *CB
3514 << " with " << AllocSize->getZExtValue()
3515 << " bytes of shared memory\n");
3516
3517 // Create a new shared memory buffer of the same size as the allocation
3518 // and replace all the uses of the original allocation with it.
3519 Module *M = CB->getModule();
3520 Type *Int8Ty = Type::getInt8Ty(M->getContext());
3521 Type *Int8ArrTy = ArrayType::get(Int8Ty, AllocSize->getZExtValue());
3522 auto *SharedMem = new GlobalVariable(
3523 *M, Int8ArrTy, /* IsConstant */ false, GlobalValue::InternalLinkage,
3524 PoisonValue::get(Int8ArrTy), CB->getName() + "_shared", nullptr,
3526 static_cast<unsigned>(AddressSpace::Shared));
3527 auto *NewBuffer =
3528 ConstantExpr::getPointerCast(SharedMem, Int8Ty->getPointerTo());
3529
3530 auto Remark = [&](OptimizationRemark OR) {
3531 return OR << "Replaced globalized variable with "
3532 << ore::NV("SharedMemory", AllocSize->getZExtValue())
3533 << (AllocSize->isOne() ? " byte " : " bytes ")
3534 << "of shared memory.";
3535 };
3536 A.emitRemark<OptimizationRemark>(CB, "OMP111", Remark);
3537
3538 MaybeAlign Alignment = CB->getRetAlign();
3539 assert(Alignment &&
3540 "HeapToShared on allocation without alignment attribute");
3541 SharedMem->setAlignment(*Alignment);
3542
3543 A.changeAfterManifest(IRPosition::callsite_returned(*CB), *NewBuffer);
3544 A.deleteAfterManifest(*CB);
3545 A.deleteAfterManifest(*FreeCalls.front());
3546
3547 SharedMemoryUsed += AllocSize->getZExtValue();
3548 NumBytesMovedToSharedMemory = SharedMemoryUsed;
3549 Changed = ChangeStatus::CHANGED;
3550 }
3551
3552 return Changed;
3553 }
3554
3555 ChangeStatus updateImpl(Attributor &A) override {
3556 if (MallocCalls.empty())
3557 return indicatePessimisticFixpoint();
3558 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3559 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3560 if (!RFI.Declaration)
3561 return ChangeStatus::UNCHANGED;
3562
3563 Function *F = getAnchorScope();
3564
3565 auto NumMallocCalls = MallocCalls.size();
3566
3567 // Only consider malloc calls executed by a single thread with a constant.
3568 for (User *U : RFI.Declaration->users()) {
3569 if (CallBase *CB = dyn_cast<CallBase>(U)) {
3570 if (CB->getCaller() != F)
3571 continue;
3572 if (!MallocCalls.count(CB))
3573 continue;
3574 if (!isa<ConstantInt>(CB->getArgOperand(0))) {
3575 MallocCalls.remove(CB);
3576 continue;
3577 }
3578 const auto *ED = A.getAAFor<AAExecutionDomain>(
3579 *this, IRPosition::function(*F), DepClassTy::REQUIRED);
3580 if (!ED || !ED->isExecutedByInitialThreadOnly(*CB))
3581 MallocCalls.remove(CB);
3582 }
3583 }
3584
3585 findPotentialRemovedFreeCalls(A);
3586
3587 if (NumMallocCalls != MallocCalls.size())
3588 return ChangeStatus::CHANGED;
3589
3590 return ChangeStatus::UNCHANGED;
3591 }
3592
3593 /// Collection of all malloc calls in a function.
3595 /// Collection of potentially removed free calls in a function.
3596 SmallPtrSet<CallBase *, 4> PotentialRemovedFreeCalls;
3597 /// The total amount of shared memory that has been used for HeapToShared.
3598 unsigned SharedMemoryUsed = 0;
3599};
3600
3601struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
3603 AAKernelInfo(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
3604
3605 /// The callee value is tracked beyond a simple stripPointerCasts, so we allow
3606 /// unknown callees.
3607 static bool requiresCalleeForCallBase() { return false; }
3608
3609 /// Statistics are tracked as part of manifest for now.
3610 void trackStatistics() const override {}
3611
3612 /// See AbstractAttribute::getAsStr()
3613 const std::string getAsStr(Attributor *) const override {
3614 if (!isValidState())
3615 return "<invalid>";
3616 return std::string(SPMDCompatibilityTracker.isAssumed() ? "SPMD"
3617 : "generic") +
3618 std::string(SPMDCompatibilityTracker.isAtFixpoint() ? " [FIX]"
3619 : "") +
3620 std::string(" #PRs: ") +
3621 (ReachedKnownParallelRegions.isValidState()
3622 ? std::to_string(ReachedKnownParallelRegions.size())
3623 : "<invalid>") +
3624 ", #Unknown PRs: " +
3625 (ReachedUnknownParallelRegions.isValidState()
3626 ? std::to_string(ReachedUnknownParallelRegions.size())
3627 : "<invalid>") +
3628 ", #Reaching Kernels: " +
3629 (ReachingKernelEntries.isValidState()
3630 ? std::to_string(ReachingKernelEntries.size())
3631 : "<invalid>") +
3632 ", #ParLevels: " +
3633 (ParallelLevels.isValidState()
3634 ? std::to_string(ParallelLevels.size())
3635 : "<invalid>") +
3636 ", NestedPar: " + (NestedParallelism ? "yes" : "no");
3637 }
3638
3639 /// Create an abstract attribute biew for the position \p IRP.
3640 static AAKernelInfo &createForPosition(const IRPosition &IRP, Attributor &A);
3641
3642 /// See AbstractAttribute::getName()
3643 const std::string getName() const override { return "AAKernelInfo"; }
3644
3645 /// See AbstractAttribute::getIdAddr()
3646 const char *getIdAddr() const override { return &ID; }
3647
3648 /// This function should return true if the type of the \p AA is AAKernelInfo
3649 static bool classof(const AbstractAttribute *AA) {
3650 return (AA->getIdAddr() == &ID);
3651 }
3652
3653 static const char ID;
3654};
3655
3656/// The function kernel info abstract attribute, basically, what can we say
3657/// about a function with regards to the KernelInfoState.
3658struct AAKernelInfoFunction : AAKernelInfo {
3659 AAKernelInfoFunction(const IRPosition &IRP, Attributor &A)
3660 : AAKernelInfo(IRP, A) {}
3661
3662 SmallPtrSet<Instruction *, 4> GuardedInstructions;
3663
3664 SmallPtrSetImpl<Instruction *> &getGuardedInstructions() {
3665 return GuardedInstructions;
3666 }
3667
3668 void setConfigurationOfKernelEnvironment(ConstantStruct *ConfigC) {
3670 KernelEnvC, ConfigC, {KernelInfo::ConfigurationIdx});
3671 assert(NewKernelEnvC && "Failed to create new kernel environment");
3672 KernelEnvC = cast<ConstantStruct>(NewKernelEnvC);
3673 }
3674
3675#define KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MEMBER) \
3676 void set##MEMBER##OfKernelEnvironment(ConstantInt *NewVal) { \
3677 ConstantStruct *ConfigC = \
3678 KernelInfo::getConfigurationFromKernelEnvironment(KernelEnvC); \
3679 Constant *NewConfigC = ConstantFoldInsertValueInstruction( \
3680 ConfigC, NewVal, {KernelInfo::MEMBER##Idx}); \
3681 assert(NewConfigC && "Failed to create new configuration environment"); \
3682 setConfigurationOfKernelEnvironment(cast<ConstantStruct>(NewConfigC)); \
3683 }
3684
3685 KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(UseGenericStateMachine)
3686 KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MayUseNestedParallelism)
3692
3693#undef KERNEL_ENVIRONMENT_CONFIGURATION_SETTER
3694
3695 /// See AbstractAttribute::initialize(...).
3696 void initialize(Attributor &A) override {
3697 // This is a high-level transform that might change the constant arguments
3698 // of the init and dinit calls. We need to tell the Attributor about this
3699 // to avoid other parts using the current constant value for simpliication.
3700 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3701
3702 Function *Fn = getAnchorScope();
3703
3704 OMPInformationCache::RuntimeFunctionInfo &InitRFI =
3705 OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
3706 OMPInformationCache::RuntimeFunctionInfo &DeinitRFI =
3707 OMPInfoCache.RFIs[OMPRTL___kmpc_target_deinit];
3708
3709 // For kernels we perform more initialization work, first we find the init
3710 // and deinit calls.
3711 auto StoreCallBase = [](Use &U,
3712 OMPInformationCache::RuntimeFunctionInfo &RFI,
3713 CallBase *&Storage) {
3714 CallBase *CB = OpenMPOpt::getCallIfRegularCall(U, &RFI);
3715 assert(CB &&
3716 "Unexpected use of __kmpc_target_init or __kmpc_target_deinit!");
3717 assert(!Storage &&
3718 "Multiple uses of __kmpc_target_init or __kmpc_target_deinit!");
3719 Storage = CB;
3720 return false;
3721 };
3722 InitRFI.foreachUse(
3723 [&](Use &U, Function &) {
3724 StoreCallBase(U, InitRFI, KernelInitCB);
3725 return false;
3726 },
3727 Fn);
3728 DeinitRFI.foreachUse(
3729 [&](Use &U, Function &) {
3730 StoreCallBase(U, DeinitRFI, KernelDeinitCB);
3731 return false;
3732 },
3733 Fn);
3734
3735 // Ignore kernels without initializers such as global constructors.
3736 if (!KernelInitCB || !KernelDeinitCB)
3737 return;
3738
3739 // Add itself to the reaching kernel and set IsKernelEntry.
3740 ReachingKernelEntries.insert(Fn);
3741 IsKernelEntry = true;
3742
3743 KernelEnvC =
3745 GlobalVariable *KernelEnvGV =
3747
3749 KernelConfigurationSimplifyCB =
3750 [&](const GlobalVariable &GV, const AbstractAttribute *AA,
3751 bool &UsedAssumedInformation) -> std::optional<Constant *> {
3752 if (!isAtFixpoint()) {
3753 if (!AA)
3754 return nullptr;
3755 UsedAssumedInformation = true;
3756 A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
3757 }
3758 return KernelEnvC;
3759 };
3760
3761 A.registerGlobalVariableSimplificationCallback(
3762 *KernelEnvGV, KernelConfigurationSimplifyCB);
3763
3764 // Check if we know we are in SPMD-mode already.
3765 ConstantInt *ExecModeC =
3766 KernelInfo::getExecModeFromKernelEnvironment(KernelEnvC);
3767 ConstantInt *AssumedExecModeC = ConstantInt::get(
3768 ExecModeC->getIntegerType(),
3770 if (ExecModeC->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD)
3771 SPMDCompatibilityTracker.indicateOptimisticFixpoint();
3773 // This is a generic region but SPMDization is disabled so stop
3774 // tracking.
3775 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3776 else
3777 setExecModeOfKernelEnvironment(AssumedExecModeC);
3778
3779 const Triple T(Fn->getParent()->getTargetTriple());
3780 auto *Int32Ty = Type::getInt32Ty(Fn->getContext());
3781 auto [MinThreads, MaxThreads] =
3783 if (MinThreads)
3784 setMinThreadsOfKernelEnvironment(ConstantInt::get(Int32Ty, MinThreads));
3785 if (MaxThreads)
3786 setMaxThreadsOfKernelEnvironment(ConstantInt::get(Int32Ty, MaxThreads));
3787 auto [MinTeams, MaxTeams] =
3789 if (MinTeams)
3790 setMinTeamsOfKernelEnvironment(ConstantInt::get(Int32Ty, MinTeams));
3791 if (MaxTeams)
3792 setMaxTeamsOfKernelEnvironment(ConstantInt::get(Int32Ty, MaxTeams));
3793
3794 ConstantInt *MayUseNestedParallelismC =
3795 KernelInfo::getMayUseNestedParallelismFromKernelEnvironment(KernelEnvC);
3796 ConstantInt *AssumedMayUseNestedParallelismC = ConstantInt::get(
3797 MayUseNestedParallelismC->getIntegerType(), NestedParallelism);
3798 setMayUseNestedParallelismOfKernelEnvironment(
3799 AssumedMayUseNestedParallelismC);
3800
3802 ConstantInt *UseGenericStateMachineC =
3803 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
3804 KernelEnvC);
3805 ConstantInt *AssumedUseGenericStateMachineC =
3806 ConstantInt::get(UseGenericStateMachineC->getIntegerType(), false);
3807 setUseGenericStateMachineOfKernelEnvironment(
3808 AssumedUseGenericStateMachineC);
3809 }
3810
3811 // Register virtual uses of functions we might need to preserve.
3812 auto RegisterVirtualUse = [&](RuntimeFunction RFKind,
3814 if (!OMPInfoCache.RFIs[RFKind].Declaration)
3815 return;
3816 A.registerVirtualUseCallback(*OMPInfoCache.RFIs[RFKind].Declaration, CB);
3817 };
3818
3819 // Add a dependence to ensure updates if the state changes.
3820 auto AddDependence = [](Attributor &A, const AAKernelInfo *KI,
3821 const AbstractAttribute *QueryingAA) {
3822 if (QueryingAA) {
3823 A.recordDependence(*KI, *QueryingAA, DepClassTy::OPTIONAL);
3824 }
3825 return true;
3826 };
3827
3828 Attributor::VirtualUseCallbackTy CustomStateMachineUseCB =
3829 [&](Attributor &A, const AbstractAttribute *QueryingAA) {
3830 // Whenever we create a custom state machine we will insert calls to
3831 // __kmpc_get_hardware_num_threads_in_block,
3832 // __kmpc_get_warp_size,
3833 // __kmpc_barrier_simple_generic,
3834 // __kmpc_kernel_parallel, and
3835 // __kmpc_kernel_end_parallel.
3836 // Not needed if we are on track for SPMDzation.
3837 if (SPMDCompatibilityTracker.isValidState())
3838 return AddDependence(A, this, QueryingAA);
3839 // Not needed if we can't rewrite due to an invalid state.
3840 if (!ReachedKnownParallelRegions.isValidState())
3841 return AddDependence(A, this, QueryingAA);
3842 return false;
3843 };
3844
3845 // Not needed if we are pre-runtime merge.
3846 if (!KernelInitCB->getCalledFunction()->isDeclaration()) {
3847 RegisterVirtualUse(OMPRTL___kmpc_get_hardware_num_threads_in_block,
3848 CustomStateMachineUseCB);
3849 RegisterVirtualUse(OMPRTL___kmpc_get_warp_size, CustomStateMachineUseCB);
3850 RegisterVirtualUse(OMPRTL___kmpc_barrier_simple_generic,
3851 CustomStateMachineUseCB);
3852 RegisterVirtualUse(OMPRTL___kmpc_kernel_parallel,
3853 CustomStateMachineUseCB);
3854 RegisterVirtualUse(OMPRTL___kmpc_kernel_end_parallel,
3855 CustomStateMachineUseCB);
3856 }
3857
3858 // If we do not perform SPMDzation we do not need the virtual uses below.
3859 if (SPMDCompatibilityTracker.isAtFixpoint())
3860 return;
3861
3862 Attributor::VirtualUseCallbackTy HWThreadIdUseCB =
3863 [&](Attributor &A, const AbstractAttribute *QueryingAA) {
3864 // Whenever we perform SPMDzation we will insert
3865 // __kmpc_get_hardware_thread_id_in_block calls.
3866 if (!SPMDCompatibilityTracker.isValidState())
3867 return AddDependence(A, this, QueryingAA);
3868 return false;
3869 };
3870 RegisterVirtualUse(OMPRTL___kmpc_get_hardware_thread_id_in_block,
3871 HWThreadIdUseCB);
3872
3873 Attributor::VirtualUseCallbackTy SPMDBarrierUseCB =
3874 [&](Attributor &A, const AbstractAttribute *QueryingAA) {
3875 // Whenever we perform SPMDzation with guarding we will insert
3876 // __kmpc_simple_barrier_spmd calls. If SPMDzation failed, there is
3877 // nothing to guard, or there are no parallel regions, we don't need
3878 // the calls.
3879 if (!SPMDCompatibilityTracker.isValidState())
3880 return AddDependence(A, this, QueryingAA);
3881 if (SPMDCompatibilityTracker.empty())
3882 return AddDependence(A, this, QueryingAA);
3883 if (!mayContainParallelRegion())
3884 return AddDependence(A, this, QueryingAA);
3885 return false;
3886 };
3887 RegisterVirtualUse(OMPRTL___kmpc_barrier_simple_spmd, SPMDBarrierUseCB);
3888 }
3889
3890 /// Sanitize the string \p S such that it is a suitable global symbol name.
3891 static std::string sanitizeForGlobalName(std::string S) {
3892 std::replace_if(
3893 S.begin(), S.end(),
3894 [](const char C) {
3895 return !((C >= 'a' && C <= 'z') || (C >= 'A' && C <= 'Z') ||
3896 (C >= '0' && C <= '9') || C == '_');
3897 },
3898 '.');
3899 return S;
3900 }
3901
3902 /// Modify the IR based on the KernelInfoState as the fixpoint iteration is
3903 /// finished now.
3904 ChangeStatus manifest(Attributor &A) override {
3905 // If we are not looking at a kernel with __kmpc_target_init and
3906 // __kmpc_target_deinit call we cannot actually manifest the information.
3907 if (!KernelInitCB || !KernelDeinitCB)
3908 return ChangeStatus::UNCHANGED;
3909
3910 ChangeStatus Changed = ChangeStatus::UNCHANGED;
3911
3912 bool HasBuiltStateMachine = true;
3913 if (!changeToSPMDMode(A, Changed)) {
3914 if (!KernelInitCB->getCalledFunction()->isDeclaration())
3915 HasBuiltStateMachine = buildCustomStateMachine(A, Changed);
3916 else
3917 HasBuiltStateMachine = false;
3918 }
3919
3920 // We need to reset KernelEnvC if specific rewriting is not done.
3921 ConstantStruct *ExistingKernelEnvC =
3923 ConstantInt *OldUseGenericStateMachineVal =
3924 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
3925 ExistingKernelEnvC);
3926 if (!HasBuiltStateMachine)
3927 setUseGenericStateMachineOfKernelEnvironment(
3928 OldUseGenericStateMachineVal);
3929
3930 // At last, update the KernelEnvc
3931 GlobalVariable *KernelEnvGV =
3933 if (KernelEnvGV->getInitializer() != KernelEnvC) {
3934 KernelEnvGV->setInitializer(KernelEnvC);
3935 Changed = ChangeStatus::CHANGED;
3936 }
3937
3938 return Changed;
3939 }
3940
3941 void insertInstructionGuardsHelper(Attributor &A) {
3942 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3943
3944 auto CreateGuardedRegion = [&](Instruction *RegionStartI,
3945 Instruction *RegionEndI) {
3946 LoopInfo *LI = nullptr;
3947 DominatorTree *DT = nullptr;
3948 MemorySSAUpdater *MSU = nullptr;
3949 using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
3950
3951 BasicBlock *ParentBB = RegionStartI->getParent();
3952 Function *Fn = ParentBB->getParent();
3953 Module &M = *Fn->getParent();
3954
3955 // Create all the blocks and logic.
3956 // ParentBB:
3957 // goto RegionCheckTidBB
3958 // RegionCheckTidBB:
3959 // Tid = __kmpc_hardware_thread_id()
3960 // if (Tid != 0)
3961 // goto RegionBarrierBB
3962 // RegionStartBB:
3963 // <execute instructions guarded>
3964 // goto RegionEndBB
3965 // RegionEndBB:
3966 // <store escaping values to shared mem>
3967 // goto RegionBarrierBB
3968 // RegionBarrierBB:
3969 // __kmpc_simple_barrier_spmd()
3970 // // second barrier is omitted if lacking escaping values.
3971 // <load escaping values from shared mem>
3972 // __kmpc_simple_barrier_spmd()
3973 // goto RegionExitBB
3974 // RegionExitBB:
3975 // <execute rest of instructions>
3976
3977 BasicBlock *RegionEndBB = SplitBlock(ParentBB, RegionEndI->getNextNode(),
3978 DT, LI, MSU, "region.guarded.end");
3979 BasicBlock *RegionBarrierBB =
3980 SplitBlock(RegionEndBB, &*RegionEndBB->getFirstInsertionPt(), DT, LI,
3981 MSU, "region.barrier");
3982 BasicBlock *RegionExitBB =
3983 SplitBlock(RegionBarrierBB, &*RegionBarrierBB->getFirstInsertionPt(),
3984 DT, LI, MSU, "region.exit");
3985 BasicBlock *RegionStartBB =
3986 SplitBlock(ParentBB, RegionStartI, DT, LI, MSU, "region.guarded");
3987
3988 assert(ParentBB->getUniqueSuccessor() == RegionStartBB &&
3989 "Expected a different CFG");
3990
3991 BasicBlock *RegionCheckTidBB = SplitBlock(
3992 ParentBB, ParentBB->getTerminator(), DT, LI, MSU, "region.check.tid");
3993
3994 // Register basic blocks with the Attributor.
3995 A.registerManifestAddedBasicBlock(*RegionEndBB);
3996 A.registerManifestAddedBasicBlock(*RegionBarrierBB);
3997 A.registerManifestAddedBasicBlock(*RegionExitBB);
3998 A.registerManifestAddedBasicBlock(*RegionStartBB);
3999 A.registerManifestAddedBasicBlock(*RegionCheckTidBB);
4000
4001 bool HasBroadcastValues = false;
4002 // Find escaping outputs from the guarded region to outside users and
4003 // broadcast their values to them.
4004 for (Instruction &I : *RegionStartBB) {
4005 SmallVector<Use *, 4> OutsideUses;
4006 for (Use &U : I.uses()) {
4007 Instruction &UsrI = *cast<Instruction>(U.getUser());
4008 if (UsrI.getParent() != RegionStartBB)
4009 OutsideUses.push_back(&U);
4010 }
4011
4012 if (OutsideUses.empty())
4013 continue;
4014
4015 HasBroadcastValues = true;
4016
4017 // Emit a global variable in shared memory to store the broadcasted
4018 // value.
4019 auto *SharedMem = new GlobalVariable(
4020 M, I.getType(), /* IsConstant */ false,
4022 sanitizeForGlobalName(
4023 (I.getName() + ".guarded.output.alloc").str()),
4025 static_cast<unsigned>(AddressSpace::Shared));
4026
4027 // Emit a store instruction to update the value.
4028 new StoreInst(&I, SharedMem,
4029 RegionEndBB->getTerminator()->getIterator());
4030
4031 LoadInst *LoadI = new LoadInst(
4032 I.getType(), SharedMem, I.getName() + ".guarded.output.load",
4033 RegionBarrierBB->getTerminator()->getIterator());
4034
4035 // Emit a load instruction and replace uses of the output value.
4036 for (Use *U : OutsideUses)
4037 A.changeUseAfterManifest(*U, *LoadI);
4038 }
4039
4040 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4041
4042 // Go to tid check BB in ParentBB.
4043 const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
4044 ParentBB->getTerminator()->eraseFromParent();
4046 InsertPointTy(ParentBB, ParentBB->end()), DL);
4047 OMPInfoCache.OMPBuilder.updateToLocation(Loc);
4048 uint32_t SrcLocStrSize;
4049 auto *SrcLocStr =
4050 OMPInfoCache.OMPBuilder.getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4051 Value *Ident =
4052 OMPInfoCache.OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4053 BranchInst::Create(RegionCheckTidBB, ParentBB)->setDebugLoc(DL);
4054
4055 // Add check for Tid in RegionCheckTidBB
4056 RegionCheckTidBB->getTerminator()->eraseFromParent();
4057 OpenMPIRBuilder::LocationDescription LocRegionCheckTid(
4058 InsertPointTy(RegionCheckTidBB, RegionCheckTidBB->end()), DL);
4059 OMPInfoCache.OMPBuilder.updateToLocation(LocRegionCheckTid);
4060 FunctionCallee HardwareTidFn =
4061 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4062 M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
4063 CallInst *Tid =
4064 OMPInfoCache.OMPBuilder.Builder.CreateCall(HardwareTidFn, {});
4065 Tid->setDebugLoc(DL);
4066 OMPInfoCache.setCallingConvention(HardwareTidFn, Tid);
4067 Value *TidCheck = OMPInfoCache.OMPBuilder.Builder.CreateIsNull(Tid);
4068 OMPInfoCache.OMPBuilder.Builder
4069 .CreateCondBr(TidCheck, RegionStartBB, RegionBarrierBB)
4070 ->setDebugLoc(DL);
4071
4072 // First barrier for synchronization, ensures main thread has updated
4073 // values.
4074 FunctionCallee BarrierFn =
4075 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4076 M, OMPRTL___kmpc_barrier_simple_spmd);
4077 OMPInfoCache.OMPBuilder.updateToLocation(InsertPointTy(
4078 RegionBarrierBB, RegionBarrierBB->getFirstInsertionPt()));
4079 CallInst *Barrier =
4080 OMPInfoCache.OMPBuilder.Builder.CreateCall(BarrierFn, {Ident, Tid});
4081 Barrier->setDebugLoc(DL);
4082 OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
4083
4084 // Second barrier ensures workers have read broadcast values.
4085 if (HasBroadcastValues) {
4086 CallInst *Barrier =
4087 CallInst::Create(BarrierFn, {Ident, Tid}, "",
4088 RegionBarrierBB->getTerminator()->getIterator());
4089 Barrier->setDebugLoc(DL);
4090 OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
4091 }
4092 };
4093
4094 auto &AllocSharedRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
4096 for (Instruction *GuardedI : SPMDCompatibilityTracker) {
4097 BasicBlock *BB = GuardedI->getParent();
4098 if (!Visited.insert(BB).second)
4099 continue;
4100
4102 Instruction *LastEffect = nullptr;
4103 BasicBlock::reverse_iterator IP = BB->rbegin(), IPEnd = BB->rend();
4104 while (++IP != IPEnd) {
4105 if (!IP->mayHaveSideEffects() && !IP->mayReadFromMemory())
4106 continue;
4107 Instruction *I = &*IP;
4108 if (OpenMPOpt::getCallIfRegularCall(*I, &AllocSharedRFI))
4109 continue;
4110 if (!I->user_empty() || !SPMDCompatibilityTracker.contains(I)) {
4111 LastEffect = nullptr;
4112 continue;
4113 }
4114 if (LastEffect)
4115 Reorders.push_back({I, LastEffect});
4116 LastEffect = &*IP;
4117 }
4118 for (auto &Reorder : Reorders)
4119 Reorder.first->moveBefore(Reorder.second);
4120 }
4121
4123
4124 for (Instruction *GuardedI : SPMDCompatibilityTracker) {
4125 BasicBlock *BB = GuardedI->getParent();
4126 auto *CalleeAA = A.lookupAAFor<AAKernelInfo>(
4127 IRPosition::function(*GuardedI->getFunction()), nullptr,
4128 DepClassTy::NONE);
4129 assert(CalleeAA != nullptr && "Expected Callee AAKernelInfo");
4130 auto &CalleeAAFunction = *cast<AAKernelInfoFunction>(CalleeAA);
4131 // Continue if instruction is already guarded.
4132 if (CalleeAAFunction.getGuardedInstructions().contains(GuardedI))
4133 continue;
4134
4135 Instruction *GuardedRegionStart = nullptr, *GuardedRegionEnd = nullptr;
4136 for (Instruction &I : *BB) {
4137 // If instruction I needs to be guarded update the guarded region
4138 // bounds.
4139 if (SPMDCompatibilityTracker.contains(&I)) {
4140 CalleeAAFunction.getGuardedInstructions().insert(&I);
4141 if (GuardedRegionStart)
4142 GuardedRegionEnd = &I;
4143 else
4144 GuardedRegionStart = GuardedRegionEnd = &I;
4145
4146 continue;
4147 }
4148
4149 // Instruction I does not need guarding, store
4150 // any region found and reset bounds.
4151 if (GuardedRegionStart) {
4152 GuardedRegions.push_back(
4153 std::make_pair(GuardedRegionStart, GuardedRegionEnd));
4154 GuardedRegionStart = nullptr;
4155 GuardedRegionEnd = nullptr;
4156 }
4157 }
4158 }
4159
4160 for (auto &GR : GuardedRegions)
4161 CreateGuardedRegion(GR.first, GR.second);
4162 }
4163
4164 void forceSingleThreadPerWorkgroupHelper(Attributor &A) {
4165 // Only allow 1 thread per workgroup to continue executing the user code.
4166 //
4167 // InitCB = __kmpc_target_init(...)
4168 // ThreadIdInBlock = __kmpc_get_hardware_thread_id_in_block();
4169 // if (ThreadIdInBlock != 0) return;
4170 // UserCode:
4171 // // user code
4172 //
4173 auto &Ctx = getAnchorValue().getContext();
4174 Function *Kernel = getAssociatedFunction();
4175 assert(Kernel && "Expected an associated function!");
4176
4177 // Create block for user code to branch to from initial block.
4178 BasicBlock *InitBB = KernelInitCB->getParent();
4179 BasicBlock *UserCodeBB = InitBB->splitBasicBlock(
4180 KernelInitCB->getNextNode(), "main.thread.user_code");
4181 BasicBlock *ReturnBB =
4182 BasicBlock::Create(Ctx, "exit.threads", Kernel, UserCodeBB);
4183
4184 // Register blocks with attributor:
4185 A.registerManifestAddedBasicBlock(*InitBB);
4186 A.registerManifestAddedBasicBlock(*UserCodeBB);
4187 A.registerManifestAddedBasicBlock(*ReturnBB);
4188
4189 // Debug location:
4190 const DebugLoc &DLoc = KernelInitCB->getDebugLoc();
4191 ReturnInst::Create(Ctx, ReturnBB)->setDebugLoc(DLoc);
4192 InitBB->getTerminator()->eraseFromParent();
4193
4194 // Prepare call to OMPRTL___kmpc_get_hardware_thread_id_in_block.
4195 Module &M = *Kernel->getParent();
4196 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4197 FunctionCallee ThreadIdInBlockFn =
4198 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4199 M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
4200
4201 // Get thread ID in block.
4202 CallInst *ThreadIdInBlock =
4203 CallInst::Create(ThreadIdInBlockFn, "thread_id.in.block", InitBB);
4204 OMPInfoCache.setCallingConvention(ThreadIdInBlockFn, ThreadIdInBlock);
4205 ThreadIdInBlock->setDebugLoc(DLoc);
4206
4207 // Eliminate all threads in the block with ID not equal to 0:
4208 Instruction *IsMainThread =
4209 ICmpInst::Create(ICmpInst::ICmp, CmpInst::ICMP_NE, ThreadIdInBlock,
4210 ConstantInt::get(ThreadIdInBlock->getType(), 0),
4211 "thread.is_main", InitBB);
4212 IsMainThread->setDebugLoc(DLoc);
4213 BranchInst::Create(ReturnBB, UserCodeBB, IsMainThread, InitBB);
4214 }
4215
4216 bool changeToSPMDMode(Attributor &A, ChangeStatus &Changed) {
4217 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4218
4219 // We cannot change to SPMD mode if the runtime functions aren't availible.
4220 if (!OMPInfoCache.runtimeFnsAvailable(
4221 {OMPRTL___kmpc_get_hardware_thread_id_in_block,
4222 OMPRTL___kmpc_barrier_simple_spmd}))
4223 return false;
4224
4225 if (!SPMDCompatibilityTracker.isAssumed()) {
4226 for (Instruction *NonCompatibleI : SPMDCompatibilityTracker) {
4227 if (!NonCompatibleI)
4228 continue;
4229
4230 // Skip diagnostics on calls to known OpenMP runtime functions for now.
4231 if (auto *CB = dyn_cast<CallBase>(NonCompatibleI))
4232 if (OMPInfoCache.RTLFunctions.contains(CB->getCalledFunction()))
4233 continue;
4234
4235 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
4236 ORA << "Value has potential side effects preventing SPMD-mode "
4237 "execution";
4238 if (isa<CallBase>(NonCompatibleI)) {
4239 ORA << ". Add `[[omp::assume(\"ompx_spmd_amenable\")]]` to "
4240 "the called function to override";
4241 }
4242 return ORA << ".";
4243 };
4244 A.emitRemark<OptimizationRemarkAnalysis>(NonCompatibleI, "OMP121",
4245 Remark);
4246
4247 LLVM_DEBUG(dbgs() << TAG << "SPMD-incompatible side-effect: "
4248 << *NonCompatibleI << "\n");
4249 }
4250
4251 return false;
4252 }
4253
4254 // Get the actual kernel, could be the caller of the anchor scope if we have
4255 // a debug wrapper.
4256 Function *Kernel = getAnchorScope();
4257 if (Kernel->hasLocalLinkage()) {
4258 assert(Kernel->hasOneUse() && "Unexpected use of debug kernel wrapper.");
4259 auto *CB = cast<CallBase>(Kernel->user_back());
4260 Kernel = CB->getCaller();
4261 }
4262 assert(omp::isOpenMPKernel(*Kernel) && "Expected kernel function!");
4263
4264 // Check if the kernel is already in SPMD mode, if so, return success.
4265 ConstantStruct *ExistingKernelEnvC =
4267 auto *ExecModeC =
4268 KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC);
4269 const int8_t ExecModeVal = ExecModeC->getSExtValue();
4270 if (ExecModeVal != OMP_TGT_EXEC_MODE_GENERIC)
4271 return true;
4272
4273 // We will now unconditionally modify the IR, indicate a change.
4274 Changed = ChangeStatus::CHANGED;
4275
4276 // Do not use instruction guards when no parallel is present inside
4277 // the target region.
4278 if (mayContainParallelRegion())
4279 insertInstructionGuardsHelper(A);
4280 else
4281 forceSingleThreadPerWorkgroupHelper(A);
4282
4283 // Adjust the global exec mode flag that tells the runtime what mode this
4284 // kernel is executed in.
4285 assert(ExecModeVal == OMP_TGT_EXEC_MODE_GENERIC &&
4286 "Initially non-SPMD kernel has SPMD exec mode!");
4287 setExecModeOfKernelEnvironment(
4288 ConstantInt::get(ExecModeC->getIntegerType(),
4289 ExecModeVal | OMP_TGT_EXEC_MODE_GENERIC_SPMD));
4290
4291 ++NumOpenMPTargetRegionKernelsSPMD;
4292
4293 auto Remark = [&](OptimizationRemark OR) {
4294 return OR << "Transformed generic-mode kernel to SPMD-mode.";
4295 };
4296 A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP120", Remark);
4297 return true;
4298 };
4299
4300 bool buildCustomStateMachine(Attributor &A, ChangeStatus &Changed) {
4301 // If we have disabled state machine rewrites, don't make a custom one
4303 return false;
4304
4305 // Don't rewrite the state machine if we are not in a valid state.
4306 if (!ReachedKnownParallelRegions.isValidState())
4307 return false;
4308
4309 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4310 if (!OMPInfoCache.runtimeFnsAvailable(
4311 {OMPRTL___kmpc_get_hardware_num_threads_in_block,
4312 OMPRTL___kmpc_get_warp_size, OMPRTL___kmpc_barrier_simple_generic,
4313 OMPRTL___kmpc_kernel_parallel, OMPRTL___kmpc_kernel_end_parallel}))
4314 return false;
4315
4316 ConstantStruct *ExistingKernelEnvC =
4318
4319 // Check if the current configuration is non-SPMD and generic state machine.
4320 // If we already have SPMD mode or a custom state machine we do not need to
4321 // go any further. If it is anything but a constant something is weird and
4322 // we give up.
4323 ConstantInt *UseStateMachineC =
4324 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
4325 ExistingKernelEnvC);
4326 ConstantInt *ModeC =
4327 KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC);
4328
4329 // If we are stuck with generic mode, try to create a custom device (=GPU)
4330 // state machine which is specialized for the parallel regions that are
4331 // reachable by the kernel.
4332 if (UseStateMachineC->isZero() ||
4334 return false;
4335
4336 Changed = ChangeStatus::CHANGED;
4337
4338 // If not SPMD mode, indicate we use a custom state machine now.
4339 setUseGenericStateMachineOfKernelEnvironment(
4340 ConstantInt::get(UseStateMachineC->getIntegerType(), false));
4341
4342 // If we don't actually need a state machine we are done here. This can
4343 // happen if there simply are no parallel regions. In the resulting kernel
4344 // all worker threads will simply exit right away, leaving the main thread
4345 // to do the work alone.
4346 if (!mayContainParallelRegion()) {
4347 ++NumOpenMPTargetRegionKernelsWithoutStateMachine;
4348
4349 auto Remark = [&](OptimizationRemark OR) {
4350 return OR << "Removing unused state machine from generic-mode kernel.";
4351 };
4352 A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP130", Remark);
4353
4354 return true;
4355 }
4356
4357 // Keep track in the statistics of our new shiny custom state machine.
4358 if (ReachedUnknownParallelRegions.empty()) {
4359 ++NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback;
4360
4361 auto Remark = [&](OptimizationRemark OR) {
4362 return OR << "Rewriting generic-mode kernel with a customized state "
4363 "machine.";
4364 };
4365 A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP131", Remark);
4366 } else {
4367 ++NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback;
4368
4370 return OR << "Generic-mode kernel is executed with a customized state "
4371 "machine that requires a fallback.";
4372 };
4373 A.emitRemark<OptimizationRemarkAnalysis>(KernelInitCB, "OMP132", Remark);
4374
4375 // Tell the user why we ended up with a fallback.
4376 for (CallBase *UnknownParallelRegionCB : ReachedUnknownParallelRegions) {
4377 if (!UnknownParallelRegionCB)
4378 continue;
4379 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
4380 return ORA << "Call may contain unknown parallel regions. Use "
4381 << "`[[omp::assume(\"omp_no_parallelism\")]]` to "
4382 "override.";
4383 };
4384 A.emitRemark<OptimizationRemarkAnalysis>(UnknownParallelRegionCB,
4385 "OMP133", Remark);
4386 }
4387 }
4388
4389 // Create all the blocks:
4390 //
4391 // InitCB = __kmpc_target_init(...)
4392 // BlockHwSize =
4393 // __kmpc_get_hardware_num_threads_in_block();
4394 // WarpSize = __kmpc_get_warp_size();
4395 // BlockSize = BlockHwSize - WarpSize;
4396 // IsWorkerCheckBB: bool IsWorker = InitCB != -1;
4397 // if (IsWorker) {
4398 // if (InitCB >= BlockSize) return;
4399 // SMBeginBB: __kmpc_barrier_simple_generic(...);
4400 // void *WorkFn;
4401 // bool Active = __kmpc_kernel_parallel(&WorkFn);
4402 // if (!WorkFn) return;
4403 // SMIsActiveCheckBB: if (Active) {
4404 // SMIfCascadeCurrentBB: if (WorkFn == <ParFn0>)
4405 // ParFn0(...);
4406 // SMIfCascadeCurrentBB: else if (WorkFn == <ParFn1>)
4407 // ParFn1(...);
4408 // ...
4409 // SMIfCascadeCurrentBB: else
4410 // ((WorkFnTy*)WorkFn)(...);
4411 // SMEndParallelBB: __kmpc_kernel_end_parallel(...);
4412 // }
4413 // SMDoneBB: __kmpc_barrier_simple_generic(...);
4414 // goto SMBeginBB;
4415 // }
4416 // UserCodeEntryBB: // user code
4417 // __kmpc_target_deinit(...)
4418 //
4419 auto &Ctx = getAnchorValue().getContext();
4420 Function *Kernel = getAssociatedFunction();
4421 assert(Kernel && "Expected an associated function!");
4422
4423 BasicBlock *InitBB = KernelInitCB->getParent();
4424 BasicBlock *UserCodeEntryBB = InitBB->splitBasicBlock(
4425 KernelInitCB->getNextNode(), "thread.user_code.check");
4426 BasicBlock *IsWorkerCheckBB =
4427 BasicBlock::Create(Ctx, "is_worker_check", Kernel, UserCodeEntryBB);
4428 BasicBlock *StateMachineBeginBB = BasicBlock::Create(
4429 Ctx, "worker_state_machine.begin", Kernel, UserCodeEntryBB);
4430 BasicBlock *StateMachineFinishedBB = BasicBlock::Create(
4431 Ctx, "worker_state_machine.finished", Kernel, UserCodeEntryBB);
4432 BasicBlock *StateMachineIsActiveCheckBB = BasicBlock::Create(
4433 Ctx, "worker_state_machine.is_active.check", Kernel, UserCodeEntryBB);
4434 BasicBlock *StateMachineIfCascadeCurrentBB =
4435 BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
4436 Kernel, UserCodeEntryBB);
4437 BasicBlock *StateMachineEndParallelBB =
4438 BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.end",
4439 Kernel, UserCodeEntryBB);
4440 BasicBlock *StateMachineDoneBarrierBB = BasicBlock::Create(
4441 Ctx, "worker_state_machine.done.barrier", Kernel, UserCodeEntryBB);
4442 A.registerManifestAddedBasicBlock(*InitBB);
4443 A.registerManifestAddedBasicBlock(*UserCodeEntryBB);
4444 A.registerManifestAddedBasicBlock(*IsWorkerCheckBB);
4445 A.registerManifestAddedBasicBlock(*StateMachineBeginBB);
4446 A.registerManifestAddedBasicBlock(*StateMachineFinishedBB);
4447 A.registerManifestAddedBasicBlock(*StateMachineIsActiveCheckBB);
4448 A.registerManifestAddedBasicBlock(*StateMachineIfCascadeCurrentBB);
4449 A.registerManifestAddedBasicBlock(*StateMachineEndParallelBB);
4450 A.registerManifestAddedBasicBlock(*StateMachineDoneBarrierBB);
4451
4452 const DebugLoc &DLoc = KernelInitCB->getDebugLoc();
4453 ReturnInst::Create(Ctx, StateMachineFinishedBB)->setDebugLoc(DLoc);
4454 InitBB->getTerminator()->eraseFromParent();
4455
4456 Instruction *IsWorker =
4457 ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_NE, KernelInitCB,
4458 ConstantInt::get(KernelInitCB->getType(), -1),
4459 "thread.is_worker", InitBB);
4460 IsWorker->setDebugLoc(DLoc);
4461 BranchInst::Create(IsWorkerCheckBB, UserCodeEntryBB, IsWorker, InitBB);
4462
4463 Module &M = *Kernel->getParent();
4464 FunctionCallee BlockHwSizeFn =
4465 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4466 M, OMPRTL___kmpc_get_hardware_num_threads_in_block);
4467 FunctionCallee WarpSizeFn =
4468 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4469 M, OMPRTL___kmpc_get_warp_size);
4470 CallInst *BlockHwSize =
4471 CallInst::Create(BlockHwSizeFn, "block.hw_size", IsWorkerCheckBB);
4472 OMPInfoCache.setCallingConvention(BlockHwSizeFn, BlockHwSize);
4473 BlockHwSize->setDebugLoc(DLoc);
4474 CallInst *WarpSize =
4475 CallInst::Create(WarpSizeFn, "warp.size", IsWorkerCheckBB);
4476 OMPInfoCache.setCallingConvention(WarpSizeFn, WarpSize);
4477 WarpSize->setDebugLoc(DLoc);
4478 Instruction *BlockSize = BinaryOperator::CreateSub(
4479 BlockHwSize, WarpSize, "block.size", IsWorkerCheckBB);
4480 BlockSize->setDebugLoc(DLoc);
4481 Instruction *IsMainOrWorker = ICmpInst::Create(
4482 ICmpInst::ICmp, llvm::CmpInst::ICMP_SLT, KernelInitCB, BlockSize,
4483 "thread.is_main_or_worker", IsWorkerCheckBB);
4484 IsMainOrWorker->setDebugLoc(DLoc);
4485 BranchInst::Create(StateMachineBeginBB, StateMachineFinishedBB,
4486 IsMainOrWorker, IsWorkerCheckBB);
4487
4488 // Create local storage for the work function pointer.
4489 const DataLayout &DL = M.getDataLayout();
4490 Type *VoidPtrTy = PointerType::getUnqual(Ctx);
4491 Instruction *WorkFnAI =
4492 new AllocaInst(VoidPtrTy, DL.getAllocaAddrSpace(), nullptr,
4493 "worker.work_fn.addr", Kernel->getEntryBlock().begin());
4494 WorkFnAI->setDebugLoc(DLoc);
4495
4496 OMPInfoCache.OMPBuilder.updateToLocation(
4498 IRBuilder<>::InsertPoint(StateMachineBeginBB,
4499 StateMachineBeginBB->end()),
4500 DLoc));
4501
4502 Value *Ident = KernelInfo::getIdentFromKernelEnvironment(KernelEnvC);
4503 Value *GTid = KernelInitCB;
4504
4505 FunctionCallee BarrierFn =
4506 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4507 M, OMPRTL___kmpc_barrier_simple_generic);
4508 CallInst *Barrier =
4509 CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineBeginBB);
4510 OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
4511 Barrier->setDebugLoc(DLoc);
4512
4513 if (WorkFnAI->getType()->getPointerAddressSpace() !=
4514 (unsigned int)AddressSpace::Generic) {
4515 WorkFnAI = new AddrSpaceCastInst(
4516 WorkFnAI, PointerType::get(Ctx, (unsigned int)AddressSpace::Generic),
4517 WorkFnAI->getName() + ".generic", StateMachineBeginBB);
4518 WorkFnAI->setDebugLoc(DLoc);
4519 }
4520
4521 FunctionCallee KernelParallelFn =
4522 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4523 M, OMPRTL___kmpc_kernel_parallel);
4524 CallInst *IsActiveWorker = CallInst::Create(
4525 KernelParallelFn, {WorkFnAI}, "worker.is_active", StateMachineBeginBB);
4526 OMPInfoCache.setCallingConvention(KernelParallelFn, IsActiveWorker);
4527 IsActiveWorker->setDebugLoc(DLoc);
4528 Instruction *WorkFn = new LoadInst(VoidPtrTy, WorkFnAI, "worker.work_fn",
4529 StateMachineBeginBB);
4530 WorkFn->setDebugLoc(DLoc);
4531
4532 FunctionType *ParallelRegionFnTy = FunctionType::get(
4534 false);
4535
4536 Instruction *IsDone =
4537 ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFn,
4538 Constant::getNullValue(VoidPtrTy), "worker.is_done",
4539 StateMachineBeginBB);
4540 IsDone->setDebugLoc(DLoc);
4541 BranchInst::Create(StateMachineFinishedBB, StateMachineIsActiveCheckBB,
4542 IsDone, StateMachineBeginBB)
4543 ->setDebugLoc(DLoc);
4544
4545 BranchInst::Create(StateMachineIfCascadeCurrentBB,
4546 StateMachineDoneBarrierBB, IsActiveWorker,
4547 StateMachineIsActiveCheckBB)
4548 ->setDebugLoc(DLoc);
4549
4550 Value *ZeroArg =
4551 Constant::getNullValue(ParallelRegionFnTy->getParamType(0));
4552
4553 const unsigned int WrapperFunctionArgNo = 6;
4554
4555 // Now that we have most of the CFG skeleton it is time for the if-cascade
4556 // that checks the function pointer we got from the runtime against the
4557 // parallel regions we expect, if there are any.
4558 for (int I = 0, E = ReachedKnownParallelRegions.size(); I < E; ++I) {
4559 auto *CB = ReachedKnownParallelRegions[I];
4560 auto *ParallelRegion = dyn_cast<Function>(
4561 CB->getArgOperand(WrapperFunctionArgNo)->stripPointerCasts());
4562 BasicBlock *PRExecuteBB = BasicBlock::Create(
4563 Ctx, "worker_state_machine.parallel_region.execute", Kernel,
4564 StateMachineEndParallelBB);
4565 CallInst::Create(ParallelRegion, {ZeroArg, GTid}, "", PRExecuteBB)
4566 ->setDebugLoc(DLoc);
4567 BranchInst::Create(StateMachineEndParallelBB, PRExecuteBB)
4568 ->setDebugLoc(DLoc);
4569
4570 BasicBlock *PRNextBB =
4571 BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
4572 Kernel, StateMachineEndParallelBB);
4573 A.registerManifestAddedBasicBlock(*PRExecuteBB);
4574 A.registerManifestAddedBasicBlock(*PRNextBB);
4575
4576 // Check if we need to compare the pointer at all or if we can just
4577 // call the parallel region function.
4578 Value *IsPR;
4579 if (I + 1 < E || !ReachedUnknownParallelRegions.empty()) {
4580 Instruction *CmpI = ICmpInst::Create(
4581 ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFn, ParallelRegion,
4582 "worker.check_parallel_region", StateMachineIfCascadeCurrentBB);
4583 CmpI->setDebugLoc(DLoc);
4584 IsPR = CmpI;
4585 } else {
4586 IsPR = ConstantInt::getTrue(Ctx);
4587 }
4588
4589 BranchInst::Create(PRExecuteBB, PRNextBB, IsPR,
4590 StateMachineIfCascadeCurrentBB)
4591 ->setDebugLoc(DLoc);
4592 StateMachineIfCascadeCurrentBB = PRNextBB;
4593 }
4594
4595 // At the end of the if-cascade we place the indirect function pointer call
4596 // in case we might need it, that is if there can be parallel regions we
4597 // have not handled in the if-cascade above.
4598 if (!ReachedUnknownParallelRegions.empty()) {
4599 StateMachineIfCascadeCurrentBB->setName(
4600 "worker_state_machine.parallel_region.fallback.execute");
4601 CallInst::Create(ParallelRegionFnTy, WorkFn, {ZeroArg, GTid}, "",
4602 StateMachineIfCascadeCurrentBB)
4603 ->setDebugLoc(DLoc);
4604 }
4605 BranchInst::Create(StateMachineEndParallelBB,
4606 StateMachineIfCascadeCurrentBB)
4607 ->setDebugLoc(DLoc);
4608
4609 FunctionCallee EndParallelFn =
4610 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4611 M, OMPRTL___kmpc_kernel_end_parallel);
4612 CallInst *EndParallel =
4613 CallInst::Create(EndParallelFn, {}, "", StateMachineEndParallelBB);
4614 OMPInfoCache.setCallingConvention(EndParallelFn, EndParallel);
4615 EndParallel->setDebugLoc(DLoc);
4616 BranchInst::Create(StateMachineDoneBarrierBB, StateMachineEndParallelBB)
4617 ->setDebugLoc(DLoc);
4618
4619 CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineDoneBarrierBB)
4620 ->setDebugLoc(DLoc);
4621 BranchInst::Create(StateMachineBeginBB, StateMachineDoneBarrierBB)
4622 ->setDebugLoc(DLoc);
4623
4624 return true;
4625 }
4626
4627 /// Fixpoint iteration update function. Will be called every time a dependence
4628 /// changed its state (and in the beginning).
4629 ChangeStatus updateImpl(Attributor &A) override {
4630 KernelInfoState StateBefore = getState();
4631
4632 // When we leave this function this RAII will make sure the member
4633 // KernelEnvC is updated properly depending on the state. That member is
4634 // used for simplification of values and needs to be up to date at all
4635 // times.
4636 struct UpdateKernelEnvCRAII {
4637 AAKernelInfoFunction &AA;
4638
4639 UpdateKernelEnvCRAII(AAKernelInfoFunction &AA) : AA(AA) {}
4640
4641 ~UpdateKernelEnvCRAII() {
4642 if (!AA.KernelEnvC)
4643 return;
4644
4645 ConstantStruct *ExistingKernelEnvC =
4647
4648 if (!AA.isValidState()) {
4649 AA.KernelEnvC = ExistingKernelEnvC;
4650 return;
4651 }
4652
4653 if (!AA.ReachedKnownParallelRegions.isValidState())
4654 AA.setUseGenericStateMachineOfKernelEnvironment(
4655 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
4656 ExistingKernelEnvC));
4657
4658 if (!AA.SPMDCompatibilityTracker.isValidState())
4659 AA.setExecModeOfKernelEnvironment(
4660 KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC));
4661
4662 ConstantInt *MayUseNestedParallelismC =
4663 KernelInfo::getMayUseNestedParallelismFromKernelEnvironment(
4664 AA.KernelEnvC);
4665 ConstantInt *NewMayUseNestedParallelismC = ConstantInt::get(
4666 MayUseNestedParallelismC->getIntegerType(), AA.NestedParallelism);
4667 AA.setMayUseNestedParallelismOfKernelEnvironment(
4668 NewMayUseNestedParallelismC);
4669 }
4670 } RAII(*this);
4671
4672 // Callback to check a read/write instruction.
4673 auto CheckRWInst = [&](Instruction &I) {
4674 // We handle calls later.
4675 if (isa<CallBase>(I))
4676 return true;
4677 // We only care about write effects.
4678 if (!I.mayWriteToMemory())
4679 return true;
4680 if (auto *SI = dyn_cast<StoreInst>(&I)) {
4681 const auto *UnderlyingObjsAA = A.getAAFor<AAUnderlyingObjects>(
4682 *this, IRPosition::value(*SI->getPointerOperand()),
4683 DepClassTy::OPTIONAL);
4684 auto *HS = A.getAAFor<AAHeapToStack>(
4685 *this, IRPosition::function(*I.getFunction()),
4686 DepClassTy::OPTIONAL);
4687 if (UnderlyingObjsAA &&
4688 UnderlyingObjsAA->forallUnderlyingObjects([&](Value &Obj) {
4689 if (AA::isAssumedThreadLocalObject(A, Obj, *this))
4690 return true;
4691 // Check for AAHeapToStack moved objects which must not be
4692 // guarded.
4693 auto *CB = dyn_cast<CallBase>(&Obj);
4694 return CB && HS && HS->isAssumedHeapToStack(*CB);
4695 }))
4696 return true;
4697 }
4698
4699 // Insert instruction that needs guarding.
4700 SPMDCompatibilityTracker.insert(&I);
4701 return true;
4702 };
4703
4704 bool UsedAssumedInformationInCheckRWInst = false;
4705 if (!SPMDCompatibilityTracker.isAtFixpoint())
4706 if (!A.checkForAllReadWriteInstructions(
4707 CheckRWInst, *this, UsedAssumedInformationInCheckRWInst))
4708 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4709
4710 bool UsedAssumedInformationFromReachingKernels = false;
4711 if (!IsKernelEntry) {
4712 updateParallelLevels(A);
4713
4714 bool AllReachingKernelsKnown = true;
4715 updateReachingKernelEntries(A, AllReachingKernelsKnown);
4716 UsedAssumedInformationFromReachingKernels = !AllReachingKernelsKnown;
4717
4718 if (!SPMDCompatibilityTracker.empty()) {
4719 if (!ParallelLevels.isValidState())
4720 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4721 else if (!ReachingKernelEntries.isValidState())
4722 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4723 else {
4724 // Check if all reaching kernels agree on the mode as we can otherwise
4725 // not guard instructions. We might not be sure about the mode so we
4726 // we cannot fix the internal spmd-zation state either.
4727 int SPMD = 0, Generic = 0;
4728 for (auto *Kernel : ReachingKernelEntries) {
4729 auto *CBAA = A.getAAFor<AAKernelInfo>(
4730 *this, IRPosition::function(*Kernel), DepClassTy::OPTIONAL);
4731 if (CBAA && CBAA->SPMDCompatibilityTracker.isValidState() &&
4732 CBAA->SPMDCompatibilityTracker.isAssumed())
4733 ++SPMD;
4734 else
4735 ++Generic;
4736 if (!CBAA || !CBAA->SPMDCompatibilityTracker.isAtFixpoint())
4737 UsedAssumedInformationFromReachingKernels = true;
4738 }
4739 if (SPMD != 0 && Generic != 0)
4740 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4741 }
4742 }
4743 }
4744
4745 // Callback to check a call instruction.
4746 bool AllParallelRegionStatesWereFixed = true;
4747 bool AllSPMDStatesWereFixed = true;
4748 auto CheckCallInst = [&](Instruction &I) {
4749 auto &CB = cast<CallBase>(I);
4750 auto *CBAA = A.getAAFor<AAKernelInfo>(
4751 *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);
4752 if (!CBAA)
4753 return false;
4754 getState() ^= CBAA->getState();
4755 AllSPMDStatesWereFixed &= CBAA->SPMDCompatibilityTracker.isAtFixpoint();
4756 AllParallelRegionStatesWereFixed &=
4757 CBAA->ReachedKnownParallelRegions.isAtFixpoint();
4758 AllParallelRegionStatesWereFixed &=
4759 CBAA->ReachedUnknownParallelRegions.isAtFixpoint();
4760 return true;
4761 };
4762
4763 bool UsedAssumedInformationInCheckCallInst = false;
4764 if (!A.checkForAllCallLikeInstructions(
4765 CheckCallInst, *this, UsedAssumedInformationInCheckCallInst)) {
4766 LLVM_DEBUG(dbgs() << TAG
4767 << "Failed to visit all call-like instructions!\n";);
4768 return indicatePessimisticFixpoint();
4769 }
4770
4771 // If we haven't used any assumed information for the reached parallel
4772 // region states we can fix it.
4773 if (!UsedAssumedInformationInCheckCallInst &&
4774 AllParallelRegionStatesWereFixed) {
4775 ReachedKnownParallelRegions.indicateOptimisticFixpoint();
4776 ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
4777 }
4778
4779 // If we haven't used any assumed information for the SPMD state we can fix
4780 // it.
4781 if (!UsedAssumedInformationInCheckRWInst &&
4782 !UsedAssumedInformationInCheckCallInst &&
4783 !UsedAssumedInformationFromReachingKernels && AllSPMDStatesWereFixed)
4784 SPMDCompatibilityTracker.indicateOptimisticFixpoint();
4785
4786 return StateBefore == getState() ? ChangeStatus::UNCHANGED
4787 : ChangeStatus::CHANGED;
4788 }
4789
4790private:
4791 /// Update info regarding reaching kernels.
4792 void updateReachingKernelEntries(Attributor &A,
4793 bool &AllReachingKernelsKnown) {
4794 auto PredCallSite = [&](AbstractCallSite ACS) {
4795 Function *Caller = ACS.getInstruction()->getFunction();
4796
4797 assert(Caller && "Caller is nullptr");
4798
4799 auto *CAA = A.getOrCreateAAFor<AAKernelInfo>(
4800 IRPosition::function(*Caller), this, DepClassTy::REQUIRED);
4801 if (CAA && CAA->ReachingKernelEntries.isValidState()) {
4802 ReachingKernelEntries ^= CAA->ReachingKernelEntries;
4803 return true;
4804 }
4805
4806 // We lost track of the caller of the associated function, any kernel
4807 // could reach now.
4808 ReachingKernelEntries.indicatePessimisticFixpoint();
4809
4810 return true;
4811 };
4812
4813 if (!A.checkForAllCallSites(PredCallSite, *this,
4814 true /* RequireAllCallSites */,
4815 AllReachingKernelsKnown))
4816 ReachingKernelEntries.indicatePessimisticFixpoint();
4817 }
4818
4819 /// Update info regarding parallel levels.
4820 void updateParallelLevels(Attributor &A) {
4821 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4822 OMPInformationCache::RuntimeFunctionInfo &Parallel51RFI =
4823 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
4824
4825 auto PredCallSite = [&](AbstractCallSite ACS) {
4826 Function *Caller = ACS.getInstruction()->getFunction();
4827
4828 assert(Caller && "Caller is nullptr");
4829
4830 auto *CAA =
4831 A.getOrCreateAAFor<AAKernelInfo>(IRPosition::function(*Caller));
4832 if (CAA && CAA->ParallelLevels.isValidState()) {
4833 // Any function that is called by `__kmpc_parallel_51` will not be
4834 // folded as the parallel level in the function is updated. In order to
4835 // get it right, all the analysis would depend on the implentation. That
4836 // said, if in the future any change to the implementation, the analysis
4837 // could be wrong. As a consequence, we are just conservative here.
4838 if (Caller == Parallel51RFI.Declaration) {
4839 ParallelLevels.indicatePessimisticFixpoint();
4840 return true;
4841 }
4842
4843 ParallelLevels ^= CAA->ParallelLevels;
4844
4845 return true;
4846 }
4847
4848 // We lost track of the caller of the associated function, any kernel
4849 // could reach now.
4850 ParallelLevels.indicatePessimisticFixpoint();
4851
4852 return true;
4853 };
4854
4855 bool AllCallSitesKnown = true;
4856 if (!A.checkForAllCallSites(PredCallSite, *this,
4857 true /* RequireAllCallSites */,
4858 AllCallSitesKnown))
4859 ParallelLevels.indicatePessimisticFixpoint();
4860 }
4861};
4862
4863/// The call site kernel info abstract attribute, basically, what can we say
4864/// about a call site with regards to the KernelInfoState. For now this simply
4865/// forwards the information from the callee.
4866struct AAKernelInfoCallSite : AAKernelInfo {
4867 AAKernelInfoCallSite(const IRPosition &IRP, Attributor &A)
4868 : AAKernelInfo(IRP, A) {}
4869
4870 /// See AbstractAttribute::initialize(...).
4871 void initialize(Attributor &A) override {
4872 AAKernelInfo::initialize(A);
4873
4874 CallBase &CB = cast<CallBase>(getAssociatedValue());
4875 auto *AssumptionAA = A.getAAFor<AAAssumptionInfo>(
4876 *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);
4877
4878 // Check for SPMD-mode assumptions.
4879 if (AssumptionAA && AssumptionAA->hasAssumption("ompx_spmd_amenable")) {
4880 indicateOptimisticFixpoint();
4881 return;
4882 }
4883
4884 // First weed out calls we do not care about, that is readonly/readnone
4885 // calls, intrinsics, and "no_openmp" calls. Neither of these can reach a
4886 // parallel region or anything else we are looking for.
4887 if (!CB.mayWriteToMemory() || isa<IntrinsicInst>(CB)) {
4888 indicateOptimisticFixpoint();
4889 return;
4890 }
4891
4892 // Next we check if we know the callee. If it is a known OpenMP function
4893 // we will handle them explicitly in the switch below. If it is not, we
4894 // will use an AAKernelInfo object on the callee to gather information and
4895 // merge that into the current state. The latter happens in the updateImpl.
4896 auto CheckCallee = [&](Function *Callee, unsigned NumCallees) {
4897 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4898 const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
4899 if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
4900 // Unknown caller or declarations are not analyzable, we give up.
4901 if (!Callee || !A.isFunctionIPOAmendable(*Callee)) {
4902
4903 // Unknown callees might contain parallel regions, except if they have
4904 // an appropriate assumption attached.
4905 if (!AssumptionAA ||
4906 !(AssumptionAA->hasAssumption("omp_no_openmp") ||
4907 AssumptionAA->hasAssumption("omp_no_parallelism")))
4908 ReachedUnknownParallelRegions.insert(&CB);
4909
4910 // If SPMDCompatibilityTracker is not fixed, we need to give up on the
4911 // idea we can run something unknown in SPMD-mode.
4912 if (!SPMDCompatibilityTracker.isAtFixpoint()) {
4913 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4914 SPMDCompatibilityTracker.insert(&CB);
4915 }
4916
4917 // We have updated the state for this unknown call properly, there
4918 // won't be any change so we indicate a fixpoint.
4919 indicateOptimisticFixpoint();
4920 }
4921 // If the callee is known and can be used in IPO, we will update the
4922 // state based on the callee state in updateImpl.
4923 return;
4924 }
4925 if (NumCallees > 1) {
4926 indicatePessimisticFixpoint();
4927 return;
4928 }
4929
4930 RuntimeFunction RF = It->getSecond();
4931 switch (RF) {
4932 // All the functions we know are compatible with SPMD mode.
4933 case OMPRTL___kmpc_is_spmd_exec_mode:
4934 case OMPRTL___kmpc_distribute_static_fini:
4935 case OMPRTL___kmpc_for_static_fini:
4936 case OMPRTL___kmpc_global_thread_num:
4937 case OMPRTL___kmpc_get_hardware_num_threads_in_block:
4938 case OMPRTL___kmpc_get_hardware_num_blocks:
4939 case OMPRTL___kmpc_single:
4940 case OMPRTL___kmpc_end_single:
4941 case OMPRTL___kmpc_master:
4942 case OMPRTL___kmpc_end_master:
4943 case OMPRTL___kmpc_barrier:
4944 case OMPRTL___kmpc_nvptx_parallel_reduce_nowait_v2:
4945 case OMPRTL___kmpc_nvptx_teams_reduce_nowait_v2:
4946 case OMPRTL___kmpc_error:
4947 case OMPRTL___kmpc_flush:
4948 case OMPRTL___kmpc_get_hardware_thread_id_in_block:
4949 case OMPRTL___kmpc_get_warp_size:
4950 case OMPRTL_omp_get_thread_num:
4951 case OMPRTL_omp_get_num_threads:
4952 case OMPRTL_omp_get_max_threads:
4953 case OMPRTL_omp_in_parallel:
4954 case OMPRTL_omp_get_dynamic:
4955 case OMPRTL_omp_get_cancellation:
4956 case OMPRTL_omp_get_nested:
4957 case OMPRTL_omp_get_schedule:
4958 case OMPRTL_omp_get_thread_limit:
4959 case OMPRTL_omp_get_supported_active_levels:
4960 case OMPRTL_omp_get_max_active_levels:
4961 case OMPRTL_omp_get_level:
4962 case OMPRTL_omp_get_ancestor_thread_num:
4963 case OMPRTL_omp_get_team_size:
4964 case OMPRTL_omp_get_active_level:
4965 case OMPRTL_omp_in_final:
4966 case OMPRTL_omp_get_proc_bind:
4967 case OMPRTL_omp_get_num_places:
4968 case OMPRTL_omp_get_num_procs:
4969 case OMPRTL_omp_get_place_proc_ids:
4970 case OMPRTL_omp_get_place_num:
4971 case OMPRTL_omp_get_partition_num_places:
4972 case OMPRTL_omp_get_partition_place_nums:
4973 case OMPRTL_omp_get_wtime:
4974 break;
4975 case OMPRTL___kmpc_distribute_static_init_4:
4976 case OMPRTL___kmpc_distribute_static_init_4u:
4977 case OMPRTL___kmpc_distribute_static_init_8:
4978 case OMPRTL___kmpc_distribute_static_init_8u:
4979 case OMPRTL___kmpc_for_static_init_4:
4980 case OMPRTL___kmpc_for_static_init_4u:
4981 case OMPRTL___kmpc_for_static_init_8:
4982 case OMPRTL___kmpc_for_static_init_8u: {
4983 // Check the schedule and allow static schedule in SPMD mode.
4984 unsigned ScheduleArgOpNo = 2;
4985 auto *ScheduleTypeCI =
4986 dyn_cast<ConstantInt>(CB.getArgOperand(ScheduleArgOpNo));
4987 unsigned ScheduleTypeVal =
4988 ScheduleTypeCI ? ScheduleTypeCI->getZExtValue() : 0;
4989 switch (OMPScheduleType(ScheduleTypeVal)) {
4990 case OMPScheduleType::UnorderedStatic:
4991 case OMPScheduleType::UnorderedStaticChunked:
4992 case OMPScheduleType::OrderedDistribute:
4993 case OMPScheduleType::OrderedDistributeChunked:
4994 break;
4995 default:
4996 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4997 SPMDCompatibilityTracker.insert(&CB);
4998 break;
4999 };
5000 } break;
5001 case OMPRTL___kmpc_target_init:
5002 KernelInitCB = &CB;
5003 break;
5004 case OMPRTL___kmpc_target_deinit:
5005 KernelDeinitCB = &CB;
5006 break;
5007 case OMPRTL___kmpc_parallel_51:
5008 if (!handleParallel51(A, CB))
5009 indicatePessimisticFixpoint();
5010 return;
5011 case OMPRTL___kmpc_omp_task:
5012 // We do not look into tasks right now, just give up.
5013 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5014 SPMDCompatibilityTracker.insert(&CB);
5015 ReachedUnknownParallelRegions.insert(&CB);
5016 break;
5017 case OMPRTL___kmpc_alloc_shared:
5018 case OMPRTL___kmpc_free_shared:
5019 // Return without setting a fixpoint, to be resolved in updateImpl.
5020 return;
5021 default:
5022 // Unknown OpenMP runtime calls cannot be executed in SPMD-mode,
5023 // generally. However, they do not hide parallel regions.
5024 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5025 SPMDCompatibilityTracker.insert(&CB);
5026 break;
5027 }
5028 // All other OpenMP runtime calls will not reach parallel regions so they
5029 // can be safely ignored for now. Since it is a known OpenMP runtime call
5030 // we have now modeled all effects and there is no need for any update.
5031 indicateOptimisticFixpoint();
5032 };
5033
5034 const auto *AACE =
5035 A.getAAFor<AACallEdges>(*this, getIRPosition(), DepClassTy::OPTIONAL);
5036 if (!AACE || !AACE->getState().isValidState() || AACE->hasUnknownCallee()) {
5037 CheckCallee(getAssociatedFunction(), 1);
5038 return;
5039 }
5040 const auto &OptimisticEdges = AACE->getOptimisticEdges();
5041 for (auto *Callee : OptimisticEdges) {
5042 CheckCallee(Callee, OptimisticEdges.size());
5043 if (isAtFixpoint())
5044 break;
5045 }
5046 }
5047
5048 ChangeStatus updateImpl(Attributor &A) override {
5049 // TODO: Once we have call site specific value information we can provide
5050 // call site specific liveness information and then it makes
5051 // sense to specialize attributes for call sites arguments instead of
5052 // redirecting requests to the callee argument.
5053 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
5054 KernelInfoState StateBefore = getState();
5055
5056 auto CheckCallee = [&](Function *F, int NumCallees) {
5057 const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(F);
5058
5059 // If F is not a runtime function, propagate the AAKernelInfo of the
5060 // callee.
5061 if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
5062 const IRPosition &FnPos = IRPosition::function(*F);
5063 auto *FnAA =
5064 A.getAAFor<AAKernelInfo>(*this, FnPos, DepClassTy::REQUIRED);
5065 if (!FnAA)
5066 return indicatePessimisticFixpoint();
5067 if (getState() == FnAA->getState())
5068 return ChangeStatus::UNCHANGED;
5069 getState() = FnAA->getState();
5070 return ChangeStatus::CHANGED;
5071 }
5072 if (NumCallees > 1)
5073 return indicatePessimisticFixpoint();
5074
5075 CallBase &CB = cast<CallBase>(getAssociatedValue());
5076 if (It->getSecond() == OMPRTL___kmpc_parallel_51) {
5077 if (!handleParallel51(A, CB))
5078 return indicatePessimisticFixpoint();
5079 return StateBefore == getState() ? ChangeStatus::UNCHANGED
5080 : ChangeStatus::CHANGED;
5081 }
5082
5083 // F is a runtime function that allocates or frees memory, check
5084 // AAHeapToStack and AAHeapToShared.
5085 assert(
5086 (It->getSecond() == OMPRTL___kmpc_alloc_shared ||
5087 It->getSecond() == OMPRTL___kmpc_free_shared) &&
5088 "Expected a __kmpc_alloc_shared or __kmpc_free_shared runtime call");
5089
5090 auto *HeapToStackAA = A.getAAFor<AAHeapToStack>(
5091 *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
5092 auto *HeapToSharedAA = A.getAAFor<AAHeapToShared>(
5093 *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
5094
5095 RuntimeFunction RF = It->getSecond();
5096
5097 switch (RF) {
5098 // If neither HeapToStack nor HeapToShared assume the call is removed,
5099 // assume SPMD incompatibility.
5100 case OMPRTL___kmpc_alloc_shared:
5101 if ((!HeapToStackAA || !HeapToStackAA->isAssumedHeapToStack(CB)) &&
5102 (!HeapToSharedAA || !HeapToSharedAA->isAssumedHeapToShared(CB)))
5103 SPMDCompatibilityTracker.insert(&CB);
5104 break;
5105 case OMPRTL___kmpc_free_shared:
5106 if ((!HeapToStackAA ||
5107 !HeapToStackAA->isAssumedHeapToStackRemovedFree(CB)) &&
5108 (!HeapToSharedAA ||
5109 !HeapToSharedAA->isAssumedHeapToSharedRemovedFree(CB)))
5110 SPMDCompatibilityTracker.insert(&CB);
5111 break;
5112 default:
5113 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5114 SPMDCompatibilityTracker.insert(&CB);
5115 }
5116 return ChangeStatus::CHANGED;
5117 };
5118
5119 const auto *AACE =
5120 A.getAAFor<AACallEdges>(*this, getIRPosition(), DepClassTy::OPTIONAL);
5121 if (!AACE || !AACE->getState().isValidState() || AACE->hasUnknownCallee()) {
5122 if (Function *F = getAssociatedFunction())
5123 CheckCallee(F, /*NumCallees=*/1);
5124 } else {
5125 const auto &OptimisticEdges = AACE->getOptimisticEdges();
5126 for (auto *Callee : OptimisticEdges) {
5127 CheckCallee(Callee, OptimisticEdges.size());
5128 if (isAtFixpoint())
5129 break;
5130 }
5131 }
5132
5133 return StateBefore == getState() ? ChangeStatus::UNCHANGED
5134 : ChangeStatus::CHANGED;
5135 }
5136
5137 /// Deal with a __kmpc_parallel_51 call (\p CB). Returns true if the call was
5138 /// handled, if a problem occurred, false is returned.
5139 bool handleParallel51(Attributor &A, CallBase &CB) {
5140 const unsigned int NonWrapperFunctionArgNo = 5;
5141 const unsigned int WrapperFunctionArgNo = 6;
5142 auto ParallelRegionOpArgNo = SPMDCompatibilityTracker.isAssumed()
5143 ? NonWrapperFunctionArgNo
5144 : WrapperFunctionArgNo;
5145
5146 auto *ParallelRegion = dyn_cast<Function>(
5147 CB.getArgOperand(ParallelRegionOpArgNo)->stripPointerCasts());
5148 if (!ParallelRegion)
5149 return false;
5150
5151 ReachedKnownParallelRegions.insert(&CB);
5152 /// Check nested parallelism
5153 auto *FnAA = A.getAAFor<AAKernelInfo>(
5154 *this, IRPosition::function(*ParallelRegion), DepClassTy::OPTIONAL);
5155 NestedParallelism |= !FnAA || !FnAA->getState().isValidState() ||
5156 !FnAA->ReachedKnownParallelRegions.empty() ||
5157 !FnAA->ReachedKnownParallelRegions.isValidState() ||
5158 !FnAA->ReachedUnknownParallelRegions.isValidState() ||
5159 !FnAA->ReachedUnknownParallelRegions.empty();
5160 return true;
5161 }
5162};
5163
5164struct AAFoldRuntimeCall
5165 : public StateWrapper<BooleanState, AbstractAttribute> {
5167
5168 AAFoldRuntimeCall(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
5169
5170 /// Statistics are tracked as part of manifest for now.
5171 void trackStatistics() const override {}
5172
5173 /// Create an abstract attribute biew for the position \p IRP.
5174 static AAFoldRuntimeCall &createForPosition(const IRPosition &IRP,
5175 Attributor &A);
5176
5177 /// See AbstractAttribute::getName()
5178 const std::string getName() const override { return "AAFoldRuntimeCall"; }
5179
5180 /// See AbstractAttribute::getIdAddr()
5181 const char *getIdAddr() const override { return &ID; }
5182
5183 /// This function should return true if the type of the \p AA is
5184 /// AAFoldRuntimeCall
5185 static bool classof(const AbstractAttribute *AA) {
5186 return (AA->getIdAddr() == &ID);
5187 }
5188
5189 static const char ID;
5190};
5191
5192struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
5193 AAFoldRuntimeCallCallSiteReturned(const IRPosition &IRP, Attributor &A)
5194 : AAFoldRuntimeCall(IRP, A) {}
5195
5196 /// See AbstractAttribute::getAsStr()
5197 const std::string getAsStr(Attributor *) const override {
5198 if (!isValidState())
5199 return "<invalid>";
5200
5201 std::string Str("simplified value: ");
5202
5203 if (!SimplifiedValue)
5204 return Str + std::string("none");
5205
5206 if (!*SimplifiedValue)
5207 return Str + std::string("nullptr");
5208
5209 if (ConstantInt *CI = dyn_cast<ConstantInt>(*SimplifiedValue))
5210 return Str + std::to_string(CI->getSExtValue());
5211
5212 return Str + std::string("unknown");
5213 }
5214
5215 void initialize(Attributor &A) override {
5217 indicatePessimisticFixpoint();
5218
5219 Function *Callee = getAssociatedFunction();
5220
5221 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
5222 const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
5223 assert(It != OMPInfoCache.RuntimeFunctionIDMap.end() &&
5224 "Expected a known OpenMP runtime function");
5225
5226 RFKind = It->getSecond();
5227
5228 CallBase &CB = cast<CallBase>(getAssociatedValue());
5229 A.registerSimplificationCallback(
5231 [&](const IRPosition &IRP, const AbstractAttribute *AA,
5232 bool &UsedAssumedInformation) -> std::optional<Value *> {
5233 assert((isValidState() ||
5234 (SimplifiedValue && *SimplifiedValue == nullptr)) &&
5235 "Unexpected invalid state!");
5236
5237 if (!isAtFixpoint()) {
5238 UsedAssumedInformation = true;
5239 if (AA)
5240 A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
5241 }
5242 return SimplifiedValue;
5243 });
5244 }
5245
5246 ChangeStatus updateImpl(Attributor &A) override {
5247 ChangeStatus Changed = ChangeStatus::UNCHANGED;
5248 switch (RFKind) {
5249 case OMPRTL___kmpc_is_spmd_exec_mode:
5250 Changed |= foldIsSPMDExecMode(A);
5251 break;
5252 case OMPRTL___kmpc_parallel_level:
5253 Changed |= foldParallelLevel(A);
5254 break;
5255 case OMPRTL___kmpc_get_hardware_num_threads_in_block:
5256 Changed = Changed | foldKernelFnAttribute(A, "omp_target_thread_limit");
5257 break;
5258 case OMPRTL___kmpc_get_hardware_num_blocks:
5259 Changed = Changed | foldKernelFnAttribute(A, "omp_target_num_teams");
5260 break;
5261 default:
5262 llvm_unreachable("Unhandled OpenMP runtime function!");
5263 }
5264
5265 return Changed;
5266 }
5267
5268 ChangeStatus manifest(Attributor &A) override {
5269 ChangeStatus Changed = ChangeStatus::UNCHANGED;
5270
5271 if (SimplifiedValue && *SimplifiedValue) {
5272 Instruction &I = *getCtxI();
5273 A.changeAfterManifest(IRPosition::inst(I), **SimplifiedValue);
5274 A.deleteAfterManifest(I);
5275
5276 CallBase *CB = dyn_cast<CallBase>(&I);
5277 auto Remark = [&](OptimizationRemark OR) {
5278 if (auto *C = dyn_cast<ConstantInt>(*SimplifiedValue))
5279 return OR << "Replacing OpenMP runtime call "
5280 << CB->getCalledFunction()->getName() << " with "
5281 << ore::NV("FoldedValue", C->getZExtValue()) << ".";
5282 return OR << "Replacing OpenMP runtime call "
5283 << CB->getCalledFunction()->getName() << ".";
5284 };
5285
5286 if (CB && EnableVerboseRemarks)
5287 A.emitRemark<OptimizationRemark>(CB, "OMP180", Remark);
5288
5289 LLVM_DEBUG(dbgs() << TAG << "Replacing runtime call: " << I << " with "
5290 << **SimplifiedValue << "\n");
5291
5292 Changed = ChangeStatus::CHANGED;
5293 }
5294
5295 return Changed;
5296 }
5297
5298 ChangeStatus indicatePessimisticFixpoint() override {
5299 SimplifiedValue = nullptr;
5300 return AAFoldRuntimeCall::indicatePessimisticFixpoint();
5301 }
5302
5303private:
5304 /// Fold __kmpc_is_spmd_exec_mode into a constant if possible.
5305 ChangeStatus foldIsSPMDExecMode(Attributor &A) {
5306 std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
5307
5308 unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
5309 unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
5310 auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
5311 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
5312
5313 if (!CallerKernelInfoAA ||
5314 !CallerKernelInfoAA->ReachingKernelEntries.isValidState())
5315 return indicatePessimisticFixpoint();
5316
5317 for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
5318 auto *AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
5319 DepClassTy::REQUIRED);
5320
5321 if (!AA || !AA->isValidState()) {
5322 SimplifiedValue = nullptr;
5323 return indicatePessimisticFixpoint();
5324 }
5325
5326 if (AA->SPMDCompatibilityTracker.isAssumed()) {
5327 if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5328 ++KnownSPMDCount;
5329 else
5330 ++AssumedSPMDCount;
5331 } else {
5332 if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5333 ++KnownNonSPMDCount;
5334 else
5335 ++AssumedNonSPMDCount;
5336 }
5337 }
5338
5339 if ((AssumedSPMDCount + KnownSPMDCount) &&
5340 (AssumedNonSPMDCount + KnownNonSPMDCount))
5341 return indicatePessimisticFixpoint();
5342
5343 auto &Ctx = getAnchorValue().getContext();
5344 if (KnownSPMDCount || AssumedSPMDCount) {
5345 assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
5346 "Expected only SPMD kernels!");
5347 // All reaching kernels are in SPMD mode. Update all function calls to
5348 // __kmpc_is_spmd_exec_mode to 1.
5349 SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), true);
5350 } else if (KnownNonSPMDCount || AssumedNonSPMDCount) {
5351 assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
5352 "Expected only non-SPMD kernels!");
5353 // All reaching kernels are in non-SPMD mode. Update all function
5354 // calls to __kmpc_is_spmd_exec_mode to 0.
5355 SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), false);
5356 } else {
5357 // We have empty reaching kernels, therefore we cannot tell if the
5358 // associated call site can be folded. At this moment, SimplifiedValue
5359 // must be none.
5360 assert(!SimplifiedValue && "SimplifiedValue should be none");
5361 }
5362
5363 return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
5364 : ChangeStatus::CHANGED;
5365 }
5366
5367 /// Fold __kmpc_parallel_level into a constant if possible.
5368 ChangeStatus foldParallelLevel(Attributor &A) {
5369 std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
5370
5371 auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
5372 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
5373
5374 if (!CallerKernelInfoAA ||
5375 !CallerKernelInfoAA->ParallelLevels.isValidState())
5376 return indicatePessimisticFixpoint();
5377
5378 if (!CallerKernelInfoAA->ReachingKernelEntries.isValidState())
5379 return indicatePessimisticFixpoint();
5380
5381 if (CallerKernelInfoAA->ReachingKernelEntries.empty()) {
5382 assert(!SimplifiedValue &&
5383 "SimplifiedValue should keep none at this point");
5384 return ChangeStatus::UNCHANGED;
5385 }
5386
5387 unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
5388 unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
5389 for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
5390 auto *AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
5391 DepClassTy::REQUIRED);
5392 if (!AA || !AA->SPMDCompatibilityTracker.isValidState())
5393 return indicatePessimisticFixpoint();
5394
5395 if (AA->SPMDCompatibilityTracker.isAssumed()) {
5396 if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5397 ++KnownSPMDCount;
5398 else
5399 ++AssumedSPMDCount;
5400 } else {
5401 if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5402 ++KnownNonSPMDCount;
5403 else
5404 ++AssumedNonSPMDCount;
5405 }
5406 }
5407
5408 if ((AssumedSPMDCount + KnownSPMDCount) &&
5409 (AssumedNonSPMDCount + KnownNonSPMDCount))
5410 return indicatePessimisticFixpoint();
5411
5412 auto &Ctx = getAnchorValue().getContext();
5413 // If the caller can only be reached by SPMD kernel entries, the parallel
5414 // level is 1. Similarly, if the caller can only be reached by non-SPMD
5415 // kernel entries, it is 0.
5416 if (AssumedSPMDCount || KnownSPMDCount) {
5417 assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
5418 "Expected only SPMD kernels!");
5419 SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 1);
5420 } else {
5421 assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
5422 "Expected only non-SPMD kernels!");
5423 SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 0);
5424 }
5425 return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
5426 : ChangeStatus::CHANGED;
5427 }
5428
5429 ChangeStatus foldKernelFnAttribute(Attributor &A, llvm::StringRef Attr) {
5430 // Specialize only if all the calls agree with the attribute constant value
5431 int32_t CurrentAttrValue = -1;
5432 std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
5433
5434 auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
5435 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
5436
5437 if (!CallerKernelInfoAA ||
5438 !CallerKernelInfoAA->ReachingKernelEntries.isValidState())
5439 return indicatePessimisticFixpoint();
5440
5441 // Iterate over the kernels that reach this function
5442 for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
5443 int32_t NextAttrVal = K->getFnAttributeAsParsedInteger(Attr, -1);
5444
5445 if (NextAttrVal == -1 ||
5446 (CurrentAttrValue != -1 && CurrentAttrValue != NextAttrVal))
5447 return indicatePessimisticFixpoint();
5448 CurrentAttrValue = NextAttrVal;
5449 }
5450
5451 if (CurrentAttrValue != -1) {
5452 auto &Ctx = getAnchorValue().getContext();
5453 SimplifiedValue =
5454 ConstantInt::get(Type::getInt32Ty(Ctx), CurrentAttrValue);
5455 }
5456 return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
5457 : ChangeStatus::CHANGED;
5458 }
5459
5460 /// An optional value the associated value is assumed to fold to. That is, we
5461 /// assume the associated value (which is a call) can be replaced by this
5462 /// simplified value.
5463 std::optional<Value *> SimplifiedValue;
5464
5465 /// The runtime function kind of the callee of the associated call site.
5466 RuntimeFunction RFKind;
5467};
5468
5469} // namespace
5470
5471/// Register folding callsite
5472void OpenMPOpt::registerFoldRuntimeCall(RuntimeFunction RF) {
5473 auto &RFI = OMPInfoCache.RFIs[RF];
5474 RFI.foreachUse(SCC, [&](Use &U, Function &F) {
5475 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &RFI);
5476 if (!CI)
5477 return false;
5478 A.getOrCreateAAFor<AAFoldRuntimeCall>(
5479 IRPosition::callsite_returned(*CI), /* QueryingAA */ nullptr,
5480 DepClassTy::NONE, /* ForceUpdate */ false,
5481 /* UpdateAfterInit */ false);
5482 return false;
5483 });
5484}
5485
5486void OpenMPOpt::registerAAs(bool IsModulePass) {
5487 if (SCC.empty())
5488 return;
5489
5490 if (IsModulePass) {
5491 // Ensure we create the AAKernelInfo AAs first and without triggering an
5492 // update. This will make sure we register all value simplification
5493 // callbacks before any other AA has the chance to create an AAValueSimplify
5494 // or similar.
5495 auto CreateKernelInfoCB = [&](Use &, Function &Kernel) {
5496 A.getOrCreateAAFor<AAKernelInfo>(
5497 IRPosition::function(Kernel), /* QueryingAA */ nullptr,
5498 DepClassTy::NONE, /* ForceUpdate */ false,
5499 /* UpdateAfterInit */ false);
5500 return false;
5501 };
5502 OMPInformationCache::RuntimeFunctionInfo &InitRFI =
5503 OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
5504 InitRFI.foreachUse(SCC, CreateKernelInfoCB);
5505
5506 registerFoldRuntimeCall(OMPRTL___kmpc_is_spmd_exec_mode);
5507 registerFoldRuntimeCall(OMPRTL___kmpc_parallel_level);
5508 registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_threads_in_block);
5509 registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_blocks);
5510 }
5511
5512 // Create CallSite AA for all Getters.
5513 if (DeduceICVValues) {
5514 for (int Idx = 0; Idx < OMPInfoCache.ICVs.size() - 1; ++Idx) {
5515 auto ICVInfo = OMPInfoCache.ICVs[static_cast<InternalControlVar>(Idx)];
5516
5517 auto &GetterRFI = OMPInfoCache.RFIs[ICVInfo.Getter];
5518
5519 auto CreateAA = [&](Use &U, Function &Caller) {
5520 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &GetterRFI);
5521 if (!CI)
5522 return false;
5523
5524 auto &CB = cast<CallBase>(*CI);
5525
5527 A.getOrCreateAAFor<AAICVTracker>(CBPos);
5528 return false;
5529 };
5530
5531 GetterRFI.foreachUse(SCC, CreateAA);
5532 }
5533 }
5534
5535 // Create an ExecutionDomain AA for every function and a HeapToStack AA for
5536 // every function if there is a device kernel.
5537 if (!isOpenMPDevice(M))
5538 return;
5539
5540 for (auto *F : SCC) {
5541 if (F->isDeclaration())
5542 continue;
5543
5544 // We look at internal functions only on-demand but if any use is not a
5545 // direct call or outside the current set of analyzed functions, we have
5546 // to do it eagerly.
5547 if (F->hasLocalLinkage()) {
5548 if (llvm::all_of(F->uses(), [this](const Use &U) {
5549 const auto *CB = dyn_cast<CallBase>(U.getUser());
5550 return CB && CB->isCallee(&U) &&
5551 A.isRunOn(const_cast<Function *>(CB->getCaller()));
5552 }))
5553 continue;
5554 }
5555 registerAAsForFunction(A, *F);
5556 }
5557}
5558
5559void OpenMPOpt::registerAAsForFunction(Attributor &A, const Function &F) {
5561 A.getOrCreateAAFor<AAHeapToShared>(IRPosition::function(F));
5562 A.getOrCreateAAFor<AAExecutionDomain>(IRPosition::function(F));
5564 A.getOrCreateAAFor<AAHeapToStack>(IRPosition::function(F));
5565 if (F.hasFnAttribute(Attribute::Convergent))
5566 A.getOrCreateAAFor<AANonConvergent>(IRPosition::function(F));
5567
5568 for (auto &I : instructions(F)) {
5569 if (auto *LI = dyn_cast<LoadInst>(&I)) {
5570 bool UsedAssumedInformation = false;
5571 A.getAssumedSimplified(IRPosition::value(*LI), /* AA */ nullptr,
5572 UsedAssumedInformation, AA::Interprocedural);
5573 continue;
5574 }
5575 if (auto *CI = dyn_cast<CallBase>(&I)) {
5576 if (CI->isIndirectCall())
5577 A.getOrCreateAAFor<AAIndirectCallInfo>(
5579 }
5580 if (auto *SI = dyn_cast<StoreInst>(&I)) {
5581 A.getOrCreateAAFor<AAIsDead>(IRPosition::value(*SI));
5582 continue;
5583 }
5584 if (auto *FI = dyn_cast<FenceInst>(&I)) {
5585 A.getOrCreateAAFor<AAIsDead>(IRPosition::value(*FI));
5586 continue;
5587 }
5588 if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
5589 if (II->getIntrinsicID() == Intrinsic::assume) {
5590 A.getOrCreateAAFor<AAPotentialValues>(
5591 IRPosition::value(*II->getArgOperand(0)));
5592 continue;
5593 }
5594 }
5595 }
5596}
5597
5598const char AAICVTracker::ID = 0;
5599const char AAKernelInfo::ID = 0;
5600const char AAExecutionDomain::ID = 0;
5601const char AAHeapToShared::ID = 0;
5602const char AAFoldRuntimeCall::ID = 0;
5603
5604AAICVTracker &AAICVTracker::createForPosition(const IRPosition &IRP,
5605 Attributor &A) {
5606 AAICVTracker *AA = nullptr;
5607 switch (IRP.getPositionKind()) {
5612 llvm_unreachable("ICVTracker can only be created for function position!");
5614 AA = new (A.Allocator) AAICVTrackerFunctionReturned(IRP, A);
5615 break;
5617 AA = new (A.Allocator) AAICVTrackerCallSiteReturned(IRP, A);
5618 break;
5620 AA = new (A.Allocator) AAICVTrackerCallSite(IRP, A);
5621 break;
5623 AA = new (A.Allocator) AAICVTrackerFunction(IRP, A);
5624 break;
5625 }
5626
5627 return *AA;
5628}
5629
5631 Attributor &A) {
5632 AAExecutionDomainFunction *AA = nullptr;
5633 switch (IRP.getPositionKind()) {
5642 "AAExecutionDomain can only be created for function position!");
5644 AA = new (A.Allocator) AAExecutionDomainFunction(IRP, A);
5645 break;
5646 }
5647
5648 return *AA;
5649}
5650
5651AAHeapToShared &AAHeapToShared::createForPosition(const IRPosition &IRP,
5652 Attributor &A) {
5653 AAHeapToSharedFunction *AA = nullptr;
5654 switch (IRP.getPositionKind()) {
5663 "AAHeapToShared can only be created for function position!");
5665 AA = new (A.Allocator) AAHeapToSharedFunction(IRP, A);
5666 break;
5667 }
5668
5669 return *AA;
5670}
5671
5672AAKernelInfo &AAKernelInfo::createForPosition(const IRPosition &IRP,
5673 Attributor &A) {
5674 AAKernelInfo *AA = nullptr;
5675 switch (IRP.getPositionKind()) {
5682 llvm_unreachable("KernelInfo can only be created for function position!");
5684 AA = new (A.Allocator) AAKernelInfoCallSite(IRP, A);
5685 break;
5687 AA = new (A.Allocator) AAKernelInfoFunction(IRP, A);
5688 break;
5689 }
5690
5691 return *AA;
5692}
5693
5694AAFoldRuntimeCall &AAFoldRuntimeCall::createForPosition(const IRPosition &IRP,
5695 Attributor &A) {
5696 AAFoldRuntimeCall *AA = nullptr;
5697 switch (IRP.getPositionKind()) {
5705 llvm_unreachable("KernelInfo can only be created for call site position!");
5707 AA = new (A.Allocator) AAFoldRuntimeCallCallSiteReturned(IRP, A);
5708 break;
5709 }
5710
5711 return *AA;
5712}
5713
5715 if (!containsOpenMP(M))
5716 return PreservedAnalyses::all();
5718 return PreservedAnalyses::all();
5719
5722 KernelSet Kernels = getDeviceKernels(M);
5723
5725 LLVM_DEBUG(dbgs() << TAG << "Module before OpenMPOpt Module Pass:\n" << M);
5726
5727 auto IsCalled = [&](Function &F) {
5728 if (Kernels.contains(&F))
5729 return true;
5730 for (const User *U : F.users())
5731 if (!isa<BlockAddress>(U))
5732 return true;
5733 return false;
5734 };
5735
5736 auto EmitRemark = [&](Function &F) {
5738 ORE.emit([&]() {
5739 OptimizationRemarkAnalysis ORA(DEBUG_TYPE, "OMP140", &F);
5740 return ORA << "Could not internalize function. "
5741 << "Some optimizations may not be possible. [OMP140]";
5742 });
5743 };
5744
5745 bool Changed = false;
5746
5747 // Create internal copies of each function if this is a kernel Module. This
5748 // allows iterprocedural passes to see every call edge.
5749 DenseMap<Function *, Function *> InternalizedMap;
5750 if (isOpenMPDevice(M)) {
5751 SmallPtrSet<Function *, 16> InternalizeFns;
5752 for (Function &F : M)
5753 if (!F.isDeclaration() && !Kernels.contains(&F) && IsCalled(F) &&
5756 InternalizeFns.insert(&F);
5757 } else if (!F.hasLocalLinkage() && !F.hasFnAttribute(Attribute::Cold)) {
5758 EmitRemark(F);
5759 }
5760 }
5761
5762 Changed |=
5763 Attributor::internalizeFunctions(InternalizeFns, InternalizedMap);
5764 }
5765
5766 // Look at every function in the Module unless it was internalized.
5767 SetVector<Function *> Functions;
5769 for (Function &F : M)
5770 if (!F.isDeclaration() && !InternalizedMap.lookup(&F)) {
5771 SCC.push_back(&F);
5772 Functions.insert(&F);
5773 }
5774
5775 if (SCC.empty())
5776 return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
5777
5778 AnalysisGetter AG(FAM);
5779
5780 auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
5782 };
5783
5785 CallGraphUpdater CGUpdater;
5786
5787 bool PostLink = LTOPhase == ThinOrFullLTOPhase::FullLTOPostLink ||
5789 OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ nullptr, PostLink);
5790
5791 unsigned MaxFixpointIterations =
5793
5794 AttributorConfig AC(CGUpdater);
5796 AC.IsModulePass = true;
5797 AC.RewriteSignatures = false;
5798 AC.MaxFixpointIterations = MaxFixpointIterations;
5799 AC.OREGetter = OREGetter;
5800 AC.PassName = DEBUG_TYPE;
5801 AC.InitializationCallback = OpenMPOpt::registerAAsForFunction;
5802 AC.IPOAmendableCB = [](const Function &F) {
5803 return F.hasFnAttribute("kernel");
5804 };
5805
5806 Attributor A(Functions, InfoCache, AC);
5807
5808 OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
5809 Changed |= OMPOpt.run(true);
5810
5811 // Optionally inline device functions for potentially better performance.
5813 for (Function &F : M)
5814 if (!F.isDeclaration() && !Kernels.contains(&F) &&
5815 !F.hasFnAttribute(Attribute::NoInline))
5816 F.addFnAttr(Attribute::AlwaysInline);
5817
5819 LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt Module Pass:\n" << M);
5820
5821 if (Changed)
5822 return PreservedAnalyses::none();
5823
5824 return PreservedAnalyses::all();
5825}
5826
5829 LazyCallGraph &CG,
5830 CGSCCUpdateResult &UR) {
5831 if (!containsOpenMP(*C.begin()->getFunction().getParent()))
5832 return PreservedAnalyses::all();
5834 return PreservedAnalyses::all();
5835
5837 // If there are kernels in the module, we have to run on all SCC's.
5838 for (LazyCallGraph::Node &N : C) {
5839 Function *Fn = &N.getFunction();
5840 SCC.push_back(Fn);
5841 }
5842
5843 if (SCC.empty())
5844 return PreservedAnalyses::all();
5845
5846 Module &M = *C.begin()->getFunction().getParent();
5847
5849 LLVM_DEBUG(dbgs() << TAG << "Module before OpenMPOpt CGSCC Pass:\n" << M);
5850
5851 KernelSet Kernels = getDeviceKernels(M);
5852
5854 AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager();
5855
5856 AnalysisGetter AG(FAM);
5857
5858 auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
5860 };
5861
5863 CallGraphUpdater CGUpdater;
5864 CGUpdater.initialize(CG, C, AM, UR);
5865
5866 bool PostLink = LTOPhase == ThinOrFullLTOPhase::FullLTOPostLink ||
5868 SetVector<Function *> Functions(SCC.begin(), SCC.end());
5869 OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, Allocator,
5870 /*CGSCC*/ &Functions, PostLink);
5871
5872 unsigned MaxFixpointIterations =
5874
5875 AttributorConfig AC(CGUpdater);
5877 AC.IsModulePass = false;
5878 AC.RewriteSignatures = false;
5879 AC.MaxFixpointIterations = MaxFixpointIterations;
5880 AC.OREGetter = OREGetter;
5881 AC.PassName = DEBUG_TYPE;
5882 AC.InitializationCallback = OpenMPOpt::registerAAsForFunction;
5883
5884 Attributor A(Functions, InfoCache, AC);
5885
5886 OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
5887 bool Changed = OMPOpt.run(false);
5888
5890 LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt CGSCC Pass:\n" << M);
5891
5892 if (Changed)
5893 return PreservedAnalyses::none();
5894
5895 return PreservedAnalyses::all();
5896}
5897
5899 return Fn.hasFnAttribute("kernel");
5900}
5901
5903 // TODO: Create a more cross-platform way of determining device kernels.
5904 NamedMDNode *MD = M.getNamedMetadata("nvvm.annotations");
5905 KernelSet Kernels;
5906
5907 if (!MD)
5908 return Kernels;
5909
5910 for (auto *Op : MD->operands()) {
5911 if (Op->getNumOperands() < 2)
5912 continue;
5913 MDString *KindID = dyn_cast<MDString>(Op->getOperand(1));
5914 if (!KindID || KindID->getString() != "kernel")
5915 continue;
5916
5917 Function *KernelFn =
5918 mdconst::dyn_extract_or_null<Function>(Op->getOperand(0));
5919 if (!KernelFn)
5920 continue;
5921
5922 // We are only interested in OpenMP target regions. Others, such as kernels
5923 // generated by CUDA but linked together, are not interesting to this pass.
5924 if (isOpenMPKernel(*KernelFn)) {
5925 ++NumOpenMPTargetRegionKernels;
5926 Kernels.insert(KernelFn);
5927 } else
5928 ++NumNonOpenMPTargetRegionKernels;
5929 }
5930
5931 return Kernels;
5932}
5933
5935 Metadata *MD = M.getModuleFlag("openmp");
5936 if (!MD)
5937 return false;
5938
5939 return true;
5940}
5941
5943 Metadata *MD = M.getModuleFlag("openmp-device");
5944 if (!MD)
5945 return false;
5946
5947 return true;
5948}
@ Generic
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Expand Atomic instructions
static cl::opt< unsigned > SetFixpointIterations("attributor-max-iterations", cl::Hidden, cl::desc("Maximal number of fixpoint iterations."), cl::init(32))
static const Function * getParent(const Value *V)
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
This file provides interfaces used to manipulate a call graph, regardless if it is a "old style" Call...
This file provides interfaces used to build and manipulate a call graph, which is a very useful tool ...
This file contains the declarations for the subclasses of Constant, which represent the different fla...
dxil pretty DXIL Metadata Pretty Printer
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
#define LLVM_DEBUG(X)
Definition: Debug.h:101
bool End
Definition: ELF_riscv.cpp:480
This file defines an array type that can be indexed using scoped enum values.
#define DEBUG_TYPE
static void emitRemark(const Function &F, OptimizationRemarkEmitter &ORE, bool Skip)
static bool lookup(const GsymReader &GR, DataExtractor &Data, uint64_t &Offset, uint64_t BaseAddr, uint64_t Addr, SourceLocations &SrcLocs, llvm::Error &Err)
A Lookup helper functions.
Definition: InlineInfo.cpp:109
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
This file provides utility analysis objects describing memory locations.
uint64_t IntrinsicInst * II
This file defines constans and helpers used when dealing with OpenMP.
This file defines constans that will be used by both host and device compilation.
static constexpr auto TAG
Definition: OpenMPOpt.cpp:185
static cl::opt< bool > HideMemoryTransferLatency("openmp-hide-memory-transfer-latency", cl::desc("[WIP] Tries to hide the latency of host to device memory" " transfers"), cl::Hidden, cl::init(false))
static cl::opt< bool > DisableOpenMPOptStateMachineRewrite("openmp-opt-disable-state-machine-rewrite", cl::desc("Disable OpenMP optimizations that replace the state machine."), cl::Hidden, cl::init(false))
static cl::opt< bool > EnableParallelRegionMerging("openmp-opt-enable-merging", cl::desc("Enable the OpenMP region merging optimization."), cl::Hidden, cl::init(false))
static cl::opt< bool > AlwaysInlineDeviceFunctions("openmp-opt-inline-device", cl::desc("Inline all applicible functions on the device."), cl::Hidden, cl::init(false))
static cl::opt< bool > PrintModuleAfterOptimizations("openmp-opt-print-module-after", cl::desc("Print the current module after OpenMP optimizations."), cl::Hidden, cl::init(false))
#define KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MEMBER)
Definition: OpenMPOpt.cpp:241
#define KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MEMBER, IDX)
Definition: OpenMPOpt.cpp:218
#define KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MEMBER)
Definition: OpenMPOpt.cpp:3675
static cl::opt< bool > PrintOpenMPKernels("openmp-print-gpu-kernels", cl::init(false), cl::Hidden)
static cl::opt< bool > DisableOpenMPOptFolding("openmp-opt-disable-folding", cl::desc("Disable OpenMP optimizations involving folding."), cl::Hidden, cl::init(false))
static cl::opt< bool > PrintModuleBeforeOptimizations("openmp-opt-print-module-before", cl::desc("Print the current module before OpenMP optimizations."), cl::Hidden, cl::init(false))
static cl::opt< unsigned > SetFixpointIterations("openmp-opt-max-iterations", cl::Hidden, cl::desc("Maximal number of attributor iterations."), cl::init(256))
static cl::opt< bool > DisableInternalization("openmp-opt-disable-internalization", cl::desc("Disable function internalization."), cl::Hidden, cl::init(false))
static cl::opt< bool > PrintICVValues("openmp-print-icv-values", cl::init(false), cl::Hidden)
static cl::opt< bool > DisableOpenMPOptimizations("openmp-opt-disable", cl::desc("Disable OpenMP specific optimizations."), cl::Hidden, cl::init(false))
static cl::opt< unsigned > SharedMemoryLimit("openmp-opt-shared-limit", cl::Hidden, cl::desc("Maximum amount of shared memory to use."), cl::init(std::numeric_limits< unsigned >::max()))
static cl::opt< bool > EnableVerboseRemarks("openmp-opt-verbose-remarks", cl::desc("Enables more verbose remarks."), cl::Hidden, cl::init(false))
static cl::opt< bool > DisableOpenMPOptDeglobalization("openmp-opt-disable-deglobalization", cl::desc("Disable OpenMP optimizations involving deglobalization."), cl::Hidden, cl::init(false))
static cl::opt< bool > DisableOpenMPOptBarrierElimination("openmp-opt-disable-barrier-elimination", cl::desc("Disable OpenMP optimizations that eliminate barriers."), cl::Hidden, cl::init(false))
#define DEBUG_TYPE
Definition: OpenMPOpt.cpp:67
static cl::opt< bool > DeduceICVValues("openmp-deduce-icv-values", cl::init(false), cl::Hidden)
#define KERNEL_ENVIRONMENT_IDX(MEMBER, IDX)
Definition: OpenMPOpt.cpp:210
#define KERNEL_ENVIRONMENT_GETTER(MEMBER, RETURNTYPE)
Definition: OpenMPOpt.cpp:231
static cl::opt< bool > DisableOpenMPOptSPMDization("openmp-opt-disable-spmdization", cl::desc("Disable OpenMP optimizations involving SPMD-ization."), cl::Hidden, cl::init(false))
#define P(N)
if(VerifyEach)
FunctionAnalysisManager FAM
This file builds on the ADT/GraphTraits.h file to build a generic graph post order iterator.
static StringRef getName(Value *V)
R600 Clause Merge
Basic Register Allocator
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
This file implements a set that has insertion order iteration characteristics.
This file defines the SmallPtrSet class.
This file defines the SmallVector class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition: Statistic.h:167
This file contains some functions that are useful when dealing with strings.
static const int BlockSize
Definition: TarWriter.cpp:33
static void initialize(TargetLibraryInfoImpl &TLI, const Triple &T, ArrayRef< StringLiteral > StandardNames)
Initialize the set of available library functions based on the specified target triple.
static bool contains(SmallPtrSetImpl< ConstantExpr * > &Cache, ConstantExpr *Expr, Constant *C)
Definition: Value.cpp:469
Value * RHS
static cl::opt< unsigned > MaxThreads("xcore-max-threads", cl::Optional, cl::desc("Maximum number of threads (for emulation thread-local storage)"), cl::Hidden, cl::value_desc("number"), cl::init(8))
AbstractCallSite.
This class represents a conversion between pointers from one address space to another.
an instruction to allocate memory on the stack
Definition: Instructions.h:61
A container for analyses that lazily runs them and caches their results.
Definition: PassManager.h:253
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Definition: PassManager.h:405
This class represents an incoming formal argument to a Function.
Definition: Argument.h:31
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition: ArrayRef.h:41
size_t size() const
size - Get the array size.
Definition: ArrayRef.h:165
AttributeSet getParamAttrs(unsigned ArgNo) const
The attributes for the argument or parameter at the given index are returned.
LLVM Basic Block Representation.
Definition: BasicBlock.h:61
iterator end()
Definition: BasicBlock.h:451
iterator begin()
Instruction iterator methods.
Definition: BasicBlock.h:438
const_iterator getFirstInsertionPt() const
Returns an iterator to the first instruction in this block that is suitable for inserting a non-PHI i...
Definition: BasicBlock.cpp:414
reverse_iterator rbegin()
Definition: BasicBlock.h:454
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
Definition: BasicBlock.h:202
BasicBlock * splitBasicBlock(iterator I, const Twine &BBName="", bool Before=false)
Split the basic block into two basic blocks at the specified instruction.
Definition: BasicBlock.cpp:575
const BasicBlock * getUniqueSuccessor() const
Return the successor of this block if it has a unique successor.
Definition: BasicBlock.cpp:495
InstListType::reverse_iterator reverse_iterator
Definition: BasicBlock.h:169
const Function * getParent() const
Return the enclosing method, or null if none.
Definition: BasicBlock.h:209
reverse_iterator rend()
Definition: BasicBlock.h:456
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition: BasicBlock.h:229
Conditional or Unconditional Branch instruction.
bool isConditional() const
static BranchInst * Create(BasicBlock *IfTrue, InsertPosition InsertBefore=nullptr)
BasicBlock * getSuccessor(unsigned i) const
Value * getCondition() const
Allocate memory in an ever growing pool, as if by bump-pointer.
Definition: Allocator.h:66
Base class for all callable instructions (InvokeInst and CallInst) Holds everything related to callin...
Definition: InstrTypes.h:1236
void setCallingConv(CallingConv::ID CC)
Definition: InstrTypes.h:1527
bool arg_empty() const
Definition: InstrTypes.h:1407
Function * getCalledFunction() const
Returns the function called, or null if this is an indirect function invocation or the function signa...
Definition: InstrTypes.h:1465
bool doesNotAccessMemory(unsigned OpNo) const
Definition: InstrTypes.h:1810
bool isIndirectCall() const
Return true if the callsite is an indirect call.
bool isCallee(Value::const_user_iterator UI) const
Determine whether the passed iterator points to the callee operand's Use.
Definition: InstrTypes.h:1476
Value * getArgOperand(unsigned i) const
Definition: InstrTypes.h:1410
void setArgOperand(unsigned i, Value *v)
Definition: InstrTypes.h:1415
iterator_range< User::op_iterator > args()
Iteration adapter for range-for loops.
Definition: InstrTypes.h:1401
unsigned getArgOperandNo(const Use *U) const
Given a use for a arg operand, get the arg operand number that corresponds to it.
Definition: InstrTypes.h:1441
unsigned arg_size() const
Definition: InstrTypes.h:1408
AttributeList getAttributes() const
Return the parameter attributes for this call.
Definition: InstrTypes.h:1542
void addParamAttr(unsigned ArgNo, Attribute::AttrKind Kind)
Adds the attribute to the indicated argument.
Definition: InstrTypes.h:1594
bool isArgOperand(const Use *U) const
Definition: InstrTypes.h:1430
bool hasOperandBundles() const
Return true if this User has any operand bundles.
Definition: InstrTypes.h:2061
Function * getCaller()
Helper to get the caller (the parent function).
Wrapper to unify "old style" CallGraph and "new style" LazyCallGraph.
void initialize(LazyCallGraph &LCG, LazyCallGraph::SCC &SCC, CGSCCAnalysisManager &AM, CGSCCUpdateResult &UR)
Initializers for usage outside of a CGSCC pass, inside a CGSCC pass in the old and new pass manager (...
This class represents a function call, abstracting a target machine's calling convention.
static CallInst * Create(FunctionType *Ty, Value *F, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)
@ ICMP_SLT
signed less than
Definition: InstrTypes.h:786
@ ICMP_EQ
equal
Definition: InstrTypes.h:778
@ ICMP_NE
not equal
Definition: InstrTypes.h:779
static Constant * getPointerCast(Constant *C, Type *Ty)
Create a BitCast, AddrSpaceCast, or a PtrToInt cast constant expression.
Definition: Constants.cpp:2215
static Constant * getPointerBitCastOrAddrSpaceCast(Constant *C, Type *Ty)
Create a BitCast or AddrSpaceCast for a pointer type depending on the address space.
Definition: Constants.cpp:2230
This is the shared class of boolean and integer constants.
Definition: Constants.h:81
IntegerType * getIntegerType() const
Variant of the getType() method to always return an IntegerType, which reduces the amount of casting ...
Definition: Constants.h:185
static ConstantInt * getTrue(LLVMContext &Context)
Definition: Constants.cpp:850
bool isZero() const
This is just a convenience method to make client code smaller for a common code.
Definition: Constants.h:206
int64_t getSExtValue() const
Return the constant as a 64-bit integer value after it has been sign extended as appropriate for the ...
Definition: Constants.h:161
This is an important base class in LLVM.
Definition: Constant.h:42
static Constant * getNullValue(Type *Ty)
Constructor to create a '0' constant of arbitrary type.
Definition: Constants.cpp:370
This class represents an Operation in the Expression.
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:110
A debug info location.
Definition: DebugLoc.h:33
ValueT lookup(const_arg_type_t< KeyT > Val) const
lookup - Return the entry for the specified key, or a default constructed value if no such entry exis...
Definition: DenseMap.h:202
iterator begin()
Definition: DenseMap.h:75
iterator end()
Definition: DenseMap.h:84
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition: DenseMap.h:220
Implements a dense probed hash-table based set.
Definition: DenseSet.h:271
Analysis pass which computes a DominatorTree.
Definition: Dominators.h:279
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition: Dominators.h:162
Instruction * findNearestCommonDominator(Instruction *I1, Instruction *I2) const
Find the nearest instruction I that dominates both I1 and I2, in the sense that a result produced bef...
Definition: Dominators.cpp:344
An instruction for ordering other memory operations.
Definition: Instructions.h:420
AtomicOrdering getOrdering() const
Returns the ordering constraint of this fence instruction.
Definition: Instructions.h:443
A proxy from a FunctionAnalysisManager to an SCC.
A handy container for a FunctionType+Callee-pointer pair, which can be passed around as a single enti...
Definition: DerivedTypes.h:168
const BasicBlock & getEntryBlock() const
Definition: Function.h:800
const BasicBlock & front() const
Definition: Function.h:823
LLVMContext & getContext() const
getContext - Return a reference to the LLVMContext associated with this function.
Definition: Function.cpp:358
Argument * getArg(unsigned i) const
Definition: Function.h:849
bool hasFnAttribute(Attribute::AttrKind Kind) const
Return true if the function has the attribute.
Definition: Function.cpp:719
bool isDeclaration() const
Return true if the primary definition of this global value is outside of the current translation unit...
Definition: Globals.cpp:290
bool hasLocalLinkage() const
Definition: GlobalValue.h:528
Module * getParent()
Get the module that this global value is contained inside of...
Definition: GlobalValue.h:656
PointerType * getType() const
Global values are always pointers.
Definition: GlobalValue.h:294
@ PrivateLinkage
Like Internal, but omit from symbol table.
Definition: GlobalValue.h:60
@ InternalLinkage
Rename collisions when linking (static functions).
Definition: GlobalValue.h:59
const Constant * getInitializer() const
getInitializer - Return the initializer for this global variable.
void setInitializer(Constant *InitVal)
setInitializer - Sets the initializer for this global variable, removing any existing initializer if ...
Definition: Globals.cpp:485
void eraseFromParent()
eraseFromParent - This method unlinks 'this' from the containing module and deletes it.
Definition: Globals.cpp:481
InsertPoint - A saved insertion point.
Definition: IRBuilder.h:254
AllocaInst * CreateAlloca(Type *Ty, unsigned AddrSpace, Value *ArraySize=nullptr, const Twine &Name="")
Definition: IRBuilder.h:1778
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
Definition: IRBuilder.h:177
Value * CreateAddrSpaceCast(Value *V, Type *DestTy, const Twine &Name="")
Definition: IRBuilder.h:2137
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:2671
An analysis over an "outer" IR unit that provides access to an analysis manager over an "inner" IR un...
Definition: PassManager.h:563
bool isLifetimeStartOrEnd() const LLVM_READONLY
Return true if the instruction is a llvm.lifetime.start or llvm.lifetime.end marker.
bool mayWriteToMemory() const LLVM_READONLY
Return true if this instruction may modify memory.
const DebugLoc & getDebugLoc() const
Return the debug location for this node as a DebugLoc.
Definition: Instruction.h:466
const Module * getModule() const
Return the module owning the function this instruction belongs to or nullptr it the function does not...
Definition: Instruction.cpp:66
const Instruction * getPrevNonDebugInstruction(bool SkipPseudoOp=false) const
Return a pointer to the previous non-debug instruction in the same basic block as 'this',...
InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
Definition: Instruction.cpp:92
const Function * getFunction() const
Return the function this instruction belongs to.
Definition: Instruction.cpp:70
bool mayHaveSideEffects() const LLVM_READONLY
Return true if the instruction may have side effects.
bool mayReadFromMemory() const LLVM_READONLY
Return true if this instruction may read memory.
const Instruction * getNextNonDebugInstruction(bool SkipPseudoOp=false) const
Return a pointer to the next non-debug instruction in the same basic block as 'this',...
void setDebugLoc(DebugLoc Loc)
Set the debug location information for this instruction.
Definition: Instruction.h:463
void setSuccessor(unsigned Idx, BasicBlock *BB)
Update the specified successor to point at the provided block.
void moveBefore(Instruction *MovePos)
Unlink this instruction from its current basic block and insert it into the basic block that MovePos ...
A node in the call graph.
An SCC of the call graph.
A lazily constructed view of the call graph of a module.
An instruction for reading from memory.
Definition: Instructions.h:174
A single uniqued string.
Definition: Metadata.h:720
StringRef getString() const
Definition: Metadata.cpp:610
void eraseFromParent()
This method unlinks 'this' from the containing function and deletes it.
Root of the metadata hierarchy.
Definition: Metadata.h:62
A Module instance is used to store all the information related to an LLVM module.
Definition: Module.h:65
const std::string & getTargetTriple() const
Get the target triple which is a string describing the target host.
Definition: Module.h:297
MutableArrayRef - Represent a mutable reference to an array (0 or more elements consecutively in memo...
Definition: ArrayRef.h:307
A tuple of MDNodes.
Definition: Metadata.h:1729
iterator_range< op_iterator > operands()
Definition: Metadata.h:1825
An interface to create LLVM-IR for OpenMP directives.
Definition: OMPIRBuilder.h:466
static std::pair< int32_t, int32_t > readThreadBoundsForKernel(const Triple &T, Function &Kernel)
}
void addAttributes(omp::RuntimeFunction FnID, Function &Fn)
Add attributes known for FnID to Fn.
IRBuilder<>::InsertPoint InsertPointTy
Type used throughout for insertion points.
Definition: OMPIRBuilder.h:492
static std::pair< int32_t, int32_t > readTeamBoundsForKernel(const Triple &T, Function &Kernel)
Read/write a bounds on teams for Kernel.
PreservedAnalyses run(LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM, LazyCallGraph &CG, CGSCCUpdateResult &UR)
Definition: OpenMPOpt.cpp:5827
PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM)
Definition: OpenMPOpt.cpp:5714
Diagnostic information for optimization analysis remarks.
The optimization diagnostic interface.
Diagnostic information for missed-optimization remarks.
Diagnostic information for applied optimization remarks.
static PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
Definition: Constants.cpp:1852
A set of analyses that are preserved following a run of a transformation pass.
Definition: Analysis.h:111
static PreservedAnalyses none()
Convenience factory function for the empty preserved set.
Definition: Analysis.h:114
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: Analysis.h:117
static ReturnInst * Create(LLVMContext &C, Value *retVal=nullptr, InsertPosition InsertBefore=nullptr)
A vector that has set insertion semantics.
Definition: SetVector.h:57
size_type size() const
Determine the number of elements in the SetVector.
Definition: SetVector.h:98
const value_type & back() const
Return the last element of the SetVector.
Definition: SetVector.h:149
iterator end()
Get an iterator to the end of the SetVector.
Definition: SetVector.h:113
size_type count(const key_type &key) const
Count the number of elements of a given key in the SetVector.
Definition: SetVector.h:264
iterator begin()
Get an iterator to the beginning of the SetVector.
Definition: SetVector.h:103
bool insert(const value_type &X)
Insert a new element into the SetVector.
Definition: SetVector.h:162
size_type size() const
Definition: SmallPtrSet.h:94
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
Definition: SmallPtrSet.h:323
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
Definition: SmallPtrSet.h:412
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
Definition: SmallPtrSet.h:344
iterator begin() const
Definition: SmallPtrSet.h:432
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
Definition: SmallPtrSet.h:479
A SetVector that performs no allocations if smaller than a certain size.
Definition: SetVector.h:370
bool empty() const
Definition: SmallVector.h:94
size_t size() const
Definition: SmallVector.h:91
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
Definition: SmallVector.h:586
void assign(size_type NumElts, ValueParamT Elt)
Definition: SmallVector.h:717
reference emplace_back(ArgTypes &&... Args)
Definition: SmallVector.h:950
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
Definition: SmallVector.h:696
void push_back(const T &Elt)
Definition: SmallVector.h:426
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1209
An instruction for storing to memory.
Definition: Instructions.h:290
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:50
bool starts_with(StringRef Prefix) const
Check if this string starts with the given Prefix.
Definition: StringRef.h:250
Triple - Helper class for working with autoconf configuration names.
Definition: Triple.h:44
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
PointerType * getPointerTo(unsigned AddrSpace=0) const
Return a pointer to the current type.
unsigned getPointerAddressSpace() const
Get the address space of this pointer or pointer vector type.
static Type * getVoidTy(LLVMContext &C)
static IntegerType * getInt16Ty(LLVMContext &C)
static IntegerType * getInt8Ty(LLVMContext &C)
static IntegerType * getInt32Ty(LLVMContext &C)
static UndefValue * get(Type *T)
Static factory methods - Return an 'undef' object of the specified type.
Definition: Constants.cpp:1833
A Use represents the edge between a Value definition and its users.
Definition: Use.h:43
bool replaceUsesOfWith(Value *From, Value *To)
Replace uses of one Value with another.
Definition: User.cpp:21
LLVM Value Representation.
Definition: Value.h:74
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
void setName(const Twine &Name)
Change the name of the value.
Definition: Value.cpp:377
bool hasOneUse() const
Return true if there is exactly one use of this value.
Definition: Value.h:434
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:534
iterator_range< user_iterator > users()
Definition: Value.h:421
User * user_back()
Definition: Value.h:407
const Value * stripPointerCasts() const
Strip off pointer casts, all-zero GEPs and address space casts.
Definition: Value.cpp:694
unsigned getNumUses() const
This method computes the number of uses of this Value.
Definition: Value.cpp:255
StringRef getName() const
Return a constant reference to the value's name.
Definition: Value.cpp:309
An efficient, type-erasing, non-owning reference to a callable.
const ParentTy * getParent() const
Definition: ilist_node.h:32
self_iterator getIterator()
Definition: ilist_node.h:132
NodeTy * getNextNode()
Get the next node, or nullptr for the list tail.
Definition: ilist_node.h:353
A raw_ostream that writes to an std::string.
Definition: raw_ostream.h:661
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
GlobalVariable * getKernelEnvironementGVFromKernelInitCB(CallBase *KernelInitCB)
Definition: OpenMPOpt.cpp:260
ConstantStruct * getKernelEnvironementFromKernelInitCB(CallBase *KernelInitCB)
Definition: OpenMPOpt.cpp:267
bool isValidAtPosition(const ValueAndContext &VAC, InformationCache &InfoCache)
Return true if the value of VAC is a valid at the position of VAC, that is a constant,...
Definition: Attributor.cpp:291
bool isPotentiallyAffectedByBarrier(Attributor &A, const Instruction &I, const AbstractAttribute &QueryingAA)
Return true if I is potentially affected by a barrier.
Definition: Attributor.cpp:890
@ Interprocedural
Definition: Attributor.h:182
bool isNoSyncInst(Attributor &A, const Instruction &I, const AbstractAttribute &QueryingAA)
Return true if I is a nosync instruction.
Definition: Attributor.cpp:206
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
E & operator^=(E &LHS, E RHS)
Definition: BitmaskEnum.h:181
@ Entry
Definition: COFF.h:811
@ C
The default llvm calling convention, compatible with C.
Definition: CallingConv.h:34
initializer< Ty > init(const Ty &Val)
Definition: CommandLine.h:443
std::optional< const char * > toString(const std::optional< DWARFFormValue > &V)
Take an optional DWARFFormValue and try to extract a string value from it.
PointerTypeMap run(const Module &M)
Compute the PointerTypeMap for the module M.
bool isOpenMPDevice(Module &M)
Helper to determine if M is a OpenMP target offloading device module.
Definition: OpenMPOpt.cpp:5942
bool containsOpenMP(Module &M)
Helper to determine if M contains OpenMP.
Definition: OpenMPOpt.cpp:5934
InternalControlVar
IDs for all Internal Control Variables (ICVs).
Definition: OMPConstants.h:26
RuntimeFunction
IDs for all omp runtime library (RTL) functions.
Definition: OMPConstants.h:45
KernelSet getDeviceKernels(Module &M)
Get OpenMP device kernels in M.
Definition: OpenMPOpt.cpp:5902
Function * Kernel
Summary of a kernel (=entry point for target offloading).
Definition: OpenMPOpt.h:21
@ OMP_TGT_EXEC_MODE_GENERIC_SPMD
bool isOpenMPKernel(Function &Fn)
Return true iff Fn is an OpenMP GPU kernel; Fn has the "kernel" attribute.
Definition: OpenMPOpt.cpp:5898
DiagnosticInfoOptimizationBase::Argument NV
const_iterator begin(StringRef path, Style style=Style::native)
Get begin iterator over path.
Definition: Path.cpp:227
const_iterator end(StringRef path)
Get end iterator over path.
Definition: Path.cpp:236
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
Definition: STLExtras.h:329
@ Offset
Definition: DWP.cpp:480
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1722
auto size(R &&Range, std::enable_if_t< std::is_base_of< std::random_access_iterator_tag, typename std::iterator_traits< decltype(Range.begin())>::iterator_category >::value, void > *=nullptr)
Get the size of a range.
Definition: STLExtras.h:1680
bool succ_empty(const Instruction *I)
Definition: CFG.h:255
std::string to_string(const T &Value)
Definition: ScopedPrinter.h:85
bool operator!=(uint64_t V1, const APInt &V2)
Definition: APInt.h:2062
Value * GetPointerBaseWithConstantOffset(Value *Ptr, int64_t &Offset, const DataLayout &DL, bool AllowNonInbounds=true)
Analyze the specified pointer to see if it can be expressed as a base pointer plus a constant offset.
const Value * getUnderlyingObject(const Value *V, unsigned MaxLookup=6)
This method strips off any GEP address adjustments, pointer casts or llvm.threadlocal....
bool operator==(const AddressRangeValuePair &LHS, const AddressRangeValuePair &RHS)
@ CGSCC
Definition: Attributor.h:6419
@ FullLTOPostLink
Full LTO postlink (backend compile) phase.
@ ThinLTOPreLink
ThinLTO prelink (summary) phase.
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:163
bool operator&=(SparseBitVector< ElementSize > *LHS, const SparseBitVector< ElementSize > &RHS)
auto count(R &&Range, const E &Element)
Wrapper function around std::count to count the number of times an element Element occurs in the give...
Definition: STLExtras.h:1914
BasicBlock * SplitBlock(BasicBlock *Old, BasicBlock::iterator SplitPt, DominatorTree *DT, LoopInfo *LI=nullptr, MemorySSAUpdater *MSSAU=nullptr, const Twine &BBName="", bool Before=false)
Split the specified block at the specified instruction.
auto predecessors(const MachineBasicBlock *BB)
ChangeStatus
{
Definition: Attributor.h:484
Constant * ConstantFoldInsertValueInstruction(Constant *Agg, Constant *Val, ArrayRef< unsigned > Idxs)
ConstantFoldInsertValueInstruction - Attempt to constant fold an insertvalue instruction with the spe...
Implement std::hash so that hash_code can be used in STL containers.
Definition: BitVector.h:858
#define N
An abstract attribute for getting assumption information.
Definition: Attributor.h:6154
An abstract state for querying live call edges.
Definition: Attributor.h:5478
virtual const SetVector< Function * > & getOptimisticEdges() const =0
Get the optimistic edges.
bool isExecutedByInitialThreadOnly(const Instruction &I) const
Check if an instruction is executed only by the initial thread.
Definition: Attributor.h:5638
static AAExecutionDomain & createForPosition(const IRPosition &IRP, Attributor &A)
Create an abstract attribute view for the position IRP.
Definition: OpenMPOpt.cpp:5630
virtual ExecutionDomainTy getFunctionExecutionDomain() const =0
virtual ExecutionDomainTy getExecutionDomain(const BasicBlock &) const =0
virtual bool isExecutedInAlignedRegion(Attributor &A, const Instruction &I) const =0
Check if the instruction I is executed in an aligned region, that is, the synchronizing effects befor...
virtual bool isNoOpFence(const FenceInst &FI) const =0
Helper function to determine if FI is a no-op given the information about its execution from ExecDoma...
static const char ID
Unique ID (due to the unique address)
Definition: Attributor.h:5669
An abstract interface for indirect call information interference.
Definition: Attributor.h:6346
An abstract interface for liveness abstract attribute.
Definition: Attributor.h:3967
An abstract interface for all memory location attributes (readnone/argmemonly/inaccessiblememonly/ina...
Definition: Attributor.h:4696
AccessKind
Simple enum to distinguish read/write/read-write accesses.
Definition: Attributor.h:4832
StateType::base_t MemoryLocationsKind
Definition: Attributor.h:4697
static bool isAlignedBarrier(const CallBase &CB, bool ExecutedAligned)
Helper function to determine if CB is an aligned (GPU) barrier.
An abstract Attribute for determining the necessity of the convergent attribute.
Definition: Attributor.h:5714
An abstract attribute for getting all assumption underlying objects.
Definition: Attributor.h:6186
Base struct for all "concrete attribute" deductions.
Definition: Attributor.h:3275
virtual ChangeStatus manifest(Attributor &A)
Hook for the Attributor to trigger the manifestation of the information represented by the abstract a...
Definition: Attributor.h:3390
virtual void initialize(Attributor &A)
Initialize the state with the information in the Attributor A.
Definition: Attributor.h:3339
virtual const std::string getAsStr(Attributor *A) const =0
This function should return the "summarized" assumed state as string.
virtual ChangeStatus updateImpl(Attributor &A)=0
The actual update/transfer function which has to be implemented by the derived classes.
virtual void trackStatistics() const =0
Hook to enable custom statistic tracking, called after manifest that resulted in a change if statisti...
virtual const char * getIdAddr() const =0
This function should return the address of the ID of the AbstractAttribute.
An interface to query the internal state of an abstract attribute.
Definition: Attributor.h:2595
virtual ChangeStatus indicatePessimisticFixpoint()=0
Indicate that the abstract state should converge to the pessimistic state.
virtual bool isAtFixpoint() const =0
Return if this abstract state is fixed, thus does not need to be updated if information changes as it...
virtual bool isValidState() const =0
Return if this abstract state is in a valid state.
virtual ChangeStatus indicateOptimisticFixpoint()=0
Indicate that the abstract state should converge to the optimistic state.
Wrapper for FunctionAnalysisManager.
Definition: Attributor.h:1122
Configuration for the Attributor.
Definition: Attributor.h:1414
std::function< void(Attributor &A, const Function &F)> InitializationCallback
Callback function to be invoked on internal functions marked live.
Definition: Attributor.h:1445
std::optional< unsigned > MaxFixpointIterations
Maximum number of iterations to run until fixpoint.
Definition: Attributor.h:1461
bool RewriteSignatures
Flag to determine if we rewrite function signatures.
Definition: Attributor.h:1431
const char * PassName
}
Definition: Attributor.h:1471
OptimizationRemarkGetter OREGetter
Definition: Attributor.h:1467
IPOAmendableCBTy IPOAmendableCB
Definition: Attributor.h:1474
bool IsModulePass
Is the user of the Attributor a module pass or not.
Definition: Attributor.h:1425
bool DefaultInitializeLiveInternals
Flag to determine if we want to initialize all default AAs for an internal function marked live.
Definition: Attributor.h:1435
The fixpoint analysis framework that orchestrates the attribute deduction.
Definition: Attributor.h:1508
static bool isInternalizable(Function &F)
Returns true if the function F can be internalized.
std::function< std::optional< Constant * >(const GlobalVariable &, const AbstractAttribute *, bool &)> GlobalVariableSimplifictionCallbackTy
Register CB as a simplification callback.
Definition: Attributor.h:2024
std::function< bool(Attributor &, const AbstractAttribute *)> VirtualUseCallbackTy
Definition: Attributor.h:2054
static bool internalizeFunctions(SmallPtrSetImpl< Function * > &FnSet, DenseMap< Function *, Function * > &FnMap)
Make copies of each function in the set FnSet such that the copied version has internal linkage after...
std::function< std::optional< Value * >(const IRPosition &, const AbstractAttribute *, bool &)> SimplifictionCallbackTy
Register CB as a simplification callback.
Definition: Attributor.h:2008
Simple wrapper for a single bit (boolean) state.
Definition: Attributor.h:2879
Support structure for SCC passes to communicate updates the call graph back to the CGSCC pass manager...
Helper to describe and deal with positions in the LLVM-IR.
Definition: Attributor.h:581
static const IRPosition callsite_returned(const CallBase &CB)
Create a position describing the returned value of CB.
Definition: Attributor.h:649
static const IRPosition returned(const Function &F, const CallBaseContext *CBContext=nullptr)
Create a position describing the returned value of F.
Definition: Attributor.h:631
static const IRPosition value(const Value &V, const CallBaseContext *CBContext=nullptr)
Create a position describing the value of V.
Definition: Attributor.h:605
static const IRPosition inst(const Instruction &I, const CallBaseContext *CBContext=nullptr)
Create a position describing the instruction I.
Definition: Attributor.h:617
@ IRP_ARGUMENT
An attribute for a function argument.
Definition: Attributor.h:595
@ IRP_RETURNED
An attribute for the function return value.
Definition: Attributor.h:591
@ IRP_CALL_SITE
An attribute for a call site (function scope).
Definition: Attributor.h:594
@ IRP_CALL_SITE_RETURNED
An attribute for a call site return value.
Definition: Attributor.h:592
@ IRP_FUNCTION
An attribute for a function (scope).
Definition: Attributor.h:593
@ IRP_FLOAT
A position that is not associated with a spot suitable for attributes.
Definition: Attributor.h:589
@ IRP_CALL_SITE_ARGUMENT
An attribute for a call site argument.
Definition: Attributor.h:596
@ IRP_INVALID
An invalid position.
Definition: Attributor.h:588
static const IRPosition function(const Function &F, const CallBaseContext *CBContext=nullptr)
Create a position describing the function scope of F.
Definition: Attributor.h:624
Kind getPositionKind() const
Return the associated position kind.
Definition: Attributor.h:877
static const IRPosition callsite_function(const CallBase &CB)
Create a position describing the function scope of CB.
Definition: Attributor.h:644
Function * getAnchorScope() const
Return the Function surrounding the anchor value.
Definition: Attributor.h:753
Data structure to hold cached (LLVM-IR) information.
Definition: Attributor.h:1198
bool isValidState() const override
See AbstractState::isValidState() NOTE: For now we simply pretend that the worst possible state is in...
Definition: Attributor.h:2654
ChangeStatus indicatePessimisticFixpoint() override
See AbstractState::indicatePessimisticFixpoint(...)
Definition: Attributor.h:2666
Direction
An enum for the direction of the loop.
Definition: LoopInfo.h:220
This struct is a compact representation of a valid (power of two) or undefined (0) alignment.
Definition: Alignment.h:117
Description of a LLVM-IR insertion point (IP) and a debug/source location (filename,...
Definition: OMPIRBuilder.h:608
Helper to tie a abstract state implementation to an abstract attribute.
Definition: Attributor.h:3164
StateType & getState() override
See AbstractAttribute::getState(...).
Definition: Attributor.h:3172
Defines various target-specific GPU grid values that must be consistent between host RTL (plugin),...
Definition: OMPGridValues.h:57