LLVM 19.0.0git
OpenMPOpt.cpp
Go to the documentation of this file.
1//===-- IPO/OpenMPOpt.cpp - Collection of OpenMP specific optimizations ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// OpenMP specific optimizations:
10//
11// - Deduplication of runtime calls, e.g., omp_get_thread_num.
12// - Replacing globalized device memory with stack memory.
13// - Replacing globalized device memory with shared memory.
14// - Parallel region merging.
15// - Transforming generic-mode device kernels to SPMD mode.
16// - Specializing the state machine for generic-mode device kernels.
17//
18//===----------------------------------------------------------------------===//
19
21
24#include "llvm/ADT/SetVector.h"
27#include "llvm/ADT/Statistic.h"
29#include "llvm/ADT/StringRef.h"
38#include "llvm/IR/Assumptions.h"
39#include "llvm/IR/BasicBlock.h"
40#include "llvm/IR/Constants.h"
42#include "llvm/IR/Dominators.h"
43#include "llvm/IR/Function.h"
44#include "llvm/IR/GlobalValue.h"
46#include "llvm/IR/InstrTypes.h"
47#include "llvm/IR/Instruction.h"
50#include "llvm/IR/IntrinsicsAMDGPU.h"
51#include "llvm/IR/IntrinsicsNVPTX.h"
52#include "llvm/IR/LLVMContext.h"
55#include "llvm/Support/Debug.h"
59
60#include <algorithm>
61#include <optional>
62#include <string>
63
64using namespace llvm;
65using namespace omp;
66
67#define DEBUG_TYPE "openmp-opt"
68
70 "openmp-opt-disable", cl::desc("Disable OpenMP specific optimizations."),
71 cl::Hidden, cl::init(false));
72
74 "openmp-opt-enable-merging",
75 cl::desc("Enable the OpenMP region merging optimization."), cl::Hidden,
76 cl::init(false));
77
78static cl::opt<bool>
79 DisableInternalization("openmp-opt-disable-internalization",
80 cl::desc("Disable function internalization."),
81 cl::Hidden, cl::init(false));
82
83static cl::opt<bool> DeduceICVValues("openmp-deduce-icv-values",
84 cl::init(false), cl::Hidden);
85static cl::opt<bool> PrintICVValues("openmp-print-icv-values", cl::init(false),
87static cl::opt<bool> PrintOpenMPKernels("openmp-print-gpu-kernels",
88 cl::init(false), cl::Hidden);
89
91 "openmp-hide-memory-transfer-latency",
92 cl::desc("[WIP] Tries to hide the latency of host to device memory"
93 " transfers"),
94 cl::Hidden, cl::init(false));
95
97 "openmp-opt-disable-deglobalization",
98 cl::desc("Disable OpenMP optimizations involving deglobalization."),
99 cl::Hidden, cl::init(false));
100
102 "openmp-opt-disable-spmdization",
103 cl::desc("Disable OpenMP optimizations involving SPMD-ization."),
104 cl::Hidden, cl::init(false));
105
107 "openmp-opt-disable-folding",
108 cl::desc("Disable OpenMP optimizations involving folding."), cl::Hidden,
109 cl::init(false));
110
112 "openmp-opt-disable-state-machine-rewrite",
113 cl::desc("Disable OpenMP optimizations that replace the state machine."),
114 cl::Hidden, cl::init(false));
115
117 "openmp-opt-disable-barrier-elimination",
118 cl::desc("Disable OpenMP optimizations that eliminate barriers."),
119 cl::Hidden, cl::init(false));
120
122 "openmp-opt-print-module-after",
123 cl::desc("Print the current module after OpenMP optimizations."),
124 cl::Hidden, cl::init(false));
125
127 "openmp-opt-print-module-before",
128 cl::desc("Print the current module before OpenMP optimizations."),
129 cl::Hidden, cl::init(false));
130
132 "openmp-opt-inline-device",
133 cl::desc("Inline all applicible functions on the device."), cl::Hidden,
134 cl::init(false));
135
136static cl::opt<bool>
137 EnableVerboseRemarks("openmp-opt-verbose-remarks",
138 cl::desc("Enables more verbose remarks."), cl::Hidden,
139 cl::init(false));
140
142 SetFixpointIterations("openmp-opt-max-iterations", cl::Hidden,
143 cl::desc("Maximal number of attributor iterations."),
144 cl::init(256));
145
147 SharedMemoryLimit("openmp-opt-shared-limit", cl::Hidden,
148 cl::desc("Maximum amount of shared memory to use."),
149 cl::init(std::numeric_limits<unsigned>::max()));
150
151STATISTIC(NumOpenMPRuntimeCallsDeduplicated,
152 "Number of OpenMP runtime calls deduplicated");
153STATISTIC(NumOpenMPParallelRegionsDeleted,
154 "Number of OpenMP parallel regions deleted");
155STATISTIC(NumOpenMPRuntimeFunctionsIdentified,
156 "Number of OpenMP runtime functions identified");
157STATISTIC(NumOpenMPRuntimeFunctionUsesIdentified,
158 "Number of OpenMP runtime function uses identified");
159STATISTIC(NumOpenMPTargetRegionKernels,
160 "Number of OpenMP target region entry points (=kernels) identified");
161STATISTIC(NumNonOpenMPTargetRegionKernels,
162 "Number of non-OpenMP target region kernels identified");
163STATISTIC(NumOpenMPTargetRegionKernelsSPMD,
164 "Number of OpenMP target region entry points (=kernels) executed in "
165 "SPMD-mode instead of generic-mode");
166STATISTIC(NumOpenMPTargetRegionKernelsWithoutStateMachine,
167 "Number of OpenMP target region entry points (=kernels) executed in "
168 "generic-mode without a state machines");
169STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback,
170 "Number of OpenMP target region entry points (=kernels) executed in "
171 "generic-mode with customized state machines with fallback");
172STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback,
173 "Number of OpenMP target region entry points (=kernels) executed in "
174 "generic-mode with customized state machines without fallback");
176 NumOpenMPParallelRegionsReplacedInGPUStateMachine,
177 "Number of OpenMP parallel regions replaced with ID in GPU state machines");
178STATISTIC(NumOpenMPParallelRegionsMerged,
179 "Number of OpenMP parallel regions merged");
180STATISTIC(NumBytesMovedToSharedMemory,
181 "Amount of memory pushed to shared memory");
182STATISTIC(NumBarriersEliminated, "Number of redundant barriers eliminated");
183
184#if !defined(NDEBUG)
185static constexpr auto TAG = "[" DEBUG_TYPE "]";
186#endif
187
188namespace KernelInfo {
189
190// struct ConfigurationEnvironmentTy {
191// uint8_t UseGenericStateMachine;
192// uint8_t MayUseNestedParallelism;
193// llvm::omp::OMPTgtExecModeFlags ExecMode;
194// int32_t MinThreads;
195// int32_t MaxThreads;
196// int32_t MinTeams;
197// int32_t MaxTeams;
198// };
199
200// struct DynamicEnvironmentTy {
201// uint16_t DebugIndentionLevel;
202// };
203
204// struct KernelEnvironmentTy {
205// ConfigurationEnvironmentTy Configuration;
206// IdentTy *Ident;
207// DynamicEnvironmentTy *DynamicEnv;
208// };
209
210#define KERNEL_ENVIRONMENT_IDX(MEMBER, IDX) \
211 constexpr const unsigned MEMBER##Idx = IDX;
212
213KERNEL_ENVIRONMENT_IDX(Configuration, 0)
215
216#undef KERNEL_ENVIRONMENT_IDX
217
218#define KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MEMBER, IDX) \
219 constexpr const unsigned MEMBER##Idx = IDX;
220
221KERNEL_ENVIRONMENT_CONFIGURATION_IDX(UseGenericStateMachine, 0)
222KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MayUseNestedParallelism, 1)
228
229#undef KERNEL_ENVIRONMENT_CONFIGURATION_IDX
230
231#define KERNEL_ENVIRONMENT_GETTER(MEMBER, RETURNTYPE) \
232 RETURNTYPE *get##MEMBER##FromKernelEnvironment(ConstantStruct *KernelEnvC) { \
233 return cast<RETURNTYPE>(KernelEnvC->getAggregateElement(MEMBER##Idx)); \
234 }
235
238
239#undef KERNEL_ENVIRONMENT_GETTER
240
241#define KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MEMBER) \
242 ConstantInt *get##MEMBER##FromKernelEnvironment( \
243 ConstantStruct *KernelEnvC) { \
244 ConstantStruct *ConfigC = \
245 getConfigurationFromKernelEnvironment(KernelEnvC); \
246 return dyn_cast<ConstantInt>(ConfigC->getAggregateElement(MEMBER##Idx)); \
247 }
248
249KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(UseGenericStateMachine)
250KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MayUseNestedParallelism)
256
257#undef KERNEL_ENVIRONMENT_CONFIGURATION_GETTER
258
261 constexpr const int InitKernelEnvironmentArgNo = 0;
262 return cast<GlobalVariable>(
263 KernelInitCB->getArgOperand(InitKernelEnvironmentArgNo)
265}
266
268 GlobalVariable *KernelEnvGV =
270 return cast<ConstantStruct>(KernelEnvGV->getInitializer());
271}
272} // namespace KernelInfo
273
274namespace {
275
276struct AAHeapToShared;
277
278struct AAICVTracker;
279
280/// OpenMP specific information. For now, stores RFIs and ICVs also needed for
281/// Attributor runs.
282struct OMPInformationCache : public InformationCache {
283 OMPInformationCache(Module &M, AnalysisGetter &AG,
285 bool OpenMPPostLink)
286 : InformationCache(M, AG, Allocator, CGSCC), OMPBuilder(M),
287 OpenMPPostLink(OpenMPPostLink) {
288
289 OMPBuilder.Config.IsTargetDevice = isOpenMPDevice(OMPBuilder.M);
290 OMPBuilder.initialize();
291 initializeRuntimeFunctions(M);
292 initializeInternalControlVars();
293 }
294
295 /// Generic information that describes an internal control variable.
296 struct InternalControlVarInfo {
297 /// The kind, as described by InternalControlVar enum.
299
300 /// The name of the ICV.
302
303 /// Environment variable associated with this ICV.
304 StringRef EnvVarName;
305
306 /// Initial value kind.
307 ICVInitValue InitKind;
308
309 /// Initial value.
310 ConstantInt *InitValue;
311
312 /// Setter RTL function associated with this ICV.
313 RuntimeFunction Setter;
314
315 /// Getter RTL function associated with this ICV.
316 RuntimeFunction Getter;
317
318 /// RTL Function corresponding to the override clause of this ICV
320 };
321
322 /// Generic information that describes a runtime function
323 struct RuntimeFunctionInfo {
324
325 /// The kind, as described by the RuntimeFunction enum.
327
328 /// The name of the function.
330
331 /// Flag to indicate a variadic function.
332 bool IsVarArg;
333
334 /// The return type of the function.
336
337 /// The argument types of the function.
338 SmallVector<Type *, 8> ArgumentTypes;
339
340 /// The declaration if available.
341 Function *Declaration = nullptr;
342
343 /// Uses of this runtime function per function containing the use.
344 using UseVector = SmallVector<Use *, 16>;
345
346 /// Clear UsesMap for runtime function.
347 void clearUsesMap() { UsesMap.clear(); }
348
349 /// Boolean conversion that is true if the runtime function was found.
350 operator bool() const { return Declaration; }
351
352 /// Return the vector of uses in function \p F.
353 UseVector &getOrCreateUseVector(Function *F) {
354 std::shared_ptr<UseVector> &UV = UsesMap[F];
355 if (!UV)
356 UV = std::make_shared<UseVector>();
357 return *UV;
358 }
359
360 /// Return the vector of uses in function \p F or `nullptr` if there are
361 /// none.
362 const UseVector *getUseVector(Function &F) const {
363 auto I = UsesMap.find(&F);
364 if (I != UsesMap.end())
365 return I->second.get();
366 return nullptr;
367 }
368
369 /// Return how many functions contain uses of this runtime function.
370 size_t getNumFunctionsWithUses() const { return UsesMap.size(); }
371
372 /// Return the number of arguments (or the minimal number for variadic
373 /// functions).
374 size_t getNumArgs() const { return ArgumentTypes.size(); }
375
376 /// Run the callback \p CB on each use and forget the use if the result is
377 /// true. The callback will be fed the function in which the use was
378 /// encountered as second argument.
379 void foreachUse(SmallVectorImpl<Function *> &SCC,
380 function_ref<bool(Use &, Function &)> CB) {
381 for (Function *F : SCC)
382 foreachUse(CB, F);
383 }
384
385 /// Run the callback \p CB on each use within the function \p F and forget
386 /// the use if the result is true.
387 void foreachUse(function_ref<bool(Use &, Function &)> CB, Function *F) {
388 SmallVector<unsigned, 8> ToBeDeleted;
389 ToBeDeleted.clear();
390
391 unsigned Idx = 0;
392 UseVector &UV = getOrCreateUseVector(F);
393
394 for (Use *U : UV) {
395 if (CB(*U, *F))
396 ToBeDeleted.push_back(Idx);
397 ++Idx;
398 }
399
400 // Remove the to-be-deleted indices in reverse order as prior
401 // modifications will not modify the smaller indices.
402 while (!ToBeDeleted.empty()) {
403 unsigned Idx = ToBeDeleted.pop_back_val();
404 UV[Idx] = UV.back();
405 UV.pop_back();
406 }
407 }
408
409 private:
410 /// Map from functions to all uses of this runtime function contained in
411 /// them.
413
414 public:
415 /// Iterators for the uses of this runtime function.
416 decltype(UsesMap)::iterator begin() { return UsesMap.begin(); }
417 decltype(UsesMap)::iterator end() { return UsesMap.end(); }
418 };
419
420 /// An OpenMP-IR-Builder instance
421 OpenMPIRBuilder OMPBuilder;
422
423 /// Map from runtime function kind to the runtime function description.
424 EnumeratedArray<RuntimeFunctionInfo, RuntimeFunction,
425 RuntimeFunction::OMPRTL___last>
426 RFIs;
427
428 /// Map from function declarations/definitions to their runtime enum type.
429 DenseMap<Function *, RuntimeFunction> RuntimeFunctionIDMap;
430
431 /// Map from ICV kind to the ICV description.
432 EnumeratedArray<InternalControlVarInfo, InternalControlVar,
433 InternalControlVar::ICV___last>
434 ICVs;
435
436 /// Helper to initialize all internal control variable information for those
437 /// defined in OMPKinds.def.
438 void initializeInternalControlVars() {
439#define ICV_RT_SET(_Name, RTL) \
440 { \
441 auto &ICV = ICVs[_Name]; \
442 ICV.Setter = RTL; \
443 }
444#define ICV_RT_GET(Name, RTL) \
445 { \
446 auto &ICV = ICVs[Name]; \
447 ICV.Getter = RTL; \
448 }
449#define ICV_DATA_ENV(Enum, _Name, _EnvVarName, Init) \
450 { \
451 auto &ICV = ICVs[Enum]; \
452 ICV.Name = _Name; \
453 ICV.Kind = Enum; \
454 ICV.InitKind = Init; \
455 ICV.EnvVarName = _EnvVarName; \
456 switch (ICV.InitKind) { \
457 case ICV_IMPLEMENTATION_DEFINED: \
458 ICV.InitValue = nullptr; \
459 break; \
460 case ICV_ZERO: \
461 ICV.InitValue = ConstantInt::get( \
462 Type::getInt32Ty(OMPBuilder.Int32->getContext()), 0); \
463 break; \
464 case ICV_FALSE: \
465 ICV.InitValue = ConstantInt::getFalse(OMPBuilder.Int1->getContext()); \
466 break; \
467 case ICV_LAST: \
468 break; \
469 } \
470 }
471#include "llvm/Frontend/OpenMP/OMPKinds.def"
472 }
473
474 /// Returns true if the function declaration \p F matches the runtime
475 /// function types, that is, return type \p RTFRetType, and argument types
476 /// \p RTFArgTypes.
477 static bool declMatchesRTFTypes(Function *F, Type *RTFRetType,
478 SmallVector<Type *, 8> &RTFArgTypes) {
479 // TODO: We should output information to the user (under debug output
480 // and via remarks).
481
482 if (!F)
483 return false;
484 if (F->getReturnType() != RTFRetType)
485 return false;
486 if (F->arg_size() != RTFArgTypes.size())
487 return false;
488
489 auto *RTFTyIt = RTFArgTypes.begin();
490 for (Argument &Arg : F->args()) {
491 if (Arg.getType() != *RTFTyIt)
492 return false;
493
494 ++RTFTyIt;
495 }
496
497 return true;
498 }
499
500 // Helper to collect all uses of the declaration in the UsesMap.
501 unsigned collectUses(RuntimeFunctionInfo &RFI, bool CollectStats = true) {
502 unsigned NumUses = 0;
503 if (!RFI.Declaration)
504 return NumUses;
505 OMPBuilder.addAttributes(RFI.Kind, *RFI.Declaration);
506
507 if (CollectStats) {
508 NumOpenMPRuntimeFunctionsIdentified += 1;
509 NumOpenMPRuntimeFunctionUsesIdentified += RFI.Declaration->getNumUses();
510 }
511
512 // TODO: We directly convert uses into proper calls and unknown uses.
513 for (Use &U : RFI.Declaration->uses()) {
514 if (Instruction *UserI = dyn_cast<Instruction>(U.getUser())) {
515 if (!CGSCC || CGSCC->empty() || CGSCC->contains(UserI->getFunction())) {
516 RFI.getOrCreateUseVector(UserI->getFunction()).push_back(&U);
517 ++NumUses;
518 }
519 } else {
520 RFI.getOrCreateUseVector(nullptr).push_back(&U);
521 ++NumUses;
522 }
523 }
524 return NumUses;
525 }
526
527 // Helper function to recollect uses of a runtime function.
528 void recollectUsesForFunction(RuntimeFunction RTF) {
529 auto &RFI = RFIs[RTF];
530 RFI.clearUsesMap();
531 collectUses(RFI, /*CollectStats*/ false);
532 }
533
534 // Helper function to recollect uses of all runtime functions.
535 void recollectUses() {
536 for (int Idx = 0; Idx < RFIs.size(); ++Idx)
537 recollectUsesForFunction(static_cast<RuntimeFunction>(Idx));
538 }
539
540 // Helper function to inherit the calling convention of the function callee.
541 void setCallingConvention(FunctionCallee Callee, CallInst *CI) {
542 if (Function *Fn = dyn_cast<Function>(Callee.getCallee()))
543 CI->setCallingConv(Fn->getCallingConv());
544 }
545
546 // Helper function to determine if it's legal to create a call to the runtime
547 // functions.
548 bool runtimeFnsAvailable(ArrayRef<RuntimeFunction> Fns) {
549 // We can always emit calls if we haven't yet linked in the runtime.
550 if (!OpenMPPostLink)
551 return true;
552
553 // Once the runtime has been already been linked in we cannot emit calls to
554 // any undefined functions.
555 for (RuntimeFunction Fn : Fns) {
556 RuntimeFunctionInfo &RFI = RFIs[Fn];
557
558 if (RFI.Declaration && RFI.Declaration->isDeclaration())
559 return false;
560 }
561 return true;
562 }
563
564 /// Helper to initialize all runtime function information for those defined
565 /// in OpenMPKinds.def.
566 void initializeRuntimeFunctions(Module &M) {
567
568 // Helper macros for handling __VA_ARGS__ in OMP_RTL
569#define OMP_TYPE(VarName, ...) \
570 Type *VarName = OMPBuilder.VarName; \
571 (void)VarName;
572
573#define OMP_ARRAY_TYPE(VarName, ...) \
574 ArrayType *VarName##Ty = OMPBuilder.VarName##Ty; \
575 (void)VarName##Ty; \
576 PointerType *VarName##PtrTy = OMPBuilder.VarName##PtrTy; \
577 (void)VarName##PtrTy;
578
579#define OMP_FUNCTION_TYPE(VarName, ...) \
580 FunctionType *VarName = OMPBuilder.VarName; \
581 (void)VarName; \
582 PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \
583 (void)VarName##Ptr;
584
585#define OMP_STRUCT_TYPE(VarName, ...) \
586 StructType *VarName = OMPBuilder.VarName; \
587 (void)VarName; \
588 PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \
589 (void)VarName##Ptr;
590
591#define OMP_RTL(_Enum, _Name, _IsVarArg, _ReturnType, ...) \
592 { \
593 SmallVector<Type *, 8> ArgsTypes({__VA_ARGS__}); \
594 Function *F = M.getFunction(_Name); \
595 RTLFunctions.insert(F); \
596 if (declMatchesRTFTypes(F, OMPBuilder._ReturnType, ArgsTypes)) { \
597 RuntimeFunctionIDMap[F] = _Enum; \
598 auto &RFI = RFIs[_Enum]; \
599 RFI.Kind = _Enum; \
600 RFI.Name = _Name; \
601 RFI.IsVarArg = _IsVarArg; \
602 RFI.ReturnType = OMPBuilder._ReturnType; \
603 RFI.ArgumentTypes = std::move(ArgsTypes); \
604 RFI.Declaration = F; \
605 unsigned NumUses = collectUses(RFI); \
606 (void)NumUses; \
607 LLVM_DEBUG({ \
608 dbgs() << TAG << RFI.Name << (RFI.Declaration ? "" : " not") \
609 << " found\n"; \
610 if (RFI.Declaration) \
611 dbgs() << TAG << "-> got " << NumUses << " uses in " \
612 << RFI.getNumFunctionsWithUses() \
613 << " different functions.\n"; \
614 }); \
615 } \
616 }
617#include "llvm/Frontend/OpenMP/OMPKinds.def"
618
619 // Remove the `noinline` attribute from `__kmpc`, `ompx::` and `omp_`
620 // functions, except if `optnone` is present.
621 if (isOpenMPDevice(M)) {
622 for (Function &F : M) {
623 for (StringRef Prefix : {"__kmpc", "_ZN4ompx", "omp_"})
624 if (F.hasFnAttribute(Attribute::NoInline) &&
625 F.getName().starts_with(Prefix) &&
626 !F.hasFnAttribute(Attribute::OptimizeNone))
627 F.removeFnAttr(Attribute::NoInline);
628 }
629 }
630
631 // TODO: We should attach the attributes defined in OMPKinds.def.
632 }
633
634 /// Collection of known OpenMP runtime functions..
635 DenseSet<const Function *> RTLFunctions;
636
637 /// Indicates if we have already linked in the OpenMP device library.
638 bool OpenMPPostLink = false;
639};
640
641template <typename Ty, bool InsertInvalidates = true>
642struct BooleanStateWithSetVector : public BooleanState {
643 bool contains(const Ty &Elem) const { return Set.contains(Elem); }
644 bool insert(const Ty &Elem) {
645 if (InsertInvalidates)
647 return Set.insert(Elem);
648 }
649
650 const Ty &operator[](int Idx) const { return Set[Idx]; }
651 bool operator==(const BooleanStateWithSetVector &RHS) const {
652 return BooleanState::operator==(RHS) && Set == RHS.Set;
653 }
654 bool operator!=(const BooleanStateWithSetVector &RHS) const {
655 return !(*this == RHS);
656 }
657
658 bool empty() const { return Set.empty(); }
659 size_t size() const { return Set.size(); }
660
661 /// "Clamp" this state with \p RHS.
662 BooleanStateWithSetVector &operator^=(const BooleanStateWithSetVector &RHS) {
663 BooleanState::operator^=(RHS);
664 Set.insert(RHS.Set.begin(), RHS.Set.end());
665 return *this;
666 }
667
668private:
669 /// A set to keep track of elements.
670 SetVector<Ty> Set;
671
672public:
673 typename decltype(Set)::iterator begin() { return Set.begin(); }
674 typename decltype(Set)::iterator end() { return Set.end(); }
675 typename decltype(Set)::const_iterator begin() const { return Set.begin(); }
676 typename decltype(Set)::const_iterator end() const { return Set.end(); }
677};
678
679template <typename Ty, bool InsertInvalidates = true>
680using BooleanStateWithPtrSetVector =
681 BooleanStateWithSetVector<Ty *, InsertInvalidates>;
682
683struct KernelInfoState : AbstractState {
684 /// Flag to track if we reached a fixpoint.
685 bool IsAtFixpoint = false;
686
687 /// The parallel regions (identified by the outlined parallel functions) that
688 /// can be reached from the associated function.
689 BooleanStateWithPtrSetVector<CallBase, /* InsertInvalidates */ false>
690 ReachedKnownParallelRegions;
691
692 /// State to track what parallel region we might reach.
693 BooleanStateWithPtrSetVector<CallBase> ReachedUnknownParallelRegions;
694
695 /// State to track if we are in SPMD-mode, assumed or know, and why we decided
696 /// we cannot be. If it is assumed, then RequiresFullRuntime should also be
697 /// false.
698 BooleanStateWithPtrSetVector<Instruction, false> SPMDCompatibilityTracker;
699
700 /// The __kmpc_target_init call in this kernel, if any. If we find more than
701 /// one we abort as the kernel is malformed.
702 CallBase *KernelInitCB = nullptr;
703
704 /// The constant kernel environement as taken from and passed to
705 /// __kmpc_target_init.
706 ConstantStruct *KernelEnvC = nullptr;
707
708 /// The __kmpc_target_deinit call in this kernel, if any. If we find more than
709 /// one we abort as the kernel is malformed.
710 CallBase *KernelDeinitCB = nullptr;
711
712 /// Flag to indicate if the associated function is a kernel entry.
713 bool IsKernelEntry = false;
714
715 /// State to track what kernel entries can reach the associated function.
716 BooleanStateWithPtrSetVector<Function, false> ReachingKernelEntries;
717
718 /// State to indicate if we can track parallel level of the associated
719 /// function. We will give up tracking if we encounter unknown caller or the
720 /// caller is __kmpc_parallel_51.
721 BooleanStateWithSetVector<uint8_t> ParallelLevels;
722
723 /// Flag that indicates if the kernel has nested Parallelism
724 bool NestedParallelism = false;
725
726 /// Abstract State interface
727 ///{
728
729 KernelInfoState() = default;
730 KernelInfoState(bool BestState) {
731 if (!BestState)
733 }
734
735 /// See AbstractState::isValidState(...)
736 bool isValidState() const override { return true; }
737
738 /// See AbstractState::isAtFixpoint(...)
739 bool isAtFixpoint() const override { return IsAtFixpoint; }
740
741 /// See AbstractState::indicatePessimisticFixpoint(...)
743 IsAtFixpoint = true;
744 ParallelLevels.indicatePessimisticFixpoint();
745 ReachingKernelEntries.indicatePessimisticFixpoint();
746 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
747 ReachedKnownParallelRegions.indicatePessimisticFixpoint();
748 ReachedUnknownParallelRegions.indicatePessimisticFixpoint();
749 NestedParallelism = true;
751 }
752
753 /// See AbstractState::indicateOptimisticFixpoint(...)
755 IsAtFixpoint = true;
756 ParallelLevels.indicateOptimisticFixpoint();
757 ReachingKernelEntries.indicateOptimisticFixpoint();
758 SPMDCompatibilityTracker.indicateOptimisticFixpoint();
759 ReachedKnownParallelRegions.indicateOptimisticFixpoint();
760 ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
762 }
763
764 /// Return the assumed state
765 KernelInfoState &getAssumed() { return *this; }
766 const KernelInfoState &getAssumed() const { return *this; }
767
768 bool operator==(const KernelInfoState &RHS) const {
769 if (SPMDCompatibilityTracker != RHS.SPMDCompatibilityTracker)
770 return false;
771 if (ReachedKnownParallelRegions != RHS.ReachedKnownParallelRegions)
772 return false;
773 if (ReachedUnknownParallelRegions != RHS.ReachedUnknownParallelRegions)
774 return false;
775 if (ReachingKernelEntries != RHS.ReachingKernelEntries)
776 return false;
777 if (ParallelLevels != RHS.ParallelLevels)
778 return false;
779 if (NestedParallelism != RHS.NestedParallelism)
780 return false;
781 return true;
782 }
783
784 /// Returns true if this kernel contains any OpenMP parallel regions.
785 bool mayContainParallelRegion() {
786 return !ReachedKnownParallelRegions.empty() ||
787 !ReachedUnknownParallelRegions.empty();
788 }
789
790 /// Return empty set as the best state of potential values.
791 static KernelInfoState getBestState() { return KernelInfoState(true); }
792
793 static KernelInfoState getBestState(KernelInfoState &KIS) {
794 return getBestState();
795 }
796
797 /// Return full set as the worst state of potential values.
798 static KernelInfoState getWorstState() { return KernelInfoState(false); }
799
800 /// "Clamp" this state with \p KIS.
801 KernelInfoState operator^=(const KernelInfoState &KIS) {
802 // Do not merge two different _init and _deinit call sites.
803 if (KIS.KernelInitCB) {
804 if (KernelInitCB && KernelInitCB != KIS.KernelInitCB)
805 llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
806 "assumptions.");
807 KernelInitCB = KIS.KernelInitCB;
808 }
809 if (KIS.KernelDeinitCB) {
810 if (KernelDeinitCB && KernelDeinitCB != KIS.KernelDeinitCB)
811 llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
812 "assumptions.");
813 KernelDeinitCB = KIS.KernelDeinitCB;
814 }
815 if (KIS.KernelEnvC) {
816 if (KernelEnvC && KernelEnvC != KIS.KernelEnvC)
817 llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
818 "assumptions.");
819 KernelEnvC = KIS.KernelEnvC;
820 }
821 SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker;
822 ReachedKnownParallelRegions ^= KIS.ReachedKnownParallelRegions;
823 ReachedUnknownParallelRegions ^= KIS.ReachedUnknownParallelRegions;
824 NestedParallelism |= KIS.NestedParallelism;
825 return *this;
826 }
827
828 KernelInfoState operator&=(const KernelInfoState &KIS) {
829 return (*this ^= KIS);
830 }
831
832 ///}
833};
834
835/// Used to map the values physically (in the IR) stored in an offload
836/// array, to a vector in memory.
837struct OffloadArray {
838 /// Physical array (in the IR).
839 AllocaInst *Array = nullptr;
840 /// Mapped values.
841 SmallVector<Value *, 8> StoredValues;
842 /// Last stores made in the offload array.
843 SmallVector<StoreInst *, 8> LastAccesses;
844
845 OffloadArray() = default;
846
847 /// Initializes the OffloadArray with the values stored in \p Array before
848 /// instruction \p Before is reached. Returns false if the initialization
849 /// fails.
850 /// This MUST be used immediately after the construction of the object.
851 bool initialize(AllocaInst &Array, Instruction &Before) {
852 if (!Array.getAllocatedType()->isArrayTy())
853 return false;
854
855 if (!getValues(Array, Before))
856 return false;
857
858 this->Array = &Array;
859 return true;
860 }
861
862 static const unsigned DeviceIDArgNum = 1;
863 static const unsigned BasePtrsArgNum = 3;
864 static const unsigned PtrsArgNum = 4;
865 static const unsigned SizesArgNum = 5;
866
867private:
868 /// Traverses the BasicBlock where \p Array is, collecting the stores made to
869 /// \p Array, leaving StoredValues with the values stored before the
870 /// instruction \p Before is reached.
871 bool getValues(AllocaInst &Array, Instruction &Before) {
872 // Initialize container.
873 const uint64_t NumValues = Array.getAllocatedType()->getArrayNumElements();
874 StoredValues.assign(NumValues, nullptr);
875 LastAccesses.assign(NumValues, nullptr);
876
877 // TODO: This assumes the instruction \p Before is in the same
878 // BasicBlock as Array. Make it general, for any control flow graph.
879 BasicBlock *BB = Array.getParent();
880 if (BB != Before.getParent())
881 return false;
882
883 const DataLayout &DL = Array.getModule()->getDataLayout();
884 const unsigned int PointerSize = DL.getPointerSize();
885
886 for (Instruction &I : *BB) {
887 if (&I == &Before)
888 break;
889
890 if (!isa<StoreInst>(&I))
891 continue;
892
893 auto *S = cast<StoreInst>(&I);
894 int64_t Offset = -1;
895 auto *Dst =
896 GetPointerBaseWithConstantOffset(S->getPointerOperand(), Offset, DL);
897 if (Dst == &Array) {
898 int64_t Idx = Offset / PointerSize;
899 StoredValues[Idx] = getUnderlyingObject(S->getValueOperand());
900 LastAccesses[Idx] = S;
901 }
902 }
903
904 return isFilled();
905 }
906
907 /// Returns true if all values in StoredValues and
908 /// LastAccesses are not nullptrs.
909 bool isFilled() {
910 const unsigned NumValues = StoredValues.size();
911 for (unsigned I = 0; I < NumValues; ++I) {
912 if (!StoredValues[I] || !LastAccesses[I])
913 return false;
914 }
915
916 return true;
917 }
918};
919
920struct OpenMPOpt {
921
922 using OptimizationRemarkGetter =
924
925 OpenMPOpt(SmallVectorImpl<Function *> &SCC, CallGraphUpdater &CGUpdater,
926 OptimizationRemarkGetter OREGetter,
927 OMPInformationCache &OMPInfoCache, Attributor &A)
928 : M(*(*SCC.begin())->getParent()), SCC(SCC), CGUpdater(CGUpdater),
929 OREGetter(OREGetter), OMPInfoCache(OMPInfoCache), A(A) {}
930
931 /// Check if any remarks are enabled for openmp-opt
932 bool remarksEnabled() {
933 auto &Ctx = M.getContext();
934 return Ctx.getDiagHandlerPtr()->isAnyRemarkEnabled(DEBUG_TYPE);
935 }
936
937 /// Run all OpenMP optimizations on the underlying SCC.
938 bool run(bool IsModulePass) {
939 if (SCC.empty())
940 return false;
941
942 bool Changed = false;
943
944 LLVM_DEBUG(dbgs() << TAG << "Run on SCC with " << SCC.size()
945 << " functions\n");
946
947 if (IsModulePass) {
948 Changed |= runAttributor(IsModulePass);
949
950 // Recollect uses, in case Attributor deleted any.
951 OMPInfoCache.recollectUses();
952
953 // TODO: This should be folded into buildCustomStateMachine.
954 Changed |= rewriteDeviceCodeStateMachine();
955
956 if (remarksEnabled())
957 analysisGlobalization();
958 } else {
959 if (PrintICVValues)
960 printICVs();
962 printKernels();
963
964 Changed |= runAttributor(IsModulePass);
965
966 // Recollect uses, in case Attributor deleted any.
967 OMPInfoCache.recollectUses();
968
969 Changed |= deleteParallelRegions();
970
972 Changed |= hideMemTransfersLatency();
973 Changed |= deduplicateRuntimeCalls();
975 if (mergeParallelRegions()) {
976 deduplicateRuntimeCalls();
977 Changed = true;
978 }
979 }
980 }
981
982 if (OMPInfoCache.OpenMPPostLink)
983 Changed |= removeRuntimeSymbols();
984
985 return Changed;
986 }
987
988 /// Print initial ICV values for testing.
989 /// FIXME: This should be done from the Attributor once it is added.
990 void printICVs() const {
991 InternalControlVar ICVs[] = {ICV_nthreads, ICV_active_levels, ICV_cancel,
992 ICV_proc_bind};
993
994 for (Function *F : SCC) {
995 for (auto ICV : ICVs) {
996 auto ICVInfo = OMPInfoCache.ICVs[ICV];
997 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
998 return ORA << "OpenMP ICV " << ore::NV("OpenMPICV", ICVInfo.Name)
999 << " Value: "
1000 << (ICVInfo.InitValue
1001 ? toString(ICVInfo.InitValue->getValue(), 10, true)
1002 : "IMPLEMENTATION_DEFINED");
1003 };
1004
1005 emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPICVTracker", Remark);
1006 }
1007 }
1008 }
1009
1010 /// Print OpenMP GPU kernels for testing.
1011 void printKernels() const {
1012 for (Function *F : SCC) {
1013 if (!omp::isOpenMPKernel(*F))
1014 continue;
1015
1016 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
1017 return ORA << "OpenMP GPU kernel "
1018 << ore::NV("OpenMPGPUKernel", F->getName()) << "\n";
1019 };
1020
1021 emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPGPU", Remark);
1022 }
1023 }
1024
1025 /// Return the call if \p U is a callee use in a regular call. If \p RFI is
1026 /// given it has to be the callee or a nullptr is returned.
1027 static CallInst *getCallIfRegularCall(
1028 Use &U, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
1029 CallInst *CI = dyn_cast<CallInst>(U.getUser());
1030 if (CI && CI->isCallee(&U) && !CI->hasOperandBundles() &&
1031 (!RFI ||
1032 (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
1033 return CI;
1034 return nullptr;
1035 }
1036
1037 /// Return the call if \p V is a regular call. If \p RFI is given it has to be
1038 /// the callee or a nullptr is returned.
1039 static CallInst *getCallIfRegularCall(
1040 Value &V, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
1041 CallInst *CI = dyn_cast<CallInst>(&V);
1042 if (CI && !CI->hasOperandBundles() &&
1043 (!RFI ||
1044 (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
1045 return CI;
1046 return nullptr;
1047 }
1048
1049private:
1050 /// Merge parallel regions when it is safe.
1051 bool mergeParallelRegions() {
1052 const unsigned CallbackCalleeOperand = 2;
1053 const unsigned CallbackFirstArgOperand = 3;
1054 using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
1055
1056 // Check if there are any __kmpc_fork_call calls to merge.
1057 OMPInformationCache::RuntimeFunctionInfo &RFI =
1058 OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
1059
1060 if (!RFI.Declaration)
1061 return false;
1062
1063 // Unmergable calls that prevent merging a parallel region.
1064 OMPInformationCache::RuntimeFunctionInfo UnmergableCallsInfo[] = {
1065 OMPInfoCache.RFIs[OMPRTL___kmpc_push_proc_bind],
1066 OMPInfoCache.RFIs[OMPRTL___kmpc_push_num_threads],
1067 };
1068
1069 bool Changed = false;
1070 LoopInfo *LI = nullptr;
1071 DominatorTree *DT = nullptr;
1072
1074
1075 BasicBlock *StartBB = nullptr, *EndBB = nullptr;
1076 auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
1077 BasicBlock *CGStartBB = CodeGenIP.getBlock();
1078 BasicBlock *CGEndBB =
1079 SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
1080 assert(StartBB != nullptr && "StartBB should not be null");
1081 CGStartBB->getTerminator()->setSuccessor(0, StartBB);
1082 assert(EndBB != nullptr && "EndBB should not be null");
1083 EndBB->getTerminator()->setSuccessor(0, CGEndBB);
1084 };
1085
1086 auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, Value &,
1087 Value &Inner, Value *&ReplacementValue) -> InsertPointTy {
1088 ReplacementValue = &Inner;
1089 return CodeGenIP;
1090 };
1091
1092 auto FiniCB = [&](InsertPointTy CodeGenIP) {};
1093
1094 /// Create a sequential execution region within a merged parallel region,
1095 /// encapsulated in a master construct with a barrier for synchronization.
1096 auto CreateSequentialRegion = [&](Function *OuterFn,
1097 BasicBlock *OuterPredBB,
1098 Instruction *SeqStartI,
1099 Instruction *SeqEndI) {
1100 // Isolate the instructions of the sequential region to a separate
1101 // block.
1102 BasicBlock *ParentBB = SeqStartI->getParent();
1103 BasicBlock *SeqEndBB =
1104 SplitBlock(ParentBB, SeqEndI->getNextNode(), DT, LI);
1105 BasicBlock *SeqAfterBB =
1106 SplitBlock(SeqEndBB, &*SeqEndBB->getFirstInsertionPt(), DT, LI);
1107 BasicBlock *SeqStartBB =
1108 SplitBlock(ParentBB, SeqStartI, DT, LI, nullptr, "seq.par.merged");
1109
1110 assert(ParentBB->getUniqueSuccessor() == SeqStartBB &&
1111 "Expected a different CFG");
1112 const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
1113 ParentBB->getTerminator()->eraseFromParent();
1114
1115 auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
1116 BasicBlock *CGStartBB = CodeGenIP.getBlock();
1117 BasicBlock *CGEndBB =
1118 SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
1119 assert(SeqStartBB != nullptr && "SeqStartBB should not be null");
1120 CGStartBB->getTerminator()->setSuccessor(0, SeqStartBB);
1121 assert(SeqEndBB != nullptr && "SeqEndBB should not be null");
1122 SeqEndBB->getTerminator()->setSuccessor(0, CGEndBB);
1123 };
1124 auto FiniCB = [&](InsertPointTy CodeGenIP) {};
1125
1126 // Find outputs from the sequential region to outside users and
1127 // broadcast their values to them.
1128 for (Instruction &I : *SeqStartBB) {
1129 SmallPtrSet<Instruction *, 4> OutsideUsers;
1130 for (User *Usr : I.users()) {
1131 Instruction &UsrI = *cast<Instruction>(Usr);
1132 // Ignore outputs to LT intrinsics, code extraction for the merged
1133 // parallel region will fix them.
1134 if (UsrI.isLifetimeStartOrEnd())
1135 continue;
1136
1137 if (UsrI.getParent() != SeqStartBB)
1138 OutsideUsers.insert(&UsrI);
1139 }
1140
1141 if (OutsideUsers.empty())
1142 continue;
1143
1144 // Emit an alloca in the outer region to store the broadcasted
1145 // value.
1146 const DataLayout &DL = M.getDataLayout();
1147 AllocaInst *AllocaI = new AllocaInst(
1148 I.getType(), DL.getAllocaAddrSpace(), nullptr,
1149 I.getName() + ".seq.output.alloc", OuterFn->front().begin());
1150
1151 // Emit a store instruction in the sequential BB to update the
1152 // value.
1153 new StoreInst(&I, AllocaI, SeqStartBB->getTerminator()->getIterator());
1154
1155 // Emit a load instruction and replace the use of the output value
1156 // with it.
1157 for (Instruction *UsrI : OutsideUsers) {
1158 LoadInst *LoadI = new LoadInst(I.getType(), AllocaI,
1159 I.getName() + ".seq.output.load",
1160 UsrI->getIterator());
1161 UsrI->replaceUsesOfWith(&I, LoadI);
1162 }
1163 }
1164
1166 InsertPointTy(ParentBB, ParentBB->end()), DL);
1167 InsertPointTy SeqAfterIP =
1168 OMPInfoCache.OMPBuilder.createMaster(Loc, BodyGenCB, FiniCB);
1169
1170 OMPInfoCache.OMPBuilder.createBarrier(SeqAfterIP, OMPD_parallel);
1171
1172 BranchInst::Create(SeqAfterBB, SeqAfterIP.getBlock());
1173
1174 LLVM_DEBUG(dbgs() << TAG << "After sequential inlining " << *OuterFn
1175 << "\n");
1176 };
1177
1178 // Helper to merge the __kmpc_fork_call calls in MergableCIs. They are all
1179 // contained in BB and only separated by instructions that can be
1180 // redundantly executed in parallel. The block BB is split before the first
1181 // call (in MergableCIs) and after the last so the entire region we merge
1182 // into a single parallel region is contained in a single basic block
1183 // without any other instructions. We use the OpenMPIRBuilder to outline
1184 // that block and call the resulting function via __kmpc_fork_call.
1185 auto Merge = [&](const SmallVectorImpl<CallInst *> &MergableCIs,
1186 BasicBlock *BB) {
1187 // TODO: Change the interface to allow single CIs expanded, e.g, to
1188 // include an outer loop.
1189 assert(MergableCIs.size() > 1 && "Assumed multiple mergable CIs");
1190
1191 auto Remark = [&](OptimizationRemark OR) {
1192 OR << "Parallel region merged with parallel region"
1193 << (MergableCIs.size() > 2 ? "s" : "") << " at ";
1194 for (auto *CI : llvm::drop_begin(MergableCIs)) {
1195 OR << ore::NV("OpenMPParallelMerge", CI->getDebugLoc());
1196 if (CI != MergableCIs.back())
1197 OR << ", ";
1198 }
1199 return OR << ".";
1200 };
1201
1202 emitRemark<OptimizationRemark>(MergableCIs.front(), "OMP150", Remark);
1203
1204 Function *OriginalFn = BB->getParent();
1205 LLVM_DEBUG(dbgs() << TAG << "Merge " << MergableCIs.size()
1206 << " parallel regions in " << OriginalFn->getName()
1207 << "\n");
1208
1209 // Isolate the calls to merge in a separate block.
1210 EndBB = SplitBlock(BB, MergableCIs.back()->getNextNode(), DT, LI);
1211 BasicBlock *AfterBB =
1212 SplitBlock(EndBB, &*EndBB->getFirstInsertionPt(), DT, LI);
1213 StartBB = SplitBlock(BB, MergableCIs.front(), DT, LI, nullptr,
1214 "omp.par.merged");
1215
1216 assert(BB->getUniqueSuccessor() == StartBB && "Expected a different CFG");
1217 const DebugLoc DL = BB->getTerminator()->getDebugLoc();
1218 BB->getTerminator()->eraseFromParent();
1219
1220 // Create sequential regions for sequential instructions that are
1221 // in-between mergable parallel regions.
1222 for (auto *It = MergableCIs.begin(), *End = MergableCIs.end() - 1;
1223 It != End; ++It) {
1224 Instruction *ForkCI = *It;
1225 Instruction *NextForkCI = *(It + 1);
1226
1227 // Continue if there are not in-between instructions.
1228 if (ForkCI->getNextNode() == NextForkCI)
1229 continue;
1230
1231 CreateSequentialRegion(OriginalFn, BB, ForkCI->getNextNode(),
1232 NextForkCI->getPrevNode());
1233 }
1234
1235 OpenMPIRBuilder::LocationDescription Loc(InsertPointTy(BB, BB->end()),
1236 DL);
1237 IRBuilder<>::InsertPoint AllocaIP(
1238 &OriginalFn->getEntryBlock(),
1239 OriginalFn->getEntryBlock().getFirstInsertionPt());
1240 // Create the merged parallel region with default proc binding, to
1241 // avoid overriding binding settings, and without explicit cancellation.
1242 InsertPointTy AfterIP = OMPInfoCache.OMPBuilder.createParallel(
1243 Loc, AllocaIP, BodyGenCB, PrivCB, FiniCB, nullptr, nullptr,
1244 OMP_PROC_BIND_default, /* IsCancellable */ false);
1245 BranchInst::Create(AfterBB, AfterIP.getBlock());
1246
1247 // Perform the actual outlining.
1248 OMPInfoCache.OMPBuilder.finalize(OriginalFn);
1249
1250 Function *OutlinedFn = MergableCIs.front()->getCaller();
1251
1252 // Replace the __kmpc_fork_call calls with direct calls to the outlined
1253 // callbacks.
1255 for (auto *CI : MergableCIs) {
1256 Value *Callee = CI->getArgOperand(CallbackCalleeOperand);
1257 FunctionType *FT = OMPInfoCache.OMPBuilder.ParallelTask;
1258 Args.clear();
1259 Args.push_back(OutlinedFn->getArg(0));
1260 Args.push_back(OutlinedFn->getArg(1));
1261 for (unsigned U = CallbackFirstArgOperand, E = CI->arg_size(); U < E;
1262 ++U)
1263 Args.push_back(CI->getArgOperand(U));
1264
1265 CallInst *NewCI =
1266 CallInst::Create(FT, Callee, Args, "", CI->getIterator());
1267 if (CI->getDebugLoc())
1268 NewCI->setDebugLoc(CI->getDebugLoc());
1269
1270 // Forward parameter attributes from the callback to the callee.
1271 for (unsigned U = CallbackFirstArgOperand, E = CI->arg_size(); U < E;
1272 ++U)
1273 for (const Attribute &A : CI->getAttributes().getParamAttrs(U))
1274 NewCI->addParamAttr(
1275 U - (CallbackFirstArgOperand - CallbackCalleeOperand), A);
1276
1277 // Emit an explicit barrier to replace the implicit fork-join barrier.
1278 if (CI != MergableCIs.back()) {
1279 // TODO: Remove barrier if the merged parallel region includes the
1280 // 'nowait' clause.
1281 OMPInfoCache.OMPBuilder.createBarrier(
1282 InsertPointTy(NewCI->getParent(),
1283 NewCI->getNextNode()->getIterator()),
1284 OMPD_parallel);
1285 }
1286
1287 CI->eraseFromParent();
1288 }
1289
1290 assert(OutlinedFn != OriginalFn && "Outlining failed");
1291 CGUpdater.registerOutlinedFunction(*OriginalFn, *OutlinedFn);
1292 CGUpdater.reanalyzeFunction(*OriginalFn);
1293
1294 NumOpenMPParallelRegionsMerged += MergableCIs.size();
1295
1296 return true;
1297 };
1298
1299 // Helper function that identifes sequences of
1300 // __kmpc_fork_call uses in a basic block.
1301 auto DetectPRsCB = [&](Use &U, Function &F) {
1302 CallInst *CI = getCallIfRegularCall(U, &RFI);
1303 BB2PRMap[CI->getParent()].insert(CI);
1304
1305 return false;
1306 };
1307
1308 BB2PRMap.clear();
1309 RFI.foreachUse(SCC, DetectPRsCB);
1310 SmallVector<SmallVector<CallInst *, 4>, 4> MergableCIsVector;
1311 // Find mergable parallel regions within a basic block that are
1312 // safe to merge, that is any in-between instructions can safely
1313 // execute in parallel after merging.
1314 // TODO: support merging across basic-blocks.
1315 for (auto &It : BB2PRMap) {
1316 auto &CIs = It.getSecond();
1317 if (CIs.size() < 2)
1318 continue;
1319
1320 BasicBlock *BB = It.getFirst();
1321 SmallVector<CallInst *, 4> MergableCIs;
1322
1323 /// Returns true if the instruction is mergable, false otherwise.
1324 /// A terminator instruction is unmergable by definition since merging
1325 /// works within a BB. Instructions before the mergable region are
1326 /// mergable if they are not calls to OpenMP runtime functions that may
1327 /// set different execution parameters for subsequent parallel regions.
1328 /// Instructions in-between parallel regions are mergable if they are not
1329 /// calls to any non-intrinsic function since that may call a non-mergable
1330 /// OpenMP runtime function.
1331 auto IsMergable = [&](Instruction &I, bool IsBeforeMergableRegion) {
1332 // We do not merge across BBs, hence return false (unmergable) if the
1333 // instruction is a terminator.
1334 if (I.isTerminator())
1335 return false;
1336
1337 if (!isa<CallInst>(&I))
1338 return true;
1339
1340 CallInst *CI = cast<CallInst>(&I);
1341 if (IsBeforeMergableRegion) {
1342 Function *CalledFunction = CI->getCalledFunction();
1343 if (!CalledFunction)
1344 return false;
1345 // Return false (unmergable) if the call before the parallel
1346 // region calls an explicit affinity (proc_bind) or number of
1347 // threads (num_threads) compiler-generated function. Those settings
1348 // may be incompatible with following parallel regions.
1349 // TODO: ICV tracking to detect compatibility.
1350 for (const auto &RFI : UnmergableCallsInfo) {
1351 if (CalledFunction == RFI.Declaration)
1352 return false;
1353 }
1354 } else {
1355 // Return false (unmergable) if there is a call instruction
1356 // in-between parallel regions when it is not an intrinsic. It
1357 // may call an unmergable OpenMP runtime function in its callpath.
1358 // TODO: Keep track of possible OpenMP calls in the callpath.
1359 if (!isa<IntrinsicInst>(CI))
1360 return false;
1361 }
1362
1363 return true;
1364 };
1365 // Find maximal number of parallel region CIs that are safe to merge.
1366 for (auto It = BB->begin(), End = BB->end(); It != End;) {
1367 Instruction &I = *It;
1368 ++It;
1369
1370 if (CIs.count(&I)) {
1371 MergableCIs.push_back(cast<CallInst>(&I));
1372 continue;
1373 }
1374
1375 // Continue expanding if the instruction is mergable.
1376 if (IsMergable(I, MergableCIs.empty()))
1377 continue;
1378
1379 // Forward the instruction iterator to skip the next parallel region
1380 // since there is an unmergable instruction which can affect it.
1381 for (; It != End; ++It) {
1382 Instruction &SkipI = *It;
1383 if (CIs.count(&SkipI)) {
1384 LLVM_DEBUG(dbgs() << TAG << "Skip parallel region " << SkipI
1385 << " due to " << I << "\n");
1386 ++It;
1387 break;
1388 }
1389 }
1390
1391 // Store mergable regions found.
1392 if (MergableCIs.size() > 1) {
1393 MergableCIsVector.push_back(MergableCIs);
1394 LLVM_DEBUG(dbgs() << TAG << "Found " << MergableCIs.size()
1395 << " parallel regions in block " << BB->getName()
1396 << " of function " << BB->getParent()->getName()
1397 << "\n";);
1398 }
1399
1400 MergableCIs.clear();
1401 }
1402
1403 if (!MergableCIsVector.empty()) {
1404 Changed = true;
1405
1406 for (auto &MergableCIs : MergableCIsVector)
1407 Merge(MergableCIs, BB);
1408 MergableCIsVector.clear();
1409 }
1410 }
1411
1412 if (Changed) {
1413 /// Re-collect use for fork calls, emitted barrier calls, and
1414 /// any emitted master/end_master calls.
1415 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_fork_call);
1416 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_barrier);
1417 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_master);
1418 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_end_master);
1419 }
1420
1421 return Changed;
1422 }
1423
1424 /// Try to delete parallel regions if possible.
1425 bool deleteParallelRegions() {
1426 const unsigned CallbackCalleeOperand = 2;
1427
1428 OMPInformationCache::RuntimeFunctionInfo &RFI =
1429 OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
1430
1431 if (!RFI.Declaration)
1432 return false;
1433
1434 bool Changed = false;
1435 auto DeleteCallCB = [&](Use &U, Function &) {
1436 CallInst *CI = getCallIfRegularCall(U);
1437 if (!CI)
1438 return false;
1439 auto *Fn = dyn_cast<Function>(
1440 CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts());
1441 if (!Fn)
1442 return false;
1443 if (!Fn->onlyReadsMemory())
1444 return false;
1445 if (!Fn->hasFnAttribute(Attribute::WillReturn))
1446 return false;
1447
1448 LLVM_DEBUG(dbgs() << TAG << "Delete read-only parallel region in "
1449 << CI->getCaller()->getName() << "\n");
1450
1451 auto Remark = [&](OptimizationRemark OR) {
1452 return OR << "Removing parallel region with no side-effects.";
1453 };
1454 emitRemark<OptimizationRemark>(CI, "OMP160", Remark);
1455
1456 CGUpdater.removeCallSite(*CI);
1457 CI->eraseFromParent();
1458 Changed = true;
1459 ++NumOpenMPParallelRegionsDeleted;
1460 return true;
1461 };
1462
1463 RFI.foreachUse(SCC, DeleteCallCB);
1464
1465 return Changed;
1466 }
1467
1468 /// Try to eliminate runtime calls by reusing existing ones.
1469 bool deduplicateRuntimeCalls() {
1470 bool Changed = false;
1471
1472 RuntimeFunction DeduplicableRuntimeCallIDs[] = {
1473 OMPRTL_omp_get_num_threads,
1474 OMPRTL_omp_in_parallel,
1475 OMPRTL_omp_get_cancellation,
1476 OMPRTL_omp_get_supported_active_levels,
1477 OMPRTL_omp_get_level,
1478 OMPRTL_omp_get_ancestor_thread_num,
1479 OMPRTL_omp_get_team_size,
1480 OMPRTL_omp_get_active_level,
1481 OMPRTL_omp_in_final,
1482 OMPRTL_omp_get_proc_bind,
1483 OMPRTL_omp_get_num_places,
1484 OMPRTL_omp_get_num_procs,
1485 OMPRTL_omp_get_place_num,
1486 OMPRTL_omp_get_partition_num_places,
1487 OMPRTL_omp_get_partition_place_nums};
1488
1489 // Global-tid is handled separately.
1491 collectGlobalThreadIdArguments(GTIdArgs);
1492 LLVM_DEBUG(dbgs() << TAG << "Found " << GTIdArgs.size()
1493 << " global thread ID arguments\n");
1494
1495 for (Function *F : SCC) {
1496 for (auto DeduplicableRuntimeCallID : DeduplicableRuntimeCallIDs)
1497 Changed |= deduplicateRuntimeCalls(
1498 *F, OMPInfoCache.RFIs[DeduplicableRuntimeCallID]);
1499
1500 // __kmpc_global_thread_num is special as we can replace it with an
1501 // argument in enough cases to make it worth trying.
1502 Value *GTIdArg = nullptr;
1503 for (Argument &Arg : F->args())
1504 if (GTIdArgs.count(&Arg)) {
1505 GTIdArg = &Arg;
1506 break;
1507 }
1508 Changed |= deduplicateRuntimeCalls(
1509 *F, OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num], GTIdArg);
1510 }
1511
1512 return Changed;
1513 }
1514
1515 /// Tries to remove known runtime symbols that are optional from the module.
1516 bool removeRuntimeSymbols() {
1517 // The RPC client symbol is defined in `libc` and indicates that something
1518 // required an RPC server. If its users were all optimized out then we can
1519 // safely remove it.
1520 // TODO: This should be somewhere more common in the future.
1521 if (GlobalVariable *GV = M.getNamedGlobal("__llvm_libc_rpc_client")) {
1522 if (!GV->getType()->isPointerTy())
1523 return false;
1524
1525 Constant *C = GV->getInitializer();
1526 if (!C)
1527 return false;
1528
1529 // Check to see if the only user of the RPC client is the external handle.
1530 GlobalVariable *Client = dyn_cast<GlobalVariable>(C->stripPointerCasts());
1531 if (!Client || Client->getNumUses() > 1 ||
1532 Client->user_back() != GV->getInitializer())
1533 return false;
1534
1535 Client->replaceAllUsesWith(PoisonValue::get(Client->getType()));
1536 Client->eraseFromParent();
1537
1538 GV->replaceAllUsesWith(PoisonValue::get(GV->getType()));
1539 GV->eraseFromParent();
1540
1541 return true;
1542 }
1543 return false;
1544 }
1545
1546 /// Tries to hide the latency of runtime calls that involve host to
1547 /// device memory transfers by splitting them into their "issue" and "wait"
1548 /// versions. The "issue" is moved upwards as much as possible. The "wait" is
1549 /// moved downards as much as possible. The "issue" issues the memory transfer
1550 /// asynchronously, returning a handle. The "wait" waits in the returned
1551 /// handle for the memory transfer to finish.
1552 bool hideMemTransfersLatency() {
1553 auto &RFI = OMPInfoCache.RFIs[OMPRTL___tgt_target_data_begin_mapper];
1554 bool Changed = false;
1555 auto SplitMemTransfers = [&](Use &U, Function &Decl) {
1556 auto *RTCall = getCallIfRegularCall(U, &RFI);
1557 if (!RTCall)
1558 return false;
1559
1560 OffloadArray OffloadArrays[3];
1561 if (!getValuesInOffloadArrays(*RTCall, OffloadArrays))
1562 return false;
1563
1564 LLVM_DEBUG(dumpValuesInOffloadArrays(OffloadArrays));
1565
1566 // TODO: Check if can be moved upwards.
1567 bool WasSplit = false;
1568 Instruction *WaitMovementPoint = canBeMovedDownwards(*RTCall);
1569 if (WaitMovementPoint)
1570 WasSplit = splitTargetDataBeginRTC(*RTCall, *WaitMovementPoint);
1571
1572 Changed |= WasSplit;
1573 return WasSplit;
1574 };
1575 if (OMPInfoCache.runtimeFnsAvailable(
1576 {OMPRTL___tgt_target_data_begin_mapper_issue,
1577 OMPRTL___tgt_target_data_begin_mapper_wait}))
1578 RFI.foreachUse(SCC, SplitMemTransfers);
1579
1580 return Changed;
1581 }
1582
1583 void analysisGlobalization() {
1584 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
1585
1586 auto CheckGlobalization = [&](Use &U, Function &Decl) {
1587 if (CallInst *CI = getCallIfRegularCall(U, &RFI)) {
1588 auto Remark = [&](OptimizationRemarkMissed ORM) {
1589 return ORM
1590 << "Found thread data sharing on the GPU. "
1591 << "Expect degraded performance due to data globalization.";
1592 };
1593 emitRemark<OptimizationRemarkMissed>(CI, "OMP112", Remark);
1594 }
1595
1596 return false;
1597 };
1598
1599 RFI.foreachUse(SCC, CheckGlobalization);
1600 }
1601
1602 /// Maps the values stored in the offload arrays passed as arguments to
1603 /// \p RuntimeCall into the offload arrays in \p OAs.
1604 bool getValuesInOffloadArrays(CallInst &RuntimeCall,
1606 assert(OAs.size() == 3 && "Need space for three offload arrays!");
1607
1608 // A runtime call that involves memory offloading looks something like:
1609 // call void @__tgt_target_data_begin_mapper(arg0, arg1,
1610 // i8** %offload_baseptrs, i8** %offload_ptrs, i64* %offload_sizes,
1611 // ...)
1612 // So, the idea is to access the allocas that allocate space for these
1613 // offload arrays, offload_baseptrs, offload_ptrs, offload_sizes.
1614 // Therefore:
1615 // i8** %offload_baseptrs.
1616 Value *BasePtrsArg =
1617 RuntimeCall.getArgOperand(OffloadArray::BasePtrsArgNum);
1618 // i8** %offload_ptrs.
1619 Value *PtrsArg = RuntimeCall.getArgOperand(OffloadArray::PtrsArgNum);
1620 // i8** %offload_sizes.
1621 Value *SizesArg = RuntimeCall.getArgOperand(OffloadArray::SizesArgNum);
1622
1623 // Get values stored in **offload_baseptrs.
1624 auto *V = getUnderlyingObject(BasePtrsArg);
1625 if (!isa<AllocaInst>(V))
1626 return false;
1627 auto *BasePtrsArray = cast<AllocaInst>(V);
1628 if (!OAs[0].initialize(*BasePtrsArray, RuntimeCall))
1629 return false;
1630
1631 // Get values stored in **offload_baseptrs.
1632 V = getUnderlyingObject(PtrsArg);
1633 if (!isa<AllocaInst>(V))
1634 return false;
1635 auto *PtrsArray = cast<AllocaInst>(V);
1636 if (!OAs[1].initialize(*PtrsArray, RuntimeCall))
1637 return false;
1638
1639 // Get values stored in **offload_sizes.
1640 V = getUnderlyingObject(SizesArg);
1641 // If it's a [constant] global array don't analyze it.
1642 if (isa<GlobalValue>(V))
1643 return isa<Constant>(V);
1644 if (!isa<AllocaInst>(V))
1645 return false;
1646
1647 auto *SizesArray = cast<AllocaInst>(V);
1648 if (!OAs[2].initialize(*SizesArray, RuntimeCall))
1649 return false;
1650
1651 return true;
1652 }
1653
1654 /// Prints the values in the OffloadArrays \p OAs using LLVM_DEBUG.
1655 /// For now this is a way to test that the function getValuesInOffloadArrays
1656 /// is working properly.
1657 /// TODO: Move this to a unittest when unittests are available for OpenMPOpt.
1658 void dumpValuesInOffloadArrays(ArrayRef<OffloadArray> OAs) {
1659 assert(OAs.size() == 3 && "There are three offload arrays to debug!");
1660
1661 LLVM_DEBUG(dbgs() << TAG << " Successfully got offload values:\n");
1662 std::string ValuesStr;
1663 raw_string_ostream Printer(ValuesStr);
1664 std::string Separator = " --- ";
1665
1666 for (auto *BP : OAs[0].StoredValues) {
1667 BP->print(Printer);
1668 Printer << Separator;
1669 }
1670 LLVM_DEBUG(dbgs() << "\t\toffload_baseptrs: " << Printer.str() << "\n");
1671 ValuesStr.clear();
1672
1673 for (auto *P : OAs[1].StoredValues) {
1674 P->print(Printer);
1675 Printer << Separator;
1676 }
1677 LLVM_DEBUG(dbgs() << "\t\toffload_ptrs: " << Printer.str() << "\n");
1678 ValuesStr.clear();
1679
1680 for (auto *S : OAs[2].StoredValues) {
1681 S->print(Printer);
1682 Printer << Separator;
1683 }
1684 LLVM_DEBUG(dbgs() << "\t\toffload_sizes: " << Printer.str() << "\n");
1685 }
1686
1687 /// Returns the instruction where the "wait" counterpart \p RuntimeCall can be
1688 /// moved. Returns nullptr if the movement is not possible, or not worth it.
1689 Instruction *canBeMovedDownwards(CallInst &RuntimeCall) {
1690 // FIXME: This traverses only the BasicBlock where RuntimeCall is.
1691 // Make it traverse the CFG.
1692
1693 Instruction *CurrentI = &RuntimeCall;
1694 bool IsWorthIt = false;
1695 while ((CurrentI = CurrentI->getNextNode())) {
1696
1697 // TODO: Once we detect the regions to be offloaded we should use the
1698 // alias analysis manager to check if CurrentI may modify one of
1699 // the offloaded regions.
1700 if (CurrentI->mayHaveSideEffects() || CurrentI->mayReadFromMemory()) {
1701 if (IsWorthIt)
1702 return CurrentI;
1703
1704 return nullptr;
1705 }
1706
1707 // FIXME: For now if we move it over anything without side effect
1708 // is worth it.
1709 IsWorthIt = true;
1710 }
1711
1712 // Return end of BasicBlock.
1713 return RuntimeCall.getParent()->getTerminator();
1714 }
1715
1716 /// Splits \p RuntimeCall into its "issue" and "wait" counterparts.
1717 bool splitTargetDataBeginRTC(CallInst &RuntimeCall,
1718 Instruction &WaitMovementPoint) {
1719 // Create stack allocated handle (__tgt_async_info) at the beginning of the
1720 // function. Used for storing information of the async transfer, allowing to
1721 // wait on it later.
1722 auto &IRBuilder = OMPInfoCache.OMPBuilder;
1723 Function *F = RuntimeCall.getCaller();
1724 BasicBlock &Entry = F->getEntryBlock();
1725 IRBuilder.Builder.SetInsertPoint(&Entry,
1726 Entry.getFirstNonPHIOrDbgOrAlloca());
1727 Value *Handle = IRBuilder.Builder.CreateAlloca(
1728 IRBuilder.AsyncInfo, /*ArraySize=*/nullptr, "handle");
1729 Handle =
1730 IRBuilder.Builder.CreateAddrSpaceCast(Handle, IRBuilder.AsyncInfoPtr);
1731
1732 // Add "issue" runtime call declaration:
1733 // declare %struct.tgt_async_info @__tgt_target_data_begin_issue(i64, i32,
1734 // i8**, i8**, i64*, i64*)
1735 FunctionCallee IssueDecl = IRBuilder.getOrCreateRuntimeFunction(
1736 M, OMPRTL___tgt_target_data_begin_mapper_issue);
1737
1738 // Change RuntimeCall call site for its asynchronous version.
1740 for (auto &Arg : RuntimeCall.args())
1741 Args.push_back(Arg.get());
1742 Args.push_back(Handle);
1743
1744 CallInst *IssueCallsite = CallInst::Create(IssueDecl, Args, /*NameStr=*/"",
1745 RuntimeCall.getIterator());
1746 OMPInfoCache.setCallingConvention(IssueDecl, IssueCallsite);
1747 RuntimeCall.eraseFromParent();
1748
1749 // Add "wait" runtime call declaration:
1750 // declare void @__tgt_target_data_begin_wait(i64, %struct.__tgt_async_info)
1751 FunctionCallee WaitDecl = IRBuilder.getOrCreateRuntimeFunction(
1752 M, OMPRTL___tgt_target_data_begin_mapper_wait);
1753
1754 Value *WaitParams[2] = {
1755 IssueCallsite->getArgOperand(
1756 OffloadArray::DeviceIDArgNum), // device_id.
1757 Handle // handle to wait on.
1758 };
1759 CallInst *WaitCallsite = CallInst::Create(
1760 WaitDecl, WaitParams, /*NameStr=*/"", WaitMovementPoint.getIterator());
1761 OMPInfoCache.setCallingConvention(WaitDecl, WaitCallsite);
1762
1763 return true;
1764 }
1765
1766 static Value *combinedIdentStruct(Value *CurrentIdent, Value *NextIdent,
1767 bool GlobalOnly, bool &SingleChoice) {
1768 if (CurrentIdent == NextIdent)
1769 return CurrentIdent;
1770
1771 // TODO: Figure out how to actually combine multiple debug locations. For
1772 // now we just keep an existing one if there is a single choice.
1773 if (!GlobalOnly || isa<GlobalValue>(NextIdent)) {
1774 SingleChoice = !CurrentIdent;
1775 return NextIdent;
1776 }
1777 return nullptr;
1778 }
1779
1780 /// Return an `struct ident_t*` value that represents the ones used in the
1781 /// calls of \p RFI inside of \p F. If \p GlobalOnly is true, we will not
1782 /// return a local `struct ident_t*`. For now, if we cannot find a suitable
1783 /// return value we create one from scratch. We also do not yet combine
1784 /// information, e.g., the source locations, see combinedIdentStruct.
1785 Value *
1786 getCombinedIdentFromCallUsesIn(OMPInformationCache::RuntimeFunctionInfo &RFI,
1787 Function &F, bool GlobalOnly) {
1788 bool SingleChoice = true;
1789 Value *Ident = nullptr;
1790 auto CombineIdentStruct = [&](Use &U, Function &Caller) {
1791 CallInst *CI = getCallIfRegularCall(U, &RFI);
1792 if (!CI || &F != &Caller)
1793 return false;
1794 Ident = combinedIdentStruct(Ident, CI->getArgOperand(0),
1795 /* GlobalOnly */ true, SingleChoice);
1796 return false;
1797 };
1798 RFI.foreachUse(SCC, CombineIdentStruct);
1799
1800 if (!Ident || !SingleChoice) {
1801 // The IRBuilder uses the insertion block to get to the module, this is
1802 // unfortunate but we work around it for now.
1803 if (!OMPInfoCache.OMPBuilder.getInsertionPoint().getBlock())
1804 OMPInfoCache.OMPBuilder.updateToLocation(OpenMPIRBuilder::InsertPointTy(
1805 &F.getEntryBlock(), F.getEntryBlock().begin()));
1806 // Create a fallback location if non was found.
1807 // TODO: Use the debug locations of the calls instead.
1808 uint32_t SrcLocStrSize;
1809 Constant *Loc =
1810 OMPInfoCache.OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
1811 Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(Loc, SrcLocStrSize);
1812 }
1813 return Ident;
1814 }
1815
1816 /// Try to eliminate calls of \p RFI in \p F by reusing an existing one or
1817 /// \p ReplVal if given.
1818 bool deduplicateRuntimeCalls(Function &F,
1819 OMPInformationCache::RuntimeFunctionInfo &RFI,
1820 Value *ReplVal = nullptr) {
1821 auto *UV = RFI.getUseVector(F);
1822 if (!UV || UV->size() + (ReplVal != nullptr) < 2)
1823 return false;
1824
1825 LLVM_DEBUG(
1826 dbgs() << TAG << "Deduplicate " << UV->size() << " uses of " << RFI.Name
1827 << (ReplVal ? " with an existing value\n" : "\n") << "\n");
1828
1829 assert((!ReplVal || (isa<Argument>(ReplVal) &&
1830 cast<Argument>(ReplVal)->getParent() == &F)) &&
1831 "Unexpected replacement value!");
1832
1833 // TODO: Use dominance to find a good position instead.
1834 auto CanBeMoved = [this](CallBase &CB) {
1835 unsigned NumArgs = CB.arg_size();
1836 if (NumArgs == 0)
1837 return true;
1838 if (CB.getArgOperand(0)->getType() != OMPInfoCache.OMPBuilder.IdentPtr)
1839 return false;
1840 for (unsigned U = 1; U < NumArgs; ++U)
1841 if (isa<Instruction>(CB.getArgOperand(U)))
1842 return false;
1843 return true;
1844 };
1845
1846 if (!ReplVal) {
1847 auto *DT =
1848 OMPInfoCache.getAnalysisResultForFunction<DominatorTreeAnalysis>(F);
1849 if (!DT)
1850 return false;
1851 Instruction *IP = nullptr;
1852 for (Use *U : *UV) {
1853 if (CallInst *CI = getCallIfRegularCall(*U, &RFI)) {
1854 if (IP)
1855 IP = DT->findNearestCommonDominator(IP, CI);
1856 else
1857 IP = CI;
1858 if (!CanBeMoved(*CI))
1859 continue;
1860 if (!ReplVal)
1861 ReplVal = CI;
1862 }
1863 }
1864 if (!ReplVal)
1865 return false;
1866 assert(IP && "Expected insertion point!");
1867 cast<Instruction>(ReplVal)->moveBefore(IP);
1868 }
1869
1870 // If we use a call as a replacement value we need to make sure the ident is
1871 // valid at the new location. For now we just pick a global one, either
1872 // existing and used by one of the calls, or created from scratch.
1873 if (CallBase *CI = dyn_cast<CallBase>(ReplVal)) {
1874 if (!CI->arg_empty() &&
1875 CI->getArgOperand(0)->getType() == OMPInfoCache.OMPBuilder.IdentPtr) {
1876 Value *Ident = getCombinedIdentFromCallUsesIn(RFI, F,
1877 /* GlobalOnly */ true);
1878 CI->setArgOperand(0, Ident);
1879 }
1880 }
1881
1882 bool Changed = false;
1883 auto ReplaceAndDeleteCB = [&](Use &U, Function &Caller) {
1884 CallInst *CI = getCallIfRegularCall(U, &RFI);
1885 if (!CI || CI == ReplVal || &F != &Caller)
1886 return false;
1887 assert(CI->getCaller() == &F && "Unexpected call!");
1888
1889 auto Remark = [&](OptimizationRemark OR) {
1890 return OR << "OpenMP runtime call "
1891 << ore::NV("OpenMPOptRuntime", RFI.Name) << " deduplicated.";
1892 };
1893 if (CI->getDebugLoc())
1894 emitRemark<OptimizationRemark>(CI, "OMP170", Remark);
1895 else
1896 emitRemark<OptimizationRemark>(&F, "OMP170", Remark);
1897
1898 CGUpdater.removeCallSite(*CI);
1899 CI->replaceAllUsesWith(ReplVal);
1900 CI->eraseFromParent();
1901 ++NumOpenMPRuntimeCallsDeduplicated;
1902 Changed = true;
1903 return true;
1904 };
1905 RFI.foreachUse(SCC, ReplaceAndDeleteCB);
1906
1907 return Changed;
1908 }
1909
1910 /// Collect arguments that represent the global thread id in \p GTIdArgs.
1911 void collectGlobalThreadIdArguments(SmallSetVector<Value *, 16> &GTIdArgs) {
1912 // TODO: Below we basically perform a fixpoint iteration with a pessimistic
1913 // initialization. We could define an AbstractAttribute instead and
1914 // run the Attributor here once it can be run as an SCC pass.
1915
1916 // Helper to check the argument \p ArgNo at all call sites of \p F for
1917 // a GTId.
1918 auto CallArgOpIsGTId = [&](Function &F, unsigned ArgNo, CallInst &RefCI) {
1919 if (!F.hasLocalLinkage())
1920 return false;
1921 for (Use &U : F.uses()) {
1922 if (CallInst *CI = getCallIfRegularCall(U)) {
1923 Value *ArgOp = CI->getArgOperand(ArgNo);
1924 if (CI == &RefCI || GTIdArgs.count(ArgOp) ||
1925 getCallIfRegularCall(
1926 *ArgOp, &OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num]))
1927 continue;
1928 }
1929 return false;
1930 }
1931 return true;
1932 };
1933
1934 // Helper to identify uses of a GTId as GTId arguments.
1935 auto AddUserArgs = [&](Value &GTId) {
1936 for (Use &U : GTId.uses())
1937 if (CallInst *CI = dyn_cast<CallInst>(U.getUser()))
1938 if (CI->isArgOperand(&U))
1939 if (Function *Callee = CI->getCalledFunction())
1940 if (CallArgOpIsGTId(*Callee, U.getOperandNo(), *CI))
1941 GTIdArgs.insert(Callee->getArg(U.getOperandNo()));
1942 };
1943
1944 // The argument users of __kmpc_global_thread_num calls are GTIds.
1945 OMPInformationCache::RuntimeFunctionInfo &GlobThreadNumRFI =
1946 OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num];
1947
1948 GlobThreadNumRFI.foreachUse(SCC, [&](Use &U, Function &F) {
1949 if (CallInst *CI = getCallIfRegularCall(U, &GlobThreadNumRFI))
1950 AddUserArgs(*CI);
1951 return false;
1952 });
1953
1954 // Transitively search for more arguments by looking at the users of the
1955 // ones we know already. During the search the GTIdArgs vector is extended
1956 // so we cannot cache the size nor can we use a range based for.
1957 for (unsigned U = 0; U < GTIdArgs.size(); ++U)
1958 AddUserArgs(*GTIdArgs[U]);
1959 }
1960
1961 /// Kernel (=GPU) optimizations and utility functions
1962 ///
1963 ///{{
1964
1965 /// Cache to remember the unique kernel for a function.
1967
1968 /// Find the unique kernel that will execute \p F, if any.
1969 Kernel getUniqueKernelFor(Function &F);
1970
1971 /// Find the unique kernel that will execute \p I, if any.
1972 Kernel getUniqueKernelFor(Instruction &I) {
1973 return getUniqueKernelFor(*I.getFunction());
1974 }
1975
1976 /// Rewrite the device (=GPU) code state machine create in non-SPMD mode in
1977 /// the cases we can avoid taking the address of a function.
1978 bool rewriteDeviceCodeStateMachine();
1979
1980 ///
1981 ///}}
1982
1983 /// Emit a remark generically
1984 ///
1985 /// This template function can be used to generically emit a remark. The
1986 /// RemarkKind should be one of the following:
1987 /// - OptimizationRemark to indicate a successful optimization attempt
1988 /// - OptimizationRemarkMissed to report a failed optimization attempt
1989 /// - OptimizationRemarkAnalysis to provide additional information about an
1990 /// optimization attempt
1991 ///
1992 /// The remark is built using a callback function provided by the caller that
1993 /// takes a RemarkKind as input and returns a RemarkKind.
1994 template <typename RemarkKind, typename RemarkCallBack>
1995 void emitRemark(Instruction *I, StringRef RemarkName,
1996 RemarkCallBack &&RemarkCB) const {
1997 Function *F = I->getParent()->getParent();
1998 auto &ORE = OREGetter(F);
1999
2000 if (RemarkName.starts_with("OMP"))
2001 ORE.emit([&]() {
2002 return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I))
2003 << " [" << RemarkName << "]";
2004 });
2005 else
2006 ORE.emit(
2007 [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I)); });
2008 }
2009
2010 /// Emit a remark on a function.
2011 template <typename RemarkKind, typename RemarkCallBack>
2012 void emitRemark(Function *F, StringRef RemarkName,
2013 RemarkCallBack &&RemarkCB) const {
2014 auto &ORE = OREGetter(F);
2015
2016 if (RemarkName.starts_with("OMP"))
2017 ORE.emit([&]() {
2018 return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F))
2019 << " [" << RemarkName << "]";
2020 });
2021 else
2022 ORE.emit(
2023 [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F)); });
2024 }
2025
2026 /// The underlying module.
2027 Module &M;
2028
2029 /// The SCC we are operating on.
2031
2032 /// Callback to update the call graph, the first argument is a removed call,
2033 /// the second an optional replacement call.
2034 CallGraphUpdater &CGUpdater;
2035
2036 /// Callback to get an OptimizationRemarkEmitter from a Function *
2037 OptimizationRemarkGetter OREGetter;
2038
2039 /// OpenMP-specific information cache. Also Used for Attributor runs.
2040 OMPInformationCache &OMPInfoCache;
2041
2042 /// Attributor instance.
2043 Attributor &A;
2044
2045 /// Helper function to run Attributor on SCC.
2046 bool runAttributor(bool IsModulePass) {
2047 if (SCC.empty())
2048 return false;
2049
2050 registerAAs(IsModulePass);
2051
2052 ChangeStatus Changed = A.run();
2053
2054 LLVM_DEBUG(dbgs() << "[Attributor] Done with " << SCC.size()
2055 << " functions, result: " << Changed << ".\n");
2056
2057 if (Changed == ChangeStatus::CHANGED)
2058 OMPInfoCache.invalidateAnalyses();
2059
2060 return Changed == ChangeStatus::CHANGED;
2061 }
2062
2063 void registerFoldRuntimeCall(RuntimeFunction RF);
2064
2065 /// Populate the Attributor with abstract attribute opportunities in the
2066 /// functions.
2067 void registerAAs(bool IsModulePass);
2068
2069public:
2070 /// Callback to register AAs for live functions, including internal functions
2071 /// marked live during the traversal.
2072 static void registerAAsForFunction(Attributor &A, const Function &F);
2073};
2074
2075Kernel OpenMPOpt::getUniqueKernelFor(Function &F) {
2076 if (OMPInfoCache.CGSCC && !OMPInfoCache.CGSCC->empty() &&
2077 !OMPInfoCache.CGSCC->contains(&F))
2078 return nullptr;
2079
2080 // Use a scope to keep the lifetime of the CachedKernel short.
2081 {
2082 std::optional<Kernel> &CachedKernel = UniqueKernelMap[&F];
2083 if (CachedKernel)
2084 return *CachedKernel;
2085
2086 // TODO: We should use an AA to create an (optimistic and callback
2087 // call-aware) call graph. For now we stick to simple patterns that
2088 // are less powerful, basically the worst fixpoint.
2089 if (isOpenMPKernel(F)) {
2090 CachedKernel = Kernel(&F);
2091 return *CachedKernel;
2092 }
2093
2094 CachedKernel = nullptr;
2095 if (!F.hasLocalLinkage()) {
2096
2097 // See https://openmp.llvm.org/remarks/OptimizationRemarks.html
2098 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2099 return ORA << "Potentially unknown OpenMP target region caller.";
2100 };
2101 emitRemark<OptimizationRemarkAnalysis>(&F, "OMP100", Remark);
2102
2103 return nullptr;
2104 }
2105 }
2106
2107 auto GetUniqueKernelForUse = [&](const Use &U) -> Kernel {
2108 if (auto *Cmp = dyn_cast<ICmpInst>(U.getUser())) {
2109 // Allow use in equality comparisons.
2110 if (Cmp->isEquality())
2111 return getUniqueKernelFor(*Cmp);
2112 return nullptr;
2113 }
2114 if (auto *CB = dyn_cast<CallBase>(U.getUser())) {
2115 // Allow direct calls.
2116 if (CB->isCallee(&U))
2117 return getUniqueKernelFor(*CB);
2118
2119 OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
2120 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
2121 // Allow the use in __kmpc_parallel_51 calls.
2122 if (OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI))
2123 return getUniqueKernelFor(*CB);
2124 return nullptr;
2125 }
2126 // Disallow every other use.
2127 return nullptr;
2128 };
2129
2130 // TODO: In the future we want to track more than just a unique kernel.
2131 SmallPtrSet<Kernel, 2> PotentialKernels;
2132 OMPInformationCache::foreachUse(F, [&](const Use &U) {
2133 PotentialKernels.insert(GetUniqueKernelForUse(U));
2134 });
2135
2136 Kernel K = nullptr;
2137 if (PotentialKernels.size() == 1)
2138 K = *PotentialKernels.begin();
2139
2140 // Cache the result.
2141 UniqueKernelMap[&F] = K;
2142
2143 return K;
2144}
2145
2146bool OpenMPOpt::rewriteDeviceCodeStateMachine() {
2147 OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
2148 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
2149
2150 bool Changed = false;
2151 if (!KernelParallelRFI)
2152 return Changed;
2153
2154 // If we have disabled state machine changes, exit
2156 return Changed;
2157
2158 for (Function *F : SCC) {
2159
2160 // Check if the function is a use in a __kmpc_parallel_51 call at
2161 // all.
2162 bool UnknownUse = false;
2163 bool KernelParallelUse = false;
2164 unsigned NumDirectCalls = 0;
2165
2166 SmallVector<Use *, 2> ToBeReplacedStateMachineUses;
2167 OMPInformationCache::foreachUse(*F, [&](Use &U) {
2168 if (auto *CB = dyn_cast<CallBase>(U.getUser()))
2169 if (CB->isCallee(&U)) {
2170 ++NumDirectCalls;
2171 return;
2172 }
2173
2174 if (isa<ICmpInst>(U.getUser())) {
2175 ToBeReplacedStateMachineUses.push_back(&U);
2176 return;
2177 }
2178
2179 // Find wrapper functions that represent parallel kernels.
2180 CallInst *CI =
2181 OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI);
2182 const unsigned int WrapperFunctionArgNo = 6;
2183 if (!KernelParallelUse && CI &&
2184 CI->getArgOperandNo(&U) == WrapperFunctionArgNo) {
2185 KernelParallelUse = true;
2186 ToBeReplacedStateMachineUses.push_back(&U);
2187 return;
2188 }
2189 UnknownUse = true;
2190 });
2191
2192 // Do not emit a remark if we haven't seen a __kmpc_parallel_51
2193 // use.
2194 if (!KernelParallelUse)
2195 continue;
2196
2197 // If this ever hits, we should investigate.
2198 // TODO: Checking the number of uses is not a necessary restriction and
2199 // should be lifted.
2200 if (UnknownUse || NumDirectCalls != 1 ||
2201 ToBeReplacedStateMachineUses.size() > 2) {
2202 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2203 return ORA << "Parallel region is used in "
2204 << (UnknownUse ? "unknown" : "unexpected")
2205 << " ways. Will not attempt to rewrite the state machine.";
2206 };
2207 emitRemark<OptimizationRemarkAnalysis>(F, "OMP101", Remark);
2208 continue;
2209 }
2210
2211 // Even if we have __kmpc_parallel_51 calls, we (for now) give
2212 // up if the function is not called from a unique kernel.
2213 Kernel K = getUniqueKernelFor(*F);
2214 if (!K) {
2215 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2216 return ORA << "Parallel region is not called from a unique kernel. "
2217 "Will not attempt to rewrite the state machine.";
2218 };
2219 emitRemark<OptimizationRemarkAnalysis>(F, "OMP102", Remark);
2220 continue;
2221 }
2222
2223 // We now know F is a parallel body function called only from the kernel K.
2224 // We also identified the state machine uses in which we replace the
2225 // function pointer by a new global symbol for identification purposes. This
2226 // ensures only direct calls to the function are left.
2227
2228 Module &M = *F->getParent();
2229 Type *Int8Ty = Type::getInt8Ty(M.getContext());
2230
2231 auto *ID = new GlobalVariable(
2232 M, Int8Ty, /* isConstant */ true, GlobalValue::PrivateLinkage,
2233 UndefValue::get(Int8Ty), F->getName() + ".ID");
2234
2235 for (Use *U : ToBeReplacedStateMachineUses)
2237 ID, U->get()->getType()));
2238
2239 ++NumOpenMPParallelRegionsReplacedInGPUStateMachine;
2240
2241 Changed = true;
2242 }
2243
2244 return Changed;
2245}
2246
2247/// Abstract Attribute for tracking ICV values.
2248struct AAICVTracker : public StateWrapper<BooleanState, AbstractAttribute> {
2250 AAICVTracker(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
2251
2252 /// Returns true if value is assumed to be tracked.
2253 bool isAssumedTracked() const { return getAssumed(); }
2254
2255 /// Returns true if value is known to be tracked.
2256 bool isKnownTracked() const { return getAssumed(); }
2257
2258 /// Create an abstract attribute biew for the position \p IRP.
2259 static AAICVTracker &createForPosition(const IRPosition &IRP, Attributor &A);
2260
2261 /// Return the value with which \p I can be replaced for specific \p ICV.
2262 virtual std::optional<Value *> getReplacementValue(InternalControlVar ICV,
2263 const Instruction *I,
2264 Attributor &A) const {
2265 return std::nullopt;
2266 }
2267
2268 /// Return an assumed unique ICV value if a single candidate is found. If
2269 /// there cannot be one, return a nullptr. If it is not clear yet, return
2270 /// std::nullopt.
2271 virtual std::optional<Value *>
2272 getUniqueReplacementValue(InternalControlVar ICV) const = 0;
2273
2274 // Currently only nthreads is being tracked.
2275 // this array will only grow with time.
2276 InternalControlVar TrackableICVs[1] = {ICV_nthreads};
2277
2278 /// See AbstractAttribute::getName()
2279 const std::string getName() const override { return "AAICVTracker"; }
2280
2281 /// See AbstractAttribute::getIdAddr()
2282 const char *getIdAddr() const override { return &ID; }
2283
2284 /// This function should return true if the type of the \p AA is AAICVTracker
2285 static bool classof(const AbstractAttribute *AA) {
2286 return (AA->getIdAddr() == &ID);
2287 }
2288
2289 static const char ID;
2290};
2291
2292struct AAICVTrackerFunction : public AAICVTracker {
2293 AAICVTrackerFunction(const IRPosition &IRP, Attributor &A)
2294 : AAICVTracker(IRP, A) {}
2295
2296 // FIXME: come up with better string.
2297 const std::string getAsStr(Attributor *) const override {
2298 return "ICVTrackerFunction";
2299 }
2300
2301 // FIXME: come up with some stats.
2302 void trackStatistics() const override {}
2303
2304 /// We don't manifest anything for this AA.
2305 ChangeStatus manifest(Attributor &A) override {
2306 return ChangeStatus::UNCHANGED;
2307 }
2308
2309 // Map of ICV to their values at specific program point.
2311 InternalControlVar::ICV___last>
2312 ICVReplacementValuesMap;
2313
2314 ChangeStatus updateImpl(Attributor &A) override {
2315 ChangeStatus HasChanged = ChangeStatus::UNCHANGED;
2316
2317 Function *F = getAnchorScope();
2318
2319 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2320
2321 for (InternalControlVar ICV : TrackableICVs) {
2322 auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2323
2324 auto &ValuesMap = ICVReplacementValuesMap[ICV];
2325 auto TrackValues = [&](Use &U, Function &) {
2326 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U);
2327 if (!CI)
2328 return false;
2329
2330 // FIXME: handle setters with more that 1 arguments.
2331 /// Track new value.
2332 if (ValuesMap.insert(std::make_pair(CI, CI->getArgOperand(0))).second)
2333 HasChanged = ChangeStatus::CHANGED;
2334
2335 return false;
2336 };
2337
2338 auto CallCheck = [&](Instruction &I) {
2339 std::optional<Value *> ReplVal = getValueForCall(A, I, ICV);
2340 if (ReplVal && ValuesMap.insert(std::make_pair(&I, *ReplVal)).second)
2341 HasChanged = ChangeStatus::CHANGED;
2342
2343 return true;
2344 };
2345
2346 // Track all changes of an ICV.
2347 SetterRFI.foreachUse(TrackValues, F);
2348
2349 bool UsedAssumedInformation = false;
2350 A.checkForAllInstructions(CallCheck, *this, {Instruction::Call},
2351 UsedAssumedInformation,
2352 /* CheckBBLivenessOnly */ true);
2353
2354 /// TODO: Figure out a way to avoid adding entry in
2355 /// ICVReplacementValuesMap
2356 Instruction *Entry = &F->getEntryBlock().front();
2357 if (HasChanged == ChangeStatus::CHANGED && !ValuesMap.count(Entry))
2358 ValuesMap.insert(std::make_pair(Entry, nullptr));
2359 }
2360
2361 return HasChanged;
2362 }
2363
2364 /// Helper to check if \p I is a call and get the value for it if it is
2365 /// unique.
2366 std::optional<Value *> getValueForCall(Attributor &A, const Instruction &I,
2367 InternalControlVar &ICV) const {
2368
2369 const auto *CB = dyn_cast<CallBase>(&I);
2370 if (!CB || CB->hasFnAttr("no_openmp") ||
2371 CB->hasFnAttr("no_openmp_routines"))
2372 return std::nullopt;
2373
2374 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2375 auto &GetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Getter];
2376 auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2377 Function *CalledFunction = CB->getCalledFunction();
2378
2379 // Indirect call, assume ICV changes.
2380 if (CalledFunction == nullptr)
2381 return nullptr;
2382 if (CalledFunction == GetterRFI.Declaration)
2383 return std::nullopt;
2384 if (CalledFunction == SetterRFI.Declaration) {
2385 if (ICVReplacementValuesMap[ICV].count(&I))
2386 return ICVReplacementValuesMap[ICV].lookup(&I);
2387
2388 return nullptr;
2389 }
2390
2391 // Since we don't know, assume it changes the ICV.
2392 if (CalledFunction->isDeclaration())
2393 return nullptr;
2394
2395 const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2396 *this, IRPosition::callsite_returned(*CB), DepClassTy::REQUIRED);
2397
2398 if (ICVTrackingAA->isAssumedTracked()) {
2399 std::optional<Value *> URV =
2400 ICVTrackingAA->getUniqueReplacementValue(ICV);
2401 if (!URV || (*URV && AA::isValidAtPosition(AA::ValueAndContext(**URV, I),
2402 OMPInfoCache)))
2403 return URV;
2404 }
2405
2406 // If we don't know, assume it changes.
2407 return nullptr;
2408 }
2409
2410 // We don't check unique value for a function, so return std::nullopt.
2411 std::optional<Value *>
2412 getUniqueReplacementValue(InternalControlVar ICV) const override {
2413 return std::nullopt;
2414 }
2415
2416 /// Return the value with which \p I can be replaced for specific \p ICV.
2417 std::optional<Value *> getReplacementValue(InternalControlVar ICV,
2418 const Instruction *I,
2419 Attributor &A) const override {
2420 const auto &ValuesMap = ICVReplacementValuesMap[ICV];
2421 if (ValuesMap.count(I))
2422 return ValuesMap.lookup(I);
2423
2426 Worklist.push_back(I);
2427
2428 std::optional<Value *> ReplVal;
2429
2430 while (!Worklist.empty()) {
2431 const Instruction *CurrInst = Worklist.pop_back_val();
2432 if (!Visited.insert(CurrInst).second)
2433 continue;
2434
2435 const BasicBlock *CurrBB = CurrInst->getParent();
2436
2437 // Go up and look for all potential setters/calls that might change the
2438 // ICV.
2439 while ((CurrInst = CurrInst->getPrevNode())) {
2440 if (ValuesMap.count(CurrInst)) {
2441 std::optional<Value *> NewReplVal = ValuesMap.lookup(CurrInst);
2442 // Unknown value, track new.
2443 if (!ReplVal) {
2444 ReplVal = NewReplVal;
2445 break;
2446 }
2447
2448 // If we found a new value, we can't know the icv value anymore.
2449 if (NewReplVal)
2450 if (ReplVal != NewReplVal)
2451 return nullptr;
2452
2453 break;
2454 }
2455
2456 std::optional<Value *> NewReplVal = getValueForCall(A, *CurrInst, ICV);
2457 if (!NewReplVal)
2458 continue;
2459
2460 // Unknown value, track new.
2461 if (!ReplVal) {
2462 ReplVal = NewReplVal;
2463 break;
2464 }
2465
2466 // if (NewReplVal.hasValue())
2467 // We found a new value, we can't know the icv value anymore.
2468 if (ReplVal != NewReplVal)
2469 return nullptr;
2470 }
2471
2472 // If we are in the same BB and we have a value, we are done.
2473 if (CurrBB == I->getParent() && ReplVal)
2474 return ReplVal;
2475
2476 // Go through all predecessors and add terminators for analysis.
2477 for (const BasicBlock *Pred : predecessors(CurrBB))
2478 if (const Instruction *Terminator = Pred->getTerminator())
2479 Worklist.push_back(Terminator);
2480 }
2481
2482 return ReplVal;
2483 }
2484};
2485
2486struct AAICVTrackerFunctionReturned : AAICVTracker {
2487 AAICVTrackerFunctionReturned(const IRPosition &IRP, Attributor &A)
2488 : AAICVTracker(IRP, A) {}
2489
2490 // FIXME: come up with better string.
2491 const std::string getAsStr(Attributor *) const override {
2492 return "ICVTrackerFunctionReturned";
2493 }
2494
2495 // FIXME: come up with some stats.
2496 void trackStatistics() const override {}
2497
2498 /// We don't manifest anything for this AA.
2499 ChangeStatus manifest(Attributor &A) override {
2500 return ChangeStatus::UNCHANGED;
2501 }
2502
2503 // Map of ICV to their values at specific program point.
2505 InternalControlVar::ICV___last>
2506 ICVReplacementValuesMap;
2507
2508 /// Return the value with which \p I can be replaced for specific \p ICV.
2509 std::optional<Value *>
2510 getUniqueReplacementValue(InternalControlVar ICV) const override {
2511 return ICVReplacementValuesMap[ICV];
2512 }
2513
2514 ChangeStatus updateImpl(Attributor &A) override {
2515 ChangeStatus Changed = ChangeStatus::UNCHANGED;
2516 const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2517 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2518
2519 if (!ICVTrackingAA->isAssumedTracked())
2520 return indicatePessimisticFixpoint();
2521
2522 for (InternalControlVar ICV : TrackableICVs) {
2523 std::optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2524 std::optional<Value *> UniqueICVValue;
2525
2526 auto CheckReturnInst = [&](Instruction &I) {
2527 std::optional<Value *> NewReplVal =
2528 ICVTrackingAA->getReplacementValue(ICV, &I, A);
2529
2530 // If we found a second ICV value there is no unique returned value.
2531 if (UniqueICVValue && UniqueICVValue != NewReplVal)
2532 return false;
2533
2534 UniqueICVValue = NewReplVal;
2535
2536 return true;
2537 };
2538
2539 bool UsedAssumedInformation = false;
2540 if (!A.checkForAllInstructions(CheckReturnInst, *this, {Instruction::Ret},
2541 UsedAssumedInformation,
2542 /* CheckBBLivenessOnly */ true))
2543 UniqueICVValue = nullptr;
2544
2545 if (UniqueICVValue == ReplVal)
2546 continue;
2547
2548 ReplVal = UniqueICVValue;
2549 Changed = ChangeStatus::CHANGED;
2550 }
2551
2552 return Changed;
2553 }
2554};
2555
2556struct AAICVTrackerCallSite : AAICVTracker {
2557 AAICVTrackerCallSite(const IRPosition &IRP, Attributor &A)
2558 : AAICVTracker(IRP, A) {}
2559
2560 void initialize(Attributor &A) override {
2561 assert(getAnchorScope() && "Expected anchor function");
2562
2563 // We only initialize this AA for getters, so we need to know which ICV it
2564 // gets.
2565 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2566 for (InternalControlVar ICV : TrackableICVs) {
2567 auto ICVInfo = OMPInfoCache.ICVs[ICV];
2568 auto &Getter = OMPInfoCache.RFIs[ICVInfo.Getter];
2569 if (Getter.Declaration == getAssociatedFunction()) {
2570 AssociatedICV = ICVInfo.Kind;
2571 return;
2572 }
2573 }
2574
2575 /// Unknown ICV.
2576 indicatePessimisticFixpoint();
2577 }
2578
2579 ChangeStatus manifest(Attributor &A) override {
2580 if (!ReplVal || !*ReplVal)
2581 return ChangeStatus::UNCHANGED;
2582
2583 A.changeAfterManifest(IRPosition::inst(*getCtxI()), **ReplVal);
2584 A.deleteAfterManifest(*getCtxI());
2585
2586 return ChangeStatus::CHANGED;
2587 }
2588
2589 // FIXME: come up with better string.
2590 const std::string getAsStr(Attributor *) const override {
2591 return "ICVTrackerCallSite";
2592 }
2593
2594 // FIXME: come up with some stats.
2595 void trackStatistics() const override {}
2596
2597 InternalControlVar AssociatedICV;
2598 std::optional<Value *> ReplVal;
2599
2600 ChangeStatus updateImpl(Attributor &A) override {
2601 const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2602 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2603
2604 // We don't have any information, so we assume it changes the ICV.
2605 if (!ICVTrackingAA->isAssumedTracked())
2606 return indicatePessimisticFixpoint();
2607
2608 std::optional<Value *> NewReplVal =
2609 ICVTrackingAA->getReplacementValue(AssociatedICV, getCtxI(), A);
2610
2611 if (ReplVal == NewReplVal)
2612 return ChangeStatus::UNCHANGED;
2613
2614 ReplVal = NewReplVal;
2615 return ChangeStatus::CHANGED;
2616 }
2617
2618 // Return the value with which associated value can be replaced for specific
2619 // \p ICV.
2620 std::optional<Value *>
2621 getUniqueReplacementValue(InternalControlVar ICV) const override {
2622 return ReplVal;
2623 }
2624};
2625
2626struct AAICVTrackerCallSiteReturned : AAICVTracker {
2627 AAICVTrackerCallSiteReturned(const IRPosition &IRP, Attributor &A)
2628 : AAICVTracker(IRP, A) {}
2629
2630 // FIXME: come up with better string.
2631 const std::string getAsStr(Attributor *) const override {
2632 return "ICVTrackerCallSiteReturned";
2633 }
2634
2635 // FIXME: come up with some stats.
2636 void trackStatistics() const override {}
2637
2638 /// We don't manifest anything for this AA.
2639 ChangeStatus manifest(Attributor &A) override {
2640 return ChangeStatus::UNCHANGED;
2641 }
2642
2643 // Map of ICV to their values at specific program point.
2645 InternalControlVar::ICV___last>
2646 ICVReplacementValuesMap;
2647
2648 /// Return the value with which associated value can be replaced for specific
2649 /// \p ICV.
2650 std::optional<Value *>
2651 getUniqueReplacementValue(InternalControlVar ICV) const override {
2652 return ICVReplacementValuesMap[ICV];
2653 }
2654
2655 ChangeStatus updateImpl(Attributor &A) override {
2656 ChangeStatus Changed = ChangeStatus::UNCHANGED;
2657 const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2658 *this, IRPosition::returned(*getAssociatedFunction()),
2659 DepClassTy::REQUIRED);
2660
2661 // We don't have any information, so we assume it changes the ICV.
2662 if (!ICVTrackingAA->isAssumedTracked())
2663 return indicatePessimisticFixpoint();
2664
2665 for (InternalControlVar ICV : TrackableICVs) {
2666 std::optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2667 std::optional<Value *> NewReplVal =
2668 ICVTrackingAA->getUniqueReplacementValue(ICV);
2669
2670 if (ReplVal == NewReplVal)
2671 continue;
2672
2673 ReplVal = NewReplVal;
2674 Changed = ChangeStatus::CHANGED;
2675 }
2676 return Changed;
2677 }
2678};
2679
2680/// Determines if \p BB exits the function unconditionally itself or reaches a
2681/// block that does through only unique successors.
2682static bool hasFunctionEndAsUniqueSuccessor(const BasicBlock *BB) {
2683 if (succ_empty(BB))
2684 return true;
2685 const BasicBlock *const Successor = BB->getUniqueSuccessor();
2686 if (!Successor)
2687 return false;
2688 return hasFunctionEndAsUniqueSuccessor(Successor);
2689}
2690
2691struct AAExecutionDomainFunction : public AAExecutionDomain {
2692 AAExecutionDomainFunction(const IRPosition &IRP, Attributor &A)
2693 : AAExecutionDomain(IRP, A) {}
2694
2695 ~AAExecutionDomainFunction() { delete RPOT; }
2696
2697 void initialize(Attributor &A) override {
2699 assert(F && "Expected anchor function");
2701 }
2702
2703 const std::string getAsStr(Attributor *) const override {
2704 unsigned TotalBlocks = 0, InitialThreadBlocks = 0, AlignedBlocks = 0;
2705 for (auto &It : BEDMap) {
2706 if (!It.getFirst())
2707 continue;
2708 TotalBlocks++;
2709 InitialThreadBlocks += It.getSecond().IsExecutedByInitialThreadOnly;
2710 AlignedBlocks += It.getSecond().IsReachedFromAlignedBarrierOnly &&
2711 It.getSecond().IsReachingAlignedBarrierOnly;
2712 }
2713 return "[AAExecutionDomain] " + std::to_string(InitialThreadBlocks) + "/" +
2714 std::to_string(AlignedBlocks) + " of " +
2715 std::to_string(TotalBlocks) +
2716 " executed by initial thread / aligned";
2717 }
2718
2719 /// See AbstractAttribute::trackStatistics().
2720 void trackStatistics() const override {}
2721
2722 ChangeStatus manifest(Attributor &A) override {
2723 LLVM_DEBUG({
2724 for (const BasicBlock &BB : *getAnchorScope()) {
2726 continue;
2727 dbgs() << TAG << " Basic block @" << getAnchorScope()->getName() << " "
2728 << BB.getName() << " is executed by a single thread.\n";
2729 }
2730 });
2731
2733
2735 return Changed;
2736
2737 SmallPtrSet<CallBase *, 16> DeletedBarriers;
2738 auto HandleAlignedBarrier = [&](CallBase *CB) {
2739 const ExecutionDomainTy &ED = CB ? CEDMap[{CB, PRE}] : BEDMap[nullptr];
2740 if (!ED.IsReachedFromAlignedBarrierOnly ||
2741 ED.EncounteredNonLocalSideEffect)
2742 return;
2743 if (!ED.EncounteredAssumes.empty() && !A.isModulePass())
2744 return;
2745
2746 // We can remove this barrier, if it is one, or aligned barriers reaching
2747 // the kernel end (if CB is nullptr). Aligned barriers reaching the kernel
2748 // end should only be removed if the kernel end is their unique successor;
2749 // otherwise, they may have side-effects that aren't accounted for in the
2750 // kernel end in their other successors. If those barriers have other
2751 // barriers reaching them, those can be transitively removed as well as
2752 // long as the kernel end is also their unique successor.
2753 if (CB) {
2754 DeletedBarriers.insert(CB);
2755 A.deleteAfterManifest(*CB);
2756 ++NumBarriersEliminated;
2757 Changed = ChangeStatus::CHANGED;
2758 } else if (!ED.AlignedBarriers.empty()) {
2759 Changed = ChangeStatus::CHANGED;
2760 SmallVector<CallBase *> Worklist(ED.AlignedBarriers.begin(),
2761 ED.AlignedBarriers.end());
2763 while (!Worklist.empty()) {
2764 CallBase *LastCB = Worklist.pop_back_val();
2765 if (!Visited.insert(LastCB))
2766 continue;
2767 if (LastCB->getFunction() != getAnchorScope())
2768 continue;
2769 if (!hasFunctionEndAsUniqueSuccessor(LastCB->getParent()))
2770 continue;
2771 if (!DeletedBarriers.count(LastCB)) {
2772 ++NumBarriersEliminated;
2773 A.deleteAfterManifest(*LastCB);
2774 continue;
2775 }
2776 // The final aligned barrier (LastCB) reaching the kernel end was
2777 // removed already. This means we can go one step further and remove
2778 // the barriers encoutered last before (LastCB).
2779 const ExecutionDomainTy &LastED = CEDMap[{LastCB, PRE}];
2780 Worklist.append(LastED.AlignedBarriers.begin(),
2781 LastED.AlignedBarriers.end());
2782 }
2783 }
2784
2785 // If we actually eliminated a barrier we need to eliminate the associated
2786 // llvm.assumes as well to avoid creating UB.
2787 if (!ED.EncounteredAssumes.empty() && (CB || !ED.AlignedBarriers.empty()))
2788 for (auto *AssumeCB : ED.EncounteredAssumes)
2789 A.deleteAfterManifest(*AssumeCB);
2790 };
2791
2792 for (auto *CB : AlignedBarriers)
2793 HandleAlignedBarrier(CB);
2794
2795 // Handle the "kernel end barrier" for kernels too.
2797 HandleAlignedBarrier(nullptr);
2798
2799 return Changed;
2800 }
2801
2802 bool isNoOpFence(const FenceInst &FI) const override {
2803 return getState().isValidState() && !NonNoOpFences.count(&FI);
2804 }
2805
2806 /// Merge barrier and assumption information from \p PredED into the successor
2807 /// \p ED.
2808 void
2809 mergeInPredecessorBarriersAndAssumptions(Attributor &A, ExecutionDomainTy &ED,
2810 const ExecutionDomainTy &PredED);
2811
2812 /// Merge all information from \p PredED into the successor \p ED. If
2813 /// \p InitialEdgeOnly is set, only the initial edge will enter the block
2814 /// represented by \p ED from this predecessor.
2815 bool mergeInPredecessor(Attributor &A, ExecutionDomainTy &ED,
2816 const ExecutionDomainTy &PredED,
2817 bool InitialEdgeOnly = false);
2818
2819 /// Accumulate information for the entry block in \p EntryBBED.
2820 bool handleCallees(Attributor &A, ExecutionDomainTy &EntryBBED);
2821
2822 /// See AbstractAttribute::updateImpl.
2824
2825 /// Query interface, see AAExecutionDomain
2826 ///{
2827 bool isExecutedByInitialThreadOnly(const BasicBlock &BB) const override {
2828 if (!isValidState())
2829 return false;
2830 assert(BB.getParent() == getAnchorScope() && "Block is out of scope!");
2831 return BEDMap.lookup(&BB).IsExecutedByInitialThreadOnly;
2832 }
2833
2835 const Instruction &I) const override {
2836 assert(I.getFunction() == getAnchorScope() &&
2837 "Instruction is out of scope!");
2838 if (!isValidState())
2839 return false;
2840
2841 bool ForwardIsOk = true;
2842 const Instruction *CurI;
2843
2844 // Check forward until a call or the block end is reached.
2845 CurI = &I;
2846 do {
2847 auto *CB = dyn_cast<CallBase>(CurI);
2848 if (!CB)
2849 continue;
2850 if (CB != &I && AlignedBarriers.contains(const_cast<CallBase *>(CB)))
2851 return true;
2852 const auto &It = CEDMap.find({CB, PRE});
2853 if (It == CEDMap.end())
2854 continue;
2855 if (!It->getSecond().IsReachingAlignedBarrierOnly)
2856 ForwardIsOk = false;
2857 break;
2858 } while ((CurI = CurI->getNextNonDebugInstruction()));
2859
2860 if (!CurI && !BEDMap.lookup(I.getParent()).IsReachingAlignedBarrierOnly)
2861 ForwardIsOk = false;
2862
2863 // Check backward until a call or the block beginning is reached.
2864 CurI = &I;
2865 do {
2866 auto *CB = dyn_cast<CallBase>(CurI);
2867 if (!CB)
2868 continue;
2869 if (CB != &I && AlignedBarriers.contains(const_cast<CallBase *>(CB)))
2870 return true;
2871 const auto &It = CEDMap.find({CB, POST});
2872 if (It == CEDMap.end())
2873 continue;
2874 if (It->getSecond().IsReachedFromAlignedBarrierOnly)
2875 break;
2876 return false;
2877 } while ((CurI = CurI->getPrevNonDebugInstruction()));
2878
2879 // Delayed decision on the forward pass to allow aligned barrier detection
2880 // in the backwards traversal.
2881 if (!ForwardIsOk)
2882 return false;
2883
2884 if (!CurI) {
2885 const BasicBlock *BB = I.getParent();
2886 if (BB == &BB->getParent()->getEntryBlock())
2887 return BEDMap.lookup(nullptr).IsReachedFromAlignedBarrierOnly;
2888 if (!llvm::all_of(predecessors(BB), [&](const BasicBlock *PredBB) {
2889 return BEDMap.lookup(PredBB).IsReachedFromAlignedBarrierOnly;
2890 })) {
2891 return false;
2892 }
2893 }
2894
2895 // On neither traversal we found a anything but aligned barriers.
2896 return true;
2897 }
2898
2899 ExecutionDomainTy getExecutionDomain(const BasicBlock &BB) const override {
2900 assert(isValidState() &&
2901 "No request should be made against an invalid state!");
2902 return BEDMap.lookup(&BB);
2903 }
2904 std::pair<ExecutionDomainTy, ExecutionDomainTy>
2905 getExecutionDomain(const CallBase &CB) const override {
2906 assert(isValidState() &&
2907 "No request should be made against an invalid state!");
2908 return {CEDMap.lookup({&CB, PRE}), CEDMap.lookup({&CB, POST})};
2909 }
2910 ExecutionDomainTy getFunctionExecutionDomain() const override {
2911 assert(isValidState() &&
2912 "No request should be made against an invalid state!");
2913 return InterProceduralED;
2914 }
2915 ///}
2916
2917 // Check if the edge into the successor block contains a condition that only
2918 // lets the main thread execute it.
2919 static bool isInitialThreadOnlyEdge(Attributor &A, BranchInst *Edge,
2920 BasicBlock &SuccessorBB) {
2921 if (!Edge || !Edge->isConditional())
2922 return false;
2923 if (Edge->getSuccessor(0) != &SuccessorBB)
2924 return false;
2925
2926 auto *Cmp = dyn_cast<CmpInst>(Edge->getCondition());
2927 if (!Cmp || !Cmp->isTrueWhenEqual() || !Cmp->isEquality())
2928 return false;
2929
2930 ConstantInt *C = dyn_cast<ConstantInt>(Cmp->getOperand(1));
2931 if (!C)
2932 return false;
2933
2934 // Match: -1 == __kmpc_target_init (for non-SPMD kernels only!)
2935 if (C->isAllOnesValue()) {
2936 auto *CB = dyn_cast<CallBase>(Cmp->getOperand(0));
2937 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2938 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
2939 CB = CB ? OpenMPOpt::getCallIfRegularCall(*CB, &RFI) : nullptr;
2940 if (!CB)
2941 return false;
2942 ConstantStruct *KernelEnvC =
2944 ConstantInt *ExecModeC =
2945 KernelInfo::getExecModeFromKernelEnvironment(KernelEnvC);
2946 return ExecModeC->getSExtValue() & OMP_TGT_EXEC_MODE_GENERIC;
2947 }
2948
2949 if (C->isZero()) {
2950 // Match: 0 == llvm.nvvm.read.ptx.sreg.tid.x()
2951 if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
2952 if (II->getIntrinsicID() == Intrinsic::nvvm_read_ptx_sreg_tid_x)
2953 return true;
2954
2955 // Match: 0 == llvm.amdgcn.workitem.id.x()
2956 if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
2957 if (II->getIntrinsicID() == Intrinsic::amdgcn_workitem_id_x)
2958 return true;
2959 }
2960
2961 return false;
2962 };
2963
2964 /// Mapping containing information about the function for other AAs.
2965 ExecutionDomainTy InterProceduralED;
2966
2967 enum Direction { PRE = 0, POST = 1 };
2968 /// Mapping containing information per block.
2971 CEDMap;
2972 SmallSetVector<CallBase *, 16> AlignedBarriers;
2973
2975
2976 /// Set \p R to \V and report true if that changed \p R.
2977 static bool setAndRecord(bool &R, bool V) {
2978 bool Eq = (R == V);
2979 R = V;
2980 return !Eq;
2981 }
2982
2983 /// Collection of fences known to be non-no-opt. All fences not in this set
2984 /// can be assumed no-opt.
2986};
2987
2988void AAExecutionDomainFunction::mergeInPredecessorBarriersAndAssumptions(
2989 Attributor &A, ExecutionDomainTy &ED, const ExecutionDomainTy &PredED) {
2990 for (auto *EA : PredED.EncounteredAssumes)
2991 ED.addAssumeInst(A, *EA);
2992
2993 for (auto *AB : PredED.AlignedBarriers)
2994 ED.addAlignedBarrier(A, *AB);
2995}
2996
2997bool AAExecutionDomainFunction::mergeInPredecessor(
2998 Attributor &A, ExecutionDomainTy &ED, const ExecutionDomainTy &PredED,
2999 bool InitialEdgeOnly) {
3000
3001 bool Changed = false;
3002 Changed |=
3003 setAndRecord(ED.IsExecutedByInitialThreadOnly,
3004 InitialEdgeOnly || (PredED.IsExecutedByInitialThreadOnly &&
3005 ED.IsExecutedByInitialThreadOnly));
3006
3007 Changed |= setAndRecord(ED.IsReachedFromAlignedBarrierOnly,
3008 ED.IsReachedFromAlignedBarrierOnly &&
3009 PredED.IsReachedFromAlignedBarrierOnly);
3010 Changed |= setAndRecord(ED.EncounteredNonLocalSideEffect,
3011 ED.EncounteredNonLocalSideEffect |
3012 PredED.EncounteredNonLocalSideEffect);
3013 // Do not track assumptions and barriers as part of Changed.
3014 if (ED.IsReachedFromAlignedBarrierOnly)
3015 mergeInPredecessorBarriersAndAssumptions(A, ED, PredED);
3016 else
3017 ED.clearAssumeInstAndAlignedBarriers();
3018 return Changed;
3019}
3020
3021bool AAExecutionDomainFunction::handleCallees(Attributor &A,
3022 ExecutionDomainTy &EntryBBED) {
3024 auto PredForCallSite = [&](AbstractCallSite ACS) {
3025 const auto *EDAA = A.getAAFor<AAExecutionDomain>(
3026 *this, IRPosition::function(*ACS.getInstruction()->getFunction()),
3027 DepClassTy::OPTIONAL);
3028 if (!EDAA || !EDAA->getState().isValidState())
3029 return false;
3030 CallSiteEDs.emplace_back(
3031 EDAA->getExecutionDomain(*cast<CallBase>(ACS.getInstruction())));
3032 return true;
3033 };
3034
3035 ExecutionDomainTy ExitED;
3036 bool AllCallSitesKnown;
3037 if (A.checkForAllCallSites(PredForCallSite, *this,
3038 /* RequiresAllCallSites */ true,
3039 AllCallSitesKnown)) {
3040 for (const auto &[CSInED, CSOutED] : CallSiteEDs) {
3041 mergeInPredecessor(A, EntryBBED, CSInED);
3042 ExitED.IsReachingAlignedBarrierOnly &=
3043 CSOutED.IsReachingAlignedBarrierOnly;
3044 }
3045
3046 } else {
3047 // We could not find all predecessors, so this is either a kernel or a
3048 // function with external linkage (or with some other weird uses).
3049 if (omp::isOpenMPKernel(*getAnchorScope())) {
3050 EntryBBED.IsExecutedByInitialThreadOnly = false;
3051 EntryBBED.IsReachedFromAlignedBarrierOnly = true;
3052 EntryBBED.EncounteredNonLocalSideEffect = false;
3053 ExitED.IsReachingAlignedBarrierOnly = false;
3054 } else {
3055 EntryBBED.IsExecutedByInitialThreadOnly = false;
3056 EntryBBED.IsReachedFromAlignedBarrierOnly = false;
3057 EntryBBED.EncounteredNonLocalSideEffect = true;
3058 ExitED.IsReachingAlignedBarrierOnly = false;
3059 }
3060 }
3061
3062 bool Changed = false;
3063 auto &FnED = BEDMap[nullptr];
3064 Changed |= setAndRecord(FnED.IsReachedFromAlignedBarrierOnly,
3065 FnED.IsReachedFromAlignedBarrierOnly &
3066 EntryBBED.IsReachedFromAlignedBarrierOnly);
3067 Changed |= setAndRecord(FnED.IsReachingAlignedBarrierOnly,
3068 FnED.IsReachingAlignedBarrierOnly &
3069 ExitED.IsReachingAlignedBarrierOnly);
3070 Changed |= setAndRecord(FnED.IsExecutedByInitialThreadOnly,
3071 EntryBBED.IsExecutedByInitialThreadOnly);
3072 return Changed;
3073}
3074
3075ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
3076
3077 bool Changed = false;
3078
3079 // Helper to deal with an aligned barrier encountered during the forward
3080 // traversal. \p CB is the aligned barrier, \p ED is the execution domain when
3081 // it was encountered.
3082 auto HandleAlignedBarrier = [&](CallBase &CB, ExecutionDomainTy &ED) {
3083 Changed |= AlignedBarriers.insert(&CB);
3084 // First, update the barrier ED kept in the separate CEDMap.
3085 auto &CallInED = CEDMap[{&CB, PRE}];
3086 Changed |= mergeInPredecessor(A, CallInED, ED);
3087 CallInED.IsReachingAlignedBarrierOnly = true;
3088 // Next adjust the ED we use for the traversal.
3089 ED.EncounteredNonLocalSideEffect = false;
3090 ED.IsReachedFromAlignedBarrierOnly = true;
3091 // Aligned barrier collection has to come last.
3092 ED.clearAssumeInstAndAlignedBarriers();
3093 ED.addAlignedBarrier(A, CB);
3094 auto &CallOutED = CEDMap[{&CB, POST}];
3095 Changed |= mergeInPredecessor(A, CallOutED, ED);
3096 };
3097
3098 auto *LivenessAA =
3099 A.getAAFor<AAIsDead>(*this, getIRPosition(), DepClassTy::OPTIONAL);
3100
3101 Function *F = getAnchorScope();
3102 BasicBlock &EntryBB = F->getEntryBlock();
3103 bool IsKernel = omp::isOpenMPKernel(*F);
3104
3105 SmallVector<Instruction *> SyncInstWorklist;
3106 for (auto &RIt : *RPOT) {
3107 BasicBlock &BB = *RIt;
3108
3109 bool IsEntryBB = &BB == &EntryBB;
3110 // TODO: We use local reasoning since we don't have a divergence analysis
3111 // running as well. We could basically allow uniform branches here.
3112 bool AlignedBarrierLastInBlock = IsEntryBB && IsKernel;
3113 bool IsExplicitlyAligned = IsEntryBB && IsKernel;
3114 ExecutionDomainTy ED;
3115 // Propagate "incoming edges" into information about this block.
3116 if (IsEntryBB) {
3117 Changed |= handleCallees(A, ED);
3118 } else {
3119 // For live non-entry blocks we only propagate
3120 // information via live edges.
3121 if (LivenessAA && LivenessAA->isAssumedDead(&BB))
3122 continue;
3123
3124 for (auto *PredBB : predecessors(&BB)) {
3125 if (LivenessAA && LivenessAA->isEdgeDead(PredBB, &BB))
3126 continue;
3127 bool InitialEdgeOnly = isInitialThreadOnlyEdge(
3128 A, dyn_cast<BranchInst>(PredBB->getTerminator()), BB);
3129 mergeInPredecessor(A, ED, BEDMap[PredBB], InitialEdgeOnly);
3130 }
3131 }
3132
3133 // Now we traverse the block, accumulate effects in ED and attach
3134 // information to calls.
3135 for (Instruction &I : BB) {
3136 bool UsedAssumedInformation;
3137 if (A.isAssumedDead(I, *this, LivenessAA, UsedAssumedInformation,
3138 /* CheckBBLivenessOnly */ false, DepClassTy::OPTIONAL,
3139 /* CheckForDeadStore */ true))
3140 continue;
3141
3142 // Asummes and "assume-like" (dbg, lifetime, ...) are handled first, the
3143 // former is collected the latter is ignored.
3144 if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
3145 if (auto *AI = dyn_cast_or_null<AssumeInst>(II)) {
3146 ED.addAssumeInst(A, *AI);
3147 continue;
3148 }
3149 // TODO: Should we also collect and delete lifetime markers?
3150 if (II->isAssumeLikeIntrinsic())
3151 continue;
3152 }
3153
3154 if (auto *FI = dyn_cast<FenceInst>(&I)) {
3155 if (!ED.EncounteredNonLocalSideEffect) {
3156 // An aligned fence without non-local side-effects is a no-op.
3157 if (ED.IsReachedFromAlignedBarrierOnly)
3158 continue;
3159 // A non-aligned fence without non-local side-effects is a no-op
3160 // if the ordering only publishes non-local side-effects (or less).
3161 switch (FI->getOrdering()) {
3162 case AtomicOrdering::NotAtomic:
3163 continue;
3164 case AtomicOrdering::Unordered:
3165 continue;
3166 case AtomicOrdering::Monotonic:
3167 continue;
3168 case AtomicOrdering::Acquire:
3169 break;
3170 case AtomicOrdering::Release:
3171 continue;
3172 case AtomicOrdering::AcquireRelease:
3173 break;
3174 case AtomicOrdering::SequentiallyConsistent:
3175 break;
3176 };
3177 }
3178 NonNoOpFences.insert(FI);
3179 }
3180
3181 auto *CB = dyn_cast<CallBase>(&I);
3182 bool IsNoSync = AA::isNoSyncInst(A, I, *this);
3183 bool IsAlignedBarrier =
3184 !IsNoSync && CB &&
3185 AANoSync::isAlignedBarrier(*CB, AlignedBarrierLastInBlock);
3186
3187 AlignedBarrierLastInBlock &= IsNoSync;
3188 IsExplicitlyAligned &= IsNoSync;
3189
3190 // Next we check for calls. Aligned barriers are handled
3191 // explicitly, everything else is kept for the backward traversal and will
3192 // also affect our state.
3193 if (CB) {
3194 if (IsAlignedBarrier) {
3195 HandleAlignedBarrier(*CB, ED);
3196 AlignedBarrierLastInBlock = true;
3197 IsExplicitlyAligned = true;
3198 continue;
3199 }
3200
3201 // Check the pointer(s) of a memory intrinsic explicitly.
3202 if (isa<MemIntrinsic>(&I)) {
3203 if (!ED.EncounteredNonLocalSideEffect &&
3205 ED.EncounteredNonLocalSideEffect = true;
3206 if (!IsNoSync) {
3207 ED.IsReachedFromAlignedBarrierOnly = false;
3208 SyncInstWorklist.push_back(&I);
3209 }
3210 continue;
3211 }
3212
3213 // Record how we entered the call, then accumulate the effect of the
3214 // call in ED for potential use by the callee.
3215 auto &CallInED = CEDMap[{CB, PRE}];
3216 Changed |= mergeInPredecessor(A, CallInED, ED);
3217
3218 // If we have a sync-definition we can check if it starts/ends in an
3219 // aligned barrier. If we are unsure we assume any sync breaks
3220 // alignment.
3222 if (!IsNoSync && Callee && !Callee->isDeclaration()) {
3223 const auto *EDAA = A.getAAFor<AAExecutionDomain>(
3224 *this, IRPosition::function(*Callee), DepClassTy::OPTIONAL);
3225 if (EDAA && EDAA->getState().isValidState()) {
3226 const auto &CalleeED = EDAA->getFunctionExecutionDomain();
3228 CalleeED.IsReachedFromAlignedBarrierOnly;
3229 AlignedBarrierLastInBlock = ED.IsReachedFromAlignedBarrierOnly;
3230 if (IsNoSync || !CalleeED.IsReachedFromAlignedBarrierOnly)
3231 ED.EncounteredNonLocalSideEffect |=
3232 CalleeED.EncounteredNonLocalSideEffect;
3233 else
3234 ED.EncounteredNonLocalSideEffect =
3235 CalleeED.EncounteredNonLocalSideEffect;
3236 if (!CalleeED.IsReachingAlignedBarrierOnly) {
3237 Changed |=
3238 setAndRecord(CallInED.IsReachingAlignedBarrierOnly, false);
3239 SyncInstWorklist.push_back(&I);
3240 }
3241 if (CalleeED.IsReachedFromAlignedBarrierOnly)
3242 mergeInPredecessorBarriersAndAssumptions(A, ED, CalleeED);
3243 auto &CallOutED = CEDMap[{CB, POST}];
3244 Changed |= mergeInPredecessor(A, CallOutED, ED);
3245 continue;
3246 }
3247 }
3248 if (!IsNoSync) {
3249 ED.IsReachedFromAlignedBarrierOnly = false;
3250 Changed |= setAndRecord(CallInED.IsReachingAlignedBarrierOnly, false);
3251 SyncInstWorklist.push_back(&I);
3252 }
3253 AlignedBarrierLastInBlock &= ED.IsReachedFromAlignedBarrierOnly;
3254 ED.EncounteredNonLocalSideEffect |= !CB->doesNotAccessMemory();
3255 auto &CallOutED = CEDMap[{CB, POST}];
3256 Changed |= mergeInPredecessor(A, CallOutED, ED);
3257 }
3258
3259 if (!I.mayHaveSideEffects() && !I.mayReadFromMemory())
3260 continue;
3261
3262 // If we have a callee we try to use fine-grained information to
3263 // determine local side-effects.
3264 if (CB) {
3265 const auto *MemAA = A.getAAFor<AAMemoryLocation>(
3266 *this, IRPosition::callsite_function(*CB), DepClassTy::OPTIONAL);
3267
3268 auto AccessPred = [&](const Instruction *I, const Value *Ptr,
3271 return !AA::isPotentiallyAffectedByBarrier(A, {Ptr}, *this, I);
3272 };
3273 if (MemAA && MemAA->getState().isValidState() &&
3274 MemAA->checkForAllAccessesToMemoryKind(
3276 continue;
3277 }
3278
3279 auto &InfoCache = A.getInfoCache();
3280 if (!I.mayHaveSideEffects() && InfoCache.isOnlyUsedByAssume(I))
3281 continue;
3282
3283 if (auto *LI = dyn_cast<LoadInst>(&I))
3284 if (LI->hasMetadata(LLVMContext::MD_invariant_load))
3285 continue;
3286
3287 if (!ED.EncounteredNonLocalSideEffect &&
3289 ED.EncounteredNonLocalSideEffect = true;
3290 }
3291
3292 bool IsEndAndNotReachingAlignedBarriersOnly = false;
3293 if (!isa<UnreachableInst>(BB.getTerminator()) &&
3294 !BB.getTerminator()->getNumSuccessors()) {
3295
3296 Changed |= mergeInPredecessor(A, InterProceduralED, ED);
3297
3298 auto &FnED = BEDMap[nullptr];
3299 if (IsKernel && !IsExplicitlyAligned)
3300 FnED.IsReachingAlignedBarrierOnly = false;
3301 Changed |= mergeInPredecessor(A, FnED, ED);
3302
3303 if (!FnED.IsReachingAlignedBarrierOnly) {
3304 IsEndAndNotReachingAlignedBarriersOnly = true;
3305 SyncInstWorklist.push_back(BB.getTerminator());
3306 auto &BBED = BEDMap[&BB];
3307 Changed |= setAndRecord(BBED.IsReachingAlignedBarrierOnly, false);
3308 }
3309 }
3310
3311 ExecutionDomainTy &StoredED = BEDMap[&BB];
3312 ED.IsReachingAlignedBarrierOnly = StoredED.IsReachingAlignedBarrierOnly &
3313 !IsEndAndNotReachingAlignedBarriersOnly;
3314
3315 // Check if we computed anything different as part of the forward
3316 // traversal. We do not take assumptions and aligned barriers into account
3317 // as they do not influence the state we iterate. Backward traversal values
3318 // are handled later on.
3319 if (ED.IsExecutedByInitialThreadOnly !=
3320 StoredED.IsExecutedByInitialThreadOnly ||
3321 ED.IsReachedFromAlignedBarrierOnly !=
3322 StoredED.IsReachedFromAlignedBarrierOnly ||
3323 ED.EncounteredNonLocalSideEffect !=
3324 StoredED.EncounteredNonLocalSideEffect)
3325 Changed = true;
3326
3327 // Update the state with the new value.
3328 StoredED = std::move(ED);
3329 }
3330
3331 // Propagate (non-aligned) sync instruction effects backwards until the
3332 // entry is hit or an aligned barrier.
3334 while (!SyncInstWorklist.empty()) {
3335 Instruction *SyncInst = SyncInstWorklist.pop_back_val();
3336 Instruction *CurInst = SyncInst;
3337 bool HitAlignedBarrierOrKnownEnd = false;
3338 while ((CurInst = CurInst->getPrevNode())) {
3339 auto *CB = dyn_cast<CallBase>(CurInst);
3340 if (!CB)
3341 continue;
3342 auto &CallOutED = CEDMap[{CB, POST}];
3343 Changed |= setAndRecord(CallOutED.IsReachingAlignedBarrierOnly, false);
3344 auto &CallInED = CEDMap[{CB, PRE}];
3345 HitAlignedBarrierOrKnownEnd =
3346 AlignedBarriers.count(CB) || !CallInED.IsReachingAlignedBarrierOnly;
3347 if (HitAlignedBarrierOrKnownEnd)
3348 break;
3349 Changed |= setAndRecord(CallInED.IsReachingAlignedBarrierOnly, false);
3350 }
3351 if (HitAlignedBarrierOrKnownEnd)
3352 continue;
3353 BasicBlock *SyncBB = SyncInst->getParent();
3354 for (auto *PredBB : predecessors(SyncBB)) {
3355 if (LivenessAA && LivenessAA->isEdgeDead(PredBB, SyncBB))
3356 continue;
3357 if (!Visited.insert(PredBB))
3358 continue;
3359 auto &PredED = BEDMap[PredBB];
3360 if (setAndRecord(PredED.IsReachingAlignedBarrierOnly, false)) {
3361 Changed = true;
3362 SyncInstWorklist.push_back(PredBB->getTerminator());
3363 }
3364 }
3365 if (SyncBB != &EntryBB)
3366 continue;
3367 Changed |=
3368 setAndRecord(InterProceduralED.IsReachingAlignedBarrierOnly, false);
3369 }
3370
3371 return Changed ? ChangeStatus::CHANGED : ChangeStatus::UNCHANGED;
3372}
3373
3374/// Try to replace memory allocation calls called by a single thread with a
3375/// static buffer of shared memory.
3376struct AAHeapToShared : public StateWrapper<BooleanState, AbstractAttribute> {
3378 AAHeapToShared(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
3379
3380 /// Create an abstract attribute view for the position \p IRP.
3381 static AAHeapToShared &createForPosition(const IRPosition &IRP,
3382 Attributor &A);
3383
3384 /// Returns true if HeapToShared conversion is assumed to be possible.
3385 virtual bool isAssumedHeapToShared(CallBase &CB) const = 0;
3386
3387 /// Returns true if HeapToShared conversion is assumed and the CB is a
3388 /// callsite to a free operation to be removed.
3389 virtual bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const = 0;
3390
3391 /// See AbstractAttribute::getName().
3392 const std::string getName() const override { return "AAHeapToShared"; }
3393
3394 /// See AbstractAttribute::getIdAddr().
3395 const char *getIdAddr() const override { return &ID; }
3396
3397 /// This function should return true if the type of the \p AA is
3398 /// AAHeapToShared.
3399 static bool classof(const AbstractAttribute *AA) {
3400 return (AA->getIdAddr() == &ID);
3401 }
3402
3403 /// Unique ID (due to the unique address)
3404 static const char ID;
3405};
3406
3407struct AAHeapToSharedFunction : public AAHeapToShared {
3408 AAHeapToSharedFunction(const IRPosition &IRP, Attributor &A)
3409 : AAHeapToShared(IRP, A) {}
3410
3411 const std::string getAsStr(Attributor *) const override {
3412 return "[AAHeapToShared] " + std::to_string(MallocCalls.size()) +
3413 " malloc calls eligible.";
3414 }
3415
3416 /// See AbstractAttribute::trackStatistics().
3417 void trackStatistics() const override {}
3418
3419 /// This functions finds free calls that will be removed by the
3420 /// HeapToShared transformation.
3421 void findPotentialRemovedFreeCalls(Attributor &A) {
3422 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3423 auto &FreeRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
3424
3425 PotentialRemovedFreeCalls.clear();
3426 // Update free call users of found malloc calls.
3427 for (CallBase *CB : MallocCalls) {
3429 for (auto *U : CB->users()) {
3430 CallBase *C = dyn_cast<CallBase>(U);
3431 if (C && C->getCalledFunction() == FreeRFI.Declaration)
3432 FreeCalls.push_back(C);
3433 }
3434
3435 if (FreeCalls.size() != 1)
3436 continue;
3437
3438 PotentialRemovedFreeCalls.insert(FreeCalls.front());
3439 }
3440 }
3441
3442 void initialize(Attributor &A) override {
3444 indicatePessimisticFixpoint();
3445 return;
3446 }
3447
3448 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3449 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3450 if (!RFI.Declaration)
3451 return;
3452
3454 [](const IRPosition &, const AbstractAttribute *,
3455 bool &) -> std::optional<Value *> { return nullptr; };
3456
3457 Function *F = getAnchorScope();
3458 for (User *U : RFI.Declaration->users())
3459 if (CallBase *CB = dyn_cast<CallBase>(U)) {
3460 if (CB->getFunction() != F)
3461 continue;
3462 MallocCalls.insert(CB);
3463 A.registerSimplificationCallback(IRPosition::callsite_returned(*CB),
3464 SCB);
3465 }
3466
3467 findPotentialRemovedFreeCalls(A);
3468 }
3469
3470 bool isAssumedHeapToShared(CallBase &CB) const override {
3471 return isValidState() && MallocCalls.count(&CB);
3472 }
3473
3474 bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const override {
3475 return isValidState() && PotentialRemovedFreeCalls.count(&CB);
3476 }
3477
3478 ChangeStatus manifest(Attributor &A) override {
3479 if (MallocCalls.empty())
3480 return ChangeStatus::UNCHANGED;
3481
3482 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3483 auto &FreeCall = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
3484
3485 Function *F = getAnchorScope();
3486 auto *HS = A.lookupAAFor<AAHeapToStack>(IRPosition::function(*F), this,
3487 DepClassTy::OPTIONAL);
3488
3489 ChangeStatus Changed = ChangeStatus::UNCHANGED;
3490 for (CallBase *CB : MallocCalls) {
3491 // Skip replacing this if HeapToStack has already claimed it.
3492 if (HS && HS->isAssumedHeapToStack(*CB))
3493 continue;
3494
3495 // Find the unique free call to remove it.
3497 for (auto *U : CB->users()) {
3498 CallBase *C = dyn_cast<CallBase>(U);
3499 if (C && C->getCalledFunction() == FreeCall.Declaration)
3500 FreeCalls.push_back(C);
3501 }
3502 if (FreeCalls.size() != 1)
3503 continue;
3504
3505 auto *AllocSize = cast<ConstantInt>(CB->getArgOperand(0));
3506
3507 if (AllocSize->getZExtValue() + SharedMemoryUsed > SharedMemoryLimit) {
3508 LLVM_DEBUG(dbgs() << TAG << "Cannot replace call " << *CB
3509 << " with shared memory."
3510 << " Shared memory usage is limited to "
3511 << SharedMemoryLimit << " bytes\n");
3512 continue;
3513 }
3514
3515 LLVM_DEBUG(dbgs() << TAG << "Replace globalization call " << *CB
3516 << " with " << AllocSize->getZExtValue()
3517 << " bytes of shared memory\n");
3518
3519 // Create a new shared memory buffer of the same size as the allocation
3520 // and replace all the uses of the original allocation with it.
3521 Module *M = CB->getModule();
3522 Type *Int8Ty = Type::getInt8Ty(M->getContext());
3523 Type *Int8ArrTy = ArrayType::get(Int8Ty, AllocSize->getZExtValue());
3524 auto *SharedMem = new GlobalVariable(
3525 *M, Int8ArrTy, /* IsConstant */ false, GlobalValue::InternalLinkage,
3526 PoisonValue::get(Int8ArrTy), CB->getName() + "_shared", nullptr,
3528 static_cast<unsigned>(AddressSpace::Shared));
3529 auto *NewBuffer =
3530 ConstantExpr::getPointerCast(SharedMem, Int8Ty->getPointerTo());
3531
3532 auto Remark = [&](OptimizationRemark OR) {
3533 return OR << "Replaced globalized variable with "
3534 << ore::NV("SharedMemory", AllocSize->getZExtValue())
3535 << (AllocSize->isOne() ? " byte " : " bytes ")
3536 << "of shared memory.";
3537 };
3538 A.emitRemark<OptimizationRemark>(CB, "OMP111", Remark);
3539
3540 MaybeAlign Alignment = CB->getRetAlign();
3541 assert(Alignment &&
3542 "HeapToShared on allocation without alignment attribute");
3543 SharedMem->setAlignment(*Alignment);
3544
3545 A.changeAfterManifest(IRPosition::callsite_returned(*CB), *NewBuffer);
3546 A.deleteAfterManifest(*CB);
3547 A.deleteAfterManifest(*FreeCalls.front());
3548
3549 SharedMemoryUsed += AllocSize->getZExtValue();
3550 NumBytesMovedToSharedMemory = SharedMemoryUsed;
3551 Changed = ChangeStatus::CHANGED;
3552 }
3553
3554 return Changed;
3555 }
3556
3557 ChangeStatus updateImpl(Attributor &A) override {
3558 if (MallocCalls.empty())
3559 return indicatePessimisticFixpoint();
3560 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3561 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3562 if (!RFI.Declaration)
3563 return ChangeStatus::UNCHANGED;
3564
3565 Function *F = getAnchorScope();
3566
3567 auto NumMallocCalls = MallocCalls.size();
3568
3569 // Only consider malloc calls executed by a single thread with a constant.
3570 for (User *U : RFI.Declaration->users()) {
3571 if (CallBase *CB = dyn_cast<CallBase>(U)) {
3572 if (CB->getCaller() != F)
3573 continue;
3574 if (!MallocCalls.count(CB))
3575 continue;
3576 if (!isa<ConstantInt>(CB->getArgOperand(0))) {
3577 MallocCalls.remove(CB);
3578 continue;
3579 }
3580 const auto *ED = A.getAAFor<AAExecutionDomain>(
3581 *this, IRPosition::function(*F), DepClassTy::REQUIRED);
3582 if (!ED || !ED->isExecutedByInitialThreadOnly(*CB))
3583 MallocCalls.remove(CB);
3584 }
3585 }
3586
3587 findPotentialRemovedFreeCalls(A);
3588
3589 if (NumMallocCalls != MallocCalls.size())
3590 return ChangeStatus::CHANGED;
3591
3592 return ChangeStatus::UNCHANGED;
3593 }
3594
3595 /// Collection of all malloc calls in a function.
3597 /// Collection of potentially removed free calls in a function.
3598 SmallPtrSet<CallBase *, 4> PotentialRemovedFreeCalls;
3599 /// The total amount of shared memory that has been used for HeapToShared.
3600 unsigned SharedMemoryUsed = 0;
3601};
3602
3603struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
3605 AAKernelInfo(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
3606
3607 /// The callee value is tracked beyond a simple stripPointerCasts, so we allow
3608 /// unknown callees.
3609 static bool requiresCalleeForCallBase() { return false; }
3610
3611 /// Statistics are tracked as part of manifest for now.
3612 void trackStatistics() const override {}
3613
3614 /// See AbstractAttribute::getAsStr()
3615 const std::string getAsStr(Attributor *) const override {
3616 if (!isValidState())
3617 return "<invalid>";
3618 return std::string(SPMDCompatibilityTracker.isAssumed() ? "SPMD"
3619 : "generic") +
3620 std::string(SPMDCompatibilityTracker.isAtFixpoint() ? " [FIX]"
3621 : "") +
3622 std::string(" #PRs: ") +
3623 (ReachedKnownParallelRegions.isValidState()
3624 ? std::to_string(ReachedKnownParallelRegions.size())
3625 : "<invalid>") +
3626 ", #Unknown PRs: " +
3627 (ReachedUnknownParallelRegions.isValidState()
3628 ? std::to_string(ReachedUnknownParallelRegions.size())
3629 : "<invalid>") +
3630 ", #Reaching Kernels: " +
3631 (ReachingKernelEntries.isValidState()
3632 ? std::to_string(ReachingKernelEntries.size())
3633 : "<invalid>") +
3634 ", #ParLevels: " +
3635 (ParallelLevels.isValidState()
3636 ? std::to_string(ParallelLevels.size())
3637 : "<invalid>") +
3638 ", NestedPar: " + (NestedParallelism ? "yes" : "no");
3639 }
3640
3641 /// Create an abstract attribute biew for the position \p IRP.
3642 static AAKernelInfo &createForPosition(const IRPosition &IRP, Attributor &A);
3643
3644 /// See AbstractAttribute::getName()
3645 const std::string getName() const override { return "AAKernelInfo"; }
3646
3647 /// See AbstractAttribute::getIdAddr()
3648 const char *getIdAddr() const override { return &ID; }
3649
3650 /// This function should return true if the type of the \p AA is AAKernelInfo
3651 static bool classof(const AbstractAttribute *AA) {
3652 return (AA->getIdAddr() == &ID);
3653 }
3654
3655 static const char ID;
3656};
3657
3658/// The function kernel info abstract attribute, basically, what can we say
3659/// about a function with regards to the KernelInfoState.
3660struct AAKernelInfoFunction : AAKernelInfo {
3661 AAKernelInfoFunction(const IRPosition &IRP, Attributor &A)
3662 : AAKernelInfo(IRP, A) {}
3663
3664 SmallPtrSet<Instruction *, 4> GuardedInstructions;
3665
3666 SmallPtrSetImpl<Instruction *> &getGuardedInstructions() {
3667 return GuardedInstructions;
3668 }
3669
3670 void setConfigurationOfKernelEnvironment(ConstantStruct *ConfigC) {
3672 KernelEnvC, ConfigC, {KernelInfo::ConfigurationIdx});
3673 assert(NewKernelEnvC && "Failed to create new kernel environment");
3674 KernelEnvC = cast<ConstantStruct>(NewKernelEnvC);
3675 }
3676
3677#define KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MEMBER) \
3678 void set##MEMBER##OfKernelEnvironment(ConstantInt *NewVal) { \
3679 ConstantStruct *ConfigC = \
3680 KernelInfo::getConfigurationFromKernelEnvironment(KernelEnvC); \
3681 Constant *NewConfigC = ConstantFoldInsertValueInstruction( \
3682 ConfigC, NewVal, {KernelInfo::MEMBER##Idx}); \
3683 assert(NewConfigC && "Failed to create new configuration environment"); \
3684 setConfigurationOfKernelEnvironment(cast<ConstantStruct>(NewConfigC)); \
3685 }
3686
3687 KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(UseGenericStateMachine)
3688 KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MayUseNestedParallelism)
3694
3695#undef KERNEL_ENVIRONMENT_CONFIGURATION_SETTER
3696
3697 /// See AbstractAttribute::initialize(...).
3698 void initialize(Attributor &A) override {
3699 // This is a high-level transform that might change the constant arguments
3700 // of the init and dinit calls. We need to tell the Attributor about this
3701 // to avoid other parts using the current constant value for simpliication.
3702 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3703
3704 Function *Fn = getAnchorScope();
3705
3706 OMPInformationCache::RuntimeFunctionInfo &InitRFI =
3707 OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
3708 OMPInformationCache::RuntimeFunctionInfo &DeinitRFI =
3709 OMPInfoCache.RFIs[OMPRTL___kmpc_target_deinit];
3710
3711 // For kernels we perform more initialization work, first we find the init
3712 // and deinit calls.
3713 auto StoreCallBase = [](Use &U,
3714 OMPInformationCache::RuntimeFunctionInfo &RFI,
3715 CallBase *&Storage) {
3716 CallBase *CB = OpenMPOpt::getCallIfRegularCall(U, &RFI);
3717 assert(CB &&
3718 "Unexpected use of __kmpc_target_init or __kmpc_target_deinit!");
3719 assert(!Storage &&
3720 "Multiple uses of __kmpc_target_init or __kmpc_target_deinit!");
3721 Storage = CB;
3722 return false;
3723 };
3724 InitRFI.foreachUse(
3725 [&](Use &U, Function &) {
3726 StoreCallBase(U, InitRFI, KernelInitCB);
3727 return false;
3728 },
3729 Fn);
3730 DeinitRFI.foreachUse(
3731 [&](Use &U, Function &) {
3732 StoreCallBase(U, DeinitRFI, KernelDeinitCB);
3733 return false;
3734 },
3735 Fn);
3736
3737 // Ignore kernels without initializers such as global constructors.
3738 if (!KernelInitCB || !KernelDeinitCB)
3739 return;
3740
3741 // Add itself to the reaching kernel and set IsKernelEntry.
3742 ReachingKernelEntries.insert(Fn);
3743 IsKernelEntry = true;
3744
3745 KernelEnvC =
3747 GlobalVariable *KernelEnvGV =
3749
3751 KernelConfigurationSimplifyCB =
3752 [&](const GlobalVariable &GV, const AbstractAttribute *AA,
3753 bool &UsedAssumedInformation) -> std::optional<Constant *> {
3754 if (!isAtFixpoint()) {
3755 if (!AA)
3756 return nullptr;
3757 UsedAssumedInformation = true;
3758 A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
3759 }
3760 return KernelEnvC;
3761 };
3762
3763 A.registerGlobalVariableSimplificationCallback(
3764 *KernelEnvGV, KernelConfigurationSimplifyCB);
3765
3766 // Check if we know we are in SPMD-mode already.
3767 ConstantInt *ExecModeC =
3768 KernelInfo::getExecModeFromKernelEnvironment(KernelEnvC);
3769 ConstantInt *AssumedExecModeC = ConstantInt::get(
3770 ExecModeC->getIntegerType(),
3772 if (ExecModeC->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD)
3773 SPMDCompatibilityTracker.indicateOptimisticFixpoint();
3775 // This is a generic region but SPMDization is disabled so stop
3776 // tracking.
3777 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3778 else
3779 setExecModeOfKernelEnvironment(AssumedExecModeC);
3780
3781 const Triple T(Fn->getParent()->getTargetTriple());
3782 auto *Int32Ty = Type::getInt32Ty(Fn->getContext());
3783 auto [MinThreads, MaxThreads] =
3785 if (MinThreads)
3786 setMinThreadsOfKernelEnvironment(ConstantInt::get(Int32Ty, MinThreads));
3787 if (MaxThreads)
3788 setMaxThreadsOfKernelEnvironment(ConstantInt::get(Int32Ty, MaxThreads));
3789 auto [MinTeams, MaxTeams] =
3791 if (MinTeams)
3792 setMinTeamsOfKernelEnvironment(ConstantInt::get(Int32Ty, MinTeams));
3793 if (MaxTeams)
3794 setMaxTeamsOfKernelEnvironment(ConstantInt::get(Int32Ty, MaxTeams));
3795
3796 ConstantInt *MayUseNestedParallelismC =
3797 KernelInfo::getMayUseNestedParallelismFromKernelEnvironment(KernelEnvC);
3798 ConstantInt *AssumedMayUseNestedParallelismC = ConstantInt::get(
3799 MayUseNestedParallelismC->getIntegerType(), NestedParallelism);
3800 setMayUseNestedParallelismOfKernelEnvironment(
3801 AssumedMayUseNestedParallelismC);
3802
3804 ConstantInt *UseGenericStateMachineC =
3805 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
3806 KernelEnvC);
3807 ConstantInt *AssumedUseGenericStateMachineC =
3808 ConstantInt::get(UseGenericStateMachineC->getIntegerType(), false);
3809 setUseGenericStateMachineOfKernelEnvironment(
3810 AssumedUseGenericStateMachineC);
3811 }
3812
3813 // Register virtual uses of functions we might need to preserve.
3814 auto RegisterVirtualUse = [&](RuntimeFunction RFKind,
3816 if (!OMPInfoCache.RFIs[RFKind].Declaration)
3817 return;
3818 A.registerVirtualUseCallback(*OMPInfoCache.RFIs[RFKind].Declaration, CB);
3819 };
3820
3821 // Add a dependence to ensure updates if the state changes.
3822 auto AddDependence = [](Attributor &A, const AAKernelInfo *KI,
3823 const AbstractAttribute *QueryingAA) {
3824 if (QueryingAA) {
3825 A.recordDependence(*KI, *QueryingAA, DepClassTy::OPTIONAL);
3826 }
3827 return true;
3828 };
3829
3830 Attributor::VirtualUseCallbackTy CustomStateMachineUseCB =
3831 [&](Attributor &A, const AbstractAttribute *QueryingAA) {
3832 // Whenever we create a custom state machine we will insert calls to
3833 // __kmpc_get_hardware_num_threads_in_block,
3834 // __kmpc_get_warp_size,
3835 // __kmpc_barrier_simple_generic,
3836 // __kmpc_kernel_parallel, and
3837 // __kmpc_kernel_end_parallel.
3838 // Not needed if we are on track for SPMDzation.
3839 if (SPMDCompatibilityTracker.isValidState())
3840 return AddDependence(A, this, QueryingAA);
3841 // Not needed if we can't rewrite due to an invalid state.
3842 if (!ReachedKnownParallelRegions.isValidState())
3843 return AddDependence(A, this, QueryingAA);
3844 return false;
3845 };
3846
3847 // Not needed if we are pre-runtime merge.
3848 if (!KernelInitCB->getCalledFunction()->isDeclaration()) {
3849 RegisterVirtualUse(OMPRTL___kmpc_get_hardware_num_threads_in_block,
3850 CustomStateMachineUseCB);
3851 RegisterVirtualUse(OMPRTL___kmpc_get_warp_size, CustomStateMachineUseCB);
3852 RegisterVirtualUse(OMPRTL___kmpc_barrier_simple_generic,
3853 CustomStateMachineUseCB);
3854 RegisterVirtualUse(OMPRTL___kmpc_kernel_parallel,
3855 CustomStateMachineUseCB);
3856 RegisterVirtualUse(OMPRTL___kmpc_kernel_end_parallel,
3857 CustomStateMachineUseCB);
3858 }
3859
3860 // If we do not perform SPMDzation we do not need the virtual uses below.
3861 if (SPMDCompatibilityTracker.isAtFixpoint())
3862 return;
3863
3864 Attributor::VirtualUseCallbackTy HWThreadIdUseCB =
3865 [&](Attributor &A, const AbstractAttribute *QueryingAA) {
3866 // Whenever we perform SPMDzation we will insert
3867 // __kmpc_get_hardware_thread_id_in_block calls.
3868 if (!SPMDCompatibilityTracker.isValidState())
3869 return AddDependence(A, this, QueryingAA);
3870 return false;
3871 };
3872 RegisterVirtualUse(OMPRTL___kmpc_get_hardware_thread_id_in_block,
3873 HWThreadIdUseCB);
3874
3875 Attributor::VirtualUseCallbackTy SPMDBarrierUseCB =
3876 [&](Attributor &A, const AbstractAttribute *QueryingAA) {
3877 // Whenever we perform SPMDzation with guarding we will insert
3878 // __kmpc_simple_barrier_spmd calls. If SPMDzation failed, there is
3879 // nothing to guard, or there are no parallel regions, we don't need
3880 // the calls.
3881 if (!SPMDCompatibilityTracker.isValidState())
3882 return AddDependence(A, this, QueryingAA);
3883 if (SPMDCompatibilityTracker.empty())
3884 return AddDependence(A, this, QueryingAA);
3885 if (!mayContainParallelRegion())
3886 return AddDependence(A, this, QueryingAA);
3887 return false;
3888 };
3889 RegisterVirtualUse(OMPRTL___kmpc_barrier_simple_spmd, SPMDBarrierUseCB);
3890 }
3891
3892 /// Sanitize the string \p S such that it is a suitable global symbol name.
3893 static std::string sanitizeForGlobalName(std::string S) {
3894 std::replace_if(
3895 S.begin(), S.end(),
3896 [](const char C) {
3897 return !((C >= 'a' && C <= 'z') || (C >= 'A' && C <= 'Z') ||
3898 (C >= '0' && C <= '9') || C == '_');
3899 },
3900 '.');
3901 return S;
3902 }
3903
3904 /// Modify the IR based on the KernelInfoState as the fixpoint iteration is
3905 /// finished now.
3906 ChangeStatus manifest(Attributor &A) override {
3907 // If we are not looking at a kernel with __kmpc_target_init and
3908 // __kmpc_target_deinit call we cannot actually manifest the information.
3909 if (!KernelInitCB || !KernelDeinitCB)
3910 return ChangeStatus::UNCHANGED;
3911
3912 ChangeStatus Changed = ChangeStatus::UNCHANGED;
3913
3914 bool HasBuiltStateMachine = true;
3915 if (!changeToSPMDMode(A, Changed)) {
3916 if (!KernelInitCB->getCalledFunction()->isDeclaration())
3917 HasBuiltStateMachine = buildCustomStateMachine(A, Changed);
3918 else
3919 HasBuiltStateMachine = false;
3920 }
3921
3922 // We need to reset KernelEnvC if specific rewriting is not done.
3923 ConstantStruct *ExistingKernelEnvC =
3925 ConstantInt *OldUseGenericStateMachineVal =
3926 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
3927 ExistingKernelEnvC);
3928 if (!HasBuiltStateMachine)
3929 setUseGenericStateMachineOfKernelEnvironment(
3930 OldUseGenericStateMachineVal);
3931
3932 // At last, update the KernelEnvc
3933 GlobalVariable *KernelEnvGV =
3935 if (KernelEnvGV->getInitializer() != KernelEnvC) {
3936 KernelEnvGV->setInitializer(KernelEnvC);
3937 Changed = ChangeStatus::CHANGED;
3938 }
3939
3940 return Changed;
3941 }
3942
3943 void insertInstructionGuardsHelper(Attributor &A) {
3944 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3945
3946 auto CreateGuardedRegion = [&](Instruction *RegionStartI,
3947 Instruction *RegionEndI) {
3948 LoopInfo *LI = nullptr;
3949 DominatorTree *DT = nullptr;
3950 MemorySSAUpdater *MSU = nullptr;
3951 using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
3952
3953 BasicBlock *ParentBB = RegionStartI->getParent();
3954 Function *Fn = ParentBB->getParent();
3955 Module &M = *Fn->getParent();
3956
3957 // Create all the blocks and logic.
3958 // ParentBB:
3959 // goto RegionCheckTidBB
3960 // RegionCheckTidBB:
3961 // Tid = __kmpc_hardware_thread_id()
3962 // if (Tid != 0)
3963 // goto RegionBarrierBB
3964 // RegionStartBB:
3965 // <execute instructions guarded>
3966 // goto RegionEndBB
3967 // RegionEndBB:
3968 // <store escaping values to shared mem>
3969 // goto RegionBarrierBB
3970 // RegionBarrierBB:
3971 // __kmpc_simple_barrier_spmd()
3972 // // second barrier is omitted if lacking escaping values.
3973 // <load escaping values from shared mem>
3974 // __kmpc_simple_barrier_spmd()
3975 // goto RegionExitBB
3976 // RegionExitBB:
3977 // <execute rest of instructions>
3978
3979 BasicBlock *RegionEndBB = SplitBlock(ParentBB, RegionEndI->getNextNode(),
3980 DT, LI, MSU, "region.guarded.end");
3981 BasicBlock *RegionBarrierBB =
3982 SplitBlock(RegionEndBB, &*RegionEndBB->getFirstInsertionPt(), DT, LI,
3983 MSU, "region.barrier");
3984 BasicBlock *RegionExitBB =
3985 SplitBlock(RegionBarrierBB, &*RegionBarrierBB->getFirstInsertionPt(),
3986 DT, LI, MSU, "region.exit");
3987 BasicBlock *RegionStartBB =
3988 SplitBlock(ParentBB, RegionStartI, DT, LI, MSU, "region.guarded");
3989
3990 assert(ParentBB->getUniqueSuccessor() == RegionStartBB &&
3991 "Expected a different CFG");
3992
3993 BasicBlock *RegionCheckTidBB = SplitBlock(
3994 ParentBB, ParentBB->getTerminator(), DT, LI, MSU, "region.check.tid");
3995
3996 // Register basic blocks with the Attributor.
3997 A.registerManifestAddedBasicBlock(*RegionEndBB);
3998 A.registerManifestAddedBasicBlock(*RegionBarrierBB);
3999 A.registerManifestAddedBasicBlock(*RegionExitBB);
4000 A.registerManifestAddedBasicBlock(*RegionStartBB);
4001 A.registerManifestAddedBasicBlock(*RegionCheckTidBB);
4002
4003 bool HasBroadcastValues = false;
4004 // Find escaping outputs from the guarded region to outside users and
4005 // broadcast their values to them.
4006 for (Instruction &I : *RegionStartBB) {
4007 SmallVector<Use *, 4> OutsideUses;
4008 for (Use &U : I.uses()) {
4009 Instruction &UsrI = *cast<Instruction>(U.getUser());
4010 if (UsrI.getParent() != RegionStartBB)
4011 OutsideUses.push_back(&U);
4012 }
4013
4014 if (OutsideUses.empty())
4015 continue;
4016
4017 HasBroadcastValues = true;
4018
4019 // Emit a global variable in shared memory to store the broadcasted
4020 // value.
4021 auto *SharedMem = new GlobalVariable(
4022 M, I.getType(), /* IsConstant */ false,
4024 sanitizeForGlobalName(
4025 (I.getName() + ".guarded.output.alloc").str()),
4027 static_cast<unsigned>(AddressSpace::Shared));
4028
4029 // Emit a store instruction to update the value.
4030 new StoreInst(&I, SharedMem,
4031 RegionEndBB->getTerminator()->getIterator());
4032
4033 LoadInst *LoadI = new LoadInst(
4034 I.getType(), SharedMem, I.getName() + ".guarded.output.load",
4035 RegionBarrierBB->getTerminator()->getIterator());
4036
4037 // Emit a load instruction and replace uses of the output value.
4038 for (Use *U : OutsideUses)
4039 A.changeUseAfterManifest(*U, *LoadI);
4040 }
4041
4042 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4043
4044 // Go to tid check BB in ParentBB.
4045 const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
4046 ParentBB->getTerminator()->eraseFromParent();
4048 InsertPointTy(ParentBB, ParentBB->end()), DL);
4049 OMPInfoCache.OMPBuilder.updateToLocation(Loc);
4050 uint32_t SrcLocStrSize;
4051 auto *SrcLocStr =
4052 OMPInfoCache.OMPBuilder.getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4053 Value *Ident =
4054 OMPInfoCache.OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4055 BranchInst::Create(RegionCheckTidBB, ParentBB)->setDebugLoc(DL);
4056
4057 // Add check for Tid in RegionCheckTidBB
4058 RegionCheckTidBB->getTerminator()->eraseFromParent();
4059 OpenMPIRBuilder::LocationDescription LocRegionCheckTid(
4060 InsertPointTy(RegionCheckTidBB, RegionCheckTidBB->end()), DL);
4061 OMPInfoCache.OMPBuilder.updateToLocation(LocRegionCheckTid);
4062 FunctionCallee HardwareTidFn =
4063 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4064 M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
4065 CallInst *Tid =
4066 OMPInfoCache.OMPBuilder.Builder.CreateCall(HardwareTidFn, {});
4067 Tid->setDebugLoc(DL);
4068 OMPInfoCache.setCallingConvention(HardwareTidFn, Tid);
4069 Value *TidCheck = OMPInfoCache.OMPBuilder.Builder.CreateIsNull(Tid);
4070 OMPInfoCache.OMPBuilder.Builder
4071 .CreateCondBr(TidCheck, RegionStartBB, RegionBarrierBB)
4072 ->setDebugLoc(DL);
4073
4074 // First barrier for synchronization, ensures main thread has updated
4075 // values.
4076 FunctionCallee BarrierFn =
4077 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4078 M, OMPRTL___kmpc_barrier_simple_spmd);
4079 OMPInfoCache.OMPBuilder.updateToLocation(InsertPointTy(
4080 RegionBarrierBB, RegionBarrierBB->getFirstInsertionPt()));
4081 CallInst *Barrier =
4082 OMPInfoCache.OMPBuilder.Builder.CreateCall(BarrierFn, {Ident, Tid});
4083 Barrier->setDebugLoc(DL);
4084 OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
4085
4086 // Second barrier ensures workers have read broadcast values.
4087 if (HasBroadcastValues) {
4088 CallInst *Barrier =
4089 CallInst::Create(BarrierFn, {Ident, Tid}, "",
4090 RegionBarrierBB->getTerminator()->getIterator());
4091 Barrier->setDebugLoc(DL);
4092 OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
4093 }
4094 };
4095
4096 auto &AllocSharedRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
4098 for (Instruction *GuardedI : SPMDCompatibilityTracker) {
4099 BasicBlock *BB = GuardedI->getParent();
4100 if (!Visited.insert(BB).second)
4101 continue;
4102
4104 Instruction *LastEffect = nullptr;
4105 BasicBlock::reverse_iterator IP = BB->rbegin(), IPEnd = BB->rend();
4106 while (++IP != IPEnd) {
4107 if (!IP->mayHaveSideEffects() && !IP->mayReadFromMemory())
4108 continue;
4109 Instruction *I = &*IP;
4110 if (OpenMPOpt::getCallIfRegularCall(*I, &AllocSharedRFI))
4111 continue;
4112 if (!I->user_empty() || !SPMDCompatibilityTracker.contains(I)) {
4113 LastEffect = nullptr;
4114 continue;
4115 }
4116 if (LastEffect)
4117 Reorders.push_back({I, LastEffect});
4118 LastEffect = &*IP;
4119 }
4120 for (auto &Reorder : Reorders)
4121 Reorder.first->moveBefore(Reorder.second);
4122 }
4123
4125
4126 for (Instruction *GuardedI : SPMDCompatibilityTracker) {
4127 BasicBlock *BB = GuardedI->getParent();
4128 auto *CalleeAA = A.lookupAAFor<AAKernelInfo>(
4129 IRPosition::function(*GuardedI->getFunction()), nullptr,
4130 DepClassTy::NONE);
4131 assert(CalleeAA != nullptr && "Expected Callee AAKernelInfo");
4132 auto &CalleeAAFunction = *cast<AAKernelInfoFunction>(CalleeAA);
4133 // Continue if instruction is already guarded.
4134 if (CalleeAAFunction.getGuardedInstructions().contains(GuardedI))
4135 continue;
4136
4137 Instruction *GuardedRegionStart = nullptr, *GuardedRegionEnd = nullptr;
4138 for (Instruction &I : *BB) {
4139 // If instruction I needs to be guarded update the guarded region
4140 // bounds.
4141 if (SPMDCompatibilityTracker.contains(&I)) {
4142 CalleeAAFunction.getGuardedInstructions().insert(&I);
4143 if (GuardedRegionStart)
4144 GuardedRegionEnd = &I;
4145 else
4146 GuardedRegionStart = GuardedRegionEnd = &I;
4147
4148 continue;
4149 }
4150
4151 // Instruction I does not need guarding, store
4152 // any region found and reset bounds.
4153 if (GuardedRegionStart) {
4154 GuardedRegions.push_back(
4155 std::make_pair(GuardedRegionStart, GuardedRegionEnd));
4156 GuardedRegionStart = nullptr;
4157 GuardedRegionEnd = nullptr;
4158 }
4159 }
4160 }
4161
4162 for (auto &GR : GuardedRegions)
4163 CreateGuardedRegion(GR.first, GR.second);
4164 }
4165
4166 void forceSingleThreadPerWorkgroupHelper(Attributor &A) {
4167 // Only allow 1 thread per workgroup to continue executing the user code.
4168 //
4169 // InitCB = __kmpc_target_init(...)
4170 // ThreadIdInBlock = __kmpc_get_hardware_thread_id_in_block();
4171 // if (ThreadIdInBlock != 0) return;
4172 // UserCode:
4173 // // user code
4174 //
4175 auto &Ctx = getAnchorValue().getContext();
4176 Function *Kernel = getAssociatedFunction();
4177 assert(Kernel && "Expected an associated function!");
4178
4179 // Create block for user code to branch to from initial block.
4180 BasicBlock *InitBB = KernelInitCB->getParent();
4181 BasicBlock *UserCodeBB = InitBB->splitBasicBlock(
4182 KernelInitCB->getNextNode(), "main.thread.user_code");
4183 BasicBlock *ReturnBB =
4184 BasicBlock::Create(Ctx, "exit.threads", Kernel, UserCodeBB);
4185
4186 // Register blocks with attributor:
4187 A.registerManifestAddedBasicBlock(*InitBB);
4188 A.registerManifestAddedBasicBlock(*UserCodeBB);
4189 A.registerManifestAddedBasicBlock(*ReturnBB);
4190
4191 // Debug location:
4192 const DebugLoc &DLoc = KernelInitCB->getDebugLoc();
4193 ReturnInst::Create(Ctx, ReturnBB)->setDebugLoc(DLoc);
4194 InitBB->getTerminator()->eraseFromParent();
4195
4196 // Prepare call to OMPRTL___kmpc_get_hardware_thread_id_in_block.
4197 Module &M = *Kernel->getParent();
4198 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4199 FunctionCallee ThreadIdInBlockFn =
4200 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4201 M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
4202
4203 // Get thread ID in block.
4204 CallInst *ThreadIdInBlock =
4205 CallInst::Create(ThreadIdInBlockFn, "thread_id.in.block", InitBB);
4206 OMPInfoCache.setCallingConvention(ThreadIdInBlockFn, ThreadIdInBlock);
4207 ThreadIdInBlock->setDebugLoc(DLoc);
4208
4209 // Eliminate all threads in the block with ID not equal to 0:
4210 Instruction *IsMainThread =
4211 ICmpInst::Create(ICmpInst::ICmp, CmpInst::ICMP_NE, ThreadIdInBlock,
4212 ConstantInt::get(ThreadIdInBlock->getType(), 0),
4213 "thread.is_main", InitBB);
4214 IsMainThread->setDebugLoc(DLoc);
4215 BranchInst::Create(ReturnBB, UserCodeBB, IsMainThread, InitBB);
4216 }
4217
4218 bool changeToSPMDMode(Attributor &A, ChangeStatus &Changed) {
4219 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4220
4221 // We cannot change to SPMD mode if the runtime functions aren't availible.
4222 if (!OMPInfoCache.runtimeFnsAvailable(
4223 {OMPRTL___kmpc_get_hardware_thread_id_in_block,
4224 OMPRTL___kmpc_barrier_simple_spmd}))
4225 return false;
4226
4227 if (!SPMDCompatibilityTracker.isAssumed()) {
4228 for (Instruction *NonCompatibleI : SPMDCompatibilityTracker) {
4229 if (!NonCompatibleI)
4230 continue;
4231
4232 // Skip diagnostics on calls to known OpenMP runtime functions for now.
4233 if (auto *CB = dyn_cast<CallBase>(NonCompatibleI))
4234 if (OMPInfoCache.RTLFunctions.contains(CB->getCalledFunction()))
4235 continue;
4236
4237 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
4238 ORA << "Value has potential side effects preventing SPMD-mode "
4239 "execution";
4240 if (isa<CallBase>(NonCompatibleI)) {
4241 ORA << ". Add `__attribute__((assume(\"ompx_spmd_amenable\")))` to "
4242 "the called function to override";
4243 }
4244 return ORA << ".";
4245 };
4246 A.emitRemark<OptimizationRemarkAnalysis>(NonCompatibleI, "OMP121",
4247 Remark);
4248
4249 LLVM_DEBUG(dbgs() << TAG << "SPMD-incompatible side-effect: "
4250 << *NonCompatibleI << "\n");
4251 }
4252
4253 return false;
4254 }
4255
4256 // Get the actual kernel, could be the caller of the anchor scope if we have
4257 // a debug wrapper.
4258 Function *Kernel = getAnchorScope();
4259 if (Kernel->hasLocalLinkage()) {
4260 assert(Kernel->hasOneUse() && "Unexpected use of debug kernel wrapper.");
4261 auto *CB = cast<CallBase>(Kernel->user_back());
4262 Kernel = CB->getCaller();
4263 }
4264 assert(omp::isOpenMPKernel(*Kernel) && "Expected kernel function!");
4265
4266 // Check if the kernel is already in SPMD mode, if so, return success.
4267 ConstantStruct *ExistingKernelEnvC =
4269 auto *ExecModeC =
4270 KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC);
4271 const int8_t ExecModeVal = ExecModeC->getSExtValue();
4272 if (ExecModeVal != OMP_TGT_EXEC_MODE_GENERIC)
4273 return true;
4274
4275 // We will now unconditionally modify the IR, indicate a change.
4276 Changed = ChangeStatus::CHANGED;
4277
4278 // Do not use instruction guards when no parallel is present inside
4279 // the target region.
4280 if (mayContainParallelRegion())
4281 insertInstructionGuardsHelper(A);
4282 else
4283 forceSingleThreadPerWorkgroupHelper(A);
4284
4285 // Adjust the global exec mode flag that tells the runtime what mode this
4286 // kernel is executed in.
4287 assert(ExecModeVal == OMP_TGT_EXEC_MODE_GENERIC &&
4288 "Initially non-SPMD kernel has SPMD exec mode!");
4289 setExecModeOfKernelEnvironment(
4290 ConstantInt::get(ExecModeC->getIntegerType(),
4291 ExecModeVal | OMP_TGT_EXEC_MODE_GENERIC_SPMD));
4292
4293 ++NumOpenMPTargetRegionKernelsSPMD;
4294
4295 auto Remark = [&](OptimizationRemark OR) {
4296 return OR << "Transformed generic-mode kernel to SPMD-mode.";
4297 };
4298 A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP120", Remark);
4299 return true;
4300 };
4301
4302 bool buildCustomStateMachine(Attributor &A, ChangeStatus &Changed) {
4303 // If we have disabled state machine rewrites, don't make a custom one
4305 return false;
4306
4307 // Don't rewrite the state machine if we are not in a valid state.
4308 if (!ReachedKnownParallelRegions.isValidState())
4309 return false;
4310
4311 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4312 if (!OMPInfoCache.runtimeFnsAvailable(
4313 {OMPRTL___kmpc_get_hardware_num_threads_in_block,
4314 OMPRTL___kmpc_get_warp_size, OMPRTL___kmpc_barrier_simple_generic,
4315 OMPRTL___kmpc_kernel_parallel, OMPRTL___kmpc_kernel_end_parallel}))
4316 return false;
4317
4318 ConstantStruct *ExistingKernelEnvC =
4320
4321 // Check if the current configuration is non-SPMD and generic state machine.
4322 // If we already have SPMD mode or a custom state machine we do not need to
4323 // go any further. If it is anything but a constant something is weird and
4324 // we give up.
4325 ConstantInt *UseStateMachineC =
4326 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
4327 ExistingKernelEnvC);
4328 ConstantInt *ModeC =
4329 KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC);
4330
4331 // If we are stuck with generic mode, try to create a custom device (=GPU)
4332 // state machine which is specialized for the parallel regions that are
4333 // reachable by the kernel.
4334 if (UseStateMachineC->isZero() ||
4336 return false;
4337
4338 Changed = ChangeStatus::CHANGED;
4339
4340 // If not SPMD mode, indicate we use a custom state machine now.
4341 setUseGenericStateMachineOfKernelEnvironment(
4342 ConstantInt::get(UseStateMachineC->getIntegerType(), false));
4343
4344 // If we don't actually need a state machine we are done here. This can
4345 // happen if there simply are no parallel regions. In the resulting kernel
4346 // all worker threads will simply exit right away, leaving the main thread
4347 // to do the work alone.
4348 if (!mayContainParallelRegion()) {
4349 ++NumOpenMPTargetRegionKernelsWithoutStateMachine;
4350
4351 auto Remark = [&](OptimizationRemark OR) {
4352 return OR << "Removing unused state machine from generic-mode kernel.";
4353 };
4354 A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP130", Remark);
4355
4356 return true;
4357 }
4358
4359 // Keep track in the statistics of our new shiny custom state machine.
4360 if (ReachedUnknownParallelRegions.empty()) {
4361 ++NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback;
4362
4363 auto Remark = [&](OptimizationRemark OR) {
4364 return OR << "Rewriting generic-mode kernel with a customized state "
4365 "machine.";
4366 };
4367 A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP131", Remark);
4368 } else {
4369 ++NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback;
4370
4372 return OR << "Generic-mode kernel is executed with a customized state "
4373 "machine that requires a fallback.";
4374 };
4375 A.emitRemark<OptimizationRemarkAnalysis>(KernelInitCB, "OMP132", Remark);
4376
4377 // Tell the user why we ended up with a fallback.
4378 for (CallBase *UnknownParallelRegionCB : ReachedUnknownParallelRegions) {
4379 if (!UnknownParallelRegionCB)
4380 continue;
4381 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
4382 return ORA << "Call may contain unknown parallel regions. Use "
4383 << "`__attribute__((assume(\"omp_no_parallelism\")))` to "
4384 "override.";
4385 };
4386 A.emitRemark<OptimizationRemarkAnalysis>(UnknownParallelRegionCB,
4387 "OMP133", Remark);
4388 }
4389 }
4390
4391 // Create all the blocks:
4392 //
4393 // InitCB = __kmpc_target_init(...)
4394 // BlockHwSize =
4395 // __kmpc_get_hardware_num_threads_in_block();
4396 // WarpSize = __kmpc_get_warp_size();
4397 // BlockSize = BlockHwSize - WarpSize;
4398 // IsWorkerCheckBB: bool IsWorker = InitCB != -1;
4399 // if (IsWorker) {
4400 // if (InitCB >= BlockSize) return;
4401 // SMBeginBB: __kmpc_barrier_simple_generic(...);
4402 // void *WorkFn;
4403 // bool Active = __kmpc_kernel_parallel(&WorkFn);
4404 // if (!WorkFn) return;
4405 // SMIsActiveCheckBB: if (Active) {
4406 // SMIfCascadeCurrentBB: if (WorkFn == <ParFn0>)
4407 // ParFn0(...);
4408 // SMIfCascadeCurrentBB: else if (WorkFn == <ParFn1>)
4409 // ParFn1(...);
4410 // ...
4411 // SMIfCascadeCurrentBB: else
4412 // ((WorkFnTy*)WorkFn)(...);
4413 // SMEndParallelBB: __kmpc_kernel_end_parallel(...);
4414 // }
4415 // SMDoneBB: __kmpc_barrier_simple_generic(...);
4416 // goto SMBeginBB;
4417 // }
4418 // UserCodeEntryBB: // user code
4419 // __kmpc_target_deinit(...)
4420 //
4421 auto &Ctx = getAnchorValue().getContext();
4422 Function *Kernel = getAssociatedFunction();
4423 assert(Kernel && "Expected an associated function!");
4424
4425 BasicBlock *InitBB = KernelInitCB->getParent();
4426 BasicBlock *UserCodeEntryBB = InitBB->splitBasicBlock(
4427 KernelInitCB->getNextNode(), "thread.user_code.check");
4428 BasicBlock *IsWorkerCheckBB =
4429 BasicBlock::Create(Ctx, "is_worker_check", Kernel, UserCodeEntryBB);
4430 BasicBlock *StateMachineBeginBB = BasicBlock::Create(
4431 Ctx, "worker_state_machine.begin", Kernel, UserCodeEntryBB);
4432 BasicBlock *StateMachineFinishedBB = BasicBlock::Create(
4433 Ctx, "worker_state_machine.finished", Kernel, UserCodeEntryBB);
4434 BasicBlock *StateMachineIsActiveCheckBB = BasicBlock::Create(
4435 Ctx, "worker_state_machine.is_active.check", Kernel, UserCodeEntryBB);
4436 BasicBlock *StateMachineIfCascadeCurrentBB =
4437 BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
4438 Kernel, UserCodeEntryBB);
4439 BasicBlock *StateMachineEndParallelBB =
4440 BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.end",
4441 Kernel, UserCodeEntryBB);
4442 BasicBlock *StateMachineDoneBarrierBB = BasicBlock::Create(
4443 Ctx, "worker_state_machine.done.barrier", Kernel, UserCodeEntryBB);
4444 A.registerManifestAddedBasicBlock(*InitBB);
4445 A.registerManifestAddedBasicBlock(*UserCodeEntryBB);
4446 A.registerManifestAddedBasicBlock(*IsWorkerCheckBB);
4447 A.registerManifestAddedBasicBlock(*StateMachineBeginBB);
4448 A.registerManifestAddedBasicBlock(*StateMachineFinishedBB);
4449 A.registerManifestAddedBasicBlock(*StateMachineIsActiveCheckBB);
4450 A.registerManifestAddedBasicBlock(*StateMachineIfCascadeCurrentBB);
4451 A.registerManifestAddedBasicBlock(*StateMachineEndParallelBB);
4452 A.registerManifestAddedBasicBlock(*StateMachineDoneBarrierBB);
4453
4454 const DebugLoc &DLoc = KernelInitCB->getDebugLoc();
4455 ReturnInst::Create(Ctx, StateMachineFinishedBB)->setDebugLoc(DLoc);
4456 InitBB->getTerminator()->eraseFromParent();
4457
4458 Instruction *IsWorker =
4459 ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_NE, KernelInitCB,
4460 ConstantInt::get(KernelInitCB->getType(), -1),
4461 "thread.is_worker", InitBB);
4462 IsWorker->setDebugLoc(DLoc);
4463 BranchInst::Create(IsWorkerCheckBB, UserCodeEntryBB, IsWorker, InitBB);
4464
4465 Module &M = *Kernel->getParent();
4466 FunctionCallee BlockHwSizeFn =
4467 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4468 M, OMPRTL___kmpc_get_hardware_num_threads_in_block);
4469 FunctionCallee WarpSizeFn =
4470 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4471 M, OMPRTL___kmpc_get_warp_size);
4472 CallInst *BlockHwSize =
4473 CallInst::Create(BlockHwSizeFn, "block.hw_size", IsWorkerCheckBB);
4474 OMPInfoCache.setCallingConvention(BlockHwSizeFn, BlockHwSize);
4475 BlockHwSize->setDebugLoc(DLoc);
4476 CallInst *WarpSize =
4477 CallInst::Create(WarpSizeFn, "warp.size", IsWorkerCheckBB);
4478 OMPInfoCache.setCallingConvention(WarpSizeFn, WarpSize);
4479 WarpSize->setDebugLoc(DLoc);
4480 Instruction *BlockSize = BinaryOperator::CreateSub(
4481 BlockHwSize, WarpSize, "block.size", IsWorkerCheckBB);
4482 BlockSize->setDebugLoc(DLoc);
4483 Instruction *IsMainOrWorker = ICmpInst::Create(
4484 ICmpInst::ICmp, llvm::CmpInst::ICMP_SLT, KernelInitCB, BlockSize,
4485 "thread.is_main_or_worker", IsWorkerCheckBB);
4486 IsMainOrWorker->setDebugLoc(DLoc);
4487 BranchInst::Create(StateMachineBeginBB, StateMachineFinishedBB,
4488 IsMainOrWorker, IsWorkerCheckBB);
4489
4490 // Create local storage for the work function pointer.
4491 const DataLayout &DL = M.getDataLayout();
4492 Type *VoidPtrTy = PointerType::getUnqual(Ctx);
4493 Instruction *WorkFnAI =
4494 new AllocaInst(VoidPtrTy, DL.getAllocaAddrSpace(), nullptr,
4495 "worker.work_fn.addr", Kernel->getEntryBlock().begin());
4496 WorkFnAI->setDebugLoc(DLoc);
4497
4498 OMPInfoCache.OMPBuilder.updateToLocation(
4500 IRBuilder<>::InsertPoint(StateMachineBeginBB,
4501 StateMachineBeginBB->end()),
4502 DLoc));
4503
4504 Value *Ident = KernelInfo::getIdentFromKernelEnvironment(KernelEnvC);
4505 Value *GTid = KernelInitCB;
4506
4507 FunctionCallee BarrierFn =
4508 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4509 M, OMPRTL___kmpc_barrier_simple_generic);
4510 CallInst *Barrier =
4511 CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineBeginBB);
4512 OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
4513 Barrier->setDebugLoc(DLoc);
4514
4515 if (WorkFnAI->getType()->getPointerAddressSpace() !=
4516 (unsigned int)AddressSpace::Generic) {
4517 WorkFnAI = new AddrSpaceCastInst(
4518 WorkFnAI, PointerType::get(Ctx, (unsigned int)AddressSpace::Generic),
4519 WorkFnAI->getName() + ".generic", StateMachineBeginBB);
4520 WorkFnAI->setDebugLoc(DLoc);
4521 }
4522
4523 FunctionCallee KernelParallelFn =
4524 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4525 M, OMPRTL___kmpc_kernel_parallel);
4526 CallInst *IsActiveWorker = CallInst::Create(
4527 KernelParallelFn, {WorkFnAI}, "worker.is_active", StateMachineBeginBB);
4528 OMPInfoCache.setCallingConvention(KernelParallelFn, IsActiveWorker);
4529 IsActiveWorker->setDebugLoc(DLoc);
4530 Instruction *WorkFn = new LoadInst(VoidPtrTy, WorkFnAI, "worker.work_fn",
4531 StateMachineBeginBB);
4532 WorkFn->setDebugLoc(DLoc);
4533
4534 FunctionType *ParallelRegionFnTy = FunctionType::get(
4536 false);
4537
4538 Instruction *IsDone =
4539 ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFn,
4540 Constant::getNullValue(VoidPtrTy), "worker.is_done",
4541 StateMachineBeginBB);
4542 IsDone->setDebugLoc(DLoc);
4543 BranchInst::Create(StateMachineFinishedBB, StateMachineIsActiveCheckBB,
4544 IsDone, StateMachineBeginBB)
4545 ->setDebugLoc(DLoc);
4546
4547 BranchInst::Create(StateMachineIfCascadeCurrentBB,
4548 StateMachineDoneBarrierBB, IsActiveWorker,
4549 StateMachineIsActiveCheckBB)
4550 ->setDebugLoc(DLoc);
4551
4552 Value *ZeroArg =
4553 Constant::getNullValue(ParallelRegionFnTy->getParamType(0));
4554
4555 const unsigned int WrapperFunctionArgNo = 6;
4556
4557 // Now that we have most of the CFG skeleton it is time for the if-cascade
4558 // that checks the function pointer we got from the runtime against the
4559 // parallel regions we expect, if there are any.
4560 for (int I = 0, E = ReachedKnownParallelRegions.size(); I < E; ++I) {
4561 auto *CB = ReachedKnownParallelRegions[I];
4562 auto *ParallelRegion = dyn_cast<Function>(
4563 CB->getArgOperand(WrapperFunctionArgNo)->stripPointerCasts());
4564 BasicBlock *PRExecuteBB = BasicBlock::Create(
4565 Ctx, "worker_state_machine.parallel_region.execute", Kernel,
4566 StateMachineEndParallelBB);
4567 CallInst::Create(ParallelRegion, {ZeroArg, GTid}, "", PRExecuteBB)
4568 ->setDebugLoc(DLoc);
4569 BranchInst::Create(StateMachineEndParallelBB, PRExecuteBB)
4570 ->setDebugLoc(DLoc);
4571
4572 BasicBlock *PRNextBB =
4573 BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
4574 Kernel, StateMachineEndParallelBB);
4575 A.registerManifestAddedBasicBlock(*PRExecuteBB);
4576 A.registerManifestAddedBasicBlock(*PRNextBB);
4577
4578 // Check if we need to compare the pointer at all or if we can just
4579 // call the parallel region function.
4580 Value *IsPR;
4581 if (I + 1 < E || !ReachedUnknownParallelRegions.empty()) {
4582 Instruction *CmpI = ICmpInst::Create(
4583 ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFn, ParallelRegion,
4584 "worker.check_parallel_region", StateMachineIfCascadeCurrentBB);
4585 CmpI->setDebugLoc(DLoc);
4586 IsPR = CmpI;
4587 } else {
4588 IsPR = ConstantInt::getTrue(Ctx);
4589 }
4590
4591 BranchInst::Create(PRExecuteBB, PRNextBB, IsPR,
4592 StateMachineIfCascadeCurrentBB)
4593 ->setDebugLoc(DLoc);
4594 StateMachineIfCascadeCurrentBB = PRNextBB;
4595 }
4596
4597 // At the end of the if-cascade we place the indirect function pointer call
4598 // in case we might need it, that is if there can be parallel regions we
4599 // have not handled in the if-cascade above.
4600 if (!ReachedUnknownParallelRegions.empty()) {
4601 StateMachineIfCascadeCurrentBB->setName(
4602 "worker_state_machine.parallel_region.fallback.execute");
4603 CallInst::Create(ParallelRegionFnTy, WorkFn, {ZeroArg, GTid}, "",
4604 StateMachineIfCascadeCurrentBB)
4605 ->setDebugLoc(DLoc);
4606 }
4607 BranchInst::Create(StateMachineEndParallelBB,
4608 StateMachineIfCascadeCurrentBB)
4609 ->setDebugLoc(DLoc);
4610
4611 FunctionCallee EndParallelFn =
4612 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4613 M, OMPRTL___kmpc_kernel_end_parallel);
4614 CallInst *EndParallel =
4615 CallInst::Create(EndParallelFn, {}, "", StateMachineEndParallelBB);
4616 OMPInfoCache.setCallingConvention(EndParallelFn, EndParallel);
4617 EndParallel->setDebugLoc(DLoc);
4618 BranchInst::Create(StateMachineDoneBarrierBB, StateMachineEndParallelBB)
4619 ->setDebugLoc(DLoc);
4620
4621 CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineDoneBarrierBB)
4622 ->setDebugLoc(DLoc);
4623 BranchInst::Create(StateMachineBeginBB, StateMachineDoneBarrierBB)
4624 ->setDebugLoc(DLoc);
4625
4626 return true;
4627 }
4628
4629 /// Fixpoint iteration update function. Will be called every time a dependence
4630 /// changed its state (and in the beginning).
4631 ChangeStatus updateImpl(Attributor &A) override {
4632 KernelInfoState StateBefore = getState();
4633
4634 // When we leave this function this RAII will make sure the member
4635 // KernelEnvC is updated properly depending on the state. That member is
4636 // used for simplification of values and needs to be up to date at all
4637 // times.
4638 struct UpdateKernelEnvCRAII {
4639 AAKernelInfoFunction &AA;
4640
4641 UpdateKernelEnvCRAII(AAKernelInfoFunction &AA) : AA(AA) {}
4642
4643 ~UpdateKernelEnvCRAII() {
4644 if (!AA.KernelEnvC)
4645 return;
4646
4647 ConstantStruct *ExistingKernelEnvC =
4649
4650 if (!AA.isValidState()) {
4651 AA.KernelEnvC = ExistingKernelEnvC;
4652 return;
4653 }
4654
4655 if (!AA.ReachedKnownParallelRegions.isValidState())
4656 AA.setUseGenericStateMachineOfKernelEnvironment(
4657 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
4658 ExistingKernelEnvC));
4659
4660 if (!AA.SPMDCompatibilityTracker.isValidState())
4661 AA.setExecModeOfKernelEnvironment(
4662 KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC));
4663
4664 ConstantInt *MayUseNestedParallelismC =
4665 KernelInfo::getMayUseNestedParallelismFromKernelEnvironment(
4666 AA.KernelEnvC);
4667 ConstantInt *NewMayUseNestedParallelismC = ConstantInt::get(
4668 MayUseNestedParallelismC->getIntegerType(), AA.NestedParallelism);
4669 AA.setMayUseNestedParallelismOfKernelEnvironment(
4670 NewMayUseNestedParallelismC);
4671 }
4672 } RAII(*this);
4673
4674 // Callback to check a read/write instruction.
4675 auto CheckRWInst = [&](Instruction &I) {
4676 // We handle calls later.
4677 if (isa<CallBase>(I))
4678 return true;
4679 // We only care about write effects.
4680 if (!I.mayWriteToMemory())
4681 return true;
4682 if (auto *SI = dyn_cast<StoreInst>(&I)) {
4683 const auto *UnderlyingObjsAA = A.getAAFor<AAUnderlyingObjects>(
4684 *this, IRPosition::value(*SI->getPointerOperand()),
4685 DepClassTy::OPTIONAL);
4686 auto *HS = A.getAAFor<AAHeapToStack>(
4687 *this, IRPosition::function(*I.getFunction()),
4688 DepClassTy::OPTIONAL);
4689 if (UnderlyingObjsAA &&
4690 UnderlyingObjsAA->forallUnderlyingObjects([&](Value &Obj) {
4691 if (AA::isAssumedThreadLocalObject(A, Obj, *this))
4692 return true;
4693 // Check for AAHeapToStack moved objects which must not be
4694 // guarded.
4695 auto *CB = dyn_cast<CallBase>(&Obj);
4696 return CB && HS && HS->isAssumedHeapToStack(*CB);
4697 }))
4698 return true;
4699 }
4700
4701 // Insert instruction that needs guarding.
4702 SPMDCompatibilityTracker.insert(&I);
4703 return true;
4704 };
4705
4706 bool UsedAssumedInformationInCheckRWInst = false;
4707 if (!SPMDCompatibilityTracker.isAtFixpoint())
4708 if (!A.checkForAllReadWriteInstructions(
4709 CheckRWInst, *this, UsedAssumedInformationInCheckRWInst))
4710 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4711
4712 bool UsedAssumedInformationFromReachingKernels = false;
4713 if (!IsKernelEntry) {
4714 updateParallelLevels(A);
4715
4716 bool AllReachingKernelsKnown = true;
4717 updateReachingKernelEntries(A, AllReachingKernelsKnown);
4718 UsedAssumedInformationFromReachingKernels = !AllReachingKernelsKnown;
4719
4720 if (!SPMDCompatibilityTracker.empty()) {
4721 if (!ParallelLevels.isValidState())
4722 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4723 else if (!ReachingKernelEntries.isValidState())
4724 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4725 else {
4726 // Check if all reaching kernels agree on the mode as we can otherwise
4727 // not guard instructions. We might not be sure about the mode so we
4728 // we cannot fix the internal spmd-zation state either.
4729 int SPMD = 0, Generic = 0;
4730 for (auto *Kernel : ReachingKernelEntries) {
4731 auto *CBAA = A.getAAFor<AAKernelInfo>(
4732 *this, IRPosition::function(*Kernel), DepClassTy::OPTIONAL);
4733 if (CBAA && CBAA->SPMDCompatibilityTracker.isValidState() &&
4734 CBAA->SPMDCompatibilityTracker.isAssumed())
4735 ++SPMD;
4736 else
4737 ++Generic;
4738 if (!CBAA || !CBAA->SPMDCompatibilityTracker.isAtFixpoint())
4739 UsedAssumedInformationFromReachingKernels = true;
4740 }
4741 if (SPMD != 0 && Generic != 0)
4742 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4743 }
4744 }
4745 }
4746
4747 // Callback to check a call instruction.
4748 bool AllParallelRegionStatesWereFixed = true;
4749 bool AllSPMDStatesWereFixed = true;
4750 auto CheckCallInst = [&](Instruction &I) {
4751 auto &CB = cast<CallBase>(I);
4752 auto *CBAA = A.getAAFor<AAKernelInfo>(
4753 *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);
4754 if (!CBAA)
4755 return false;
4756 getState() ^= CBAA->getState();
4757 AllSPMDStatesWereFixed &= CBAA->SPMDCompatibilityTracker.isAtFixpoint();
4758 AllParallelRegionStatesWereFixed &=
4759 CBAA->ReachedKnownParallelRegions.isAtFixpoint();
4760 AllParallelRegionStatesWereFixed &=
4761 CBAA->ReachedUnknownParallelRegions.isAtFixpoint();
4762 return true;
4763 };
4764
4765 bool UsedAssumedInformationInCheckCallInst = false;
4766 if (!A.checkForAllCallLikeInstructions(
4767 CheckCallInst, *this, UsedAssumedInformationInCheckCallInst)) {
4768 LLVM_DEBUG(dbgs() << TAG
4769 << "Failed to visit all call-like instructions!\n";);
4770 return indicatePessimisticFixpoint();
4771 }
4772
4773 // If we haven't used any assumed information for the reached parallel
4774 // region states we can fix it.
4775 if (!UsedAssumedInformationInCheckCallInst &&
4776 AllParallelRegionStatesWereFixed) {
4777 ReachedKnownParallelRegions.indicateOptimisticFixpoint();
4778 ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
4779 }
4780
4781 // If we haven't used any assumed information for the SPMD state we can fix
4782 // it.
4783 if (!UsedAssumedInformationInCheckRWInst &&
4784 !UsedAssumedInformationInCheckCallInst &&
4785 !UsedAssumedInformationFromReachingKernels && AllSPMDStatesWereFixed)
4786 SPMDCompatibilityTracker.indicateOptimisticFixpoint();
4787
4788 return StateBefore == getState() ? ChangeStatus::UNCHANGED
4789 : ChangeStatus::CHANGED;
4790 }
4791
4792private:
4793 /// Update info regarding reaching kernels.
4794 void updateReachingKernelEntries(Attributor &A,
4795 bool &AllReachingKernelsKnown) {
4796 auto PredCallSite = [&](AbstractCallSite ACS) {
4797 Function *Caller = ACS.getInstruction()->getFunction();
4798
4799 assert(Caller && "Caller is nullptr");
4800
4801 auto *CAA = A.getOrCreateAAFor<AAKernelInfo>(
4802 IRPosition::function(*Caller), this, DepClassTy::REQUIRED);
4803 if (CAA && CAA->ReachingKernelEntries.isValidState()) {
4804 ReachingKernelEntries ^= CAA->ReachingKernelEntries;
4805 return true;
4806 }
4807
4808 // We lost track of the caller of the associated function, any kernel
4809 // could reach now.
4810 ReachingKernelEntries.indicatePessimisticFixpoint();
4811
4812 return true;
4813 };
4814
4815 if (!A.checkForAllCallSites(PredCallSite, *this,
4816 true /* RequireAllCallSites */,
4817 AllReachingKernelsKnown))
4818 ReachingKernelEntries.indicatePessimisticFixpoint();
4819 }
4820
4821 /// Update info regarding parallel levels.
4822 void updateParallelLevels(Attributor &A) {
4823 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4824 OMPInformationCache::RuntimeFunctionInfo &Parallel51RFI =
4825 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
4826
4827 auto PredCallSite = [&](AbstractCallSite ACS) {
4828 Function *Caller = ACS.getInstruction()->getFunction();
4829
4830 assert(Caller && "Caller is nullptr");
4831
4832 auto *CAA =
4833 A.getOrCreateAAFor<AAKernelInfo>(IRPosition::function(*Caller));
4834 if (CAA && CAA->ParallelLevels.isValidState()) {
4835 // Any function that is called by `__kmpc_parallel_51` will not be
4836 // folded as the parallel level in the function is updated. In order to
4837 // get it right, all the analysis would depend on the implentation. That
4838 // said, if in the future any change to the implementation, the analysis
4839 // could be wrong. As a consequence, we are just conservative here.
4840 if (Caller == Parallel51RFI.Declaration) {
4841 ParallelLevels.indicatePessimisticFixpoint();
4842 return true;
4843 }
4844
4845 ParallelLevels ^= CAA->ParallelLevels;
4846
4847 return true;
4848 }
4849
4850 // We lost track of the caller of the associated function, any kernel
4851 // could reach now.
4852 ParallelLevels.indicatePessimisticFixpoint();
4853
4854 return true;
4855 };
4856
4857 bool AllCallSitesKnown = true;
4858 if (!A.checkForAllCallSites(PredCallSite, *this,
4859 true /* RequireAllCallSites */,
4860 AllCallSitesKnown))
4861 ParallelLevels.indicatePessimisticFixpoint();
4862 }
4863};
4864
4865/// The call site kernel info abstract attribute, basically, what can we say
4866/// about a call site with regards to the KernelInfoState. For now this simply
4867/// forwards the information from the callee.
4868struct AAKernelInfoCallSite : AAKernelInfo {
4869 AAKernelInfoCallSite(const IRPosition &IRP, Attributor &A)
4870 : AAKernelInfo(IRP, A) {}
4871
4872 /// See AbstractAttribute::initialize(...).
4873 void initialize(Attributor &A) override {
4874 AAKernelInfo::initialize(A);
4875
4876 CallBase &CB = cast<CallBase>(getAssociatedValue());
4877 auto *AssumptionAA = A.getAAFor<AAAssumptionInfo>(
4878 *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);
4879
4880 // Check for SPMD-mode assumptions.
4881 if (AssumptionAA && AssumptionAA->hasAssumption("ompx_spmd_amenable")) {
4882 indicateOptimisticFixpoint();
4883 return;
4884 }
4885
4886 // First weed out calls we do not care about, that is readonly/readnone
4887 // calls, intrinsics, and "no_openmp" calls. Neither of these can reach a
4888 // parallel region or anything else we are looking for.
4889 if (!CB.mayWriteToMemory() || isa<IntrinsicInst>(CB)) {
4890 indicateOptimisticFixpoint();
4891 return;
4892 }
4893
4894 // Next we check if we know the callee. If it is a known OpenMP function
4895 // we will handle them explicitly in the switch below. If it is not, we
4896 // will use an AAKernelInfo object on the callee to gather information and
4897 // merge that into the current state. The latter happens in the updateImpl.
4898 auto CheckCallee = [&](Function *Callee, unsigned NumCallees) {
4899 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4900 const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
4901 if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
4902 // Unknown caller or declarations are not analyzable, we give up.
4903 if (!Callee || !A.isFunctionIPOAmendable(*Callee)) {
4904
4905 // Unknown callees might contain parallel regions, except if they have
4906 // an appropriate assumption attached.
4907 if (!AssumptionAA ||
4908 !(AssumptionAA->hasAssumption("omp_no_openmp") ||
4909 AssumptionAA->hasAssumption("omp_no_parallelism")))
4910 ReachedUnknownParallelRegions.insert(&CB);
4911
4912 // If SPMDCompatibilityTracker is not fixed, we need to give up on the
4913 // idea we can run something unknown in SPMD-mode.
4914 if (!SPMDCompatibilityTracker.isAtFixpoint()) {
4915 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4916 SPMDCompatibilityTracker.insert(&CB);
4917 }
4918
4919 // We have updated the state for this unknown call properly, there
4920 // won't be any change so we indicate a fixpoint.
4921 indicateOptimisticFixpoint();
4922 }
4923 // If the callee is known and can be used in IPO, we will update the
4924 // state based on the callee state in updateImpl.
4925 return;
4926 }
4927 if (NumCallees > 1) {
4928 indicatePessimisticFixpoint();
4929 return;
4930 }
4931
4932 RuntimeFunction RF = It->getSecond();
4933 switch (RF) {
4934 // All the functions we know are compatible with SPMD mode.
4935 case OMPRTL___kmpc_is_spmd_exec_mode:
4936 case OMPRTL___kmpc_distribute_static_fini:
4937 case OMPRTL___kmpc_for_static_fini:
4938 case OMPRTL___kmpc_global_thread_num:
4939 case OMPRTL___kmpc_get_hardware_num_threads_in_block:
4940 case OMPRTL___kmpc_get_hardware_num_blocks:
4941 case OMPRTL___kmpc_single:
4942 case OMPRTL___kmpc_end_single:
4943 case OMPRTL___kmpc_master:
4944 case OMPRTL___kmpc_end_master:
4945 case OMPRTL___kmpc_barrier:
4946 case OMPRTL___kmpc_nvptx_parallel_reduce_nowait_v2:
4947 case OMPRTL___kmpc_nvptx_teams_reduce_nowait_v2:
4948 case OMPRTL___kmpc_error:
4949 case OMPRTL___kmpc_flush:
4950 case OMPRTL___kmpc_get_hardware_thread_id_in_block:
4951 case OMPRTL___kmpc_get_warp_size:
4952 case OMPRTL_omp_get_thread_num:
4953 case OMPRTL_omp_get_num_threads:
4954 case OMPRTL_omp_get_max_threads:
4955 case OMPRTL_omp_in_parallel:
4956 case OMPRTL_omp_get_dynamic:
4957 case OMPRTL_omp_get_cancellation:
4958 case OMPRTL_omp_get_nested:
4959 case OMPRTL_omp_get_schedule:
4960 case OMPRTL_omp_get_thread_limit:
4961 case OMPRTL_omp_get_supported_active_levels:
4962 case OMPRTL_omp_get_max_active_levels:
4963 case OMPRTL_omp_get_level:
4964 case OMPRTL_omp_get_ancestor_thread_num:
4965 case OMPRTL_omp_get_team_size:
4966 case OMPRTL_omp_get_active_level:
4967 case OMPRTL_omp_in_final:
4968 case OMPRTL_omp_get_proc_bind:
4969 case OMPRTL_omp_get_num_places:
4970 case OMPRTL_omp_get_num_procs:
4971 case OMPRTL_omp_get_place_proc_ids:
4972 case OMPRTL_omp_get_place_num:
4973 case OMPRTL_omp_get_partition_num_places:
4974 case OMPRTL_omp_get_partition_place_nums:
4975 case OMPRTL_omp_get_wtime:
4976 break;
4977 case OMPRTL___kmpc_distribute_static_init_4:
4978 case OMPRTL___kmpc_distribute_static_init_4u:
4979 case OMPRTL___kmpc_distribute_static_init_8:
4980 case OMPRTL___kmpc_distribute_static_init_8u:
4981 case OMPRTL___kmpc_for_static_init_4:
4982 case OMPRTL___kmpc_for_static_init_4u:
4983 case OMPRTL___kmpc_for_static_init_8:
4984 case OMPRTL___kmpc_for_static_init_8u: {
4985 // Check the schedule and allow static schedule in SPMD mode.
4986 unsigned ScheduleArgOpNo = 2;
4987 auto *ScheduleTypeCI =
4988 dyn_cast<ConstantInt>(CB.getArgOperand(ScheduleArgOpNo));
4989 unsigned ScheduleTypeVal =
4990 ScheduleTypeCI ? ScheduleTypeCI->getZExtValue() : 0;
4991 switch (OMPScheduleType(ScheduleTypeVal)) {
4992 case OMPScheduleType::UnorderedStatic:
4993 case OMPScheduleType::UnorderedStaticChunked:
4994 case OMPScheduleType::OrderedDistribute:
4995 case OMPScheduleType::OrderedDistributeChunked:
4996 break;
4997 default:
4998 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4999 SPMDCompatibilityTracker.insert(&CB);
5000 break;
5001 };
5002 } break;
5003 case OMPRTL___kmpc_target_init:
5004 KernelInitCB = &CB;
5005 break;
5006 case OMPRTL___kmpc_target_deinit:
5007 KernelDeinitCB = &CB;
5008 break;
5009 case OMPRTL___kmpc_parallel_51:
5010 if (!handleParallel51(A, CB))
5011 indicatePessimisticFixpoint();
5012 return;
5013 case OMPRTL___kmpc_omp_task:
5014 // We do not look into tasks right now, just give up.
5015 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5016 SPMDCompatibilityTracker.insert(&CB);
5017 ReachedUnknownParallelRegions.insert(&CB);
5018 break;
5019 case OMPRTL___kmpc_alloc_shared:
5020 case OMPRTL___kmpc_free_shared:
5021 // Return without setting a fixpoint, to be resolved in updateImpl.
5022 return;
5023 default:
5024 // Unknown OpenMP runtime calls cannot be executed in SPMD-mode,
5025 // generally. However, they do not hide parallel regions.
5026 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5027 SPMDCompatibilityTracker.insert(&CB);
5028 break;
5029 }
5030 // All other OpenMP runtime calls will not reach parallel regions so they
5031 // can be safely ignored for now. Since it is a known OpenMP runtime call
5032 // we have now modeled all effects and there is no need for any update.
5033 indicateOptimisticFixpoint();
5034 };
5035
5036 const auto *AACE =
5037 A.getAAFor<AACallEdges>(*this, getIRPosition(), DepClassTy::OPTIONAL);
5038 if (!AACE || !AACE->getState().isValidState() || AACE->hasUnknownCallee()) {
5039 CheckCallee(getAssociatedFunction(), 1);
5040 return;
5041 }
5042 const auto &OptimisticEdges = AACE->getOptimisticEdges();
5043 for (auto *Callee : OptimisticEdges) {
5044 CheckCallee(Callee, OptimisticEdges.size());
5045 if (isAtFixpoint())
5046 break;
5047 }
5048 }
5049
5050 ChangeStatus updateImpl(Attributor &A) override {
5051 // TODO: Once we have call site specific value information we can provide
5052 // call site specific liveness information and then it makes
5053 // sense to specialize attributes for call sites arguments instead of
5054 // redirecting requests to the callee argument.
5055 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
5056 KernelInfoState StateBefore = getState();
5057
5058 auto CheckCallee = [&](Function *F, int NumCallees) {
5059 const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(F);
5060
5061 // If F is not a runtime function, propagate the AAKernelInfo of the
5062 // callee.
5063 if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
5064 const IRPosition &FnPos = IRPosition::function(*F);
5065 auto *FnAA =
5066 A.getAAFor<AAKernelInfo>(*this, FnPos, DepClassTy::REQUIRED);
5067 if (!FnAA)
5068 return indicatePessimisticFixpoint();
5069 if (getState() == FnAA->getState())
5070 return ChangeStatus::UNCHANGED;
5071 getState() = FnAA->getState();
5072 return ChangeStatus::CHANGED;
5073 }
5074 if (NumCallees > 1)
5075 return indicatePessimisticFixpoint();
5076
5077 CallBase &CB = cast<CallBase>(getAssociatedValue());
5078 if (It->getSecond() == OMPRTL___kmpc_parallel_51) {
5079 if (!handleParallel51(A, CB))
5080 return indicatePessimisticFixpoint();
5081 return StateBefore == getState() ? ChangeStatus::UNCHANGED
5082 : ChangeStatus::CHANGED;
5083 }
5084
5085 // F is a runtime function that allocates or frees memory, check
5086 // AAHeapToStack and AAHeapToShared.
5087 assert(
5088 (It->getSecond() == OMPRTL___kmpc_alloc_shared ||
5089 It->getSecond() == OMPRTL___kmpc_free_shared) &&
5090 "Expected a __kmpc_alloc_shared or __kmpc_free_shared runtime call");
5091
5092 auto *HeapToStackAA = A.getAAFor<AAHeapToStack>(
5093 *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
5094 auto *HeapToSharedAA = A.getAAFor<AAHeapToShared>(
5095 *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
5096
5097 RuntimeFunction RF = It->getSecond();
5098
5099 switch (RF) {
5100 // If neither HeapToStack nor HeapToShared assume the call is removed,
5101 // assume SPMD incompatibility.
5102 case OMPRTL___kmpc_alloc_shared:
5103 if ((!HeapToStackAA || !HeapToStackAA->isAssumedHeapToStack(CB)) &&
5104 (!HeapToSharedAA || !HeapToSharedAA->isAssumedHeapToShared(CB)))
5105 SPMDCompatibilityTracker.insert(&CB);
5106 break;
5107 case OMPRTL___kmpc_free_shared:
5108 if ((!HeapToStackAA ||
5109 !HeapToStackAA->isAssumedHeapToStackRemovedFree(CB)) &&
5110 (!HeapToSharedAA ||
5111 !HeapToSharedAA->isAssumedHeapToSharedRemovedFree(CB)))
5112 SPMDCompatibilityTracker.insert(&CB);
5113 break;
5114 default:
5115 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5116 SPMDCompatibilityTracker.insert(&CB);
5117 }
5118 return ChangeStatus::CHANGED;
5119 };
5120
5121 const auto *AACE =
5122 A.getAAFor<AACallEdges>(*this, getIRPosition(), DepClassTy::OPTIONAL);
5123 if (!AACE || !AACE->getState().isValidState() || AACE->hasUnknownCallee()) {
5124 if (Function *F = getAssociatedFunction())
5125 CheckCallee(F, /*NumCallees=*/1);
5126 } else {
5127 const auto &OptimisticEdges = AACE->getOptimisticEdges();
5128 for (auto *Callee : OptimisticEdges) {
5129 CheckCallee(Callee, OptimisticEdges.size());
5130 if (isAtFixpoint())
5131 break;
5132 }
5133 }
5134
5135 return StateBefore == getState() ? ChangeStatus::UNCHANGED
5136 : ChangeStatus::CHANGED;
5137 }
5138
5139 /// Deal with a __kmpc_parallel_51 call (\p CB). Returns true if the call was
5140 /// handled, if a problem occurred, false is returned.
5141 bool handleParallel51(Attributor &A, CallBase &CB) {
5142 const unsigned int NonWrapperFunctionArgNo = 5;
5143 const unsigned int WrapperFunctionArgNo = 6;
5144 auto ParallelRegionOpArgNo = SPMDCompatibilityTracker.isAssumed()
5145 ? NonWrapperFunctionArgNo
5146 : WrapperFunctionArgNo;
5147
5148 auto *ParallelRegion = dyn_cast<Function>(
5149 CB.getArgOperand(ParallelRegionOpArgNo)->stripPointerCasts());
5150 if (!ParallelRegion)
5151 return false;
5152
5153 ReachedKnownParallelRegions.insert(&CB);
5154 /// Check nested parallelism
5155 auto *FnAA = A.getAAFor<AAKernelInfo>(
5156 *this, IRPosition::function(*ParallelRegion), DepClassTy::OPTIONAL);
5157 NestedParallelism |= !FnAA || !FnAA->getState().isValidState() ||
5158 !FnAA->ReachedKnownParallelRegions.empty() ||
5159 !FnAA->ReachedKnownParallelRegions.isValidState() ||
5160 !FnAA->ReachedUnknownParallelRegions.isValidState() ||
5161 !FnAA->ReachedUnknownParallelRegions.empty();
5162 return true;
5163 }
5164};
5165
5166struct AAFoldRuntimeCall
5167 : public StateWrapper<BooleanState, AbstractAttribute> {
5169
5170 AAFoldRuntimeCall(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
5171
5172 /// Statistics are tracked as part of manifest for now.
5173 void trackStatistics() const override {}
5174
5175 /// Create an abstract attribute biew for the position \p IRP.
5176 static AAFoldRuntimeCall &createForPosition(const IRPosition &IRP,
5177 Attributor &A);
5178
5179 /// See AbstractAttribute::getName()
5180 const std::string getName() const override { return "AAFoldRuntimeCall"; }
5181
5182 /// See AbstractAttribute::getIdAddr()
5183 const char *getIdAddr() const override { return &ID; }
5184
5185 /// This function should return true if the type of the \p AA is
5186 /// AAFoldRuntimeCall
5187 static bool classof(const AbstractAttribute *AA) {
5188 return (AA->getIdAddr() == &ID);
5189 }
5190
5191 static const char ID;
5192};
5193
5194struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
5195 AAFoldRuntimeCallCallSiteReturned(const IRPosition &IRP, Attributor &A)
5196 : AAFoldRuntimeCall(IRP, A) {}
5197
5198 /// See AbstractAttribute::getAsStr()
5199 const std::string getAsStr(Attributor *) const override {
5200 if (!isValidState())
5201 return "<invalid>";
5202
5203 std::string Str("simplified value: ");
5204
5205 if (!SimplifiedValue)
5206 return Str + std::string("none");
5207
5208 if (!*SimplifiedValue)
5209 return Str + std::string("nullptr");
5210
5211 if (ConstantInt *CI = dyn_cast<ConstantInt>(*SimplifiedValue))
5212 return Str + std::to_string(CI->getSExtValue());
5213
5214 return Str + std::string("unknown");
5215 }
5216
5217 void initialize(Attributor &A) override {
5219 indicatePessimisticFixpoint();
5220
5221 Function *Callee = getAssociatedFunction();
5222
5223 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
5224 const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
5225 assert(It != OMPInfoCache.RuntimeFunctionIDMap.end() &&
5226 "Expected a known OpenMP runtime function");
5227
5228 RFKind = It->getSecond();
5229
5230 CallBase &CB = cast<CallBase>(getAssociatedValue());
5231 A.registerSimplificationCallback(
5233 [&](const IRPosition &IRP, const AbstractAttribute *AA,
5234 bool &UsedAssumedInformation) -> std::optional<Value *> {
5235 assert((isValidState() ||
5236 (SimplifiedValue && *SimplifiedValue == nullptr)) &&
5237 "Unexpected invalid state!");
5238
5239 if (!isAtFixpoint()) {
5240 UsedAssumedInformation = true;
5241 if (AA)
5242 A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
5243 }
5244 return SimplifiedValue;
5245 });
5246 }
5247
5248 ChangeStatus updateImpl(Attributor &A) override {
5249 ChangeStatus Changed = ChangeStatus::UNCHANGED;
5250 switch (RFKind) {
5251 case OMPRTL___kmpc_is_spmd_exec_mode:
5252 Changed |= foldIsSPMDExecMode(A);
5253 break;
5254 case OMPRTL___kmpc_parallel_level:
5255 Changed |= foldParallelLevel(A);
5256 break;
5257 case OMPRTL___kmpc_get_hardware_num_threads_in_block:
5258 Changed = Changed | foldKernelFnAttribute(A, "omp_target_thread_limit");
5259 break;
5260 case OMPRTL___kmpc_get_hardware_num_blocks:
5261 Changed = Changed | foldKernelFnAttribute(A, "omp_target_num_teams");
5262 break;
5263 default:
5264 llvm_unreachable("Unhandled OpenMP runtime function!");
5265 }
5266
5267 return Changed;
5268 }
5269
5270 ChangeStatus manifest(Attributor &A) override {
5271 ChangeStatus Changed = ChangeStatus::UNCHANGED;
5272
5273 if (SimplifiedValue && *SimplifiedValue) {
5274 Instruction &I = *getCtxI();
5275 A.changeAfterManifest(IRPosition::inst(I), **SimplifiedValue);
5276 A.deleteAfterManifest(I);
5277
5278 CallBase *CB = dyn_cast<CallBase>(&I);
5279 auto Remark = [&](OptimizationRemark OR) {
5280 if (auto *C = dyn_cast<ConstantInt>(*SimplifiedValue))
5281 return OR << "Replacing OpenMP runtime call "
5282 << CB->getCalledFunction()->getName() << " with "
5283 << ore::NV("FoldedValue", C->getZExtValue()) << ".";
5284 return OR << "Replacing OpenMP runtime call "
5285 << CB->getCalledFunction()->getName() << ".";
5286 };
5287
5288 if (CB && EnableVerboseRemarks)
5289 A.emitRemark<OptimizationRemark>(CB, "OMP180", Remark);
5290
5291 LLVM_DEBUG(dbgs() << TAG << "Replacing runtime call: " << I << " with "
5292 << **SimplifiedValue << "\n");
5293
5294 Changed = ChangeStatus::CHANGED;
5295 }
5296
5297 return Changed;
5298 }
5299
5300 ChangeStatus indicatePessimisticFixpoint() override {
5301 SimplifiedValue = nullptr;
5302 return AAFoldRuntimeCall::indicatePessimisticFixpoint();
5303 }
5304
5305private:
5306 /// Fold __kmpc_is_spmd_exec_mode into a constant if possible.
5307 ChangeStatus foldIsSPMDExecMode(Attributor &A) {
5308 std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
5309
5310 unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
5311 unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
5312 auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
5313 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
5314
5315 if (!CallerKernelInfoAA ||
5316 !CallerKernelInfoAA->ReachingKernelEntries.isValidState())
5317 return indicatePessimisticFixpoint();
5318
5319 for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
5320 auto *AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
5321 DepClassTy::REQUIRED);
5322
5323 if (!AA || !AA->isValidState()) {
5324 SimplifiedValue = nullptr;
5325 return indicatePessimisticFixpoint();
5326 }
5327
5328 if (AA->SPMDCompatibilityTracker.isAssumed()) {
5329 if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5330 ++KnownSPMDCount;
5331 else
5332 ++AssumedSPMDCount;
5333 } else {
5334 if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5335 ++KnownNonSPMDCount;
5336 else
5337 ++AssumedNonSPMDCount;
5338 }
5339 }
5340
5341 if ((AssumedSPMDCount + KnownSPMDCount) &&
5342 (AssumedNonSPMDCount + KnownNonSPMDCount))
5343 return indicatePessimisticFixpoint();
5344
5345 auto &Ctx = getAnchorValue().getContext();
5346 if (KnownSPMDCount || AssumedSPMDCount) {
5347 assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
5348 "Expected only SPMD kernels!");
5349 // All reaching kernels are in SPMD mode. Update all function calls to
5350 // __kmpc_is_spmd_exec_mode to 1.
5351 SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), true);
5352 } else if (KnownNonSPMDCount || AssumedNonSPMDCount) {
5353 assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
5354 "Expected only non-SPMD kernels!");
5355 // All reaching kernels are in non-SPMD mode. Update all function
5356 // calls to __kmpc_is_spmd_exec_mode to 0.
5357 SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), false);
5358 } else {
5359 // We have empty reaching kernels, therefore we cannot tell if the
5360 // associated call site can be folded. At this moment, SimplifiedValue
5361 // must be none.
5362 assert(!SimplifiedValue && "SimplifiedValue should be none");
5363 }
5364
5365 return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
5366 : ChangeStatus::CHANGED;
5367 }
5368
5369 /// Fold __kmpc_parallel_level into a constant if possible.
5370 ChangeStatus foldParallelLevel(Attributor &A) {
5371 std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
5372
5373 auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
5374 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
5375
5376 if (!CallerKernelInfoAA ||
5377 !CallerKernelInfoAA->ParallelLevels.isValidState())
5378 return indicatePessimisticFixpoint();
5379
5380 if (!CallerKernelInfoAA->ReachingKernelEntries.isValidState())
5381 return indicatePessimisticFixpoint();
5382
5383 if (CallerKernelInfoAA->ReachingKernelEntries.empty()) {
5384 assert(!SimplifiedValue &&
5385 "SimplifiedValue should keep none at this point");
5386 return ChangeStatus::UNCHANGED;
5387 }
5388
5389 unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
5390 unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
5391 for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
5392 auto *AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
5393 DepClassTy::REQUIRED);
5394 if (!AA || !AA->SPMDCompatibilityTracker.isValidState())
5395 return indicatePessimisticFixpoint();
5396
5397 if (AA->SPMDCompatibilityTracker.isAssumed()) {
5398 if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5399 ++KnownSPMDCount;
5400 else
5401 ++AssumedSPMDCount;
5402 } else {
5403 if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5404 ++KnownNonSPMDCount;
5405 else
5406 ++AssumedNonSPMDCount;
5407 }
5408 }
5409
5410 if ((AssumedSPMDCount + KnownSPMDCount) &&
5411 (AssumedNonSPMDCount + KnownNonSPMDCount))
5412 return indicatePessimisticFixpoint();
5413
5414 auto &Ctx = getAnchorValue().getContext();
5415 // If the caller can only be reached by SPMD kernel entries, the parallel
5416 // level is 1. Similarly, if the caller can only be reached by non-SPMD
5417 // kernel entries, it is 0.
5418 if (AssumedSPMDCount || KnownSPMDCount) {
5419 assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
5420 "Expected only SPMD kernels!");
5421 SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 1);
5422 } else {
5423 assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
5424 "Expected only non-SPMD kernels!");
5425 SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 0);
5426 }
5427 return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
5428 : ChangeStatus::CHANGED;
5429 }
5430
5431 ChangeStatus foldKernelFnAttribute(Attributor &A, llvm::StringRef Attr) {
5432 // Specialize only if all the calls agree with the attribute constant value
5433 int32_t CurrentAttrValue = -1;
5434 std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
5435
5436 auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
5437 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
5438
5439 if (!CallerKernelInfoAA ||
5440 !CallerKernelInfoAA->ReachingKernelEntries.isValidState())
5441 return indicatePessimisticFixpoint();
5442
5443 // Iterate over the kernels that reach this function
5444 for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
5445 int32_t NextAttrVal = K->getFnAttributeAsParsedInteger(Attr, -1);
5446
5447 if (NextAttrVal == -1 ||
5448 (CurrentAttrValue != -1 && CurrentAttrValue != NextAttrVal))
5449 return indicatePessimisticFixpoint();
5450 CurrentAttrValue = NextAttrVal;
5451 }
5452
5453 if (CurrentAttrValue != -1) {
5454 auto &Ctx = getAnchorValue().getContext();
5455 SimplifiedValue =
5456 ConstantInt::get(Type::getInt32Ty(Ctx), CurrentAttrValue);
5457 }
5458 return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
5459 : ChangeStatus::CHANGED;
5460 }
5461
5462 /// An optional value the associated value is assumed to fold to. That is, we
5463 /// assume the associated value (which is a call) can be replaced by this
5464 /// simplified value.
5465 std::optional<Value *> SimplifiedValue;
5466
5467 /// The runtime function kind of the callee of the associated call site.
5468 RuntimeFunction RFKind;
5469};
5470
5471} // namespace
5472
5473/// Register folding callsite
5474void OpenMPOpt::registerFoldRuntimeCall(RuntimeFunction RF) {
5475 auto &RFI = OMPInfoCache.RFIs[RF];
5476 RFI.foreachUse(SCC, [&](Use &U, Function &F) {
5477 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &RFI);
5478 if (!CI)
5479 return false;
5480 A.getOrCreateAAFor<AAFoldRuntimeCall>(
5481 IRPosition::callsite_returned(*CI), /* QueryingAA */ nullptr,
5482 DepClassTy::NONE, /* ForceUpdate */ false,
5483 /* UpdateAfterInit */ false);
5484 return false;
5485 });
5486}
5487
5488void OpenMPOpt::registerAAs(bool IsModulePass) {
5489 if (SCC.empty())
5490 return;
5491
5492 if (IsModulePass) {
5493 // Ensure we create the AAKernelInfo AAs first and without triggering an
5494 // update. This will make sure we register all value simplification
5495 // callbacks before any other AA has the chance to create an AAValueSimplify
5496 // or similar.
5497 auto CreateKernelInfoCB = [&](Use &, Function &Kernel) {
5498 A.getOrCreateAAFor<AAKernelInfo>(
5499 IRPosition::function(Kernel), /* QueryingAA */ nullptr,
5500 DepClassTy::NONE, /* ForceUpdate */ false,
5501 /* UpdateAfterInit */ false);
5502 return false;
5503 };
5504 OMPInformationCache::RuntimeFunctionInfo &InitRFI =
5505 OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
5506 InitRFI.foreachUse(SCC, CreateKernelInfoCB);
5507
5508 registerFoldRuntimeCall(OMPRTL___kmpc_is_spmd_exec_mode);
5509 registerFoldRuntimeCall(OMPRTL___kmpc_parallel_level);
5510 registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_threads_in_block);
5511 registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_blocks);
5512 }
5513
5514 // Create CallSite AA for all Getters.
5515 if (DeduceICVValues) {
5516 for (int Idx = 0; Idx < OMPInfoCache.ICVs.size() - 1; ++Idx) {
5517 auto ICVInfo = OMPInfoCache.ICVs[static_cast<InternalControlVar>(Idx)];
5518
5519 auto &GetterRFI = OMPInfoCache.RFIs[ICVInfo.Getter];
5520
5521 auto CreateAA = [&](Use &U, Function &Caller) {
5522 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &GetterRFI);
5523 if (!CI)
5524 return false;
5525
5526 auto &CB = cast<CallBase>(*CI);
5527
5529 A.getOrCreateAAFor<AAICVTracker>(CBPos);
5530 return false;
5531 };
5532
5533 GetterRFI.foreachUse(SCC, CreateAA);
5534 }
5535 }
5536
5537 // Create an ExecutionDomain AA for every function and a HeapToStack AA for
5538 // every function if there is a device kernel.
5539 if (!isOpenMPDevice(M))
5540 return;
5541
5542 for (auto *F : SCC) {
5543 if (F->isDeclaration())
5544 continue;
5545
5546 // We look at internal functions only on-demand but if any use is not a
5547 // direct call or outside the current set of analyzed functions, we have
5548 // to do it eagerly.
5549 if (F->hasLocalLinkage()) {
5550 if (llvm::all_of(F->uses(), [this](const Use &U) {
5551 const auto *CB = dyn_cast<CallBase>(U.getUser());
5552 return CB && CB->isCallee(&U) &&
5553 A.isRunOn(const_cast<Function *>(CB->getCaller()));
5554 }))
5555 continue;
5556 }
5557 registerAAsForFunction(A, *F);
5558 }
5559}
5560
5561void OpenMPOpt::registerAAsForFunction(Attributor &A, const Function &F) {
5563 A.getOrCreateAAFor<AAHeapToShared>(IRPosition::function(F));
5564 A.getOrCreateAAFor<AAExecutionDomain>(IRPosition::function(F));
5566 A.getOrCreateAAFor<AAHeapToStack>(IRPosition::function(F));
5567 if (F.hasFnAttribute(Attribute::Convergent))
5568 A.getOrCreateAAFor<AANonConvergent>(IRPosition::function(F));
5569
5570 for (auto &I : instructions(F)) {
5571 if (auto *LI = dyn_cast<LoadInst>(&I)) {
5572 bool UsedAssumedInformation = false;
5573 A.getAssumedSimplified(IRPosition::value(*LI), /* AA */ nullptr,
5574 UsedAssumedInformation, AA::Interprocedural);
5575 continue;
5576 }
5577 if (auto *CI = dyn_cast<CallBase>(&I)) {
5578 if (CI->isIndirectCall())
5579 A.getOrCreateAAFor<AAIndirectCallInfo>(
5581 }
5582 if (auto *SI = dyn_cast<StoreInst>(&I)) {
5583 A.getOrCreateAAFor<AAIsDead>(IRPosition::value(*SI));
5584 continue;
5585 }
5586 if (auto *FI = dyn_cast<FenceInst>(&I)) {
5587 A.getOrCreateAAFor<AAIsDead>(IRPosition::value(*FI));
5588 continue;
5589 }
5590 if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
5591 if (II->getIntrinsicID() == Intrinsic::assume) {
5592 A.getOrCreateAAFor<AAPotentialValues>(
5593 IRPosition::value(*II->getArgOperand(0)));
5594 continue;
5595 }
5596 }
5597 }
5598}
5599
5600const char AAICVTracker::ID = 0;
5601const char AAKernelInfo::ID = 0;
5602const char AAExecutionDomain::ID = 0;
5603const char AAHeapToShared::ID = 0;
5604const char AAFoldRuntimeCall::ID = 0;
5605
5606AAICVTracker &AAICVTracker::createForPosition(const IRPosition &IRP,
5607 Attributor &A) {
5608 AAICVTracker *AA = nullptr;
5609 switch (IRP.getPositionKind()) {
5614 llvm_unreachable("ICVTracker can only be created for function position!");
5616 AA = new (A.Allocator) AAICVTrackerFunctionReturned(IRP, A);
5617 break;
5619 AA = new (A.Allocator) AAICVTrackerCallSiteReturned(IRP, A);
5620 break;
5622 AA = new (A.Allocator) AAICVTrackerCallSite(IRP, A);
5623 break;
5625 AA = new (A.Allocator) AAICVTrackerFunction(IRP, A);
5626 break;
5627 }
5628
5629 return *AA;
5630}
5631
5633 Attributor &A) {
5634 AAExecutionDomainFunction *AA = nullptr;
5635 switch (IRP.getPositionKind()) {
5644 "AAExecutionDomain can only be created for function position!");
5646 AA = new (A.Allocator) AAExecutionDomainFunction(IRP, A);
5647 break;
5648 }
5649
5650 return *AA;
5651}
5652
5653AAHeapToShared &AAHeapToShared::createForPosition(const IRPosition &IRP,
5654 Attributor &A) {
5655 AAHeapToSharedFunction *AA = nullptr;
5656 switch (IRP.getPositionKind()) {
5665 "AAHeapToShared can only be created for function position!");
5667 AA = new (A.Allocator) AAHeapToSharedFunction(IRP, A);
5668 break;
5669 }
5670
5671 return *AA;
5672}
5673
5674AAKernelInfo &AAKernelInfo::createForPosition(const IRPosition &IRP,
5675 Attributor &A) {
5676 AAKernelInfo *AA = nullptr;
5677 switch (IRP.getPositionKind()) {
5684 llvm_unreachable("KernelInfo can only be created for function position!");
5686 AA = new (A.Allocator) AAKernelInfoCallSite(IRP, A);
5687 break;
5689 AA = new (A.Allocator) AAKernelInfoFunction(IRP, A);
5690 break;
5691 }
5692
5693 return *AA;
5694}
5695
5696AAFoldRuntimeCall &AAFoldRuntimeCall::createForPosition(const IRPosition &IRP,
5697 Attributor &A) {
5698 AAFoldRuntimeCall *AA = nullptr;
5699 switch (IRP.getPositionKind()) {
5707 llvm_unreachable("KernelInfo can only be created for call site position!");
5709 AA = new (A.Allocator) AAFoldRuntimeCallCallSiteReturned(IRP, A);
5710 break;
5711 }
5712
5713 return *AA;
5714}
5715
5717 if (!containsOpenMP(M))
5718 return PreservedAnalyses::all();
5720 return PreservedAnalyses::all();
5721
5724 KernelSet Kernels = getDeviceKernels(M);
5725
5727 LLVM_DEBUG(dbgs() << TAG << "Module before OpenMPOpt Module Pass:\n" << M);
5728
5729 auto IsCalled = [&](Function &F) {
5730 if (Kernels.contains(&F))
5731 return true;
5732 for (const User *U : F.users())
5733 if (!isa<BlockAddress>(U))
5734 return true;
5735 return false;
5736 };
5737
5738 auto EmitRemark = [&](Function &F) {
5740 ORE.emit([&]() {
5741 OptimizationRemarkAnalysis ORA(DEBUG_TYPE, "OMP140", &F);
5742 return ORA << "Could not internalize function. "
5743 << "Some optimizations may not be possible. [OMP140]";
5744 });
5745 };
5746
5747 bool Changed = false;
5748
5749 // Create internal copies of each function if this is a kernel Module. This
5750 // allows iterprocedural passes to see every call edge.
5751 DenseMap<Function *, Function *> InternalizedMap;
5752 if (isOpenMPDevice(M)) {
5753 SmallPtrSet<Function *, 16> InternalizeFns;
5754 for (Function &F : M)
5755 if (!F.isDeclaration() && !Kernels.contains(&F) && IsCalled(F) &&
5758 InternalizeFns.insert(&F);
5759 } else if (!F.hasLocalLinkage() && !F.hasFnAttribute(Attribute::Cold)) {
5760 EmitRemark(F);
5761 }
5762 }
5763
5764 Changed |=
5765 Attributor::internalizeFunctions(InternalizeFns, InternalizedMap);
5766 }
5767
5768 // Look at every function in the Module unless it was internalized.
5769 SetVector<Function *> Functions;
5771 for (Function &F : M)
5772 if (!F.isDeclaration() && !InternalizedMap.lookup(&F)) {
5773 SCC.push_back(&F);
5774 Functions.insert(&F);
5775 }
5776
5777 if (SCC.empty())
5778 return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
5779
5780 AnalysisGetter AG(FAM);
5781
5782 auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
5784 };
5785
5787 CallGraphUpdater CGUpdater;
5788
5789 bool PostLink = LTOPhase == ThinOrFullLTOPhase::FullLTOPostLink ||
5791 OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ nullptr, PostLink);
5792
5793 unsigned MaxFixpointIterations =
5795
5796 AttributorConfig AC(CGUpdater);
5798 AC.IsModulePass = true;
5799 AC.RewriteSignatures = false;
5800 AC.MaxFixpointIterations = MaxFixpointIterations;
5801 AC.OREGetter = OREGetter;
5802 AC.PassName = DEBUG_TYPE;
5803 AC.InitializationCallback = OpenMPOpt::registerAAsForFunction;
5804 AC.IPOAmendableCB = [](const Function &F) {
5805 return F.hasFnAttribute("kernel");
5806 };
5807
5808 Attributor A(Functions, InfoCache, AC);
5809
5810 OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
5811 Changed |= OMPOpt.run(true);
5812
5813 // Optionally inline device functions for potentially better performance.
5815 for (Function &F : M)
5816 if (!F.isDeclaration() && !Kernels.contains(&F) &&
5817 !F.hasFnAttribute(Attribute::NoInline))
5818 F.addFnAttr(Attribute::AlwaysInline);
5819
5821 LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt Module Pass:\n" << M);
5822
5823 if (Changed)
5824 return PreservedAnalyses::none();
5825
5826 return PreservedAnalyses::all();
5827}
5828
5831 LazyCallGraph &CG,
5832 CGSCCUpdateResult &UR) {
5833 if (!containsOpenMP(*C.begin()->getFunction().getParent()))
5834 return PreservedAnalyses::all();
5836 return PreservedAnalyses::all();
5837
5839 // If there are kernels in the module, we have to run on all SCC's.
5840 for (LazyCallGraph::Node &N : C) {
5841 Function *Fn = &N.getFunction();
5842 SCC.push_back(Fn);
5843 }
5844
5845 if (SCC.empty())
5846 return PreservedAnalyses::all();
5847
5848 Module &M = *C.begin()->getFunction().getParent();
5849
5851 LLVM_DEBUG(dbgs() << TAG << "Module before OpenMPOpt CGSCC Pass:\n" << M);
5852
5853 KernelSet Kernels = getDeviceKernels(M);
5854
5856 AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager();
5857
5858 AnalysisGetter AG(FAM);
5859
5860 auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
5862 };
5863
5865 CallGraphUpdater CGUpdater;
5866 CGUpdater.initialize(CG, C, AM, UR);
5867
5868 bool PostLink = LTOPhase == ThinOrFullLTOPhase::FullLTOPostLink ||
5870 SetVector<Function *> Functions(SCC.begin(), SCC.end());
5871 OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, Allocator,
5872 /*CGSCC*/ &Functions, PostLink);
5873
5874 unsigned MaxFixpointIterations =
5876
5877 AttributorConfig AC(CGUpdater);
5879 AC.IsModulePass = false;
5880 AC.RewriteSignatures = false;
5881 AC.MaxFixpointIterations = MaxFixpointIterations;
5882 AC.OREGetter = OREGetter;
5883 AC.PassName = DEBUG_TYPE;
5884 AC.InitializationCallback = OpenMPOpt::registerAAsForFunction;
5885
5886 Attributor A(Functions, InfoCache, AC);
5887
5888 OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
5889 bool Changed = OMPOpt.run(false);
5890
5892 LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt CGSCC Pass:\n" << M);
5893
5894 if (Changed)
5895 return PreservedAnalyses::none();
5896
5897 return PreservedAnalyses::all();
5898}
5899
5901 return Fn.hasFnAttribute("kernel");
5902}
5903
5905 // TODO: Create a more cross-platform way of determining device kernels.
5906 NamedMDNode *MD = M.getNamedMetadata("nvvm.annotations");
5907 KernelSet Kernels;
5908
5909 if (!MD)
5910 return Kernels;
5911
5912 for (auto *Op : MD->operands()) {
5913 if (Op->getNumOperands() < 2)
5914 continue;
5915 MDString *KindID = dyn_cast<MDString>(Op->getOperand(1));
5916 if (!KindID || KindID->getString() != "kernel")
5917 continue;
5918
5919 Function *KernelFn =
5920 mdconst::dyn_extract_or_null<Function>(Op->getOperand(0));
5921 if (!KernelFn)
5922 continue;
5923
5924 // We are only interested in OpenMP target regions. Others, such as kernels
5925 // generated by CUDA but linked together, are not interesting to this pass.
5926 if (isOpenMPKernel(*KernelFn)) {
5927 ++NumOpenMPTargetRegionKernels;
5928 Kernels.insert(KernelFn);
5929 } else
5930 ++NumNonOpenMPTargetRegionKernels;
5931 }
5932
5933 return Kernels;
5934}
5935
5937 Metadata *MD = M.getModuleFlag("openmp");
5938 if (!MD)
5939 return false;
5940
5941 return true;
5942}
5943
5945 Metadata *MD = M.getModuleFlag("openmp-device");
5946 if (!MD)
5947 return false;
5948
5949 return true;
5950}
@ Generic
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Expand Atomic instructions
static cl::opt< unsigned > SetFixpointIterations("attributor-max-iterations", cl::Hidden, cl::desc("Maximal number of fixpoint iterations."), cl::init(32))
static const Function * getParent(const Value *V)
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
This file provides interfaces used to manipulate a call graph, regardless if it is a "old style" Call...
This file provides interfaces used to build and manipulate a call graph, which is a very useful tool ...
This file contains the declarations for the subclasses of Constant, which represent the different fla...
dxil pretty DXIL Metadata Pretty Printer
Returns the sub type a function will return at a given Idx Should correspond to the result type of an ExtractValue instruction executed with just that one unsigned Idx
#define LLVM_DEBUG(X)
Definition: Debug.h:101
bool End
Definition: ELF_riscv.cpp:480
This file defines an array type that can be indexed using scoped enum values.
#define DEBUG_TYPE
static bool lookup(const GsymReader &GR, DataExtractor &Data, uint64_t &Offset, uint64_t BaseAddr, uint64_t Addr, SourceLocations &SrcLocs, llvm::Error &Err)
A Lookup helper functions.
Definition: InlineInfo.cpp:109
#define F(x, y, z)
Definition: MD5.cpp:55
#define I(x, y, z)
Definition: MD5.cpp:58
This file provides utility analysis objects describing memory locations.
IntegerType * Int32Ty
This file defines constans and helpers used when dealing with OpenMP.
This file defines constans that will be used by both host and device compilation.
static constexpr auto TAG
Definition: OpenMPOpt.cpp:185
static cl::opt< bool > HideMemoryTransferLatency("openmp-hide-memory-transfer-latency", cl::desc("[WIP] Tries to hide the latency of host to device memory" " transfers"), cl::Hidden, cl::init(false))
static cl::opt< bool > DisableOpenMPOptStateMachineRewrite("openmp-opt-disable-state-machine-rewrite", cl::desc("Disable OpenMP optimizations that replace the state machine."), cl::Hidden, cl::init(false))
static cl::opt< bool > EnableParallelRegionMerging("openmp-opt-enable-merging", cl::desc("Enable the OpenMP region merging optimization."), cl::Hidden, cl::init(false))
static cl::opt< bool > AlwaysInlineDeviceFunctions("openmp-opt-inline-device", cl::desc("Inline all applicible functions on the device."), cl::Hidden, cl::init(false))
static cl::opt< bool > PrintModuleAfterOptimizations("openmp-opt-print-module-after", cl::desc("Print the current module after OpenMP optimizations."), cl::Hidden, cl::init(false))
#define KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MEMBER)
Definition: OpenMPOpt.cpp:241
#define KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MEMBER, IDX)
Definition: OpenMPOpt.cpp:218
#define KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MEMBER)
Definition: OpenMPOpt.cpp:3677
static cl::opt< bool > PrintOpenMPKernels("openmp-print-gpu-kernels", cl::init(false), cl::Hidden)
static cl::opt< bool > DisableOpenMPOptFolding("openmp-opt-disable-folding", cl::desc("Disable OpenMP optimizations involving folding."), cl::Hidden, cl::init(false))
static cl::opt< bool > PrintModuleBeforeOptimizations("openmp-opt-print-module-before", cl::desc("Print the current module before OpenMP optimizations."), cl::Hidden, cl::init(false))
static cl::opt< unsigned > SetFixpointIterations("openmp-opt-max-iterations", cl::Hidden, cl::desc("Maximal number of attributor iterations."), cl::init(256))
static cl::opt< bool > DisableInternalization("openmp-opt-disable-internalization", cl::desc("Disable function internalization."), cl::Hidden, cl::init(false))
static cl::opt< bool > PrintICVValues("openmp-print-icv-values", cl::init(false), cl::Hidden)
static cl::opt< bool > DisableOpenMPOptimizations("openmp-opt-disable", cl::desc("Disable OpenMP specific optimizations."), cl::Hidden, cl::init(false))
static cl::opt< unsigned > SharedMemoryLimit("openmp-opt-shared-limit", cl::Hidden, cl::desc("Maximum amount of shared memory to use."), cl::init(std::numeric_limits< unsigned >::max()))
static cl::opt< bool > EnableVerboseRemarks("openmp-opt-verbose-remarks", cl::desc("Enables more verbose remarks."), cl::Hidden, cl::init(false))
static cl::opt< bool > DisableOpenMPOptDeglobalization("openmp-opt-disable-deglobalization", cl::desc("Disable OpenMP optimizations involving deglobalization."), cl::Hidden, cl::init(false))
static cl::opt< bool > DisableOpenMPOptBarrierElimination("openmp-opt-disable-barrier-elimination", cl::desc("Disable OpenMP optimizations that eliminate barriers."), cl::Hidden, cl::init(false))
#define DEBUG_TYPE
Definition: OpenMPOpt.cpp:67
static cl::opt< bool > DeduceICVValues("openmp-deduce-icv-values", cl::init(false), cl::Hidden)
#define KERNEL_ENVIRONMENT_IDX(MEMBER, IDX)
Definition: OpenMPOpt.cpp:210
#define KERNEL_ENVIRONMENT_GETTER(MEMBER, RETURNTYPE)
Definition: OpenMPOpt.cpp:231
static cl::opt< bool > DisableOpenMPOptSPMDization("openmp-opt-disable-spmdization", cl::desc("Disable OpenMP optimizations involving SPMD-ization."), cl::Hidden, cl::init(false))
#define P(N)
if(VerifyEach)
FunctionAnalysisManager FAM
This file builds on the ADT/GraphTraits.h file to build a generic graph post order iterator.
static StringRef getName(Value *V)
R600 Clause Merge
Basic Register Allocator
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
This file implements a set that has insertion order iteration characteristics.
This file defines the SmallPtrSet class.
This file defines the SmallVector class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition: Statistic.h:167
This file contains some functions that are useful when dealing with strings.
static const int BlockSize
Definition: TarWriter.cpp:33
static void initialize(TargetLibraryInfoImpl &TLI, const Triple &T, ArrayRef< StringLiteral > StandardNames)
Initialize the set of available library functions based on the specified target triple.
static bool contains(SmallPtrSetImpl< ConstantExpr * > &Cache, ConstantExpr *Expr, Constant *C)
Definition: Value.cpp:469
Value * RHS
static cl::opt< unsigned > MaxThreads("xcore-max-threads", cl::Optional, cl::desc("Maximum number of threads (for emulation thread-local storage)"), cl::Hidden, cl::value_desc("number"), cl::init(8))
AbstractCallSite.
This class represents a conversion between pointers from one address space to another.
an instruction to allocate memory on the stack
Definition: Instructions.h:59
A container for analyses that lazily runs them and caches their results.
Definition: PassManager.h:348
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Definition: PassManager.h:500
This class represents an incoming formal argument to a Function.
Definition: Argument.h:31
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition: ArrayRef.h:41
size_t size() const
size - Get the array size.
Definition: ArrayRef.h:165
AttributeSet getParamAttrs(unsigned ArgNo) const
The attributes for the argument or parameter at the given index are returned.
LLVM Basic Block Representation.
Definition: BasicBlock.h:60
iterator end()
Definition: BasicBlock.h:442
iterator begin()
Instruction iterator methods.
Definition: BasicBlock.h:429
const_iterator getFirstInsertionPt() const
Returns an iterator to the first instruction in this block that is suitable for inserting a non-PHI i...
Definition: BasicBlock.cpp:398
reverse_iterator rbegin()
Definition: BasicBlock.h:445
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
Definition: BasicBlock.h:198
BasicBlock * splitBasicBlock(iterator I, const Twine &BBName="", bool Before=false)
Split the basic block into two basic blocks at the specified instruction.
Definition: BasicBlock.cpp:559
const BasicBlock * getUniqueSuccessor() const
Return the successor of this block if it has a unique successor.
Definition: BasicBlock.cpp:479
InstListType::reverse_iterator reverse_iterator
Definition: BasicBlock.h:166
const Function * getParent() const
Return the enclosing method, or null if none.
Definition: BasicBlock.h:205
reverse_iterator rend()
Definition: BasicBlock.h:447
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition: BasicBlock.h:220
Conditional or Unconditional Branch instruction.
static BranchInst * Create(BasicBlock *IfTrue, BasicBlock::iterator InsertBefore)
bool isConditional() const
BasicBlock * getSuccessor(unsigned i) const
Value * getCondition() const
Allocate memory in an ever growing pool, as if by bump-pointer.
Definition: Allocator.h:66
Base class for all callable instructions (InvokeInst and CallInst) Holds everything related to callin...
Definition: InstrTypes.h:1461
void setCallingConv(CallingConv::ID CC)
Definition: InstrTypes.h:1771
bool arg_empty() const
Definition: InstrTypes.h:1651
Function * getCalledFunction() const
Returns the function called, or null if this is an indirect function invocation or the function signa...
Definition: InstrTypes.h:1709
bool doesNotAccessMemory(unsigned OpNo) const
Definition: InstrTypes.h:2043
bool isIndirectCall() const
Return true if the callsite is an indirect call.
bool isCallee(Value::const_user_iterator UI) const
Determine whether the passed iterator points to the callee operand's Use.
Definition: InstrTypes.h:1720
Value * getArgOperand(unsigned i) const
Definition: InstrTypes.h:1654
void setArgOperand(unsigned i, Value *v)
Definition: InstrTypes.h:1659
iterator_range< User::op_iterator > args()
Iteration adapter for range-for loops.
Definition: InstrTypes.h:1645
unsigned getArgOperandNo(const Use *U) const
Given a use for a arg operand, get the arg operand number that corresponds to it.
Definition: InstrTypes.h:1685
unsigned arg_size() const
Definition: InstrTypes.h:1652
AttributeList getAttributes() const
Return the parameter attributes for this call.
Definition: InstrTypes.h:1786
void addParamAttr(unsigned ArgNo, Attribute::AttrKind Kind)
Adds the attribute to the indicated argument.
Definition: InstrTypes.h:1838
bool isArgOperand(const Use *U) const
Definition: InstrTypes.h:1674
bool hasOperandBundles() const
Return true if this User has any operand bundles.
Definition: InstrTypes.h:2285
Function * getCaller()
Helper to get the caller (the parent function).
Wrapper to unify "old style" CallGraph and "new style" LazyCallGraph.
void initialize(CallGraph &CG, CallGraphSCC &SCC)
Initializers for usage outside of a CGSCC pass, inside a CGSCC pass in the old and new pass manager (...
This class represents a function call, abstracting a target machine's calling convention.
static CallInst * Create(FunctionType *Ty, Value *F, const Twine &NameStr, BasicBlock::iterator InsertBefore)
@ ICMP_SLT
signed less than
Definition: InstrTypes.h:989
@ ICMP_EQ
equal
Definition: InstrTypes.h:981
@ ICMP_NE
not equal
Definition: InstrTypes.h:982
static Constant * getPointerCast(Constant *C, Type *Ty)
Create a BitCast, AddrSpaceCast, or a PtrToInt cast constant expression.
Definition: Constants.cpp:2072
static Constant * getPointerBitCastOrAddrSpaceCast(Constant *C, Type *Ty)
Create a BitCast or AddrSpaceCast for a pointer type depending on the address space.
Definition: Constants.cpp:2087
This is the shared class of boolean and integer constants.
Definition: Constants.h:80
IntegerType * getIntegerType() const
Variant of the getType() method to always return an IntegerType, which reduces the amount of casting ...
Definition: Constants.h:184
static ConstantInt * getTrue(LLVMContext &Context)
Definition: Constants.cpp:849
bool isZero() const
This is just a convenience method to make client code smaller for a common code.
Definition: Constants.h:205
int64_t getSExtValue() const
Return the constant as a 64-bit integer value after it has been sign extended as appropriate for the ...
Definition: Constants.h:160
This is an important base class in LLVM.
Definition: Constant.h:41
static Constant * getNullValue(Type *Ty)
Constructor to create a '0' constant of arbitrary type.
Definition: Constants.cpp:370
This class represents an Operation in the Expression.
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:110
A debug info location.
Definition: DebugLoc.h:33
ValueT lookup(const_arg_type_t< KeyT > Val) const
lookup - Return the entry for the specified key, or a default constructed value if no such entry exis...
Definition: DenseMap.h:202
iterator begin()
Definition: DenseMap.h:75
iterator end()
Definition: DenseMap.h:84
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition: DenseMap.h:220
Implements a dense probed hash-table based set.
Definition: DenseSet.h:271
Analysis pass which computes a DominatorTree.
Definition: Dominators.h:279
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition: Dominators.h:162
Instruction * findNearestCommonDominator(Instruction *I1, Instruction *I2) const
Find the nearest instruction I that dominates both I1 and I2, in the sense that a result produced bef...
Definition: Dominators.cpp:344
An instruction for ordering other memory operations.
Definition: Instructions.h:460
AtomicOrdering getOrdering() const
Returns the ordering constraint of this fence instruction.
Definition: Instructions.h:487
A proxy from a FunctionAnalysisManager to an SCC.
A handy container for a FunctionType+Callee-pointer pair, which can be passed around as a single enti...
Definition: DerivedTypes.h:168
const BasicBlock & getEntryBlock() const
Definition: Function.h:782
const BasicBlock & front() const
Definition: Function.h:805
LLVMContext & getContext() const
getContext - Return a reference to the LLVMContext associated with this function.
Definition: Function.cpp:350
Argument * getArg(unsigned i) const
Definition: Function.h:831
bool hasFnAttribute(Attribute::AttrKind Kind) const
Return true if the function has the attribute.
Definition: Function.cpp:677
bool isDeclaration() const
Return true if the primary definition of this global value is outside of the current translation unit...
Definition: Globals.cpp:274
bool hasLocalLinkage() const
Definition: GlobalValue.h:527
Module * getParent()
Get the module that this global value is contained inside of...
Definition: GlobalValue.h:655
PointerType * getType() const
Global values are always pointers.
Definition: GlobalValue.h:294
@ PrivateLinkage
Like Internal, but omit from symbol table.
Definition: GlobalValue.h:60
@ InternalLinkage
Rename collisions when linking (static functions).
Definition: GlobalValue.h:59
const Constant * getInitializer() const
getInitializer - Return the initializer for this global variable.
void setInitializer(Constant *InitVal)
setInitializer - Sets the initializer for this global variable, removing any existing initializer if ...
Definition: Globals.cpp:459
void eraseFromParent()
eraseFromParent - This method unlinks 'this' from the containing module and deletes it.
Definition: Globals.cpp:455
InsertPoint - A saved insertion point.
Definition: IRBuilder.h:251
AllocaInst * CreateAlloca(Type *Ty, unsigned AddrSpace, Value *ArraySize=nullptr, const Twine &Name="")
Definition: IRBuilder.h:1767
void SetInsertPoint(BasicBlock *TheBB)
This specifies that created instructions should be appended to the end of the specified block.
Definition: IRBuilder.h:180
Value * CreateAddrSpaceCast(Value *V, Type *DestTy, const Twine &Name="")
Definition: IRBuilder.h:2110
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:2644
An analysis over an "outer" IR unit that provides access to an analysis manager over an "inner" IR un...
Definition: PassManager.h:658
bool isLifetimeStartOrEnd() const LLVM_READONLY
Return true if the instruction is a llvm.lifetime.start or llvm.lifetime.end marker.
bool mayWriteToMemory() const LLVM_READONLY
Return true if this instruction may modify memory.
const DebugLoc & getDebugLoc() const
Return the debug location for this node as a DebugLoc.
Definition: Instruction.h:454
const Module * getModule() const
Return the module owning the function this instruction belongs to or nullptr it the function does not...
Definition: Instruction.cpp:80
const Instruction * getPrevNonDebugInstruction(bool SkipPseudoOp=false) const
Return a pointer to the previous non-debug instruction in the same basic block as 'this',...
const BasicBlock * getParent() const
Definition: Instruction.h:152
InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
const Function * getFunction() const
Return the function this instruction belongs to.
Definition: Instruction.cpp:84
bool mayHaveSideEffects() const LLVM_READONLY
Return true if the instruction may have side effects.
bool mayReadFromMemory() const LLVM_READONLY
Return true if this instruction may read memory.
const Instruction * getNextNonDebugInstruction(bool SkipPseudoOp=false) const
Return a pointer to the next non-debug instruction in the same basic block as 'this',...
void setDebugLoc(DebugLoc Loc)
Set the debug location information for this instruction.
Definition: Instruction.h:451
void setSuccessor(unsigned Idx, BasicBlock *BB)
Update the specified successor to point at the provided block.
void moveBefore(Instruction *MovePos)
Unlink this instruction from its current basic block and insert it into the basic block that MovePos ...
A node in the call graph.
An SCC of the call graph.
A lazily constructed view of the call graph of a module.
An instruction for reading from memory.
Definition: Instructions.h:184
A single uniqued string.
Definition: Metadata.h:720
StringRef getString() const
Definition: Metadata.cpp:610
void eraseFromParent()
This method unlinks 'this' from the containing function and deletes it.
Root of the metadata hierarchy.
Definition: Metadata.h:62
A Module instance is used to store all the information related to an LLVM module.
Definition: Module.h:65
const std::string & getTargetTriple() const
Get the target triple which is a string describing the target host.
Definition: Module.h:291
MutableArrayRef - Represent a mutable reference to an array (0 or more elements consecutively in memo...
Definition: ArrayRef.h:307
A tuple of MDNodes.
Definition: Metadata.h:1729
iterator_range< op_iterator > operands()
Definition: Metadata.h:1825
An interface to create LLVM-IR for OpenMP directives.
Definition: OMPIRBuilder.h:451
static std::pair< int32_t, int32_t > readThreadBoundsForKernel(const Triple &T, Function &Kernel)
}
void addAttributes(omp::RuntimeFunction FnID, Function &Fn)
Add attributes known for FnID to Fn.
IRBuilder<>::InsertPoint InsertPointTy
Type used throughout for insertion points.
Definition: OMPIRBuilder.h:477
static std::pair< int32_t, int32_t > readTeamBoundsForKernel(const Triple &T, Function &Kernel)
Read/write a bounds on teams for Kernel.
PreservedAnalyses run(LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM, LazyCallGraph &CG, CGSCCUpdateResult &UR)
Definition: OpenMPOpt.cpp:5829
PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM)
Definition: OpenMPOpt.cpp:5716
Diagnostic information for optimization analysis remarks.
The optimization diagnostic interface.
Diagnostic information for missed-optimization remarks.
Diagnostic information for applied optimization remarks.
static PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
Definition: Constants.cpp:1827
A set of analyses that are preserved following a run of a transformation pass.
Definition: Analysis.h:109
static PreservedAnalyses none()
Convenience factory function for the empty preserved set.
Definition: Analysis.h:112
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition: Analysis.h:115
static ReturnInst * Create(LLVMContext &C, Value *retVal, BasicBlock::iterator InsertBefore)
A vector that has set insertion semantics.
Definition: SetVector.h:57
size_type size() const
Determine the number of elements in the SetVector.
Definition: SetVector.h:98
const value_type & back() const
Return the last element of the SetVector.
Definition: SetVector.h:149
iterator end()
Get an iterator to the end of the SetVector.
Definition: SetVector.h:113
size_type count(const key_type &key) const
Count the number of elements of a given key in the SetVector.
Definition: SetVector.h:264
iterator begin()
Get an iterator to the beginning of the SetVector.
Definition: SetVector.h:103
bool insert(const value_type &X)
Insert a new element into the SetVector.
Definition: SetVector.h:162
size_type size() const
Definition: SmallPtrSet.h:94
A templated base class for SmallPtrSet which provides the typesafe interface that is common across al...
Definition: SmallPtrSet.h:321
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
Definition: SmallPtrSet.h:360
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
Definition: SmallPtrSet.h:342
iterator begin() const
Definition: SmallPtrSet.h:380
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
Definition: SmallPtrSet.h:427
A SetVector that performs no allocations if smaller than a certain size.
Definition: SetVector.h:370
bool empty() const
Definition: SmallVector.h:94
size_t size() const
Definition: SmallVector.h:91
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
Definition: SmallVector.h:586
void assign(size_type NumElts, ValueParamT Elt)
Definition: SmallVector.h:717
reference emplace_back(ArgTypes &&... Args)
Definition: SmallVector.h:950
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
Definition: SmallVector.h:696
void push_back(const T &Elt)
Definition: SmallVector.h:426
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1209
An instruction for storing to memory.
Definition: Instructions.h:317
StringRef - Represent a constant reference to a string, i.e.
Definition: StringRef.h:50
bool starts_with(StringRef Prefix) const
Check if this string starts with the given Prefix.
Definition: StringRef.h:257
Triple - Helper class for working with autoconf configuration names.
Definition: Triple.h:44
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
PointerType * getPointerTo(unsigned AddrSpace=0) const
Return a pointer to the current type.
unsigned getPointerAddressSpace() const
Get the address space of this pointer or pointer vector type.
static Type * getVoidTy(LLVMContext &C)
static IntegerType * getInt16Ty(LLVMContext &C)
static IntegerType * getInt8Ty(LLVMContext &C)
static IntegerType * getInt32Ty(LLVMContext &C)
static UndefValue * get(Type *T)
Static factory methods - Return an 'undef' object of the specified type.
Definition: Constants.cpp:1808
A Use represents the edge between a Value definition and its users.
Definition: Use.h:43
bool replaceUsesOfWith(Value *From, Value *To)
Replace uses of one Value with another.
Definition: User.cpp:21
LLVM Value Representation.
Definition: Value.h:74
Type * getType() const
All values are typed, get the type of this value.
Definition: Value.h:255
void setName(const Twine &Name)
Change the name of the value.
Definition: Value.cpp:377
bool hasOneUse() const
Return true if there is exactly one use of this value.
Definition: Value.h:434
void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition: Value.cpp:534
iterator_range< user_iterator > users()
Definition: Value.h:421
User * user_back()
Definition: Value.h:407
const Value * stripPointerCasts() const
Strip off pointer casts, all-zero GEPs and address space casts.
Definition: Value.cpp:693
unsigned getNumUses() const
This method computes the number of uses of this Value.
Definition: Value.cpp:255
StringRef getName() const
Return a constant reference to the value's name.
Definition: Value.cpp:309
An efficient, type-erasing, non-owning reference to a callable.
self_iterator getIterator()
Definition: ilist_node.h:109
NodeTy * getNextNode()
Get the next node, or nullptr for the list tail.
Definition: ilist_node.h:316
A raw_ostream that writes to an std::string.
Definition: raw_ostream.h:660
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
GlobalVariable * getKernelEnvironementGVFromKernelInitCB(CallBase *KernelInitCB)
Definition: OpenMPOpt.cpp:260
ConstantStruct * getKernelEnvironementFromKernelInitCB(CallBase *KernelInitCB)
Definition: OpenMPOpt.cpp:267
bool isValidAtPosition(const ValueAndContext &VAC, InformationCache &InfoCache)
Return true if the value of VAC is a valid at the position of VAC, that is a constant,...
Definition: Attributor.cpp:291
bool isPotentiallyAffectedByBarrier(Attributor &A, const Instruction &I, const AbstractAttribute &QueryingAA)
Return true if I is potentially affected by a barrier.
Definition: Attributor.cpp:890
@ Interprocedural
Definition: Attributor.h:181
bool isNoSyncInst(Attributor &A, const Instruction &I, const AbstractAttribute &QueryingAA)
Return true if I is a nosync instruction.
Definition: Attributor.cpp:206
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
E & operator^=(E &LHS, E RHS)
Definition: BitmaskEnum.h:181
@ C
The default llvm calling convention, compatible with C.
Definition: CallingConv.h:34
initializer< Ty > init(const Ty &Val)
Definition: CommandLine.h:450
std::optional< const char * > toString(const std::optional< DWARFFormValue > &V)
Take an optional DWARFFormValue and try to extract a string value from it.
PointerTypeMap run(const Module &M)
Compute the PointerTypeMap for the module M.
bool isOpenMPDevice(Module &M)
Helper to determine if M is a OpenMP target offloading device module.
Definition: OpenMPOpt.cpp:5944
bool containsOpenMP(Module &M)
Helper to determine if M contains OpenMP.
Definition: OpenMPOpt.cpp:5936
InternalControlVar
IDs for all Internal Control Variables (ICVs).
Definition: OMPConstants.h:26
RuntimeFunction
IDs for all omp runtime library (RTL) functions.
Definition: OMPConstants.h:45
KernelSet getDeviceKernels(Module &M)
Get OpenMP device kernels in M.
Definition: OpenMPOpt.cpp:5904
Function * Kernel
Summary of a kernel (=entry point for target offloading).
Definition: OpenMPOpt.h:21
@ OMP_TGT_EXEC_MODE_GENERIC_SPMD
bool isOpenMPKernel(Function &Fn)
Return true iff Fn is an OpenMP GPU kernel; Fn has the "kernel" attribute.
Definition: OpenMPOpt.cpp:5900
DiagnosticInfoOptimizationBase::Argument NV
const_iterator begin(StringRef path, Style style=Style::native)
Get begin iterator over path.
Definition: Path.cpp:227
const_iterator end(StringRef path)
Get end iterator over path.
Definition: Path.cpp:236
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
Definition: STLExtras.h:329
@ Offset
Definition: DWP.cpp:456
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1731
auto size(R &&Range, std::enable_if_t< std::is_base_of< std::random_access_iterator_tag, typename std::iterator_traits< decltype(Range.begin())>::iterator_category >::value, void > *=nullptr)
Get the size of a range.
Definition: STLExtras.h:1689
bool succ_empty(const Instruction *I)
Definition: CFG.h:255
std::string to_string(const T &Value)
Definition: ScopedPrinter.h:85
bool operator!=(uint64_t V1, const APInt &V2)
Definition: APInt.h:2043
Value * GetPointerBaseWithConstantOffset(Value *Ptr, int64_t &Offset, const DataLayout &DL, bool AllowNonInbounds=true)
Analyze the specified pointer to see if it can be expressed as a base pointer plus a constant offset.
const Value * getUnderlyingObject(const Value *V, unsigned MaxLookup=6)
This method strips off any GEP address adjustments and pointer casts from the specified value,...
bool operator==(const AddressRangeValuePair &LHS, const AddressRangeValuePair &RHS)
@ CGSCC
Definition: Attributor.h:6428
@ FullLTOPostLink
Full LTO postlink (backend compile) phase.
@ ThinLTOPreLink
ThinLTO prelink (summary) phase.
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:163
bool operator&=(SparseBitVector< ElementSize > *LHS, const SparseBitVector< ElementSize > &RHS)
auto count(R &&Range, const E &Element)
Wrapper function around std::count to count the number of times an element Element occurs in the give...
Definition: STLExtras.h:1923
BasicBlock * SplitBlock(BasicBlock *Old, BasicBlock::iterator SplitPt, DominatorTree *DT, LoopInfo *LI=nullptr, MemorySSAUpdater *MSSAU=nullptr, const Twine &BBName="", bool Before=false)
Split the specified block at the specified instruction.
auto predecessors(const MachineBasicBlock *BB)
ChangeStatus
{
Definition: Attributor.h:483
Constant * ConstantFoldInsertValueInstruction(Constant *Agg, Constant *Val, ArrayRef< unsigned > Idxs)
ConstantFoldInsertValueInstruction - Attempt to constant fold an insertvalue instruction with the spe...
Implement std::hash so that hash_code can be used in STL containers.
Definition: BitVector.h:858
#define N
An abstract attribute for getting assumption information.
Definition: Attributor.h:6163
An abstract state for querying live call edges.
Definition: Attributor.h:5487
virtual const SetVector< Function * > & getOptimisticEdges() const =0
Get the optimistic edges.
bool isExecutedByInitialThreadOnly(const Instruction &I) const
Check if an instruction is executed only by the initial thread.
Definition: Attributor.h:5647
static AAExecutionDomain & createForPosition(const IRPosition &IRP, Attributor &A)
Create an abstract attribute view for the position IRP.
Definition: OpenMPOpt.cpp:5632
virtual ExecutionDomainTy getFunctionExecutionDomain() const =0
virtual ExecutionDomainTy getExecutionDomain(const BasicBlock &) const =0
virtual bool isExecutedInAlignedRegion(Attributor &A, const Instruction &I) const =0
Check if the instruction I is executed in an aligned region, that is, the synchronizing effects befor...
virtual bool isNoOpFence(const FenceInst &FI) const =0
Helper function to determine if FI is a no-op given the information about its execution from ExecDoma...
static const char ID
Unique ID (due to the unique address)
Definition: Attributor.h:5678
An abstract interface for indirect call information interference.
Definition: Attributor.h:6355
An abstract interface for liveness abstract attribute.
Definition: Attributor.h:3974
An abstract interface for all memory location attributes (readnone/argmemonly/inaccessiblememonly/ina...
Definition: Attributor.h:4703
AccessKind
Simple enum to distinguish read/write/read-write accesses.
Definition: Attributor.h:4839
StateType::base_t MemoryLocationsKind
Definition: Attributor.h:4704
static bool isAlignedBarrier(const CallBase &CB, bool ExecutedAligned)
Helper function to determine if CB is an aligned (GPU) barrier.
An abstract Attribute for determining the necessity of the convergent attribute.
Definition: Attributor.h:5723
An abstract attribute for getting all assumption underlying objects.
Definition: Attributor.h:6195
Base struct for all "concrete attribute" deductions.
Definition: Attributor.h:3282
virtual ChangeStatus manifest(Attributor &A)
Hook for the Attributor to trigger the manifestation of the information represented by the abstract a...
Definition: Attributor.h:3397
virtual void initialize(Attributor &A)
Initialize the state with the information in the Attributor A.
Definition: Attributor.h:3346
virtual const std::string getAsStr(Attributor *A) const =0
This function should return the "summarized" assumed state as string.
virtual ChangeStatus updateImpl(Attributor &A)=0
The actual update/transfer function which has to be implemented by the derived classes.
virtual void trackStatistics() const =0
Hook to enable custom statistic tracking, called after manifest that resulted in a change if statisti...
virtual const char * getIdAddr() const =0
This function should return the address of the ID of the AbstractAttribute.
An interface to query the internal state of an abstract attribute.
Definition: Attributor.h:2602
virtual ChangeStatus indicatePessimisticFixpoint()=0
Indicate that the abstract state should converge to the pessimistic state.
virtual bool isAtFixpoint() const =0
Return if this abstract state is fixed, thus does not need to be updated if information changes as it...
virtual bool isValidState() const =0
Return if this abstract state is in a valid state.
virtual ChangeStatus indicateOptimisticFixpoint()=0
Indicate that the abstract state should converge to the optimistic state.
Wrapper for FunctionAnalysisManager.
Definition: Attributor.h:1121
Configuration for the Attributor.
Definition: Attributor.h:1413
std::function< void(Attributor &A, const Function &F)> InitializationCallback
Callback function to be invoked on internal functions marked live.
Definition: Attributor.h:1444
std::optional< unsigned > MaxFixpointIterations
Maximum number of iterations to run until fixpoint.
Definition: Attributor.h:1460
bool RewriteSignatures
Flag to determine if we rewrite function signatures.
Definition: Attributor.h:1430
const char * PassName
}
Definition: Attributor.h:1470
OptimizationRemarkGetter OREGetter
Definition: Attributor.h:1466
IPOAmendableCBTy IPOAmendableCB
Definition: Attributor.h:1473
bool IsModulePass
Is the user of the Attributor a module pass or not.
Definition: Attributor.h:1424
bool DefaultInitializeLiveInternals
Flag to determine if we want to initialize all default AAs for an internal function marked live.
Definition: Attributor.h:1434
The fixpoint analysis framework that orchestrates the attribute deduction.
Definition: Attributor.h:1507
static bool isInternalizable(Function &F)
Returns true if the function F can be internalized.
std::function< std::optional< Constant * >(const GlobalVariable &, const AbstractAttribute *, bool &)> GlobalVariableSimplifictionCallbackTy
Register CB as a simplification callback.
Definition: Attributor.h:2031
std::function< bool(Attributor &, const AbstractAttribute *)> VirtualUseCallbackTy
Definition: Attributor.h:2061
static bool internalizeFunctions(SmallPtrSetImpl< Function * > &FnSet, DenseMap< Function *, Function * > &FnMap)
Make copies of each function in the set FnSet such that the copied version has internal linkage after...
std::function< std::optional< Value * >(const IRPosition &, const AbstractAttribute *, bool &)> SimplifictionCallbackTy
Register CB as a simplification callback.
Definition: Attributor.h:2015
Simple wrapper for a single bit (boolean) state.
Definition: Attributor.h:2886
Support structure for SCC passes to communicate updates the call graph back to the CGSCC pass manager...
Helper to describe and deal with positions in the LLVM-IR.
Definition: Attributor.h:580
static const IRPosition callsite_returned(const CallBase &CB)
Create a position describing the returned value of CB.
Definition: Attributor.h:648
static const IRPosition returned(const Function &F, const CallBaseContext *CBContext=nullptr)
Create a position describing the returned value of F.
Definition: Attributor.h:630
static const IRPosition value(const Value &V, const CallBaseContext *CBContext=nullptr)
Create a position describing the value of V.
Definition: Attributor.h:604
static const IRPosition inst(const Instruction &I, const CallBaseContext *CBContext=nullptr)
Create a position describing the instruction I.
Definition: Attributor.h:616
@ IRP_ARGUMENT
An attribute for a function argument.
Definition: Attributor.h:594
@ IRP_RETURNED
An attribute for the function return value.
Definition: Attributor.h:590
@ IRP_CALL_SITE
An attribute for a call site (function scope).
Definition: Attributor.h:593
@ IRP_CALL_SITE_RETURNED
An attribute for a call site return value.
Definition: Attributor.h:591
@ IRP_FUNCTION
An attribute for a function (scope).
Definition: Attributor.h:592
@ IRP_FLOAT
A position that is not associated with a spot suitable for attributes.
Definition: Attributor.h:588
@ IRP_CALL_SITE_ARGUMENT
An attribute for a call site argument.
Definition: Attributor.h:595
@ IRP_INVALID
An invalid position.
Definition: Attributor.h:587
static const IRPosition function(const Function &F, const CallBaseContext *CBContext=nullptr)
Create a position describing the function scope of F.
Definition: Attributor.h:623
Kind getPositionKind() const
Return the associated position kind.
Definition: Attributor.h:876
static const IRPosition callsite_function(const CallBase &CB)
Create a position describing the function scope of CB.
Definition: Attributor.h:643
Function * getAnchorScope() const
Return the Function surrounding the anchor value.
Definition: Attributor.h:752
Data structure to hold cached (LLVM-IR) information.
Definition: Attributor.h:1197
bool isValidState() const override
See AbstractState::isValidState() NOTE: For now we simply pretend that the worst possible state is in...
Definition: Attributor.h:2661
ChangeStatus indicatePessimisticFixpoint() override
See AbstractState::indicatePessimisticFixpoint(...)
Definition: Attributor.h:2673
bool operator==(const IntegerStateBase< base_t, BestState, WorstState > &R) const
Equality for IntegerStateBase.
Definition: Attributor.h:2686
Direction
An enum for the direction of the loop.
Definition: LoopInfo.h:220
This struct is a compact representation of a valid (power of two) or undefined (0) alignment.
Definition: Alignment.h:117
Description of a LLVM-IR insertion point (IP) and a debug/source location (filename,...
Definition: OMPIRBuilder.h:593
Helper to tie a abstract state implementation to an abstract attribute.
Definition: Attributor.h:3171
StateType & getState() override
See AbstractAttribute::getState(...).
Definition: Attributor.h:3179
Defines various target-specific GPU grid values that must be consistent between host RTL (plugin),...
Definition: OMPGridValues.h:57