LLVM 23.0.0git
OpenMPOpt.cpp
Go to the documentation of this file.
1//===-- IPO/OpenMPOpt.cpp - Collection of OpenMP specific optimizations ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// OpenMP specific optimizations:
10//
11// - Deduplication of runtime calls, e.g., omp_get_thread_num.
12// - Replacing globalized device memory with stack memory.
13// - Replacing globalized device memory with shared memory.
14// - Parallel region merging.
15// - Transforming generic-mode device kernels to SPMD mode.
16// - Specializing the state machine for generic-mode device kernels.
17//
18//===----------------------------------------------------------------------===//
19
21
22#include "llvm/ADT/DenseSet.h"
25#include "llvm/ADT/SetVector.h"
28#include "llvm/ADT/Statistic.h"
30#include "llvm/ADT/StringRef.h"
38#include "llvm/IR/Assumptions.h"
39#include "llvm/IR/BasicBlock.h"
40#include "llvm/IR/Constants.h"
42#include "llvm/IR/Dominators.h"
43#include "llvm/IR/Function.h"
44#include "llvm/IR/GlobalValue.h"
46#include "llvm/IR/InstrTypes.h"
47#include "llvm/IR/Instruction.h"
50#include "llvm/IR/IntrinsicsAMDGPU.h"
51#include "llvm/IR/IntrinsicsNVPTX.h"
52#include "llvm/IR/LLVMContext.h"
55#include "llvm/Support/Debug.h"
59
60#include <algorithm>
61#include <optional>
62#include <string>
63
64using namespace llvm;
65using namespace omp;
66
67#define DEBUG_TYPE "openmp-opt"
68
70 "openmp-opt-disable", cl::desc("Disable OpenMP specific optimizations."),
71 cl::Hidden, cl::init(false));
72
74 "openmp-opt-enable-merging",
75 cl::desc("Enable the OpenMP region merging optimization."), cl::Hidden,
76 cl::init(false));
77
78static cl::opt<bool>
79 DisableInternalization("openmp-opt-disable-internalization",
80 cl::desc("Disable function internalization."),
81 cl::Hidden, cl::init(false));
82
83static cl::opt<bool> DeduceICVValues("openmp-deduce-icv-values",
84 cl::init(false), cl::Hidden);
85static cl::opt<bool> PrintICVValues("openmp-print-icv-values", cl::init(false),
87static cl::opt<bool> PrintOpenMPKernels("openmp-print-gpu-kernels",
88 cl::init(false), cl::Hidden);
89
91 "openmp-hide-memory-transfer-latency",
92 cl::desc("[WIP] Tries to hide the latency of host to device memory"
93 " transfers"),
94 cl::Hidden, cl::init(false));
95
97 "openmp-opt-disable-deglobalization",
98 cl::desc("Disable OpenMP optimizations involving deglobalization."),
99 cl::Hidden, cl::init(false));
100
102 "openmp-opt-disable-spmdization",
103 cl::desc("Disable OpenMP optimizations involving SPMD-ization."),
104 cl::Hidden, cl::init(false));
105
107 "openmp-opt-disable-folding",
108 cl::desc("Disable OpenMP optimizations involving folding."), cl::Hidden,
109 cl::init(false));
110
112 "openmp-opt-disable-state-machine-rewrite",
113 cl::desc("Disable OpenMP optimizations that replace the state machine."),
114 cl::Hidden, cl::init(false));
115
117 "openmp-opt-disable-barrier-elimination",
118 cl::desc("Disable OpenMP optimizations that eliminate barriers."),
119 cl::Hidden, cl::init(false));
120
122 "openmp-opt-print-module-after",
123 cl::desc("Print the current module after OpenMP optimizations."),
124 cl::Hidden, cl::init(false));
125
127 "openmp-opt-print-module-before",
128 cl::desc("Print the current module before OpenMP optimizations."),
129 cl::Hidden, cl::init(false));
130
132 "openmp-opt-inline-device",
133 cl::desc("Inline all applicable functions on the device."), cl::Hidden,
134 cl::init(false));
135
136static cl::opt<bool>
137 EnableVerboseRemarks("openmp-opt-verbose-remarks",
138 cl::desc("Enables more verbose remarks."), cl::Hidden,
139 cl::init(false));
140
142 SetFixpointIterations("openmp-opt-max-iterations", cl::Hidden,
143 cl::desc("Maximal number of attributor iterations."),
144 cl::init(256));
145
147 SharedMemoryLimit("openmp-opt-shared-limit", cl::Hidden,
148 cl::desc("Maximum amount of shared memory to use."),
149 cl::init(std::numeric_limits<unsigned>::max()));
150
151STATISTIC(NumOpenMPRuntimeCallsDeduplicated,
152 "Number of OpenMP runtime calls deduplicated");
153STATISTIC(NumOpenMPParallelRegionsDeleted,
154 "Number of OpenMP parallel regions deleted");
155STATISTIC(NumOpenMPRuntimeFunctionsIdentified,
156 "Number of OpenMP runtime functions identified");
157STATISTIC(NumOpenMPRuntimeFunctionUsesIdentified,
158 "Number of OpenMP runtime function uses identified");
159STATISTIC(NumOpenMPTargetRegionKernels,
160 "Number of OpenMP target region entry points (=kernels) identified");
161STATISTIC(NumNonOpenMPTargetRegionKernels,
162 "Number of non-OpenMP target region kernels identified");
163STATISTIC(NumOpenMPTargetRegionKernelsSPMD,
164 "Number of OpenMP target region entry points (=kernels) executed in "
165 "SPMD-mode instead of generic-mode");
166STATISTIC(NumOpenMPTargetRegionKernelsWithoutStateMachine,
167 "Number of OpenMP target region entry points (=kernels) executed in "
168 "generic-mode without a state machines");
169STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback,
170 "Number of OpenMP target region entry points (=kernels) executed in "
171 "generic-mode with customized state machines with fallback");
172STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback,
173 "Number of OpenMP target region entry points (=kernels) executed in "
174 "generic-mode with customized state machines without fallback");
176 NumOpenMPParallelRegionsReplacedInGPUStateMachine,
177 "Number of OpenMP parallel regions replaced with ID in GPU state machines");
178STATISTIC(NumOpenMPParallelRegionsMerged,
179 "Number of OpenMP parallel regions merged");
180STATISTIC(NumBytesMovedToSharedMemory,
181 "Amount of memory pushed to shared memory");
182STATISTIC(NumBarriersEliminated, "Number of redundant barriers eliminated");
183
184#if !defined(NDEBUG)
185static constexpr auto TAG = "[" DEBUG_TYPE "]";
186#endif
187
188namespace KernelInfo {
189
190// struct ConfigurationEnvironmentTy {
191// uint8_t UseGenericStateMachine;
192// uint8_t MayUseNestedParallelism;
193// llvm::omp::OMPTgtExecModeFlags ExecMode;
194// int32_t MinThreads;
195// int32_t MaxThreads;
196// int32_t MinTeams;
197// int32_t MaxTeams;
198// };
199
200// struct DynamicEnvironmentTy {
201// uint16_t DebugIndentionLevel;
202// };
203
204// struct KernelEnvironmentTy {
205// ConfigurationEnvironmentTy Configuration;
206// IdentTy *Ident;
207// DynamicEnvironmentTy *DynamicEnv;
208// };
209
210#define KERNEL_ENVIRONMENT_IDX(MEMBER, IDX) \
211 constexpr unsigned MEMBER##Idx = IDX;
212
213KERNEL_ENVIRONMENT_IDX(Configuration, 0)
215
216#undef KERNEL_ENVIRONMENT_IDX
217
218#define KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MEMBER, IDX) \
219 constexpr unsigned MEMBER##Idx = IDX;
220
221KERNEL_ENVIRONMENT_CONFIGURATION_IDX(UseGenericStateMachine, 0)
222KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MayUseNestedParallelism, 1)
228
229#undef KERNEL_ENVIRONMENT_CONFIGURATION_IDX
230
231#define KERNEL_ENVIRONMENT_GETTER(MEMBER, RETURNTYPE) \
232 RETURNTYPE *get##MEMBER##FromKernelEnvironment(ConstantStruct *KernelEnvC) { \
233 return cast<RETURNTYPE>(KernelEnvC->getAggregateElement(MEMBER##Idx)); \
234 }
235
238
239#undef KERNEL_ENVIRONMENT_GETTER
240
241#define KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MEMBER) \
242 ConstantInt *get##MEMBER##FromKernelEnvironment( \
243 ConstantStruct *KernelEnvC) { \
244 ConstantStruct *ConfigC = \
245 getConfigurationFromKernelEnvironment(KernelEnvC); \
246 return dyn_cast<ConstantInt>(ConfigC->getAggregateElement(MEMBER##Idx)); \
247 }
248
249KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(UseGenericStateMachine)
250KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MayUseNestedParallelism)
256
257#undef KERNEL_ENVIRONMENT_CONFIGURATION_GETTER
258
261 constexpr int InitKernelEnvironmentArgNo = 0;
263 KernelInitCB->getArgOperand(InitKernelEnvironmentArgNo)
265}
266
272} // namespace KernelInfo
273
274namespace {
275
276struct AAHeapToShared;
277
278struct AAICVTracker;
279
280/// OpenMP specific information. For now, stores RFIs and ICVs also needed for
281/// Attributor runs.
282struct OMPInformationCache : public InformationCache {
283 OMPInformationCache(Module &M, AnalysisGetter &AG,
284 BumpPtrAllocator &Allocator, SetVector<Function *> *CGSCC,
285 bool OpenMPPostLink)
286 : InformationCache(M, AG, Allocator, CGSCC), OMPBuilder(M),
287 OpenMPPostLink(OpenMPPostLink) {
288
289 OMPBuilder.Config.IsTargetDevice = isOpenMPDevice(OMPBuilder.M);
290 const Triple T(OMPBuilder.M.getTargetTriple());
291 switch (T.getArch()) {
295 assert(OMPBuilder.Config.IsTargetDevice &&
296 "OpenMP AMDGPU/NVPTX is only prepared to deal with device code.");
297 OMPBuilder.Config.IsGPU = true;
298 break;
299 default:
300 OMPBuilder.Config.IsGPU = false;
301 break;
302 }
303 OMPBuilder.initialize();
304 initializeRuntimeFunctions(M);
305 initializeInternalControlVars();
306 }
307
308 /// Generic information that describes an internal control variable.
309 struct InternalControlVarInfo {
310 /// The kind, as described by InternalControlVar enum.
312
313 /// The name of the ICV.
314 StringRef Name;
315
316 /// Environment variable associated with this ICV.
317 StringRef EnvVarName;
318
319 /// Initial value kind.
320 ICVInitValue InitKind;
321
322 /// Initial value.
323 ConstantInt *InitValue;
324
325 /// Setter RTL function associated with this ICV.
326 RuntimeFunction Setter;
327
328 /// Getter RTL function associated with this ICV.
329 RuntimeFunction Getter;
330
331 /// RTL Function corresponding to the override clause of this ICV
332 RuntimeFunction Clause;
333 };
334
335 /// Generic information that describes a runtime function
336 struct RuntimeFunctionInfo {
337
338 /// The kind, as described by the RuntimeFunction enum.
339 RuntimeFunction Kind;
340
341 /// The name of the function.
342 StringRef Name;
343
344 /// Flag to indicate a variadic function.
345 bool IsVarArg;
346
347 /// The return type of the function.
348 Type *ReturnType;
349
350 /// The argument types of the function.
351 SmallVector<Type *, 8> ArgumentTypes;
352
353 /// The declaration if available.
354 Function *Declaration = nullptr;
355
356 /// Uses of this runtime function per function containing the use.
357 using UseVector = SmallVector<Use *, 16>;
358
359 /// Clear UsesMap for runtime function.
360 void clearUsesMap() { UsesMap.clear(); }
361
362 /// Boolean conversion that is true if the runtime function was found.
363 operator bool() const { return Declaration; }
364
365 /// Return the vector of uses in function \p F.
366 UseVector &getOrCreateUseVector(Function *F) {
367 std::shared_ptr<UseVector> &UV = UsesMap[F];
368 if (!UV)
369 UV = std::make_shared<UseVector>();
370 return *UV;
371 }
372
373 /// Return the vector of uses in function \p F or `nullptr` if there are
374 /// none.
375 const UseVector *getUseVector(Function &F) const {
376 auto I = UsesMap.find(&F);
377 if (I != UsesMap.end())
378 return I->second.get();
379 return nullptr;
380 }
381
382 /// Return how many functions contain uses of this runtime function.
383 size_t getNumFunctionsWithUses() const { return UsesMap.size(); }
384
385 /// Return the number of arguments (or the minimal number for variadic
386 /// functions).
387 size_t getNumArgs() const { return ArgumentTypes.size(); }
388
389 /// Run the callback \p CB on each use and forget the use if the result is
390 /// true. The callback will be fed the function in which the use was
391 /// encountered as second argument.
392 void foreachUse(SmallVectorImpl<Function *> &SCC,
393 function_ref<bool(Use &, Function &)> CB) {
394 for (Function *F : SCC)
395 foreachUse(CB, F);
396 }
397
398 /// Run the callback \p CB on each use within the function \p F and forget
399 /// the use if the result is true.
400 void foreachUse(function_ref<bool(Use &, Function &)> CB, Function *F) {
401 SmallVector<unsigned, 8> ToBeDeleted;
402 ToBeDeleted.clear();
403
404 unsigned Idx = 0;
405 UseVector &UV = getOrCreateUseVector(F);
406
407 for (Use *U : UV) {
408 if (CB(*U, *F))
409 ToBeDeleted.push_back(Idx);
410 ++Idx;
411 }
412
413 // Remove the to-be-deleted indices in reverse order as prior
414 // modifications will not modify the smaller indices.
415 while (!ToBeDeleted.empty()) {
416 unsigned Idx = ToBeDeleted.pop_back_val();
417 UV[Idx] = UV.back();
418 UV.pop_back();
419 }
420 }
421
422 private:
423 /// Map from functions to all uses of this runtime function contained in
424 /// them.
425 DenseMap<Function *, std::shared_ptr<UseVector>> UsesMap;
426
427 public:
428 /// Iterators for the uses of this runtime function.
429 decltype(UsesMap)::iterator begin() { return UsesMap.begin(); }
430 decltype(UsesMap)::iterator end() { return UsesMap.end(); }
431 };
432
433 /// An OpenMP-IR-Builder instance
434 OpenMPIRBuilder OMPBuilder;
435
436 /// Map from runtime function kind to the runtime function description.
437 EnumeratedArray<RuntimeFunctionInfo, RuntimeFunction,
438 RuntimeFunction::OMPRTL___last>
439 RFIs;
440
441 /// Map from function declarations/definitions to their runtime enum type.
442 DenseMap<Function *, RuntimeFunction> RuntimeFunctionIDMap;
443
444 /// Map from ICV kind to the ICV description.
445 EnumeratedArray<InternalControlVarInfo, InternalControlVar,
446 InternalControlVar::ICV___last>
447 ICVs;
448
449 /// Helper to initialize all internal control variable information for those
450 /// defined in OMPKinds.def.
451 void initializeInternalControlVars() {
452#define ICV_RT_SET(_Name, RTL) \
453 { \
454 auto &ICV = ICVs[_Name]; \
455 ICV.Setter = RTL; \
456 }
457#define ICV_RT_GET(Name, RTL) \
458 { \
459 auto &ICV = ICVs[Name]; \
460 ICV.Getter = RTL; \
461 }
462#define ICV_DATA_ENV(Enum, _Name, _EnvVarName, Init) \
463 { \
464 auto &ICV = ICVs[Enum]; \
465 ICV.Name = _Name; \
466 ICV.Kind = Enum; \
467 ICV.InitKind = Init; \
468 ICV.EnvVarName = _EnvVarName; \
469 switch (ICV.InitKind) { \
470 case ICV_IMPLEMENTATION_DEFINED: \
471 ICV.InitValue = nullptr; \
472 break; \
473 case ICV_ZERO: \
474 ICV.InitValue = ConstantInt::get( \
475 Type::getInt32Ty(OMPBuilder.Int32->getContext()), 0); \
476 break; \
477 case ICV_FALSE: \
478 ICV.InitValue = ConstantInt::getFalse(OMPBuilder.Int1->getContext()); \
479 break; \
480 case ICV_LAST: \
481 break; \
482 } \
483 }
484#include "llvm/Frontend/OpenMP/OMPKinds.def"
485 }
486
487 /// Returns true if the function declaration \p F matches the runtime
488 /// function types, that is, return type \p RTFRetType, and argument types
489 /// \p RTFArgTypes.
490 static bool declMatchesRTFTypes(Function *F, Type *RTFRetType,
491 SmallVector<Type *, 8> &RTFArgTypes) {
492 // TODO: We should output information to the user (under debug output
493 // and via remarks).
494
495 if (!F)
496 return false;
497 if (F->getReturnType() != RTFRetType)
498 return false;
499 if (F->arg_size() != RTFArgTypes.size())
500 return false;
501
502 auto *RTFTyIt = RTFArgTypes.begin();
503 for (Argument &Arg : F->args()) {
504 if (Arg.getType() != *RTFTyIt)
505 return false;
506
507 ++RTFTyIt;
508 }
509
510 return true;
511 }
512
513 // Helper to collect all uses of the declaration in the UsesMap.
514 unsigned collectUses(RuntimeFunctionInfo &RFI, bool CollectStats = true) {
515 unsigned NumUses = 0;
516 if (!RFI.Declaration)
517 return NumUses;
518 OMPBuilder.addAttributes(RFI.Kind, *RFI.Declaration);
519
520 if (CollectStats) {
521 NumOpenMPRuntimeFunctionsIdentified += 1;
522 NumOpenMPRuntimeFunctionUsesIdentified += RFI.Declaration->getNumUses();
523 }
524
525 // TODO: We directly convert uses into proper calls and unknown uses.
526 for (Use &U : RFI.Declaration->uses()) {
527 if (Instruction *UserI = dyn_cast<Instruction>(U.getUser())) {
528 if (!CGSCC || CGSCC->empty() || CGSCC->contains(UserI->getFunction())) {
529 RFI.getOrCreateUseVector(UserI->getFunction()).push_back(&U);
530 ++NumUses;
531 }
532 } else {
533 RFI.getOrCreateUseVector(nullptr).push_back(&U);
534 ++NumUses;
535 }
536 }
537 return NumUses;
538 }
539
540 // Helper function to recollect uses of a runtime function.
541 void recollectUsesForFunction(RuntimeFunction RTF) {
542 auto &RFI = RFIs[RTF];
543 RFI.clearUsesMap();
544 collectUses(RFI, /*CollectStats*/ false);
545 }
546
547 // Helper function to recollect uses of all runtime functions.
548 void recollectUses() {
549 for (int Idx = 0; Idx < RFIs.size(); ++Idx)
550 recollectUsesForFunction(static_cast<RuntimeFunction>(Idx));
551 }
552
553 // Helper function to inherit the calling convention of the function callee.
554 void setCallingConvention(FunctionCallee Callee, CallInst *CI) {
555 if (Function *Fn = dyn_cast<Function>(Callee.getCallee()))
556 CI->setCallingConv(Fn->getCallingConv());
557 }
558
559 // Helper function to determine if it's legal to create a call to the runtime
560 // functions.
561 bool runtimeFnsAvailable(ArrayRef<RuntimeFunction> Fns) {
562 // We can always emit calls if we haven't yet linked in the runtime.
563 if (!OpenMPPostLink)
564 return true;
565
566 // Once the runtime has been already been linked in we cannot emit calls to
567 // any undefined functions.
568 for (RuntimeFunction Fn : Fns) {
569 RuntimeFunctionInfo &RFI = RFIs[Fn];
570
571 if (!RFI.Declaration || RFI.Declaration->isDeclaration())
572 return false;
573 }
574 return true;
575 }
576
577 /// Helper to initialize all runtime function information for those defined
578 /// in OpenMPKinds.def.
579 void initializeRuntimeFunctions(Module &M) {
580
581 // Helper macros for handling __VA_ARGS__ in OMP_RTL
582#define OMP_TYPE(VarName, ...) \
583 Type *VarName = OMPBuilder.VarName; \
584 (void)VarName;
585
586#define OMP_ARRAY_TYPE(VarName, ...) \
587 ArrayType *VarName##Ty = OMPBuilder.VarName##Ty; \
588 (void)VarName##Ty; \
589 PointerType *VarName##PtrTy = OMPBuilder.VarName##PtrTy; \
590 (void)VarName##PtrTy;
591
592#define OMP_FUNCTION_TYPE(VarName, ...) \
593 FunctionType *VarName = OMPBuilder.VarName; \
594 (void)VarName; \
595 PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \
596 (void)VarName##Ptr;
597
598#define OMP_STRUCT_TYPE(VarName, ...) \
599 StructType *VarName = OMPBuilder.VarName; \
600 (void)VarName; \
601 PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \
602 (void)VarName##Ptr;
603
604#define OMP_RTL(_Enum, _Name, _IsVarArg, _ReturnType, ...) \
605 { \
606 SmallVector<Type *, 8> ArgsTypes({__VA_ARGS__}); \
607 Function *F = M.getFunction(_Name); \
608 RTLFunctions.insert(F); \
609 if (declMatchesRTFTypes(F, OMPBuilder._ReturnType, ArgsTypes)) { \
610 RuntimeFunctionIDMap[F] = _Enum; \
611 auto &RFI = RFIs[_Enum]; \
612 RFI.Kind = _Enum; \
613 RFI.Name = _Name; \
614 RFI.IsVarArg = _IsVarArg; \
615 RFI.ReturnType = OMPBuilder._ReturnType; \
616 RFI.ArgumentTypes = std::move(ArgsTypes); \
617 RFI.Declaration = F; \
618 unsigned NumUses = collectUses(RFI); \
619 (void)NumUses; \
620 LLVM_DEBUG({ \
621 dbgs() << TAG << RFI.Name << (RFI.Declaration ? "" : " not") \
622 << " found\n"; \
623 if (RFI.Declaration) \
624 dbgs() << TAG << "-> got " << NumUses << " uses in " \
625 << RFI.getNumFunctionsWithUses() \
626 << " different functions.\n"; \
627 }); \
628 } \
629 }
630#include "llvm/Frontend/OpenMP/OMPKinds.def"
631
632 // Remove the `noinline` attribute from `__kmpc`, `ompx::` and `omp_`
633 // functions, except if `optnone` is present.
634 if (isOpenMPDevice(M)) {
635 for (Function &F : M) {
636 for (StringRef Prefix : {"__kmpc", "_ZN4ompx", "omp_"})
637 if (F.hasFnAttribute(Attribute::NoInline) &&
638 F.getName().starts_with(Prefix) &&
639 !F.hasFnAttribute(Attribute::OptimizeNone))
640 F.removeFnAttr(Attribute::NoInline);
641 }
642 }
643
644 // TODO: We should attach the attributes defined in OMPKinds.def.
645 }
646
647 /// Collection of known OpenMP runtime functions..
648 DenseSet<const Function *> RTLFunctions;
649
650 /// Indicates if we have already linked in the OpenMP device library.
651 bool OpenMPPostLink = false;
652};
653
654template <typename Ty, bool InsertInvalidates = true>
655struct BooleanStateWithSetVector : public BooleanState {
656 bool contains(const Ty &Elem) const { return Set.contains(Elem); }
657 bool insert(const Ty &Elem) {
658 if (InsertInvalidates)
659 BooleanState::indicatePessimisticFixpoint();
660 return Set.insert(Elem);
661 }
662
663 const Ty &operator[](int Idx) const { return Set[Idx]; }
664 bool operator==(const BooleanStateWithSetVector &RHS) const {
665 return BooleanState::operator==(RHS) && Set == RHS.Set;
666 }
667 bool operator!=(const BooleanStateWithSetVector &RHS) const {
668 return !(*this == RHS);
669 }
670
671 bool empty() const { return Set.empty(); }
672 size_t size() const { return Set.size(); }
673
674 /// "Clamp" this state with \p RHS.
675 BooleanStateWithSetVector &operator^=(const BooleanStateWithSetVector &RHS) {
676 BooleanState::operator^=(RHS);
677 Set.insert_range(RHS.Set);
678 return *this;
679 }
680
681private:
682 /// A set to keep track of elements.
683 SetVector<Ty> Set;
684
685public:
686 typename decltype(Set)::iterator begin() { return Set.begin(); }
687 typename decltype(Set)::iterator end() { return Set.end(); }
688 typename decltype(Set)::const_iterator begin() const { return Set.begin(); }
689 typename decltype(Set)::const_iterator end() const { return Set.end(); }
690};
691
692template <typename Ty, bool InsertInvalidates = true>
693using BooleanStateWithPtrSetVector =
694 BooleanStateWithSetVector<Ty *, InsertInvalidates>;
695
696struct KernelInfoState : AbstractState {
697 /// Flag to track if we reached a fixpoint.
698 bool IsAtFixpoint = false;
699
700 /// The parallel regions (identified by the outlined parallel functions) that
701 /// can be reached from the associated function.
702 BooleanStateWithPtrSetVector<CallBase, /* InsertInvalidates */ false>
703 ReachedKnownParallelRegions;
704
705 /// State to track what parallel region we might reach.
706 BooleanStateWithPtrSetVector<CallBase> ReachedUnknownParallelRegions;
707
708 /// State to track if we are in SPMD-mode, assumed or know, and why we decided
709 /// we cannot be. If it is assumed, then RequiresFullRuntime should also be
710 /// false.
711 BooleanStateWithPtrSetVector<Instruction, false> SPMDCompatibilityTracker;
712
713 /// The __kmpc_target_init call in this kernel, if any. If we find more than
714 /// one we abort as the kernel is malformed.
715 CallBase *KernelInitCB = nullptr;
716
717 /// The constant kernel environement as taken from and passed to
718 /// __kmpc_target_init.
719 ConstantStruct *KernelEnvC = nullptr;
720
721 /// The __kmpc_target_deinit call in this kernel, if any. If we find more than
722 /// one we abort as the kernel is malformed.
723 CallBase *KernelDeinitCB = nullptr;
724
725 /// Flag to indicate if the associated function is a kernel entry.
726 bool IsKernelEntry = false;
727
728 /// State to track what kernel entries can reach the associated function.
729 BooleanStateWithPtrSetVector<Function, false> ReachingKernelEntries;
730
731 /// State to indicate if we can track parallel level of the associated
732 /// function. We will give up tracking if we encounter unknown caller or the
733 /// caller is __kmpc_parallel_60.
734 BooleanStateWithSetVector<uint8_t> ParallelLevels;
735
736 /// Flag that indicates if the kernel has nested Parallelism
737 bool NestedParallelism = false;
738
739 /// Abstract State interface
740 ///{
741
742 KernelInfoState() = default;
743 KernelInfoState(bool BestState) {
744 if (!BestState)
745 indicatePessimisticFixpoint();
746 }
747
748 /// See AbstractState::isValidState(...)
749 bool isValidState() const override { return true; }
750
751 /// See AbstractState::isAtFixpoint(...)
752 bool isAtFixpoint() const override { return IsAtFixpoint; }
753
754 /// See AbstractState::indicatePessimisticFixpoint(...)
755 ChangeStatus indicatePessimisticFixpoint() override {
756 IsAtFixpoint = true;
757 ParallelLevels.indicatePessimisticFixpoint();
758 ReachingKernelEntries.indicatePessimisticFixpoint();
759 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
760 ReachedKnownParallelRegions.indicatePessimisticFixpoint();
761 ReachedUnknownParallelRegions.indicatePessimisticFixpoint();
762 NestedParallelism = true;
763 return ChangeStatus::CHANGED;
764 }
765
766 /// See AbstractState::indicateOptimisticFixpoint(...)
767 ChangeStatus indicateOptimisticFixpoint() override {
768 IsAtFixpoint = true;
769 ParallelLevels.indicateOptimisticFixpoint();
770 ReachingKernelEntries.indicateOptimisticFixpoint();
771 SPMDCompatibilityTracker.indicateOptimisticFixpoint();
772 ReachedKnownParallelRegions.indicateOptimisticFixpoint();
773 ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
774 return ChangeStatus::UNCHANGED;
775 }
776
777 /// Return the assumed state
778 KernelInfoState &getAssumed() { return *this; }
779 const KernelInfoState &getAssumed() const { return *this; }
780
781 bool operator==(const KernelInfoState &RHS) const {
782 if (SPMDCompatibilityTracker != RHS.SPMDCompatibilityTracker)
783 return false;
784 if (ReachedKnownParallelRegions != RHS.ReachedKnownParallelRegions)
785 return false;
786 if (ReachedUnknownParallelRegions != RHS.ReachedUnknownParallelRegions)
787 return false;
788 if (ReachingKernelEntries != RHS.ReachingKernelEntries)
789 return false;
790 if (ParallelLevels != RHS.ParallelLevels)
791 return false;
792 if (NestedParallelism != RHS.NestedParallelism)
793 return false;
794 return true;
795 }
796
797 /// Returns true if this kernel contains any OpenMP parallel regions.
798 bool mayContainParallelRegion() {
799 return !ReachedKnownParallelRegions.empty() ||
800 !ReachedUnknownParallelRegions.empty();
801 }
802
803 /// Return empty set as the best state of potential values.
804 static KernelInfoState getBestState() { return KernelInfoState(true); }
805
806 static KernelInfoState getBestState(KernelInfoState &KIS) {
807 return getBestState();
808 }
809
810 /// Return full set as the worst state of potential values.
811 static KernelInfoState getWorstState() { return KernelInfoState(false); }
812
813 /// "Clamp" this state with \p KIS.
814 KernelInfoState operator^=(const KernelInfoState &KIS) {
815 // Do not merge two different _init and _deinit call sites.
816 if (KIS.KernelInitCB) {
817 if (KernelInitCB && KernelInitCB != KIS.KernelInitCB)
818 llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
819 "assumptions.");
820 KernelInitCB = KIS.KernelInitCB;
821 }
822 if (KIS.KernelDeinitCB) {
823 if (KernelDeinitCB && KernelDeinitCB != KIS.KernelDeinitCB)
824 llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
825 "assumptions.");
826 KernelDeinitCB = KIS.KernelDeinitCB;
827 }
828 if (KIS.KernelEnvC) {
829 if (KernelEnvC && KernelEnvC != KIS.KernelEnvC)
830 llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
831 "assumptions.");
832 KernelEnvC = KIS.KernelEnvC;
833 }
834 SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker;
835 ReachedKnownParallelRegions ^= KIS.ReachedKnownParallelRegions;
836 ReachedUnknownParallelRegions ^= KIS.ReachedUnknownParallelRegions;
837 NestedParallelism |= KIS.NestedParallelism;
838 return *this;
839 }
840
841 KernelInfoState operator&=(const KernelInfoState &KIS) {
842 return (*this ^= KIS);
843 }
844
845 ///}
846};
847
848/// Used to map the values physically (in the IR) stored in an offload
849/// array, to a vector in memory.
850struct OffloadArray {
851 /// Physical array (in the IR).
852 AllocaInst *Array = nullptr;
853 /// Mapped values.
854 SmallVector<Value *, 8> StoredValues;
855 /// Last stores made in the offload array.
856 SmallVector<StoreInst *, 8> LastAccesses;
857
858 OffloadArray() = default;
859
860 /// Initializes the OffloadArray with the values stored in \p Array before
861 /// instruction \p Before is reached. Returns false if the initialization
862 /// fails.
863 /// This MUST be used immediately after the construction of the object.
864 bool initialize(AllocaInst &Array, Instruction &Before) {
865 if (!getValues(Array, Before))
866 return false;
867
868 this->Array = &Array;
869 return true;
870 }
871
872 static const unsigned DeviceIDArgNum = 1;
873 static const unsigned BasePtrsArgNum = 3;
874 static const unsigned PtrsArgNum = 4;
875 static const unsigned SizesArgNum = 5;
876
877private:
878 /// Traverses the BasicBlock where \p Array is, collecting the stores made to
879 /// \p Array, leaving StoredValues with the values stored before the
880 /// instruction \p Before is reached.
881 bool getValues(AllocaInst &Array, Instruction &Before) {
882 // Initialize containers.
883 const DataLayout &DL = Array.getDataLayout();
884 std::optional<TypeSize> ArraySize = Array.getAllocationSize(DL);
885 if (!ArraySize || !ArraySize->isFixed())
886 return false;
887 const unsigned int PointerSize = DL.getPointerSize();
888 const uint64_t NumValues = ArraySize->getFixedValue() / PointerSize;
889 StoredValues.assign(NumValues, nullptr);
890 LastAccesses.assign(NumValues, nullptr);
891
892 // TODO: This assumes the instruction \p Before is in the same
893 // BasicBlock as Array. Make it general, for any control flow graph.
894 BasicBlock *BB = Array.getParent();
895 if (BB != Before.getParent())
896 return false;
897
898 for (Instruction &I : *BB) {
899 if (&I == &Before)
900 break;
901
902 if (!isa<StoreInst>(&I))
903 continue;
904
905 auto *S = cast<StoreInst>(&I);
906 int64_t Offset = -1;
907 auto *Dst =
908 GetPointerBaseWithConstantOffset(S->getPointerOperand(), Offset, DL);
909 if (Dst == &Array) {
910 int64_t Idx = Offset / PointerSize;
911 StoredValues[Idx] = getUnderlyingObject(S->getValueOperand());
912 LastAccesses[Idx] = S;
913 }
914 }
915
916 return isFilled();
917 }
918
919 /// Returns true if all values in StoredValues and
920 /// LastAccesses are not nullptrs.
921 bool isFilled() {
922 const unsigned NumValues = StoredValues.size();
923 for (unsigned I = 0; I < NumValues; ++I) {
924 if (!StoredValues[I] || !LastAccesses[I])
925 return false;
926 }
927
928 return true;
929 }
930};
931
932struct OpenMPOpt {
933
934 using OptimizationRemarkGetter =
935 function_ref<OptimizationRemarkEmitter &(Function *)>;
936
937 OpenMPOpt(SmallVectorImpl<Function *> &SCC, CallGraphUpdater &CGUpdater,
938 OptimizationRemarkGetter OREGetter,
939 OMPInformationCache &OMPInfoCache, Attributor &A)
940 : M(*(*SCC.begin())->getParent()), SCC(SCC), CGUpdater(CGUpdater),
941 OREGetter(OREGetter), OMPInfoCache(OMPInfoCache), A(A) {}
942
943 /// Check if any remarks are enabled for openmp-opt
944 bool remarksEnabled() {
945 auto &Ctx = M.getContext();
946 return Ctx.getDiagHandlerPtr()->isAnyRemarkEnabled(DEBUG_TYPE);
947 }
948
949 /// Run all OpenMP optimizations on the underlying SCC.
950 bool run(bool IsModulePass) {
951 if (SCC.empty())
952 return false;
953
954 bool Changed = false;
955
956 LLVM_DEBUG(dbgs() << TAG << "Run on SCC with " << SCC.size()
957 << " functions\n");
958
959 if (IsModulePass) {
960 Changed |= runAttributor(IsModulePass);
961
962 // Recollect uses, in case Attributor deleted any.
963 OMPInfoCache.recollectUses();
964
965 // TODO: This should be folded into buildCustomStateMachine.
966 Changed |= rewriteDeviceCodeStateMachine();
967
968 if (remarksEnabled())
969 analysisGlobalization();
970 } else {
971 if (PrintICVValues)
972 printICVs();
974 printKernels();
975
976 Changed |= runAttributor(IsModulePass);
977
978 // Recollect uses, in case Attributor deleted any.
979 OMPInfoCache.recollectUses();
980
981 Changed |= deleteParallelRegions();
982
984 Changed |= hideMemTransfersLatency();
985 Changed |= deduplicateRuntimeCalls();
987 if (mergeParallelRegions()) {
988 deduplicateRuntimeCalls();
989 Changed = true;
990 }
991 }
992 }
993
994 if (OMPInfoCache.OpenMPPostLink)
995 Changed |= removeRuntimeSymbols();
996
997 return Changed;
998 }
999
1000 /// Print initial ICV values for testing.
1001 /// FIXME: This should be done from the Attributor once it is added.
1002 void printICVs() const {
1003 InternalControlVar ICVs[] = {ICV_nthreads, ICV_active_levels, ICV_cancel,
1004 ICV_proc_bind};
1005
1006 for (Function *F : SCC) {
1007 for (auto ICV : ICVs) {
1008 auto ICVInfo = OMPInfoCache.ICVs[ICV];
1009 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
1010 return ORA << "OpenMP ICV " << ore::NV("OpenMPICV", ICVInfo.Name)
1011 << " Value: "
1012 << (ICVInfo.InitValue
1013 ? toString(ICVInfo.InitValue->getValue(), 10, true)
1014 : "IMPLEMENTATION_DEFINED");
1015 };
1016
1017 emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPICVTracker", Remark);
1018 }
1019 }
1020 }
1021
1022 /// Print OpenMP GPU kernels for testing.
1023 void printKernels() const {
1024 for (Function *F : SCC) {
1025 if (!omp::isOpenMPKernel(*F))
1026 continue;
1027
1028 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
1029 return ORA << "OpenMP GPU kernel "
1030 << ore::NV("OpenMPGPUKernel", F->getName()) << "\n";
1031 };
1032
1034 }
1035 }
1036
1037 /// Return the call if \p U is a callee use in a regular call. If \p RFI is
1038 /// given it has to be the callee or a nullptr is returned.
1039 static CallInst *getCallIfRegularCall(
1040 Use &U, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
1041 CallInst *CI = dyn_cast<CallInst>(U.getUser());
1042 if (CI && CI->isCallee(&U) && !CI->hasOperandBundles() &&
1043 (!RFI ||
1044 (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
1045 return CI;
1046 return nullptr;
1047 }
1048
1049 /// Return the call if \p V is a regular call. If \p RFI is given it has to be
1050 /// the callee or a nullptr is returned.
1051 static CallInst *getCallIfRegularCall(
1052 Value &V, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
1053 CallInst *CI = dyn_cast<CallInst>(&V);
1054 if (CI && !CI->hasOperandBundles() &&
1055 (!RFI ||
1056 (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
1057 return CI;
1058 return nullptr;
1059 }
1060
1061private:
1062 /// Merge parallel regions when it is safe.
1063 bool mergeParallelRegions() {
1064 const unsigned CallbackCalleeOperand = 2;
1065 const unsigned CallbackFirstArgOperand = 3;
1066 using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
1067
1068 // Check if there are any __kmpc_fork_call calls to merge.
1069 OMPInformationCache::RuntimeFunctionInfo &RFI =
1070 OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
1071
1072 if (!RFI.Declaration)
1073 return false;
1074
1075 // Unmergable calls that prevent merging a parallel region.
1076 OMPInformationCache::RuntimeFunctionInfo UnmergableCallsInfo[] = {
1077 OMPInfoCache.RFIs[OMPRTL___kmpc_push_proc_bind],
1078 OMPInfoCache.RFIs[OMPRTL___kmpc_push_num_threads],
1079 };
1080
1081 bool Changed = false;
1082 LoopInfo *LI = nullptr;
1083 DominatorTree *DT = nullptr;
1084
1085 SmallDenseMap<BasicBlock *, SmallPtrSet<Instruction *, 4>> BB2PRMap;
1086
1087 BasicBlock *StartBB = nullptr, *EndBB = nullptr;
1088 auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
1089 BasicBlock *CGStartBB = CodeGenIP.getBlock();
1090 BasicBlock *CGEndBB =
1091 SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
1092 assert(StartBB != nullptr && "StartBB should not be null");
1093 CGStartBB->getTerminator()->setSuccessor(0, StartBB);
1094 assert(EndBB != nullptr && "EndBB should not be null");
1095 EndBB->getTerminator()->setSuccessor(0, CGEndBB);
1096 return Error::success();
1097 };
1098
1099 auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, Value &,
1100 Value &Inner, Value *&ReplacementValue) -> InsertPointTy {
1101 ReplacementValue = &Inner;
1102 return CodeGenIP;
1103 };
1104
1105 auto FiniCB = [&](InsertPointTy CodeGenIP) { return Error::success(); };
1106
1107 /// Create a sequential execution region within a merged parallel region,
1108 /// encapsulated in a master construct with a barrier for synchronization.
1109 auto CreateSequentialRegion = [&](Function *OuterFn,
1110 BasicBlock *OuterPredBB,
1111 Instruction *SeqStartI,
1112 Instruction *SeqEndI) {
1113 // Isolate the instructions of the sequential region to a separate
1114 // block.
1115 BasicBlock *ParentBB = SeqStartI->getParent();
1116 BasicBlock *SeqEndBB =
1117 SplitBlock(ParentBB, SeqEndI->getNextNode(), DT, LI);
1118 BasicBlock *SeqAfterBB =
1119 SplitBlock(SeqEndBB, &*SeqEndBB->getFirstInsertionPt(), DT, LI);
1120 BasicBlock *SeqStartBB =
1121 SplitBlock(ParentBB, SeqStartI, DT, LI, nullptr, "seq.par.merged");
1122
1123 assert(ParentBB->getUniqueSuccessor() == SeqStartBB &&
1124 "Expected a different CFG");
1125 const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
1126 ParentBB->getTerminator()->eraseFromParent();
1127
1128 auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
1129 BasicBlock *CGStartBB = CodeGenIP.getBlock();
1130 BasicBlock *CGEndBB =
1131 SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
1132 assert(SeqStartBB != nullptr && "SeqStartBB should not be null");
1133 CGStartBB->getTerminator()->setSuccessor(0, SeqStartBB);
1134 assert(SeqEndBB != nullptr && "SeqEndBB should not be null");
1135 SeqEndBB->getTerminator()->setSuccessor(0, CGEndBB);
1136 return Error::success();
1137 };
1138 auto FiniCB = [&](InsertPointTy CodeGenIP) { return Error::success(); };
1139
1140 // Find outputs from the sequential region to outside users and
1141 // broadcast their values to them.
1142 for (Instruction &I : *SeqStartBB) {
1143 SmallPtrSet<Instruction *, 4> OutsideUsers;
1144 for (User *Usr : I.users()) {
1145 Instruction &UsrI = *cast<Instruction>(Usr);
1146 // Ignore outputs to LT intrinsics, code extraction for the merged
1147 // parallel region will fix them.
1148 if (UsrI.isLifetimeStartOrEnd())
1149 continue;
1150
1151 if (UsrI.getParent() != SeqStartBB)
1152 OutsideUsers.insert(&UsrI);
1153 }
1154
1155 if (OutsideUsers.empty())
1156 continue;
1157
1158 // Emit an alloca in the outer region to store the broadcasted
1159 // value.
1160 const DataLayout &DL = M.getDataLayout();
1161 AllocaInst *AllocaI = new AllocaInst(
1162 I.getType(), DL.getAllocaAddrSpace(), nullptr,
1163 I.getName() + ".seq.output.alloc", OuterFn->front().begin());
1164
1165 // Emit a store instruction in the sequential BB to update the
1166 // value.
1167 new StoreInst(&I, AllocaI, SeqStartBB->getTerminator()->getIterator());
1168
1169 // Emit a load instruction and replace the use of the output value
1170 // with it.
1171 for (Instruction *UsrI : OutsideUsers) {
1172 LoadInst *LoadI = new LoadInst(I.getType(), AllocaI,
1173 I.getName() + ".seq.output.load",
1174 UsrI->getIterator());
1175 UsrI->replaceUsesOfWith(&I, LoadI);
1176 }
1177 }
1178
1179 OpenMPIRBuilder::LocationDescription Loc(
1180 InsertPointTy(ParentBB, ParentBB->end()), DL);
1182 OMPInfoCache.OMPBuilder.createMaster(Loc, BodyGenCB, FiniCB));
1183 cantFail(
1184 OMPInfoCache.OMPBuilder.createBarrier(SeqAfterIP, OMPD_parallel));
1185
1186 BranchInst::Create(SeqAfterBB, SeqAfterIP.getBlock());
1187
1188 LLVM_DEBUG(dbgs() << TAG << "After sequential inlining " << *OuterFn
1189 << "\n");
1190 };
1191
1192 // Helper to merge the __kmpc_fork_call calls in MergableCIs. They are all
1193 // contained in BB and only separated by instructions that can be
1194 // redundantly executed in parallel. The block BB is split before the first
1195 // call (in MergableCIs) and after the last so the entire region we merge
1196 // into a single parallel region is contained in a single basic block
1197 // without any other instructions. We use the OpenMPIRBuilder to outline
1198 // that block and call the resulting function via __kmpc_fork_call.
1199 auto Merge = [&](const SmallVectorImpl<CallInst *> &MergableCIs,
1200 BasicBlock *BB) {
1201 // TODO: Change the interface to allow single CIs expanded, e.g, to
1202 // include an outer loop.
1203 assert(MergableCIs.size() > 1 && "Assumed multiple mergable CIs");
1204
1205 auto Remark = [&](OptimizationRemark OR) {
1206 OR << "Parallel region merged with parallel region"
1207 << (MergableCIs.size() > 2 ? "s" : "") << " at ";
1208 for (auto *CI : llvm::drop_begin(MergableCIs)) {
1209 OR << ore::NV("OpenMPParallelMerge", CI->getDebugLoc());
1210 if (CI != MergableCIs.back())
1211 OR << ", ";
1212 }
1213 return OR << ".";
1214 };
1215
1216 emitRemark<OptimizationRemark>(MergableCIs.front(), "OMP150", Remark);
1217
1218 Function *OriginalFn = BB->getParent();
1219 LLVM_DEBUG(dbgs() << TAG << "Merge " << MergableCIs.size()
1220 << " parallel regions in " << OriginalFn->getName()
1221 << "\n");
1222
1223 // Isolate the calls to merge in a separate block.
1224 EndBB = SplitBlock(BB, MergableCIs.back()->getNextNode(), DT, LI);
1225 BasicBlock *AfterBB =
1226 SplitBlock(EndBB, &*EndBB->getFirstInsertionPt(), DT, LI);
1227 StartBB = SplitBlock(BB, MergableCIs.front(), DT, LI, nullptr,
1228 "omp.par.merged");
1229
1230 assert(BB->getUniqueSuccessor() == StartBB && "Expected a different CFG");
1231 const DebugLoc DL = BB->getTerminator()->getDebugLoc();
1232 BB->getTerminator()->eraseFromParent();
1233
1234 // Create sequential regions for sequential instructions that are
1235 // in-between mergable parallel regions.
1236 for (auto *It = MergableCIs.begin(), *End = MergableCIs.end() - 1;
1237 It != End; ++It) {
1238 Instruction *ForkCI = *It;
1239 Instruction *NextForkCI = *(It + 1);
1240
1241 // Continue if there are not in-between instructions.
1242 if (ForkCI->getNextNode() == NextForkCI)
1243 continue;
1244
1245 CreateSequentialRegion(OriginalFn, BB, ForkCI->getNextNode(),
1246 NextForkCI->getPrevNode());
1247 }
1248
1249 OpenMPIRBuilder::LocationDescription Loc(InsertPointTy(BB, BB->end()),
1250 DL);
1251 IRBuilder<>::InsertPoint AllocaIP(
1252 &OriginalFn->getEntryBlock(),
1253 OriginalFn->getEntryBlock().getFirstInsertionPt());
1254 // Create the merged parallel region with default proc binding, to
1255 // avoid overriding binding settings, and without explicit cancellation.
1257 cantFail(OMPInfoCache.OMPBuilder.createParallel(
1258 Loc, AllocaIP, BodyGenCB, PrivCB, FiniCB, nullptr, nullptr,
1259 OMP_PROC_BIND_default, /* IsCancellable */ false));
1260 BranchInst::Create(AfterBB, AfterIP.getBlock());
1261
1262 // Perform the actual outlining.
1263 OMPInfoCache.OMPBuilder.finalize(OriginalFn);
1264
1265 Function *OutlinedFn = MergableCIs.front()->getCaller();
1266
1267 // Replace the __kmpc_fork_call calls with direct calls to the outlined
1268 // callbacks.
1269 SmallVector<Value *, 8> Args;
1270 for (auto *CI : MergableCIs) {
1271 Value *Callee = CI->getArgOperand(CallbackCalleeOperand);
1272 FunctionType *FT = OMPInfoCache.OMPBuilder.ParallelTask;
1273 Args.clear();
1274 Args.push_back(OutlinedFn->getArg(0));
1275 Args.push_back(OutlinedFn->getArg(1));
1276 for (unsigned U = CallbackFirstArgOperand, E = CI->arg_size(); U < E;
1277 ++U)
1278 Args.push_back(CI->getArgOperand(U));
1279
1280 CallInst *NewCI =
1281 CallInst::Create(FT, Callee, Args, "", CI->getIterator());
1282 if (CI->getDebugLoc())
1283 NewCI->setDebugLoc(CI->getDebugLoc());
1284
1285 // Forward parameter attributes from the callback to the callee.
1286 for (unsigned U = CallbackFirstArgOperand, E = CI->arg_size(); U < E;
1287 ++U)
1288 for (const Attribute &A : CI->getAttributes().getParamAttrs(U))
1289 NewCI->addParamAttr(
1290 U - (CallbackFirstArgOperand - CallbackCalleeOperand), A);
1291
1292 // Emit an explicit barrier to replace the implicit fork-join barrier.
1293 if (CI != MergableCIs.back()) {
1294 // TODO: Remove barrier if the merged parallel region includes the
1295 // 'nowait' clause.
1296 cantFail(OMPInfoCache.OMPBuilder.createBarrier(
1297 InsertPointTy(NewCI->getParent(),
1298 NewCI->getNextNode()->getIterator()),
1299 OMPD_parallel));
1300 }
1301
1302 CI->eraseFromParent();
1303 }
1304
1305 assert(OutlinedFn != OriginalFn && "Outlining failed");
1306 CGUpdater.registerOutlinedFunction(*OriginalFn, *OutlinedFn);
1307 CGUpdater.reanalyzeFunction(*OriginalFn);
1308
1309 NumOpenMPParallelRegionsMerged += MergableCIs.size();
1310
1311 return true;
1312 };
1313
1314 // Helper function that identifes sequences of
1315 // __kmpc_fork_call uses in a basic block.
1316 auto DetectPRsCB = [&](Use &U, Function &F) {
1317 CallInst *CI = getCallIfRegularCall(U, &RFI);
1318 BB2PRMap[CI->getParent()].insert(CI);
1319
1320 return false;
1321 };
1322
1323 BB2PRMap.clear();
1324 RFI.foreachUse(SCC, DetectPRsCB);
1325 SmallVector<SmallVector<CallInst *, 4>, 4> MergableCIsVector;
1326 // Find mergable parallel regions within a basic block that are
1327 // safe to merge, that is any in-between instructions can safely
1328 // execute in parallel after merging.
1329 // TODO: support merging across basic-blocks.
1330 for (auto &It : BB2PRMap) {
1331 auto &CIs = It.getSecond();
1332 if (CIs.size() < 2)
1333 continue;
1334
1335 BasicBlock *BB = It.getFirst();
1336 SmallVector<CallInst *, 4> MergableCIs;
1337
1338 /// Returns true if the instruction is mergable, false otherwise.
1339 /// A terminator instruction is unmergable by definition since merging
1340 /// works within a BB. Instructions before the mergable region are
1341 /// mergable if they are not calls to OpenMP runtime functions that may
1342 /// set different execution parameters for subsequent parallel regions.
1343 /// Instructions in-between parallel regions are mergable if they are not
1344 /// calls to any non-intrinsic function since that may call a non-mergable
1345 /// OpenMP runtime function.
1346 auto IsMergable = [&](Instruction &I, bool IsBeforeMergableRegion) {
1347 // We do not merge across BBs, hence return false (unmergable) if the
1348 // instruction is a terminator.
1349 if (I.isTerminator())
1350 return false;
1351
1352 if (!isa<CallInst>(&I))
1353 return true;
1354
1355 CallInst *CI = cast<CallInst>(&I);
1356 if (IsBeforeMergableRegion) {
1357 Function *CalledFunction = CI->getCalledFunction();
1358 if (!CalledFunction)
1359 return false;
1360 // Return false (unmergable) if the call before the parallel
1361 // region calls an explicit affinity (proc_bind) or number of
1362 // threads (num_threads) compiler-generated function. Those settings
1363 // may be incompatible with following parallel regions.
1364 // TODO: ICV tracking to detect compatibility.
1365 for (const auto &RFI : UnmergableCallsInfo) {
1366 if (CalledFunction == RFI.Declaration)
1367 return false;
1368 }
1369 } else {
1370 // Return false (unmergable) if there is a call instruction
1371 // in-between parallel regions when it is not an intrinsic. It
1372 // may call an unmergable OpenMP runtime function in its callpath.
1373 // TODO: Keep track of possible OpenMP calls in the callpath.
1374 if (!isa<IntrinsicInst>(CI))
1375 return false;
1376 }
1377
1378 return true;
1379 };
1380 // Find maximal number of parallel region CIs that are safe to merge.
1381 for (auto It = BB->begin(), End = BB->end(); It != End;) {
1382 Instruction &I = *It;
1383 ++It;
1384
1385 if (CIs.count(&I)) {
1386 MergableCIs.push_back(cast<CallInst>(&I));
1387 continue;
1388 }
1389
1390 // Continue expanding if the instruction is mergable.
1391 if (IsMergable(I, MergableCIs.empty()))
1392 continue;
1393
1394 // Forward the instruction iterator to skip the next parallel region
1395 // since there is an unmergable instruction which can affect it.
1396 for (; It != End; ++It) {
1397 Instruction &SkipI = *It;
1398 if (CIs.count(&SkipI)) {
1399 LLVM_DEBUG(dbgs() << TAG << "Skip parallel region " << SkipI
1400 << " due to " << I << "\n");
1401 ++It;
1402 break;
1403 }
1404 }
1405
1406 // Store mergable regions found.
1407 if (MergableCIs.size() > 1) {
1408 MergableCIsVector.push_back(MergableCIs);
1409 LLVM_DEBUG(dbgs() << TAG << "Found " << MergableCIs.size()
1410 << " parallel regions in block " << BB->getName()
1411 << " of function " << BB->getParent()->getName()
1412 << "\n";);
1413 }
1414
1415 MergableCIs.clear();
1416 }
1417
1418 if (!MergableCIsVector.empty()) {
1419 Changed = true;
1420
1421 for (auto &MergableCIs : MergableCIsVector)
1422 Merge(MergableCIs, BB);
1423 MergableCIsVector.clear();
1424 }
1425 }
1426
1427 if (Changed) {
1428 /// Re-collect use for fork calls, emitted barrier calls, and
1429 /// any emitted master/end_master calls.
1430 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_fork_call);
1431 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_barrier);
1432 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_master);
1433 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_end_master);
1434 }
1435
1436 return Changed;
1437 }
1438
1439 /// Try to delete parallel regions if possible.
1440 bool deleteParallelRegions() {
1441 const unsigned CallbackCalleeOperand = 2;
1442
1443 OMPInformationCache::RuntimeFunctionInfo &RFI =
1444 OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
1445
1446 if (!RFI.Declaration)
1447 return false;
1448
1449 bool Changed = false;
1450 auto DeleteCallCB = [&](Use &U, Function &) {
1451 CallInst *CI = getCallIfRegularCall(U);
1452 if (!CI)
1453 return false;
1454 auto *Fn = dyn_cast<Function>(
1455 CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts());
1456 if (!Fn)
1457 return false;
1458 if (!Fn->onlyReadsMemory())
1459 return false;
1460 if (!Fn->hasFnAttribute(Attribute::WillReturn))
1461 return false;
1462
1463 LLVM_DEBUG(dbgs() << TAG << "Delete read-only parallel region in "
1464 << CI->getCaller()->getName() << "\n");
1465
1466 auto Remark = [&](OptimizationRemark OR) {
1467 return OR << "Removing parallel region with no side-effects.";
1468 };
1470
1471 CI->eraseFromParent();
1472 Changed = true;
1473 ++NumOpenMPParallelRegionsDeleted;
1474 return true;
1475 };
1476
1477 RFI.foreachUse(SCC, DeleteCallCB);
1478
1479 return Changed;
1480 }
1481
1482 /// Try to eliminate runtime calls by reusing existing ones.
1483 bool deduplicateRuntimeCalls() {
1484 bool Changed = false;
1485
1486 RuntimeFunction DeduplicableRuntimeCallIDs[] = {
1487 OMPRTL_omp_get_num_threads,
1488 OMPRTL_omp_in_parallel,
1489 OMPRTL_omp_get_cancellation,
1490 OMPRTL_omp_get_supported_active_levels,
1491 OMPRTL_omp_get_level,
1492 OMPRTL_omp_get_ancestor_thread_num,
1493 OMPRTL_omp_get_team_size,
1494 OMPRTL_omp_get_active_level,
1495 OMPRTL_omp_in_final,
1496 OMPRTL_omp_get_proc_bind,
1497 OMPRTL_omp_get_num_places,
1498 OMPRTL_omp_get_num_procs,
1499 OMPRTL_omp_get_place_num,
1500 OMPRTL_omp_get_partition_num_places,
1501 OMPRTL_omp_get_partition_place_nums};
1502
1503 // Global-tid is handled separately.
1504 SmallSetVector<Value *, 16> GTIdArgs;
1505 collectGlobalThreadIdArguments(GTIdArgs);
1506 LLVM_DEBUG(dbgs() << TAG << "Found " << GTIdArgs.size()
1507 << " global thread ID arguments\n");
1508
1509 for (Function *F : SCC) {
1510 for (auto DeduplicableRuntimeCallID : DeduplicableRuntimeCallIDs)
1511 Changed |= deduplicateRuntimeCalls(
1512 *F, OMPInfoCache.RFIs[DeduplicableRuntimeCallID]);
1513
1514 // __kmpc_global_thread_num is special as we can replace it with an
1515 // argument in enough cases to make it worth trying.
1516 Value *GTIdArg = nullptr;
1517 for (Argument &Arg : F->args())
1518 if (GTIdArgs.count(&Arg)) {
1519 GTIdArg = &Arg;
1520 break;
1521 }
1522 Changed |= deduplicateRuntimeCalls(
1523 *F, OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num], GTIdArg);
1524 }
1525
1526 return Changed;
1527 }
1528
1529 /// Tries to remove known runtime symbols that are optional from the module.
1530 bool removeRuntimeSymbols() {
1531 // The RPC client symbol is defined in `libc` and indicates that something
1532 // required an RPC server. If its users were all optimized out then we can
1533 // safely remove it.
1534 // TODO: This should be somewhere more common in the future.
1535 if (GlobalVariable *GV = M.getNamedGlobal("__llvm_rpc_client")) {
1536 if (GV->hasNUsesOrMore(1))
1537 return false;
1538
1539 GV->replaceAllUsesWith(PoisonValue::get(GV->getType()));
1540 GV->eraseFromParent();
1541 return true;
1542 }
1543 return false;
1544 }
1545
1546 /// Tries to hide the latency of runtime calls that involve host to
1547 /// device memory transfers by splitting them into their "issue" and "wait"
1548 /// versions. The "issue" is moved upwards as much as possible. The "wait" is
1549 /// moved downards as much as possible. The "issue" issues the memory transfer
1550 /// asynchronously, returning a handle. The "wait" waits in the returned
1551 /// handle for the memory transfer to finish.
1552 bool hideMemTransfersLatency() {
1553 auto &RFI = OMPInfoCache.RFIs[OMPRTL___tgt_target_data_begin_mapper];
1554 bool Changed = false;
1555 auto SplitMemTransfers = [&](Use &U, Function &Decl) {
1556 auto *RTCall = getCallIfRegularCall(U, &RFI);
1557 if (!RTCall)
1558 return false;
1559
1560 OffloadArray OffloadArrays[3];
1561 if (!getValuesInOffloadArrays(*RTCall, OffloadArrays))
1562 return false;
1563
1564 LLVM_DEBUG(dumpValuesInOffloadArrays(OffloadArrays));
1565
1566 // TODO: Check if can be moved upwards.
1567 bool WasSplit = false;
1568 Instruction *WaitMovementPoint = canBeMovedDownwards(*RTCall);
1569 if (WaitMovementPoint)
1570 WasSplit = splitTargetDataBeginRTC(*RTCall, *WaitMovementPoint);
1571
1572 Changed |= WasSplit;
1573 return WasSplit;
1574 };
1575 if (OMPInfoCache.runtimeFnsAvailable(
1576 {OMPRTL___tgt_target_data_begin_mapper_issue,
1577 OMPRTL___tgt_target_data_begin_mapper_wait}))
1578 RFI.foreachUse(SCC, SplitMemTransfers);
1579
1580 return Changed;
1581 }
1582
1583 void analysisGlobalization() {
1584 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
1585
1586 auto CheckGlobalization = [&](Use &U, Function &Decl) {
1587 if (CallInst *CI = getCallIfRegularCall(U, &RFI)) {
1588 auto Remark = [&](OptimizationRemarkMissed ORM) {
1589 return ORM
1590 << "Found thread data sharing on the GPU. "
1591 << "Expect degraded performance due to data globalization.";
1592 };
1594 }
1595
1596 return false;
1597 };
1598
1599 RFI.foreachUse(SCC, CheckGlobalization);
1600 }
1601
1602 /// Maps the values stored in the offload arrays passed as arguments to
1603 /// \p RuntimeCall into the offload arrays in \p OAs.
1604 bool getValuesInOffloadArrays(CallInst &RuntimeCall,
1606 assert(OAs.size() == 3 && "Need space for three offload arrays!");
1607
1608 // A runtime call that involves memory offloading looks something like:
1609 // call void @__tgt_target_data_begin_mapper(arg0, arg1,
1610 // i8** %offload_baseptrs, i8** %offload_ptrs, i64* %offload_sizes,
1611 // ...)
1612 // So, the idea is to access the allocas that allocate space for these
1613 // offload arrays, offload_baseptrs, offload_ptrs, offload_sizes.
1614 // Therefore:
1615 // i8** %offload_baseptrs.
1616 Value *BasePtrsArg =
1617 RuntimeCall.getArgOperand(OffloadArray::BasePtrsArgNum);
1618 // i8** %offload_ptrs.
1619 Value *PtrsArg = RuntimeCall.getArgOperand(OffloadArray::PtrsArgNum);
1620 // i8** %offload_sizes.
1621 Value *SizesArg = RuntimeCall.getArgOperand(OffloadArray::SizesArgNum);
1622
1623 // Get values stored in **offload_baseptrs.
1624 auto *V = getUnderlyingObject(BasePtrsArg);
1625 if (!isa<AllocaInst>(V))
1626 return false;
1627 auto *BasePtrsArray = cast<AllocaInst>(V);
1628 if (!OAs[0].initialize(*BasePtrsArray, RuntimeCall))
1629 return false;
1630
1631 // Get values stored in **offload_baseptrs.
1632 V = getUnderlyingObject(PtrsArg);
1633 if (!isa<AllocaInst>(V))
1634 return false;
1635 auto *PtrsArray = cast<AllocaInst>(V);
1636 if (!OAs[1].initialize(*PtrsArray, RuntimeCall))
1637 return false;
1638
1639 // Get values stored in **offload_sizes.
1640 V = getUnderlyingObject(SizesArg);
1641 // If it's a [constant] global array don't analyze it.
1642 if (isa<GlobalValue>(V))
1643 return isa<Constant>(V);
1644 if (!isa<AllocaInst>(V))
1645 return false;
1646
1647 auto *SizesArray = cast<AllocaInst>(V);
1648 if (!OAs[2].initialize(*SizesArray, RuntimeCall))
1649 return false;
1650
1651 return true;
1652 }
1653
1654 /// Prints the values in the OffloadArrays \p OAs using LLVM_DEBUG.
1655 /// For now this is a way to test that the function getValuesInOffloadArrays
1656 /// is working properly.
1657 /// TODO: Move this to a unittest when unittests are available for OpenMPOpt.
1658 void dumpValuesInOffloadArrays(ArrayRef<OffloadArray> OAs) {
1659 assert(OAs.size() == 3 && "There are three offload arrays to debug!");
1660
1661 LLVM_DEBUG(dbgs() << TAG << " Successfully got offload values:\n");
1662 std::string ValuesStr;
1663 raw_string_ostream Printer(ValuesStr);
1664 std::string Separator = " --- ";
1665
1666 for (auto *BP : OAs[0].StoredValues) {
1667 BP->print(Printer);
1668 Printer << Separator;
1669 }
1670 LLVM_DEBUG(dbgs() << "\t\toffload_baseptrs: " << ValuesStr << "\n");
1671 ValuesStr.clear();
1672
1673 for (auto *P : OAs[1].StoredValues) {
1674 P->print(Printer);
1675 Printer << Separator;
1676 }
1677 LLVM_DEBUG(dbgs() << "\t\toffload_ptrs: " << ValuesStr << "\n");
1678 ValuesStr.clear();
1679
1680 for (auto *S : OAs[2].StoredValues) {
1681 S->print(Printer);
1682 Printer << Separator;
1683 }
1684 LLVM_DEBUG(dbgs() << "\t\toffload_sizes: " << ValuesStr << "\n");
1685 }
1686
1687 /// Returns the instruction where the "wait" counterpart \p RuntimeCall can be
1688 /// moved. Returns nullptr if the movement is not possible, or not worth it.
1689 Instruction *canBeMovedDownwards(CallInst &RuntimeCall) {
1690 // FIXME: This traverses only the BasicBlock where RuntimeCall is.
1691 // Make it traverse the CFG.
1692
1693 Instruction *CurrentI = &RuntimeCall;
1694 bool IsWorthIt = false;
1695 while ((CurrentI = CurrentI->getNextNode())) {
1696
1697 // TODO: Once we detect the regions to be offloaded we should use the
1698 // alias analysis manager to check if CurrentI may modify one of
1699 // the offloaded regions.
1700 if (CurrentI->mayHaveSideEffects() || CurrentI->mayReadFromMemory()) {
1701 if (IsWorthIt)
1702 return CurrentI;
1703
1704 return nullptr;
1705 }
1706
1707 // FIXME: For now if we move it over anything without side effect
1708 // is worth it.
1709 IsWorthIt = true;
1710 }
1711
1712 // Return end of BasicBlock.
1713 return RuntimeCall.getParent()->getTerminator();
1714 }
1715
1716 /// Splits \p RuntimeCall into its "issue" and "wait" counterparts.
1717 bool splitTargetDataBeginRTC(CallInst &RuntimeCall,
1718 Instruction &WaitMovementPoint) {
1719 // Create stack allocated handle (__tgt_async_info) at the beginning of the
1720 // function. Used for storing information of the async transfer, allowing to
1721 // wait on it later.
1722 auto &IRBuilder = OMPInfoCache.OMPBuilder;
1723 Function *F = RuntimeCall.getCaller();
1724 BasicBlock &Entry = F->getEntryBlock();
1725 IRBuilder.Builder.SetInsertPoint(&Entry,
1726 Entry.getFirstNonPHIOrDbgOrAlloca());
1727 Value *Handle = IRBuilder.Builder.CreateAlloca(
1728 IRBuilder.AsyncInfo, /*ArraySize=*/nullptr, "handle");
1729 Handle =
1730 IRBuilder.Builder.CreateAddrSpaceCast(Handle, IRBuilder.AsyncInfoPtr);
1731
1732 // Add "issue" runtime call declaration:
1733 // declare %struct.tgt_async_info @__tgt_target_data_begin_issue(i64, i32,
1734 // i8**, i8**, i64*, i64*)
1735 FunctionCallee IssueDecl = IRBuilder.getOrCreateRuntimeFunction(
1736 M, OMPRTL___tgt_target_data_begin_mapper_issue);
1737
1738 // Change RuntimeCall call site for its asynchronous version.
1739 SmallVector<Value *, 16> Args;
1740 for (auto &Arg : RuntimeCall.args())
1741 Args.push_back(Arg.get());
1742 Args.push_back(Handle);
1743
1744 CallInst *IssueCallsite = CallInst::Create(IssueDecl, Args, /*NameStr=*/"",
1745 RuntimeCall.getIterator());
1746 OMPInfoCache.setCallingConvention(IssueDecl, IssueCallsite);
1747 RuntimeCall.eraseFromParent();
1748
1749 // Add "wait" runtime call declaration:
1750 // declare void @__tgt_target_data_begin_wait(i64, %struct.__tgt_async_info)
1751 FunctionCallee WaitDecl = IRBuilder.getOrCreateRuntimeFunction(
1752 M, OMPRTL___tgt_target_data_begin_mapper_wait);
1753
1754 Value *WaitParams[2] = {
1755 IssueCallsite->getArgOperand(
1756 OffloadArray::DeviceIDArgNum), // device_id.
1757 Handle // handle to wait on.
1758 };
1759 CallInst *WaitCallsite = CallInst::Create(
1760 WaitDecl, WaitParams, /*NameStr=*/"", WaitMovementPoint.getIterator());
1761 OMPInfoCache.setCallingConvention(WaitDecl, WaitCallsite);
1762
1763 return true;
1764 }
1765
1766 static Value *combinedIdentStruct(Value *CurrentIdent, Value *NextIdent,
1767 bool GlobalOnly, bool &SingleChoice) {
1768 if (CurrentIdent == NextIdent)
1769 return CurrentIdent;
1770
1771 // TODO: Figure out how to actually combine multiple debug locations. For
1772 // now we just keep an existing one if there is a single choice.
1773 if (!GlobalOnly || isa<GlobalValue>(NextIdent)) {
1774 SingleChoice = !CurrentIdent;
1775 return NextIdent;
1776 }
1777 return nullptr;
1778 }
1779
1780 /// Return an `struct ident_t*` value that represents the ones used in the
1781 /// calls of \p RFI inside of \p F. If \p GlobalOnly is true, we will not
1782 /// return a local `struct ident_t*`. For now, if we cannot find a suitable
1783 /// return value we create one from scratch. We also do not yet combine
1784 /// information, e.g., the source locations, see combinedIdentStruct.
1785 Value *
1786 getCombinedIdentFromCallUsesIn(OMPInformationCache::RuntimeFunctionInfo &RFI,
1787 Function &F, bool GlobalOnly) {
1788 bool SingleChoice = true;
1789 Value *Ident = nullptr;
1790 auto CombineIdentStruct = [&](Use &U, Function &Caller) {
1791 CallInst *CI = getCallIfRegularCall(U, &RFI);
1792 if (!CI || &F != &Caller)
1793 return false;
1794 Ident = combinedIdentStruct(Ident, CI->getArgOperand(0),
1795 /* GlobalOnly */ true, SingleChoice);
1796 return false;
1797 };
1798 RFI.foreachUse(SCC, CombineIdentStruct);
1799
1800 if (!Ident || !SingleChoice) {
1801 // The IRBuilder uses the insertion block to get to the module, this is
1802 // unfortunate but we work around it for now.
1803 if (!OMPInfoCache.OMPBuilder.getInsertionPoint().getBlock())
1804 OMPInfoCache.OMPBuilder.updateToLocation(OpenMPIRBuilder::InsertPointTy(
1805 &F.getEntryBlock(), F.getEntryBlock().begin()));
1806 // Create a fallback location if non was found.
1807 // TODO: Use the debug locations of the calls instead.
1808 uint32_t SrcLocStrSize;
1809 Constant *Loc =
1810 OMPInfoCache.OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
1811 Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(Loc, SrcLocStrSize);
1812 }
1813 return Ident;
1814 }
1815
1816 /// Try to eliminate calls of \p RFI in \p F by reusing an existing one or
1817 /// \p ReplVal if given.
1818 bool deduplicateRuntimeCalls(Function &F,
1819 OMPInformationCache::RuntimeFunctionInfo &RFI,
1820 Value *ReplVal = nullptr) {
1821 auto *UV = RFI.getUseVector(F);
1822 if (!UV || UV->size() + (ReplVal != nullptr) < 2)
1823 return false;
1824
1825 LLVM_DEBUG(
1826 dbgs() << TAG << "Deduplicate " << UV->size() << " uses of " << RFI.Name
1827 << (ReplVal ? " with an existing value\n" : "\n") << "\n");
1828
1829 assert((!ReplVal || (isa<Argument>(ReplVal) &&
1830 cast<Argument>(ReplVal)->getParent() == &F)) &&
1831 "Unexpected replacement value!");
1832
1833 // TODO: Use dominance to find a good position instead.
1834 auto CanBeMoved = [this](CallBase &CB) {
1835 unsigned NumArgs = CB.arg_size();
1836 if (NumArgs == 0)
1837 return true;
1838 if (CB.getArgOperand(0)->getType() != OMPInfoCache.OMPBuilder.IdentPtr)
1839 return false;
1840 for (unsigned U = 1; U < NumArgs; ++U)
1841 if (isa<Instruction>(CB.getArgOperand(U)))
1842 return false;
1843 return true;
1844 };
1845
1846 if (!ReplVal) {
1847 auto *DT =
1848 OMPInfoCache.getAnalysisResultForFunction<DominatorTreeAnalysis>(F);
1849 if (!DT)
1850 return false;
1851 Instruction *IP = nullptr;
1852 for (Use *U : *UV) {
1853 if (CallInst *CI = getCallIfRegularCall(*U, &RFI)) {
1854 if (IP)
1855 IP = DT->findNearestCommonDominator(IP, CI);
1856 else
1857 IP = CI;
1858 if (!CanBeMoved(*CI))
1859 continue;
1860 if (!ReplVal)
1861 ReplVal = CI;
1862 }
1863 }
1864 if (!ReplVal)
1865 return false;
1866 assert(IP && "Expected insertion point!");
1867 cast<Instruction>(ReplVal)->moveBefore(IP->getIterator());
1868 }
1869
1870 // If we use a call as a replacement value we need to make sure the ident is
1871 // valid at the new location. For now we just pick a global one, either
1872 // existing and used by one of the calls, or created from scratch.
1873 if (CallBase *CI = dyn_cast<CallBase>(ReplVal)) {
1874 if (!CI->arg_empty() &&
1875 CI->getArgOperand(0)->getType() == OMPInfoCache.OMPBuilder.IdentPtr) {
1876 Value *Ident = getCombinedIdentFromCallUsesIn(RFI, F,
1877 /* GlobalOnly */ true);
1878 CI->setArgOperand(0, Ident);
1879 }
1880 }
1881
1882 bool Changed = false;
1883 auto ReplaceAndDeleteCB = [&](Use &U, Function &Caller) {
1884 CallInst *CI = getCallIfRegularCall(U, &RFI);
1885 if (!CI || CI == ReplVal || &F != &Caller)
1886 return false;
1887 assert(CI->getCaller() == &F && "Unexpected call!");
1888
1889 auto Remark = [&](OptimizationRemark OR) {
1890 return OR << "OpenMP runtime call "
1891 << ore::NV("OpenMPOptRuntime", RFI.Name) << " deduplicated.";
1892 };
1893 if (CI->getDebugLoc())
1895 else
1897
1898 CI->replaceAllUsesWith(ReplVal);
1899 CI->eraseFromParent();
1900 ++NumOpenMPRuntimeCallsDeduplicated;
1901 Changed = true;
1902 return true;
1903 };
1904 RFI.foreachUse(SCC, ReplaceAndDeleteCB);
1905
1906 return Changed;
1907 }
1908
1909 /// Collect arguments that represent the global thread id in \p GTIdArgs.
1910 void collectGlobalThreadIdArguments(SmallSetVector<Value *, 16> &GTIdArgs) {
1911 // TODO: Below we basically perform a fixpoint iteration with a pessimistic
1912 // initialization. We could define an AbstractAttribute instead and
1913 // run the Attributor here once it can be run as an SCC pass.
1914
1915 // Helper to check the argument \p ArgNo at all call sites of \p F for
1916 // a GTId.
1917 auto CallArgOpIsGTId = [&](Function &F, unsigned ArgNo, CallInst &RefCI) {
1918 if (!F.hasLocalLinkage())
1919 return false;
1920 for (Use &U : F.uses()) {
1921 if (CallInst *CI = getCallIfRegularCall(U)) {
1922 Value *ArgOp = CI->getArgOperand(ArgNo);
1923 if (CI == &RefCI || GTIdArgs.count(ArgOp) ||
1924 getCallIfRegularCall(
1925 *ArgOp, &OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num]))
1926 continue;
1927 }
1928 return false;
1929 }
1930 return true;
1931 };
1932
1933 // Helper to identify uses of a GTId as GTId arguments.
1934 auto AddUserArgs = [&](Value &GTId) {
1935 for (Use &U : GTId.uses())
1936 if (CallInst *CI = dyn_cast<CallInst>(U.getUser()))
1937 if (CI->isArgOperand(&U))
1938 if (Function *Callee = CI->getCalledFunction())
1939 if (CallArgOpIsGTId(*Callee, U.getOperandNo(), *CI))
1940 GTIdArgs.insert(Callee->getArg(U.getOperandNo()));
1941 };
1942
1943 // The argument users of __kmpc_global_thread_num calls are GTIds.
1944 OMPInformationCache::RuntimeFunctionInfo &GlobThreadNumRFI =
1945 OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num];
1946
1947 GlobThreadNumRFI.foreachUse(SCC, [&](Use &U, Function &F) {
1948 if (CallInst *CI = getCallIfRegularCall(U, &GlobThreadNumRFI))
1949 AddUserArgs(*CI);
1950 return false;
1951 });
1952
1953 // Transitively search for more arguments by looking at the users of the
1954 // ones we know already. During the search the GTIdArgs vector is extended
1955 // so we cannot cache the size nor can we use a range based for.
1956 for (unsigned U = 0; U < GTIdArgs.size(); ++U)
1957 AddUserArgs(*GTIdArgs[U]);
1958 }
1959
1960 /// Kernel (=GPU) optimizations and utility functions
1961 ///
1962 ///{{
1963
1964 /// Cache to remember the unique kernel for a function.
1965 DenseMap<Function *, std::optional<Kernel>> UniqueKernelMap;
1966
1967 /// Find the unique kernel that will execute \p F, if any.
1968 Kernel getUniqueKernelFor(Function &F);
1969
1970 /// Find the unique kernel that will execute \p I, if any.
1971 Kernel getUniqueKernelFor(Instruction &I) {
1972 return getUniqueKernelFor(*I.getFunction());
1973 }
1974
1975 /// Rewrite the device (=GPU) code state machine create in non-SPMD mode in
1976 /// the cases we can avoid taking the address of a function.
1977 bool rewriteDeviceCodeStateMachine();
1978
1979 ///
1980 ///}}
1981
1982 /// Emit a remark generically
1983 ///
1984 /// This template function can be used to generically emit a remark. The
1985 /// RemarkKind should be one of the following:
1986 /// - OptimizationRemark to indicate a successful optimization attempt
1987 /// - OptimizationRemarkMissed to report a failed optimization attempt
1988 /// - OptimizationRemarkAnalysis to provide additional information about an
1989 /// optimization attempt
1990 ///
1991 /// The remark is built using a callback function provided by the caller that
1992 /// takes a RemarkKind as input and returns a RemarkKind.
1993 template <typename RemarkKind, typename RemarkCallBack>
1994 void emitRemark(Instruction *I, StringRef RemarkName,
1995 RemarkCallBack &&RemarkCB) const {
1996 Function *F = I->getParent()->getParent();
1997 auto &ORE = OREGetter(F);
1998
1999 if (RemarkName.starts_with("OMP"))
2000 ORE.emit([&]() {
2001 return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I))
2002 << " [" << RemarkName << "]";
2003 });
2004 else
2005 ORE.emit(
2006 [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I)); });
2007 }
2008
2009 /// Emit a remark on a function.
2010 template <typename RemarkKind, typename RemarkCallBack>
2011 void emitRemark(Function *F, StringRef RemarkName,
2012 RemarkCallBack &&RemarkCB) const {
2013 auto &ORE = OREGetter(F);
2014
2015 if (RemarkName.starts_with("OMP"))
2016 ORE.emit([&]() {
2017 return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F))
2018 << " [" << RemarkName << "]";
2019 });
2020 else
2021 ORE.emit(
2022 [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F)); });
2023 }
2024
2025 /// The underlying module.
2026 Module &M;
2027
2028 /// The SCC we are operating on.
2029 SmallVectorImpl<Function *> &SCC;
2030
2031 /// Callback to update the call graph, the first argument is a removed call,
2032 /// the second an optional replacement call.
2033 CallGraphUpdater &CGUpdater;
2034
2035 /// Callback to get an OptimizationRemarkEmitter from a Function *
2036 OptimizationRemarkGetter OREGetter;
2037
2038 /// OpenMP-specific information cache. Also Used for Attributor runs.
2039 OMPInformationCache &OMPInfoCache;
2040
2041 /// Attributor instance.
2042 Attributor &A;
2043
2044 /// Helper function to run Attributor on SCC.
2045 bool runAttributor(bool IsModulePass) {
2046 if (SCC.empty())
2047 return false;
2048
2049 registerAAs(IsModulePass);
2050
2051 ChangeStatus Changed = A.run();
2052
2053 LLVM_DEBUG(dbgs() << "[Attributor] Done with " << SCC.size()
2054 << " functions, result: " << Changed << ".\n");
2055
2056 if (Changed == ChangeStatus::CHANGED)
2057 OMPInfoCache.invalidateAnalyses();
2058
2059 return Changed == ChangeStatus::CHANGED;
2060 }
2061
2062 void registerFoldRuntimeCall(RuntimeFunction RF);
2063
2064 /// Populate the Attributor with abstract attribute opportunities in the
2065 /// functions.
2066 void registerAAs(bool IsModulePass);
2067
2068public:
2069 /// Callback to register AAs for live functions, including internal functions
2070 /// marked live during the traversal.
2071 static void registerAAsForFunction(Attributor &A, const Function &F);
2072};
2073
2074Kernel OpenMPOpt::getUniqueKernelFor(Function &F) {
2075 if (OMPInfoCache.CGSCC && !OMPInfoCache.CGSCC->empty() &&
2076 !OMPInfoCache.CGSCC->contains(&F))
2077 return nullptr;
2078
2079 // Use a scope to keep the lifetime of the CachedKernel short.
2080 {
2081 std::optional<Kernel> &CachedKernel = UniqueKernelMap[&F];
2082 if (CachedKernel)
2083 return *CachedKernel;
2084
2085 // TODO: We should use an AA to create an (optimistic and callback
2086 // call-aware) call graph. For now we stick to simple patterns that
2087 // are less powerful, basically the worst fixpoint.
2088 if (isOpenMPKernel(F)) {
2089 CachedKernel = Kernel(&F);
2090 return *CachedKernel;
2091 }
2092
2093 CachedKernel = nullptr;
2094 if (!F.hasLocalLinkage()) {
2095
2096 // See https://openmp.llvm.org/remarks/OptimizationRemarks.html
2097 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2098 return ORA << "Potentially unknown OpenMP target region caller.";
2099 };
2101
2102 return nullptr;
2103 }
2104 }
2105
2106 auto GetUniqueKernelForUse = [&](const Use &U) -> Kernel {
2107 if (auto *Cmp = dyn_cast<ICmpInst>(U.getUser())) {
2108 // Allow use in equality comparisons.
2109 if (Cmp->isEquality())
2110 return getUniqueKernelFor(*Cmp);
2111 return nullptr;
2112 }
2113 if (auto *CB = dyn_cast<CallBase>(U.getUser())) {
2114 // Allow direct calls.
2115 if (CB->isCallee(&U))
2116 return getUniqueKernelFor(*CB);
2117
2118 OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
2119 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_60];
2120 // Allow the use in __kmpc_parallel_60 calls.
2121 if (OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI))
2122 return getUniqueKernelFor(*CB);
2123 return nullptr;
2124 }
2125 // Disallow every other use.
2126 return nullptr;
2127 };
2128
2129 // TODO: In the future we want to track more than just a unique kernel.
2130 SmallPtrSet<Kernel, 2> PotentialKernels;
2131 OMPInformationCache::foreachUse(F, [&](const Use &U) {
2132 PotentialKernels.insert(GetUniqueKernelForUse(U));
2133 });
2134
2135 Kernel K = nullptr;
2136 if (PotentialKernels.size() == 1)
2137 K = *PotentialKernels.begin();
2138
2139 // Cache the result.
2140 UniqueKernelMap[&F] = K;
2141
2142 return K;
2143}
2144
2145bool OpenMPOpt::rewriteDeviceCodeStateMachine() {
2146 OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
2147 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_60];
2148
2149 bool Changed = false;
2150 if (!KernelParallelRFI)
2151 return Changed;
2152
2153 // If we have disabled state machine changes, exit
2155 return Changed;
2156
2157 for (Function *F : SCC) {
2158
2159 // Check if the function is a use in a __kmpc_parallel_60 call at
2160 // all.
2161 bool UnknownUse = false;
2162 bool KernelParallelUse = false;
2163 unsigned NumDirectCalls = 0;
2164
2165 SmallVector<Use *, 2> ToBeReplacedStateMachineUses;
2166 OMPInformationCache::foreachUse(*F, [&](Use &U) {
2167 if (auto *CB = dyn_cast<CallBase>(U.getUser()))
2168 if (CB->isCallee(&U)) {
2169 ++NumDirectCalls;
2170 return;
2171 }
2172
2173 if (isa<ICmpInst>(U.getUser())) {
2174 ToBeReplacedStateMachineUses.push_back(&U);
2175 return;
2176 }
2177
2178 // Find wrapper functions that represent parallel kernels.
2179 CallInst *CI =
2180 OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI);
2181 const unsigned int WrapperFunctionArgNo = 6;
2182 if (!KernelParallelUse && CI &&
2183 CI->getArgOperandNo(&U) == WrapperFunctionArgNo) {
2184 KernelParallelUse = true;
2185 ToBeReplacedStateMachineUses.push_back(&U);
2186 return;
2187 }
2188 UnknownUse = true;
2189 });
2190
2191 // Do not emit a remark if we haven't seen a __kmpc_parallel_60
2192 // use.
2193 if (!KernelParallelUse)
2194 continue;
2195
2196 // If this ever hits, we should investigate.
2197 // TODO: Checking the number of uses is not a necessary restriction and
2198 // should be lifted.
2199 if (UnknownUse || NumDirectCalls != 1 ||
2200 ToBeReplacedStateMachineUses.size() > 2) {
2201 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2202 return ORA << "Parallel region is used in "
2203 << (UnknownUse ? "unknown" : "unexpected")
2204 << " ways. Will not attempt to rewrite the state machine.";
2205 };
2207 continue;
2208 }
2209
2210 // Even if we have __kmpc_parallel_60 calls, we (for now) give
2211 // up if the function is not called from a unique kernel.
2212 Kernel K = getUniqueKernelFor(*F);
2213 if (!K) {
2214 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2215 return ORA << "Parallel region is not called from a unique kernel. "
2216 "Will not attempt to rewrite the state machine.";
2217 };
2219 continue;
2220 }
2221
2222 // We now know F is a parallel body function called only from the kernel K.
2223 // We also identified the state machine uses in which we replace the
2224 // function pointer by a new global symbol for identification purposes. This
2225 // ensures only direct calls to the function are left.
2226
2227 Module &M = *F->getParent();
2228 Type *Int8Ty = Type::getInt8Ty(M.getContext());
2229
2230 auto *ID = new GlobalVariable(
2231 M, Int8Ty, /* isConstant */ true, GlobalValue::PrivateLinkage,
2232 UndefValue::get(Int8Ty), F->getName() + ".ID");
2233
2234 for (Use *U : ToBeReplacedStateMachineUses)
2236 ID, U->get()->getType()));
2237
2238 ++NumOpenMPParallelRegionsReplacedInGPUStateMachine;
2239
2240 Changed = true;
2241 }
2242
2243 return Changed;
2244}
2245
2246/// Abstract Attribute for tracking ICV values.
2247struct AAICVTracker : public StateWrapper<BooleanState, AbstractAttribute> {
2248 using Base = StateWrapper<BooleanState, AbstractAttribute>;
2249 AAICVTracker(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
2250
2251 /// Returns true if value is assumed to be tracked.
2252 bool isAssumedTracked() const { return getAssumed(); }
2253
2254 /// Returns true if value is known to be tracked.
2255 bool isKnownTracked() const { return getAssumed(); }
2256
2257 /// Create an abstract attribute biew for the position \p IRP.
2258 static AAICVTracker &createForPosition(const IRPosition &IRP, Attributor &A);
2259
2260 /// Return the value with which \p I can be replaced for specific \p ICV.
2261 virtual std::optional<Value *> getReplacementValue(InternalControlVar ICV,
2262 const Instruction *I,
2263 Attributor &A) const {
2264 return std::nullopt;
2265 }
2266
2267 /// Return an assumed unique ICV value if a single candidate is found. If
2268 /// there cannot be one, return a nullptr. If it is not clear yet, return
2269 /// std::nullopt.
2270 virtual std::optional<Value *>
2271 getUniqueReplacementValue(InternalControlVar ICV) const = 0;
2272
2273 // Currently only nthreads is being tracked.
2274 // this array will only grow with time.
2275 InternalControlVar TrackableICVs[1] = {ICV_nthreads};
2276
2277 /// See AbstractAttribute::getName()
2278 StringRef getName() const override { return "AAICVTracker"; }
2279
2280 /// See AbstractAttribute::getIdAddr()
2281 const char *getIdAddr() const override { return &ID; }
2282
2283 /// This function should return true if the type of the \p AA is AAICVTracker
2284 static bool classof(const AbstractAttribute *AA) {
2285 return (AA->getIdAddr() == &ID);
2286 }
2287
2288 static const char ID;
2289};
2290
2291struct AAICVTrackerFunction : public AAICVTracker {
2292 AAICVTrackerFunction(const IRPosition &IRP, Attributor &A)
2293 : AAICVTracker(IRP, A) {}
2294
2295 // FIXME: come up with better string.
2296 const std::string getAsStr(Attributor *) const override {
2297 return "ICVTrackerFunction";
2298 }
2299
2300 // FIXME: come up with some stats.
2301 void trackStatistics() const override {}
2302
2303 /// We don't manifest anything for this AA.
2304 ChangeStatus manifest(Attributor &A) override {
2305 return ChangeStatus::UNCHANGED;
2306 }
2307
2308 // Map of ICV to their values at specific program point.
2309 EnumeratedArray<DenseMap<Instruction *, Value *>, InternalControlVar,
2310 InternalControlVar::ICV___last>
2311 ICVReplacementValuesMap;
2312
2313 ChangeStatus updateImpl(Attributor &A) override {
2314 ChangeStatus HasChanged = ChangeStatus::UNCHANGED;
2315
2316 Function *F = getAnchorScope();
2317
2318 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2319
2320 for (InternalControlVar ICV : TrackableICVs) {
2321 auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2322
2323 auto &ValuesMap = ICVReplacementValuesMap[ICV];
2324 auto TrackValues = [&](Use &U, Function &) {
2325 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U);
2326 if (!CI)
2327 return false;
2328
2329 // FIXME: handle setters with more that 1 arguments.
2330 /// Track new value.
2331 if (ValuesMap.insert(std::make_pair(CI, CI->getArgOperand(0))).second)
2332 HasChanged = ChangeStatus::CHANGED;
2333
2334 return false;
2335 };
2336
2337 auto CallCheck = [&](Instruction &I) {
2338 std::optional<Value *> ReplVal = getValueForCall(A, I, ICV);
2339 if (ReplVal && ValuesMap.insert(std::make_pair(&I, *ReplVal)).second)
2340 HasChanged = ChangeStatus::CHANGED;
2341
2342 return true;
2343 };
2344
2345 // Track all changes of an ICV.
2346 SetterRFI.foreachUse(TrackValues, F);
2347
2348 bool UsedAssumedInformation = false;
2349 A.checkForAllInstructions(CallCheck, *this, {Instruction::Call},
2350 UsedAssumedInformation,
2351 /* CheckBBLivenessOnly */ true);
2352
2353 /// TODO: Figure out a way to avoid adding entry in
2354 /// ICVReplacementValuesMap
2355 Instruction *Entry = &F->getEntryBlock().front();
2356 if (HasChanged == ChangeStatus::CHANGED)
2357 ValuesMap.try_emplace(Entry);
2358 }
2359
2360 return HasChanged;
2361 }
2362
2363 /// Helper to check if \p I is a call and get the value for it if it is
2364 /// unique.
2365 std::optional<Value *> getValueForCall(Attributor &A, const Instruction &I,
2366 InternalControlVar &ICV) const {
2367
2368 const auto *CB = dyn_cast<CallBase>(&I);
2369 if (!CB || CB->hasFnAttr("no_openmp") ||
2370 CB->hasFnAttr("no_openmp_routines") ||
2371 CB->hasFnAttr("no_openmp_constructs"))
2372 return std::nullopt;
2373
2374 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2375 auto &GetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Getter];
2376 auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2377 Function *CalledFunction = CB->getCalledFunction();
2378
2379 // Indirect call, assume ICV changes.
2380 if (CalledFunction == nullptr)
2381 return nullptr;
2382 if (CalledFunction == GetterRFI.Declaration)
2383 return std::nullopt;
2384 if (CalledFunction == SetterRFI.Declaration) {
2385 if (ICVReplacementValuesMap[ICV].count(&I))
2386 return ICVReplacementValuesMap[ICV].lookup(&I);
2387
2388 return nullptr;
2389 }
2390
2391 // Since we don't know, assume it changes the ICV.
2392 if (CalledFunction->isDeclaration())
2393 return nullptr;
2394
2395 const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2396 *this, IRPosition::callsite_returned(*CB), DepClassTy::REQUIRED);
2397
2398 if (ICVTrackingAA->isAssumedTracked()) {
2399 std::optional<Value *> URV =
2400 ICVTrackingAA->getUniqueReplacementValue(ICV);
2401 if (!URV || (*URV && AA::isValidAtPosition(AA::ValueAndContext(**URV, I),
2402 OMPInfoCache)))
2403 return URV;
2404 }
2405
2406 // If we don't know, assume it changes.
2407 return nullptr;
2408 }
2409
2410 // We don't check unique value for a function, so return std::nullopt.
2411 std::optional<Value *>
2412 getUniqueReplacementValue(InternalControlVar ICV) const override {
2413 return std::nullopt;
2414 }
2415
2416 /// Return the value with which \p I can be replaced for specific \p ICV.
2417 std::optional<Value *> getReplacementValue(InternalControlVar ICV,
2418 const Instruction *I,
2419 Attributor &A) const override {
2420 const auto &ValuesMap = ICVReplacementValuesMap[ICV];
2421 if (ValuesMap.count(I))
2422 return ValuesMap.lookup(I);
2423
2425 SmallPtrSet<const Instruction *, 16> Visited;
2426 Worklist.push_back(I);
2427
2428 std::optional<Value *> ReplVal;
2429
2430 while (!Worklist.empty()) {
2431 const Instruction *CurrInst = Worklist.pop_back_val();
2432 if (!Visited.insert(CurrInst).second)
2433 continue;
2434
2435 const BasicBlock *CurrBB = CurrInst->getParent();
2436
2437 // Go up and look for all potential setters/calls that might change the
2438 // ICV.
2439 while ((CurrInst = CurrInst->getPrevNode())) {
2440 if (ValuesMap.count(CurrInst)) {
2441 std::optional<Value *> NewReplVal = ValuesMap.lookup(CurrInst);
2442 // Unknown value, track new.
2443 if (!ReplVal) {
2444 ReplVal = NewReplVal;
2445 break;
2446 }
2447
2448 // If we found a new value, we can't know the icv value anymore.
2449 if (NewReplVal)
2450 if (ReplVal != NewReplVal)
2451 return nullptr;
2452
2453 break;
2454 }
2455
2456 std::optional<Value *> NewReplVal = getValueForCall(A, *CurrInst, ICV);
2457 if (!NewReplVal)
2458 continue;
2459
2460 // Unknown value, track new.
2461 if (!ReplVal) {
2462 ReplVal = NewReplVal;
2463 break;
2464 }
2465
2466 // if (NewReplVal.hasValue())
2467 // We found a new value, we can't know the icv value anymore.
2468 if (ReplVal != NewReplVal)
2469 return nullptr;
2470 }
2471
2472 // If we are in the same BB and we have a value, we are done.
2473 if (CurrBB == I->getParent() && ReplVal)
2474 return ReplVal;
2475
2476 // Go through all predecessors and add terminators for analysis.
2477 for (const BasicBlock *Pred : predecessors(CurrBB))
2478 if (const Instruction *Terminator = Pred->getTerminator())
2479 Worklist.push_back(Terminator);
2480 }
2481
2482 return ReplVal;
2483 }
2484};
2485
2486struct AAICVTrackerFunctionReturned : AAICVTracker {
2487 AAICVTrackerFunctionReturned(const IRPosition &IRP, Attributor &A)
2488 : AAICVTracker(IRP, A) {}
2489
2490 // FIXME: come up with better string.
2491 const std::string getAsStr(Attributor *) const override {
2492 return "ICVTrackerFunctionReturned";
2493 }
2494
2495 // FIXME: come up with some stats.
2496 void trackStatistics() const override {}
2497
2498 /// We don't manifest anything for this AA.
2499 ChangeStatus manifest(Attributor &A) override {
2500 return ChangeStatus::UNCHANGED;
2501 }
2502
2503 // Map of ICV to their values at specific program point.
2504 EnumeratedArray<std::optional<Value *>, InternalControlVar,
2505 InternalControlVar::ICV___last>
2506 ICVReplacementValuesMap;
2507
2508 /// Return the value with which \p I can be replaced for specific \p ICV.
2509 std::optional<Value *>
2510 getUniqueReplacementValue(InternalControlVar ICV) const override {
2511 return ICVReplacementValuesMap[ICV];
2512 }
2513
2514 ChangeStatus updateImpl(Attributor &A) override {
2515 ChangeStatus Changed = ChangeStatus::UNCHANGED;
2516 const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2517 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2518
2519 if (!ICVTrackingAA->isAssumedTracked())
2520 return indicatePessimisticFixpoint();
2521
2522 for (InternalControlVar ICV : TrackableICVs) {
2523 std::optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2524 std::optional<Value *> UniqueICVValue;
2525
2526 auto CheckReturnInst = [&](Instruction &I) {
2527 std::optional<Value *> NewReplVal =
2528 ICVTrackingAA->getReplacementValue(ICV, &I, A);
2529
2530 // If we found a second ICV value there is no unique returned value.
2531 if (UniqueICVValue && UniqueICVValue != NewReplVal)
2532 return false;
2533
2534 UniqueICVValue = NewReplVal;
2535
2536 return true;
2537 };
2538
2539 bool UsedAssumedInformation = false;
2540 if (!A.checkForAllInstructions(CheckReturnInst, *this, {Instruction::Ret},
2541 UsedAssumedInformation,
2542 /* CheckBBLivenessOnly */ true))
2543 UniqueICVValue = nullptr;
2544
2545 if (UniqueICVValue == ReplVal)
2546 continue;
2547
2548 ReplVal = UniqueICVValue;
2549 Changed = ChangeStatus::CHANGED;
2550 }
2551
2552 return Changed;
2553 }
2554};
2555
2556struct AAICVTrackerCallSite : AAICVTracker {
2557 AAICVTrackerCallSite(const IRPosition &IRP, Attributor &A)
2558 : AAICVTracker(IRP, A) {}
2559
2560 void initialize(Attributor &A) override {
2561 assert(getAnchorScope() && "Expected anchor function");
2562
2563 // We only initialize this AA for getters, so we need to know which ICV it
2564 // gets.
2565 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2566 for (InternalControlVar ICV : TrackableICVs) {
2567 auto ICVInfo = OMPInfoCache.ICVs[ICV];
2568 auto &Getter = OMPInfoCache.RFIs[ICVInfo.Getter];
2569 if (Getter.Declaration == getAssociatedFunction()) {
2570 AssociatedICV = ICVInfo.Kind;
2571 return;
2572 }
2573 }
2574
2575 /// Unknown ICV.
2576 indicatePessimisticFixpoint();
2577 }
2578
2579 ChangeStatus manifest(Attributor &A) override {
2580 if (!ReplVal || !*ReplVal)
2581 return ChangeStatus::UNCHANGED;
2582
2583 A.changeAfterManifest(IRPosition::inst(*getCtxI()), **ReplVal);
2584 A.deleteAfterManifest(*getCtxI());
2585
2586 return ChangeStatus::CHANGED;
2587 }
2588
2589 // FIXME: come up with better string.
2590 const std::string getAsStr(Attributor *) const override {
2591 return "ICVTrackerCallSite";
2592 }
2593
2594 // FIXME: come up with some stats.
2595 void trackStatistics() const override {}
2596
2597 InternalControlVar AssociatedICV;
2598 std::optional<Value *> ReplVal;
2599
2600 ChangeStatus updateImpl(Attributor &A) override {
2601 const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2602 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2603
2604 // We don't have any information, so we assume it changes the ICV.
2605 if (!ICVTrackingAA->isAssumedTracked())
2606 return indicatePessimisticFixpoint();
2607
2608 std::optional<Value *> NewReplVal =
2609 ICVTrackingAA->getReplacementValue(AssociatedICV, getCtxI(), A);
2610
2611 if (ReplVal == NewReplVal)
2612 return ChangeStatus::UNCHANGED;
2613
2614 ReplVal = NewReplVal;
2615 return ChangeStatus::CHANGED;
2616 }
2617
2618 // Return the value with which associated value can be replaced for specific
2619 // \p ICV.
2620 std::optional<Value *>
2621 getUniqueReplacementValue(InternalControlVar ICV) const override {
2622 return ReplVal;
2623 }
2624};
2625
2626struct AAICVTrackerCallSiteReturned : AAICVTracker {
2627 AAICVTrackerCallSiteReturned(const IRPosition &IRP, Attributor &A)
2628 : AAICVTracker(IRP, A) {}
2629
2630 // FIXME: come up with better string.
2631 const std::string getAsStr(Attributor *) const override {
2632 return "ICVTrackerCallSiteReturned";
2633 }
2634
2635 // FIXME: come up with some stats.
2636 void trackStatistics() const override {}
2637
2638 /// We don't manifest anything for this AA.
2639 ChangeStatus manifest(Attributor &A) override {
2640 return ChangeStatus::UNCHANGED;
2641 }
2642
2643 // Map of ICV to their values at specific program point.
2644 EnumeratedArray<std::optional<Value *>, InternalControlVar,
2645 InternalControlVar::ICV___last>
2646 ICVReplacementValuesMap;
2647
2648 /// Return the value with which associated value can be replaced for specific
2649 /// \p ICV.
2650 std::optional<Value *>
2651 getUniqueReplacementValue(InternalControlVar ICV) const override {
2652 return ICVReplacementValuesMap[ICV];
2653 }
2654
2655 ChangeStatus updateImpl(Attributor &A) override {
2656 ChangeStatus Changed = ChangeStatus::UNCHANGED;
2657 const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(
2658 *this, IRPosition::returned(*getAssociatedFunction()),
2659 DepClassTy::REQUIRED);
2660
2661 // We don't have any information, so we assume it changes the ICV.
2662 if (!ICVTrackingAA->isAssumedTracked())
2663 return indicatePessimisticFixpoint();
2664
2665 for (InternalControlVar ICV : TrackableICVs) {
2666 std::optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2667 std::optional<Value *> NewReplVal =
2668 ICVTrackingAA->getUniqueReplacementValue(ICV);
2669
2670 if (ReplVal == NewReplVal)
2671 continue;
2672
2673 ReplVal = NewReplVal;
2674 Changed = ChangeStatus::CHANGED;
2675 }
2676 return Changed;
2677 }
2678};
2679
2680/// Determines if \p BB exits the function unconditionally itself or reaches a
2681/// block that does through only unique successors.
2682static bool hasFunctionEndAsUniqueSuccessor(const BasicBlock *BB) {
2683 if (succ_empty(BB))
2684 return true;
2685 const BasicBlock *const Successor = BB->getUniqueSuccessor();
2686 if (!Successor)
2687 return false;
2688 return hasFunctionEndAsUniqueSuccessor(Successor);
2689}
2690
2691struct AAExecutionDomainFunction : public AAExecutionDomain {
2692 AAExecutionDomainFunction(const IRPosition &IRP, Attributor &A)
2693 : AAExecutionDomain(IRP, A) {}
2694
2695 ~AAExecutionDomainFunction() override { delete RPOT; }
2696
2697 void initialize(Attributor &A) override {
2698 Function *F = getAnchorScope();
2699 assert(F && "Expected anchor function");
2700 RPOT = new ReversePostOrderTraversal<Function *>(F);
2701 }
2702
2703 const std::string getAsStr(Attributor *) const override {
2704 unsigned TotalBlocks = 0, InitialThreadBlocks = 0, AlignedBlocks = 0;
2705 for (auto &It : BEDMap) {
2706 if (!It.getFirst())
2707 continue;
2708 TotalBlocks++;
2709 InitialThreadBlocks += It.getSecond().IsExecutedByInitialThreadOnly;
2710 AlignedBlocks += It.getSecond().IsReachedFromAlignedBarrierOnly &&
2711 It.getSecond().IsReachingAlignedBarrierOnly;
2712 }
2713 return "[AAExecutionDomain] " + std::to_string(InitialThreadBlocks) + "/" +
2714 std::to_string(AlignedBlocks) + " of " +
2715 std::to_string(TotalBlocks) +
2716 " executed by initial thread / aligned";
2717 }
2718
2719 /// See AbstractAttribute::trackStatistics().
2720 void trackStatistics() const override {}
2721
2722 ChangeStatus manifest(Attributor &A) override {
2723 LLVM_DEBUG({
2724 for (const BasicBlock &BB : *getAnchorScope()) {
2725 if (!isExecutedByInitialThreadOnly(BB))
2726 continue;
2727 dbgs() << TAG << " Basic block @" << getAnchorScope()->getName() << " "
2728 << BB.getName() << " is executed by a single thread.\n";
2729 }
2730 });
2731
2732 ChangeStatus Changed = ChangeStatus::UNCHANGED;
2733
2735 return Changed;
2736
2737 SmallPtrSet<CallBase *, 16> DeletedBarriers;
2738 auto HandleAlignedBarrier = [&](CallBase *CB) {
2739 const ExecutionDomainTy &ED = CB ? CEDMap[{CB, PRE}] : BEDMap[nullptr];
2740 if (!ED.IsReachedFromAlignedBarrierOnly ||
2741 ED.EncounteredNonLocalSideEffect)
2742 return;
2743 if (!ED.EncounteredAssumes.empty() && !A.isModulePass())
2744 return;
2745
2746 // We can remove this barrier, if it is one, or aligned barriers reaching
2747 // the kernel end (if CB is nullptr). Aligned barriers reaching the kernel
2748 // end should only be removed if the kernel end is their unique successor;
2749 // otherwise, they may have side-effects that aren't accounted for in the
2750 // kernel end in their other successors. If those barriers have other
2751 // barriers reaching them, those can be transitively removed as well as
2752 // long as the kernel end is also their unique successor.
2753 if (CB) {
2754 DeletedBarriers.insert(CB);
2755 A.deleteAfterManifest(*CB);
2756 ++NumBarriersEliminated;
2757 Changed = ChangeStatus::CHANGED;
2758 } else if (!ED.AlignedBarriers.empty()) {
2759 Changed = ChangeStatus::CHANGED;
2760 SmallVector<CallBase *> Worklist(ED.AlignedBarriers.begin(),
2761 ED.AlignedBarriers.end());
2762 SmallSetVector<CallBase *, 16> Visited;
2763 while (!Worklist.empty()) {
2764 CallBase *LastCB = Worklist.pop_back_val();
2765 if (!Visited.insert(LastCB))
2766 continue;
2767 if (LastCB->getFunction() != getAnchorScope())
2768 continue;
2769 if (!hasFunctionEndAsUniqueSuccessor(LastCB->getParent()))
2770 continue;
2771 if (!DeletedBarriers.count(LastCB)) {
2772 ++NumBarriersEliminated;
2773 A.deleteAfterManifest(*LastCB);
2774 continue;
2775 }
2776 // The final aligned barrier (LastCB) reaching the kernel end was
2777 // removed already. This means we can go one step further and remove
2778 // the barriers encoutered last before (LastCB).
2779 const ExecutionDomainTy &LastED = CEDMap[{LastCB, PRE}];
2780 Worklist.append(LastED.AlignedBarriers.begin(),
2781 LastED.AlignedBarriers.end());
2782 }
2783 }
2784
2785 // If we actually eliminated a barrier we need to eliminate the associated
2786 // llvm.assumes as well to avoid creating UB.
2787 if (!ED.EncounteredAssumes.empty() && (CB || !ED.AlignedBarriers.empty()))
2788 for (auto *AssumeCB : ED.EncounteredAssumes)
2789 A.deleteAfterManifest(*AssumeCB);
2790 };
2791
2792 for (auto *CB : AlignedBarriers)
2793 HandleAlignedBarrier(CB);
2794
2795 // Handle the "kernel end barrier" for kernels too.
2796 if (omp::isOpenMPKernel(*getAnchorScope()))
2797 HandleAlignedBarrier(nullptr);
2798
2799 return Changed;
2800 }
2801
2802 bool isNoOpFence(const FenceInst &FI) const override {
2803 return getState().isValidState() && !NonNoOpFences.count(&FI);
2804 }
2805
2806 /// Merge barrier and assumption information from \p PredED into the successor
2807 /// \p ED.
2808 void
2809 mergeInPredecessorBarriersAndAssumptions(Attributor &A, ExecutionDomainTy &ED,
2810 const ExecutionDomainTy &PredED);
2811
2812 /// Merge all information from \p PredED into the successor \p ED. If
2813 /// \p InitialEdgeOnly is set, only the initial edge will enter the block
2814 /// represented by \p ED from this predecessor.
2815 bool mergeInPredecessor(Attributor &A, ExecutionDomainTy &ED,
2816 const ExecutionDomainTy &PredED,
2817 bool InitialEdgeOnly = false);
2818
2819 /// Accumulate information for the entry block in \p EntryBBED.
2820 bool handleCallees(Attributor &A, ExecutionDomainTy &EntryBBED);
2821
2822 /// See AbstractAttribute::updateImpl.
2823 ChangeStatus updateImpl(Attributor &A) override;
2824
2825 /// Query interface, see AAExecutionDomain
2826 ///{
2827 bool isExecutedByInitialThreadOnly(const BasicBlock &BB) const override {
2828 if (!isValidState())
2829 return false;
2830 assert(BB.getParent() == getAnchorScope() && "Block is out of scope!");
2831 return BEDMap.lookup(&BB).IsExecutedByInitialThreadOnly;
2832 }
2833
2834 bool isExecutedInAlignedRegion(Attributor &A,
2835 const Instruction &I) const override {
2836 assert(I.getFunction() == getAnchorScope() &&
2837 "Instruction is out of scope!");
2838 if (!isValidState())
2839 return false;
2840
2841 bool ForwardIsOk = true;
2842 const Instruction *CurI;
2843
2844 // Check forward until a call or the block end is reached.
2845 CurI = &I;
2846 do {
2847 auto *CB = dyn_cast<CallBase>(CurI);
2848 if (!CB)
2849 continue;
2850 if (CB != &I && AlignedBarriers.contains(const_cast<CallBase *>(CB)))
2851 return true;
2852 const auto &It = CEDMap.find({CB, PRE});
2853 if (It == CEDMap.end())
2854 continue;
2855 if (!It->getSecond().IsReachingAlignedBarrierOnly)
2856 ForwardIsOk = false;
2857 break;
2858 } while ((CurI = CurI->getNextNode()));
2859
2860 if (!CurI && !BEDMap.lookup(I.getParent()).IsReachingAlignedBarrierOnly)
2861 ForwardIsOk = false;
2862
2863 // Check backward until a call or the block beginning is reached.
2864 CurI = &I;
2865 do {
2866 auto *CB = dyn_cast<CallBase>(CurI);
2867 if (!CB)
2868 continue;
2869 if (CB != &I && AlignedBarriers.contains(const_cast<CallBase *>(CB)))
2870 return true;
2871 const auto &It = CEDMap.find({CB, POST});
2872 if (It == CEDMap.end())
2873 continue;
2874 if (It->getSecond().IsReachedFromAlignedBarrierOnly)
2875 break;
2876 return false;
2877 } while ((CurI = CurI->getPrevNode()));
2878
2879 // Delayed decision on the forward pass to allow aligned barrier detection
2880 // in the backwards traversal.
2881 if (!ForwardIsOk)
2882 return false;
2883
2884 if (!CurI) {
2885 const BasicBlock *BB = I.getParent();
2886 if (BB == &BB->getParent()->getEntryBlock())
2887 return BEDMap.lookup(nullptr).IsReachedFromAlignedBarrierOnly;
2888 if (!llvm::all_of(predecessors(BB), [&](const BasicBlock *PredBB) {
2889 return BEDMap.lookup(PredBB).IsReachedFromAlignedBarrierOnly;
2890 })) {
2891 return false;
2892 }
2893 }
2894
2895 // On neither traversal we found a anything but aligned barriers.
2896 return true;
2897 }
2898
2899 ExecutionDomainTy getExecutionDomain(const BasicBlock &BB) const override {
2900 assert(isValidState() &&
2901 "No request should be made against an invalid state!");
2902 return BEDMap.lookup(&BB);
2903 }
2904 std::pair<ExecutionDomainTy, ExecutionDomainTy>
2905 getExecutionDomain(const CallBase &CB) const override {
2906 assert(isValidState() &&
2907 "No request should be made against an invalid state!");
2908 return {CEDMap.lookup({&CB, PRE}), CEDMap.lookup({&CB, POST})};
2909 }
2910 ExecutionDomainTy getFunctionExecutionDomain() const override {
2911 assert(isValidState() &&
2912 "No request should be made against an invalid state!");
2913 return InterProceduralED;
2914 }
2915 ///}
2916
2917 // Check if the edge into the successor block contains a condition that only
2918 // lets the main thread execute it.
2919 static bool isInitialThreadOnlyEdge(Attributor &A, BranchInst *Edge,
2920 BasicBlock &SuccessorBB) {
2921 if (!Edge || !Edge->isConditional())
2922 return false;
2923 if (Edge->getSuccessor(0) != &SuccessorBB)
2924 return false;
2925
2926 auto *Cmp = dyn_cast<CmpInst>(Edge->getCondition());
2927 if (!Cmp || !Cmp->isTrueWhenEqual() || !Cmp->isEquality())
2928 return false;
2929
2930 ConstantInt *C = dyn_cast<ConstantInt>(Cmp->getOperand(1));
2931 if (!C)
2932 return false;
2933
2934 // Match: -1 == __kmpc_target_init (for non-SPMD kernels only!)
2935 if (C->isAllOnesValue()) {
2936 auto *CB = dyn_cast<CallBase>(Cmp->getOperand(0));
2937 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2938 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
2939 CB = CB ? OpenMPOpt::getCallIfRegularCall(*CB, &RFI) : nullptr;
2940 if (!CB)
2941 return false;
2942 ConstantStruct *KernelEnvC =
2944 ConstantInt *ExecModeC =
2945 KernelInfo::getExecModeFromKernelEnvironment(KernelEnvC);
2946 return ExecModeC->getSExtValue() & OMP_TGT_EXEC_MODE_GENERIC;
2947 }
2948
2949 if (C->isZero()) {
2950 // Match: 0 == llvm.nvvm.read.ptx.sreg.tid.x()
2951 if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
2952 if (II->getIntrinsicID() == Intrinsic::nvvm_read_ptx_sreg_tid_x)
2953 return true;
2954
2955 // Match: 0 == llvm.amdgcn.workitem.id.x()
2956 if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
2957 if (II->getIntrinsicID() == Intrinsic::amdgcn_workitem_id_x)
2958 return true;
2959 }
2960
2961 return false;
2962 };
2963
2964 /// Mapping containing information about the function for other AAs.
2965 ExecutionDomainTy InterProceduralED;
2966
2967 enum Direction { PRE = 0, POST = 1 };
2968 /// Mapping containing information per block.
2969 DenseMap<const BasicBlock *, ExecutionDomainTy> BEDMap;
2970 DenseMap<PointerIntPair<const CallBase *, 1, Direction>, ExecutionDomainTy>
2971 CEDMap;
2972 SmallSetVector<CallBase *, 16> AlignedBarriers;
2973
2974 ReversePostOrderTraversal<Function *> *RPOT = nullptr;
2975
2976 /// Set \p R to \V and report true if that changed \p R.
2977 static bool setAndRecord(bool &R, bool V) {
2978 bool Eq = (R == V);
2979 R = V;
2980 return !Eq;
2981 }
2982
2983 /// Collection of fences known to be non-no-opt. All fences not in this set
2984 /// can be assumed no-opt.
2985 SmallPtrSet<const FenceInst *, 8> NonNoOpFences;
2986};
2987
2988void AAExecutionDomainFunction::mergeInPredecessorBarriersAndAssumptions(
2989 Attributor &A, ExecutionDomainTy &ED, const ExecutionDomainTy &PredED) {
2990 for (auto *EA : PredED.EncounteredAssumes)
2991 ED.addAssumeInst(A, *EA);
2992
2993 for (auto *AB : PredED.AlignedBarriers)
2994 ED.addAlignedBarrier(A, *AB);
2995}
2996
2997bool AAExecutionDomainFunction::mergeInPredecessor(
2998 Attributor &A, ExecutionDomainTy &ED, const ExecutionDomainTy &PredED,
2999 bool InitialEdgeOnly) {
3000
3001 bool Changed = false;
3002 Changed |=
3003 setAndRecord(ED.IsExecutedByInitialThreadOnly,
3004 InitialEdgeOnly || (PredED.IsExecutedByInitialThreadOnly &&
3005 ED.IsExecutedByInitialThreadOnly));
3006
3007 Changed |= setAndRecord(ED.IsReachedFromAlignedBarrierOnly,
3008 ED.IsReachedFromAlignedBarrierOnly &&
3009 PredED.IsReachedFromAlignedBarrierOnly);
3010 Changed |= setAndRecord(ED.EncounteredNonLocalSideEffect,
3011 ED.EncounteredNonLocalSideEffect |
3012 PredED.EncounteredNonLocalSideEffect);
3013 // Do not track assumptions and barriers as part of Changed.
3014 if (ED.IsReachedFromAlignedBarrierOnly)
3015 mergeInPredecessorBarriersAndAssumptions(A, ED, PredED);
3016 else
3017 ED.clearAssumeInstAndAlignedBarriers();
3018 return Changed;
3019}
3020
3021bool AAExecutionDomainFunction::handleCallees(Attributor &A,
3022 ExecutionDomainTy &EntryBBED) {
3024 auto PredForCallSite = [&](AbstractCallSite ACS) {
3025 const auto *EDAA = A.getAAFor<AAExecutionDomain>(
3026 *this, IRPosition::function(*ACS.getInstruction()->getFunction()),
3027 DepClassTy::OPTIONAL);
3028 if (!EDAA || !EDAA->getState().isValidState())
3029 return false;
3030 CallSiteEDs.emplace_back(
3031 EDAA->getExecutionDomain(*cast<CallBase>(ACS.getInstruction())));
3032 return true;
3033 };
3034
3035 ExecutionDomainTy ExitED;
3036 bool AllCallSitesKnown;
3037 if (A.checkForAllCallSites(PredForCallSite, *this,
3038 /* RequiresAllCallSites */ true,
3039 AllCallSitesKnown)) {
3040 for (const auto &[CSInED, CSOutED] : CallSiteEDs) {
3041 mergeInPredecessor(A, EntryBBED, CSInED);
3042 ExitED.IsReachingAlignedBarrierOnly &=
3043 CSOutED.IsReachingAlignedBarrierOnly;
3044 }
3045
3046 } else {
3047 // We could not find all predecessors, so this is either a kernel or a
3048 // function with external linkage (or with some other weird uses).
3049 if (omp::isOpenMPKernel(*getAnchorScope())) {
3050 EntryBBED.IsExecutedByInitialThreadOnly = false;
3051 EntryBBED.IsReachedFromAlignedBarrierOnly = true;
3052 EntryBBED.EncounteredNonLocalSideEffect = false;
3053 ExitED.IsReachingAlignedBarrierOnly = false;
3054 } else {
3055 EntryBBED.IsExecutedByInitialThreadOnly = false;
3056 EntryBBED.IsReachedFromAlignedBarrierOnly = false;
3057 EntryBBED.EncounteredNonLocalSideEffect = true;
3058 ExitED.IsReachingAlignedBarrierOnly = false;
3059 }
3060 }
3061
3062 bool Changed = false;
3063 auto &FnED = BEDMap[nullptr];
3064 Changed |= setAndRecord(FnED.IsReachedFromAlignedBarrierOnly,
3065 FnED.IsReachedFromAlignedBarrierOnly &
3066 EntryBBED.IsReachedFromAlignedBarrierOnly);
3067 Changed |= setAndRecord(FnED.IsReachingAlignedBarrierOnly,
3068 FnED.IsReachingAlignedBarrierOnly &
3069 ExitED.IsReachingAlignedBarrierOnly);
3070 Changed |= setAndRecord(FnED.IsExecutedByInitialThreadOnly,
3071 EntryBBED.IsExecutedByInitialThreadOnly);
3072 return Changed;
3073}
3074
3075ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
3076
3077 bool Changed = false;
3078
3079 // Helper to deal with an aligned barrier encountered during the forward
3080 // traversal. \p CB is the aligned barrier, \p ED is the execution domain when
3081 // it was encountered.
3082 auto HandleAlignedBarrier = [&](CallBase &CB, ExecutionDomainTy &ED) {
3083 Changed |= AlignedBarriers.insert(&CB);
3084 // First, update the barrier ED kept in the separate CEDMap.
3085 auto &CallInED = CEDMap[{&CB, PRE}];
3086 Changed |= mergeInPredecessor(A, CallInED, ED);
3087 CallInED.IsReachingAlignedBarrierOnly = true;
3088 // Next adjust the ED we use for the traversal.
3089 ED.EncounteredNonLocalSideEffect = false;
3090 ED.IsReachedFromAlignedBarrierOnly = true;
3091 // Aligned barrier collection has to come last.
3092 ED.clearAssumeInstAndAlignedBarriers();
3093 ED.addAlignedBarrier(A, CB);
3094 auto &CallOutED = CEDMap[{&CB, POST}];
3095 Changed |= mergeInPredecessor(A, CallOutED, ED);
3096 };
3097
3098 auto *LivenessAA =
3099 A.getAAFor<AAIsDead>(*this, getIRPosition(), DepClassTy::OPTIONAL);
3100
3101 Function *F = getAnchorScope();
3102 BasicBlock &EntryBB = F->getEntryBlock();
3103 bool IsKernel = omp::isOpenMPKernel(*F);
3104
3105 SmallVector<Instruction *> SyncInstWorklist;
3106 for (auto &RIt : *RPOT) {
3107 BasicBlock &BB = *RIt;
3108
3109 bool IsEntryBB = &BB == &EntryBB;
3110 // TODO: We use local reasoning since we don't have a divergence analysis
3111 // running as well. We could basically allow uniform branches here.
3112 bool AlignedBarrierLastInBlock = IsEntryBB && IsKernel;
3113 bool IsExplicitlyAligned = IsEntryBB && IsKernel;
3114 ExecutionDomainTy ED;
3115 // Propagate "incoming edges" into information about this block.
3116 if (IsEntryBB) {
3117 Changed |= handleCallees(A, ED);
3118 } else {
3119 // For live non-entry blocks we only propagate
3120 // information via live edges.
3121 if (LivenessAA && LivenessAA->isAssumedDead(&BB))
3122 continue;
3123
3124 for (auto *PredBB : predecessors(&BB)) {
3125 if (LivenessAA && LivenessAA->isEdgeDead(PredBB, &BB))
3126 continue;
3127 bool InitialEdgeOnly = isInitialThreadOnlyEdge(
3128 A, dyn_cast<BranchInst>(PredBB->getTerminator()), BB);
3129 mergeInPredecessor(A, ED, BEDMap[PredBB], InitialEdgeOnly);
3130 }
3131 }
3132
3133 // Now we traverse the block, accumulate effects in ED and attach
3134 // information to calls.
3135 for (Instruction &I : BB) {
3136 bool UsedAssumedInformation;
3137 if (A.isAssumedDead(I, *this, LivenessAA, UsedAssumedInformation,
3138 /* CheckBBLivenessOnly */ false, DepClassTy::OPTIONAL,
3139 /* CheckForDeadStore */ true))
3140 continue;
3141
3142 // Asummes and "assume-like" (dbg, lifetime, ...) are handled first, the
3143 // former is collected the latter is ignored.
3144 if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
3145 if (auto *AI = dyn_cast_or_null<AssumeInst>(II)) {
3146 ED.addAssumeInst(A, *AI);
3147 continue;
3148 }
3149 // TODO: Should we also collect and delete lifetime markers?
3150 if (II->isAssumeLikeIntrinsic())
3151 continue;
3152 }
3153
3154 if (auto *FI = dyn_cast<FenceInst>(&I)) {
3155 if (!ED.EncounteredNonLocalSideEffect) {
3156 // An aligned fence without non-local side-effects is a no-op.
3157 if (ED.IsReachedFromAlignedBarrierOnly)
3158 continue;
3159 // A non-aligned fence without non-local side-effects is a no-op
3160 // if the ordering only publishes non-local side-effects (or less).
3161 switch (FI->getOrdering()) {
3162 case AtomicOrdering::NotAtomic:
3163 continue;
3164 case AtomicOrdering::Unordered:
3165 continue;
3166 case AtomicOrdering::Monotonic:
3167 continue;
3168 case AtomicOrdering::Acquire:
3169 break;
3170 case AtomicOrdering::Release:
3171 continue;
3172 case AtomicOrdering::AcquireRelease:
3173 break;
3174 case AtomicOrdering::SequentiallyConsistent:
3175 break;
3176 };
3177 }
3178 NonNoOpFences.insert(FI);
3179 }
3180
3181 auto *CB = dyn_cast<CallBase>(&I);
3182 bool IsNoSync = AA::isNoSyncInst(A, I, *this);
3183 bool IsAlignedBarrier =
3184 !IsNoSync && CB &&
3185 AANoSync::isAlignedBarrier(*CB, AlignedBarrierLastInBlock);
3186
3187 AlignedBarrierLastInBlock &= IsNoSync;
3188 IsExplicitlyAligned &= IsNoSync;
3189
3190 // Next we check for calls. Aligned barriers are handled
3191 // explicitly, everything else is kept for the backward traversal and will
3192 // also affect our state.
3193 if (CB) {
3194 if (IsAlignedBarrier) {
3195 HandleAlignedBarrier(*CB, ED);
3196 AlignedBarrierLastInBlock = true;
3197 IsExplicitlyAligned = true;
3198 continue;
3199 }
3200
3201 // Check the pointer(s) of a memory intrinsic explicitly.
3202 if (isa<MemIntrinsic>(&I)) {
3203 if (!ED.EncounteredNonLocalSideEffect &&
3205 ED.EncounteredNonLocalSideEffect = true;
3206 if (!IsNoSync) {
3207 ED.IsReachedFromAlignedBarrierOnly = false;
3208 SyncInstWorklist.push_back(&I);
3209 }
3210 continue;
3211 }
3212
3213 // Record how we entered the call, then accumulate the effect of the
3214 // call in ED for potential use by the callee.
3215 auto &CallInED = CEDMap[{CB, PRE}];
3216 Changed |= mergeInPredecessor(A, CallInED, ED);
3217
3218 // If we have a sync-definition we can check if it starts/ends in an
3219 // aligned barrier. If we are unsure we assume any sync breaks
3220 // alignment.
3222 if (!IsNoSync && Callee && !Callee->isDeclaration()) {
3223 const auto *EDAA = A.getAAFor<AAExecutionDomain>(
3224 *this, IRPosition::function(*Callee), DepClassTy::OPTIONAL);
3225 if (EDAA && EDAA->getState().isValidState()) {
3226 const auto &CalleeED = EDAA->getFunctionExecutionDomain();
3227 ED.IsReachedFromAlignedBarrierOnly =
3228 CalleeED.IsReachedFromAlignedBarrierOnly;
3229 AlignedBarrierLastInBlock = ED.IsReachedFromAlignedBarrierOnly;
3230 if (IsNoSync || !CalleeED.IsReachedFromAlignedBarrierOnly)
3231 ED.EncounteredNonLocalSideEffect |=
3232 CalleeED.EncounteredNonLocalSideEffect;
3233 else
3234 ED.EncounteredNonLocalSideEffect =
3235 CalleeED.EncounteredNonLocalSideEffect;
3236 if (!CalleeED.IsReachingAlignedBarrierOnly) {
3237 Changed |=
3238 setAndRecord(CallInED.IsReachingAlignedBarrierOnly, false);
3239 SyncInstWorklist.push_back(&I);
3240 }
3241 if (CalleeED.IsReachedFromAlignedBarrierOnly)
3242 mergeInPredecessorBarriersAndAssumptions(A, ED, CalleeED);
3243 auto &CallOutED = CEDMap[{CB, POST}];
3244 Changed |= mergeInPredecessor(A, CallOutED, ED);
3245 continue;
3246 }
3247 }
3248 if (!IsNoSync) {
3249 ED.IsReachedFromAlignedBarrierOnly = false;
3250 Changed |= setAndRecord(CallInED.IsReachingAlignedBarrierOnly, false);
3251 SyncInstWorklist.push_back(&I);
3252 }
3253 AlignedBarrierLastInBlock &= ED.IsReachedFromAlignedBarrierOnly;
3254 ED.EncounteredNonLocalSideEffect |= !CB->doesNotAccessMemory();
3255 auto &CallOutED = CEDMap[{CB, POST}];
3256 Changed |= mergeInPredecessor(A, CallOutED, ED);
3257 }
3258
3259 if (!I.mayHaveSideEffects() && !I.mayReadFromMemory())
3260 continue;
3261
3262 // If we have a callee we try to use fine-grained information to
3263 // determine local side-effects.
3264 if (CB) {
3265 const auto *MemAA = A.getAAFor<AAMemoryLocation>(
3266 *this, IRPosition::callsite_function(*CB), DepClassTy::OPTIONAL);
3267
3268 auto AccessPred = [&](const Instruction *I, const Value *Ptr,
3271 return !AA::isPotentiallyAffectedByBarrier(A, {Ptr}, *this, I);
3272 };
3273 if (MemAA && MemAA->getState().isValidState() &&
3274 MemAA->checkForAllAccessesToMemoryKind(
3276 continue;
3277 }
3278
3279 auto &InfoCache = A.getInfoCache();
3280 if (!I.mayHaveSideEffects() && InfoCache.isOnlyUsedByAssume(I))
3281 continue;
3282
3283 if (auto *LI = dyn_cast<LoadInst>(&I))
3284 if (LI->hasMetadata(LLVMContext::MD_invariant_load))
3285 continue;
3286
3287 if (!ED.EncounteredNonLocalSideEffect &&
3289 ED.EncounteredNonLocalSideEffect = true;
3290 }
3291
3292 bool IsEndAndNotReachingAlignedBarriersOnly = false;
3293 if (!isa<UnreachableInst>(BB.getTerminator()) &&
3294 !BB.getTerminator()->getNumSuccessors()) {
3295
3296 Changed |= mergeInPredecessor(A, InterProceduralED, ED);
3297
3298 auto &FnED = BEDMap[nullptr];
3299 if (IsKernel && !IsExplicitlyAligned)
3300 FnED.IsReachingAlignedBarrierOnly = false;
3301 Changed |= mergeInPredecessor(A, FnED, ED);
3302
3303 if (!FnED.IsReachingAlignedBarrierOnly) {
3304 IsEndAndNotReachingAlignedBarriersOnly = true;
3305 SyncInstWorklist.push_back(BB.getTerminator());
3306 auto &BBED = BEDMap[&BB];
3307 Changed |= setAndRecord(BBED.IsReachingAlignedBarrierOnly, false);
3308 }
3309 }
3310
3311 ExecutionDomainTy &StoredED = BEDMap[&BB];
3312 ED.IsReachingAlignedBarrierOnly = StoredED.IsReachingAlignedBarrierOnly &
3313 !IsEndAndNotReachingAlignedBarriersOnly;
3314
3315 // Check if we computed anything different as part of the forward
3316 // traversal. We do not take assumptions and aligned barriers into account
3317 // as they do not influence the state we iterate. Backward traversal values
3318 // are handled later on.
3319 if (ED.IsExecutedByInitialThreadOnly !=
3320 StoredED.IsExecutedByInitialThreadOnly ||
3321 ED.IsReachedFromAlignedBarrierOnly !=
3322 StoredED.IsReachedFromAlignedBarrierOnly ||
3323 ED.EncounteredNonLocalSideEffect !=
3324 StoredED.EncounteredNonLocalSideEffect)
3325 Changed = true;
3326
3327 // Update the state with the new value.
3328 StoredED = std::move(ED);
3329 }
3330
3331 // Propagate (non-aligned) sync instruction effects backwards until the
3332 // entry is hit or an aligned barrier.
3333 SmallSetVector<BasicBlock *, 16> Visited;
3334 while (!SyncInstWorklist.empty()) {
3335 Instruction *SyncInst = SyncInstWorklist.pop_back_val();
3336 Instruction *CurInst = SyncInst;
3337 bool HitAlignedBarrierOrKnownEnd = false;
3338 while ((CurInst = CurInst->getPrevNode())) {
3339 auto *CB = dyn_cast<CallBase>(CurInst);
3340 if (!CB)
3341 continue;
3342 auto &CallOutED = CEDMap[{CB, POST}];
3343 Changed |= setAndRecord(CallOutED.IsReachingAlignedBarrierOnly, false);
3344 auto &CallInED = CEDMap[{CB, PRE}];
3345 HitAlignedBarrierOrKnownEnd =
3346 AlignedBarriers.count(CB) || !CallInED.IsReachingAlignedBarrierOnly;
3347 if (HitAlignedBarrierOrKnownEnd)
3348 break;
3349 Changed |= setAndRecord(CallInED.IsReachingAlignedBarrierOnly, false);
3350 }
3351 if (HitAlignedBarrierOrKnownEnd)
3352 continue;
3353 BasicBlock *SyncBB = SyncInst->getParent();
3354 for (auto *PredBB : predecessors(SyncBB)) {
3355 if (LivenessAA && LivenessAA->isEdgeDead(PredBB, SyncBB))
3356 continue;
3357 if (!Visited.insert(PredBB))
3358 continue;
3359 auto &PredED = BEDMap[PredBB];
3360 if (setAndRecord(PredED.IsReachingAlignedBarrierOnly, false)) {
3361 Changed = true;
3362 SyncInstWorklist.push_back(PredBB->getTerminator());
3363 }
3364 }
3365 if (SyncBB != &EntryBB)
3366 continue;
3367 Changed |=
3368 setAndRecord(InterProceduralED.IsReachingAlignedBarrierOnly, false);
3369 }
3370
3371 return Changed ? ChangeStatus::CHANGED : ChangeStatus::UNCHANGED;
3372}
3373
3374/// Try to replace memory allocation calls called by a single thread with a
3375/// static buffer of shared memory.
3376struct AAHeapToShared : public StateWrapper<BooleanState, AbstractAttribute> {
3377 using Base = StateWrapper<BooleanState, AbstractAttribute>;
3378 AAHeapToShared(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
3379
3380 /// Create an abstract attribute view for the position \p IRP.
3381 static AAHeapToShared &createForPosition(const IRPosition &IRP,
3382 Attributor &A);
3383
3384 /// Returns true if HeapToShared conversion is assumed to be possible.
3385 virtual bool isAssumedHeapToShared(CallBase &CB) const = 0;
3386
3387 /// Returns true if HeapToShared conversion is assumed and the CB is a
3388 /// callsite to a free operation to be removed.
3389 virtual bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const = 0;
3390
3391 /// See AbstractAttribute::getName().
3392 StringRef getName() const override { return "AAHeapToShared"; }
3393
3394 /// See AbstractAttribute::getIdAddr().
3395 const char *getIdAddr() const override { return &ID; }
3396
3397 /// This function should return true if the type of the \p AA is
3398 /// AAHeapToShared.
3399 static bool classof(const AbstractAttribute *AA) {
3400 return (AA->getIdAddr() == &ID);
3401 }
3402
3403 /// Unique ID (due to the unique address)
3404 static const char ID;
3405};
3406
3407struct AAHeapToSharedFunction : public AAHeapToShared {
3408 AAHeapToSharedFunction(const IRPosition &IRP, Attributor &A)
3409 : AAHeapToShared(IRP, A) {}
3410
3411 const std::string getAsStr(Attributor *) const override {
3412 return "[AAHeapToShared] " + std::to_string(MallocCalls.size()) +
3413 " malloc calls eligible.";
3414 }
3415
3416 /// See AbstractAttribute::trackStatistics().
3417 void trackStatistics() const override {}
3418
3419 /// This functions finds free calls that will be removed by the
3420 /// HeapToShared transformation.
3421 void findPotentialRemovedFreeCalls(Attributor &A) {
3422 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3423 auto &FreeRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
3424
3425 PotentialRemovedFreeCalls.clear();
3426 // Update free call users of found malloc calls.
3427 for (CallBase *CB : MallocCalls) {
3429 for (auto *U : CB->users()) {
3430 CallBase *C = dyn_cast<CallBase>(U);
3431 if (C && C->getCalledFunction() == FreeRFI.Declaration)
3432 FreeCalls.push_back(C);
3433 }
3434
3435 if (FreeCalls.size() != 1)
3436 continue;
3437
3438 PotentialRemovedFreeCalls.insert(FreeCalls.front());
3439 }
3440 }
3441
3442 void initialize(Attributor &A) override {
3444 indicatePessimisticFixpoint();
3445 return;
3446 }
3447
3448 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3449 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3450 if (!RFI.Declaration)
3451 return;
3452
3454 [](const IRPosition &, const AbstractAttribute *,
3455 bool &) -> std::optional<Value *> { return nullptr; };
3456
3457 Function *F = getAnchorScope();
3458 for (User *U : RFI.Declaration->users())
3459 if (CallBase *CB = dyn_cast<CallBase>(U)) {
3460 if (CB->getFunction() != F)
3461 continue;
3462 MallocCalls.insert(CB);
3463 A.registerSimplificationCallback(IRPosition::callsite_returned(*CB),
3464 SCB);
3465 }
3466
3467 findPotentialRemovedFreeCalls(A);
3468 }
3469
3470 bool isAssumedHeapToShared(CallBase &CB) const override {
3471 return isValidState() && MallocCalls.count(&CB);
3472 }
3473
3474 bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const override {
3475 return isValidState() && PotentialRemovedFreeCalls.count(&CB);
3476 }
3477
3478 ChangeStatus manifest(Attributor &A) override {
3479 if (MallocCalls.empty())
3480 return ChangeStatus::UNCHANGED;
3481
3482 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3483 auto &FreeCall = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
3484
3485 Function *F = getAnchorScope();
3486 auto *HS = A.lookupAAFor<AAHeapToStack>(IRPosition::function(*F), this,
3487 DepClassTy::OPTIONAL);
3488
3489 ChangeStatus Changed = ChangeStatus::UNCHANGED;
3490 for (CallBase *CB : MallocCalls) {
3491 // Skip replacing this if HeapToStack has already claimed it.
3492 if (HS && HS->isAssumedHeapToStack(*CB))
3493 continue;
3494
3495 // Find the unique free call to remove it.
3497 for (auto *U : CB->users()) {
3498 CallBase *C = dyn_cast<CallBase>(U);
3499 if (C && C->getCalledFunction() == FreeCall.Declaration)
3500 FreeCalls.push_back(C);
3501 }
3502 if (FreeCalls.size() != 1)
3503 continue;
3504
3505 auto *AllocSize = cast<ConstantInt>(CB->getArgOperand(0));
3506
3507 if (AllocSize->getZExtValue() + SharedMemoryUsed > SharedMemoryLimit) {
3508 LLVM_DEBUG(dbgs() << TAG << "Cannot replace call " << *CB
3509 << " with shared memory."
3510 << " Shared memory usage is limited to "
3511 << SharedMemoryLimit << " bytes\n");
3512 continue;
3513 }
3514
3515 LLVM_DEBUG(dbgs() << TAG << "Replace globalization call " << *CB
3516 << " with " << AllocSize->getZExtValue()
3517 << " bytes of shared memory\n");
3518
3519 // Create a new shared memory buffer of the same size as the allocation
3520 // and replace all the uses of the original allocation with it.
3521 Module *M = CB->getModule();
3522 Type *Int8Ty = Type::getInt8Ty(M->getContext());
3523 Type *Int8ArrTy = ArrayType::get(Int8Ty, AllocSize->getZExtValue());
3524 auto *SharedMem = new GlobalVariable(
3525 *M, Int8ArrTy, /* IsConstant */ false, GlobalValue::InternalLinkage,
3526 PoisonValue::get(Int8ArrTy), CB->getName() + "_shared", nullptr,
3528 static_cast<unsigned>(AddressSpace::Shared));
3529 auto *NewBuffer = ConstantExpr::getPointerCast(
3530 SharedMem, PointerType::getUnqual(M->getContext()));
3531
3532 auto Remark = [&](OptimizationRemark OR) {
3533 return OR << "Replaced globalized variable with "
3534 << ore::NV("SharedMemory", AllocSize->getZExtValue())
3535 << (AllocSize->isOne() ? " byte " : " bytes ")
3536 << "of shared memory.";
3537 };
3538 A.emitRemark<OptimizationRemark>(CB, "OMP111", Remark);
3539
3540 MaybeAlign Alignment = CB->getRetAlign();
3541 assert(Alignment &&
3542 "HeapToShared on allocation without alignment attribute");
3543 SharedMem->setAlignment(*Alignment);
3544
3545 A.changeAfterManifest(IRPosition::callsite_returned(*CB), *NewBuffer);
3546 A.deleteAfterManifest(*CB);
3547 A.deleteAfterManifest(*FreeCalls.front());
3548
3549 SharedMemoryUsed += AllocSize->getZExtValue();
3550 NumBytesMovedToSharedMemory = SharedMemoryUsed;
3551 Changed = ChangeStatus::CHANGED;
3552 }
3553
3554 return Changed;
3555 }
3556
3557 ChangeStatus updateImpl(Attributor &A) override {
3558 if (MallocCalls.empty())
3559 return indicatePessimisticFixpoint();
3560 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3561 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3562 if (!RFI.Declaration)
3563 return ChangeStatus::UNCHANGED;
3564
3565 Function *F = getAnchorScope();
3566
3567 auto NumMallocCalls = MallocCalls.size();
3568
3569 // Only consider malloc calls executed by a single thread with a constant.
3570 for (User *U : RFI.Declaration->users()) {
3571 if (CallBase *CB = dyn_cast<CallBase>(U)) {
3572 if (CB->getCaller() != F)
3573 continue;
3574 if (!MallocCalls.count(CB))
3575 continue;
3576 if (!isa<ConstantInt>(CB->getArgOperand(0))) {
3577 MallocCalls.remove(CB);
3578 continue;
3579 }
3580 const auto *ED = A.getAAFor<AAExecutionDomain>(
3581 *this, IRPosition::function(*F), DepClassTy::REQUIRED);
3582 if (!ED || !ED->isExecutedByInitialThreadOnly(*CB))
3583 MallocCalls.remove(CB);
3584 }
3585 }
3586
3587 findPotentialRemovedFreeCalls(A);
3588
3589 if (NumMallocCalls != MallocCalls.size())
3590 return ChangeStatus::CHANGED;
3591
3592 return ChangeStatus::UNCHANGED;
3593 }
3594
3595 /// Collection of all malloc calls in a function.
3596 SmallSetVector<CallBase *, 4> MallocCalls;
3597 /// Collection of potentially removed free calls in a function.
3598 SmallPtrSet<CallBase *, 4> PotentialRemovedFreeCalls;
3599 /// The total amount of shared memory that has been used for HeapToShared.
3600 unsigned SharedMemoryUsed = 0;
3601};
3602
3603struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
3604 using Base = StateWrapper<KernelInfoState, AbstractAttribute>;
3605 AAKernelInfo(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
3606
3607 /// The callee value is tracked beyond a simple stripPointerCasts, so we allow
3608 /// unknown callees.
3609 static bool requiresCalleeForCallBase() { return false; }
3610
3611 /// Statistics are tracked as part of manifest for now.
3612 void trackStatistics() const override {}
3613
3614 /// See AbstractAttribute::getAsStr()
3615 const std::string getAsStr(Attributor *) const override {
3616 if (!isValidState())
3617 return "<invalid>";
3618 return std::string(SPMDCompatibilityTracker.isAssumed() ? "SPMD"
3619 : "generic") +
3620 std::string(SPMDCompatibilityTracker.isAtFixpoint() ? " [FIX]"
3621 : "") +
3622 std::string(" #PRs: ") +
3623 (ReachedKnownParallelRegions.isValidState()
3624 ? std::to_string(ReachedKnownParallelRegions.size())
3625 : "<invalid>") +
3626 ", #Unknown PRs: " +
3627 (ReachedUnknownParallelRegions.isValidState()
3628 ? std::to_string(ReachedUnknownParallelRegions.size())
3629 : "<invalid>") +
3630 ", #Reaching Kernels: " +
3631 (ReachingKernelEntries.isValidState()
3632 ? std::to_string(ReachingKernelEntries.size())
3633 : "<invalid>") +
3634 ", #ParLevels: " +
3635 (ParallelLevels.isValidState()
3636 ? std::to_string(ParallelLevels.size())
3637 : "<invalid>") +
3638 ", NestedPar: " + (NestedParallelism ? "yes" : "no");
3639 }
3640
3641 /// Create an abstract attribute biew for the position \p IRP.
3642 static AAKernelInfo &createForPosition(const IRPosition &IRP, Attributor &A);
3643
3644 /// See AbstractAttribute::getName()
3645 StringRef getName() const override { return "AAKernelInfo"; }
3646
3647 /// See AbstractAttribute::getIdAddr()
3648 const char *getIdAddr() const override { return &ID; }
3649
3650 /// This function should return true if the type of the \p AA is AAKernelInfo
3651 static bool classof(const AbstractAttribute *AA) {
3652 return (AA->getIdAddr() == &ID);
3653 }
3654
3655 static const char ID;
3656};
3657
3658/// The function kernel info abstract attribute, basically, what can we say
3659/// about a function with regards to the KernelInfoState.
3660struct AAKernelInfoFunction : AAKernelInfo {
3661 AAKernelInfoFunction(const IRPosition &IRP, Attributor &A)
3662 : AAKernelInfo(IRP, A) {}
3663
3664 SmallPtrSet<Instruction *, 4> GuardedInstructions;
3665
3666 SmallPtrSetImpl<Instruction *> &getGuardedInstructions() {
3667 return GuardedInstructions;
3668 }
3669
3670 void setConfigurationOfKernelEnvironment(ConstantStruct *ConfigC) {
3672 KernelEnvC, ConfigC, {KernelInfo::ConfigurationIdx});
3673 assert(NewKernelEnvC && "Failed to create new kernel environment");
3674 KernelEnvC = cast<ConstantStruct>(NewKernelEnvC);
3675 }
3676
3677#define KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MEMBER) \
3678 void set##MEMBER##OfKernelEnvironment(ConstantInt *NewVal) { \
3679 ConstantStruct *ConfigC = \
3680 KernelInfo::getConfigurationFromKernelEnvironment(KernelEnvC); \
3681 Constant *NewConfigC = ConstantFoldInsertValueInstruction( \
3682 ConfigC, NewVal, {KernelInfo::MEMBER##Idx}); \
3683 assert(NewConfigC && "Failed to create new configuration environment"); \
3684 setConfigurationOfKernelEnvironment(cast<ConstantStruct>(NewConfigC)); \
3685 }
3686
3687 KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(UseGenericStateMachine)
3688 KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MayUseNestedParallelism)
3694
3695#undef KERNEL_ENVIRONMENT_CONFIGURATION_SETTER
3696
3697 /// See AbstractAttribute::initialize(...).
3698 void initialize(Attributor &A) override {
3699 // This is a high-level transform that might change the constant arguments
3700 // of the init and dinit calls. We need to tell the Attributor about this
3701 // to avoid other parts using the current constant value for simpliication.
3702 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3703
3704 Function *Fn = getAnchorScope();
3705
3706 OMPInformationCache::RuntimeFunctionInfo &InitRFI =
3707 OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
3708 OMPInformationCache::RuntimeFunctionInfo &DeinitRFI =
3709 OMPInfoCache.RFIs[OMPRTL___kmpc_target_deinit];
3710
3711 // For kernels we perform more initialization work, first we find the init
3712 // and deinit calls.
3713 auto StoreCallBase = [](Use &U,
3714 OMPInformationCache::RuntimeFunctionInfo &RFI,
3715 CallBase *&Storage) {
3716 CallBase *CB = OpenMPOpt::getCallIfRegularCall(U, &RFI);
3717 assert(CB &&
3718 "Unexpected use of __kmpc_target_init or __kmpc_target_deinit!");
3719 assert(!Storage &&
3720 "Multiple uses of __kmpc_target_init or __kmpc_target_deinit!");
3721 Storage = CB;
3722 return false;
3723 };
3724 InitRFI.foreachUse(
3725 [&](Use &U, Function &) {
3726 StoreCallBase(U, InitRFI, KernelInitCB);
3727 return false;
3728 },
3729 Fn);
3730 DeinitRFI.foreachUse(
3731 [&](Use &U, Function &) {
3732 StoreCallBase(U, DeinitRFI, KernelDeinitCB);
3733 return false;
3734 },
3735 Fn);
3736
3737 // Ignore kernels without initializers such as global constructors.
3738 if (!KernelInitCB || !KernelDeinitCB)
3739 return;
3740
3741 // Add itself to the reaching kernel and set IsKernelEntry.
3742 ReachingKernelEntries.insert(Fn);
3743 IsKernelEntry = true;
3744
3745 KernelEnvC =
3747 GlobalVariable *KernelEnvGV =
3749
3751 KernelConfigurationSimplifyCB =
3752 [&](const GlobalVariable &GV, const AbstractAttribute *AA,
3753 bool &UsedAssumedInformation) -> std::optional<Constant *> {
3754 if (!isAtFixpoint()) {
3755 if (!AA)
3756 return nullptr;
3757 UsedAssumedInformation = true;
3758 A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
3759 }
3760 return KernelEnvC;
3761 };
3762
3763 A.registerGlobalVariableSimplificationCallback(
3764 *KernelEnvGV, KernelConfigurationSimplifyCB);
3765
3766 // We cannot change to SPMD mode if the runtime functions aren't availible.
3767 bool CanChangeToSPMD = OMPInfoCache.runtimeFnsAvailable(
3768 {OMPRTL___kmpc_get_hardware_thread_id_in_block,
3769 OMPRTL___kmpc_barrier_simple_spmd});
3770
3771 // Check if we know we are in SPMD-mode already.
3772 ConstantInt *ExecModeC =
3773 KernelInfo::getExecModeFromKernelEnvironment(KernelEnvC);
3774 ConstantInt *AssumedExecModeC = ConstantInt::get(
3775 ExecModeC->getIntegerType(),
3777 if (ExecModeC->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD)
3778 SPMDCompatibilityTracker.indicateOptimisticFixpoint();
3779 else if (DisableOpenMPOptSPMDization || !CanChangeToSPMD)
3780 // This is a generic region but SPMDization is disabled so stop
3781 // tracking.
3782 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3783 else
3784 setExecModeOfKernelEnvironment(AssumedExecModeC);
3785
3786 const Triple T(Fn->getParent()->getTargetTriple());
3787 auto *Int32Ty = Type::getInt32Ty(Fn->getContext());
3788 auto [MinThreads, MaxThreads] =
3790 if (MinThreads)
3791 setMinThreadsOfKernelEnvironment(ConstantInt::get(Int32Ty, MinThreads));
3792 if (MaxThreads)
3793 setMaxThreadsOfKernelEnvironment(ConstantInt::get(Int32Ty, MaxThreads));
3794 auto [MinTeams, MaxTeams] =
3796 if (MinTeams)
3797 setMinTeamsOfKernelEnvironment(ConstantInt::get(Int32Ty, MinTeams));
3798 if (MaxTeams)
3799 setMaxTeamsOfKernelEnvironment(ConstantInt::get(Int32Ty, MaxTeams));
3800
3801 ConstantInt *MayUseNestedParallelismC =
3802 KernelInfo::getMayUseNestedParallelismFromKernelEnvironment(KernelEnvC);
3803 ConstantInt *AssumedMayUseNestedParallelismC = ConstantInt::get(
3804 MayUseNestedParallelismC->getIntegerType(), NestedParallelism);
3805 setMayUseNestedParallelismOfKernelEnvironment(
3806 AssumedMayUseNestedParallelismC);
3807
3809 ConstantInt *UseGenericStateMachineC =
3810 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
3811 KernelEnvC);
3812 ConstantInt *AssumedUseGenericStateMachineC =
3813 ConstantInt::get(UseGenericStateMachineC->getIntegerType(), false);
3814 setUseGenericStateMachineOfKernelEnvironment(
3815 AssumedUseGenericStateMachineC);
3816 }
3817
3818 // Register virtual uses of functions we might need to preserve.
3819 auto RegisterVirtualUse = [&](RuntimeFunction RFKind,
3821 if (!OMPInfoCache.RFIs[RFKind].Declaration)
3822 return;
3823 A.registerVirtualUseCallback(*OMPInfoCache.RFIs[RFKind].Declaration, CB);
3824 };
3825
3826 // Add a dependence to ensure updates if the state changes.
3827 auto AddDependence = [](Attributor &A, const AAKernelInfo *KI,
3828 const AbstractAttribute *QueryingAA) {
3829 if (QueryingAA) {
3830 A.recordDependence(*KI, *QueryingAA, DepClassTy::OPTIONAL);
3831 }
3832 return true;
3833 };
3834
3835 Attributor::VirtualUseCallbackTy CustomStateMachineUseCB =
3836 [&](Attributor &A, const AbstractAttribute *QueryingAA) {
3837 // Whenever we create a custom state machine we will insert calls to
3838 // __kmpc_get_hardware_num_threads_in_block,
3839 // __kmpc_get_warp_size,
3840 // __kmpc_barrier_simple_generic,
3841 // __kmpc_kernel_parallel, and
3842 // __kmpc_kernel_end_parallel.
3843 // Not needed if we are on track for SPMDzation.
3844 if (SPMDCompatibilityTracker.isValidState())
3845 return AddDependence(A, this, QueryingAA);
3846 // Not needed if we can't rewrite due to an invalid state.
3847 if (!ReachedKnownParallelRegions.isValidState())
3848 return AddDependence(A, this, QueryingAA);
3849 return false;
3850 };
3851
3852 // Not needed if we are pre-runtime merge.
3853 if (!KernelInitCB->getCalledFunction()->isDeclaration()) {
3854 RegisterVirtualUse(OMPRTL___kmpc_get_hardware_num_threads_in_block,
3855 CustomStateMachineUseCB);
3856 RegisterVirtualUse(OMPRTL___kmpc_get_warp_size, CustomStateMachineUseCB);
3857 RegisterVirtualUse(OMPRTL___kmpc_barrier_simple_generic,
3858 CustomStateMachineUseCB);
3859 RegisterVirtualUse(OMPRTL___kmpc_kernel_parallel,
3860 CustomStateMachineUseCB);
3861 RegisterVirtualUse(OMPRTL___kmpc_kernel_end_parallel,
3862 CustomStateMachineUseCB);
3863 }
3864
3865 // If we do not perform SPMDzation we do not need the virtual uses below.
3866 if (SPMDCompatibilityTracker.isAtFixpoint())
3867 return;
3868
3869 Attributor::VirtualUseCallbackTy HWThreadIdUseCB =
3870 [&](Attributor &A, const AbstractAttribute *QueryingAA) {
3871 // Whenever we perform SPMDzation we will insert
3872 // __kmpc_get_hardware_thread_id_in_block calls.
3873 if (!SPMDCompatibilityTracker.isValidState())
3874 return AddDependence(A, this, QueryingAA);
3875 return false;
3876 };
3877 RegisterVirtualUse(OMPRTL___kmpc_get_hardware_thread_id_in_block,
3878 HWThreadIdUseCB);
3879
3880 Attributor::VirtualUseCallbackTy SPMDBarrierUseCB =
3881 [&](Attributor &A, const AbstractAttribute *QueryingAA) {
3882 // Whenever we perform SPMDzation with guarding we will insert
3883 // __kmpc_simple_barrier_spmd calls. If SPMDzation failed, there is
3884 // nothing to guard, or there are no parallel regions, we don't need
3885 // the calls.
3886 if (!SPMDCompatibilityTracker.isValidState())
3887 return AddDependence(A, this, QueryingAA);
3888 if (SPMDCompatibilityTracker.empty())
3889 return AddDependence(A, this, QueryingAA);
3890 if (!mayContainParallelRegion())
3891 return AddDependence(A, this, QueryingAA);
3892 return false;
3893 };
3894 RegisterVirtualUse(OMPRTL___kmpc_barrier_simple_spmd, SPMDBarrierUseCB);
3895 }
3896
3897 /// Sanitize the string \p S such that it is a suitable global symbol name.
3898 static std::string sanitizeForGlobalName(std::string S) {
3899 std::replace_if(
3900 S.begin(), S.end(),
3901 [](const char C) {
3902 return !((C >= 'a' && C <= 'z') || (C >= 'A' && C <= 'Z') ||
3903 (C >= '0' && C <= '9') || C == '_');
3904 },
3905 '.');
3906 return S;
3907 }
3908
3909 /// Modify the IR based on the KernelInfoState as the fixpoint iteration is
3910 /// finished now.
3911 ChangeStatus manifest(Attributor &A) override {
3912 // If we are not looking at a kernel with __kmpc_target_init and
3913 // __kmpc_target_deinit call we cannot actually manifest the information.
3914 if (!KernelInitCB || !KernelDeinitCB)
3915 return ChangeStatus::UNCHANGED;
3916
3917 ChangeStatus Changed = ChangeStatus::UNCHANGED;
3918
3919 bool HasBuiltStateMachine = true;
3920 if (!changeToSPMDMode(A, Changed)) {
3921 if (!KernelInitCB->getCalledFunction()->isDeclaration())
3922 HasBuiltStateMachine = buildCustomStateMachine(A, Changed);
3923 else
3924 HasBuiltStateMachine = false;
3925 }
3926
3927 // We need to reset KernelEnvC if specific rewriting is not done.
3928 ConstantStruct *ExistingKernelEnvC =
3930 ConstantInt *OldUseGenericStateMachineVal =
3931 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
3932 ExistingKernelEnvC);
3933 if (!HasBuiltStateMachine)
3934 setUseGenericStateMachineOfKernelEnvironment(
3935 OldUseGenericStateMachineVal);
3936
3937 // At last, update the KernelEnvc
3938 GlobalVariable *KernelEnvGV =
3940 if (KernelEnvGV->getInitializer() != KernelEnvC) {
3941 KernelEnvGV->setInitializer(KernelEnvC);
3942 Changed = ChangeStatus::CHANGED;
3943 }
3944
3945 return Changed;
3946 }
3947
3948 void insertInstructionGuardsHelper(Attributor &A) {
3949 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3950
3951 auto CreateGuardedRegion = [&](Instruction *RegionStartI,
3952 Instruction *RegionEndI) {
3953 LoopInfo *LI = nullptr;
3954 DominatorTree *DT = nullptr;
3955 MemorySSAUpdater *MSU = nullptr;
3956 using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
3957
3958 BasicBlock *ParentBB = RegionStartI->getParent();
3959 Function *Fn = ParentBB->getParent();
3960 Module &M = *Fn->getParent();
3961
3962 // Create all the blocks and logic.
3963 // ParentBB:
3964 // goto RegionCheckTidBB
3965 // RegionCheckTidBB:
3966 // Tid = __kmpc_hardware_thread_id()
3967 // if (Tid != 0)
3968 // goto RegionBarrierBB
3969 // RegionStartBB:
3970 // <execute instructions guarded>
3971 // goto RegionEndBB
3972 // RegionEndBB:
3973 // <store escaping values to shared mem>
3974 // goto RegionBarrierBB
3975 // RegionBarrierBB:
3976 // __kmpc_simple_barrier_spmd()
3977 // // second barrier is omitted if lacking escaping values.
3978 // <load escaping values from shared mem>
3979 // __kmpc_simple_barrier_spmd()
3980 // goto RegionExitBB
3981 // RegionExitBB:
3982 // <execute rest of instructions>
3983
3984 BasicBlock *RegionEndBB = SplitBlock(ParentBB, RegionEndI->getNextNode(),
3985 DT, LI, MSU, "region.guarded.end");
3986 BasicBlock *RegionBarrierBB =
3987 SplitBlock(RegionEndBB, &*RegionEndBB->getFirstInsertionPt(), DT, LI,
3988 MSU, "region.barrier");
3989 BasicBlock *RegionExitBB =
3990 SplitBlock(RegionBarrierBB, &*RegionBarrierBB->getFirstInsertionPt(),
3991 DT, LI, MSU, "region.exit");
3992 BasicBlock *RegionStartBB =
3993 SplitBlock(ParentBB, RegionStartI, DT, LI, MSU, "region.guarded");
3994
3995 assert(ParentBB->getUniqueSuccessor() == RegionStartBB &&
3996 "Expected a different CFG");
3997
3998 BasicBlock *RegionCheckTidBB = SplitBlock(
3999 ParentBB, ParentBB->getTerminator(), DT, LI, MSU, "region.check.tid");
4000
4001 // Register basic blocks with the Attributor.
4002 A.registerManifestAddedBasicBlock(*RegionEndBB);
4003 A.registerManifestAddedBasicBlock(*RegionBarrierBB);
4004 A.registerManifestAddedBasicBlock(*RegionExitBB);
4005 A.registerManifestAddedBasicBlock(*RegionStartBB);
4006 A.registerManifestAddedBasicBlock(*RegionCheckTidBB);
4007
4008 bool HasBroadcastValues = false;
4009 // Find escaping outputs from the guarded region to outside users and
4010 // broadcast their values to them.
4011 for (Instruction &I : *RegionStartBB) {
4012 SmallVector<Use *, 4> OutsideUses;
4013 for (Use &U : I.uses()) {
4014 Instruction &UsrI = *cast<Instruction>(U.getUser());
4015 if (UsrI.getParent() != RegionStartBB)
4016 OutsideUses.push_back(&U);
4017 }
4018
4019 if (OutsideUses.empty())
4020 continue;
4021
4022 HasBroadcastValues = true;
4023
4024 // Emit a global variable in shared memory to store the broadcasted
4025 // value.
4026 auto *SharedMem = new GlobalVariable(
4027 M, I.getType(), /* IsConstant */ false,
4029 sanitizeForGlobalName(
4030 (I.getName() + ".guarded.output.alloc").str()),
4032 static_cast<unsigned>(AddressSpace::Shared));
4033
4034 // Emit a store instruction to update the value.
4035 new StoreInst(&I, SharedMem,
4036 RegionEndBB->getTerminator()->getIterator());
4037
4038 LoadInst *LoadI = new LoadInst(
4039 I.getType(), SharedMem, I.getName() + ".guarded.output.load",
4040 RegionBarrierBB->getTerminator()->getIterator());
4041
4042 // Emit a load instruction and replace uses of the output value.
4043 for (Use *U : OutsideUses)
4044 A.changeUseAfterManifest(*U, *LoadI);
4045 }
4046
4047 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4048
4049 // Go to tid check BB in ParentBB.
4050 const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
4051 ParentBB->getTerminator()->eraseFromParent();
4052 OpenMPIRBuilder::LocationDescription Loc(
4053 InsertPointTy(ParentBB, ParentBB->end()), DL);
4054 OMPInfoCache.OMPBuilder.updateToLocation(Loc);
4055 uint32_t SrcLocStrSize;
4056 auto *SrcLocStr =
4057 OMPInfoCache.OMPBuilder.getOrCreateSrcLocStr(Loc, SrcLocStrSize);
4058 Value *Ident =
4059 OMPInfoCache.OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4060 BranchInst::Create(RegionCheckTidBB, ParentBB)->setDebugLoc(DL);
4061
4062 // Add check for Tid in RegionCheckTidBB
4063 RegionCheckTidBB->getTerminator()->eraseFromParent();
4064 OpenMPIRBuilder::LocationDescription LocRegionCheckTid(
4065 InsertPointTy(RegionCheckTidBB, RegionCheckTidBB->end()), DL);
4066 OMPInfoCache.OMPBuilder.updateToLocation(LocRegionCheckTid);
4067 FunctionCallee HardwareTidFn =
4068 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4069 M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
4070 CallInst *Tid =
4071 OMPInfoCache.OMPBuilder.Builder.CreateCall(HardwareTidFn, {});
4072 Tid->setDebugLoc(DL);
4073 OMPInfoCache.setCallingConvention(HardwareTidFn, Tid);
4074 Value *TidCheck = OMPInfoCache.OMPBuilder.Builder.CreateIsNull(Tid);
4075 OMPInfoCache.OMPBuilder.Builder
4076 .CreateCondBr(TidCheck, RegionStartBB, RegionBarrierBB)
4077 ->setDebugLoc(DL);
4078
4079 // First barrier for synchronization, ensures main thread has updated
4080 // values.
4081 FunctionCallee BarrierFn =
4082 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4083 M, OMPRTL___kmpc_barrier_simple_spmd);
4084 OMPInfoCache.OMPBuilder.updateToLocation(InsertPointTy(
4085 RegionBarrierBB, RegionBarrierBB->getFirstInsertionPt()));
4086 CallInst *Barrier =
4087 OMPInfoCache.OMPBuilder.Builder.CreateCall(BarrierFn, {Ident, Tid});
4088 Barrier->setDebugLoc(DL);
4089 OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
4090
4091 // Second barrier ensures workers have read broadcast values.
4092 if (HasBroadcastValues) {
4093 CallInst *Barrier =
4094 CallInst::Create(BarrierFn, {Ident, Tid}, "",
4095 RegionBarrierBB->getTerminator()->getIterator());
4096 Barrier->setDebugLoc(DL);
4097 OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
4098 }
4099 };
4100
4101 auto &AllocSharedRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
4102 SmallPtrSet<BasicBlock *, 8> Visited;
4103 for (Instruction *GuardedI : SPMDCompatibilityTracker) {
4104 BasicBlock *BB = GuardedI->getParent();
4105 if (!Visited.insert(BB).second)
4106 continue;
4107
4109 Instruction *LastEffect = nullptr;
4110 BasicBlock::reverse_iterator IP = BB->rbegin(), IPEnd = BB->rend();
4111 while (++IP != IPEnd) {
4112 if (!IP->mayHaveSideEffects() && !IP->mayReadFromMemory())
4113 continue;
4114 Instruction *I = &*IP;
4115 if (OpenMPOpt::getCallIfRegularCall(*I, &AllocSharedRFI))
4116 continue;
4117 if (!I->user_empty() || !SPMDCompatibilityTracker.contains(I)) {
4118 LastEffect = nullptr;
4119 continue;
4120 }
4121 if (LastEffect)
4122 Reorders.push_back({I, LastEffect});
4123 LastEffect = &*IP;
4124 }
4125 for (auto &Reorder : Reorders)
4126 Reorder.first->moveBefore(Reorder.second->getIterator());
4127 }
4128
4130
4131 for (Instruction *GuardedI : SPMDCompatibilityTracker) {
4132 BasicBlock *BB = GuardedI->getParent();
4133 auto *CalleeAA = A.lookupAAFor<AAKernelInfo>(
4134 IRPosition::function(*GuardedI->getFunction()), nullptr,
4135 DepClassTy::NONE);
4136 assert(CalleeAA != nullptr && "Expected Callee AAKernelInfo");
4137 auto &CalleeAAFunction = *cast<AAKernelInfoFunction>(CalleeAA);
4138 // Continue if instruction is already guarded.
4139 if (CalleeAAFunction.getGuardedInstructions().contains(GuardedI))
4140 continue;
4141
4142 Instruction *GuardedRegionStart = nullptr, *GuardedRegionEnd = nullptr;
4143 for (Instruction &I : *BB) {
4144 // If instruction I needs to be guarded update the guarded region
4145 // bounds.
4146 if (SPMDCompatibilityTracker.contains(&I)) {
4147 CalleeAAFunction.getGuardedInstructions().insert(&I);
4148 if (GuardedRegionStart)
4149 GuardedRegionEnd = &I;
4150 else
4151 GuardedRegionStart = GuardedRegionEnd = &I;
4152
4153 continue;
4154 }
4155
4156 // Instruction I does not need guarding, store
4157 // any region found and reset bounds.
4158 if (GuardedRegionStart) {
4159 GuardedRegions.push_back(
4160 std::make_pair(GuardedRegionStart, GuardedRegionEnd));
4161 GuardedRegionStart = nullptr;
4162 GuardedRegionEnd = nullptr;
4163 }
4164 }
4165 }
4166
4167 for (auto &GR : GuardedRegions)
4168 CreateGuardedRegion(GR.first, GR.second);
4169 }
4170
4171 void forceSingleThreadPerWorkgroupHelper(Attributor &A) {
4172 // Only allow 1 thread per workgroup to continue executing the user code.
4173 //
4174 // InitCB = __kmpc_target_init(...)
4175 // ThreadIdInBlock = __kmpc_get_hardware_thread_id_in_block();
4176 // if (ThreadIdInBlock != 0) return;
4177 // UserCode:
4178 // // user code
4179 //
4180 auto &Ctx = getAnchorValue().getContext();
4181 Function *Kernel = getAssociatedFunction();
4182 assert(Kernel && "Expected an associated function!");
4183
4184 // Create block for user code to branch to from initial block.
4185 BasicBlock *InitBB = KernelInitCB->getParent();
4186 BasicBlock *UserCodeBB = InitBB->splitBasicBlock(
4187 KernelInitCB->getNextNode(), "main.thread.user_code");
4188 BasicBlock *ReturnBB =
4189 BasicBlock::Create(Ctx, "exit.threads", Kernel, UserCodeBB);
4190
4191 // Register blocks with attributor:
4192 A.registerManifestAddedBasicBlock(*InitBB);
4193 A.registerManifestAddedBasicBlock(*UserCodeBB);
4194 A.registerManifestAddedBasicBlock(*ReturnBB);
4195
4196 // Debug location:
4197 const DebugLoc &DLoc = KernelInitCB->getDebugLoc();
4198 ReturnInst::Create(Ctx, ReturnBB)->setDebugLoc(DLoc);
4199 InitBB->getTerminator()->eraseFromParent();
4200
4201 // Prepare call to OMPRTL___kmpc_get_hardware_thread_id_in_block.
4202 Module &M = *Kernel->getParent();
4203 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4204 FunctionCallee ThreadIdInBlockFn =
4205 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4206 M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
4207
4208 // Get thread ID in block.
4209 CallInst *ThreadIdInBlock =
4210 CallInst::Create(ThreadIdInBlockFn, "thread_id.in.block", InitBB);
4211 OMPInfoCache.setCallingConvention(ThreadIdInBlockFn, ThreadIdInBlock);
4212 ThreadIdInBlock->setDebugLoc(DLoc);
4213
4214 // Eliminate all threads in the block with ID not equal to 0:
4215 Instruction *IsMainThread =
4216 ICmpInst::Create(ICmpInst::ICmp, CmpInst::ICMP_NE, ThreadIdInBlock,
4217 ConstantInt::get(ThreadIdInBlock->getType(), 0),
4218 "thread.is_main", InitBB);
4219 IsMainThread->setDebugLoc(DLoc);
4220 BranchInst::Create(ReturnBB, UserCodeBB, IsMainThread, InitBB);
4221 }
4222
4223 bool changeToSPMDMode(Attributor &A, ChangeStatus &Changed) {
4224 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4225
4226 if (!SPMDCompatibilityTracker.isAssumed()) {
4227 for (Instruction *NonCompatibleI : SPMDCompatibilityTracker) {
4228 if (!NonCompatibleI)
4229 continue;
4230
4231 // Skip diagnostics on calls to known OpenMP runtime functions for now.
4232 if (auto *CB = dyn_cast<CallBase>(NonCompatibleI))
4233 if (OMPInfoCache.RTLFunctions.contains(CB->getCalledFunction()))
4234 continue;
4235
4236 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
4237 ORA << "Value has potential side effects preventing SPMD-mode "
4238 "execution";
4239 if (isa<CallBase>(NonCompatibleI)) {
4240 ORA << ". Add `[[omp::assume(\"ompx_spmd_amenable\")]]` to "
4241 "the called function to override";
4242 }
4243 return ORA << ".";
4244 };
4245 A.emitRemark<OptimizationRemarkAnalysis>(NonCompatibleI, "OMP121",
4246 Remark);
4247
4248 LLVM_DEBUG(dbgs() << TAG << "SPMD-incompatible side-effect: "
4249 << *NonCompatibleI << "\n");
4250 }
4251
4252 return false;
4253 }
4254
4255 // Get the actual kernel, could be the caller of the anchor scope if we have
4256 // a debug wrapper.
4257 Function *Kernel = getAnchorScope();
4258 if (Kernel->hasLocalLinkage()) {
4259 assert(Kernel->hasOneUse() && "Unexpected use of debug kernel wrapper.");
4260 auto *CB = cast<CallBase>(Kernel->user_back());
4261 Kernel = CB->getCaller();
4262 }
4263 assert(omp::isOpenMPKernel(*Kernel) && "Expected kernel function!");
4264
4265 // Check if the kernel is already in SPMD mode, if so, return success.
4266 ConstantStruct *ExistingKernelEnvC =
4268 auto *ExecModeC =
4269 KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC);
4270 const int8_t ExecModeVal = ExecModeC->getSExtValue();
4271 if (ExecModeVal != OMP_TGT_EXEC_MODE_GENERIC)
4272 return true;
4273
4274 // We will now unconditionally modify the IR, indicate a change.
4275 Changed = ChangeStatus::CHANGED;
4276
4277 // Do not use instruction guards when no parallel is present inside
4278 // the target region.
4279 if (mayContainParallelRegion())
4280 insertInstructionGuardsHelper(A);
4281 else
4282 forceSingleThreadPerWorkgroupHelper(A);
4283
4284 // Adjust the global exec mode flag that tells the runtime what mode this
4285 // kernel is executed in.
4286 assert(ExecModeVal == OMP_TGT_EXEC_MODE_GENERIC &&
4287 "Initially non-SPMD kernel has SPMD exec mode!");
4288 setExecModeOfKernelEnvironment(
4289 ConstantInt::get(ExecModeC->getIntegerType(),
4290 ExecModeVal | OMP_TGT_EXEC_MODE_GENERIC_SPMD));
4291
4292 ++NumOpenMPTargetRegionKernelsSPMD;
4293
4294 auto Remark = [&](OptimizationRemark OR) {
4295 return OR << "Transformed generic-mode kernel to SPMD-mode.";
4296 };
4297 A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP120", Remark);
4298 return true;
4299 };
4300
4301 bool buildCustomStateMachine(Attributor &A, ChangeStatus &Changed) {
4302 // If we have disabled state machine rewrites, don't make a custom one
4304 return false;
4305
4306 // Don't rewrite the state machine if we are not in a valid state.
4307 if (!ReachedKnownParallelRegions.isValidState())
4308 return false;
4309
4310 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4311 if (!OMPInfoCache.runtimeFnsAvailable(
4312 {OMPRTL___kmpc_get_hardware_num_threads_in_block,
4313 OMPRTL___kmpc_get_warp_size, OMPRTL___kmpc_barrier_simple_generic,
4314 OMPRTL___kmpc_kernel_parallel, OMPRTL___kmpc_kernel_end_parallel}))
4315 return false;
4316
4317 ConstantStruct *ExistingKernelEnvC =
4319
4320 // Check if the current configuration is non-SPMD and generic state machine.
4321 // If we already have SPMD mode or a custom state machine we do not need to
4322 // go any further. If it is anything but a constant something is weird and
4323 // we give up.
4324 ConstantInt *UseStateMachineC =
4325 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
4326 ExistingKernelEnvC);
4327 ConstantInt *ModeC =
4328 KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC);
4329
4330 // If we are stuck with generic mode, try to create a custom device (=GPU)
4331 // state machine which is specialized for the parallel regions that are
4332 // reachable by the kernel.
4333 if (UseStateMachineC->isZero() ||
4335 return false;
4336
4337 Changed = ChangeStatus::CHANGED;
4338
4339 // If not SPMD mode, indicate we use a custom state machine now.
4340 setUseGenericStateMachineOfKernelEnvironment(
4341 ConstantInt::get(UseStateMachineC->getIntegerType(), false));
4342
4343 // If we don't actually need a state machine we are done here. This can
4344 // happen if there simply are no parallel regions. In the resulting kernel
4345 // all worker threads will simply exit right away, leaving the main thread
4346 // to do the work alone.
4347 if (!mayContainParallelRegion()) {
4348 ++NumOpenMPTargetRegionKernelsWithoutStateMachine;
4349
4350 auto Remark = [&](OptimizationRemark OR) {
4351 return OR << "Removing unused state machine from generic-mode kernel.";
4352 };
4353 A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP130", Remark);
4354
4355 return true;
4356 }
4357
4358 // Keep track in the statistics of our new shiny custom state machine.
4359 if (ReachedUnknownParallelRegions.empty()) {
4360 ++NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback;
4361
4362 auto Remark = [&](OptimizationRemark OR) {
4363 return OR << "Rewriting generic-mode kernel with a customized state "
4364 "machine.";
4365 };
4366 A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP131", Remark);
4367 } else {
4368 ++NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback;
4369
4370 auto Remark = [&](OptimizationRemarkAnalysis OR) {
4371 return OR << "Generic-mode kernel is executed with a customized state "
4372 "machine that requires a fallback.";
4373 };
4374 A.emitRemark<OptimizationRemarkAnalysis>(KernelInitCB, "OMP132", Remark);
4375
4376 // Tell the user why we ended up with a fallback.
4377 for (CallBase *UnknownParallelRegionCB : ReachedUnknownParallelRegions) {
4378 if (!UnknownParallelRegionCB)
4379 continue;
4380 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
4381 return ORA << "Call may contain unknown parallel regions. Use "
4382 << "`[[omp::assume(\"omp_no_parallelism\")]]` to "
4383 "override.";
4384 };
4385 A.emitRemark<OptimizationRemarkAnalysis>(UnknownParallelRegionCB,
4386 "OMP133", Remark);
4387 }
4388 }
4389
4390 // Create all the blocks:
4391 //
4392 // InitCB = __kmpc_target_init(...)
4393 // BlockHwSize =
4394 // __kmpc_get_hardware_num_threads_in_block();
4395 // WarpSize = __kmpc_get_warp_size();
4396 // BlockSize = BlockHwSize - WarpSize;
4397 // IsWorkerCheckBB: bool IsWorker = InitCB != -1;
4398 // if (IsWorker) {
4399 // if (InitCB >= BlockSize) return;
4400 // SMBeginBB: __kmpc_barrier_simple_generic(...);
4401 // void *WorkFn;
4402 // bool Active = __kmpc_kernel_parallel(&WorkFn);
4403 // if (!WorkFn) return;
4404 // SMIsActiveCheckBB: if (Active) {
4405 // SMIfCascadeCurrentBB: if (WorkFn == <ParFn0>)
4406 // ParFn0(...);
4407 // SMIfCascadeCurrentBB: else if (WorkFn == <ParFn1>)
4408 // ParFn1(...);
4409 // ...
4410 // SMIfCascadeCurrentBB: else
4411 // ((WorkFnTy*)WorkFn)(...);
4412 // SMEndParallelBB: __kmpc_kernel_end_parallel(...);
4413 // }
4414 // SMDoneBB: __kmpc_barrier_simple_generic(...);
4415 // goto SMBeginBB;
4416 // }
4417 // UserCodeEntryBB: // user code
4418 // __kmpc_target_deinit(...)
4419 //
4420 auto &Ctx = getAnchorValue().getContext();
4421 Function *Kernel = getAssociatedFunction();
4422 assert(Kernel && "Expected an associated function!");
4423
4424 BasicBlock *InitBB = KernelInitCB->getParent();
4425 BasicBlock *UserCodeEntryBB = InitBB->splitBasicBlock(
4426 KernelInitCB->getNextNode(), "thread.user_code.check");
4427 BasicBlock *IsWorkerCheckBB =
4428 BasicBlock::Create(Ctx, "is_worker_check", Kernel, UserCodeEntryBB);
4429 BasicBlock *StateMachineBeginBB = BasicBlock::Create(
4430 Ctx, "worker_state_machine.begin", Kernel, UserCodeEntryBB);
4431 BasicBlock *StateMachineFinishedBB = BasicBlock::Create(
4432 Ctx, "worker_state_machine.finished", Kernel, UserCodeEntryBB);
4433 BasicBlock *StateMachineIsActiveCheckBB = BasicBlock::Create(
4434 Ctx, "worker_state_machine.is_active.check", Kernel, UserCodeEntryBB);
4435 BasicBlock *StateMachineIfCascadeCurrentBB =
4436 BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
4437 Kernel, UserCodeEntryBB);
4438 BasicBlock *StateMachineEndParallelBB =
4439 BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.end",
4440 Kernel, UserCodeEntryBB);
4441 BasicBlock *StateMachineDoneBarrierBB = BasicBlock::Create(
4442 Ctx, "worker_state_machine.done.barrier", Kernel, UserCodeEntryBB);
4443 A.registerManifestAddedBasicBlock(*InitBB);
4444 A.registerManifestAddedBasicBlock(*UserCodeEntryBB);
4445 A.registerManifestAddedBasicBlock(*IsWorkerCheckBB);
4446 A.registerManifestAddedBasicBlock(*StateMachineBeginBB);
4447 A.registerManifestAddedBasicBlock(*StateMachineFinishedBB);
4448 A.registerManifestAddedBasicBlock(*StateMachineIsActiveCheckBB);
4449 A.registerManifestAddedBasicBlock(*StateMachineIfCascadeCurrentBB);
4450 A.registerManifestAddedBasicBlock(*StateMachineEndParallelBB);
4451 A.registerManifestAddedBasicBlock(*StateMachineDoneBarrierBB);
4452
4453 const DebugLoc &DLoc = KernelInitCB->getDebugLoc();
4454 ReturnInst::Create(Ctx, StateMachineFinishedBB)->setDebugLoc(DLoc);
4455 InitBB->getTerminator()->eraseFromParent();
4456
4457 Instruction *IsWorker =
4458 ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_NE, KernelInitCB,
4459 ConstantInt::getAllOnesValue(KernelInitCB->getType()),
4460 "thread.is_worker", InitBB);
4461 IsWorker->setDebugLoc(DLoc);
4462 BranchInst::Create(IsWorkerCheckBB, UserCodeEntryBB, IsWorker, InitBB);
4463
4464 Module &M = *Kernel->getParent();
4465 FunctionCallee BlockHwSizeFn =
4466 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4467 M, OMPRTL___kmpc_get_hardware_num_threads_in_block);
4468 FunctionCallee WarpSizeFn =
4469 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4470 M, OMPRTL___kmpc_get_warp_size);
4471 CallInst *BlockHwSize =
4472 CallInst::Create(BlockHwSizeFn, "block.hw_size", IsWorkerCheckBB);
4473 OMPInfoCache.setCallingConvention(BlockHwSizeFn, BlockHwSize);
4474 BlockHwSize->setDebugLoc(DLoc);
4475 CallInst *WarpSize =
4476 CallInst::Create(WarpSizeFn, "warp.size", IsWorkerCheckBB);
4477 OMPInfoCache.setCallingConvention(WarpSizeFn, WarpSize);
4478 WarpSize->setDebugLoc(DLoc);
4479 Instruction *BlockSize = BinaryOperator::CreateSub(
4480 BlockHwSize, WarpSize, "block.size", IsWorkerCheckBB);
4481 BlockSize->setDebugLoc(DLoc);
4482 Instruction *IsMainOrWorker = ICmpInst::Create(
4483 ICmpInst::ICmp, llvm::CmpInst::ICMP_SLT, KernelInitCB, BlockSize,
4484 "thread.is_main_or_worker", IsWorkerCheckBB);
4485 IsMainOrWorker->setDebugLoc(DLoc);
4486 BranchInst::Create(StateMachineBeginBB, StateMachineFinishedBB,
4487 IsMainOrWorker, IsWorkerCheckBB);
4488
4489 // Create local storage for the work function pointer.
4490 const DataLayout &DL = M.getDataLayout();
4491 Type *VoidPtrTy = PointerType::getUnqual(Ctx);
4492 Instruction *WorkFnAI =
4493 new AllocaInst(VoidPtrTy, DL.getAllocaAddrSpace(), nullptr,
4494 "worker.work_fn.addr", Kernel->getEntryBlock().begin());
4495 WorkFnAI->setDebugLoc(DLoc);
4496
4497 OMPInfoCache.OMPBuilder.updateToLocation(
4498 OpenMPIRBuilder::LocationDescription(
4499 IRBuilder<>::InsertPoint(StateMachineBeginBB,
4500 StateMachineBeginBB->end()),
4501 DLoc));
4502
4503 Value *Ident = KernelInfo::getIdentFromKernelEnvironment(KernelEnvC);
4504 Value *GTid = KernelInitCB;
4505
4506 FunctionCallee BarrierFn =
4507 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4508 M, OMPRTL___kmpc_barrier_simple_generic);
4509 CallInst *Barrier =
4510 CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineBeginBB);
4511 OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
4512 Barrier->setDebugLoc(DLoc);
4513
4514 if (WorkFnAI->getType()->getPointerAddressSpace() !=
4515 (unsigned int)AddressSpace::Generic) {
4516 WorkFnAI = new AddrSpaceCastInst(
4517 WorkFnAI, PointerType::get(Ctx, (unsigned int)AddressSpace::Generic),
4518 WorkFnAI->getName() + ".generic", StateMachineBeginBB);
4519 WorkFnAI->setDebugLoc(DLoc);
4520 }
4521
4522 FunctionCallee KernelParallelFn =
4523 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4524 M, OMPRTL___kmpc_kernel_parallel);
4525 CallInst *IsActiveWorker = CallInst::Create(
4526 KernelParallelFn, {WorkFnAI}, "worker.is_active", StateMachineBeginBB);
4527 OMPInfoCache.setCallingConvention(KernelParallelFn, IsActiveWorker);
4528 IsActiveWorker->setDebugLoc(DLoc);
4529 Instruction *WorkFn = new LoadInst(VoidPtrTy, WorkFnAI, "worker.work_fn",
4530 StateMachineBeginBB);
4531 WorkFn->setDebugLoc(DLoc);
4532
4533 FunctionType *ParallelRegionFnTy = FunctionType::get(
4534 Type::getVoidTy(Ctx), {Type::getInt16Ty(Ctx), Type::getInt32Ty(Ctx)},
4535 false);
4536
4537 Instruction *IsDone =
4538 ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFn,
4539 Constant::getNullValue(VoidPtrTy), "worker.is_done",
4540 StateMachineBeginBB);
4541 IsDone->setDebugLoc(DLoc);
4542 BranchInst::Create(StateMachineFinishedBB, StateMachineIsActiveCheckBB,
4543 IsDone, StateMachineBeginBB)
4544 ->setDebugLoc(DLoc);
4545
4546 BranchInst::Create(StateMachineIfCascadeCurrentBB,
4547 StateMachineDoneBarrierBB, IsActiveWorker,
4548 StateMachineIsActiveCheckBB)
4549 ->setDebugLoc(DLoc);
4550
4551 Value *ZeroArg =
4552 Constant::getNullValue(ParallelRegionFnTy->getParamType(0));
4553
4554 const unsigned int WrapperFunctionArgNo = 6;
4555
4556 // Now that we have most of the CFG skeleton it is time for the if-cascade
4557 // that checks the function pointer we got from the runtime against the
4558 // parallel regions we expect, if there are any.
4559 for (int I = 0, E = ReachedKnownParallelRegions.size(); I < E; ++I) {
4560 auto *CB = ReachedKnownParallelRegions[I];
4561 auto *ParallelRegion = dyn_cast<Function>(
4562 CB->getArgOperand(WrapperFunctionArgNo)->stripPointerCasts());
4563 BasicBlock *PRExecuteBB = BasicBlock::Create(
4564 Ctx, "worker_state_machine.parallel_region.execute", Kernel,
4565 StateMachineEndParallelBB);
4566 CallInst::Create(ParallelRegion, {ZeroArg, GTid}, "", PRExecuteBB)
4567 ->setDebugLoc(DLoc);
4568 BranchInst::Create(StateMachineEndParallelBB, PRExecuteBB)
4569 ->setDebugLoc(DLoc);
4570
4571 BasicBlock *PRNextBB =
4572 BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
4573 Kernel, StateMachineEndParallelBB);
4574 A.registerManifestAddedBasicBlock(*PRExecuteBB);
4575 A.registerManifestAddedBasicBlock(*PRNextBB);
4576
4577 // Check if we need to compare the pointer at all or if we can just
4578 // call the parallel region function.
4579 Value *IsPR;
4580 if (I + 1 < E || !ReachedUnknownParallelRegions.empty()) {
4581 Instruction *CmpI = ICmpInst::Create(
4582 ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFn, ParallelRegion,
4583 "worker.check_parallel_region", StateMachineIfCascadeCurrentBB);
4584 CmpI->setDebugLoc(DLoc);
4585 IsPR = CmpI;
4586 } else {
4587 IsPR = ConstantInt::getTrue(Ctx);
4588 }
4589
4590 BranchInst::Create(PRExecuteBB, PRNextBB, IsPR,
4591 StateMachineIfCascadeCurrentBB)
4592 ->setDebugLoc(DLoc);
4593 StateMachineIfCascadeCurrentBB = PRNextBB;
4594 }
4595
4596 // At the end of the if-cascade we place the indirect function pointer call
4597 // in case we might need it, that is if there can be parallel regions we
4598 // have not handled in the if-cascade above.
4599 if (!ReachedUnknownParallelRegions.empty()) {
4600 StateMachineIfCascadeCurrentBB->setName(
4601 "worker_state_machine.parallel_region.fallback.execute");
4602 CallInst::Create(ParallelRegionFnTy, WorkFn, {ZeroArg, GTid}, "",
4603 StateMachineIfCascadeCurrentBB)
4604 ->setDebugLoc(DLoc);
4605 }
4606 BranchInst::Create(StateMachineEndParallelBB,
4607 StateMachineIfCascadeCurrentBB)
4608 ->setDebugLoc(DLoc);
4609
4610 FunctionCallee EndParallelFn =
4611 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4612 M, OMPRTL___kmpc_kernel_end_parallel);
4613 CallInst *EndParallel =
4614 CallInst::Create(EndParallelFn, {}, "", StateMachineEndParallelBB);
4615 OMPInfoCache.setCallingConvention(EndParallelFn, EndParallel);
4616 EndParallel->setDebugLoc(DLoc);
4617 BranchInst::Create(StateMachineDoneBarrierBB, StateMachineEndParallelBB)
4618 ->setDebugLoc(DLoc);
4619
4620 CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineDoneBarrierBB)
4621 ->setDebugLoc(DLoc);
4622 BranchInst::Create(StateMachineBeginBB, StateMachineDoneBarrierBB)
4623 ->setDebugLoc(DLoc);
4624
4625 return true;
4626 }
4627
4628 /// Fixpoint iteration update function. Will be called every time a dependence
4629 /// changed its state (and in the beginning).
4630 ChangeStatus updateImpl(Attributor &A) override {
4631 KernelInfoState StateBefore = getState();
4632
4633 // When we leave this function this RAII will make sure the member
4634 // KernelEnvC is updated properly depending on the state. That member is
4635 // used for simplification of values and needs to be up to date at all
4636 // times.
4637 struct UpdateKernelEnvCRAII {
4638 AAKernelInfoFunction &AA;
4639
4640 UpdateKernelEnvCRAII(AAKernelInfoFunction &AA) : AA(AA) {}
4641
4642 ~UpdateKernelEnvCRAII() {
4643 if (!AA.KernelEnvC)
4644 return;
4645
4646 ConstantStruct *ExistingKernelEnvC =
4648
4649 if (!AA.isValidState()) {
4650 AA.KernelEnvC = ExistingKernelEnvC;
4651 return;
4652 }
4653
4654 if (!AA.ReachedKnownParallelRegions.isValidState())
4655 AA.setUseGenericStateMachineOfKernelEnvironment(
4656 KernelInfo::getUseGenericStateMachineFromKernelEnvironment(
4657 ExistingKernelEnvC));
4658
4659 if (!AA.SPMDCompatibilityTracker.isValidState())
4660 AA.setExecModeOfKernelEnvironment(
4661 KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC));
4662
4663 ConstantInt *MayUseNestedParallelismC =
4664 KernelInfo::getMayUseNestedParallelismFromKernelEnvironment(
4665 AA.KernelEnvC);
4666 ConstantInt *NewMayUseNestedParallelismC = ConstantInt::get(
4667 MayUseNestedParallelismC->getIntegerType(), AA.NestedParallelism);
4668 AA.setMayUseNestedParallelismOfKernelEnvironment(
4669 NewMayUseNestedParallelismC);
4670 }
4671 } RAII(*this);
4672
4673 // Callback to check a read/write instruction.
4674 auto CheckRWInst = [&](Instruction &I) {
4675 // We handle calls later.
4676 if (isa<CallBase>(I))
4677 return true;
4678 // We only care about write effects.
4679 if (!I.mayWriteToMemory())
4680 return true;
4681 if (auto *SI = dyn_cast<StoreInst>(&I)) {
4682 const auto *UnderlyingObjsAA = A.getAAFor<AAUnderlyingObjects>(
4683 *this, IRPosition::value(*SI->getPointerOperand()),
4684 DepClassTy::OPTIONAL);
4685 auto *HS = A.getAAFor<AAHeapToStack>(
4686 *this, IRPosition::function(*I.getFunction()),
4687 DepClassTy::OPTIONAL);
4688 if (UnderlyingObjsAA &&
4689 UnderlyingObjsAA->forallUnderlyingObjects([&](Value &Obj) {
4690 if (AA::isAssumedThreadLocalObject(A, Obj, *this))
4691 return true;
4692 // Check for AAHeapToStack moved objects which must not be
4693 // guarded.
4694 auto *CB = dyn_cast<CallBase>(&Obj);
4695 return CB && HS && HS->isAssumedHeapToStack(*CB);
4696 }))
4697 return true;
4698 }
4699
4700 // Insert instruction that needs guarding.
4701 SPMDCompatibilityTracker.insert(&I);
4702 return true;
4703 };
4704
4705 bool UsedAssumedInformationInCheckRWInst = false;
4706 if (!SPMDCompatibilityTracker.isAtFixpoint())
4707 if (!A.checkForAllReadWriteInstructions(
4708 CheckRWInst, *this, UsedAssumedInformationInCheckRWInst))
4709 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4710
4711 bool UsedAssumedInformationFromReachingKernels = false;
4712 if (!IsKernelEntry) {
4713 updateParallelLevels(A);
4714
4715 bool AllReachingKernelsKnown = true;
4716 updateReachingKernelEntries(A, AllReachingKernelsKnown);
4717 UsedAssumedInformationFromReachingKernels = !AllReachingKernelsKnown;
4718
4719 if (!SPMDCompatibilityTracker.empty()) {
4720 if (!ParallelLevels.isValidState())
4721 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4722 else if (!ReachingKernelEntries.isValidState())
4723 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4724 else {
4725 // Check if all reaching kernels agree on the mode as we can otherwise
4726 // not guard instructions. We might not be sure about the mode so we
4727 // we cannot fix the internal spmd-zation state either.
4728 int SPMD = 0, Generic = 0;
4729 for (auto *Kernel : ReachingKernelEntries) {
4730 auto *CBAA = A.getAAFor<AAKernelInfo>(
4731 *this, IRPosition::function(*Kernel), DepClassTy::OPTIONAL);
4732 if (CBAA && CBAA->SPMDCompatibilityTracker.isValidState() &&
4733 CBAA->SPMDCompatibilityTracker.isAssumed())
4734 ++SPMD;
4735 else
4736 ++Generic;
4737 if (!CBAA || !CBAA->SPMDCompatibilityTracker.isAtFixpoint())
4738 UsedAssumedInformationFromReachingKernels = true;
4739 }
4740 if (SPMD != 0 && Generic != 0)
4741 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4742 }
4743 }
4744 }
4745
4746 // Callback to check a call instruction.
4747 bool AllParallelRegionStatesWereFixed = true;
4748 bool AllSPMDStatesWereFixed = true;
4749 auto CheckCallInst = [&](Instruction &I) {
4750 auto &CB = cast<CallBase>(I);
4751 auto *CBAA = A.getAAFor<AAKernelInfo>(
4752 *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);
4753 if (!CBAA)
4754 return false;
4755 getState() ^= CBAA->getState();
4756 AllSPMDStatesWereFixed &= CBAA->SPMDCompatibilityTracker.isAtFixpoint();
4757 AllParallelRegionStatesWereFixed &=
4758 CBAA->ReachedKnownParallelRegions.isAtFixpoint();
4759 AllParallelRegionStatesWereFixed &=
4760 CBAA->ReachedUnknownParallelRegions.isAtFixpoint();
4761 return true;
4762 };
4763
4764 bool UsedAssumedInformationInCheckCallInst = false;
4765 if (!A.checkForAllCallLikeInstructions(
4766 CheckCallInst, *this, UsedAssumedInformationInCheckCallInst)) {
4767 LLVM_DEBUG(dbgs() << TAG
4768 << "Failed to visit all call-like instructions!\n";);
4769 return indicatePessimisticFixpoint();
4770 }
4771
4772 // If we haven't used any assumed information for the reached parallel
4773 // region states we can fix it.
4774 if (!UsedAssumedInformationInCheckCallInst &&
4775 AllParallelRegionStatesWereFixed) {
4776 ReachedKnownParallelRegions.indicateOptimisticFixpoint();
4777 ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
4778 }
4779
4780 // If we haven't used any assumed information for the SPMD state we can fix
4781 // it.
4782 if (!UsedAssumedInformationInCheckRWInst &&
4783 !UsedAssumedInformationInCheckCallInst &&
4784 !UsedAssumedInformationFromReachingKernels && AllSPMDStatesWereFixed)
4785 SPMDCompatibilityTracker.indicateOptimisticFixpoint();
4786
4787 return StateBefore == getState() ? ChangeStatus::UNCHANGED
4788 : ChangeStatus::CHANGED;
4789 }
4790
4791private:
4792 /// Update info regarding reaching kernels.
4793 void updateReachingKernelEntries(Attributor &A,
4794 bool &AllReachingKernelsKnown) {
4795 auto PredCallSite = [&](AbstractCallSite ACS) {
4796 Function *Caller = ACS.getInstruction()->getFunction();
4797
4798 assert(Caller && "Caller is nullptr");
4799
4800 auto *CAA = A.getOrCreateAAFor<AAKernelInfo>(
4801 IRPosition::function(*Caller), this, DepClassTy::REQUIRED);
4802 if (CAA && CAA->ReachingKernelEntries.isValidState()) {
4803 ReachingKernelEntries ^= CAA->ReachingKernelEntries;
4804 return true;
4805 }
4806
4807 // We lost track of the caller of the associated function, any kernel
4808 // could reach now.
4809 ReachingKernelEntries.indicatePessimisticFixpoint();
4810
4811 return true;
4812 };
4813
4814 if (!A.checkForAllCallSites(PredCallSite, *this,
4815 true /* RequireAllCallSites */,
4816 AllReachingKernelsKnown))
4817 ReachingKernelEntries.indicatePessimisticFixpoint();
4818 }
4819
4820 /// Update info regarding parallel levels.
4821 void updateParallelLevels(Attributor &A) {
4822 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4823 OMPInformationCache::RuntimeFunctionInfo &Parallel60RFI =
4824 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_60];
4825
4826 auto PredCallSite = [&](AbstractCallSite ACS) {
4827 Function *Caller = ACS.getInstruction()->getFunction();
4828
4829 assert(Caller && "Caller is nullptr");
4830
4831 auto *CAA =
4832 A.getOrCreateAAFor<AAKernelInfo>(IRPosition::function(*Caller));
4833 if (CAA && CAA->ParallelLevels.isValidState()) {
4834 // Any function that is called by `__kmpc_parallel_60` will not be
4835 // folded as the parallel level in the function is updated. In order to
4836 // get it right, all the analysis would depend on the implentation. That
4837 // said, if in the future any change to the implementation, the analysis
4838 // could be wrong. As a consequence, we are just conservative here.
4839 if (Caller == Parallel60RFI.Declaration) {
4840 ParallelLevels.indicatePessimisticFixpoint();
4841 return true;
4842 }
4843
4844 ParallelLevels ^= CAA->ParallelLevels;
4845
4846 return true;
4847 }
4848
4849 // We lost track of the caller of the associated function, any kernel
4850 // could reach now.
4851 ParallelLevels.indicatePessimisticFixpoint();
4852
4853 return true;
4854 };
4855
4856 bool AllCallSitesKnown = true;
4857 if (!A.checkForAllCallSites(PredCallSite, *this,
4858 true /* RequireAllCallSites */,
4859 AllCallSitesKnown))
4860 ParallelLevels.indicatePessimisticFixpoint();
4861 }
4862};
4863
4864/// The call site kernel info abstract attribute, basically, what can we say
4865/// about a call site with regards to the KernelInfoState. For now this simply
4866/// forwards the information from the callee.
4867struct AAKernelInfoCallSite : AAKernelInfo {
4868 AAKernelInfoCallSite(const IRPosition &IRP, Attributor &A)
4869 : AAKernelInfo(IRP, A) {}
4870
4871 /// See AbstractAttribute::initialize(...).
4872 void initialize(Attributor &A) override {
4873 AAKernelInfo::initialize(A);
4874
4875 CallBase &CB = cast<CallBase>(getAssociatedValue());
4876 auto *AssumptionAA = A.getAAFor<AAAssumptionInfo>(
4877 *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);
4878
4879 // Check for SPMD-mode assumptions.
4880 if (AssumptionAA && AssumptionAA->hasAssumption("ompx_spmd_amenable")) {
4881 indicateOptimisticFixpoint();
4882 return;
4883 }
4884
4885 // First weed out calls we do not care about, that is readonly/readnone
4886 // calls, intrinsics, and "no_openmp" calls. Neither of these can reach a
4887 // parallel region or anything else we are looking for.
4888 if (!CB.mayWriteToMemory() || isa<IntrinsicInst>(CB)) {
4889 indicateOptimisticFixpoint();
4890 return;
4891 }
4892
4893 // Next we check if we know the callee. If it is a known OpenMP function
4894 // we will handle them explicitly in the switch below. If it is not, we
4895 // will use an AAKernelInfo object on the callee to gather information and
4896 // merge that into the current state. The latter happens in the updateImpl.
4897 auto CheckCallee = [&](Function *Callee, unsigned NumCallees) {
4898 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4899 const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
4900 if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
4901 // Unknown caller or declarations are not analyzable, we give up.
4902 if (!Callee || !A.isFunctionIPOAmendable(*Callee)) {
4903
4904 // Unknown callees might contain parallel regions, except if they have
4905 // an appropriate assumption attached.
4906 if (!AssumptionAA ||
4907 !(AssumptionAA->hasAssumption("omp_no_openmp") ||
4908 AssumptionAA->hasAssumption("omp_no_parallelism")))
4909 ReachedUnknownParallelRegions.insert(&CB);
4910
4911 // If SPMDCompatibilityTracker is not fixed, we need to give up on the
4912 // idea we can run something unknown in SPMD-mode.
4913 if (!SPMDCompatibilityTracker.isAtFixpoint()) {
4914 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4915 SPMDCompatibilityTracker.insert(&CB);
4916 }
4917
4918 // We have updated the state for this unknown call properly, there
4919 // won't be any change so we indicate a fixpoint.
4920 indicateOptimisticFixpoint();
4921 }
4922 // If the callee is known and can be used in IPO, we will update the
4923 // state based on the callee state in updateImpl.
4924 return;
4925 }
4926 if (NumCallees > 1) {
4927 indicatePessimisticFixpoint();
4928 return;
4929 }
4930
4931 RuntimeFunction RF = It->getSecond();
4932 switch (RF) {
4933 // All the functions we know are compatible with SPMD mode.
4934 case OMPRTL___kmpc_is_spmd_exec_mode:
4935 case OMPRTL___kmpc_distribute_static_fini:
4936 case OMPRTL___kmpc_for_static_fini:
4937 case OMPRTL___kmpc_global_thread_num:
4938 case OMPRTL___kmpc_get_hardware_num_threads_in_block:
4939 case OMPRTL___kmpc_get_hardware_num_blocks:
4940 case OMPRTL___kmpc_single:
4941 case OMPRTL___kmpc_end_single:
4942 case OMPRTL___kmpc_master:
4943 case OMPRTL___kmpc_end_master:
4944 case OMPRTL___kmpc_barrier:
4945 case OMPRTL___kmpc_nvptx_parallel_reduce_nowait_v2:
4946 case OMPRTL___kmpc_nvptx_teams_reduce_nowait_v2:
4947 case OMPRTL___kmpc_error:
4948 case OMPRTL___kmpc_flush:
4949 case OMPRTL___kmpc_get_hardware_thread_id_in_block:
4950 case OMPRTL___kmpc_get_warp_size:
4951 case OMPRTL_omp_get_thread_num:
4952 case OMPRTL_omp_get_num_threads:
4953 case OMPRTL_omp_get_max_threads:
4954 case OMPRTL_omp_in_parallel:
4955 case OMPRTL_omp_get_dynamic:
4956 case OMPRTL_omp_get_cancellation:
4957 case OMPRTL_omp_get_nested:
4958 case OMPRTL_omp_get_schedule:
4959 case OMPRTL_omp_get_thread_limit:
4960 case OMPRTL_omp_get_supported_active_levels:
4961 case OMPRTL_omp_get_max_active_levels:
4962 case OMPRTL_omp_get_level:
4963 case OMPRTL_omp_get_ancestor_thread_num:
4964 case OMPRTL_omp_get_team_size:
4965 case OMPRTL_omp_get_active_level:
4966 case OMPRTL_omp_in_final:
4967 case OMPRTL_omp_get_proc_bind:
4968 case OMPRTL_omp_get_num_places:
4969 case OMPRTL_omp_get_num_procs:
4970 case OMPRTL_omp_get_place_proc_ids:
4971 case OMPRTL_omp_get_place_num:
4972 case OMPRTL_omp_get_partition_num_places:
4973 case OMPRTL_omp_get_partition_place_nums:
4974 case OMPRTL_omp_get_wtime:
4975 break;
4976 case OMPRTL___kmpc_distribute_static_init_4:
4977 case OMPRTL___kmpc_distribute_static_init_4u:
4978 case OMPRTL___kmpc_distribute_static_init_8:
4979 case OMPRTL___kmpc_distribute_static_init_8u:
4980 case OMPRTL___kmpc_for_static_init_4:
4981 case OMPRTL___kmpc_for_static_init_4u:
4982 case OMPRTL___kmpc_for_static_init_8:
4983 case OMPRTL___kmpc_for_static_init_8u: {
4984 // Check the schedule and allow static schedule in SPMD mode.
4985 unsigned ScheduleArgOpNo = 2;
4986 auto *ScheduleTypeCI =
4987 dyn_cast<ConstantInt>(CB.getArgOperand(ScheduleArgOpNo));
4988 unsigned ScheduleTypeVal =
4989 ScheduleTypeCI ? ScheduleTypeCI->getZExtValue() : 0;
4990 switch (OMPScheduleType(ScheduleTypeVal)) {
4991 case OMPScheduleType::UnorderedStatic:
4992 case OMPScheduleType::UnorderedStaticChunked:
4993 case OMPScheduleType::OrderedDistribute:
4994 case OMPScheduleType::OrderedDistributeChunked:
4995 break;
4996 default:
4997 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4998 SPMDCompatibilityTracker.insert(&CB);
4999 break;
5000 };
5001 } break;
5002 case OMPRTL___kmpc_target_init:
5003 KernelInitCB = &CB;
5004 break;
5005 case OMPRTL___kmpc_target_deinit:
5006 KernelDeinitCB = &CB;
5007 break;
5008 case OMPRTL___kmpc_parallel_60:
5009 if (!handleParallel60(A, CB))
5010 indicatePessimisticFixpoint();
5011 return;
5012 case OMPRTL___kmpc_omp_task:
5013 // We do not look into tasks right now, just give up.
5014 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5015 SPMDCompatibilityTracker.insert(&CB);
5016 ReachedUnknownParallelRegions.insert(&CB);
5017 break;
5018 case OMPRTL___kmpc_alloc_shared:
5019 case OMPRTL___kmpc_free_shared:
5020 // Return without setting a fixpoint, to be resolved in updateImpl.
5021 return;
5022 default:
5023 // Unknown OpenMP runtime calls cannot be executed in SPMD-mode,
5024 // generally. However, they do not hide parallel regions.
5025 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5026 SPMDCompatibilityTracker.insert(&CB);
5027 break;
5028 }
5029 // All other OpenMP runtime calls will not reach parallel regions so they
5030 // can be safely ignored for now. Since it is a known OpenMP runtime call
5031 // we have now modeled all effects and there is no need for any update.
5032 indicateOptimisticFixpoint();
5033 };
5034
5035 const auto *AACE =
5036 A.getAAFor<AACallEdges>(*this, getIRPosition(), DepClassTy::OPTIONAL);
5037 if (!AACE || !AACE->getState().isValidState() || AACE->hasUnknownCallee()) {
5038 CheckCallee(getAssociatedFunction(), 1);
5039 return;
5040 }
5041 const auto &OptimisticEdges = AACE->getOptimisticEdges();
5042 for (auto *Callee : OptimisticEdges) {
5043 CheckCallee(Callee, OptimisticEdges.size());
5044 if (isAtFixpoint())
5045 break;
5046 }
5047 }
5048
5049 ChangeStatus updateImpl(Attributor &A) override {
5050 // TODO: Once we have call site specific value information we can provide
5051 // call site specific liveness information and then it makes
5052 // sense to specialize attributes for call sites arguments instead of
5053 // redirecting requests to the callee argument.
5054 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
5055 KernelInfoState StateBefore = getState();
5056
5057 auto CheckCallee = [&](Function *F, int NumCallees) {
5058 const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(F);
5059
5060 // If F is not a runtime function, propagate the AAKernelInfo of the
5061 // callee.
5062 if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
5063 const IRPosition &FnPos = IRPosition::function(*F);
5064 auto *FnAA =
5065 A.getAAFor<AAKernelInfo>(*this, FnPos, DepClassTy::REQUIRED);
5066 if (!FnAA)
5067 return indicatePessimisticFixpoint();
5068 if (getState() == FnAA->getState())
5069 return ChangeStatus::UNCHANGED;
5070 getState() = FnAA->getState();
5071 return ChangeStatus::CHANGED;
5072 }
5073 if (NumCallees > 1)
5074 return indicatePessimisticFixpoint();
5075
5076 CallBase &CB = cast<CallBase>(getAssociatedValue());
5077 if (It->getSecond() == OMPRTL___kmpc_parallel_60) {
5078 if (!handleParallel60(A, CB))
5079 return indicatePessimisticFixpoint();
5080 return StateBefore == getState() ? ChangeStatus::UNCHANGED
5081 : ChangeStatus::CHANGED;
5082 }
5083
5084 // F is a runtime function that allocates or frees memory, check
5085 // AAHeapToStack and AAHeapToShared.
5086 assert(
5087 (It->getSecond() == OMPRTL___kmpc_alloc_shared ||
5088 It->getSecond() == OMPRTL___kmpc_free_shared) &&
5089 "Expected a __kmpc_alloc_shared or __kmpc_free_shared runtime call");
5090
5091 auto *HeapToStackAA = A.getAAFor<AAHeapToStack>(
5092 *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
5093 auto *HeapToSharedAA = A.getAAFor<AAHeapToShared>(
5094 *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
5095
5096 RuntimeFunction RF = It->getSecond();
5097
5098 switch (RF) {
5099 // If neither HeapToStack nor HeapToShared assume the call is removed,
5100 // assume SPMD incompatibility.
5101 case OMPRTL___kmpc_alloc_shared:
5102 if ((!HeapToStackAA || !HeapToStackAA->isAssumedHeapToStack(CB)) &&
5103 (!HeapToSharedAA || !HeapToSharedAA->isAssumedHeapToShared(CB)))
5104 SPMDCompatibilityTracker.insert(&CB);
5105 break;
5106 case OMPRTL___kmpc_free_shared:
5107 if ((!HeapToStackAA ||
5108 !HeapToStackAA->isAssumedHeapToStackRemovedFree(CB)) &&
5109 (!HeapToSharedAA ||
5110 !HeapToSharedAA->isAssumedHeapToSharedRemovedFree(CB)))
5111 SPMDCompatibilityTracker.insert(&CB);
5112 break;
5113 default:
5114 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
5115 SPMDCompatibilityTracker.insert(&CB);
5116 }
5117 return ChangeStatus::CHANGED;
5118 };
5119
5120 const auto *AACE =
5121 A.getAAFor<AACallEdges>(*this, getIRPosition(), DepClassTy::OPTIONAL);
5122 if (!AACE || !AACE->getState().isValidState() || AACE->hasUnknownCallee()) {
5123 if (Function *F = getAssociatedFunction())
5124 CheckCallee(F, /*NumCallees=*/1);
5125 } else {
5126 const auto &OptimisticEdges = AACE->getOptimisticEdges();
5127 for (auto *Callee : OptimisticEdges) {
5128 CheckCallee(Callee, OptimisticEdges.size());
5129 if (isAtFixpoint())
5130 break;
5131 }
5132 }
5133
5134 return StateBefore == getState() ? ChangeStatus::UNCHANGED
5135 : ChangeStatus::CHANGED;
5136 }
5137
5138 /// Deal with a __kmpc_parallel_60 call (\p CB). Returns true if the call was
5139 /// handled, if a problem occurred, false is returned.
5140 bool handleParallel60(Attributor &A, CallBase &CB) {
5141 const unsigned int NonWrapperFunctionArgNo = 5;
5142 const unsigned int WrapperFunctionArgNo = 6;
5143 auto ParallelRegionOpArgNo = SPMDCompatibilityTracker.isAssumed()
5144 ? NonWrapperFunctionArgNo
5145 : WrapperFunctionArgNo;
5146
5147 auto *ParallelRegion = dyn_cast<Function>(
5148 CB.getArgOperand(ParallelRegionOpArgNo)->stripPointerCasts());
5149 if (!ParallelRegion)
5150 return false;
5151
5152 ReachedKnownParallelRegions.insert(&CB);
5153 /// Check nested parallelism
5154 auto *FnAA = A.getAAFor<AAKernelInfo>(
5155 *this, IRPosition::function(*ParallelRegion), DepClassTy::OPTIONAL);
5156 NestedParallelism |= !FnAA || !FnAA->getState().isValidState() ||
5157 !FnAA->ReachedKnownParallelRegions.empty() ||
5158 !FnAA->ReachedKnownParallelRegions.isValidState() ||
5159 !FnAA->ReachedUnknownParallelRegions.isValidState() ||
5160 !FnAA->ReachedUnknownParallelRegions.empty();
5161 return true;
5162 }
5163};
5164
5165struct AAFoldRuntimeCall
5166 : public StateWrapper<BooleanState, AbstractAttribute> {
5167 using Base = StateWrapper<BooleanState, AbstractAttribute>;
5168
5169 AAFoldRuntimeCall(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
5170
5171 /// Statistics are tracked as part of manifest for now.
5172 void trackStatistics() const override {}
5173
5174 /// Create an abstract attribute biew for the position \p IRP.
5175 static AAFoldRuntimeCall &createForPosition(const IRPosition &IRP,
5176 Attributor &A);
5177
5178 /// See AbstractAttribute::getName()
5179 StringRef getName() const override { return "AAFoldRuntimeCall"; }
5180
5181 /// See AbstractAttribute::getIdAddr()
5182 const char *getIdAddr() const override { return &ID; }
5183
5184 /// This function should return true if the type of the \p AA is
5185 /// AAFoldRuntimeCall
5186 static bool classof(const AbstractAttribute *AA) {
5187 return (AA->getIdAddr() == &ID);
5188 }
5189
5190 static const char ID;
5191};
5192
5193struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
5194 AAFoldRuntimeCallCallSiteReturned(const IRPosition &IRP, Attributor &A)
5195 : AAFoldRuntimeCall(IRP, A) {}
5196
5197 /// See AbstractAttribute::getAsStr()
5198 const std::string getAsStr(Attributor *) const override {
5199 if (!isValidState())
5200 return "<invalid>";
5201
5202 std::string Str("simplified value: ");
5203
5204 if (!SimplifiedValue)
5205 return Str + std::string("none");
5206
5207 if (!*SimplifiedValue)
5208 return Str + std::string("nullptr");
5209
5210 if (ConstantInt *CI = dyn_cast<ConstantInt>(*SimplifiedValue))
5211 return Str + std::to_string(CI->getSExtValue());
5212
5213 return Str + std::string("unknown");
5214 }
5215
5216 void initialize(Attributor &A) override {
5218 indicatePessimisticFixpoint();
5219
5220 Function *Callee = getAssociatedFunction();
5221
5222 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
5223 const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
5224 assert(It != OMPInfoCache.RuntimeFunctionIDMap.end() &&
5225 "Expected a known OpenMP runtime function");
5226
5227 RFKind = It->getSecond();
5228
5229 CallBase &CB = cast<CallBase>(getAssociatedValue());
5230 A.registerSimplificationCallback(
5232 [&](const IRPosition &IRP, const AbstractAttribute *AA,
5233 bool &UsedAssumedInformation) -> std::optional<Value *> {
5234 assert((isValidState() || SimplifiedValue == nullptr) &&
5235 "Unexpected invalid state!");
5236
5237 if (!isAtFixpoint()) {
5238 UsedAssumedInformation = true;
5239 if (AA)
5240 A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
5241 }
5242 return SimplifiedValue;
5243 });
5244 }
5245
5246 ChangeStatus updateImpl(Attributor &A) override {
5247 ChangeStatus Changed = ChangeStatus::UNCHANGED;
5248 switch (RFKind) {
5249 case OMPRTL___kmpc_is_spmd_exec_mode:
5250 Changed |= foldIsSPMDExecMode(A);
5251 break;
5252 case OMPRTL___kmpc_parallel_level:
5253 Changed |= foldParallelLevel(A);
5254 break;
5255 case OMPRTL___kmpc_get_hardware_num_threads_in_block:
5256 Changed = Changed | foldKernelFnAttribute(A, "omp_target_thread_limit");
5257 break;
5258 case OMPRTL___kmpc_get_hardware_num_blocks:
5259 Changed = Changed | foldKernelFnAttribute(A, "omp_target_num_teams");
5260 break;
5261 default:
5262 llvm_unreachable("Unhandled OpenMP runtime function!");
5263 }
5264
5265 return Changed;
5266 }
5267
5268 ChangeStatus manifest(Attributor &A) override {
5269 ChangeStatus Changed = ChangeStatus::UNCHANGED;
5270
5271 if (SimplifiedValue && *SimplifiedValue) {
5272 Instruction &I = *getCtxI();
5273 A.changeAfterManifest(IRPosition::inst(I), **SimplifiedValue);
5274 A.deleteAfterManifest(I);
5275
5276 CallBase *CB = dyn_cast<CallBase>(&I);
5277 auto Remark = [&](OptimizationRemark OR) {
5278 if (auto *C = dyn_cast<ConstantInt>(*SimplifiedValue))
5279 return OR << "Replacing OpenMP runtime call "
5280 << CB->getCalledFunction()->getName() << " with "
5281 << ore::NV("FoldedValue", C->getZExtValue()) << ".";
5282 return OR << "Replacing OpenMP runtime call "
5283 << CB->getCalledFunction()->getName() << ".";
5284 };
5285
5286 if (CB && EnableVerboseRemarks)
5287 A.emitRemark<OptimizationRemark>(CB, "OMP180", Remark);
5288
5289 LLVM_DEBUG(dbgs() << TAG << "Replacing runtime call: " << I << " with "
5290 << **SimplifiedValue << "\n");
5291
5292 Changed = ChangeStatus::CHANGED;
5293 }
5294
5295 return Changed;
5296 }
5297
5298 ChangeStatus indicatePessimisticFixpoint() override {
5299 SimplifiedValue = nullptr;
5300 return AAFoldRuntimeCall::indicatePessimisticFixpoint();
5301 }
5302
5303private:
5304 /// Fold __kmpc_is_spmd_exec_mode into a constant if possible.
5305 ChangeStatus foldIsSPMDExecMode(Attributor &A) {
5306 std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
5307
5308 unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
5309 unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
5310 auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
5311 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
5312
5313 if (!CallerKernelInfoAA ||
5314 !CallerKernelInfoAA->ReachingKernelEntries.isValidState())
5315 return indicatePessimisticFixpoint();
5316
5317 for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
5318 auto *AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
5319 DepClassTy::REQUIRED);
5320
5321 if (!AA || !AA->isValidState()) {
5322 SimplifiedValue = nullptr;
5323 return indicatePessimisticFixpoint();
5324 }
5325
5326 if (AA->SPMDCompatibilityTracker.isAssumed()) {
5327 if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5328 ++KnownSPMDCount;
5329 else
5330 ++AssumedSPMDCount;
5331 } else {
5332 if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5333 ++KnownNonSPMDCount;
5334 else
5335 ++AssumedNonSPMDCount;
5336 }
5337 }
5338
5339 if ((AssumedSPMDCount + KnownSPMDCount) &&
5340 (AssumedNonSPMDCount + KnownNonSPMDCount))
5341 return indicatePessimisticFixpoint();
5342
5343 auto &Ctx = getAnchorValue().getContext();
5344 if (KnownSPMDCount || AssumedSPMDCount) {
5345 assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
5346 "Expected only SPMD kernels!");
5347 // All reaching kernels are in SPMD mode. Update all function calls to
5348 // __kmpc_is_spmd_exec_mode to 1.
5349 SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), true);
5350 } else if (KnownNonSPMDCount || AssumedNonSPMDCount) {
5351 assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
5352 "Expected only non-SPMD kernels!");
5353 // All reaching kernels are in non-SPMD mode. Update all function
5354 // calls to __kmpc_is_spmd_exec_mode to 0.
5355 SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), false);
5356 } else {
5357 // We have empty reaching kernels, therefore we cannot tell if the
5358 // associated call site can be folded. At this moment, SimplifiedValue
5359 // must be none.
5360 assert(!SimplifiedValue && "SimplifiedValue should be none");
5361 }
5362
5363 return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
5364 : ChangeStatus::CHANGED;
5365 }
5366
5367 /// Fold __kmpc_parallel_level into a constant if possible.
5368 ChangeStatus foldParallelLevel(Attributor &A) {
5369 std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
5370
5371 auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
5372 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
5373
5374 if (!CallerKernelInfoAA ||
5375 !CallerKernelInfoAA->ParallelLevels.isValidState())
5376 return indicatePessimisticFixpoint();
5377
5378 if (!CallerKernelInfoAA->ReachingKernelEntries.isValidState())
5379 return indicatePessimisticFixpoint();
5380
5381 if (CallerKernelInfoAA->ReachingKernelEntries.empty()) {
5382 assert(!SimplifiedValue &&
5383 "SimplifiedValue should keep none at this point");
5384 return ChangeStatus::UNCHANGED;
5385 }
5386
5387 unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
5388 unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
5389 for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
5390 auto *AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
5391 DepClassTy::REQUIRED);
5392 if (!AA || !AA->SPMDCompatibilityTracker.isValidState())
5393 return indicatePessimisticFixpoint();
5394
5395 if (AA->SPMDCompatibilityTracker.isAssumed()) {
5396 if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5397 ++KnownSPMDCount;
5398 else
5399 ++AssumedSPMDCount;
5400 } else {
5401 if (AA->SPMDCompatibilityTracker.isAtFixpoint())
5402 ++KnownNonSPMDCount;
5403 else
5404 ++AssumedNonSPMDCount;
5405 }
5406 }
5407
5408 if ((AssumedSPMDCount + KnownSPMDCount) &&
5409 (AssumedNonSPMDCount + KnownNonSPMDCount))
5410 return indicatePessimisticFixpoint();
5411
5412 auto &Ctx = getAnchorValue().getContext();
5413 // If the caller can only be reached by SPMD kernel entries, the parallel
5414 // level is 1. Similarly, if the caller can only be reached by non-SPMD
5415 // kernel entries, it is 0.
5416 if (AssumedSPMDCount || KnownSPMDCount) {
5417 assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
5418 "Expected only SPMD kernels!");
5419 SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 1);
5420 } else {
5421 assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
5422 "Expected only non-SPMD kernels!");
5423 SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 0);
5424 }
5425 return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
5426 : ChangeStatus::CHANGED;
5427 }
5428
5429 ChangeStatus foldKernelFnAttribute(Attributor &A, llvm::StringRef Attr) {
5430 // Specialize only if all the calls agree with the attribute constant value
5431 int32_t CurrentAttrValue = -1;
5432 std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
5433
5434 auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
5435 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
5436
5437 if (!CallerKernelInfoAA ||
5438 !CallerKernelInfoAA->ReachingKernelEntries.isValidState())
5439 return indicatePessimisticFixpoint();
5440
5441 // Iterate over the kernels that reach this function
5442 for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {
5443 int32_t NextAttrVal = K->getFnAttributeAsParsedInteger(Attr, -1);
5444
5445 if (NextAttrVal == -1 ||
5446 (CurrentAttrValue != -1 && CurrentAttrValue != NextAttrVal))
5447 return indicatePessimisticFixpoint();
5448 CurrentAttrValue = NextAttrVal;
5449 }
5450
5451 if (CurrentAttrValue != -1) {
5452 auto &Ctx = getAnchorValue().getContext();
5453 SimplifiedValue =
5454 ConstantInt::get(Type::getInt32Ty(Ctx), CurrentAttrValue);
5455 }
5456 return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
5457 : ChangeStatus::CHANGED;
5458 }
5459
5460 /// An optional value the associated value is assumed to fold to. That is, we
5461 /// assume the associated value (which is a call) can be replaced by this
5462 /// simplified value.
5463 std::optional<Value *> SimplifiedValue;
5464
5465 /// The runtime function kind of the callee of the associated call site.
5466 RuntimeFunction RFKind;
5467};
5468
5469} // namespace
5470
5471/// Register folding callsite
5472void OpenMPOpt::registerFoldRuntimeCall(RuntimeFunction RF) {
5473 auto &RFI = OMPInfoCache.RFIs[RF];
5474 RFI.foreachUse(SCC, [&](Use &U, Function &F) {
5475 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &RFI);
5476 if (!CI)
5477 return false;
5478 A.getOrCreateAAFor<AAFoldRuntimeCall>(
5479 IRPosition::callsite_returned(*CI), /* QueryingAA */ nullptr,
5480 DepClassTy::NONE, /* ForceUpdate */ false,
5481 /* UpdateAfterInit */ false);
5482 return false;
5483 });
5484}
5485
5486void OpenMPOpt::registerAAs(bool IsModulePass) {
5487 if (SCC.empty())
5488 return;
5489
5490 if (IsModulePass) {
5491 // Ensure we create the AAKernelInfo AAs first and without triggering an
5492 // update. This will make sure we register all value simplification
5493 // callbacks before any other AA has the chance to create an AAValueSimplify
5494 // or similar.
5495 auto CreateKernelInfoCB = [&](Use &, Function &Kernel) {
5496 A.getOrCreateAAFor<AAKernelInfo>(
5497 IRPosition::function(Kernel), /* QueryingAA */ nullptr,
5498 DepClassTy::NONE, /* ForceUpdate */ false,
5499 /* UpdateAfterInit */ false);
5500 return false;
5501 };
5502 OMPInformationCache::RuntimeFunctionInfo &InitRFI =
5503 OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
5504 InitRFI.foreachUse(SCC, CreateKernelInfoCB);
5505
5506 registerFoldRuntimeCall(OMPRTL___kmpc_is_spmd_exec_mode);
5507 registerFoldRuntimeCall(OMPRTL___kmpc_parallel_level);
5508 registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_threads_in_block);
5509 registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_blocks);
5510 }
5511
5512 // Create CallSite AA for all Getters.
5513 if (DeduceICVValues) {
5514 for (int Idx = 0; Idx < OMPInfoCache.ICVs.size() - 1; ++Idx) {
5515 auto ICVInfo = OMPInfoCache.ICVs[static_cast<InternalControlVar>(Idx)];
5516
5517 auto &GetterRFI = OMPInfoCache.RFIs[ICVInfo.Getter];
5518
5519 auto CreateAA = [&](Use &U, Function &Caller) {
5520 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &GetterRFI);
5521 if (!CI)
5522 return false;
5523
5524 auto &CB = cast<CallBase>(*CI);
5525
5526 IRPosition CBPos = IRPosition::callsite_function(CB);
5527 A.getOrCreateAAFor<AAICVTracker>(CBPos);
5528 return false;
5529 };
5530
5531 GetterRFI.foreachUse(SCC, CreateAA);
5532 }
5533 }
5534
5535 // Create an ExecutionDomain AA for every function and a HeapToStack AA for
5536 // every function if there is a device kernel.
5537 if (!isOpenMPDevice(M))
5538 return;
5539
5540 for (auto *F : SCC) {
5541 if (F->isDeclaration())
5542 continue;
5543
5544 // We look at internal functions only on-demand but if any use is not a
5545 // direct call or outside the current set of analyzed functions, we have
5546 // to do it eagerly.
5547 if (F->hasLocalLinkage()) {
5548 if (llvm::all_of(F->uses(), [this](const Use &U) {
5549 const auto *CB = dyn_cast<CallBase>(U.getUser());
5550 return CB && CB->isCallee(&U) &&
5551 A.isRunOn(const_cast<Function *>(CB->getCaller()));
5552 }))
5553 continue;
5554 }
5555 registerAAsForFunction(A, *F);
5556 }
5557}
5558
5559void OpenMPOpt::registerAAsForFunction(Attributor &A, const Function &F) {
5561 A.getOrCreateAAFor<AAHeapToShared>(IRPosition::function(F));
5562 A.getOrCreateAAFor<AAExecutionDomain>(IRPosition::function(F));
5564 A.getOrCreateAAFor<AAHeapToStack>(IRPosition::function(F));
5565 if (F.hasFnAttribute(Attribute::Convergent))
5566 A.getOrCreateAAFor<AANonConvergent>(IRPosition::function(F));
5567
5568 for (auto &I : instructions(F)) {
5569 if (auto *LI = dyn_cast<LoadInst>(&I)) {
5570 bool UsedAssumedInformation = false;
5571 A.getAssumedSimplified(IRPosition::value(*LI), /* AA */ nullptr,
5572 UsedAssumedInformation, AA::Interprocedural);
5573 A.getOrCreateAAFor<AAAddressSpace>(
5574 IRPosition::value(*LI->getPointerOperand()));
5575 continue;
5576 }
5577 if (auto *CI = dyn_cast<CallBase>(&I)) {
5578 if (CI->isIndirectCall())
5579 A.getOrCreateAAFor<AAIndirectCallInfo>(
5581 }
5582 if (auto *SI = dyn_cast<StoreInst>(&I)) {
5583 A.getOrCreateAAFor<AAIsDead>(IRPosition::value(*SI));
5584 A.getOrCreateAAFor<AAAddressSpace>(
5585 IRPosition::value(*SI->getPointerOperand()));
5586 continue;
5587 }
5588 if (auto *FI = dyn_cast<FenceInst>(&I)) {
5589 A.getOrCreateAAFor<AAIsDead>(IRPosition::value(*FI));
5590 continue;
5591 }
5592 if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
5593 if (II->getIntrinsicID() == Intrinsic::assume) {
5594 A.getOrCreateAAFor<AAPotentialValues>(
5595 IRPosition::value(*II->getArgOperand(0)));
5596 continue;
5597 }
5598 }
5599 }
5600}
5601
5602const char AAICVTracker::ID = 0;
5603const char AAKernelInfo::ID = 0;
5604const char AAExecutionDomain::ID = 0;
5605const char AAHeapToShared::ID = 0;
5606const char AAFoldRuntimeCall::ID = 0;
5607
5608AAICVTracker &AAICVTracker::createForPosition(const IRPosition &IRP,
5609 Attributor &A) {
5610 AAICVTracker *AA = nullptr;
5611 switch (IRP.getPositionKind()) {
5616 llvm_unreachable("ICVTracker can only be created for function position!");
5618 AA = new (A.Allocator) AAICVTrackerFunctionReturned(IRP, A);
5619 break;
5621 AA = new (A.Allocator) AAICVTrackerCallSiteReturned(IRP, A);
5622 break;
5624 AA = new (A.Allocator) AAICVTrackerCallSite(IRP, A);
5625 break;
5627 AA = new (A.Allocator) AAICVTrackerFunction(IRP, A);
5628 break;
5629 }
5630
5631 return *AA;
5632}
5633
5635 Attributor &A) {
5636 AAExecutionDomainFunction *AA = nullptr;
5637 switch (IRP.getPositionKind()) {
5646 "AAExecutionDomain can only be created for function position!");
5648 AA = new (A.Allocator) AAExecutionDomainFunction(IRP, A);
5649 break;
5650 }
5651
5652 return *AA;
5653}
5654
5655AAHeapToShared &AAHeapToShared::createForPosition(const IRPosition &IRP,
5656 Attributor &A) {
5657 AAHeapToSharedFunction *AA = nullptr;
5658 switch (IRP.getPositionKind()) {
5667 "AAHeapToShared can only be created for function position!");
5669 AA = new (A.Allocator) AAHeapToSharedFunction(IRP, A);
5670 break;
5671 }
5672
5673 return *AA;
5674}
5675
5676AAKernelInfo &AAKernelInfo::createForPosition(const IRPosition &IRP,
5677 Attributor &A) {
5678 AAKernelInfo *AA = nullptr;
5679 switch (IRP.getPositionKind()) {
5686 llvm_unreachable("KernelInfo can only be created for function position!");
5688 AA = new (A.Allocator) AAKernelInfoCallSite(IRP, A);
5689 break;
5691 AA = new (A.Allocator) AAKernelInfoFunction(IRP, A);
5692 break;
5693 }
5694
5695 return *AA;
5696}
5697
5698AAFoldRuntimeCall &AAFoldRuntimeCall::createForPosition(const IRPosition &IRP,
5699 Attributor &A) {
5700 AAFoldRuntimeCall *AA = nullptr;
5701 switch (IRP.getPositionKind()) {
5709 llvm_unreachable("KernelInfo can only be created for call site position!");
5711 AA = new (A.Allocator) AAFoldRuntimeCallCallSiteReturned(IRP, A);
5712 break;
5713 }
5714
5715 return *AA;
5716}
5717
5719 if (!containsOpenMP(M))
5720 return PreservedAnalyses::all();
5722 return PreservedAnalyses::all();
5723
5726 KernelSet Kernels = getDeviceKernels(M);
5727
5729 LLVM_DEBUG(dbgs() << TAG << "Module before OpenMPOpt Module Pass:\n" << M);
5730
5731 auto IsCalled = [&](Function &F) {
5732 if (Kernels.contains(&F))
5733 return true;
5734 return !F.use_empty();
5735 };
5736
5737 auto EmitRemark = [&](Function &F) {
5738 auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F);
5739 ORE.emit([&]() {
5740 OptimizationRemarkAnalysis ORA(DEBUG_TYPE, "OMP140", &F);
5741 return ORA << "Could not internalize function. "
5742 << "Some optimizations may not be possible. [OMP140]";
5743 });
5744 };
5745
5746 bool Changed = false;
5747
5748 // Create internal copies of each function if this is a kernel Module. This
5749 // allows iterprocedural passes to see every call edge.
5750 DenseMap<Function *, Function *> InternalizedMap;
5751 if (isOpenMPDevice(M)) {
5752 SmallPtrSet<Function *, 16> InternalizeFns;
5753 for (Function &F : M)
5754 if (!F.isDeclaration() && !Kernels.contains(&F) && IsCalled(F) &&
5757 InternalizeFns.insert(&F);
5758 } else if (!F.hasLocalLinkage() && !F.hasFnAttribute(Attribute::Cold)) {
5759 EmitRemark(F);
5760 }
5761 }
5762
5763 Changed |=
5764 Attributor::internalizeFunctions(InternalizeFns, InternalizedMap);
5765 }
5766
5767 // Look at every function in the Module unless it was internalized.
5768 SetVector<Function *> Functions;
5770 for (Function &F : M)
5771 if (!F.isDeclaration() && !InternalizedMap.lookup(&F)) {
5772 SCC.push_back(&F);
5773 Functions.insert(&F);
5774 }
5775
5776 if (SCC.empty())
5778
5779 AnalysisGetter AG(FAM);
5780
5781 auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
5782 return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F);
5783 };
5784
5785 BumpPtrAllocator Allocator;
5786 CallGraphUpdater CGUpdater;
5787
5788 bool PostLink = LTOPhase == ThinOrFullLTOPhase::FullLTOPostLink ||
5791 OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ nullptr, PostLink);
5792
5793 unsigned MaxFixpointIterations =
5795
5796 AttributorConfig AC(CGUpdater);
5798 AC.IsModulePass = true;
5799 AC.RewriteSignatures = false;
5800 AC.MaxFixpointIterations = MaxFixpointIterations;
5801 AC.OREGetter = OREGetter;
5802 AC.PassName = DEBUG_TYPE;
5803 AC.InitializationCallback = OpenMPOpt::registerAAsForFunction;
5804 AC.IPOAmendableCB = [](const Function &F) {
5805 return F.hasFnAttribute("kernel");
5806 };
5807
5808 Attributor A(Functions, InfoCache, AC);
5809
5810 OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
5811 Changed |= OMPOpt.run(true);
5812
5813 // Optionally inline device functions for potentially better performance.
5815 for (Function &F : M)
5816 if (!F.isDeclaration() && !Kernels.contains(&F) &&
5817 !F.hasFnAttribute(Attribute::NoInline))
5818 F.addFnAttr(Attribute::AlwaysInline);
5819
5821 LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt Module Pass:\n" << M);
5822
5823 if (Changed)
5824 return PreservedAnalyses::none();
5825
5826 return PreservedAnalyses::all();
5827}
5828
5831 LazyCallGraph &CG,
5832 CGSCCUpdateResult &UR) {
5833 if (!containsOpenMP(*C.begin()->getFunction().getParent()))
5834 return PreservedAnalyses::all();
5836 return PreservedAnalyses::all();
5837
5839 // If there are kernels in the module, we have to run on all SCC's.
5840 for (LazyCallGraph::Node &N : C) {
5841 Function *Fn = &N.getFunction();
5842 SCC.push_back(Fn);
5843 }
5844
5845 if (SCC.empty())
5846 return PreservedAnalyses::all();
5847
5848 Module &M = *C.begin()->getFunction().getParent();
5849
5851 LLVM_DEBUG(dbgs() << TAG << "Module before OpenMPOpt CGSCC Pass:\n" << M);
5852
5854 AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager();
5855
5856 AnalysisGetter AG(FAM);
5857
5858 auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
5859 return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F);
5860 };
5861
5862 BumpPtrAllocator Allocator;
5863 CallGraphUpdater CGUpdater;
5864 CGUpdater.initialize(CG, C, AM, UR);
5865
5866 bool PostLink = LTOPhase == ThinOrFullLTOPhase::FullLTOPostLink ||
5870 OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, Allocator,
5871 /*CGSCC*/ &Functions, PostLink);
5872
5873 unsigned MaxFixpointIterations =
5875
5876 AttributorConfig AC(CGUpdater);
5878 AC.IsModulePass = false;
5879 AC.RewriteSignatures = false;
5880 AC.MaxFixpointIterations = MaxFixpointIterations;
5881 AC.OREGetter = OREGetter;
5882 AC.PassName = DEBUG_TYPE;
5883 AC.InitializationCallback = OpenMPOpt::registerAAsForFunction;
5884
5885 Attributor A(Functions, InfoCache, AC);
5886
5887 OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
5888 bool Changed = OMPOpt.run(false);
5889
5891 LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt CGSCC Pass:\n" << M);
5892
5893 if (Changed)
5894 return PreservedAnalyses::none();
5895
5896 return PreservedAnalyses::all();
5897}
5898
5900 return Fn.hasFnAttribute("kernel");
5901}
5902
5904 KernelSet Kernels;
5905
5906 for (Function &F : M)
5907 if (F.hasKernelCallingConv()) {
5908 // We are only interested in OpenMP target regions. Others, such as
5909 // kernels generated by CUDA but linked together, are not interesting to
5910 // this pass.
5911 if (isOpenMPKernel(F)) {
5912 ++NumOpenMPTargetRegionKernels;
5913 Kernels.insert(&F);
5914 } else
5915 ++NumNonOpenMPTargetRegionKernels;
5916 }
5917
5918 return Kernels;
5919}
5920
5922 Metadata *MD = M.getModuleFlag("openmp");
5923 if (!MD)
5924 return false;
5925
5926 return true;
5927}
5928
5930 Metadata *MD = M.getModuleFlag("openmp-device");
5931 if (!MD)
5932 return false;
5933
5934 return true;
5935}
@ Generic
assert(UImm &&(UImm !=~static_cast< T >(0)) &&"Invalid immediate!")
MachineBasicBlock MachineBasicBlock::iterator DebugLoc DL
Expand Atomic instructions
static cl::opt< unsigned > SetFixpointIterations("attributor-max-iterations", cl::Hidden, cl::desc("Maximal number of fixpoint iterations."), cl::init(32))
static const Function * getParent(const Value *V)
static GCRegistry::Add< ErlangGC > A("erlang", "erlang-compatible garbage collector")
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
This file provides interfaces used to manipulate a call graph, regardless if it is a "old style" Call...
This file provides interfaces used to build and manipulate a call graph, which is a very useful tool ...
This file contains the declarations for the subclasses of Constant, which represent the different fla...
dxil pretty DXIL Metadata Pretty Printer
This file defines the DenseSet and SmallDenseSet classes.
This file defines an array type that can be indexed using scoped enum values.
#define DEBUG_TYPE
static void emitRemark(const Function &F, OptimizationRemarkEmitter &ORE, bool Skip)
Loop::LoopBounds::Direction Direction
Definition LoopInfo.cpp:231
#define F(x, y, z)
Definition MD5.cpp:54
#define I(x, y, z)
Definition MD5.cpp:57
Machine Check Debug Module
This file provides utility analysis objects describing memory locations.
#define T
uint64_t IntrinsicInst * II
This file defines constans and helpers used when dealing with OpenMP.
This file defines constans that will be used by both host and device compilation.
static constexpr auto TAG
static cl::opt< bool > HideMemoryTransferLatency("openmp-hide-memory-transfer-latency", cl::desc("[WIP] Tries to hide the latency of host to device memory" " transfers"), cl::Hidden, cl::init(false))
static cl::opt< bool > DisableOpenMPOptStateMachineRewrite("openmp-opt-disable-state-machine-rewrite", cl::desc("Disable OpenMP optimizations that replace the state machine."), cl::Hidden, cl::init(false))
static cl::opt< bool > EnableParallelRegionMerging("openmp-opt-enable-merging", cl::desc("Enable the OpenMP region merging optimization."), cl::Hidden, cl::init(false))
static cl::opt< bool > PrintModuleAfterOptimizations("openmp-opt-print-module-after", cl::desc("Print the current module after OpenMP optimizations."), cl::Hidden, cl::init(false))
#define KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MEMBER)
#define KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MEMBER, IDX)
#define KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MEMBER)
static cl::opt< bool > PrintOpenMPKernels("openmp-print-gpu-kernels", cl::init(false), cl::Hidden)
static cl::opt< bool > DisableOpenMPOptFolding("openmp-opt-disable-folding", cl::desc("Disable OpenMP optimizations involving folding."), cl::Hidden, cl::init(false))
static cl::opt< bool > PrintModuleBeforeOptimizations("openmp-opt-print-module-before", cl::desc("Print the current module before OpenMP optimizations."), cl::Hidden, cl::init(false))
static cl::opt< unsigned > SetFixpointIterations("openmp-opt-max-iterations", cl::Hidden, cl::desc("Maximal number of attributor iterations."), cl::init(256))
static cl::opt< bool > DisableInternalization("openmp-opt-disable-internalization", cl::desc("Disable function internalization."), cl::Hidden, cl::init(false))
static cl::opt< bool > PrintICVValues("openmp-print-icv-values", cl::init(false), cl::Hidden)
static cl::opt< bool > DisableOpenMPOptimizations("openmp-opt-disable", cl::desc("Disable OpenMP specific optimizations."), cl::Hidden, cl::init(false))
static cl::opt< unsigned > SharedMemoryLimit("openmp-opt-shared-limit", cl::Hidden, cl::desc("Maximum amount of shared memory to use."), cl::init(std::numeric_limits< unsigned >::max()))
static cl::opt< bool > EnableVerboseRemarks("openmp-opt-verbose-remarks", cl::desc("Enables more verbose remarks."), cl::Hidden, cl::init(false))
static cl::opt< bool > DisableOpenMPOptDeglobalization("openmp-opt-disable-deglobalization", cl::desc("Disable OpenMP optimizations involving deglobalization."), cl::Hidden, cl::init(false))
static cl::opt< bool > DisableOpenMPOptBarrierElimination("openmp-opt-disable-barrier-elimination", cl::desc("Disable OpenMP optimizations that eliminate barriers."), cl::Hidden, cl::init(false))
#define DEBUG_TYPE
Definition OpenMPOpt.cpp:67
static cl::opt< bool > DeduceICVValues("openmp-deduce-icv-values", cl::init(false), cl::Hidden)
#define KERNEL_ENVIRONMENT_IDX(MEMBER, IDX)
#define KERNEL_ENVIRONMENT_GETTER(MEMBER, RETURNTYPE)
static cl::opt< bool > DisableOpenMPOptSPMDization("openmp-opt-disable-spmdization", cl::desc("Disable OpenMP optimizations involving SPMD-ization."), cl::Hidden, cl::init(false))
static cl::opt< bool > AlwaysInlineDeviceFunctions("openmp-opt-inline-device", cl::desc("Inline all applicable functions on the device."), cl::Hidden, cl::init(false))
#define P(N)
FunctionAnalysisManager FAM
This file builds on the ADT/GraphTraits.h file to build a generic graph post order iterator.
static StringRef getName(Value *V)
R600 Clause Merge
Basic Register Allocator
std::pair< BasicBlock *, BasicBlock * > Edge
static bool contains(SmallPtrSetImpl< ConstantExpr * > &Cache, ConstantExpr *Expr, Constant *C)
Definition Value.cpp:487
This file implements a set that has insertion order iteration characteristics.
This file defines the SmallPtrSet class.
This file defines the SmallVector class.
This file defines the 'Statistic' class, which is designed to be an easy way to expose various metric...
#define STATISTIC(VARNAME, DESC)
Definition Statistic.h:171
This file contains some functions that are useful when dealing with strings.
#define LLVM_DEBUG(...)
Definition Debug.h:114
static const int BlockSize
Definition TarWriter.cpp:33
static void initialize(TargetLibraryInfoImpl &TLI, const Triple &T, const llvm::StringTable &StandardNames, VectorLibrary VecLib)
Initialize the set of available library functions based on the specified target triple.
Value * RHS
static cl::opt< unsigned > MaxThreads("xcore-max-threads", cl::Optional, cl::desc("Maximum number of threads (for emulation thread-local storage)"), cl::Hidden, cl::value_desc("number"), cl::init(8))
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
size_t size() const
size - Get the array size.
Definition ArrayRef.h:142
iterator end()
Definition BasicBlock.h:483
iterator begin()
Instruction iterator methods.
Definition BasicBlock.h:470
LLVM_ABI const_iterator getFirstInsertionPt() const
Returns an iterator to the first instruction in this block that is suitable for inserting a non-PHI i...
const Function * getParent() const
Return the enclosing method, or null if none.
Definition BasicBlock.h:213
reverse_iterator rbegin()
Definition BasicBlock.h:486
static BasicBlock * Create(LLVMContext &Context, const Twine &Name="", Function *Parent=nullptr, BasicBlock *InsertBefore=nullptr)
Creates a new BasicBlock.
Definition BasicBlock.h:206
LLVM_ABI BasicBlock * splitBasicBlock(iterator I, const Twine &BBName="", bool Before=false)
Split the basic block into two basic blocks at the specified instruction.
LLVM_ABI const BasicBlock * getUniqueSuccessor() const
Return the successor of this block if it has a unique successor.
InstListType::reverse_iterator reverse_iterator
Definition BasicBlock.h:172
reverse_iterator rend()
Definition BasicBlock.h:488
const Instruction * getTerminator() const LLVM_READONLY
Returns the terminator instruction if the block is well formed or null if the block is not well forme...
Definition BasicBlock.h:233
static BranchInst * Create(BasicBlock *IfTrue, InsertPosition InsertBefore=nullptr)
Base class for all callable instructions (InvokeInst and CallInst) Holds everything related to callin...
void setCallingConv(CallingConv::ID CC)
bool arg_empty() const
Function * getCalledFunction() const
Returns the function called, or null if this is an indirect function invocation or the function signa...
bool doesNotAccessMemory(unsigned OpNo) const
LLVM_ABI bool isIndirectCall() const
Return true if the callsite is an indirect call.
bool isCallee(Value::const_user_iterator UI) const
Determine whether the passed iterator points to the callee operand's Use.
Value * getArgOperand(unsigned i) const
void setArgOperand(unsigned i, Value *v)
iterator_range< User::op_iterator > args()
Iteration adapter for range-for loops.
unsigned getArgOperandNo(const Use *U) const
Given a use for a arg operand, get the arg operand number that corresponds to it.
unsigned arg_size() const
AttributeList getAttributes() const
Return the attributes for this call.
void addParamAttr(unsigned ArgNo, Attribute::AttrKind Kind)
Adds the attribute to the indicated argument.
bool isArgOperand(const Use *U) const
bool hasOperandBundles() const
Return true if this User has any operand bundles.
LLVM_ABI Function * getCaller()
Helper to get the caller (the parent function).
Wrapper to unify "old style" CallGraph and "new style" LazyCallGraph.
void initialize(LazyCallGraph &LCG, LazyCallGraph::SCC &SCC, CGSCCAnalysisManager &AM, CGSCCUpdateResult &UR)
Initializers for usage outside of a CGSCC pass, inside a CGSCC pass in the old and new pass manager (...
static CallInst * Create(FunctionType *Ty, Value *F, const Twine &NameStr="", InsertPosition InsertBefore=nullptr)
@ ICMP_SLT
signed less than
Definition InstrTypes.h:705
@ ICMP_NE
not equal
Definition InstrTypes.h:698
static LLVM_ABI Constant * getPointerCast(Constant *C, Type *Ty)
Create a BitCast, AddrSpaceCast, or a PtrToInt cast constant expression.
static LLVM_ABI Constant * getPointerBitCastOrAddrSpaceCast(Constant *C, Type *Ty)
Create a BitCast or AddrSpaceCast for a pointer type depending on the address space.
This is the shared class of boolean and integer constants.
Definition Constants.h:87
IntegerType * getIntegerType() const
Variant of the getType() method to always return an IntegerType, which reduces the amount of casting ...
Definition Constants.h:198
static LLVM_ABI ConstantInt * getTrue(LLVMContext &Context)
bool isZero() const
This is just a convenience method to make client code smaller for a common code.
Definition Constants.h:219
int64_t getSExtValue() const
Return the constant as a 64-bit integer value after it has been sign extended as appropriate for the ...
Definition Constants.h:174
This is an important base class in LLVM.
Definition Constant.h:43
static LLVM_ABI Constant * getNullValue(Type *Ty)
Constructor to create a '0' constant of arbitrary type.
ValueT lookup(const_arg_type_t< KeyT > Val) const
lookup - Return the entry for the specified key, or a default constructed value if no such entry exis...
Definition DenseMap.h:205
std::pair< iterator, bool > insert(const std::pair< KeyT, ValueT > &KV)
Definition DenseMap.h:241
LLVM_ABI Instruction * findNearestCommonDominator(Instruction *I1, Instruction *I2) const
Find the nearest instruction I that dominates both I1 and I2, in the sense that a result produced bef...
static ErrorSuccess success()
Create a success value.
Definition Error.h:336
AtomicOrdering getOrdering() const
Returns the ordering constraint of this fence instruction.
A proxy from a FunctionAnalysisManager to an SCC.
const BasicBlock & getEntryBlock() const
Definition Function.h:813
const BasicBlock & front() const
Definition Function.h:864
LLVMContext & getContext() const
getContext - Return a reference to the LLVMContext associated with this function.
Definition Function.cpp:358
Argument * getArg(unsigned i) const
Definition Function.h:890
bool hasFnAttribute(Attribute::AttrKind Kind) const
Return true if the function has the attribute.
Definition Function.cpp:729
LLVM_ABI bool isDeclaration() const
Return true if the primary definition of this global value is outside of the current translation unit...
Definition Globals.cpp:329
bool hasLocalLinkage() const
Module * getParent()
Get the module that this global value is contained inside of...
@ PrivateLinkage
Like Internal, but omit from symbol table.
Definition GlobalValue.h:61
@ InternalLinkage
Rename collisions when linking (static functions).
Definition GlobalValue.h:60
const Constant * getInitializer() const
getInitializer - Return the initializer for this global variable.
LLVM_ABI void setInitializer(Constant *InitVal)
setInitializer - Sets the initializer for this global variable, removing any existing initializer if ...
Definition Globals.cpp:534
BasicBlock * getBlock() const
Definition IRBuilder.h:306
BranchInst * CreateCondBr(Value *Cond, BasicBlock *True, BasicBlock *False, MDNode *BranchWeights=nullptr, MDNode *Unpredictable=nullptr)
Create a conditional 'br Cond, TrueDest, FalseDest' instruction.
Definition IRBuilder.h:1201
CallInst * CreateCall(FunctionType *FTy, Value *Callee, ArrayRef< Value * > Args={}, const Twine &Name="", MDNode *FPMathTag=nullptr)
Definition IRBuilder.h:2481
Value * CreateIsNull(Value *Arg, const Twine &Name="")
Return a boolean value testing if Arg == 0.
Definition IRBuilder.h:2630
LLVM_ABI bool isLifetimeStartOrEnd() const LLVM_READONLY
Return true if the instruction is a llvm.lifetime.start or llvm.lifetime.end marker.
LLVM_ABI bool mayWriteToMemory() const LLVM_READONLY
Return true if this instruction may modify memory.
const DebugLoc & getDebugLoc() const
Return the debug location for this node as a DebugLoc.
LLVM_ABI const Module * getModule() const
Return the module owning the function this instruction belongs to or nullptr it the function does not...
LLVM_ABI InstListType::iterator eraseFromParent()
This method unlinks 'this' from the containing basic block and deletes it.
LLVM_ABI const Function * getFunction() const
Return the function this instruction belongs to.
LLVM_ABI bool mayHaveSideEffects() const LLVM_READONLY
Return true if the instruction may have side effects.
LLVM_ABI bool mayReadFromMemory() const LLVM_READONLY
Return true if this instruction may read memory.
void setDebugLoc(DebugLoc Loc)
Set the debug location information for this instruction.
LLVM_ABI void setSuccessor(unsigned Idx, BasicBlock *BB)
Update the specified successor to point at the provided block.
A node in the call graph.
An SCC of the call graph.
A lazily constructed view of the call graph of a module.
LLVM_ABI void eraseFromParent()
This method unlinks 'this' from the containing function and deletes it.
LLVM_ABI StringRef getName() const
Return the name of the corresponding LLVM basic block, or an empty string.
Root of the metadata hierarchy.
Definition Metadata.h:64
A Module instance is used to store all the information related to an LLVM module.
Definition Module.h:67
const Triple & getTargetTriple() const
Get the target triple which is a string describing the target host.
Definition Module.h:281
LLVM_ABI Constant * getOrCreateIdent(Constant *SrcLocStr, uint32_t SrcLocStrSize, omp::IdentFlag Flags=omp::IdentFlag(0), unsigned Reserve2Flags=0)
Return an ident_t* encoding the source location SrcLocStr and Flags.
LLVM_ABI FunctionCallee getOrCreateRuntimeFunction(Module &M, omp::RuntimeFunction FnID)
Return the function declaration for the runtime function with FnID.
static LLVM_ABI std::pair< int32_t, int32_t > readThreadBoundsForKernel(const Triple &T, Function &Kernel)
}
LLVM_ABI Constant * getOrCreateSrcLocStr(StringRef LocStr, uint32_t &SrcLocStrSize)
Return the (LLVM-IR) string describing the source location LocStr.
IRBuilder<>::InsertPoint InsertPointTy
Type used throughout for insertion points.
IRBuilder Builder
The LLVM-IR Builder used to create IR.
static LLVM_ABI std::pair< int32_t, int32_t > readTeamBoundsForKernel(const Triple &T, Function &Kernel)
Read/write a bounds on teams for Kernel.
bool updateToLocation(const LocationDescription &Loc)
Update the internal location to Loc.
PreservedAnalyses run(LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM, LazyCallGraph &CG, CGSCCUpdateResult &UR)
PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM)
Diagnostic information for optimization analysis remarks.
The optimization diagnostic interface.
static LLVM_ABI PoisonValue * get(Type *T)
Static factory methods - Return an 'poison' object of the specified type.
A set of analyses that are preserved following a run of a transformation pass.
Definition Analysis.h:112
static PreservedAnalyses none()
Convenience factory function for the empty preserved set.
Definition Analysis.h:115
static PreservedAnalyses all()
Construct a special preserved set that preserves all passes.
Definition Analysis.h:118
static ReturnInst * Create(LLVMContext &C, Value *retVal=nullptr, InsertPosition InsertBefore=nullptr)
A vector that has set insertion semantics.
Definition SetVector.h:57
size_type size() const
Determine the number of elements in the SetVector.
Definition SetVector.h:103
size_type count(const_arg_type key) const
Count the number of elements of a given key in the SetVector.
Definition SetVector.h:262
bool insert(const value_type &X)
Insert a new element into the SetVector.
Definition SetVector.h:151
size_type size() const
Definition SmallPtrSet.h:99
size_type count(ConstPtrType Ptr) const
count - Return 1 if the specified pointer is in the set, 0 otherwise.
std::pair< iterator, bool > insert(PtrType Ptr)
Inserts Ptr if and only if there is no element in the container equal to Ptr.
iterator begin() const
SmallPtrSet - This class implements a set which is optimized for holding SmallSize or less elements.
reference emplace_back(ArgTypes &&... Args)
void append(ItTy in_start, ItTy in_end)
Add the specified range to the end of the SmallVector.
void push_back(const T &Elt)
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
bool starts_with(StringRef Prefix) const
Check if this string starts with the given Prefix.
Definition StringRef.h:261
Triple - Helper class for working with autoconf configuration names.
Definition Triple.h:47
static LLVM_ABI IntegerType * getInt32Ty(LLVMContext &C)
Definition Type.cpp:296
LLVM_ABI unsigned getPointerAddressSpace() const
Get the address space of this pointer or pointer vector type.
static LLVM_ABI UndefValue * get(Type *T)
Static factory methods - Return an 'undef' object of the specified type.
A Use represents the edge between a Value definition and its users.
Definition Use.h:35
LLVM_ABI bool replaceUsesOfWith(Value *From, Value *To)
Replace uses of one Value with another.
Definition User.cpp:25
Type * getType() const
All values are typed, get the type of this value.
Definition Value.h:256
LLVM_ABI void setName(const Twine &Name)
Change the name of the value.
Definition Value.cpp:397
bool hasOneUse() const
Return true if there is exactly one use of this value.
Definition Value.h:439
LLVM_ABI void replaceAllUsesWith(Value *V)
Change all uses of this to point to a new Value.
Definition Value.cpp:553
iterator_range< user_iterator > users()
Definition Value.h:426
User * user_back()
Definition Value.h:412
LLVM_ABI const Value * stripPointerCasts() const
Strip off pointer casts, all-zero GEPs and address space casts.
Definition Value.cpp:708
LLVM_ABI StringRef getName() const
Return a constant reference to the value's name.
Definition Value.cpp:322
const ParentTy * getParent() const
Definition ilist_node.h:34
self_iterator getIterator()
Definition ilist_node.h:123
NodeTy * getNextNode()
Get the next node, or nullptr for the list tail.
Definition ilist_node.h:348
Changed
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
GlobalVariable * getKernelEnvironementGVFromKernelInitCB(CallBase *KernelInitCB)
ConstantStruct * getKernelEnvironementFromKernelInitCB(CallBase *KernelInitCB)
Abstract Attribute helper functions.
Definition Attributor.h:165
LLVM_ABI bool isValidAtPosition(const ValueAndContext &VAC, InformationCache &InfoCache)
Return true if the value of VAC is a valid at the position of VAC, that is a constant,...
LLVM_ABI bool isPotentiallyAffectedByBarrier(Attributor &A, const Instruction &I, const AbstractAttribute &QueryingAA)
Return true if I is potentially affected by a barrier.
@ Interprocedural
Definition Attributor.h:184
LLVM_ABI bool isNoSyncInst(Attributor &A, const Instruction &I, const AbstractAttribute &QueryingAA)
Return true if I is a nosync instruction.
constexpr char Args[]
Key for Kernel::Metadata::mArgs.
E & operator^=(E &LHS, E RHS)
@ Entry
Definition COFF.h:862
unsigned ID
LLVM IR allows to use arbitrary numbers as calling convention identifiers.
Definition CallingConv.h:24
@ C
The default llvm calling convention, compatible with C.
Definition CallingConv.h:34
@ BasicBlock
Various leaf nodes.
Definition ISDOpcodes.h:81
initializer< Ty > init(const Ty &Val)
PointerTypeMap run(const Module &M)
Compute the PointerTypeMap for the module M.
bool isOpenMPDevice(Module &M)
Helper to determine if M is a OpenMP target offloading device module.
bool containsOpenMP(Module &M)
Helper to determine if M contains OpenMP.
InternalControlVar
IDs for all Internal Control Variables (ICVs).
RuntimeFunction
IDs for all omp runtime library (RTL) functions.
KernelSet getDeviceKernels(Module &M)
Get OpenMP device kernels in M.
@ OMP_TGT_EXEC_MODE_GENERIC_SPMD
SetVector< Kernel > KernelSet
Set of kernels in the module.
Definition OpenMPOpt.h:24
Function * Kernel
Summary of a kernel (=entry point for target offloading).
Definition OpenMPOpt.h:21
bool isOpenMPKernel(Function &Fn)
Return true iff Fn is an OpenMP GPU kernel; Fn has the "kernel" attribute.
DiagnosticInfoOptimizationBase::Argument NV
NodeAddr< UseNode * > Use
Definition RDFGraph.h:385
bool empty() const
Definition BasicBlock.h:101
iterator end() const
Definition BasicBlock.h:89
friend class Instruction
Iterator for Instructions in a `BasicBlock.
Definition BasicBlock.h:73
LLVM_ABI iterator begin() const
This is an optimization pass for GlobalISel generic memory operations.
Definition Types.h:26
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
Definition STLExtras.h:316
@ Offset
Definition DWP.cpp:532
FunctionAddr VTableAddr Value
Definition InstrProf.h:137
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition STLExtras.h:1737
auto size(R &&Range, std::enable_if_t< std::is_base_of< std::random_access_iterator_tag, typename std::iterator_traits< decltype(Range.begin())>::iterator_category >::value, void > *=nullptr)
Get the size of a range.
Definition STLExtras.h:1667
bool succ_empty(const Instruction *I)
Definition CFG.h:257
decltype(auto) dyn_cast(const From &Val)
dyn_cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:643
FunctionAddr VTableAddr uintptr_t uintptr_t Int32Ty
Definition InstrProf.h:296
bool operator!=(uint64_t V1, const APInt &V2)
Definition APInt.h:2122
constexpr from_range_t from_range
Value * GetPointerBaseWithConstantOffset(Value *Ptr, int64_t &Offset, const DataLayout &DL, bool AllowNonInbounds=true)
Analyze the specified pointer to see if it can be expressed as a base pointer plus a constant offset.
InnerAnalysisManagerProxy< FunctionAnalysisManager, Module > FunctionAnalysisManagerModuleProxy
Provide the FunctionAnalysisManager to Module proxy.
bool operator==(const AddressRangeValuePair &LHS, const AddressRangeValuePair &RHS)
AnalysisManager< LazyCallGraph::SCC, LazyCallGraph & > CGSCCAnalysisManager
The CGSCC analysis manager.
@ ThinLTOPostLink
ThinLTO postlink (backend compile) phase.
Definition Pass.h:83
@ FullLTOPostLink
Full LTO postlink (backend compile) phase.
Definition Pass.h:87
@ ThinLTOPreLink
ThinLTO prelink (summary) phase.
Definition Pass.h:81
auto dyn_cast_or_null(const Y &Val)
Definition Casting.h:753
LLVM_ABI raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition Debug.cpp:207
class LLVM_GSL_OWNER SmallVector
Forward declaration of SmallVector so that calculateSmallVectorDefaultInlinedElements can reference s...
bool isa(const From &Val)
isa<X> - Return true if the parameter to the template is an instance of one of the template type argu...
Definition Casting.h:547
MutableArrayRef(T &OneElt) -> MutableArrayRef< T >
void cantFail(Error Err, const char *Msg=nullptr)
Report a fatal error if Err is a failure value.
Definition Error.h:769
IRBuilder(LLVMContext &, FolderTy, InserterTy, MDNode *, ArrayRef< OperandBundleDef >) -> IRBuilder< FolderTy, InserterTy >
bool operator&=(SparseBitVector< ElementSize > *LHS, const SparseBitVector< ElementSize > &RHS)
auto count(R &&Range, const E &Element)
Wrapper function around std::count to count the number of times an element Element occurs in the give...
Definition STLExtras.h:2002
ArrayRef(const T &OneElt) -> ArrayRef< T >
std::string toString(const APInt &I, unsigned Radix, bool Signed, bool formatAsCLiteral=false, bool UpperCase=true, bool InsertSeparators=false)
decltype(auto) cast(const From &Val)
cast<X> - Return the argument parameter cast to the specified type.
Definition Casting.h:559
LLVM_ABI BasicBlock * SplitBlock(BasicBlock *Old, BasicBlock::iterator SplitPt, DominatorTree *DT, LoopInfo *LI=nullptr, MemorySSAUpdater *MSSAU=nullptr, const Twine &BBName="", bool Before=false)
Split the specified block at the specified instruction.
auto predecessors(const MachineBasicBlock *BB)
ChangeStatus
{
Definition Attributor.h:496
LLVM_ABI Constant * ConstantFoldInsertValueInstruction(Constant *Agg, Constant *Val, ArrayRef< unsigned > Idxs)
Attempt to constant fold an insertvalue instruction with the specified operands and indices.
@ OPTIONAL
The target may be valid if the source is not.
Definition Attributor.h:508
AnalysisManager< Function > FunctionAnalysisManager
Convenience typedef for the Function analysis manager.
BumpPtrAllocatorImpl<> BumpPtrAllocator
The standard BumpPtrAllocator which just uses the default template parameters.
Definition Allocator.h:383
LLVM_ABI const Value * getUnderlyingObject(const Value *V, unsigned MaxLookup=MaxLookupSearchDepth)
This method strips off any GEP address adjustments, pointer casts or llvm.threadlocal....
AnalysisManager< Module > ModuleAnalysisManager
Convenience typedef for the Module analysis manager.
Definition MIRParser.h:39
#define N
static LLVM_ABI AAExecutionDomain & createForPosition(const IRPosition &IRP, Attributor &A)
Create an abstract attribute view for the position IRP.
AAExecutionDomain(const IRPosition &IRP, Attributor &A)
static LLVM_ABI const char ID
Unique ID (due to the unique address)
AccessKind
Simple enum to distinguish read/write/read-write accesses.
StateType::base_t MemoryLocationsKind
static LLVM_ABI bool isAlignedBarrier(const CallBase &CB, bool ExecutedAligned)
Helper function to determine if CB is an aligned (GPU) barrier.
Base struct for all "concrete attribute" deductions.
virtual const char * getIdAddr() const =0
This function should return the address of the ID of the AbstractAttribute.
An interface to query the internal state of an abstract attribute.
Wrapper for FunctionAnalysisManager.
Configuration for the Attributor.
std::function< void(Attributor &A, const Function &F)> InitializationCallback
Callback function to be invoked on internal functions marked live.
std::optional< unsigned > MaxFixpointIterations
Maximum number of iterations to run until fixpoint.
bool RewriteSignatures
Flag to determine if we rewrite function signatures.
const char * PassName
}
OptimizationRemarkGetter OREGetter
IPOAmendableCBTy IPOAmendableCB
bool IsModulePass
Is the user of the Attributor a module pass or not.
bool DefaultInitializeLiveInternals
Flag to determine if we want to initialize all default AAs for an internal function marked live.
The fixpoint analysis framework that orchestrates the attribute deduction.
static LLVM_ABI bool isInternalizable(Function &F)
Returns true if the function F can be internalized.
std::function< std::optional< Value * >( const IRPosition &, const AbstractAttribute *, bool &)> SimplifictionCallbackTy
Register CB as a simplification callback.
std::function< std::optional< Constant * >( const GlobalVariable &, const AbstractAttribute *, bool &)> GlobalVariableSimplifictionCallbackTy
Register CB as a simplification callback.
std::function< bool(Attributor &, const AbstractAttribute *)> VirtualUseCallbackTy
static LLVM_ABI bool internalizeFunctions(SmallPtrSetImpl< Function * > &FnSet, DenseMap< Function *, Function * > &FnMap)
Make copies of each function in the set FnSet such that the copied version has internal linkage after...
Simple wrapper for a single bit (boolean) state.
Support structure for SCC passes to communicate updates the call graph back to the CGSCC pass manager...
Helper to describe and deal with positions in the LLVM-IR.
Definition Attributor.h:593
static const IRPosition callsite_returned(const CallBase &CB)
Create a position describing the returned value of CB.
Definition Attributor.h:661
static const IRPosition returned(const Function &F, const CallBaseContext *CBContext=nullptr)
Create a position describing the returned value of F.
Definition Attributor.h:643
static const IRPosition value(const Value &V, const CallBaseContext *CBContext=nullptr)
Create a position describing the value of V.
Definition Attributor.h:617
static const IRPosition inst(const Instruction &I, const CallBaseContext *CBContext=nullptr)
Create a position describing the instruction I.
Definition Attributor.h:629
@ IRP_ARGUMENT
An attribute for a function argument.
Definition Attributor.h:607
@ IRP_RETURNED
An attribute for the function return value.
Definition Attributor.h:603
@ IRP_CALL_SITE
An attribute for a call site (function scope).
Definition Attributor.h:606
@ IRP_CALL_SITE_RETURNED
An attribute for a call site return value.
Definition Attributor.h:604
@ IRP_FUNCTION
An attribute for a function (scope).
Definition Attributor.h:605
@ IRP_FLOAT
A position that is not associated with a spot suitable for attributes.
Definition Attributor.h:601
@ IRP_CALL_SITE_ARGUMENT
An attribute for a call site argument.
Definition Attributor.h:608
@ IRP_INVALID
An invalid position.
Definition Attributor.h:600
static const IRPosition function(const Function &F, const CallBaseContext *CBContext=nullptr)
Create a position describing the function scope of F.
Definition Attributor.h:636
Kind getPositionKind() const
Return the associated position kind.
Definition Attributor.h:889
static const IRPosition callsite_function(const CallBase &CB)
Create a position describing the function scope of CB.
Definition Attributor.h:656
Data structure to hold cached (LLVM-IR) information.
Defines various target-specific GPU grid values that must be consistent between host RTL (plugin),...