LLVM 17.0.0git
OpenMPOpt.cpp
Go to the documentation of this file.
1//===-- IPO/OpenMPOpt.cpp - Collection of OpenMP specific optimizations ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// OpenMP specific optimizations:
10//
11// - Deduplication of runtime calls, e.g., omp_get_thread_num.
12// - Replacing globalized device memory with stack memory.
13// - Replacing globalized device memory with shared memory.
14// - Parallel region merging.
15// - Transforming generic-mode device kernels to SPMD mode.
16// - Specializing the state machine for generic-mode device kernels.
17//
18//===----------------------------------------------------------------------===//
19
21
24#include "llvm/ADT/SetVector.h"
26#include "llvm/ADT/Statistic.h"
27#include "llvm/ADT/StringRef.h"
35#include "llvm/IR/Assumptions.h"
36#include "llvm/IR/BasicBlock.h"
37#include "llvm/IR/Constants.h"
39#include "llvm/IR/GlobalValue.h"
41#include "llvm/IR/Instruction.h"
44#include "llvm/IR/IntrinsicsAMDGPU.h"
45#include "llvm/IR/IntrinsicsNVPTX.h"
46#include "llvm/IR/LLVMContext.h"
49#include "llvm/Support/Debug.h"
53
54#include <algorithm>
55#include <optional>
56#include <string>
57
58using namespace llvm;
59using namespace omp;
60
61#define DEBUG_TYPE "openmp-opt"
62
64 "openmp-opt-disable", cl::desc("Disable OpenMP specific optimizations."),
65 cl::Hidden, cl::init(false));
66
68 "openmp-opt-enable-merging",
69 cl::desc("Enable the OpenMP region merging optimization."), cl::Hidden,
70 cl::init(false));
71
72static cl::opt<bool>
73 DisableInternalization("openmp-opt-disable-internalization",
74 cl::desc("Disable function internalization."),
75 cl::Hidden, cl::init(false));
76
77static cl::opt<bool> DeduceICVValues("openmp-deduce-icv-values",
78 cl::init(false), cl::Hidden);
79static cl::opt<bool> PrintICVValues("openmp-print-icv-values", cl::init(false),
81static cl::opt<bool> PrintOpenMPKernels("openmp-print-gpu-kernels",
82 cl::init(false), cl::Hidden);
83
85 "openmp-hide-memory-transfer-latency",
86 cl::desc("[WIP] Tries to hide the latency of host to device memory"
87 " transfers"),
88 cl::Hidden, cl::init(false));
89
91 "openmp-opt-disable-deglobalization",
92 cl::desc("Disable OpenMP optimizations involving deglobalization."),
93 cl::Hidden, cl::init(false));
94
96 "openmp-opt-disable-spmdization",
97 cl::desc("Disable OpenMP optimizations involving SPMD-ization."),
98 cl::Hidden, cl::init(false));
99
101 "openmp-opt-disable-folding",
102 cl::desc("Disable OpenMP optimizations involving folding."), cl::Hidden,
103 cl::init(false));
104
106 "openmp-opt-disable-state-machine-rewrite",
107 cl::desc("Disable OpenMP optimizations that replace the state machine."),
108 cl::Hidden, cl::init(false));
109
111 "openmp-opt-disable-barrier-elimination",
112 cl::desc("Disable OpenMP optimizations that eliminate barriers."),
113 cl::Hidden, cl::init(false));
114
116 "openmp-opt-print-module-after",
117 cl::desc("Print the current module after OpenMP optimizations."),
118 cl::Hidden, cl::init(false));
119
121 "openmp-opt-print-module-before",
122 cl::desc("Print the current module before OpenMP optimizations."),
123 cl::Hidden, cl::init(false));
124
126 "openmp-opt-inline-device",
127 cl::desc("Inline all applicible functions on the device."), cl::Hidden,
128 cl::init(false));
129
130static cl::opt<bool>
131 EnableVerboseRemarks("openmp-opt-verbose-remarks",
132 cl::desc("Enables more verbose remarks."), cl::Hidden,
133 cl::init(false));
134
136 SetFixpointIterations("openmp-opt-max-iterations", cl::Hidden,
137 cl::desc("Maximal number of attributor iterations."),
138 cl::init(256));
139
141 SharedMemoryLimit("openmp-opt-shared-limit", cl::Hidden,
142 cl::desc("Maximum amount of shared memory to use."),
143 cl::init(std::numeric_limits<unsigned>::max()));
144
145STATISTIC(NumOpenMPRuntimeCallsDeduplicated,
146 "Number of OpenMP runtime calls deduplicated");
147STATISTIC(NumOpenMPParallelRegionsDeleted,
148 "Number of OpenMP parallel regions deleted");
149STATISTIC(NumOpenMPRuntimeFunctionsIdentified,
150 "Number of OpenMP runtime functions identified");
151STATISTIC(NumOpenMPRuntimeFunctionUsesIdentified,
152 "Number of OpenMP runtime function uses identified");
153STATISTIC(NumOpenMPTargetRegionKernels,
154 "Number of OpenMP target region entry points (=kernels) identified");
155STATISTIC(NumOpenMPTargetRegionKernelsSPMD,
156 "Number of OpenMP target region entry points (=kernels) executed in "
157 "SPMD-mode instead of generic-mode");
158STATISTIC(NumOpenMPTargetRegionKernelsWithoutStateMachine,
159 "Number of OpenMP target region entry points (=kernels) executed in "
160 "generic-mode without a state machines");
161STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback,
162 "Number of OpenMP target region entry points (=kernels) executed in "
163 "generic-mode with customized state machines with fallback");
164STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback,
165 "Number of OpenMP target region entry points (=kernels) executed in "
166 "generic-mode with customized state machines without fallback");
168 NumOpenMPParallelRegionsReplacedInGPUStateMachine,
169 "Number of OpenMP parallel regions replaced with ID in GPU state machines");
170STATISTIC(NumOpenMPParallelRegionsMerged,
171 "Number of OpenMP parallel regions merged");
172STATISTIC(NumBytesMovedToSharedMemory,
173 "Amount of memory pushed to shared memory");
174STATISTIC(NumBarriersEliminated, "Number of redundant barriers eliminated");
175
176#if !defined(NDEBUG)
177static constexpr auto TAG = "[" DEBUG_TYPE "]";
178#endif
179
180namespace {
181
182struct AAHeapToShared;
183
184struct AAICVTracker;
185
186/// OpenMP specific information. For now, stores RFIs and ICVs also needed for
187/// Attributor runs.
188struct OMPInformationCache : public InformationCache {
189 OMPInformationCache(Module &M, AnalysisGetter &AG,
191 KernelSet &Kernels, bool OpenMPPostLink)
192 : InformationCache(M, AG, Allocator, CGSCC), OMPBuilder(M),
193 Kernels(Kernels), OpenMPPostLink(OpenMPPostLink) {
194
195 OMPBuilder.initialize();
196 initializeRuntimeFunctions(M);
197 initializeInternalControlVars();
198 }
199
200 /// Generic information that describes an internal control variable.
201 struct InternalControlVarInfo {
202 /// The kind, as described by InternalControlVar enum.
204
205 /// The name of the ICV.
206 StringRef Name;
207
208 /// Environment variable associated with this ICV.
209 StringRef EnvVarName;
210
211 /// Initial value kind.
212 ICVInitValue InitKind;
213
214 /// Initial value.
215 ConstantInt *InitValue;
216
217 /// Setter RTL function associated with this ICV.
218 RuntimeFunction Setter;
219
220 /// Getter RTL function associated with this ICV.
221 RuntimeFunction Getter;
222
223 /// RTL Function corresponding to the override clause of this ICV
225 };
226
227 /// Generic information that describes a runtime function
228 struct RuntimeFunctionInfo {
229
230 /// The kind, as described by the RuntimeFunction enum.
231 RuntimeFunction Kind;
232
233 /// The name of the function.
234 StringRef Name;
235
236 /// Flag to indicate a variadic function.
237 bool IsVarArg;
238
239 /// The return type of the function.
240 Type *ReturnType;
241
242 /// The argument types of the function.
243 SmallVector<Type *, 8> ArgumentTypes;
244
245 /// The declaration if available.
246 Function *Declaration = nullptr;
247
248 /// Uses of this runtime function per function containing the use.
249 using UseVector = SmallVector<Use *, 16>;
250
251 /// Clear UsesMap for runtime function.
252 void clearUsesMap() { UsesMap.clear(); }
253
254 /// Boolean conversion that is true if the runtime function was found.
255 operator bool() const { return Declaration; }
256
257 /// Return the vector of uses in function \p F.
258 UseVector &getOrCreateUseVector(Function *F) {
259 std::shared_ptr<UseVector> &UV = UsesMap[F];
260 if (!UV)
261 UV = std::make_shared<UseVector>();
262 return *UV;
263 }
264
265 /// Return the vector of uses in function \p F or `nullptr` if there are
266 /// none.
267 const UseVector *getUseVector(Function &F) const {
268 auto I = UsesMap.find(&F);
269 if (I != UsesMap.end())
270 return I->second.get();
271 return nullptr;
272 }
273
274 /// Return how many functions contain uses of this runtime function.
275 size_t getNumFunctionsWithUses() const { return UsesMap.size(); }
276
277 /// Return the number of arguments (or the minimal number for variadic
278 /// functions).
279 size_t getNumArgs() const { return ArgumentTypes.size(); }
280
281 /// Run the callback \p CB on each use and forget the use if the result is
282 /// true. The callback will be fed the function in which the use was
283 /// encountered as second argument.
284 void foreachUse(SmallVectorImpl<Function *> &SCC,
285 function_ref<bool(Use &, Function &)> CB) {
286 for (Function *F : SCC)
287 foreachUse(CB, F);
288 }
289
290 /// Run the callback \p CB on each use within the function \p F and forget
291 /// the use if the result is true.
292 void foreachUse(function_ref<bool(Use &, Function &)> CB, Function *F) {
293 SmallVector<unsigned, 8> ToBeDeleted;
294 ToBeDeleted.clear();
295
296 unsigned Idx = 0;
297 UseVector &UV = getOrCreateUseVector(F);
298
299 for (Use *U : UV) {
300 if (CB(*U, *F))
301 ToBeDeleted.push_back(Idx);
302 ++Idx;
303 }
304
305 // Remove the to-be-deleted indices in reverse order as prior
306 // modifications will not modify the smaller indices.
307 while (!ToBeDeleted.empty()) {
308 unsigned Idx = ToBeDeleted.pop_back_val();
309 UV[Idx] = UV.back();
310 UV.pop_back();
311 }
312 }
313
314 private:
315 /// Map from functions to all uses of this runtime function contained in
316 /// them.
318
319 public:
320 /// Iterators for the uses of this runtime function.
321 decltype(UsesMap)::iterator begin() { return UsesMap.begin(); }
322 decltype(UsesMap)::iterator end() { return UsesMap.end(); }
323 };
324
325 /// An OpenMP-IR-Builder instance
326 OpenMPIRBuilder OMPBuilder;
327
328 /// Map from runtime function kind to the runtime function description.
329 EnumeratedArray<RuntimeFunctionInfo, RuntimeFunction,
330 RuntimeFunction::OMPRTL___last>
331 RFIs;
332
333 /// Map from function declarations/definitions to their runtime enum type.
334 DenseMap<Function *, RuntimeFunction> RuntimeFunctionIDMap;
335
336 /// Map from ICV kind to the ICV description.
337 EnumeratedArray<InternalControlVarInfo, InternalControlVar,
338 InternalControlVar::ICV___last>
339 ICVs;
340
341 /// Helper to initialize all internal control variable information for those
342 /// defined in OMPKinds.def.
343 void initializeInternalControlVars() {
344#define ICV_RT_SET(_Name, RTL) \
345 { \
346 auto &ICV = ICVs[_Name]; \
347 ICV.Setter = RTL; \
348 }
349#define ICV_RT_GET(Name, RTL) \
350 { \
351 auto &ICV = ICVs[Name]; \
352 ICV.Getter = RTL; \
353 }
354#define ICV_DATA_ENV(Enum, _Name, _EnvVarName, Init) \
355 { \
356 auto &ICV = ICVs[Enum]; \
357 ICV.Name = _Name; \
358 ICV.Kind = Enum; \
359 ICV.InitKind = Init; \
360 ICV.EnvVarName = _EnvVarName; \
361 switch (ICV.InitKind) { \
362 case ICV_IMPLEMENTATION_DEFINED: \
363 ICV.InitValue = nullptr; \
364 break; \
365 case ICV_ZERO: \
366 ICV.InitValue = ConstantInt::get( \
367 Type::getInt32Ty(OMPBuilder.Int32->getContext()), 0); \
368 break; \
369 case ICV_FALSE: \
370 ICV.InitValue = ConstantInt::getFalse(OMPBuilder.Int1->getContext()); \
371 break; \
372 case ICV_LAST: \
373 break; \
374 } \
375 }
376#include "llvm/Frontend/OpenMP/OMPKinds.def"
377 }
378
379 /// Returns true if the function declaration \p F matches the runtime
380 /// function types, that is, return type \p RTFRetType, and argument types
381 /// \p RTFArgTypes.
382 static bool declMatchesRTFTypes(Function *F, Type *RTFRetType,
383 SmallVector<Type *, 8> &RTFArgTypes) {
384 // TODO: We should output information to the user (under debug output
385 // and via remarks).
386
387 if (!F)
388 return false;
389 if (F->getReturnType() != RTFRetType)
390 return false;
391 if (F->arg_size() != RTFArgTypes.size())
392 return false;
393
394 auto *RTFTyIt = RTFArgTypes.begin();
395 for (Argument &Arg : F->args()) {
396 if (Arg.getType() != *RTFTyIt)
397 return false;
398
399 ++RTFTyIt;
400 }
401
402 return true;
403 }
404
405 // Helper to collect all uses of the declaration in the UsesMap.
406 unsigned collectUses(RuntimeFunctionInfo &RFI, bool CollectStats = true) {
407 unsigned NumUses = 0;
408 if (!RFI.Declaration)
409 return NumUses;
410 OMPBuilder.addAttributes(RFI.Kind, *RFI.Declaration);
411
412 if (CollectStats) {
413 NumOpenMPRuntimeFunctionsIdentified += 1;
414 NumOpenMPRuntimeFunctionUsesIdentified += RFI.Declaration->getNumUses();
415 }
416
417 // TODO: We directly convert uses into proper calls and unknown uses.
418 for (Use &U : RFI.Declaration->uses()) {
419 if (Instruction *UserI = dyn_cast<Instruction>(U.getUser())) {
420 if (ModuleSlice.empty() || ModuleSlice.count(UserI->getFunction())) {
421 RFI.getOrCreateUseVector(UserI->getFunction()).push_back(&U);
422 ++NumUses;
423 }
424 } else {
425 RFI.getOrCreateUseVector(nullptr).push_back(&U);
426 ++NumUses;
427 }
428 }
429 return NumUses;
430 }
431
432 // Helper function to recollect uses of a runtime function.
433 void recollectUsesForFunction(RuntimeFunction RTF) {
434 auto &RFI = RFIs[RTF];
435 RFI.clearUsesMap();
436 collectUses(RFI, /*CollectStats*/ false);
437 }
438
439 // Helper function to recollect uses of all runtime functions.
440 void recollectUses() {
441 for (int Idx = 0; Idx < RFIs.size(); ++Idx)
442 recollectUsesForFunction(static_cast<RuntimeFunction>(Idx));
443 }
444
445 // Helper function to inherit the calling convention of the function callee.
446 void setCallingConvention(FunctionCallee Callee, CallInst *CI) {
447 if (Function *Fn = dyn_cast<Function>(Callee.getCallee()))
448 CI->setCallingConv(Fn->getCallingConv());
449 }
450
451 // Helper function to determine if it's legal to create a call to the runtime
452 // functions.
453 bool runtimeFnsAvailable(ArrayRef<RuntimeFunction> Fns) {
454 // We can always emit calls if we haven't yet linked in the runtime.
455 if (!OpenMPPostLink)
456 return true;
457
458 // Once the runtime has been already been linked in we cannot emit calls to
459 // any undefined functions.
460 for (RuntimeFunction Fn : Fns) {
461 RuntimeFunctionInfo &RFI = RFIs[Fn];
462
463 if (RFI.Declaration && RFI.Declaration->isDeclaration())
464 return false;
465 }
466 return true;
467 }
468
469 /// Helper to initialize all runtime function information for those defined
470 /// in OpenMPKinds.def.
471 void initializeRuntimeFunctions(Module &M) {
472
473 // Helper macros for handling __VA_ARGS__ in OMP_RTL
474#define OMP_TYPE(VarName, ...) \
475 Type *VarName = OMPBuilder.VarName; \
476 (void)VarName;
477
478#define OMP_ARRAY_TYPE(VarName, ...) \
479 ArrayType *VarName##Ty = OMPBuilder.VarName##Ty; \
480 (void)VarName##Ty; \
481 PointerType *VarName##PtrTy = OMPBuilder.VarName##PtrTy; \
482 (void)VarName##PtrTy;
483
484#define OMP_FUNCTION_TYPE(VarName, ...) \
485 FunctionType *VarName = OMPBuilder.VarName; \
486 (void)VarName; \
487 PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \
488 (void)VarName##Ptr;
489
490#define OMP_STRUCT_TYPE(VarName, ...) \
491 StructType *VarName = OMPBuilder.VarName; \
492 (void)VarName; \
493 PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \
494 (void)VarName##Ptr;
495
496#define OMP_RTL(_Enum, _Name, _IsVarArg, _ReturnType, ...) \
497 { \
498 SmallVector<Type *, 8> ArgsTypes({__VA_ARGS__}); \
499 Function *F = M.getFunction(_Name); \
500 RTLFunctions.insert(F); \
501 if (declMatchesRTFTypes(F, OMPBuilder._ReturnType, ArgsTypes)) { \
502 RuntimeFunctionIDMap[F] = _Enum; \
503 auto &RFI = RFIs[_Enum]; \
504 RFI.Kind = _Enum; \
505 RFI.Name = _Name; \
506 RFI.IsVarArg = _IsVarArg; \
507 RFI.ReturnType = OMPBuilder._ReturnType; \
508 RFI.ArgumentTypes = std::move(ArgsTypes); \
509 RFI.Declaration = F; \
510 unsigned NumUses = collectUses(RFI); \
511 (void)NumUses; \
512 LLVM_DEBUG({ \
513 dbgs() << TAG << RFI.Name << (RFI.Declaration ? "" : " not") \
514 << " found\n"; \
515 if (RFI.Declaration) \
516 dbgs() << TAG << "-> got " << NumUses << " uses in " \
517 << RFI.getNumFunctionsWithUses() \
518 << " different functions.\n"; \
519 }); \
520 } \
521 }
522#include "llvm/Frontend/OpenMP/OMPKinds.def"
523
524 // Remove the `noinline` attribute from `__kmpc`, `ompx::` and `omp_`
525 // functions, except if `optnone` is present.
526 if (isOpenMPDevice(M)) {
527 for (Function &F : M) {
528 for (StringRef Prefix : {"__kmpc", "_ZN4ompx", "omp_"})
529 if (F.hasFnAttribute(Attribute::NoInline) &&
530 F.getName().startswith(Prefix) &&
531 !F.hasFnAttribute(Attribute::OptimizeNone))
532 F.removeFnAttr(Attribute::NoInline);
533 }
534 }
535
536 // TODO: We should attach the attributes defined in OMPKinds.def.
537 }
538
539 /// Collection of known kernels (\see Kernel) in the module.
540 KernelSet &Kernels;
541
542 /// Collection of known OpenMP runtime functions..
543 DenseSet<const Function *> RTLFunctions;
544
545 /// Indicates if we have already linked in the OpenMP device library.
546 bool OpenMPPostLink = false;
547};
548
549template <typename Ty, bool InsertInvalidates = true>
550struct BooleanStateWithSetVector : public BooleanState {
551 bool contains(const Ty &Elem) const { return Set.contains(Elem); }
552 bool insert(const Ty &Elem) {
553 if (InsertInvalidates)
555 return Set.insert(Elem);
556 }
557
558 const Ty &operator[](int Idx) const { return Set[Idx]; }
559 bool operator==(const BooleanStateWithSetVector &RHS) const {
560 return BooleanState::operator==(RHS) && Set == RHS.Set;
561 }
562 bool operator!=(const BooleanStateWithSetVector &RHS) const {
563 return !(*this == RHS);
564 }
565
566 bool empty() const { return Set.empty(); }
567 size_t size() const { return Set.size(); }
568
569 /// "Clamp" this state with \p RHS.
570 BooleanStateWithSetVector &operator^=(const BooleanStateWithSetVector &RHS) {
571 BooleanState::operator^=(RHS);
572 Set.insert(RHS.Set.begin(), RHS.Set.end());
573 return *this;
574 }
575
576private:
577 /// A set to keep track of elements.
578 SetVector<Ty> Set;
579
580public:
581 typename decltype(Set)::iterator begin() { return Set.begin(); }
582 typename decltype(Set)::iterator end() { return Set.end(); }
583 typename decltype(Set)::const_iterator begin() const { return Set.begin(); }
584 typename decltype(Set)::const_iterator end() const { return Set.end(); }
585};
586
587template <typename Ty, bool InsertInvalidates = true>
588using BooleanStateWithPtrSetVector =
589 BooleanStateWithSetVector<Ty *, InsertInvalidates>;
590
591struct KernelInfoState : AbstractState {
592 /// Flag to track if we reached a fixpoint.
593 bool IsAtFixpoint = false;
594
595 /// The parallel regions (identified by the outlined parallel functions) that
596 /// can be reached from the associated function.
597 BooleanStateWithPtrSetVector<Function, /* InsertInvalidates */ false>
598 ReachedKnownParallelRegions;
599
600 /// State to track what parallel region we might reach.
601 BooleanStateWithPtrSetVector<CallBase> ReachedUnknownParallelRegions;
602
603 /// State to track if we are in SPMD-mode, assumed or know, and why we decided
604 /// we cannot be. If it is assumed, then RequiresFullRuntime should also be
605 /// false.
606 BooleanStateWithPtrSetVector<Instruction, false> SPMDCompatibilityTracker;
607
608 /// The __kmpc_target_init call in this kernel, if any. If we find more than
609 /// one we abort as the kernel is malformed.
610 CallBase *KernelInitCB = nullptr;
611
612 /// The __kmpc_target_deinit call in this kernel, if any. If we find more than
613 /// one we abort as the kernel is malformed.
614 CallBase *KernelDeinitCB = nullptr;
615
616 /// Flag to indicate if the associated function is a kernel entry.
617 bool IsKernelEntry = false;
618
619 /// State to track what kernel entries can reach the associated function.
620 BooleanStateWithPtrSetVector<Function, false> ReachingKernelEntries;
621
622 /// State to indicate if we can track parallel level of the associated
623 /// function. We will give up tracking if we encounter unknown caller or the
624 /// caller is __kmpc_parallel_51.
625 BooleanStateWithSetVector<uint8_t> ParallelLevels;
626
627 /// Flag that indicates if the kernel has nested Parallelism
628 bool NestedParallelism = false;
629
630 /// Abstract State interface
631 ///{
632
633 KernelInfoState() = default;
634 KernelInfoState(bool BestState) {
635 if (!BestState)
637 }
638
639 /// See AbstractState::isValidState(...)
640 bool isValidState() const override { return true; }
641
642 /// See AbstractState::isAtFixpoint(...)
643 bool isAtFixpoint() const override { return IsAtFixpoint; }
644
645 /// See AbstractState::indicatePessimisticFixpoint(...)
647 IsAtFixpoint = true;
648 ParallelLevels.indicatePessimisticFixpoint();
649 ReachingKernelEntries.indicatePessimisticFixpoint();
650 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
651 ReachedKnownParallelRegions.indicatePessimisticFixpoint();
652 ReachedUnknownParallelRegions.indicatePessimisticFixpoint();
654 }
655
656 /// See AbstractState::indicateOptimisticFixpoint(...)
658 IsAtFixpoint = true;
659 ParallelLevels.indicateOptimisticFixpoint();
660 ReachingKernelEntries.indicateOptimisticFixpoint();
661 SPMDCompatibilityTracker.indicateOptimisticFixpoint();
662 ReachedKnownParallelRegions.indicateOptimisticFixpoint();
663 ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
665 }
666
667 /// Return the assumed state
668 KernelInfoState &getAssumed() { return *this; }
669 const KernelInfoState &getAssumed() const { return *this; }
670
671 bool operator==(const KernelInfoState &RHS) const {
672 if (SPMDCompatibilityTracker != RHS.SPMDCompatibilityTracker)
673 return false;
674 if (ReachedKnownParallelRegions != RHS.ReachedKnownParallelRegions)
675 return false;
676 if (ReachedUnknownParallelRegions != RHS.ReachedUnknownParallelRegions)
677 return false;
678 if (ReachingKernelEntries != RHS.ReachingKernelEntries)
679 return false;
680 if (ParallelLevels != RHS.ParallelLevels)
681 return false;
682 return true;
683 }
684
685 /// Returns true if this kernel contains any OpenMP parallel regions.
686 bool mayContainParallelRegion() {
687 return !ReachedKnownParallelRegions.empty() ||
688 !ReachedUnknownParallelRegions.empty();
689 }
690
691 /// Return empty set as the best state of potential values.
692 static KernelInfoState getBestState() { return KernelInfoState(true); }
693
694 static KernelInfoState getBestState(KernelInfoState &KIS) {
695 return getBestState();
696 }
697
698 /// Return full set as the worst state of potential values.
699 static KernelInfoState getWorstState() { return KernelInfoState(false); }
700
701 /// "Clamp" this state with \p KIS.
702 KernelInfoState operator^=(const KernelInfoState &KIS) {
703 // Do not merge two different _init and _deinit call sites.
704 if (KIS.KernelInitCB) {
705 if (KernelInitCB && KernelInitCB != KIS.KernelInitCB)
706 llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
707 "assumptions.");
708 KernelInitCB = KIS.KernelInitCB;
709 }
710 if (KIS.KernelDeinitCB) {
711 if (KernelDeinitCB && KernelDeinitCB != KIS.KernelDeinitCB)
712 llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
713 "assumptions.");
714 KernelDeinitCB = KIS.KernelDeinitCB;
715 }
716 SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker;
717 ReachedKnownParallelRegions ^= KIS.ReachedKnownParallelRegions;
718 ReachedUnknownParallelRegions ^= KIS.ReachedUnknownParallelRegions;
719 NestedParallelism |= KIS.NestedParallelism;
720 return *this;
721 }
722
723 KernelInfoState operator&=(const KernelInfoState &KIS) {
724 return (*this ^= KIS);
725 }
726
727 ///}
728};
729
730/// Used to map the values physically (in the IR) stored in an offload
731/// array, to a vector in memory.
732struct OffloadArray {
733 /// Physical array (in the IR).
734 AllocaInst *Array = nullptr;
735 /// Mapped values.
736 SmallVector<Value *, 8> StoredValues;
737 /// Last stores made in the offload array.
738 SmallVector<StoreInst *, 8> LastAccesses;
739
740 OffloadArray() = default;
741
742 /// Initializes the OffloadArray with the values stored in \p Array before
743 /// instruction \p Before is reached. Returns false if the initialization
744 /// fails.
745 /// This MUST be used immediately after the construction of the object.
746 bool initialize(AllocaInst &Array, Instruction &Before) {
747 if (!Array.getAllocatedType()->isArrayTy())
748 return false;
749
750 if (!getValues(Array, Before))
751 return false;
752
753 this->Array = &Array;
754 return true;
755 }
756
757 static const unsigned DeviceIDArgNum = 1;
758 static const unsigned BasePtrsArgNum = 3;
759 static const unsigned PtrsArgNum = 4;
760 static const unsigned SizesArgNum = 5;
761
762private:
763 /// Traverses the BasicBlock where \p Array is, collecting the stores made to
764 /// \p Array, leaving StoredValues with the values stored before the
765 /// instruction \p Before is reached.
766 bool getValues(AllocaInst &Array, Instruction &Before) {
767 // Initialize container.
768 const uint64_t NumValues = Array.getAllocatedType()->getArrayNumElements();
769 StoredValues.assign(NumValues, nullptr);
770 LastAccesses.assign(NumValues, nullptr);
771
772 // TODO: This assumes the instruction \p Before is in the same
773 // BasicBlock as Array. Make it general, for any control flow graph.
774 BasicBlock *BB = Array.getParent();
775 if (BB != Before.getParent())
776 return false;
777
778 const DataLayout &DL = Array.getModule()->getDataLayout();
779 const unsigned int PointerSize = DL.getPointerSize();
780
781 for (Instruction &I : *BB) {
782 if (&I == &Before)
783 break;
784
785 if (!isa<StoreInst>(&I))
786 continue;
787
788 auto *S = cast<StoreInst>(&I);
789 int64_t Offset = -1;
790 auto *Dst =
791 GetPointerBaseWithConstantOffset(S->getPointerOperand(), Offset, DL);
792 if (Dst == &Array) {
793 int64_t Idx = Offset / PointerSize;
794 StoredValues[Idx] = getUnderlyingObject(S->getValueOperand());
795 LastAccesses[Idx] = S;
796 }
797 }
798
799 return isFilled();
800 }
801
802 /// Returns true if all values in StoredValues and
803 /// LastAccesses are not nullptrs.
804 bool isFilled() {
805 const unsigned NumValues = StoredValues.size();
806 for (unsigned I = 0; I < NumValues; ++I) {
807 if (!StoredValues[I] || !LastAccesses[I])
808 return false;
809 }
810
811 return true;
812 }
813};
814
815struct OpenMPOpt {
816
817 using OptimizationRemarkGetter =
819
820 OpenMPOpt(SmallVectorImpl<Function *> &SCC, CallGraphUpdater &CGUpdater,
821 OptimizationRemarkGetter OREGetter,
822 OMPInformationCache &OMPInfoCache, Attributor &A)
823 : M(*(*SCC.begin())->getParent()), SCC(SCC), CGUpdater(CGUpdater),
824 OREGetter(OREGetter), OMPInfoCache(OMPInfoCache), A(A) {}
825
826 /// Check if any remarks are enabled for openmp-opt
827 bool remarksEnabled() {
828 auto &Ctx = M.getContext();
829 return Ctx.getDiagHandlerPtr()->isAnyRemarkEnabled(DEBUG_TYPE);
830 }
831
832 /// Run all OpenMP optimizations on the underlying SCC/ModuleSlice.
833 bool run(bool IsModulePass) {
834 if (SCC.empty())
835 return false;
836
837 bool Changed = false;
838
839 LLVM_DEBUG(dbgs() << TAG << "Run on SCC with " << SCC.size()
840 << " functions in a slice with "
841 << OMPInfoCache.ModuleSlice.size() << " functions\n");
842
843 if (IsModulePass) {
844 Changed |= runAttributor(IsModulePass);
845
846 // Recollect uses, in case Attributor deleted any.
847 OMPInfoCache.recollectUses();
848
849 // TODO: This should be folded into buildCustomStateMachine.
850 Changed |= rewriteDeviceCodeStateMachine();
851
852 if (remarksEnabled())
853 analysisGlobalization();
854 } else {
855 if (PrintICVValues)
856 printICVs();
858 printKernels();
859
860 Changed |= runAttributor(IsModulePass);
861
862 // Recollect uses, in case Attributor deleted any.
863 OMPInfoCache.recollectUses();
864
865 Changed |= deleteParallelRegions();
866
868 Changed |= hideMemTransfersLatency();
869 Changed |= deduplicateRuntimeCalls();
871 if (mergeParallelRegions()) {
872 deduplicateRuntimeCalls();
873 Changed = true;
874 }
875 }
876 }
877
878 return Changed;
879 }
880
881 /// Print initial ICV values for testing.
882 /// FIXME: This should be done from the Attributor once it is added.
883 void printICVs() const {
884 InternalControlVar ICVs[] = {ICV_nthreads, ICV_active_levels, ICV_cancel,
885 ICV_proc_bind};
886
887 for (Function *F : SCC) {
888 for (auto ICV : ICVs) {
889 auto ICVInfo = OMPInfoCache.ICVs[ICV];
890 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
891 return ORA << "OpenMP ICV " << ore::NV("OpenMPICV", ICVInfo.Name)
892 << " Value: "
893 << (ICVInfo.InitValue
894 ? toString(ICVInfo.InitValue->getValue(), 10, true)
895 : "IMPLEMENTATION_DEFINED");
896 };
897
898 emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPICVTracker", Remark);
899 }
900 }
901 }
902
903 /// Print OpenMP GPU kernels for testing.
904 void printKernels() const {
905 for (Function *F : SCC) {
906 if (!OMPInfoCache.Kernels.count(F))
907 continue;
908
909 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
910 return ORA << "OpenMP GPU kernel "
911 << ore::NV("OpenMPGPUKernel", F->getName()) << "\n";
912 };
913
914 emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPGPU", Remark);
915 }
916 }
917
918 /// Return the call if \p U is a callee use in a regular call. If \p RFI is
919 /// given it has to be the callee or a nullptr is returned.
920 static CallInst *getCallIfRegularCall(
921 Use &U, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
922 CallInst *CI = dyn_cast<CallInst>(U.getUser());
923 if (CI && CI->isCallee(&U) && !CI->hasOperandBundles() &&
924 (!RFI ||
925 (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
926 return CI;
927 return nullptr;
928 }
929
930 /// Return the call if \p V is a regular call. If \p RFI is given it has to be
931 /// the callee or a nullptr is returned.
932 static CallInst *getCallIfRegularCall(
933 Value &V, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
934 CallInst *CI = dyn_cast<CallInst>(&V);
935 if (CI && !CI->hasOperandBundles() &&
936 (!RFI ||
937 (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
938 return CI;
939 return nullptr;
940 }
941
942private:
943 /// Merge parallel regions when it is safe.
944 bool mergeParallelRegions() {
945 const unsigned CallbackCalleeOperand = 2;
946 const unsigned CallbackFirstArgOperand = 3;
947 using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
948
949 // Check if there are any __kmpc_fork_call calls to merge.
950 OMPInformationCache::RuntimeFunctionInfo &RFI =
951 OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
952
953 if (!RFI.Declaration)
954 return false;
955
956 // Unmergable calls that prevent merging a parallel region.
957 OMPInformationCache::RuntimeFunctionInfo UnmergableCallsInfo[] = {
958 OMPInfoCache.RFIs[OMPRTL___kmpc_push_proc_bind],
959 OMPInfoCache.RFIs[OMPRTL___kmpc_push_num_threads],
960 };
961
962 bool Changed = false;
963 LoopInfo *LI = nullptr;
964 DominatorTree *DT = nullptr;
965
967
968 BasicBlock *StartBB = nullptr, *EndBB = nullptr;
969 auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
970 BasicBlock *CGStartBB = CodeGenIP.getBlock();
971 BasicBlock *CGEndBB =
972 SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
973 assert(StartBB != nullptr && "StartBB should not be null");
974 CGStartBB->getTerminator()->setSuccessor(0, StartBB);
975 assert(EndBB != nullptr && "EndBB should not be null");
976 EndBB->getTerminator()->setSuccessor(0, CGEndBB);
977 };
978
979 auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, Value &,
980 Value &Inner, Value *&ReplacementValue) -> InsertPointTy {
981 ReplacementValue = &Inner;
982 return CodeGenIP;
983 };
984
985 auto FiniCB = [&](InsertPointTy CodeGenIP) {};
986
987 /// Create a sequential execution region within a merged parallel region,
988 /// encapsulated in a master construct with a barrier for synchronization.
989 auto CreateSequentialRegion = [&](Function *OuterFn,
990 BasicBlock *OuterPredBB,
991 Instruction *SeqStartI,
992 Instruction *SeqEndI) {
993 // Isolate the instructions of the sequential region to a separate
994 // block.
995 BasicBlock *ParentBB = SeqStartI->getParent();
996 BasicBlock *SeqEndBB =
997 SplitBlock(ParentBB, SeqEndI->getNextNode(), DT, LI);
998 BasicBlock *SeqAfterBB =
999 SplitBlock(SeqEndBB, &*SeqEndBB->getFirstInsertionPt(), DT, LI);
1000 BasicBlock *SeqStartBB =
1001 SplitBlock(ParentBB, SeqStartI, DT, LI, nullptr, "seq.par.merged");
1002
1003 assert(ParentBB->getUniqueSuccessor() == SeqStartBB &&
1004 "Expected a different CFG");
1005 const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
1006 ParentBB->getTerminator()->eraseFromParent();
1007
1008 auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
1009 BasicBlock *CGStartBB = CodeGenIP.getBlock();
1010 BasicBlock *CGEndBB =
1011 SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
1012 assert(SeqStartBB != nullptr && "SeqStartBB should not be null");
1013 CGStartBB->getTerminator()->setSuccessor(0, SeqStartBB);
1014 assert(SeqEndBB != nullptr && "SeqEndBB should not be null");
1015 SeqEndBB->getTerminator()->setSuccessor(0, CGEndBB);
1016 };
1017 auto FiniCB = [&](InsertPointTy CodeGenIP) {};
1018
1019 // Find outputs from the sequential region to outside users and
1020 // broadcast their values to them.
1021 for (Instruction &I : *SeqStartBB) {
1022 SmallPtrSet<Instruction *, 4> OutsideUsers;
1023 for (User *Usr : I.users()) {
1024 Instruction &UsrI = *cast<Instruction>(Usr);
1025 // Ignore outputs to LT intrinsics, code extraction for the merged
1026 // parallel region will fix them.
1027 if (UsrI.isLifetimeStartOrEnd())
1028 continue;
1029
1030 if (UsrI.getParent() != SeqStartBB)
1031 OutsideUsers.insert(&UsrI);
1032 }
1033
1034 if (OutsideUsers.empty())
1035 continue;
1036
1037 // Emit an alloca in the outer region to store the broadcasted
1038 // value.
1039 const DataLayout &DL = M.getDataLayout();
1040 AllocaInst *AllocaI = new AllocaInst(
1041 I.getType(), DL.getAllocaAddrSpace(), nullptr,
1042 I.getName() + ".seq.output.alloc", &OuterFn->front().front());
1043
1044 // Emit a store instruction in the sequential BB to update the
1045 // value.
1046 new StoreInst(&I, AllocaI, SeqStartBB->getTerminator());
1047
1048 // Emit a load instruction and replace the use of the output value
1049 // with it.
1050 for (Instruction *UsrI : OutsideUsers) {
1051 LoadInst *LoadI = new LoadInst(
1052 I.getType(), AllocaI, I.getName() + ".seq.output.load", UsrI);
1053 UsrI->replaceUsesOfWith(&I, LoadI);
1054 }
1055 }
1056
1058 InsertPointTy(ParentBB, ParentBB->end()), DL);
1059 InsertPointTy SeqAfterIP =
1060 OMPInfoCache.OMPBuilder.createMaster(Loc, BodyGenCB, FiniCB);
1061
1062 OMPInfoCache.OMPBuilder.createBarrier(SeqAfterIP, OMPD_parallel);
1063
1064 BranchInst::Create(SeqAfterBB, SeqAfterIP.getBlock());
1065
1066 LLVM_DEBUG(dbgs() << TAG << "After sequential inlining " << *OuterFn
1067 << "\n");
1068 };
1069
1070 // Helper to merge the __kmpc_fork_call calls in MergableCIs. They are all
1071 // contained in BB and only separated by instructions that can be
1072 // redundantly executed in parallel. The block BB is split before the first
1073 // call (in MergableCIs) and after the last so the entire region we merge
1074 // into a single parallel region is contained in a single basic block
1075 // without any other instructions. We use the OpenMPIRBuilder to outline
1076 // that block and call the resulting function via __kmpc_fork_call.
1077 auto Merge = [&](const SmallVectorImpl<CallInst *> &MergableCIs,
1078 BasicBlock *BB) {
1079 // TODO: Change the interface to allow single CIs expanded, e.g, to
1080 // include an outer loop.
1081 assert(MergableCIs.size() > 1 && "Assumed multiple mergable CIs");
1082
1083 auto Remark = [&](OptimizationRemark OR) {
1084 OR << "Parallel region merged with parallel region"
1085 << (MergableCIs.size() > 2 ? "s" : "") << " at ";
1086 for (auto *CI : llvm::drop_begin(MergableCIs)) {
1087 OR << ore::NV("OpenMPParallelMerge", CI->getDebugLoc());
1088 if (CI != MergableCIs.back())
1089 OR << ", ";
1090 }
1091 return OR << ".";
1092 };
1093
1094 emitRemark<OptimizationRemark>(MergableCIs.front(), "OMP150", Remark);
1095
1096 Function *OriginalFn = BB->getParent();
1097 LLVM_DEBUG(dbgs() << TAG << "Merge " << MergableCIs.size()
1098 << " parallel regions in " << OriginalFn->getName()
1099 << "\n");
1100
1101 // Isolate the calls to merge in a separate block.
1102 EndBB = SplitBlock(BB, MergableCIs.back()->getNextNode(), DT, LI);
1103 BasicBlock *AfterBB =
1104 SplitBlock(EndBB, &*EndBB->getFirstInsertionPt(), DT, LI);
1105 StartBB = SplitBlock(BB, MergableCIs.front(), DT, LI, nullptr,
1106 "omp.par.merged");
1107
1108 assert(BB->getUniqueSuccessor() == StartBB && "Expected a different CFG");
1109 const DebugLoc DL = BB->getTerminator()->getDebugLoc();
1110 BB->getTerminator()->eraseFromParent();
1111
1112 // Create sequential regions for sequential instructions that are
1113 // in-between mergable parallel regions.
1114 for (auto *It = MergableCIs.begin(), *End = MergableCIs.end() - 1;
1115 It != End; ++It) {
1116 Instruction *ForkCI = *It;
1117 Instruction *NextForkCI = *(It + 1);
1118
1119 // Continue if there are not in-between instructions.
1120 if (ForkCI->getNextNode() == NextForkCI)
1121 continue;
1122
1123 CreateSequentialRegion(OriginalFn, BB, ForkCI->getNextNode(),
1124 NextForkCI->getPrevNode());
1125 }
1126
1127 OpenMPIRBuilder::LocationDescription Loc(InsertPointTy(BB, BB->end()),
1128 DL);
1129 IRBuilder<>::InsertPoint AllocaIP(
1130 &OriginalFn->getEntryBlock(),
1131 OriginalFn->getEntryBlock().getFirstInsertionPt());
1132 // Create the merged parallel region with default proc binding, to
1133 // avoid overriding binding settings, and without explicit cancellation.
1134 InsertPointTy AfterIP = OMPInfoCache.OMPBuilder.createParallel(
1135 Loc, AllocaIP, BodyGenCB, PrivCB, FiniCB, nullptr, nullptr,
1136 OMP_PROC_BIND_default, /* IsCancellable */ false);
1137 BranchInst::Create(AfterBB, AfterIP.getBlock());
1138
1139 // Perform the actual outlining.
1140 OMPInfoCache.OMPBuilder.finalize(OriginalFn);
1141
1142 Function *OutlinedFn = MergableCIs.front()->getCaller();
1143
1144 // Replace the __kmpc_fork_call calls with direct calls to the outlined
1145 // callbacks.
1147 for (auto *CI : MergableCIs) {
1148 Value *Callee = CI->getArgOperand(CallbackCalleeOperand);
1149 FunctionType *FT = OMPInfoCache.OMPBuilder.ParallelTask;
1150 Args.clear();
1151 Args.push_back(OutlinedFn->getArg(0));
1152 Args.push_back(OutlinedFn->getArg(1));
1153 for (unsigned U = CallbackFirstArgOperand, E = CI->arg_size(); U < E;
1154 ++U)
1155 Args.push_back(CI->getArgOperand(U));
1156
1157 CallInst *NewCI = CallInst::Create(FT, Callee, Args, "", CI);
1158 if (CI->getDebugLoc())
1159 NewCI->setDebugLoc(CI->getDebugLoc());
1160
1161 // Forward parameter attributes from the callback to the callee.
1162 for (unsigned U = CallbackFirstArgOperand, E = CI->arg_size(); U < E;
1163 ++U)
1164 for (const Attribute &A : CI->getAttributes().getParamAttrs(U))
1165 NewCI->addParamAttr(
1166 U - (CallbackFirstArgOperand - CallbackCalleeOperand), A);
1167
1168 // Emit an explicit barrier to replace the implicit fork-join barrier.
1169 if (CI != MergableCIs.back()) {
1170 // TODO: Remove barrier if the merged parallel region includes the
1171 // 'nowait' clause.
1172 OMPInfoCache.OMPBuilder.createBarrier(
1173 InsertPointTy(NewCI->getParent(),
1174 NewCI->getNextNode()->getIterator()),
1175 OMPD_parallel);
1176 }
1177
1178 CI->eraseFromParent();
1179 }
1180
1181 assert(OutlinedFn != OriginalFn && "Outlining failed");
1182 CGUpdater.registerOutlinedFunction(*OriginalFn, *OutlinedFn);
1183 CGUpdater.reanalyzeFunction(*OriginalFn);
1184
1185 NumOpenMPParallelRegionsMerged += MergableCIs.size();
1186
1187 return true;
1188 };
1189
1190 // Helper function that identifes sequences of
1191 // __kmpc_fork_call uses in a basic block.
1192 auto DetectPRsCB = [&](Use &U, Function &F) {
1193 CallInst *CI = getCallIfRegularCall(U, &RFI);
1194 BB2PRMap[CI->getParent()].insert(CI);
1195
1196 return false;
1197 };
1198
1199 BB2PRMap.clear();
1200 RFI.foreachUse(SCC, DetectPRsCB);
1201 SmallVector<SmallVector<CallInst *, 4>, 4> MergableCIsVector;
1202 // Find mergable parallel regions within a basic block that are
1203 // safe to merge, that is any in-between instructions can safely
1204 // execute in parallel after merging.
1205 // TODO: support merging across basic-blocks.
1206 for (auto &It : BB2PRMap) {
1207 auto &CIs = It.getSecond();
1208 if (CIs.size() < 2)
1209 continue;
1210
1211 BasicBlock *BB = It.getFirst();
1212 SmallVector<CallInst *, 4> MergableCIs;
1213
1214 /// Returns true if the instruction is mergable, false otherwise.
1215 /// A terminator instruction is unmergable by definition since merging
1216 /// works within a BB. Instructions before the mergable region are
1217 /// mergable if they are not calls to OpenMP runtime functions that may
1218 /// set different execution parameters for subsequent parallel regions.
1219 /// Instructions in-between parallel regions are mergable if they are not
1220 /// calls to any non-intrinsic function since that may call a non-mergable
1221 /// OpenMP runtime function.
1222 auto IsMergable = [&](Instruction &I, bool IsBeforeMergableRegion) {
1223 // We do not merge across BBs, hence return false (unmergable) if the
1224 // instruction is a terminator.
1225 if (I.isTerminator())
1226 return false;
1227
1228 if (!isa<CallInst>(&I))
1229 return true;
1230
1231 CallInst *CI = cast<CallInst>(&I);
1232 if (IsBeforeMergableRegion) {
1233 Function *CalledFunction = CI->getCalledFunction();
1234 if (!CalledFunction)
1235 return false;
1236 // Return false (unmergable) if the call before the parallel
1237 // region calls an explicit affinity (proc_bind) or number of
1238 // threads (num_threads) compiler-generated function. Those settings
1239 // may be incompatible with following parallel regions.
1240 // TODO: ICV tracking to detect compatibility.
1241 for (const auto &RFI : UnmergableCallsInfo) {
1242 if (CalledFunction == RFI.Declaration)
1243 return false;
1244 }
1245 } else {
1246 // Return false (unmergable) if there is a call instruction
1247 // in-between parallel regions when it is not an intrinsic. It
1248 // may call an unmergable OpenMP runtime function in its callpath.
1249 // TODO: Keep track of possible OpenMP calls in the callpath.
1250 if (!isa<IntrinsicInst>(CI))
1251 return false;
1252 }
1253
1254 return true;
1255 };
1256 // Find maximal number of parallel region CIs that are safe to merge.
1257 for (auto It = BB->begin(), End = BB->end(); It != End;) {
1258 Instruction &I = *It;
1259 ++It;
1260
1261 if (CIs.count(&I)) {
1262 MergableCIs.push_back(cast<CallInst>(&I));
1263 continue;
1264 }
1265
1266 // Continue expanding if the instruction is mergable.
1267 if (IsMergable(I, MergableCIs.empty()))
1268 continue;
1269
1270 // Forward the instruction iterator to skip the next parallel region
1271 // since there is an unmergable instruction which can affect it.
1272 for (; It != End; ++It) {
1273 Instruction &SkipI = *It;
1274 if (CIs.count(&SkipI)) {
1275 LLVM_DEBUG(dbgs() << TAG << "Skip parallel region " << SkipI
1276 << " due to " << I << "\n");
1277 ++It;
1278 break;
1279 }
1280 }
1281
1282 // Store mergable regions found.
1283 if (MergableCIs.size() > 1) {
1284 MergableCIsVector.push_back(MergableCIs);
1285 LLVM_DEBUG(dbgs() << TAG << "Found " << MergableCIs.size()
1286 << " parallel regions in block " << BB->getName()
1287 << " of function " << BB->getParent()->getName()
1288 << "\n";);
1289 }
1290
1291 MergableCIs.clear();
1292 }
1293
1294 if (!MergableCIsVector.empty()) {
1295 Changed = true;
1296
1297 for (auto &MergableCIs : MergableCIsVector)
1298 Merge(MergableCIs, BB);
1299 MergableCIsVector.clear();
1300 }
1301 }
1302
1303 if (Changed) {
1304 /// Re-collect use for fork calls, emitted barrier calls, and
1305 /// any emitted master/end_master calls.
1306 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_fork_call);
1307 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_barrier);
1308 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_master);
1309 OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_end_master);
1310 }
1311
1312 return Changed;
1313 }
1314
1315 /// Try to delete parallel regions if possible.
1316 bool deleteParallelRegions() {
1317 const unsigned CallbackCalleeOperand = 2;
1318
1319 OMPInformationCache::RuntimeFunctionInfo &RFI =
1320 OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
1321
1322 if (!RFI.Declaration)
1323 return false;
1324
1325 bool Changed = false;
1326 auto DeleteCallCB = [&](Use &U, Function &) {
1327 CallInst *CI = getCallIfRegularCall(U);
1328 if (!CI)
1329 return false;
1330 auto *Fn = dyn_cast<Function>(
1331 CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts());
1332 if (!Fn)
1333 return false;
1334 if (!Fn->onlyReadsMemory())
1335 return false;
1336 if (!Fn->hasFnAttribute(Attribute::WillReturn))
1337 return false;
1338
1339 LLVM_DEBUG(dbgs() << TAG << "Delete read-only parallel region in "
1340 << CI->getCaller()->getName() << "\n");
1341
1342 auto Remark = [&](OptimizationRemark OR) {
1343 return OR << "Removing parallel region with no side-effects.";
1344 };
1345 emitRemark<OptimizationRemark>(CI, "OMP160", Remark);
1346
1347 CGUpdater.removeCallSite(*CI);
1348 CI->eraseFromParent();
1349 Changed = true;
1350 ++NumOpenMPParallelRegionsDeleted;
1351 return true;
1352 };
1353
1354 RFI.foreachUse(SCC, DeleteCallCB);
1355
1356 return Changed;
1357 }
1358
1359 /// Try to eliminate runtime calls by reusing existing ones.
1360 bool deduplicateRuntimeCalls() {
1361 bool Changed = false;
1362
1363 RuntimeFunction DeduplicableRuntimeCallIDs[] = {
1364 OMPRTL_omp_get_num_threads,
1365 OMPRTL_omp_in_parallel,
1366 OMPRTL_omp_get_cancellation,
1367 OMPRTL_omp_get_thread_limit,
1368 OMPRTL_omp_get_supported_active_levels,
1369 OMPRTL_omp_get_level,
1370 OMPRTL_omp_get_ancestor_thread_num,
1371 OMPRTL_omp_get_team_size,
1372 OMPRTL_omp_get_active_level,
1373 OMPRTL_omp_in_final,
1374 OMPRTL_omp_get_proc_bind,
1375 OMPRTL_omp_get_num_places,
1376 OMPRTL_omp_get_num_procs,
1377 OMPRTL_omp_get_place_num,
1378 OMPRTL_omp_get_partition_num_places,
1379 OMPRTL_omp_get_partition_place_nums};
1380
1381 // Global-tid is handled separately.
1383 collectGlobalThreadIdArguments(GTIdArgs);
1384 LLVM_DEBUG(dbgs() << TAG << "Found " << GTIdArgs.size()
1385 << " global thread ID arguments\n");
1386
1387 for (Function *F : SCC) {
1388 for (auto DeduplicableRuntimeCallID : DeduplicableRuntimeCallIDs)
1389 Changed |= deduplicateRuntimeCalls(
1390 *F, OMPInfoCache.RFIs[DeduplicableRuntimeCallID]);
1391
1392 // __kmpc_global_thread_num is special as we can replace it with an
1393 // argument in enough cases to make it worth trying.
1394 Value *GTIdArg = nullptr;
1395 for (Argument &Arg : F->args())
1396 if (GTIdArgs.count(&Arg)) {
1397 GTIdArg = &Arg;
1398 break;
1399 }
1400 Changed |= deduplicateRuntimeCalls(
1401 *F, OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num], GTIdArg);
1402 }
1403
1404 return Changed;
1405 }
1406
1407 /// Tries to hide the latency of runtime calls that involve host to
1408 /// device memory transfers by splitting them into their "issue" and "wait"
1409 /// versions. The "issue" is moved upwards as much as possible. The "wait" is
1410 /// moved downards as much as possible. The "issue" issues the memory transfer
1411 /// asynchronously, returning a handle. The "wait" waits in the returned
1412 /// handle for the memory transfer to finish.
1413 bool hideMemTransfersLatency() {
1414 auto &RFI = OMPInfoCache.RFIs[OMPRTL___tgt_target_data_begin_mapper];
1415 bool Changed = false;
1416 auto SplitMemTransfers = [&](Use &U, Function &Decl) {
1417 auto *RTCall = getCallIfRegularCall(U, &RFI);
1418 if (!RTCall)
1419 return false;
1420
1421 OffloadArray OffloadArrays[3];
1422 if (!getValuesInOffloadArrays(*RTCall, OffloadArrays))
1423 return false;
1424
1425 LLVM_DEBUG(dumpValuesInOffloadArrays(OffloadArrays));
1426
1427 // TODO: Check if can be moved upwards.
1428 bool WasSplit = false;
1429 Instruction *WaitMovementPoint = canBeMovedDownwards(*RTCall);
1430 if (WaitMovementPoint)
1431 WasSplit = splitTargetDataBeginRTC(*RTCall, *WaitMovementPoint);
1432
1433 Changed |= WasSplit;
1434 return WasSplit;
1435 };
1436 if (OMPInfoCache.runtimeFnsAvailable(
1437 {OMPRTL___tgt_target_data_begin_mapper_issue,
1438 OMPRTL___tgt_target_data_begin_mapper_wait}))
1439 RFI.foreachUse(SCC, SplitMemTransfers);
1440
1441 return Changed;
1442 }
1443
1444 void analysisGlobalization() {
1445 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
1446
1447 auto CheckGlobalization = [&](Use &U, Function &Decl) {
1448 if (CallInst *CI = getCallIfRegularCall(U, &RFI)) {
1449 auto Remark = [&](OptimizationRemarkMissed ORM) {
1450 return ORM
1451 << "Found thread data sharing on the GPU. "
1452 << "Expect degraded performance due to data globalization.";
1453 };
1454 emitRemark<OptimizationRemarkMissed>(CI, "OMP112", Remark);
1455 }
1456
1457 return false;
1458 };
1459
1460 RFI.foreachUse(SCC, CheckGlobalization);
1461 }
1462
1463 /// Maps the values stored in the offload arrays passed as arguments to
1464 /// \p RuntimeCall into the offload arrays in \p OAs.
1465 bool getValuesInOffloadArrays(CallInst &RuntimeCall,
1467 assert(OAs.size() == 3 && "Need space for three offload arrays!");
1468
1469 // A runtime call that involves memory offloading looks something like:
1470 // call void @__tgt_target_data_begin_mapper(arg0, arg1,
1471 // i8** %offload_baseptrs, i8** %offload_ptrs, i64* %offload_sizes,
1472 // ...)
1473 // So, the idea is to access the allocas that allocate space for these
1474 // offload arrays, offload_baseptrs, offload_ptrs, offload_sizes.
1475 // Therefore:
1476 // i8** %offload_baseptrs.
1477 Value *BasePtrsArg =
1478 RuntimeCall.getArgOperand(OffloadArray::BasePtrsArgNum);
1479 // i8** %offload_ptrs.
1480 Value *PtrsArg = RuntimeCall.getArgOperand(OffloadArray::PtrsArgNum);
1481 // i8** %offload_sizes.
1482 Value *SizesArg = RuntimeCall.getArgOperand(OffloadArray::SizesArgNum);
1483
1484 // Get values stored in **offload_baseptrs.
1485 auto *V = getUnderlyingObject(BasePtrsArg);
1486 if (!isa<AllocaInst>(V))
1487 return false;
1488 auto *BasePtrsArray = cast<AllocaInst>(V);
1489 if (!OAs[0].initialize(*BasePtrsArray, RuntimeCall))
1490 return false;
1491
1492 // Get values stored in **offload_baseptrs.
1493 V = getUnderlyingObject(PtrsArg);
1494 if (!isa<AllocaInst>(V))
1495 return false;
1496 auto *PtrsArray = cast<AllocaInst>(V);
1497 if (!OAs[1].initialize(*PtrsArray, RuntimeCall))
1498 return false;
1499
1500 // Get values stored in **offload_sizes.
1501 V = getUnderlyingObject(SizesArg);
1502 // If it's a [constant] global array don't analyze it.
1503 if (isa<GlobalValue>(V))
1504 return isa<Constant>(V);
1505 if (!isa<AllocaInst>(V))
1506 return false;
1507
1508 auto *SizesArray = cast<AllocaInst>(V);
1509 if (!OAs[2].initialize(*SizesArray, RuntimeCall))
1510 return false;
1511
1512 return true;
1513 }
1514
1515 /// Prints the values in the OffloadArrays \p OAs using LLVM_DEBUG.
1516 /// For now this is a way to test that the function getValuesInOffloadArrays
1517 /// is working properly.
1518 /// TODO: Move this to a unittest when unittests are available for OpenMPOpt.
1519 void dumpValuesInOffloadArrays(ArrayRef<OffloadArray> OAs) {
1520 assert(OAs.size() == 3 && "There are three offload arrays to debug!");
1521
1522 LLVM_DEBUG(dbgs() << TAG << " Successfully got offload values:\n");
1523 std::string ValuesStr;
1524 raw_string_ostream Printer(ValuesStr);
1525 std::string Separator = " --- ";
1526
1527 for (auto *BP : OAs[0].StoredValues) {
1528 BP->print(Printer);
1529 Printer << Separator;
1530 }
1531 LLVM_DEBUG(dbgs() << "\t\toffload_baseptrs: " << Printer.str() << "\n");
1532 ValuesStr.clear();
1533
1534 for (auto *P : OAs[1].StoredValues) {
1535 P->print(Printer);
1536 Printer << Separator;
1537 }
1538 LLVM_DEBUG(dbgs() << "\t\toffload_ptrs: " << Printer.str() << "\n");
1539 ValuesStr.clear();
1540
1541 for (auto *S : OAs[2].StoredValues) {
1542 S->print(Printer);
1543 Printer << Separator;
1544 }
1545 LLVM_DEBUG(dbgs() << "\t\toffload_sizes: " << Printer.str() << "\n");
1546 }
1547
1548 /// Returns the instruction where the "wait" counterpart \p RuntimeCall can be
1549 /// moved. Returns nullptr if the movement is not possible, or not worth it.
1550 Instruction *canBeMovedDownwards(CallInst &RuntimeCall) {
1551 // FIXME: This traverses only the BasicBlock where RuntimeCall is.
1552 // Make it traverse the CFG.
1553
1554 Instruction *CurrentI = &RuntimeCall;
1555 bool IsWorthIt = false;
1556 while ((CurrentI = CurrentI->getNextNode())) {
1557
1558 // TODO: Once we detect the regions to be offloaded we should use the
1559 // alias analysis manager to check if CurrentI may modify one of
1560 // the offloaded regions.
1561 if (CurrentI->mayHaveSideEffects() || CurrentI->mayReadFromMemory()) {
1562 if (IsWorthIt)
1563 return CurrentI;
1564
1565 return nullptr;
1566 }
1567
1568 // FIXME: For now if we move it over anything without side effect
1569 // is worth it.
1570 IsWorthIt = true;
1571 }
1572
1573 // Return end of BasicBlock.
1574 return RuntimeCall.getParent()->getTerminator();
1575 }
1576
1577 /// Splits \p RuntimeCall into its "issue" and "wait" counterparts.
1578 bool splitTargetDataBeginRTC(CallInst &RuntimeCall,
1579 Instruction &WaitMovementPoint) {
1580 // Create stack allocated handle (__tgt_async_info) at the beginning of the
1581 // function. Used for storing information of the async transfer, allowing to
1582 // wait on it later.
1583 auto &IRBuilder = OMPInfoCache.OMPBuilder;
1584 Function *F = RuntimeCall.getCaller();
1585 BasicBlock &Entry = F->getEntryBlock();
1586 IRBuilder.Builder.SetInsertPoint(&Entry,
1587 Entry.getFirstNonPHIOrDbgOrAlloca());
1588 Value *Handle = IRBuilder.Builder.CreateAlloca(
1589 IRBuilder.AsyncInfo, /*ArraySize=*/nullptr, "handle");
1590 Handle =
1591 IRBuilder.Builder.CreateAddrSpaceCast(Handle, IRBuilder.AsyncInfoPtr);
1592
1593 // Add "issue" runtime call declaration:
1594 // declare %struct.tgt_async_info @__tgt_target_data_begin_issue(i64, i32,
1595 // i8**, i8**, i64*, i64*)
1596 FunctionCallee IssueDecl = IRBuilder.getOrCreateRuntimeFunction(
1597 M, OMPRTL___tgt_target_data_begin_mapper_issue);
1598
1599 // Change RuntimeCall call site for its asynchronous version.
1601 for (auto &Arg : RuntimeCall.args())
1602 Args.push_back(Arg.get());
1603 Args.push_back(Handle);
1604
1605 CallInst *IssueCallsite =
1606 CallInst::Create(IssueDecl, Args, /*NameStr=*/"", &RuntimeCall);
1607 OMPInfoCache.setCallingConvention(IssueDecl, IssueCallsite);
1608 RuntimeCall.eraseFromParent();
1609
1610 // Add "wait" runtime call declaration:
1611 // declare void @__tgt_target_data_begin_wait(i64, %struct.__tgt_async_info)
1612 FunctionCallee WaitDecl = IRBuilder.getOrCreateRuntimeFunction(
1613 M, OMPRTL___tgt_target_data_begin_mapper_wait);
1614
1615 Value *WaitParams[2] = {
1616 IssueCallsite->getArgOperand(
1617 OffloadArray::DeviceIDArgNum), // device_id.
1618 Handle // handle to wait on.
1619 };
1620 CallInst *WaitCallsite = CallInst::Create(
1621 WaitDecl, WaitParams, /*NameStr=*/"", &WaitMovementPoint);
1622 OMPInfoCache.setCallingConvention(WaitDecl, WaitCallsite);
1623
1624 return true;
1625 }
1626
1627 static Value *combinedIdentStruct(Value *CurrentIdent, Value *NextIdent,
1628 bool GlobalOnly, bool &SingleChoice) {
1629 if (CurrentIdent == NextIdent)
1630 return CurrentIdent;
1631
1632 // TODO: Figure out how to actually combine multiple debug locations. For
1633 // now we just keep an existing one if there is a single choice.
1634 if (!GlobalOnly || isa<GlobalValue>(NextIdent)) {
1635 SingleChoice = !CurrentIdent;
1636 return NextIdent;
1637 }
1638 return nullptr;
1639 }
1640
1641 /// Return an `struct ident_t*` value that represents the ones used in the
1642 /// calls of \p RFI inside of \p F. If \p GlobalOnly is true, we will not
1643 /// return a local `struct ident_t*`. For now, if we cannot find a suitable
1644 /// return value we create one from scratch. We also do not yet combine
1645 /// information, e.g., the source locations, see combinedIdentStruct.
1646 Value *
1647 getCombinedIdentFromCallUsesIn(OMPInformationCache::RuntimeFunctionInfo &RFI,
1648 Function &F, bool GlobalOnly) {
1649 bool SingleChoice = true;
1650 Value *Ident = nullptr;
1651 auto CombineIdentStruct = [&](Use &U, Function &Caller) {
1652 CallInst *CI = getCallIfRegularCall(U, &RFI);
1653 if (!CI || &F != &Caller)
1654 return false;
1655 Ident = combinedIdentStruct(Ident, CI->getArgOperand(0),
1656 /* GlobalOnly */ true, SingleChoice);
1657 return false;
1658 };
1659 RFI.foreachUse(SCC, CombineIdentStruct);
1660
1661 if (!Ident || !SingleChoice) {
1662 // The IRBuilder uses the insertion block to get to the module, this is
1663 // unfortunate but we work around it for now.
1664 if (!OMPInfoCache.OMPBuilder.getInsertionPoint().getBlock())
1665 OMPInfoCache.OMPBuilder.updateToLocation(OpenMPIRBuilder::InsertPointTy(
1666 &F.getEntryBlock(), F.getEntryBlock().begin()));
1667 // Create a fallback location if non was found.
1668 // TODO: Use the debug locations of the calls instead.
1669 uint32_t SrcLocStrSize;
1670 Constant *Loc =
1671 OMPInfoCache.OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
1672 Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(Loc, SrcLocStrSize);
1673 }
1674 return Ident;
1675 }
1676
1677 /// Try to eliminate calls of \p RFI in \p F by reusing an existing one or
1678 /// \p ReplVal if given.
1679 bool deduplicateRuntimeCalls(Function &F,
1680 OMPInformationCache::RuntimeFunctionInfo &RFI,
1681 Value *ReplVal = nullptr) {
1682 auto *UV = RFI.getUseVector(F);
1683 if (!UV || UV->size() + (ReplVal != nullptr) < 2)
1684 return false;
1685
1686 LLVM_DEBUG(
1687 dbgs() << TAG << "Deduplicate " << UV->size() << " uses of " << RFI.Name
1688 << (ReplVal ? " with an existing value\n" : "\n") << "\n");
1689
1690 assert((!ReplVal || (isa<Argument>(ReplVal) &&
1691 cast<Argument>(ReplVal)->getParent() == &F)) &&
1692 "Unexpected replacement value!");
1693
1694 // TODO: Use dominance to find a good position instead.
1695 auto CanBeMoved = [this](CallBase &CB) {
1696 unsigned NumArgs = CB.arg_size();
1697 if (NumArgs == 0)
1698 return true;
1699 if (CB.getArgOperand(0)->getType() != OMPInfoCache.OMPBuilder.IdentPtr)
1700 return false;
1701 for (unsigned U = 1; U < NumArgs; ++U)
1702 if (isa<Instruction>(CB.getArgOperand(U)))
1703 return false;
1704 return true;
1705 };
1706
1707 if (!ReplVal) {
1708 for (Use *U : *UV)
1709 if (CallInst *CI = getCallIfRegularCall(*U, &RFI)) {
1710 if (!CanBeMoved(*CI))
1711 continue;
1712
1713 // If the function is a kernel, dedup will move
1714 // the runtime call right after the kernel init callsite. Otherwise,
1715 // it will move it to the beginning of the caller function.
1716 if (isKernel(F)) {
1717 auto &KernelInitRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
1718 auto *KernelInitUV = KernelInitRFI.getUseVector(F);
1719
1720 if (KernelInitUV->empty())
1721 continue;
1722
1723 assert(KernelInitUV->size() == 1 &&
1724 "Expected a single __kmpc_target_init in kernel\n");
1725
1726 CallInst *KernelInitCI =
1727 getCallIfRegularCall(*KernelInitUV->front(), &KernelInitRFI);
1728 assert(KernelInitCI &&
1729 "Expected a call to __kmpc_target_init in kernel\n");
1730
1731 CI->moveAfter(KernelInitCI);
1732 } else
1733 CI->moveBefore(&*F.getEntryBlock().getFirstInsertionPt());
1734 ReplVal = CI;
1735 break;
1736 }
1737 if (!ReplVal)
1738 return false;
1739 }
1740
1741 // If we use a call as a replacement value we need to make sure the ident is
1742 // valid at the new location. For now we just pick a global one, either
1743 // existing and used by one of the calls, or created from scratch.
1744 if (CallBase *CI = dyn_cast<CallBase>(ReplVal)) {
1745 if (!CI->arg_empty() &&
1746 CI->getArgOperand(0)->getType() == OMPInfoCache.OMPBuilder.IdentPtr) {
1747 Value *Ident = getCombinedIdentFromCallUsesIn(RFI, F,
1748 /* GlobalOnly */ true);
1749 CI->setArgOperand(0, Ident);
1750 }
1751 }
1752
1753 bool Changed = false;
1754 auto ReplaceAndDeleteCB = [&](Use &U, Function &Caller) {
1755 CallInst *CI = getCallIfRegularCall(U, &RFI);
1756 if (!CI || CI == ReplVal || &F != &Caller)
1757 return false;
1758 assert(CI->getCaller() == &F && "Unexpected call!");
1759
1760 auto Remark = [&](OptimizationRemark OR) {
1761 return OR << "OpenMP runtime call "
1762 << ore::NV("OpenMPOptRuntime", RFI.Name) << " deduplicated.";
1763 };
1764 if (CI->getDebugLoc())
1765 emitRemark<OptimizationRemark>(CI, "OMP170", Remark);
1766 else
1767 emitRemark<OptimizationRemark>(&F, "OMP170", Remark);
1768
1769 CGUpdater.removeCallSite(*CI);
1770 CI->replaceAllUsesWith(ReplVal);
1771 CI->eraseFromParent();
1772 ++NumOpenMPRuntimeCallsDeduplicated;
1773 Changed = true;
1774 return true;
1775 };
1776 RFI.foreachUse(SCC, ReplaceAndDeleteCB);
1777
1778 return Changed;
1779 }
1780
1781 /// Collect arguments that represent the global thread id in \p GTIdArgs.
1782 void collectGlobalThreadIdArguments(SmallSetVector<Value *, 16> &GTIdArgs) {
1783 // TODO: Below we basically perform a fixpoint iteration with a pessimistic
1784 // initialization. We could define an AbstractAttribute instead and
1785 // run the Attributor here once it can be run as an SCC pass.
1786
1787 // Helper to check the argument \p ArgNo at all call sites of \p F for
1788 // a GTId.
1789 auto CallArgOpIsGTId = [&](Function &F, unsigned ArgNo, CallInst &RefCI) {
1790 if (!F.hasLocalLinkage())
1791 return false;
1792 for (Use &U : F.uses()) {
1793 if (CallInst *CI = getCallIfRegularCall(U)) {
1794 Value *ArgOp = CI->getArgOperand(ArgNo);
1795 if (CI == &RefCI || GTIdArgs.count(ArgOp) ||
1796 getCallIfRegularCall(
1797 *ArgOp, &OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num]))
1798 continue;
1799 }
1800 return false;
1801 }
1802 return true;
1803 };
1804
1805 // Helper to identify uses of a GTId as GTId arguments.
1806 auto AddUserArgs = [&](Value &GTId) {
1807 for (Use &U : GTId.uses())
1808 if (CallInst *CI = dyn_cast<CallInst>(U.getUser()))
1809 if (CI->isArgOperand(&U))
1810 if (Function *Callee = CI->getCalledFunction())
1811 if (CallArgOpIsGTId(*Callee, U.getOperandNo(), *CI))
1812 GTIdArgs.insert(Callee->getArg(U.getOperandNo()));
1813 };
1814
1815 // The argument users of __kmpc_global_thread_num calls are GTIds.
1816 OMPInformationCache::RuntimeFunctionInfo &GlobThreadNumRFI =
1817 OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num];
1818
1819 GlobThreadNumRFI.foreachUse(SCC, [&](Use &U, Function &F) {
1820 if (CallInst *CI = getCallIfRegularCall(U, &GlobThreadNumRFI))
1821 AddUserArgs(*CI);
1822 return false;
1823 });
1824
1825 // Transitively search for more arguments by looking at the users of the
1826 // ones we know already. During the search the GTIdArgs vector is extended
1827 // so we cannot cache the size nor can we use a range based for.
1828 for (unsigned U = 0; U < GTIdArgs.size(); ++U)
1829 AddUserArgs(*GTIdArgs[U]);
1830 }
1831
1832 /// Kernel (=GPU) optimizations and utility functions
1833 ///
1834 ///{{
1835
1836 /// Check if \p F is a kernel, hence entry point for target offloading.
1837 bool isKernel(Function &F) { return OMPInfoCache.Kernels.count(&F); }
1838
1839 /// Cache to remember the unique kernel for a function.
1841
1842 /// Find the unique kernel that will execute \p F, if any.
1843 Kernel getUniqueKernelFor(Function &F);
1844
1845 /// Find the unique kernel that will execute \p I, if any.
1846 Kernel getUniqueKernelFor(Instruction &I) {
1847 return getUniqueKernelFor(*I.getFunction());
1848 }
1849
1850 /// Rewrite the device (=GPU) code state machine create in non-SPMD mode in
1851 /// the cases we can avoid taking the address of a function.
1852 bool rewriteDeviceCodeStateMachine();
1853
1854 ///
1855 ///}}
1856
1857 /// Emit a remark generically
1858 ///
1859 /// This template function can be used to generically emit a remark. The
1860 /// RemarkKind should be one of the following:
1861 /// - OptimizationRemark to indicate a successful optimization attempt
1862 /// - OptimizationRemarkMissed to report a failed optimization attempt
1863 /// - OptimizationRemarkAnalysis to provide additional information about an
1864 /// optimization attempt
1865 ///
1866 /// The remark is built using a callback function provided by the caller that
1867 /// takes a RemarkKind as input and returns a RemarkKind.
1868 template <typename RemarkKind, typename RemarkCallBack>
1869 void emitRemark(Instruction *I, StringRef RemarkName,
1870 RemarkCallBack &&RemarkCB) const {
1871 Function *F = I->getParent()->getParent();
1872 auto &ORE = OREGetter(F);
1873
1874 if (RemarkName.startswith("OMP"))
1875 ORE.emit([&]() {
1876 return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I))
1877 << " [" << RemarkName << "]";
1878 });
1879 else
1880 ORE.emit(
1881 [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I)); });
1882 }
1883
1884 /// Emit a remark on a function.
1885 template <typename RemarkKind, typename RemarkCallBack>
1886 void emitRemark(Function *F, StringRef RemarkName,
1887 RemarkCallBack &&RemarkCB) const {
1888 auto &ORE = OREGetter(F);
1889
1890 if (RemarkName.startswith("OMP"))
1891 ORE.emit([&]() {
1892 return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F))
1893 << " [" << RemarkName << "]";
1894 });
1895 else
1896 ORE.emit(
1897 [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F)); });
1898 }
1899
1900 /// The underlying module.
1901 Module &M;
1902
1903 /// The SCC we are operating on.
1905
1906 /// Callback to update the call graph, the first argument is a removed call,
1907 /// the second an optional replacement call.
1908 CallGraphUpdater &CGUpdater;
1909
1910 /// Callback to get an OptimizationRemarkEmitter from a Function *
1911 OptimizationRemarkGetter OREGetter;
1912
1913 /// OpenMP-specific information cache. Also Used for Attributor runs.
1914 OMPInformationCache &OMPInfoCache;
1915
1916 /// Attributor instance.
1917 Attributor &A;
1918
1919 /// Helper function to run Attributor on SCC.
1920 bool runAttributor(bool IsModulePass) {
1921 if (SCC.empty())
1922 return false;
1923
1924 registerAAs(IsModulePass);
1925
1926 ChangeStatus Changed = A.run();
1927
1928 LLVM_DEBUG(dbgs() << "[Attributor] Done with " << SCC.size()
1929 << " functions, result: " << Changed << ".\n");
1930
1931 return Changed == ChangeStatus::CHANGED;
1932 }
1933
1934 void registerFoldRuntimeCall(RuntimeFunction RF);
1935
1936 /// Populate the Attributor with abstract attribute opportunities in the
1937 /// functions.
1938 void registerAAs(bool IsModulePass);
1939
1940public:
1941 /// Callback to register AAs for live functions, including internal functions
1942 /// marked live during the traversal.
1943 static void registerAAsForFunction(Attributor &A, const Function &F);
1944};
1945
1946Kernel OpenMPOpt::getUniqueKernelFor(Function &F) {
1947 if (!OMPInfoCache.ModuleSlice.empty() && !OMPInfoCache.ModuleSlice.count(&F))
1948 return nullptr;
1949
1950 // Use a scope to keep the lifetime of the CachedKernel short.
1951 {
1952 std::optional<Kernel> &CachedKernel = UniqueKernelMap[&F];
1953 if (CachedKernel)
1954 return *CachedKernel;
1955
1956 // TODO: We should use an AA to create an (optimistic and callback
1957 // call-aware) call graph. For now we stick to simple patterns that
1958 // are less powerful, basically the worst fixpoint.
1959 if (isKernel(F)) {
1960 CachedKernel = Kernel(&F);
1961 return *CachedKernel;
1962 }
1963
1964 CachedKernel = nullptr;
1965 if (!F.hasLocalLinkage()) {
1966
1967 // See https://openmp.llvm.org/remarks/OptimizationRemarks.html
1968 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
1969 return ORA << "Potentially unknown OpenMP target region caller.";
1970 };
1971 emitRemark<OptimizationRemarkAnalysis>(&F, "OMP100", Remark);
1972
1973 return nullptr;
1974 }
1975 }
1976
1977 auto GetUniqueKernelForUse = [&](const Use &U) -> Kernel {
1978 if (auto *Cmp = dyn_cast<ICmpInst>(U.getUser())) {
1979 // Allow use in equality comparisons.
1980 if (Cmp->isEquality())
1981 return getUniqueKernelFor(*Cmp);
1982 return nullptr;
1983 }
1984 if (auto *CB = dyn_cast<CallBase>(U.getUser())) {
1985 // Allow direct calls.
1986 if (CB->isCallee(&U))
1987 return getUniqueKernelFor(*CB);
1988
1989 OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
1990 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
1991 // Allow the use in __kmpc_parallel_51 calls.
1992 if (OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI))
1993 return getUniqueKernelFor(*CB);
1994 return nullptr;
1995 }
1996 // Disallow every other use.
1997 return nullptr;
1998 };
1999
2000 // TODO: In the future we want to track more than just a unique kernel.
2001 SmallPtrSet<Kernel, 2> PotentialKernels;
2002 OMPInformationCache::foreachUse(F, [&](const Use &U) {
2003 PotentialKernels.insert(GetUniqueKernelForUse(U));
2004 });
2005
2006 Kernel K = nullptr;
2007 if (PotentialKernels.size() == 1)
2008 K = *PotentialKernels.begin();
2009
2010 // Cache the result.
2011 UniqueKernelMap[&F] = K;
2012
2013 return K;
2014}
2015
2016bool OpenMPOpt::rewriteDeviceCodeStateMachine() {
2017 OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
2018 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
2019
2020 bool Changed = false;
2021 if (!KernelParallelRFI)
2022 return Changed;
2023
2024 // If we have disabled state machine changes, exit
2026 return Changed;
2027
2028 for (Function *F : SCC) {
2029
2030 // Check if the function is a use in a __kmpc_parallel_51 call at
2031 // all.
2032 bool UnknownUse = false;
2033 bool KernelParallelUse = false;
2034 unsigned NumDirectCalls = 0;
2035
2036 SmallVector<Use *, 2> ToBeReplacedStateMachineUses;
2037 OMPInformationCache::foreachUse(*F, [&](Use &U) {
2038 if (auto *CB = dyn_cast<CallBase>(U.getUser()))
2039 if (CB->isCallee(&U)) {
2040 ++NumDirectCalls;
2041 return;
2042 }
2043
2044 if (isa<ICmpInst>(U.getUser())) {
2045 ToBeReplacedStateMachineUses.push_back(&U);
2046 return;
2047 }
2048
2049 // Find wrapper functions that represent parallel kernels.
2050 CallInst *CI =
2051 OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI);
2052 const unsigned int WrapperFunctionArgNo = 6;
2053 if (!KernelParallelUse && CI &&
2054 CI->getArgOperandNo(&U) == WrapperFunctionArgNo) {
2055 KernelParallelUse = true;
2056 ToBeReplacedStateMachineUses.push_back(&U);
2057 return;
2058 }
2059 UnknownUse = true;
2060 });
2061
2062 // Do not emit a remark if we haven't seen a __kmpc_parallel_51
2063 // use.
2064 if (!KernelParallelUse)
2065 continue;
2066
2067 // If this ever hits, we should investigate.
2068 // TODO: Checking the number of uses is not a necessary restriction and
2069 // should be lifted.
2070 if (UnknownUse || NumDirectCalls != 1 ||
2071 ToBeReplacedStateMachineUses.size() > 2) {
2072 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2073 return ORA << "Parallel region is used in "
2074 << (UnknownUse ? "unknown" : "unexpected")
2075 << " ways. Will not attempt to rewrite the state machine.";
2076 };
2077 emitRemark<OptimizationRemarkAnalysis>(F, "OMP101", Remark);
2078 continue;
2079 }
2080
2081 // Even if we have __kmpc_parallel_51 calls, we (for now) give
2082 // up if the function is not called from a unique kernel.
2083 Kernel K = getUniqueKernelFor(*F);
2084 if (!K) {
2085 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2086 return ORA << "Parallel region is not called from a unique kernel. "
2087 "Will not attempt to rewrite the state machine.";
2088 };
2089 emitRemark<OptimizationRemarkAnalysis>(F, "OMP102", Remark);
2090 continue;
2091 }
2092
2093 // We now know F is a parallel body function called only from the kernel K.
2094 // We also identified the state machine uses in which we replace the
2095 // function pointer by a new global symbol for identification purposes. This
2096 // ensures only direct calls to the function are left.
2097
2098 Module &M = *F->getParent();
2099 Type *Int8Ty = Type::getInt8Ty(M.getContext());
2100
2101 auto *ID = new GlobalVariable(
2102 M, Int8Ty, /* isConstant */ true, GlobalValue::PrivateLinkage,
2103 UndefValue::get(Int8Ty), F->getName() + ".ID");
2104
2105 for (Use *U : ToBeReplacedStateMachineUses)
2107 ID, U->get()->getType()));
2108
2109 ++NumOpenMPParallelRegionsReplacedInGPUStateMachine;
2110
2111 Changed = true;
2112 }
2113
2114 return Changed;
2115}
2116
2117/// Abstract Attribute for tracking ICV values.
2118struct AAICVTracker : public StateWrapper<BooleanState, AbstractAttribute> {
2120 AAICVTracker(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
2121
2122 void initialize(Attributor &A) override {
2123 Function *F = getAnchorScope();
2124 if (!F || !A.isFunctionIPOAmendable(*F))
2125 indicatePessimisticFixpoint();
2126 }
2127
2128 /// Returns true if value is assumed to be tracked.
2129 bool isAssumedTracked() const { return getAssumed(); }
2130
2131 /// Returns true if value is known to be tracked.
2132 bool isKnownTracked() const { return getAssumed(); }
2133
2134 /// Create an abstract attribute biew for the position \p IRP.
2135 static AAICVTracker &createForPosition(const IRPosition &IRP, Attributor &A);
2136
2137 /// Return the value with which \p I can be replaced for specific \p ICV.
2138 virtual std::optional<Value *> getReplacementValue(InternalControlVar ICV,
2139 const Instruction *I,
2140 Attributor &A) const {
2141 return std::nullopt;
2142 }
2143
2144 /// Return an assumed unique ICV value if a single candidate is found. If
2145 /// there cannot be one, return a nullptr. If it is not clear yet, return
2146 /// std::nullopt.
2147 virtual std::optional<Value *>
2148 getUniqueReplacementValue(InternalControlVar ICV) const = 0;
2149
2150 // Currently only nthreads is being tracked.
2151 // this array will only grow with time.
2152 InternalControlVar TrackableICVs[1] = {ICV_nthreads};
2153
2154 /// See AbstractAttribute::getName()
2155 const std::string getName() const override { return "AAICVTracker"; }
2156
2157 /// See AbstractAttribute::getIdAddr()
2158 const char *getIdAddr() const override { return &ID; }
2159
2160 /// This function should return true if the type of the \p AA is AAICVTracker
2161 static bool classof(const AbstractAttribute *AA) {
2162 return (AA->getIdAddr() == &ID);
2163 }
2164
2165 static const char ID;
2166};
2167
2168struct AAICVTrackerFunction : public AAICVTracker {
2169 AAICVTrackerFunction(const IRPosition &IRP, Attributor &A)
2170 : AAICVTracker(IRP, A) {}
2171
2172 // FIXME: come up with better string.
2173 const std::string getAsStr() const override { return "ICVTrackerFunction"; }
2174
2175 // FIXME: come up with some stats.
2176 void trackStatistics() const override {}
2177
2178 /// We don't manifest anything for this AA.
2179 ChangeStatus manifest(Attributor &A) override {
2180 return ChangeStatus::UNCHANGED;
2181 }
2182
2183 // Map of ICV to their values at specific program point.
2185 InternalControlVar::ICV___last>
2186 ICVReplacementValuesMap;
2187
2188 ChangeStatus updateImpl(Attributor &A) override {
2189 ChangeStatus HasChanged = ChangeStatus::UNCHANGED;
2190
2191 Function *F = getAnchorScope();
2192
2193 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2194
2195 for (InternalControlVar ICV : TrackableICVs) {
2196 auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2197
2198 auto &ValuesMap = ICVReplacementValuesMap[ICV];
2199 auto TrackValues = [&](Use &U, Function &) {
2200 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U);
2201 if (!CI)
2202 return false;
2203
2204 // FIXME: handle setters with more that 1 arguments.
2205 /// Track new value.
2206 if (ValuesMap.insert(std::make_pair(CI, CI->getArgOperand(0))).second)
2207 HasChanged = ChangeStatus::CHANGED;
2208
2209 return false;
2210 };
2211
2212 auto CallCheck = [&](Instruction &I) {
2213 std::optional<Value *> ReplVal = getValueForCall(A, I, ICV);
2214 if (ReplVal && ValuesMap.insert(std::make_pair(&I, *ReplVal)).second)
2215 HasChanged = ChangeStatus::CHANGED;
2216
2217 return true;
2218 };
2219
2220 // Track all changes of an ICV.
2221 SetterRFI.foreachUse(TrackValues, F);
2222
2223 bool UsedAssumedInformation = false;
2224 A.checkForAllInstructions(CallCheck, *this, {Instruction::Call},
2225 UsedAssumedInformation,
2226 /* CheckBBLivenessOnly */ true);
2227
2228 /// TODO: Figure out a way to avoid adding entry in
2229 /// ICVReplacementValuesMap
2230 Instruction *Entry = &F->getEntryBlock().front();
2231 if (HasChanged == ChangeStatus::CHANGED && !ValuesMap.count(Entry))
2232 ValuesMap.insert(std::make_pair(Entry, nullptr));
2233 }
2234
2235 return HasChanged;
2236 }
2237
2238 /// Helper to check if \p I is a call and get the value for it if it is
2239 /// unique.
2240 std::optional<Value *> getValueForCall(Attributor &A, const Instruction &I,
2241 InternalControlVar &ICV) const {
2242
2243 const auto *CB = dyn_cast<CallBase>(&I);
2244 if (!CB || CB->hasFnAttr("no_openmp") ||
2245 CB->hasFnAttr("no_openmp_routines"))
2246 return std::nullopt;
2247
2248 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2249 auto &GetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Getter];
2250 auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2251 Function *CalledFunction = CB->getCalledFunction();
2252
2253 // Indirect call, assume ICV changes.
2254 if (CalledFunction == nullptr)
2255 return nullptr;
2256 if (CalledFunction == GetterRFI.Declaration)
2257 return std::nullopt;
2258 if (CalledFunction == SetterRFI.Declaration) {
2259 if (ICVReplacementValuesMap[ICV].count(&I))
2260 return ICVReplacementValuesMap[ICV].lookup(&I);
2261
2262 return nullptr;
2263 }
2264
2265 // Since we don't know, assume it changes the ICV.
2266 if (CalledFunction->isDeclaration())
2267 return nullptr;
2268
2269 const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2270 *this, IRPosition::callsite_returned(*CB), DepClassTy::REQUIRED);
2271
2272 if (ICVTrackingAA.isAssumedTracked()) {
2273 std::optional<Value *> URV = ICVTrackingAA.getUniqueReplacementValue(ICV);
2274 if (!URV || (*URV && AA::isValidAtPosition(AA::ValueAndContext(**URV, I),
2275 OMPInfoCache)))
2276 return URV;
2277 }
2278
2279 // If we don't know, assume it changes.
2280 return nullptr;
2281 }
2282
2283 // We don't check unique value for a function, so return std::nullopt.
2284 std::optional<Value *>
2285 getUniqueReplacementValue(InternalControlVar ICV) const override {
2286 return std::nullopt;
2287 }
2288
2289 /// Return the value with which \p I can be replaced for specific \p ICV.
2290 std::optional<Value *> getReplacementValue(InternalControlVar ICV,
2291 const Instruction *I,
2292 Attributor &A) const override {
2293 const auto &ValuesMap = ICVReplacementValuesMap[ICV];
2294 if (ValuesMap.count(I))
2295 return ValuesMap.lookup(I);
2296
2299 Worklist.push_back(I);
2300
2301 std::optional<Value *> ReplVal;
2302
2303 while (!Worklist.empty()) {
2304 const Instruction *CurrInst = Worklist.pop_back_val();
2305 if (!Visited.insert(CurrInst).second)
2306 continue;
2307
2308 const BasicBlock *CurrBB = CurrInst->getParent();
2309
2310 // Go up and look for all potential setters/calls that might change the
2311 // ICV.
2312 while ((CurrInst = CurrInst->getPrevNode())) {
2313 if (ValuesMap.count(CurrInst)) {
2314 std::optional<Value *> NewReplVal = ValuesMap.lookup(CurrInst);
2315 // Unknown value, track new.
2316 if (!ReplVal) {
2317 ReplVal = NewReplVal;
2318 break;
2319 }
2320
2321 // If we found a new value, we can't know the icv value anymore.
2322 if (NewReplVal)
2323 if (ReplVal != NewReplVal)
2324 return nullptr;
2325
2326 break;
2327 }
2328
2329 std::optional<Value *> NewReplVal = getValueForCall(A, *CurrInst, ICV);
2330 if (!NewReplVal)
2331 continue;
2332
2333 // Unknown value, track new.
2334 if (!ReplVal) {
2335 ReplVal = NewReplVal;
2336 break;
2337 }
2338
2339 // if (NewReplVal.hasValue())
2340 // We found a new value, we can't know the icv value anymore.
2341 if (ReplVal != NewReplVal)
2342 return nullptr;
2343 }
2344
2345 // If we are in the same BB and we have a value, we are done.
2346 if (CurrBB == I->getParent() && ReplVal)
2347 return ReplVal;
2348
2349 // Go through all predecessors and add terminators for analysis.
2350 for (const BasicBlock *Pred : predecessors(CurrBB))
2351 if (const Instruction *Terminator = Pred->getTerminator())
2352 Worklist.push_back(Terminator);
2353 }
2354
2355 return ReplVal;
2356 }
2357};
2358
2359struct AAICVTrackerFunctionReturned : AAICVTracker {
2360 AAICVTrackerFunctionReturned(const IRPosition &IRP, Attributor &A)
2361 : AAICVTracker(IRP, A) {}
2362
2363 // FIXME: come up with better string.
2364 const std::string getAsStr() const override {
2365 return "ICVTrackerFunctionReturned";
2366 }
2367
2368 // FIXME: come up with some stats.
2369 void trackStatistics() const override {}
2370
2371 /// We don't manifest anything for this AA.
2372 ChangeStatus manifest(Attributor &A) override {
2373 return ChangeStatus::UNCHANGED;
2374 }
2375
2376 // Map of ICV to their values at specific program point.
2378 InternalControlVar::ICV___last>
2379 ICVReplacementValuesMap;
2380
2381 /// Return the value with which \p I can be replaced for specific \p ICV.
2382 std::optional<Value *>
2383 getUniqueReplacementValue(InternalControlVar ICV) const override {
2384 return ICVReplacementValuesMap[ICV];
2385 }
2386
2387 ChangeStatus updateImpl(Attributor &A) override {
2388 ChangeStatus Changed = ChangeStatus::UNCHANGED;
2389 const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2390 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2391
2392 if (!ICVTrackingAA.isAssumedTracked())
2393 return indicatePessimisticFixpoint();
2394
2395 for (InternalControlVar ICV : TrackableICVs) {
2396 std::optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2397 std::optional<Value *> UniqueICVValue;
2398
2399 auto CheckReturnInst = [&](Instruction &I) {
2400 std::optional<Value *> NewReplVal =
2401 ICVTrackingAA.getReplacementValue(ICV, &I, A);
2402
2403 // If we found a second ICV value there is no unique returned value.
2404 if (UniqueICVValue && UniqueICVValue != NewReplVal)
2405 return false;
2406
2407 UniqueICVValue = NewReplVal;
2408
2409 return true;
2410 };
2411
2412 bool UsedAssumedInformation = false;
2413 if (!A.checkForAllInstructions(CheckReturnInst, *this, {Instruction::Ret},
2414 UsedAssumedInformation,
2415 /* CheckBBLivenessOnly */ true))
2416 UniqueICVValue = nullptr;
2417
2418 if (UniqueICVValue == ReplVal)
2419 continue;
2420
2421 ReplVal = UniqueICVValue;
2422 Changed = ChangeStatus::CHANGED;
2423 }
2424
2425 return Changed;
2426 }
2427};
2428
2429struct AAICVTrackerCallSite : AAICVTracker {
2430 AAICVTrackerCallSite(const IRPosition &IRP, Attributor &A)
2431 : AAICVTracker(IRP, A) {}
2432
2433 void initialize(Attributor &A) override {
2434 Function *F = getAnchorScope();
2435 if (!F || !A.isFunctionIPOAmendable(*F))
2436 indicatePessimisticFixpoint();
2437
2438 // We only initialize this AA for getters, so we need to know which ICV it
2439 // gets.
2440 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2441 for (InternalControlVar ICV : TrackableICVs) {
2442 auto ICVInfo = OMPInfoCache.ICVs[ICV];
2443 auto &Getter = OMPInfoCache.RFIs[ICVInfo.Getter];
2444 if (Getter.Declaration == getAssociatedFunction()) {
2445 AssociatedICV = ICVInfo.Kind;
2446 return;
2447 }
2448 }
2449
2450 /// Unknown ICV.
2451 indicatePessimisticFixpoint();
2452 }
2453
2454 ChangeStatus manifest(Attributor &A) override {
2455 if (!ReplVal || !*ReplVal)
2456 return ChangeStatus::UNCHANGED;
2457
2458 A.changeAfterManifest(IRPosition::inst(*getCtxI()), **ReplVal);
2459 A.deleteAfterManifest(*getCtxI());
2460
2461 return ChangeStatus::CHANGED;
2462 }
2463
2464 // FIXME: come up with better string.
2465 const std::string getAsStr() const override { return "ICVTrackerCallSite"; }
2466
2467 // FIXME: come up with some stats.
2468 void trackStatistics() const override {}
2469
2470 InternalControlVar AssociatedICV;
2471 std::optional<Value *> ReplVal;
2472
2473 ChangeStatus updateImpl(Attributor &A) override {
2474 const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2475 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2476
2477 // We don't have any information, so we assume it changes the ICV.
2478 if (!ICVTrackingAA.isAssumedTracked())
2479 return indicatePessimisticFixpoint();
2480
2481 std::optional<Value *> NewReplVal =
2482 ICVTrackingAA.getReplacementValue(AssociatedICV, getCtxI(), A);
2483
2484 if (ReplVal == NewReplVal)
2485 return ChangeStatus::UNCHANGED;
2486
2487 ReplVal = NewReplVal;
2488 return ChangeStatus::CHANGED;
2489 }
2490
2491 // Return the value with which associated value can be replaced for specific
2492 // \p ICV.
2493 std::optional<Value *>
2494 getUniqueReplacementValue(InternalControlVar ICV) const override {
2495 return ReplVal;
2496 }
2497};
2498
2499struct AAICVTrackerCallSiteReturned : AAICVTracker {
2500 AAICVTrackerCallSiteReturned(const IRPosition &IRP, Attributor &A)
2501 : AAICVTracker(IRP, A) {}
2502
2503 // FIXME: come up with better string.
2504 const std::string getAsStr() const override {
2505 return "ICVTrackerCallSiteReturned";
2506 }
2507
2508 // FIXME: come up with some stats.
2509 void trackStatistics() const override {}
2510
2511 /// We don't manifest anything for this AA.
2512 ChangeStatus manifest(Attributor &A) override {
2513 return ChangeStatus::UNCHANGED;
2514 }
2515
2516 // Map of ICV to their values at specific program point.
2518 InternalControlVar::ICV___last>
2519 ICVReplacementValuesMap;
2520
2521 /// Return the value with which associated value can be replaced for specific
2522 /// \p ICV.
2523 std::optional<Value *>
2524 getUniqueReplacementValue(InternalControlVar ICV) const override {
2525 return ICVReplacementValuesMap[ICV];
2526 }
2527
2528 ChangeStatus updateImpl(Attributor &A) override {
2529 ChangeStatus Changed = ChangeStatus::UNCHANGED;
2530 const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2531 *this, IRPosition::returned(*getAssociatedFunction()),
2532 DepClassTy::REQUIRED);
2533
2534 // We don't have any information, so we assume it changes the ICV.
2535 if (!ICVTrackingAA.isAssumedTracked())
2536 return indicatePessimisticFixpoint();
2537
2538 for (InternalControlVar ICV : TrackableICVs) {
2539 std::optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2540 std::optional<Value *> NewReplVal =
2541 ICVTrackingAA.getUniqueReplacementValue(ICV);
2542
2543 if (ReplVal == NewReplVal)
2544 continue;
2545
2546 ReplVal = NewReplVal;
2547 Changed = ChangeStatus::CHANGED;
2548 }
2549 return Changed;
2550 }
2551};
2552
2553struct AAExecutionDomainFunction : public AAExecutionDomain {
2554 AAExecutionDomainFunction(const IRPosition &IRP, Attributor &A)
2555 : AAExecutionDomain(IRP, A) {}
2556
2557 ~AAExecutionDomainFunction() { delete RPOT; }
2558
2559 void initialize(Attributor &A) override {
2560 if (getAnchorScope()->isDeclaration()) {
2562 return;
2563 }
2565 }
2566
2567 const std::string getAsStr() const override {
2568 unsigned TotalBlocks = 0, InitialThreadBlocks = 0;
2569 for (auto &It : BEDMap) {
2570 TotalBlocks++;
2571 InitialThreadBlocks += It.getSecond().IsExecutedByInitialThreadOnly;
2572 }
2573 return "[AAExecutionDomain] " + std::to_string(InitialThreadBlocks) + "/" +
2574 std::to_string(TotalBlocks) + " executed by initial thread only";
2575 }
2576
2577 /// See AbstractAttribute::trackStatistics().
2578 void trackStatistics() const override {}
2579
2580 ChangeStatus manifest(Attributor &A) override {
2581 LLVM_DEBUG({
2582 for (const BasicBlock &BB : *getAnchorScope()) {
2584 continue;
2585 dbgs() << TAG << " Basic block @" << getAnchorScope()->getName() << " "
2586 << BB.getName() << " is executed by a single thread.\n";
2587 }
2588 });
2589
2591
2593 return Changed;
2594
2595 SmallPtrSet<CallBase *, 16> DeletedBarriers;
2596 auto HandleAlignedBarrier = [&](CallBase *CB) {
2597 const ExecutionDomainTy &ED = CEDMap[CB];
2598 if (!ED.IsReachedFromAlignedBarrierOnly ||
2599 ED.EncounteredNonLocalSideEffect)
2600 return;
2601
2602 // We can remove this barrier, if it is one, or all aligned barriers
2603 // reaching the kernel end. In the latter case we can transitively work
2604 // our way back until we find a barrier that guards a side-effect if we
2605 // are dealing with the kernel end here.
2606 if (CB) {
2607 DeletedBarriers.insert(CB);
2608 A.deleteAfterManifest(*CB);
2609 ++NumBarriersEliminated;
2610 Changed = ChangeStatus::CHANGED;
2611 } else if (!ED.AlignedBarriers.empty()) {
2612 NumBarriersEliminated += ED.AlignedBarriers.size();
2613 Changed = ChangeStatus::CHANGED;
2614 SmallVector<CallBase *> Worklist(ED.AlignedBarriers.begin(),
2615 ED.AlignedBarriers.end());
2617 while (!Worklist.empty()) {
2618 CallBase *LastCB = Worklist.pop_back_val();
2619 if (!Visited.insert(LastCB))
2620 continue;
2621 if (!DeletedBarriers.count(LastCB)) {
2622 A.deleteAfterManifest(*LastCB);
2623 continue;
2624 }
2625 // The final aligned barrier (LastCB) reaching the kernel end was
2626 // removed already. This means we can go one step further and remove
2627 // the barriers encoutered last before (LastCB).
2628 const ExecutionDomainTy &LastED = CEDMap[LastCB];
2629 Worklist.append(LastED.AlignedBarriers.begin(),
2630 LastED.AlignedBarriers.end());
2631 }
2632 }
2633
2634 // If we actually eliminated a barrier we need to eliminate the associated
2635 // llvm.assumes as well to avoid creating UB.
2636 if (!ED.EncounteredAssumes.empty() && (CB || !ED.AlignedBarriers.empty()))
2637 for (auto *AssumeCB : ED.EncounteredAssumes)
2638 A.deleteAfterManifest(*AssumeCB);
2639 };
2640
2641 for (auto *CB : AlignedBarriers)
2642 HandleAlignedBarrier(CB);
2643
2644 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2645 // Handle the "kernel end barrier" for kernels too.
2646 if (OMPInfoCache.Kernels.count(getAnchorScope()))
2647 HandleAlignedBarrier(nullptr);
2648
2649 return Changed;
2650 }
2651
2652 /// Merge barrier and assumption information from \p PredED into the successor
2653 /// \p ED.
2654 void
2655 mergeInPredecessorBarriersAndAssumptions(Attributor &A, ExecutionDomainTy &ED,
2656 const ExecutionDomainTy &PredED);
2657
2658 /// Merge all information from \p PredED into the successor \p ED. If
2659 /// \p InitialEdgeOnly is set, only the initial edge will enter the block
2660 /// represented by \p ED from this predecessor.
2661 void mergeInPredecessor(Attributor &A, ExecutionDomainTy &ED,
2662 const ExecutionDomainTy &PredED,
2663 bool InitialEdgeOnly = false);
2664
2665 /// Accumulate information for the entry block in \p EntryBBED.
2666 void handleEntryBB(Attributor &A, ExecutionDomainTy &EntryBBED);
2667
2668 /// See AbstractAttribute::updateImpl.
2670
2671 /// Query interface, see AAExecutionDomain
2672 ///{
2673 bool isExecutedByInitialThreadOnly(const BasicBlock &BB) const override {
2674 if (!isValidState())
2675 return false;
2676 return BEDMap.lookup(&BB).IsExecutedByInitialThreadOnly;
2677 }
2678
2680 const Instruction &I) const override {
2681 assert(I.getFunction() == getAnchorScope() &&
2682 "Instruction is out of scope!");
2683 if (!isValidState())
2684 return false;
2685
2686 const Instruction *CurI;
2687
2688 // Check forward until a call or the block end is reached.
2689 CurI = &I;
2690 do {
2691 auto *CB = dyn_cast<CallBase>(CurI);
2692 if (!CB)
2693 continue;
2694 if (CB != &I && AlignedBarriers.contains(const_cast<CallBase *>(CB))) {
2695 break;
2696 }
2697 const auto &It = CEDMap.find(CB);
2698 if (It == CEDMap.end())
2699 continue;
2700 if (!It->getSecond().IsReachingAlignedBarrierOnly)
2701 return false;
2702 break;
2703 } while ((CurI = CurI->getNextNonDebugInstruction()));
2704
2705 if (!CurI && !BEDMap.lookup(I.getParent()).IsReachingAlignedBarrierOnly)
2706 return false;
2707
2708 // Check backward until a call or the block beginning is reached.
2709 CurI = &I;
2710 do {
2711 auto *CB = dyn_cast<CallBase>(CurI);
2712 if (!CB)
2713 continue;
2714 if (CB != &I && AlignedBarriers.contains(const_cast<CallBase *>(CB))) {
2715 break;
2716 }
2717 const auto &It = CEDMap.find(CB);
2718 if (It == CEDMap.end())
2719 continue;
2720 if (!AA::isNoSyncInst(A, *CB, *this)) {
2721 if (It->getSecond().IsReachedFromAlignedBarrierOnly) {
2722 break;
2723 }
2724 return false;
2725 }
2726
2727 Function *Callee = CB->getCalledFunction();
2728 if (!Callee || Callee->isDeclaration())
2729 return false;
2730 const auto &EDAA = A.getAAFor<AAExecutionDomain>(
2732 if (!EDAA.getState().isValidState())
2733 return false;
2734 if (!EDAA.getFunctionExecutionDomain().IsReachedFromAlignedBarrierOnly)
2735 return false;
2736 break;
2737 } while ((CurI = CurI->getPrevNonDebugInstruction()));
2738
2739 if (!CurI &&
2740 !llvm::all_of(
2741 predecessors(I.getParent()), [&](const BasicBlock *PredBB) {
2742 return BEDMap.lookup(PredBB).IsReachedFromAlignedBarrierOnly;
2743 })) {
2744 return false;
2745 }
2746
2747 // On neither traversal we found a anything but aligned barriers.
2748 return true;
2749 }
2750
2751 ExecutionDomainTy getExecutionDomain(const BasicBlock &BB) const override {
2752 assert(isValidState() &&
2753 "No request should be made against an invalid state!");
2754 return BEDMap.lookup(&BB);
2755 }
2756 ExecutionDomainTy getExecutionDomain(const CallBase &CB) const override {
2757 assert(isValidState() &&
2758 "No request should be made against an invalid state!");
2759 return CEDMap.lookup(&CB);
2760 }
2761 ExecutionDomainTy getFunctionExecutionDomain() const override {
2762 assert(isValidState() &&
2763 "No request should be made against an invalid state!");
2764 return BEDMap.lookup(nullptr);
2765 }
2766 ///}
2767
2768 // Check if the edge into the successor block contains a condition that only
2769 // lets the main thread execute it.
2770 static bool isInitialThreadOnlyEdge(Attributor &A, BranchInst *Edge,
2771 BasicBlock &SuccessorBB) {
2772 if (!Edge || !Edge->isConditional())
2773 return false;
2774 if (Edge->getSuccessor(0) != &SuccessorBB)
2775 return false;
2776
2777 auto *Cmp = dyn_cast<CmpInst>(Edge->getCondition());
2778 if (!Cmp || !Cmp->isTrueWhenEqual() || !Cmp->isEquality())
2779 return false;
2780
2781 ConstantInt *C = dyn_cast<ConstantInt>(Cmp->getOperand(1));
2782 if (!C)
2783 return false;
2784
2785 // Match: -1 == __kmpc_target_init (for non-SPMD kernels only!)
2786 if (C->isAllOnesValue()) {
2787 auto *CB = dyn_cast<CallBase>(Cmp->getOperand(0));
2788 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2789 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
2790 CB = CB ? OpenMPOpt::getCallIfRegularCall(*CB, &RFI) : nullptr;
2791 if (!CB)
2792 return false;
2793 const int InitModeArgNo = 1;
2794 auto *ModeCI = dyn_cast<ConstantInt>(CB->getOperand(InitModeArgNo));
2795 return ModeCI && (ModeCI->getSExtValue() & OMP_TGT_EXEC_MODE_GENERIC);
2796 }
2797
2798 if (C->isZero()) {
2799 // Match: 0 == llvm.nvvm.read.ptx.sreg.tid.x()
2800 if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
2801 if (II->getIntrinsicID() == Intrinsic::nvvm_read_ptx_sreg_tid_x)
2802 return true;
2803
2804 // Match: 0 == llvm.amdgcn.workitem.id.x()
2805 if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
2806 if (II->getIntrinsicID() == Intrinsic::amdgcn_workitem_id_x)
2807 return true;
2808 }
2809
2810 return false;
2811 };
2812
2813 /// Mapping containing information per block.
2816 SmallSetVector<CallBase *, 16> AlignedBarriers;
2817
2819};
2820
2821void AAExecutionDomainFunction::mergeInPredecessorBarriersAndAssumptions(
2822 Attributor &A, ExecutionDomainTy &ED, const ExecutionDomainTy &PredED) {
2823 for (auto *EA : PredED.EncounteredAssumes)
2824 ED.addAssumeInst(A, *EA);
2825
2826 for (auto *AB : PredED.AlignedBarriers)
2827 ED.addAlignedBarrier(A, *AB);
2828}
2829
2830void AAExecutionDomainFunction::mergeInPredecessor(
2831 Attributor &A, ExecutionDomainTy &ED, const ExecutionDomainTy &PredED,
2832 bool InitialEdgeOnly) {
2833 ED.IsExecutedByInitialThreadOnly =
2834 InitialEdgeOnly || (PredED.IsExecutedByInitialThreadOnly &&
2835 ED.IsExecutedByInitialThreadOnly);
2836
2837 ED.IsReachedFromAlignedBarrierOnly = ED.IsReachedFromAlignedBarrierOnly &&
2838 PredED.IsReachedFromAlignedBarrierOnly;
2839 ED.EncounteredNonLocalSideEffect =
2840 ED.EncounteredNonLocalSideEffect | PredED.EncounteredNonLocalSideEffect;
2841 if (ED.IsReachedFromAlignedBarrierOnly)
2842 mergeInPredecessorBarriersAndAssumptions(A, ED, PredED);
2843 else
2844 ED.clearAssumeInstAndAlignedBarriers();
2845}
2846
2847void AAExecutionDomainFunction::handleEntryBB(Attributor &A,
2848 ExecutionDomainTy &EntryBBED) {
2849 SmallVector<ExecutionDomainTy> PredExecDomains;
2850 auto PredForCallSite = [&](AbstractCallSite ACS) {
2851 const auto &EDAA = A.getAAFor<AAExecutionDomain>(
2852 *this, IRPosition::function(*ACS.getInstruction()->getFunction()),
2853 DepClassTy::OPTIONAL);
2854 if (!EDAA.getState().isValidState())
2855 return false;
2856 PredExecDomains.emplace_back(
2857 EDAA.getExecutionDomain(*cast<CallBase>(ACS.getInstruction())));
2858 return true;
2859 };
2860
2861 bool AllCallSitesKnown;
2862 if (A.checkForAllCallSites(PredForCallSite, *this,
2863 /* RequiresAllCallSites */ true,
2864 AllCallSitesKnown)) {
2865 for (const auto &PredED : PredExecDomains)
2866 mergeInPredecessor(A, EntryBBED, PredED);
2867
2868 } else {
2869 // We could not find all predecessors, so this is either a kernel or a
2870 // function with external linkage (or with some other weird uses).
2871 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2872 if (OMPInfoCache.Kernels.count(getAnchorScope())) {
2873 EntryBBED.IsExecutedByInitialThreadOnly = false;
2874 EntryBBED.IsReachedFromAlignedBarrierOnly = true;
2875 EntryBBED.EncounteredNonLocalSideEffect = false;
2876 } else {
2877 EntryBBED.IsExecutedByInitialThreadOnly = false;
2878 EntryBBED.IsReachedFromAlignedBarrierOnly = false;
2879 EntryBBED.EncounteredNonLocalSideEffect = true;
2880 }
2881 }
2882
2883 auto &FnED = BEDMap[nullptr];
2884 FnED.IsReachingAlignedBarrierOnly &=
2885 EntryBBED.IsReachedFromAlignedBarrierOnly;
2886}
2887
2888ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
2889
2890 bool Changed = false;
2891
2892 // Helper to deal with an aligned barrier encountered during the forward
2893 // traversal. \p CB is the aligned barrier, \p ED is the execution domain when
2894 // it was encountered.
2895 auto HandleAlignedBarrier = [&](CallBase *CB, ExecutionDomainTy &ED) {
2896 if (CB)
2897 Changed |= AlignedBarriers.insert(CB);
2898 // First, update the barrier ED kept in the separate CEDMap.
2899 auto &CallED = CEDMap[CB];
2900 mergeInPredecessor(A, CallED, ED);
2901 // Next adjust the ED we use for the traversal.
2902 ED.EncounteredNonLocalSideEffect = false;
2903 ED.IsReachedFromAlignedBarrierOnly = true;
2904 // Aligned barrier collection has to come last.
2905 ED.clearAssumeInstAndAlignedBarriers();
2906 if (CB)
2907 ED.addAlignedBarrier(A, *CB);
2908 };
2909
2910 auto &LivenessAA =
2911 A.getAAFor<AAIsDead>(*this, getIRPosition(), DepClassTy::OPTIONAL);
2912
2913 // Set \p R to \V and report true if that changed \p R.
2914 auto SetAndRecord = [&](bool &R, bool V) {
2915 bool Eq = (R == V);
2916 R = V;
2917 return !Eq;
2918 };
2919
2920 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2921
2922 Function *F = getAnchorScope();
2923 BasicBlock &EntryBB = F->getEntryBlock();
2924 bool IsKernel = OMPInfoCache.Kernels.count(F);
2925
2926 SmallVector<Instruction *> SyncInstWorklist;
2927 for (auto &RIt : *RPOT) {
2928 BasicBlock &BB = *RIt;
2929
2930 bool IsEntryBB = &BB == &EntryBB;
2931 // TODO: We use local reasoning since we don't have a divergence analysis
2932 // running as well. We could basically allow uniform branches here.
2933 bool AlignedBarrierLastInBlock = IsEntryBB && IsKernel;
2934 ExecutionDomainTy ED;
2935 // Propagate "incoming edges" into information about this block.
2936 if (IsEntryBB) {
2937 handleEntryBB(A, ED);
2938 } else {
2939 // For live non-entry blocks we only propagate
2940 // information via live edges.
2941 if (LivenessAA.isAssumedDead(&BB))
2942 continue;
2943
2944 for (auto *PredBB : predecessors(&BB)) {
2945 if (LivenessAA.isEdgeDead(PredBB, &BB))
2946 continue;
2947 bool InitialEdgeOnly = isInitialThreadOnlyEdge(
2948 A, dyn_cast<BranchInst>(PredBB->getTerminator()), BB);
2949 mergeInPredecessor(A, ED, BEDMap[PredBB], InitialEdgeOnly);
2950 }
2951 }
2952
2953 // Now we traverse the block, accumulate effects in ED and attach
2954 // information to calls.
2955 for (Instruction &I : BB) {
2956 bool UsedAssumedInformation;
2957 if (A.isAssumedDead(I, *this, &LivenessAA, UsedAssumedInformation,
2958 /* CheckBBLivenessOnly */ false, DepClassTy::OPTIONAL,
2959 /* CheckForDeadStore */ true))
2960 continue;
2961
2962 // Asummes and "assume-like" (dbg, lifetime, ...) are handled first, the
2963 // former is collected the latter is ignored.
2964 if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
2965 if (auto *AI = dyn_cast_or_null<AssumeInst>(II)) {
2966 ED.addAssumeInst(A, *AI);
2967 continue;
2968 }
2969 // TODO: Should we also collect and delete lifetime markers?
2970 if (II->isAssumeLikeIntrinsic())
2971 continue;
2972 }
2973
2974 auto *CB = dyn_cast<CallBase>(&I);
2975 bool IsNoSync = AA::isNoSyncInst(A, I, *this);
2976 bool IsAlignedBarrier =
2977 !IsNoSync && CB &&
2978 AANoSync::isAlignedBarrier(*CB, AlignedBarrierLastInBlock);
2979
2980 AlignedBarrierLastInBlock &= IsNoSync;
2981
2982 // Next we check for calls. Aligned barriers are handled
2983 // explicitly, everything else is kept for the backward traversal and will
2984 // also affect our state.
2985 if (CB) {
2986 if (IsAlignedBarrier) {
2987 HandleAlignedBarrier(CB, ED);
2988 AlignedBarrierLastInBlock = true;
2989 continue;
2990 }
2991
2992 // Check the pointer(s) of a memory intrinsic explicitly.
2993 if (isa<MemIntrinsic>(&I)) {
2994 if (!ED.EncounteredNonLocalSideEffect &&
2996 ED.EncounteredNonLocalSideEffect = true;
2997 if (!IsNoSync) {
2998 ED.IsReachedFromAlignedBarrierOnly = false;
2999 SyncInstWorklist.push_back(&I);
3000 }
3001 continue;
3002 }
3003
3004 // Record how we entered the call, then accumulate the effect of the
3005 // call in ED for potential use by the callee.
3006 auto &CallED = CEDMap[CB];
3007 mergeInPredecessor(A, CallED, ED);
3008
3009 // If we have a sync-definition we can check if it starts/ends in an
3010 // aligned barrier. If we are unsure we assume any sync breaks
3011 // alignment.
3013 if (!IsNoSync && Callee && !Callee->isDeclaration()) {
3014 const auto &EDAA = A.getAAFor<AAExecutionDomain>(
3015 *this, IRPosition::function(*Callee), DepClassTy::OPTIONAL);
3016 if (EDAA.getState().isValidState()) {
3017 const auto &CalleeED = EDAA.getFunctionExecutionDomain();
3019 CallED.IsReachedFromAlignedBarrierOnly =
3020 CalleeED.IsReachedFromAlignedBarrierOnly;
3021 AlignedBarrierLastInBlock = ED.IsReachedFromAlignedBarrierOnly;
3022 if (IsNoSync || !CalleeED.IsReachedFromAlignedBarrierOnly)
3023 ED.EncounteredNonLocalSideEffect |=
3024 CalleeED.EncounteredNonLocalSideEffect;
3025 else
3026 ED.EncounteredNonLocalSideEffect =
3027 CalleeED.EncounteredNonLocalSideEffect;
3028 if (!CalleeED.IsReachingAlignedBarrierOnly)
3029 SyncInstWorklist.push_back(&I);
3030 if (CalleeED.IsReachedFromAlignedBarrierOnly)
3031 mergeInPredecessorBarriersAndAssumptions(A, ED, CalleeED);
3032 continue;
3033 }
3034 }
3035 if (!IsNoSync)
3036 ED.IsReachedFromAlignedBarrierOnly =
3037 CallED.IsReachedFromAlignedBarrierOnly = false;
3038 AlignedBarrierLastInBlock &= ED.IsReachedFromAlignedBarrierOnly;
3039 ED.EncounteredNonLocalSideEffect |= !CB->doesNotAccessMemory();
3040 if (!IsNoSync)
3041 SyncInstWorklist.push_back(&I);
3042 }
3043
3044 if (!I.mayHaveSideEffects() && !I.mayReadFromMemory())
3045 continue;
3046
3047 // If we have a callee we try to use fine-grained information to
3048 // determine local side-effects.
3049 if (CB) {
3050 const auto &MemAA = A.getAAFor<AAMemoryLocation>(
3051 *this, IRPosition::callsite_function(*CB), DepClassTy::OPTIONAL);
3052
3053 auto AccessPred = [&](const Instruction *I, const Value *Ptr,
3056 return !AA::isPotentiallyAffectedByBarrier(A, {Ptr}, *this, I);
3057 };
3058 if (MemAA.getState().isValidState() &&
3059 MemAA.checkForAllAccessesToMemoryKind(
3061 continue;
3062 }
3063
3064 if (!I.mayHaveSideEffects() && OMPInfoCache.isOnlyUsedByAssume(I))
3065 continue;
3066
3067 if (auto *LI = dyn_cast<LoadInst>(&I))
3068 if (LI->hasMetadata(LLVMContext::MD_invariant_load))
3069 continue;
3070
3071 if (!ED.EncounteredNonLocalSideEffect &&
3073 ED.EncounteredNonLocalSideEffect = true;
3074 }
3075
3076 if (!isa<UnreachableInst>(BB.getTerminator()) &&
3077 !BB.getTerminator()->getNumSuccessors()) {
3078
3079 auto &FnED = BEDMap[nullptr];
3080 mergeInPredecessor(A, FnED, ED);
3081
3082 if (IsKernel)
3083 HandleAlignedBarrier(nullptr, ED);
3084 }
3085
3086 ExecutionDomainTy &StoredED = BEDMap[&BB];
3087 ED.IsReachingAlignedBarrierOnly = StoredED.IsReachingAlignedBarrierOnly;
3088
3089 // Check if we computed anything different as part of the forward
3090 // traversal. We do not take assumptions and aligned barriers into account
3091 // as they do not influence the state we iterate. Backward traversal values
3092 // are handled later on.
3093 if (ED.IsExecutedByInitialThreadOnly !=
3094 StoredED.IsExecutedByInitialThreadOnly ||
3095 ED.IsReachedFromAlignedBarrierOnly !=
3096 StoredED.IsReachedFromAlignedBarrierOnly ||
3097 ED.EncounteredNonLocalSideEffect !=
3098 StoredED.EncounteredNonLocalSideEffect)
3099 Changed = true;
3100
3101 // Update the state with the new value.
3102 StoredED = std::move(ED);
3103 }
3104
3105 // Propagate (non-aligned) sync instruction effects backwards until the
3106 // entry is hit or an aligned barrier.
3108 while (!SyncInstWorklist.empty()) {
3109 Instruction *SyncInst = SyncInstWorklist.pop_back_val();
3110 Instruction *CurInst = SyncInst;
3111 bool HitAlignedBarrier = false;
3112 while ((CurInst = CurInst->getPrevNode())) {
3113 auto *CB = dyn_cast<CallBase>(CurInst);
3114 if (!CB)
3115 continue;
3116 auto &CallED = CEDMap[CB];
3117 if (SetAndRecord(CallED.IsReachingAlignedBarrierOnly, false))
3118 Changed = true;
3119 HitAlignedBarrier = AlignedBarriers.count(CB);
3120 if (HitAlignedBarrier)
3121 break;
3122 }
3123 if (HitAlignedBarrier)
3124 continue;
3125 BasicBlock *SyncBB = SyncInst->getParent();
3126 for (auto *PredBB : predecessors(SyncBB)) {
3127 if (LivenessAA.isEdgeDead(PredBB, SyncBB))
3128 continue;
3129 if (!Visited.insert(PredBB))
3130 continue;
3131 SyncInstWorklist.push_back(PredBB->getTerminator());
3132 auto &PredED = BEDMap[PredBB];
3133 if (SetAndRecord(PredED.IsReachingAlignedBarrierOnly, false))
3134 Changed = true;
3135 }
3136 if (SyncBB != &EntryBB)
3137 continue;
3138 auto &FnED = BEDMap[nullptr];
3139 if (SetAndRecord(FnED.IsReachingAlignedBarrierOnly, false))
3140 Changed = true;
3141 }
3142
3143 return Changed ? ChangeStatus::CHANGED : ChangeStatus::UNCHANGED;
3144}
3145
3146/// Try to replace memory allocation calls called by a single thread with a
3147/// static buffer of shared memory.
3148struct AAHeapToShared : public StateWrapper<BooleanState, AbstractAttribute> {
3150 AAHeapToShared(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
3151
3152 /// Create an abstract attribute view for the position \p IRP.
3153 static AAHeapToShared &createForPosition(const IRPosition &IRP,
3154 Attributor &A);
3155
3156 /// Returns true if HeapToShared conversion is assumed to be possible.
3157 virtual bool isAssumedHeapToShared(CallBase &CB) const = 0;
3158
3159 /// Returns true if HeapToShared conversion is assumed and the CB is a
3160 /// callsite to a free operation to be removed.
3161 virtual bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const = 0;
3162
3163 /// See AbstractAttribute::getName().
3164 const std::string getName() const override { return "AAHeapToShared"; }
3165
3166 /// See AbstractAttribute::getIdAddr().
3167 const char *getIdAddr() const override { return &ID; }
3168
3169 /// This function should return true if the type of the \p AA is
3170 /// AAHeapToShared.
3171 static bool classof(const AbstractAttribute *AA) {
3172 return (AA->getIdAddr() == &ID);
3173 }
3174
3175 /// Unique ID (due to the unique address)
3176 static const char ID;
3177};
3178
3179struct AAHeapToSharedFunction : public AAHeapToShared {
3180 AAHeapToSharedFunction(const IRPosition &IRP, Attributor &A)
3181 : AAHeapToShared(IRP, A) {}
3182
3183 const std::string getAsStr() const override {
3184 return "[AAHeapToShared] " + std::to_string(MallocCalls.size()) +
3185 " malloc calls eligible.";
3186 }
3187
3188 /// See AbstractAttribute::trackStatistics().
3189 void trackStatistics() const override {}
3190
3191 /// This functions finds free calls that will be removed by the
3192 /// HeapToShared transformation.
3193 void findPotentialRemovedFreeCalls(Attributor &A) {
3194 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3195 auto &FreeRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
3196
3197 PotentialRemovedFreeCalls.clear();
3198 // Update free call users of found malloc calls.
3199 for (CallBase *CB : MallocCalls) {
3201 for (auto *U : CB->users()) {
3202 CallBase *C = dyn_cast<CallBase>(U);
3203 if (C && C->getCalledFunction() == FreeRFI.Declaration)
3204 FreeCalls.push_back(C);
3205 }
3206
3207 if (FreeCalls.size() != 1)
3208 continue;
3209
3210 PotentialRemovedFreeCalls.insert(FreeCalls.front());
3211 }
3212 }
3213
3214 void initialize(Attributor &A) override {
3216 indicatePessimisticFixpoint();
3217 return;
3218 }
3219
3220 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3221 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3222 if (!RFI.Declaration)
3223 return;
3224
3226 [](const IRPosition &, const AbstractAttribute *,
3227 bool &) -> std::optional<Value *> { return nullptr; };
3228
3229 Function *F = getAnchorScope();
3230 for (User *U : RFI.Declaration->users())
3231 if (CallBase *CB = dyn_cast<CallBase>(U)) {
3232 if (CB->getFunction() != F)
3233 continue;
3234 MallocCalls.insert(CB);
3235 A.registerSimplificationCallback(IRPosition::callsite_returned(*CB),
3236 SCB);
3237 }
3238
3239 findPotentialRemovedFreeCalls(A);
3240 }
3241
3242 bool isAssumedHeapToShared(CallBase &CB) const override {
3243 return isValidState() && MallocCalls.count(&CB);
3244 }
3245
3246 bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const override {
3247 return isValidState() && PotentialRemovedFreeCalls.count(&CB);
3248 }
3249
3250 ChangeStatus manifest(Attributor &A) override {
3251 if (MallocCalls.empty())
3252 return ChangeStatus::UNCHANGED;
3253
3254 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3255 auto &FreeCall = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
3256
3257 Function *F = getAnchorScope();
3258 auto *HS = A.lookupAAFor<AAHeapToStack>(IRPosition::function(*F), this,
3259 DepClassTy::OPTIONAL);
3260
3261 ChangeStatus Changed = ChangeStatus::UNCHANGED;
3262 for (CallBase *CB : MallocCalls) {
3263 // Skip replacing this if HeapToStack has already claimed it.
3264 if (HS && HS->isAssumedHeapToStack(*CB))
3265 continue;
3266
3267 // Find the unique free call to remove it.
3269 for (auto *U : CB->users()) {
3270 CallBase *C = dyn_cast<CallBase>(U);
3271 if (C && C->getCalledFunction() == FreeCall.Declaration)
3272 FreeCalls.push_back(C);
3273 }
3274 if (FreeCalls.size() != 1)
3275 continue;
3276
3277 auto *AllocSize = cast<ConstantInt>(CB->getArgOperand(0));
3278
3279 if (AllocSize->getZExtValue() + SharedMemoryUsed > SharedMemoryLimit) {
3280 LLVM_DEBUG(dbgs() << TAG << "Cannot replace call " << *CB
3281 << " with shared memory."
3282 << " Shared memory usage is limited to "
3283 << SharedMemoryLimit << " bytes\n");
3284 continue;
3285 }
3286
3287 LLVM_DEBUG(dbgs() << TAG << "Replace globalization call " << *CB
3288 << " with " << AllocSize->getZExtValue()
3289 << " bytes of shared memory\n");
3290
3291 // Create a new shared memory buffer of the same size as the allocation
3292 // and replace all the uses of the original allocation with it.
3293 Module *M = CB->getModule();
3294 Type *Int8Ty = Type::getInt8Ty(M->getContext());
3295 Type *Int8ArrTy = ArrayType::get(Int8Ty, AllocSize->getZExtValue());
3296 auto *SharedMem = new GlobalVariable(
3297 *M, Int8ArrTy, /* IsConstant */ false, GlobalValue::InternalLinkage,
3298 UndefValue::get(Int8ArrTy), CB->getName() + "_shared", nullptr,
3300 static_cast<unsigned>(AddressSpace::Shared));
3301 auto *NewBuffer =
3302 ConstantExpr::getPointerCast(SharedMem, Int8Ty->getPointerTo());
3303
3304 auto Remark = [&](OptimizationRemark OR) {
3305 return OR << "Replaced globalized variable with "
3306 << ore::NV("SharedMemory", AllocSize->getZExtValue())
3307 << ((AllocSize->getZExtValue() != 1) ? " bytes " : " byte ")
3308 << "of shared memory.";
3309 };
3310 A.emitRemark<OptimizationRemark>(CB, "OMP111", Remark);
3311
3312 MaybeAlign Alignment = CB->getRetAlign();
3313 assert(Alignment &&
3314 "HeapToShared on allocation without alignment attribute");
3315 SharedMem->setAlignment(*Alignment);
3316
3317 A.changeAfterManifest(IRPosition::callsite_returned(*CB), *NewBuffer);
3318 A.deleteAfterManifest(*CB);
3319 A.deleteAfterManifest(*FreeCalls.front());
3320
3321 SharedMemoryUsed += AllocSize->getZExtValue();
3322 NumBytesMovedToSharedMemory = SharedMemoryUsed;
3323 Changed = ChangeStatus::CHANGED;
3324 }
3325
3326 return Changed;
3327 }
3328
3329 ChangeStatus updateImpl(Attributor &A) override {
3330 if (MallocCalls.empty())
3331 return indicatePessimisticFixpoint();
3332 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3333 auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3334 if (!RFI.Declaration)
3335 return ChangeStatus::UNCHANGED;
3336
3337 Function *F = getAnchorScope();
3338
3339 auto NumMallocCalls = MallocCalls.size();
3340
3341 // Only consider malloc calls executed by a single thread with a constant.
3342 for (User *U : RFI.Declaration->users()) {
3343 if (CallBase *CB = dyn_cast<CallBase>(U)) {
3344 if (CB->getCaller() != F)
3345 continue;
3346 if (!MallocCalls.count(CB))
3347 continue;
3348 if (!isa<ConstantInt>(CB->getArgOperand(0))) {
3349 MallocCalls.remove(CB);
3350 continue;
3351 }
3352 const auto &ED = A.getAAFor<AAExecutionDomain>(
3353 *this, IRPosition::function(*F), DepClassTy::REQUIRED);
3354 if (!ED.isExecutedByInitialThreadOnly(*CB))
3355 MallocCalls.remove(CB);
3356 }
3357 }
3358
3359 findPotentialRemovedFreeCalls(A);
3360
3361 if (NumMallocCalls != MallocCalls.size())
3362 return ChangeStatus::CHANGED;
3363
3364 return ChangeStatus::UNCHANGED;
3365 }
3366
3367 /// Collection of all malloc calls in a function.
3369 /// Collection of potentially removed free calls in a function.
3370 SmallPtrSet<CallBase *, 4> PotentialRemovedFreeCalls;
3371 /// The total amount of shared memory that has been used for HeapToShared.
3372 unsigned SharedMemoryUsed = 0;
3373};
3374
3375struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
3377 AAKernelInfo(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
3378
3379 /// Statistics are tracked as part of manifest for now.
3380 void trackStatistics() const override {}
3381
3382 /// See AbstractAttribute::getAsStr()
3383 const std::string getAsStr() const override {
3384 if (!isValidState())
3385 return "<invalid>";
3386 return std::string(SPMDCompatibilityTracker.isAssumed() ? "SPMD"
3387 : "generic") +
3388 std::string(SPMDCompatibilityTracker.isAtFixpoint() ? " [FIX]"
3389 : "") +
3390 std::string(" #PRs: ") +
3391 (ReachedKnownParallelRegions.isValidState()
3392 ? std::to_string(ReachedKnownParallelRegions.size())
3393 : "<invalid>") +
3394 ", #Unknown PRs: " +
3395 (ReachedUnknownParallelRegions.isValidState()
3396 ? std::to_string(ReachedUnknownParallelRegions.size())
3397 : "<invalid>") +
3398 ", #Reaching Kernels: " +
3399 (ReachingKernelEntries.isValidState()
3400 ? std::to_string(ReachingKernelEntries.size())
3401 : "<invalid>") +
3402 ", #ParLevels: " +
3403 (ParallelLevels.isValidState()
3404 ? std::to_string(ParallelLevels.size())
3405 : "<invalid>");
3406 }
3407
3408 /// Create an abstract attribute biew for the position \p IRP.
3409 static AAKernelInfo &createForPosition(const IRPosition &IRP, Attributor &A);
3410
3411 /// See AbstractAttribute::getName()
3412 const std::string getName() const override { return "AAKernelInfo"; }
3413
3414 /// See AbstractAttribute::getIdAddr()
3415 const char *getIdAddr() const override { return &ID; }
3416
3417 /// This function should return true if the type of the \p AA is AAKernelInfo
3418 static bool classof(const AbstractAttribute *AA) {
3419 return (AA->getIdAddr() == &ID);
3420 }
3421
3422 static const char ID;
3423};
3424
3425/// The function kernel info abstract attribute, basically, what can we say
3426/// about a function with regards to the KernelInfoState.
3427struct AAKernelInfoFunction : AAKernelInfo {
3428 AAKernelInfoFunction(const IRPosition &IRP, Attributor &A)
3429 : AAKernelInfo(IRP, A) {}
3430
3431 SmallPtrSet<Instruction *, 4> GuardedInstructions;
3432
3433 SmallPtrSetImpl<Instruction *> &getGuardedInstructions() {
3434 return GuardedInstructions;
3435 }
3436
3437 /// See AbstractAttribute::initialize(...).
3438 void initialize(Attributor &A) override {
3439 // This is a high-level transform that might change the constant arguments
3440 // of the init and dinit calls. We need to tell the Attributor about this
3441 // to avoid other parts using the current constant value for simpliication.
3442 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3443
3444 Function *Fn = getAnchorScope();
3445
3446 OMPInformationCache::RuntimeFunctionInfo &InitRFI =
3447 OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
3448 OMPInformationCache::RuntimeFunctionInfo &DeinitRFI =
3449 OMPInfoCache.RFIs[OMPRTL___kmpc_target_deinit];
3450
3451 // For kernels we perform more initialization work, first we find the init
3452 // and deinit calls.
3453 auto StoreCallBase = [](Use &U,
3454 OMPInformationCache::RuntimeFunctionInfo &RFI,
3455 CallBase *&Storage) {
3456 CallBase *CB = OpenMPOpt::getCallIfRegularCall(U, &RFI);
3457 assert(CB &&
3458 "Unexpected use of __kmpc_target_init or __kmpc_target_deinit!");
3459 assert(!Storage &&
3460 "Multiple uses of __kmpc_target_init or __kmpc_target_deinit!");
3461 Storage = CB;
3462 return false;
3463 };
3464 InitRFI.foreachUse(
3465 [&](Use &U, Function &) {
3466 StoreCallBase(U, InitRFI, KernelInitCB);
3467 return false;
3468 },
3469 Fn);
3470 DeinitRFI.foreachUse(
3471 [&](Use &U, Function &) {
3472 StoreCallBase(U, DeinitRFI, KernelDeinitCB);
3473 return false;
3474 },
3475 Fn);
3476
3477 // Ignore kernels without initializers such as global constructors.
3478 if (!KernelInitCB || !KernelDeinitCB)
3479 return;
3480
3481 // Add itself to the reaching kernel and set IsKernelEntry.
3482 ReachingKernelEntries.insert(Fn);
3483 IsKernelEntry = true;
3484
3485 // For kernels we might need to initialize/finalize the IsSPMD state and
3486 // we need to register a simplification callback so that the Attributor
3487 // knows the constant arguments to __kmpc_target_init and
3488 // __kmpc_target_deinit might actually change.
3489
3490 Attributor::SimplifictionCallbackTy StateMachineSimplifyCB =
3491 [&](const IRPosition &IRP, const AbstractAttribute *AA,
3492 bool &UsedAssumedInformation) -> std::optional<Value *> {
3493 // IRP represents the "use generic state machine" argument of an
3494 // __kmpc_target_init call. We will answer this one with the internal
3495 // state. As long as we are not in an invalid state, we will create a
3496 // custom state machine so the value should be a `i1 false`. If we are
3497 // in an invalid state, we won't change the value that is in the IR.
3498 if (!ReachedKnownParallelRegions.isValidState())
3499 return nullptr;
3500 // If we have disabled state machine rewrites, don't make a custom one.
3502 return nullptr;
3503 if (AA)
3504 A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
3505 UsedAssumedInformation = !isAtFixpoint();
3506 auto *FalseVal =
3508 return FalseVal;
3509 };
3510
3512 [&](const IRPosition &IRP, const AbstractAttribute *AA,
3513 bool &UsedAssumedInformation) -> std::optional<Value *> {
3514 // IRP represents the "SPMDCompatibilityTracker" argument of an
3515 // __kmpc_target_init or
3516 // __kmpc_target_deinit call. We will answer this one with the internal
3517 // state.
3518 if (!SPMDCompatibilityTracker.isValidState())
3519 return nullptr;
3520 if (!SPMDCompatibilityTracker.isAtFixpoint()) {
3521 if (AA)
3522 A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
3523 UsedAssumedInformation = true;
3524 } else {
3525 UsedAssumedInformation = false;
3526 }
3527 auto *Val = ConstantInt::getSigned(
3528 IntegerType::getInt8Ty(IRP.getAnchorValue().getContext()),
3529 SPMDCompatibilityTracker.isAssumed() ? OMP_TGT_EXEC_MODE_SPMD
3531 return Val;
3532 };
3533
3534 constexpr const int InitModeArgNo = 1;
3535 constexpr const int DeinitModeArgNo = 1;
3536 constexpr const int InitUseStateMachineArgNo = 2;
3537 A.registerSimplificationCallback(
3538 IRPosition::callsite_argument(*KernelInitCB, InitUseStateMachineArgNo),
3539 StateMachineSimplifyCB);
3540 A.registerSimplificationCallback(
3541 IRPosition::callsite_argument(*KernelInitCB, InitModeArgNo),
3542 ModeSimplifyCB);
3543 A.registerSimplificationCallback(
3544 IRPosition::callsite_argument(*KernelDeinitCB, DeinitModeArgNo),
3545 ModeSimplifyCB);
3546
3547 // Check if we know we are in SPMD-mode already.
3548 ConstantInt *ModeArg =
3549 dyn_cast<ConstantInt>(KernelInitCB->getArgOperand(InitModeArgNo));
3550 if (ModeArg && (ModeArg->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD))
3551 SPMDCompatibilityTracker.indicateOptimisticFixpoint();
3552 // This is a generic region but SPMDization is disabled so stop tracking.
3554 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3555
3556 // Register virtual uses of functions we might need to preserve.
3557 auto RegisterVirtualUse = [&](RuntimeFunction RFKind,
3559 if (!OMPInfoCache.RFIs[RFKind].Declaration)
3560 return;
3561 A.registerVirtualUseCallback(*OMPInfoCache.RFIs[RFKind].Declaration, CB);
3562 };
3563
3564 // Add a dependence to ensure updates if the state changes.
3565 auto AddDependence = [](Attributor &A, const AAKernelInfo *KI,
3566 const AbstractAttribute *QueryingAA) {
3567 if (QueryingAA) {
3568 A.recordDependence(*KI, *QueryingAA, DepClassTy::OPTIONAL);
3569 }
3570 return true;
3571 };
3572
3573 Attributor::VirtualUseCallbackTy CustomStateMachineUseCB =
3574 [&](Attributor &A, const AbstractAttribute *QueryingAA) {
3575 // Whenever we create a custom state machine we will insert calls to
3576 // __kmpc_get_hardware_num_threads_in_block,
3577 // __kmpc_get_warp_size,
3578 // __kmpc_barrier_simple_generic,
3579 // __kmpc_kernel_parallel, and
3580 // __kmpc_kernel_end_parallel.
3581 // Not needed if we are on track for SPMDzation.
3582 if (SPMDCompatibilityTracker.isValidState())
3583 return AddDependence(A, this, QueryingAA);
3584 // Not needed if we can't rewrite due to an invalid state.
3585 if (!ReachedKnownParallelRegions.isValidState())
3586 return AddDependence(A, this, QueryingAA);
3587 return false;
3588 };
3589
3590 // Not needed if we are pre-runtime merge.
3591 if (!KernelInitCB->getCalledFunction()->isDeclaration()) {
3592 RegisterVirtualUse(OMPRTL___kmpc_get_hardware_num_threads_in_block,
3593 CustomStateMachineUseCB);
3594 RegisterVirtualUse(OMPRTL___kmpc_get_warp_size, CustomStateMachineUseCB);
3595 RegisterVirtualUse(OMPRTL___kmpc_barrier_simple_generic,
3596 CustomStateMachineUseCB);
3597 RegisterVirtualUse(OMPRTL___kmpc_kernel_parallel,
3598 CustomStateMachineUseCB);
3599 RegisterVirtualUse(OMPRTL___kmpc_kernel_end_parallel,
3600 CustomStateMachineUseCB);
3601 }
3602
3603 // If we do not perform SPMDzation we do not need the virtual uses below.
3604 if (SPMDCompatibilityTracker.isAtFixpoint())
3605 return;
3606
3607 Attributor::VirtualUseCallbackTy HWThreadIdUseCB =
3608 [&](Attributor &A, const AbstractAttribute *QueryingAA) {
3609 // Whenever we perform SPMDzation we will insert
3610 // __kmpc_get_hardware_thread_id_in_block calls.
3611 if (!SPMDCompatibilityTracker.isValidState())
3612 return AddDependence(A, this, QueryingAA);
3613 return false;
3614 };
3615 RegisterVirtualUse(OMPRTL___kmpc_get_hardware_thread_id_in_block,
3616 HWThreadIdUseCB);
3617
3618 Attributor::VirtualUseCallbackTy SPMDBarrierUseCB =
3619 [&](Attributor &A, const AbstractAttribute *QueryingAA) {
3620 // Whenever we perform SPMDzation with guarding we will insert
3621 // __kmpc_simple_barrier_spmd calls. If SPMDzation failed, there is
3622 // nothing to guard, or there are no parallel regions, we don't need
3623 // the calls.
3624 if (!SPMDCompatibilityTracker.isValidState())
3625 return AddDependence(A, this, QueryingAA);
3626 if (SPMDCompatibilityTracker.empty())
3627 return AddDependence(A, this, QueryingAA);
3628 if (!mayContainParallelRegion())
3629 return AddDependence(A, this, QueryingAA);
3630 return false;
3631 };
3632 RegisterVirtualUse(OMPRTL___kmpc_barrier_simple_spmd, SPMDBarrierUseCB);
3633 }
3634
3635 /// Sanitize the string \p S such that it is a suitable global symbol name.
3636 static std::string sanitizeForGlobalName(std::string S) {
3637 std::replace_if(
3638 S.begin(), S.end(),
3639 [](const char C) {
3640 return !((C >= 'a' && C <= 'z') || (C >= 'A' && C <= 'Z') ||
3641 (C >= '0' && C <= '9') || C == '_');
3642 },
3643 '.');
3644 return S;
3645 }
3646
3647 /// Modify the IR based on the KernelInfoState as the fixpoint iteration is
3648 /// finished now.
3649 ChangeStatus manifest(Attributor &A) override {
3650 // If we are not looking at a kernel with __kmpc_target_init and
3651 // __kmpc_target_deinit call we cannot actually manifest the information.
3652 if (!KernelInitCB || !KernelDeinitCB)
3653 return ChangeStatus::UNCHANGED;
3654
3655 /// Insert nested Parallelism global variable
3656 Function *Kernel = getAnchorScope();
3657 Module &M = *Kernel->getParent();
3658 Type *Int8Ty = Type::getInt8Ty(M.getContext());
3659 new GlobalVariable(M, Int8Ty, /* isConstant */ true,
3661 ConstantInt::get(Int8Ty, NestedParallelism ? 1 : 0),
3662 Kernel->getName() + "_nested_parallelism");
3663
3664 // If we can we change the execution mode to SPMD-mode otherwise we build a
3665 // custom state machine.
3666 ChangeStatus Changed = ChangeStatus::UNCHANGED;
3667 if (!changeToSPMDMode(A, Changed)) {
3668 if (!KernelInitCB->getCalledFunction()->isDeclaration())
3669 return buildCustomStateMachine(A);
3670 }
3671
3672 return Changed;
3673 }
3674
3675 void insertInstructionGuardsHelper(Attributor &A) {
3676 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3677
3678 auto CreateGuardedRegion = [&](Instruction *RegionStartI,
3679 Instruction *RegionEndI) {
3680 LoopInfo *LI = nullptr;
3681 DominatorTree *DT = nullptr;
3682 MemorySSAUpdater *MSU = nullptr;
3683 using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
3684
3685 BasicBlock *ParentBB = RegionStartI->getParent();
3686 Function *Fn = ParentBB->getParent();
3687 Module &M = *Fn->getParent();
3688
3689 // Create all the blocks and logic.
3690 // ParentBB:
3691 // goto RegionCheckTidBB
3692 // RegionCheckTidBB:
3693 // Tid = __kmpc_hardware_thread_id()
3694 // if (Tid != 0)
3695 // goto RegionBarrierBB
3696 // RegionStartBB:
3697 // <execute instructions guarded>
3698 // goto RegionEndBB
3699 // RegionEndBB:
3700 // <store escaping values to shared mem>
3701 // goto RegionBarrierBB
3702 // RegionBarrierBB:
3703 // __kmpc_simple_barrier_spmd()
3704 // // second barrier is omitted if lacking escaping values.
3705 // <load escaping values from shared mem>
3706 // __kmpc_simple_barrier_spmd()
3707 // goto RegionExitBB
3708 // RegionExitBB:
3709 // <execute rest of instructions>
3710
3711 BasicBlock *RegionEndBB = SplitBlock(ParentBB, RegionEndI->getNextNode(),
3712 DT, LI, MSU, "region.guarded.end");
3713 BasicBlock *RegionBarrierBB =
3714 SplitBlock(RegionEndBB, &*RegionEndBB->getFirstInsertionPt(), DT, LI,
3715 MSU, "region.barrier");
3716 BasicBlock *RegionExitBB =
3717 SplitBlock(RegionBarrierBB, &*RegionBarrierBB->getFirstInsertionPt(),
3718 DT, LI, MSU, "region.exit");
3719 BasicBlock *RegionStartBB =
3720 SplitBlock(ParentBB, RegionStartI, DT, LI, MSU, "region.guarded");
3721
3722 assert(ParentBB->getUniqueSuccessor() == RegionStartBB &&
3723 "Expected a different CFG");
3724
3725 BasicBlock *RegionCheckTidBB = SplitBlock(
3726 ParentBB, ParentBB->getTerminator(), DT, LI, MSU, "region.check.tid");
3727
3728 // Register basic blocks with the Attributor.
3729 A.registerManifestAddedBasicBlock(*RegionEndBB);
3730 A.registerManifestAddedBasicBlock(*RegionBarrierBB);
3731 A.registerManifestAddedBasicBlock(*RegionExitBB);
3732 A.registerManifestAddedBasicBlock(*RegionStartBB);
3733 A.registerManifestAddedBasicBlock(*RegionCheckTidBB);
3734
3735 bool HasBroadcastValues = false;
3736 // Find escaping outputs from the guarded region to outside users and
3737 // broadcast their values to them.
3738 for (Instruction &I : *RegionStartBB) {
3739 SmallPtrSet<Instruction *, 4> OutsideUsers;
3740 for (User *Usr : I.users()) {
3741 Instruction &UsrI = *cast<Instruction>(Usr);
3742 if (UsrI.getParent() != RegionStartBB)
3743 OutsideUsers.insert(&UsrI);
3744 }
3745
3746 if (OutsideUsers.empty())
3747 continue;
3748
3749 HasBroadcastValues = true;
3750
3751 // Emit a global variable in shared memory to store the broadcasted
3752 // value.
3753 auto *SharedMem = new GlobalVariable(
3754 M, I.getType(), /* IsConstant */ false,
3756 sanitizeForGlobalName(
3757 (I.getName() + ".guarded.output.alloc").str()),
3759 static_cast<unsigned>(AddressSpace::Shared));
3760
3761 // Emit a store instruction to update the value.
3762 new StoreInst(&I, SharedMem, RegionEndBB->getTerminator());
3763
3764 LoadInst *LoadI = new LoadInst(I.getType(), SharedMem,
3765 I.getName() + ".guarded.output.load",
3766 RegionBarrierBB->getTerminator());
3767
3768 // Emit a load instruction and replace uses of the output value.
3769 for (Instruction *UsrI : OutsideUsers)
3770 UsrI->replaceUsesOfWith(&I, LoadI);
3771 }
3772
3773 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3774
3775 // Go to tid check BB in ParentBB.
3776 const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
3777 ParentBB->getTerminator()->eraseFromParent();
3779 InsertPointTy(ParentBB, ParentBB->end()), DL);
3780 OMPInfoCache.OMPBuilder.updateToLocation(Loc);
3781 uint32_t SrcLocStrSize;
3782 auto *SrcLocStr =
3783 OMPInfoCache.OMPBuilder.getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3784 Value *Ident =
3785 OMPInfoCache.OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3786 BranchInst::Create(RegionCheckTidBB, ParentBB)->setDebugLoc(DL);
3787
3788 // Add check for Tid in RegionCheckTidBB
3789 RegionCheckTidBB->getTerminator()->eraseFromParent();
3790 OpenMPIRBuilder::LocationDescription LocRegionCheckTid(
3791 InsertPointTy(RegionCheckTidBB, RegionCheckTidBB->end()), DL);
3792 OMPInfoCache.OMPBuilder.updateToLocation(LocRegionCheckTid);
3793 FunctionCallee HardwareTidFn =
3794 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3795 M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
3796 CallInst *Tid =
3797 OMPInfoCache.OMPBuilder.Builder.CreateCall(HardwareTidFn, {});
3798 Tid->setDebugLoc(DL);
3799 OMPInfoCache.setCallingConvention(HardwareTidFn, Tid);
3800 Value *TidCheck = OMPInfoCache.OMPBuilder.Builder.CreateIsNull(Tid);
3801 OMPInfoCache.OMPBuilder.Builder
3802 .CreateCondBr(TidCheck, RegionStartBB, RegionBarrierBB)
3803 ->setDebugLoc(DL);
3804
3805 // First barrier for synchronization, ensures main thread has updated
3806 // values.
3807 FunctionCallee BarrierFn =
3808 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3809 M, OMPRTL___kmpc_barrier_simple_spmd);
3810 OMPInfoCache.OMPBuilder.updateToLocation(InsertPointTy(
3811 RegionBarrierBB, RegionBarrierBB->getFirstInsertionPt()));
3812 CallInst *Barrier =
3813 OMPInfoCache.OMPBuilder.Builder.CreateCall(BarrierFn, {Ident, Tid});
3814 Barrier->setDebugLoc(DL);
3815 OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
3816
3817 // Second barrier ensures workers have read broadcast values.
3818 if (HasBroadcastValues) {
3819 CallInst *Barrier = CallInst::Create(BarrierFn, {Ident, Tid}, "",
3820 RegionBarrierBB->getTerminator());
3821 Barrier->setDebugLoc(DL);
3822 OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
3823 }
3824 };
3825
3826 auto &AllocSharedRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3828 for (Instruction *GuardedI : SPMDCompatibilityTracker) {
3829 BasicBlock *BB = GuardedI->getParent();
3830 if (!Visited.insert(BB).second)
3831 continue;
3832
3834 Instruction *LastEffect = nullptr;
3835 BasicBlock::reverse_iterator IP = BB->rbegin(), IPEnd = BB->rend();
3836 while (++IP != IPEnd) {
3837 if (!IP->mayHaveSideEffects() && !IP->mayReadFromMemory())
3838 continue;
3839 Instruction *I = &*IP;
3840 if (OpenMPOpt::getCallIfRegularCall(*I, &AllocSharedRFI))
3841 continue;
3842 if (!I->user_empty() || !SPMDCompatibilityTracker.contains(I)) {
3843 LastEffect = nullptr;
3844 continue;
3845 }
3846 if (LastEffect)
3847 Reorders.push_back({I, LastEffect});
3848 LastEffect = &*IP;
3849 }
3850 for (auto &Reorder : Reorders)
3851 Reorder.first->moveBefore(Reorder.second);
3852 }
3853
3855
3856 for (Instruction *GuardedI : SPMDCompatibilityTracker) {
3857 BasicBlock *BB = GuardedI->getParent();
3858 auto *CalleeAA = A.lookupAAFor<AAKernelInfo>(
3859 IRPosition::function(*GuardedI->getFunction()), nullptr,
3860 DepClassTy::NONE);
3861 assert(CalleeAA != nullptr && "Expected Callee AAKernelInfo");
3862 auto &CalleeAAFunction = *cast<AAKernelInfoFunction>(CalleeAA);
3863 // Continue if instruction is already guarded.
3864 if (CalleeAAFunction.getGuardedInstructions().contains(GuardedI))
3865 continue;
3866
3867 Instruction *GuardedRegionStart = nullptr, *GuardedRegionEnd = nullptr;
3868 for (Instruction &I : *BB) {
3869 // If instruction I needs to be guarded update the guarded region
3870 // bounds.
3871 if (SPMDCompatibilityTracker.contains(&I)) {
3872 CalleeAAFunction.getGuardedInstructions().insert(&I);
3873 if (GuardedRegionStart)
3874 GuardedRegionEnd = &I;
3875 else
3876 GuardedRegionStart = GuardedRegionEnd = &I;
3877
3878 continue;
3879 }
3880
3881 // Instruction I does not need guarding, store
3882 // any region found and reset bounds.
3883 if (GuardedRegionStart) {
3884 GuardedRegions.push_back(
3885 std::make_pair(GuardedRegionStart, GuardedRegionEnd));
3886 GuardedRegionStart = nullptr;
3887 GuardedRegionEnd = nullptr;
3888 }
3889 }
3890 }
3891
3892 for (auto &GR : GuardedRegions)
3893 CreateGuardedRegion(GR.first, GR.second);
3894 }
3895
3896 void forceSingleThreadPerWorkgroupHelper(Attributor &A) {
3897 // Only allow 1 thread per workgroup to continue executing the user code.
3898 //
3899 // InitCB = __kmpc_target_init(...)
3900 // ThreadIdInBlock = __kmpc_get_hardware_thread_id_in_block();
3901 // if (ThreadIdInBlock != 0) return;
3902 // UserCode:
3903 // // user code
3904 //
3905 auto &Ctx = getAnchorValue().getContext();
3906 Function *Kernel = getAssociatedFunction();
3907 assert(Kernel && "Expected an associated function!");
3908
3909 // Create block for user code to branch to from initial block.
3910 BasicBlock *InitBB = KernelInitCB->getParent();
3911 BasicBlock *UserCodeBB = InitBB->splitBasicBlock(
3912 KernelInitCB->getNextNode(), "main.thread.user_code");
3913 BasicBlock *ReturnBB =
3914 BasicBlock::Create(Ctx, "exit.threads", Kernel, UserCodeBB);
3915
3916 // Register blocks with attributor:
3917 A.registerManifestAddedBasicBlock(*InitBB);
3918 A.registerManifestAddedBasicBlock(*UserCodeBB);
3919 A.registerManifestAddedBasicBlock(*ReturnBB);
3920
3921 // Debug location:
3922 const DebugLoc &DLoc = KernelInitCB->getDebugLoc();
3923 ReturnInst::Create(Ctx, ReturnBB)->setDebugLoc(DLoc);
3924 InitBB->getTerminator()->eraseFromParent();
3925
3926 // Prepare call to OMPRTL___kmpc_get_hardware_thread_id_in_block.
3927 Module &M = *Kernel->getParent();
3928 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3929 FunctionCallee ThreadIdInBlockFn =
3930 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3931 M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
3932
3933 // Get thread ID in block.
3934 CallInst *ThreadIdInBlock =
3935 CallInst::Create(ThreadIdInBlockFn, "thread_id.in.block", InitBB);
3936 OMPInfoCache.setCallingConvention(ThreadIdInBlockFn, ThreadIdInBlock);
3937 ThreadIdInBlock->setDebugLoc(DLoc);
3938
3939 // Eliminate all threads in the block with ID not equal to 0:
3940 Instruction *IsMainThread =
3941 ICmpInst::Create(ICmpInst::ICmp, CmpInst::ICMP_NE, ThreadIdInBlock,
3942 ConstantInt::get(ThreadIdInBlock->getType(), 0),
3943 "thread.is_main", InitBB);
3944 IsMainThread->setDebugLoc(DLoc);
3945 BranchInst::Create(ReturnBB, UserCodeBB, IsMainThread, InitBB);
3946 }
3947
3948 bool changeToSPMDMode(Attributor &A, ChangeStatus &Changed) {
3949 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3950
3951 // We cannot change to SPMD mode if the runtime functions aren't availible.
3952 if (!OMPInfoCache.runtimeFnsAvailable(
3953 {OMPRTL___kmpc_get_hardware_thread_id_in_block,
3954 OMPRTL___kmpc_barrier_simple_spmd}))
3955 return false;
3956
3957 if (!SPMDCompatibilityTracker.isAssumed()) {
3958 for (Instruction *NonCompatibleI : SPMDCompatibilityTracker) {
3959 if (!NonCompatibleI)
3960 continue;
3961
3962 // Skip diagnostics on calls to known OpenMP runtime functions for now.
3963 if (auto *CB = dyn_cast<CallBase>(NonCompatibleI))
3964 if (OMPInfoCache.RTLFunctions.contains(CB->getCalledFunction()))
3965 continue;
3966
3967 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
3968 ORA << "Value has potential side effects preventing SPMD-mode "
3969 "execution";
3970 if (isa<CallBase>(NonCompatibleI)) {
3971 ORA << ". Add `__attribute__((assume(\"ompx_spmd_amenable\")))` to "
3972 "the called function to override";
3973 }
3974 return ORA << ".";
3975 };
3976 A.emitRemark<OptimizationRemarkAnalysis>(NonCompatibleI, "OMP121",
3977 Remark);
3978
3979 LLVM_DEBUG(dbgs() << TAG << "SPMD-incompatible side-effect: "
3980 << *NonCompatibleI << "\n");
3981 }
3982
3983 return false;
3984 }
3985
3986 // Get the actual kernel, could be the caller of the anchor scope if we have
3987 // a debug wrapper.
3988 Function *Kernel = getAnchorScope();
3989 if (Kernel->hasLocalLinkage()) {
3990 assert(Kernel->hasOneUse() && "Unexpected use of debug kernel wrapper.");
3991 auto *CB = cast<CallBase>(Kernel->user_back());
3992 Kernel = CB->getCaller();
3993 }
3994 assert(OMPInfoCache.Kernels.count(Kernel) && "Expected kernel function!");
3995
3996 // Check if the kernel is already in SPMD mode, if so, return success.
3998 (Kernel->getName() + "_exec_mode").str());
3999 assert(ExecMode && "Kernel without exec mode?");
4000 assert(ExecMode->getInitializer() && "ExecMode doesn't have initializer!");
4001
4002 // Set the global exec mode flag to indicate SPMD-Generic mode.
4003 assert(isa<ConstantInt>(ExecMode->getInitializer()) &&
4004 "ExecMode is not an integer!");
4005 const int8_t ExecModeVal =
4006 cast<ConstantInt>(ExecMode->getInitializer())->getSExtValue();
4007 if (ExecModeVal != OMP_TGT_EXEC_MODE_GENERIC)
4008 return true;
4009
4010 // We will now unconditionally modify the IR, indicate a change.
4011 Changed = ChangeStatus::CHANGED;
4012
4013 // Do not use instruction guards when no parallel is present inside
4014 // the target region.
4015 if (mayContainParallelRegion())
4016 insertInstructionGuardsHelper(A);
4017 else
4018 forceSingleThreadPerWorkgroupHelper(A);
4019
4020 // Adjust the global exec mode flag that tells the runtime what mode this
4021 // kernel is executed in.
4022 assert(ExecModeVal == OMP_TGT_EXEC_MODE_GENERIC &&
4023 "Initially non-SPMD kernel has SPMD exec mode!");
4024 ExecMode->setInitializer(
4026 ExecModeVal | OMP_TGT_EXEC_MODE_GENERIC_SPMD));
4027
4028 // Next rewrite the init and deinit calls to indicate we use SPMD-mode now.
4029 const int InitModeArgNo = 1;
4030 const int DeinitModeArgNo = 1;
4031 const int InitUseStateMachineArgNo = 2;
4032
4033 auto &Ctx = getAnchorValue().getContext();
4034 A.changeUseAfterManifest(
4035 KernelInitCB->getArgOperandUse(InitModeArgNo),
4036 *ConstantInt::getSigned(IntegerType::getInt8Ty(Ctx),
4038 A.changeUseAfterManifest(
4039 KernelInitCB->getArgOperandUse(InitUseStateMachineArgNo),
4040 *ConstantInt::getBool(Ctx, false));
4041 A.changeUseAfterManifest(
4042 KernelDeinitCB->getArgOperandUse(DeinitModeArgNo),
4043 *ConstantInt::getSigned(IntegerType::getInt8Ty(Ctx),
4045
4046 ++NumOpenMPTargetRegionKernelsSPMD;
4047
4048 auto Remark = [&](OptimizationRemark OR) {
4049 return OR << "Transformed generic-mode kernel to SPMD-mode.";
4050 };
4051 A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP120", Remark);
4052 return true;
4053 };
4054
4055 ChangeStatus buildCustomStateMachine(Attributor &A) {
4056 // If we have disabled state machine rewrites, don't make a custom one
4058 return ChangeStatus::UNCHANGED;
4059
4060 // Don't rewrite the state machine if we are not in a valid state.
4061 if (!ReachedKnownParallelRegions.isValidState())
4062 return ChangeStatus::UNCHANGED;
4063
4064 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4065 if (!OMPInfoCache.runtimeFnsAvailable(
4066 {OMPRTL___kmpc_get_hardware_num_threads_in_block,
4067 OMPRTL___kmpc_get_warp_size, OMPRTL___kmpc_barrier_simple_generic,
4068 OMPRTL___kmpc_kernel_parallel, OMPRTL___kmpc_kernel_end_parallel}))
4069 return ChangeStatus::UNCHANGED;
4070
4071 const int InitModeArgNo = 1;
4072 const int InitUseStateMachineArgNo = 2;
4073
4074 // Check if the current configuration is non-SPMD and generic state machine.
4075 // If we already have SPMD mode or a custom state machine we do not need to
4076 // go any further. If it is anything but a constant something is weird and
4077 // we give up.
4078 ConstantInt *UseStateMachine = dyn_cast<ConstantInt>(
4079 KernelInitCB->getArgOperand(InitUseStateMachineArgNo));
4080 ConstantInt *Mode =
4081 dyn_cast<ConstantInt>(KernelInitCB->getArgOperand(InitModeArgNo));
4082
4083 // If we are stuck with generic mode, try to create a custom device (=GPU)
4084 // state machine which is specialized for the parallel regions that are
4085 // reachable by the kernel.
4086 if (!UseStateMachine || UseStateMachine->isZero() || !Mode ||
4087 (Mode->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD))
4088 return ChangeStatus::UNCHANGED;
4089
4090 // If not SPMD mode, indicate we use a custom state machine now.
4091 auto &Ctx = getAnchorValue().getContext();
4092 auto *FalseVal = ConstantInt::getBool(Ctx, false);
4093 A.changeUseAfterManifest(
4094 KernelInitCB->getArgOperandUse(InitUseStateMachineArgNo), *FalseVal);
4095
4096 // If we don't actually need a state machine we are done here. This can
4097 // happen if there simply are no parallel regions. In the resulting kernel
4098 // all worker threads will simply exit right away, leaving the main thread
4099 // to do the work alone.
4100 if (!mayContainParallelRegion()) {
4101 ++NumOpenMPTargetRegionKernelsWithoutStateMachine;
4102
4103 auto Remark = [&](OptimizationRemark OR) {
4104 return OR << "Removing unused state machine from generic-mode kernel.";
4105 };
4106 A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP130", Remark);
4107
4108 return ChangeStatus::CHANGED;
4109 }
4110
4111 // Keep track in the statistics of our new shiny custom state machine.
4112 if (ReachedUnknownParallelRegions.empty()) {
4113 ++NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback;
4114
4115 auto Remark = [&](OptimizationRemark OR) {
4116 return OR << "Rewriting generic-mode kernel with a customized state "
4117 "machine.";
4118 };
4119 A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP131", Remark);
4120 } else {
4121 ++NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback;
4122
4124 return OR << "Generic-mode kernel is executed with a customized state "
4125 "machine that requires a fallback.";
4126 };
4127 A.emitRemark<OptimizationRemarkAnalysis>(KernelInitCB, "OMP132", Remark);
4128
4129 // Tell the user why we ended up with a fallback.
4130 for (CallBase *UnknownParallelRegionCB : ReachedUnknownParallelRegions) {
4131 if (!UnknownParallelRegionCB)
4132 continue;
4133 auto Remark = [&](OptimizationRemarkAnalysis ORA) {
4134 return ORA << "Call may contain unknown parallel regions. Use "
4135 << "`__attribute__((assume(\"omp_no_parallelism\")))` to "
4136 "override.";
4137 };
4138 A.emitRemark<OptimizationRemarkAnalysis>(UnknownParallelRegionCB,
4139 "OMP133", Remark);
4140 }
4141 }
4142
4143 // Create all the blocks:
4144 //
4145 // InitCB = __kmpc_target_init(...)
4146 // BlockHwSize =
4147 // __kmpc_get_hardware_num_threads_in_block();
4148 // WarpSize = __kmpc_get_warp_size();
4149 // BlockSize = BlockHwSize - WarpSize;
4150 // IsWorkerCheckBB: bool IsWorker = InitCB != -1;
4151 // if (IsWorker) {
4152 // if (InitCB >= BlockSize) return;
4153 // SMBeginBB: __kmpc_barrier_simple_generic(...);
4154 // void *WorkFn;
4155 // bool Active = __kmpc_kernel_parallel(&WorkFn);
4156 // if (!WorkFn) return;
4157 // SMIsActiveCheckBB: if (Active) {
4158 // SMIfCascadeCurrentBB: if (WorkFn == <ParFn0>)
4159 // ParFn0(...);
4160 // SMIfCascadeCurrentBB: else if (WorkFn == <ParFn1>)
4161 // ParFn1(...);
4162 // ...
4163 // SMIfCascadeCurrentBB: else
4164 // ((WorkFnTy*)WorkFn)(...);
4165 // SMEndParallelBB: __kmpc_kernel_end_parallel(...);
4166 // }
4167 // SMDoneBB: __kmpc_barrier_simple_generic(...);
4168 // goto SMBeginBB;
4169 // }
4170 // UserCodeEntryBB: // user code
4171 // __kmpc_target_deinit(...)
4172 //
4173 Function *Kernel = getAssociatedFunction();
4174 assert(Kernel && "Expected an associated function!");
4175
4176 BasicBlock *InitBB = KernelInitCB->getParent();
4177 BasicBlock *UserCodeEntryBB = InitBB->splitBasicBlock(
4178 KernelInitCB->getNextNode(), "thread.user_code.check");
4179 BasicBlock *IsWorkerCheckBB =
4180 BasicBlock::Create(Ctx, "is_worker_check", Kernel, UserCodeEntryBB);
4181 BasicBlock *StateMachineBeginBB = BasicBlock::Create(
4182 Ctx, "worker_state_machine.begin", Kernel, UserCodeEntryBB);
4183 BasicBlock *StateMachineFinishedBB = BasicBlock::Create(
4184 Ctx, "worker_state_machine.finished", Kernel, UserCodeEntryBB);
4185 BasicBlock *StateMachineIsActiveCheckBB = BasicBlock::Create(
4186 Ctx, "worker_state_machine.is_active.check", Kernel, UserCodeEntryBB);
4187 BasicBlock *StateMachineIfCascadeCurrentBB =
4188 BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
4189 Kernel, UserCodeEntryBB);
4190 BasicBlock *StateMachineEndParallelBB =
4191 BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.end",
4192 Kernel, UserCodeEntryBB);
4193 BasicBlock *StateMachineDoneBarrierBB = BasicBlock::Create(
4194 Ctx, "worker_state_machine.done.barrier", Kernel, UserCodeEntryBB);
4195 A.registerManifestAddedBasicBlock(*InitBB);
4196 A.registerManifestAddedBasicBlock(*UserCodeEntryBB);
4197 A.registerManifestAddedBasicBlock(*IsWorkerCheckBB);
4198 A.registerManifestAddedBasicBlock(*StateMachineBeginBB);
4199 A.registerManifestAddedBasicBlock(*StateMachineFinishedBB);
4200 A.registerManifestAddedBasicBlock(*StateMachineIsActiveCheckBB);
4201 A.registerManifestAddedBasicBlock(*StateMachineIfCascadeCurrentBB);
4202 A.registerManifestAddedBasicBlock(*StateMachineEndParallelBB);
4203 A.registerManifestAddedBasicBlock(*StateMachineDoneBarrierBB);
4204
4205 const DebugLoc &DLoc = KernelInitCB->getDebugLoc();
4206 ReturnInst::Create(Ctx, StateMachineFinishedBB)->setDebugLoc(DLoc);
4207 InitBB->getTerminator()->eraseFromParent();
4208
4209 Instruction *IsWorker =
4210 ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_NE, KernelInitCB,
4211 ConstantInt::get(KernelInitCB->getType(), -1),
4212 "thread.is_worker", InitBB);
4213 IsWorker->setDebugLoc(DLoc);
4214 BranchInst::Create(IsWorkerCheckBB, UserCodeEntryBB, IsWorker, InitBB);
4215
4216 Module &M = *Kernel->getParent();
4217 FunctionCallee BlockHwSizeFn =
4218 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4219 M, OMPRTL___kmpc_get_hardware_num_threads_in_block);
4220 FunctionCallee WarpSizeFn =
4221 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4222 M, OMPRTL___kmpc_get_warp_size);
4223 CallInst *BlockHwSize =
4224 CallInst::Create(BlockHwSizeFn, "block.hw_size", IsWorkerCheckBB);
4225 OMPInfoCache.setCallingConvention(BlockHwSizeFn, BlockHwSize);
4226 BlockHwSize->setDebugLoc(DLoc);
4227 CallInst *WarpSize =
4228 CallInst::Create(WarpSizeFn, "warp.size", IsWorkerCheckBB);
4229 OMPInfoCache.setCallingConvention(WarpSizeFn, WarpSize);
4230 WarpSize->setDebugLoc(DLoc);
4231 Instruction *BlockSize = BinaryOperator::CreateSub(
4232 BlockHwSize, WarpSize, "block.size", IsWorkerCheckBB);
4233 BlockSize->setDebugLoc(DLoc);
4234 Instruction *IsMainOrWorker = ICmpInst::Create(
4235 ICmpInst::ICmp, llvm::CmpInst::ICMP_SLT, KernelInitCB, BlockSize,
4236 "thread.is_main_or_worker", IsWorkerCheckBB);
4237 IsMainOrWorker->setDebugLoc(DLoc);
4238 BranchInst::Create(StateMachineBeginBB, StateMachineFinishedBB,
4239 IsMainOrWorker, IsWorkerCheckBB);
4240
4241 // Create local storage for the work function pointer.
4242 const DataLayout &DL = M.getDataLayout();
4243 Type *VoidPtrTy = Type::getInt8PtrTy(Ctx);
4244 Instruction *WorkFnAI =
4245 new AllocaInst(VoidPtrTy, DL.getAllocaAddrSpace(), nullptr,
4246 "worker.work_fn.addr", &Kernel->getEntryBlock().front());
4247 WorkFnAI->setDebugLoc(DLoc);
4248
4249 OMPInfoCache.OMPBuilder.updateToLocation(
4251 IRBuilder<>::InsertPoint(StateMachineBeginBB,
4252 StateMachineBeginBB->end()),
4253 DLoc));
4254
4255 Value *Ident = KernelInitCB->getArgOperand(0);
4256 Value *GTid = KernelInitCB;
4257
4258 FunctionCallee BarrierFn =
4259 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4260 M, OMPRTL___kmpc_barrier_simple_generic);
4261 CallInst *Barrier =
4262 CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineBeginBB);
4263 OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
4264 Barrier->setDebugLoc(DLoc);
4265
4266 if (WorkFnAI->getType()->getPointerAddressSpace() !=
4267 (unsigned int)AddressSpace::Generic) {
4268 WorkFnAI = new AddrSpaceCastInst(
4269 WorkFnAI,
4270 PointerType::getWithSamePointeeType(
4271 cast<PointerType>(WorkFnAI->getType()),
4272 (unsigned int)AddressSpace::Generic),
4273 WorkFnAI->getName() + ".generic", StateMachineBeginBB);
4274 WorkFnAI->setDebugLoc(DLoc);
4275 }
4276
4277 FunctionCallee KernelParallelFn =
4278 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4279 M, OMPRTL___kmpc_kernel_parallel);
4280 CallInst *IsActiveWorker = CallInst::Create(
4281 KernelParallelFn, {WorkFnAI}, "worker.is_active", StateMachineBeginBB);
4282 OMPInfoCache.setCallingConvention(KernelParallelFn, IsActiveWorker);
4283 IsActiveWorker->setDebugLoc(DLoc);
4284 Instruction *WorkFn = new LoadInst(VoidPtrTy, WorkFnAI, "worker.work_fn",
4285 StateMachineBeginBB);
4286 WorkFn->setDebugLoc(DLoc);
4287
4288 FunctionType *ParallelRegionFnTy = FunctionType::get(
4290 false);
4291 Value *WorkFnCast = BitCastInst::CreatePointerBitCastOrAddrSpaceCast(
4292 WorkFn, ParallelRegionFnTy->getPointerTo(), "worker.work_fn.addr_cast",
4293 StateMachineBeginBB);
4294
4295 Instruction *IsDone =
4296 ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFn,
4297 Constant::getNullValue(VoidPtrTy), "worker.is_done",
4298 StateMachineBeginBB);
4299 IsDone->setDebugLoc(DLoc);
4300 BranchInst::Create(StateMachineFinishedBB, StateMachineIsActiveCheckBB,
4301 IsDone, StateMachineBeginBB)
4302 ->setDebugLoc(DLoc);
4303
4304 BranchInst::Create(StateMachineIfCascadeCurrentBB,
4305 StateMachineDoneBarrierBB, IsActiveWorker,
4306 StateMachineIsActiveCheckBB)
4307 ->setDebugLoc(DLoc);
4308
4309 Value *ZeroArg =
4310 Constant::getNullValue(ParallelRegionFnTy->getParamType(0));
4311
4312 // Now that we have most of the CFG skeleton it is time for the if-cascade
4313 // that checks the function pointer we got from the runtime against the
4314 // parallel regions we expect, if there are any.
4315 for (int I = 0, E = ReachedKnownParallelRegions.size(); I < E; ++I) {
4316 auto *ParallelRegion = ReachedKnownParallelRegions[I];
4317 BasicBlock *PRExecuteBB = BasicBlock::Create(
4318 Ctx, "worker_state_machine.parallel_region.execute", Kernel,
4319 StateMachineEndParallelBB);
4320 CallInst::Create(ParallelRegion, {ZeroArg, GTid}, "", PRExecuteBB)
4321 ->setDebugLoc(DLoc);
4322 BranchInst::Create(StateMachineEndParallelBB, PRExecuteBB)
4323 ->setDebugLoc(DLoc);
4324
4325 BasicBlock *PRNextBB =
4326 BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
4327 Kernel, StateMachineEndParallelBB);
4328
4329 // Check if we need to compare the pointer at all or if we can just
4330 // call the parallel region function.
4331 Value *IsPR;
4332 if (I + 1 < E || !ReachedUnknownParallelRegions.empty()) {
4333 Instruction *CmpI = ICmpInst::Create(
4334 ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFnCast, ParallelRegion,
4335 "worker.check_parallel_region", StateMachineIfCascadeCurrentBB);
4336 CmpI->setDebugLoc(DLoc);
4337 IsPR = CmpI;
4338 } else {
4339 IsPR = ConstantInt::getTrue(Ctx);
4340 }
4341
4342 BranchInst::Create(PRExecuteBB, PRNextBB, IsPR,
4343 StateMachineIfCascadeCurrentBB)
4344 ->setDebugLoc(DLoc);
4345 StateMachineIfCascadeCurrentBB = PRNextBB;
4346 }
4347
4348 // At the end of the if-cascade we place the indirect function pointer call
4349 // in case we might need it, that is if there can be parallel regions we
4350 // have not handled in the if-cascade above.
4351 if (!ReachedUnknownParallelRegions.empty()) {
4352 StateMachineIfCascadeCurrentBB->setName(
4353 "worker_state_machine.parallel_region.fallback.execute");
4354 CallInst::Create(ParallelRegionFnTy, WorkFnCast, {ZeroArg, GTid}, "",
4355 StateMachineIfCascadeCurrentBB)
4356 ->setDebugLoc(DLoc);
4357 }
4358 BranchInst::Create(StateMachineEndParallelBB,
4359 StateMachineIfCascadeCurrentBB)
4360 ->setDebugLoc(DLoc);
4361
4362 FunctionCallee EndParallelFn =
4363 OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
4364 M, OMPRTL___kmpc_kernel_end_parallel);
4365 CallInst *EndParallel =
4366 CallInst::Create(EndParallelFn, {}, "", StateMachineEndParallelBB);
4367 OMPInfoCache.setCallingConvention(EndParallelFn, EndParallel);
4368 EndParallel->setDebugLoc(DLoc);
4369 BranchInst::Create(StateMachineDoneBarrierBB, StateMachineEndParallelBB)
4370 ->setDebugLoc(DLoc);
4371
4372 CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineDoneBarrierBB)
4373 ->setDebugLoc(DLoc);
4374 BranchInst::Create(StateMachineBeginBB, StateMachineDoneBarrierBB)
4375 ->setDebugLoc(DLoc);
4376
4377 return ChangeStatus::CHANGED;
4378 }
4379
4380 /// Fixpoint iteration update function. Will be called every time a dependence
4381 /// changed its state (and in the beginning).
4382 ChangeStatus updateImpl(Attributor &A) override {
4383 KernelInfoState StateBefore = getState();
4384
4385 // Callback to check a read/write instruction.
4386 auto CheckRWInst = [&](Instruction &I) {
4387 // We handle calls later.
4388 if (isa<CallBase>(I))
4389 return true;
4390 // We only care about write effects.
4391 if (!I.mayWriteToMemory())
4392 return true;
4393 if (auto *SI = dyn_cast<StoreInst>(&I)) {
4394 const auto &UnderlyingObjsAA = A.getAAFor<AAUnderlyingObjects>(
4395 *this, IRPosition::value(*SI->getPointerOperand()),
4396 DepClassTy::OPTIONAL);
4397 auto &HS = A.getAAFor<AAHeapToStack>(
4398 *this, IRPosition::function(*I.getFunction()),
4399 DepClassTy::OPTIONAL);
4400 if (UnderlyingObjsAA.forallUnderlyingObjects([&](Value &Obj) {
4401 if (AA::isAssumedThreadLocalObject(A, Obj, *this))
4402 return true;
4403 // Check for AAHeapToStack moved objects which must not be
4404 // guarded.
4405 auto *CB = dyn_cast<CallBase>(&Obj);
4406 return CB && HS.isAssumedHeapToStack(*CB);
4407 }))
4408 return true;
4409 }
4410
4411 // Insert instruction that needs guarding.
4412 SPMDCompatibilityTracker.insert(&I);
4413 return true;
4414 };
4415
4416 bool UsedAssumedInformationInCheckRWInst = false;
4417 if (!SPMDCompatibilityTracker.isAtFixpoint())
4418 if (!A.checkForAllReadWriteInstructions(
4419 CheckRWInst, *this, UsedAssumedInformationInCheckRWInst))
4420 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4421
4422 bool UsedAssumedInformationFromReachingKernels = false;
4423 if (!IsKernelEntry) {
4424 updateParallelLevels(A);
4425
4426 bool AllReachingKernelsKnown = true;
4427 updateReachingKernelEntries(A, AllReachingKernelsKnown);
4428 UsedAssumedInformationFromReachingKernels = !AllReachingKernelsKnown;
4429
4430 if (!SPMDCompatibilityTracker.empty()) {
4431 if (!ParallelLevels.isValidState())
4432 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4433 else if (!ReachingKernelEntries.isValidState())
4434 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4435 else {
4436 // Check if all reaching kernels agree on the mode as we can otherwise
4437 // not guard instructions. We might not be sure about the mode so we
4438 // we cannot fix the internal spmd-zation state either.
4439 int SPMD = 0, Generic = 0;
4440 for (auto *Kernel : ReachingKernelEntries) {
4441 auto &CBAA = A.getAAFor<AAKernelInfo>(
4442 *this, IRPosition::function(*Kernel), DepClassTy::OPTIONAL);
4443 if (CBAA.SPMDCompatibilityTracker.isValidState() &&
4444 CBAA.SPMDCompatibilityTracker.isAssumed())
4445 ++SPMD;
4446 else
4447 ++Generic;
4448 if (!CBAA.SPMDCompatibilityTracker.isAtFixpoint())
4449 UsedAssumedInformationFromReachingKernels = true;
4450 }
4451 if (SPMD != 0 && Generic != 0)
4452 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4453 }
4454 }
4455 }
4456
4457 // Callback to check a call instruction.
4458 bool AllParallelRegionStatesWereFixed = true;
4459 bool AllSPMDStatesWereFixed = true;
4460 auto CheckCallInst = [&](Instruction &I) {
4461 auto &CB = cast<CallBase>(I);
4462 auto &CBAA = A.getAAFor<AAKernelInfo>(
4463 *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);
4464 getState() ^= CBAA.getState();
4465 AllSPMDStatesWereFixed &= CBAA.SPMDCompatibilityTracker.isAtFixpoint();
4466 AllParallelRegionStatesWereFixed &=
4467 CBAA.ReachedKnownParallelRegions.isAtFixpoint();
4468 AllParallelRegionStatesWereFixed &=
4469 CBAA.ReachedUnknownParallelRegions.isAtFixpoint();
4470 return true;
4471 };
4472
4473 bool UsedAssumedInformationInCheckCallInst = false;
4474 if (!A.checkForAllCallLikeInstructions(
4475 CheckCallInst, *this, UsedAssumedInformationInCheckCallInst)) {
4476 LLVM_DEBUG(dbgs() << TAG
4477 << "Failed to visit all call-like instructions!\n";);
4478 return indicatePessimisticFixpoint();
4479 }
4480
4481 // If we haven't used any assumed information for the reached parallel
4482 // region states we can fix it.
4483 if (!UsedAssumedInformationInCheckCallInst &&
4484 AllParallelRegionStatesWereFixed) {
4485 ReachedKnownParallelRegions.indicateOptimisticFixpoint();
4486 ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
4487 }
4488
4489 // If we haven't used any assumed information for the SPMD state we can fix
4490 // it.
4491 if (!UsedAssumedInformationInCheckRWInst &&
4492 !UsedAssumedInformationInCheckCallInst &&
4493 !UsedAssumedInformationFromReachingKernels && AllSPMDStatesWereFixed)
4494 SPMDCompatibilityTracker.indicateOptimisticFixpoint();
4495
4496 return StateBefore == getState() ? ChangeStatus::UNCHANGED
4497 : ChangeStatus::CHANGED;
4498 }
4499
4500private:
4501 /// Update info regarding reaching kernels.
4502 void updateReachingKernelEntries(Attributor &A,
4503 bool &AllReachingKernelsKnown) {
4504 auto PredCallSite = [&](AbstractCallSite ACS) {
4505 Function *Caller = ACS.getInstruction()->getFunction();
4506
4507 assert(Caller && "Caller is nullptr");
4508
4509 auto &CAA = A.getOrCreateAAFor<AAKernelInfo>(
4510 IRPosition::function(*Caller), this, DepClassTy::REQUIRED);
4511 if (CAA.ReachingKernelEntries.isValidState()) {
4512 ReachingKernelEntries ^= CAA.ReachingKernelEntries;
4513 return true;
4514 }
4515
4516 // We lost track of the caller of the associated function, any kernel
4517 // could reach now.
4518 ReachingKernelEntries.indicatePessimisticFixpoint();
4519
4520 return true;
4521 };
4522
4523 if (!A.checkForAllCallSites(PredCallSite, *this,
4524 true /* RequireAllCallSites */,
4525 AllReachingKernelsKnown))
4526 ReachingKernelEntries.indicatePessimisticFixpoint();
4527 }
4528
4529 /// Update info regarding parallel levels.
4530 void updateParallelLevels(Attributor &A) {
4531 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4532 OMPInformationCache::RuntimeFunctionInfo &Parallel51RFI =
4533 OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
4534
4535 auto PredCallSite = [&](AbstractCallSite ACS) {
4536 Function *Caller = ACS.getInstruction()->getFunction();
4537
4538 assert(Caller && "Caller is nullptr");
4539
4540 auto &CAA =
4541 A.getOrCreateAAFor<AAKernelInfo>(IRPosition::function(*Caller));
4542 if (CAA.ParallelLevels.isValidState()) {
4543 // Any function that is called by `__kmpc_parallel_51` will not be
4544 // folded as the parallel level in the function is updated. In order to
4545 // get it right, all the analysis would depend on the implentation. That
4546 // said, if in the future any change to the implementation, the analysis
4547 // could be wrong. As a consequence, we are just conservative here.
4548 if (Caller == Parallel51RFI.Declaration) {
4549 ParallelLevels.indicatePessimisticFixpoint();
4550 return true;
4551 }
4552
4553 ParallelLevels ^= CAA.ParallelLevels;
4554
4555 return true;
4556 }
4557
4558 // We lost track of the caller of the associated function, any kernel
4559 // could reach now.
4560 ParallelLevels.indicatePessimisticFixpoint();
4561
4562 return true;
4563 };
4564
4565 bool AllCallSitesKnown = true;
4566 if (!A.checkForAllCallSites(PredCallSite, *this,
4567 true /* RequireAllCallSites */,
4568 AllCallSitesKnown))
4569 ParallelLevels.indicatePessimisticFixpoint();
4570 }
4571};
4572
4573/// The call site kernel info abstract attribute, basically, what can we say
4574/// about a call site with regards to the KernelInfoState. For now this simply
4575/// forwards the information from the callee.
4576struct AAKernelInfoCallSite : AAKernelInfo {
4577 AAKernelInfoCallSite(const IRPosition &IRP, Attributor &A)
4578 : AAKernelInfo(IRP, A) {}
4579
4580 /// See AbstractAttribute::initialize(...).
4581 void initialize(Attributor &A) override {
4582 AAKernelInfo::initialize(A);
4583
4584 CallBase &CB = cast<CallBase>(getAssociatedValue());
4585 Function *Callee = getAssociatedFunction();
4586
4587 auto &AssumptionAA = A.getAAFor<AAAssumptionInfo>(
4588 *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);
4589
4590 // Check for SPMD-mode assumptions.
4591 if (AssumptionAA.hasAssumption("ompx_spmd_amenable")) {
4592 SPMDCompatibilityTracker.indicateOptimisticFixpoint();
4593 indicateOptimisticFixpoint();
4594 }
4595
4596 // First weed out calls we do not care about, that is readonly/readnone
4597 // calls, intrinsics, and "no_openmp" calls. Neither of these can reach a
4598 // parallel region or anything else we are looking for.
4599 if (!CB.mayWriteToMemory() || isa<IntrinsicInst>(CB)) {
4600 indicateOptimisticFixpoint();
4601 return;
4602 }
4603
4604 // Next we check if we know the callee. If it is a known OpenMP function
4605 // we will handle them explicitly in the switch below. If it is not, we
4606 // will use an AAKernelInfo object on the callee to gather information and
4607 // merge that into the current state. The latter happens in the updateImpl.
4608 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4609 const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
4610 if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
4611 // Unknown caller or declarations are not analyzable, we give up.
4612 if (!Callee || !A.isFunctionIPOAmendable(*Callee)) {
4613
4614 // Unknown callees might contain parallel regions, except if they have
4615 // an appropriate assumption attached.
4616 if (!(AssumptionAA.hasAssumption("omp_no_openmp") ||
4617 AssumptionAA.hasAssumption("omp_no_parallelism")))
4618 ReachedUnknownParallelRegions.insert(&CB);
4619
4620 // If SPMDCompatibilityTracker is not fixed, we need to give up on the
4621 // idea we can run something unknown in SPMD-mode.
4622 if (!SPMDCompatibilityTracker.isAtFixpoint()) {
4623 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4624 SPMDCompatibilityTracker.insert(&CB);
4625 }
4626
4627 // We have updated the state for this unknown call properly, there won't
4628 // be any change so we indicate a fixpoint.
4629 indicateOptimisticFixpoint();
4630 }
4631 // If the callee is known and can be used in IPO, we will update the state
4632 // based on the callee state in updateImpl.
4633 return;
4634 }
4635
4636 const unsigned int WrapperFunctionArgNo = 6;
4637 RuntimeFunction RF = It->getSecond();
4638 switch (RF) {
4639 // All the functions we know are compatible with SPMD mode.
4640 case OMPRTL___kmpc_is_spmd_exec_mode:
4641 case OMPRTL___kmpc_distribute_static_fini:
4642 case OMPRTL___kmpc_for_static_fini:
4643 case OMPRTL___kmpc_global_thread_num:
4644 case OMPRTL___kmpc_get_hardware_num_threads_in_block:
4645 case OMPRTL___kmpc_get_hardware_num_blocks:
4646 case OMPRTL___kmpc_single:
4647 case OMPRTL___kmpc_end_single:
4648 case OMPRTL___kmpc_master:
4649 case OMPRTL___kmpc_end_master:
4650 case OMPRTL___kmpc_barrier:
4651 case OMPRTL___kmpc_nvptx_parallel_reduce_nowait_v2:
4652 case OMPRTL___kmpc_nvptx_teams_reduce_nowait_v2:
4653 case OMPRTL___kmpc_nvptx_end_reduce_nowait:
4654 break;
4655 case OMPRTL___kmpc_distribute_static_init_4:
4656 case OMPRTL___kmpc_distribute_static_init_4u:
4657 case OMPRTL___kmpc_distribute_static_init_8:
4658 case OMPRTL___kmpc_distribute_static_init_8u:
4659 case OMPRTL___kmpc_for_static_init_4:
4660 case OMPRTL___kmpc_for_static_init_4u:
4661 case OMPRTL___kmpc_for_static_init_8:
4662 case OMPRTL___kmpc_for_static_init_8u: {
4663 // Check the schedule and allow static schedule in SPMD mode.
4664 unsigned ScheduleArgOpNo = 2;
4665 auto *ScheduleTypeCI =
4666 dyn_cast<ConstantInt>(CB.getArgOperand(ScheduleArgOpNo));
4667 unsigned ScheduleTypeVal =
4668 ScheduleTypeCI ? ScheduleTypeCI->getZExtValue() : 0;
4669 switch (OMPScheduleType(ScheduleTypeVal)) {
4670 case OMPScheduleType::UnorderedStatic:
4671 case OMPScheduleType::UnorderedStaticChunked:
4672 case OMPScheduleType::OrderedDistribute:
4673 case OMPScheduleType::OrderedDistributeChunked:
4674 break;
4675 default:
4676 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4677 SPMDCompatibilityTracker.insert(&CB);
4678 break;
4679 };
4680 } break;
4681 case OMPRTL___kmpc_target_init:
4682 KernelInitCB = &CB;
4683 break;
4684 case OMPRTL___kmpc_target_deinit:
4685 KernelDeinitCB = &CB;
4686 break;
4687 case OMPRTL___kmpc_parallel_51:
4688 if (auto *ParallelRegion = dyn_cast<Function>(
4689 CB.getArgOperand(WrapperFunctionArgNo)->stripPointerCasts())) {
4690 ReachedKnownParallelRegions.insert(ParallelRegion);
4691 /// Check nested parallelism
4692 auto &FnAA = A.getAAFor<AAKernelInfo>(
4693 *this, IRPosition::function(*ParallelRegion), DepClassTy::OPTIONAL);
4694 NestedParallelism |= !FnAA.getState().isValidState() ||
4695 !FnAA.ReachedKnownParallelRegions.empty() ||
4696 !FnAA.ReachedUnknownParallelRegions.empty();
4697 break;
4698 }
4699 // The condition above should usually get the parallel region function
4700 // pointer and record it. In the off chance it doesn't we assume the
4701 // worst.
4702 ReachedUnknownParallelRegions.insert(&CB);
4703 break;
4704 case OMPRTL___kmpc_omp_task:
4705 // We do not look into tasks right now, just give up.
4706 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4707 SPMDCompatibilityTracker.insert(&CB);
4708 ReachedUnknownParallelRegions.insert(&CB);
4709 break;
4710 case OMPRTL___kmpc_alloc_shared:
4711 case OMPRTL___kmpc_free_shared:
4712 // Return without setting a fixpoint, to be resolved in updateImpl.
4713 return;
4714 default:
4715 // Unknown OpenMP runtime calls cannot be executed in SPMD-mode,
4716 // generally. However, they do not hide parallel regions.
4717 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4718 SPMDCompatibilityTracker.insert(&CB);
4719 break;
4720 }
4721 // All other OpenMP runtime calls will not reach parallel regions so they
4722 // can be safely ignored for now. Since it is a known OpenMP runtime call we
4723 // have now modeled all effects and there is no need for any update.
4724 indicateOptimisticFixpoint();
4725 }
4726
4727 ChangeStatus updateImpl(Attributor &A) override {
4728 // TODO: Once we have call site specific value information we can provide
4729 // call site specific liveness information and then it makes
4730 // sense to specialize attributes for call sites arguments instead of
4731 // redirecting requests to the callee argument.
4732 Function *F = getAssociatedFunction();
4733
4734 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4735 const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(F);
4736
4737 // If F is not a runtime function, propagate the AAKernelInfo of the callee.
4738 if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
4739 const IRPosition &FnPos = IRPosition::function(*F);
4740 auto &FnAA = A.getAAFor<AAKernelInfo>(*this, FnPos, DepClassTy::REQUIRED);
4741 if (getState() == FnAA.getState())
4742 return ChangeStatus::UNCHANGED;
4743 getState() = FnAA.getState();
4744 return ChangeStatus::CHANGED;
4745 }
4746
4747 // F is a runtime function that allocates or frees memory, check
4748 // AAHeapToStack and AAHeapToShared.
4749 KernelInfoState StateBefore = getState();
4750 assert((It->getSecond() == OMPRTL___kmpc_alloc_shared ||
4751 It->getSecond() == OMPRTL___kmpc_free_shared) &&
4752 "Expected a __kmpc_alloc_shared or __kmpc_free_shared runtime call");
4753
4754 CallBase &CB = cast<CallBase>(getAssociatedValue());
4755
4756 auto &HeapToStackAA = A.getAAFor<AAHeapToStack>(
4757 *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
4758 auto &HeapToSharedAA = A.getAAFor<AAHeapToShared>(
4759 *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
4760
4761 RuntimeFunction RF = It->getSecond();
4762
4763 switch (RF) {
4764 // If neither HeapToStack nor HeapToShared assume the call is removed,
4765 // assume SPMD incompatibility.
4766 case OMPRTL___kmpc_alloc_shared:
4767 if (!HeapToStackAA.isAssumedHeapToStack(CB) &&
4768 !HeapToSharedAA.isAssumedHeapToShared(CB))
4769 SPMDCompatibilityTracker.insert(&CB);
4770 break;
4771 case OMPRTL___kmpc_free_shared:
4772 if (!HeapToStackAA.isAssumedHeapToStackRemovedFree(CB) &&
4773 !HeapToSharedAA.isAssumedHeapToSharedRemovedFree(CB))
4774 SPMDCompatibilityTracker.insert(&CB);
4775 break;
4776 default:
4777 SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4778 SPMDCompatibilityTracker.insert(&CB);
4779 }
4780
4781 return StateBefore == getState() ? ChangeStatus::UNCHANGED
4782 : ChangeStatus::CHANGED;
4783 }
4784};
4785
4786struct AAFoldRuntimeCall
4787 : public StateWrapper<BooleanState, AbstractAttribute> {
4789
4790 AAFoldRuntimeCall(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
4791
4792 /// Statistics are tracked as part of manifest for now.
4793 void trackStatistics() const override {}
4794
4795 /// Create an abstract attribute biew for the position \p IRP.
4796 static AAFoldRuntimeCall &createForPosition(const IRPosition &IRP,
4797 Attributor &A);
4798
4799 /// See AbstractAttribute::getName()
4800 const std::string getName() const override { return "AAFoldRuntimeCall"; }
4801
4802 /// See AbstractAttribute::getIdAddr()
4803 const char *getIdAddr() const override { return &ID; }
4804
4805 /// This function should return true if the type of the \p AA is
4806 /// AAFoldRuntimeCall
4807 static bool classof(const AbstractAttribute *AA) {
4808 return (AA->getIdAddr() == &ID);
4809 }
4810
4811 static const char ID;
4812};
4813
4814struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
4815 AAFoldRuntimeCallCallSiteReturned(const IRPosition &IRP, Attributor &A)
4816 : AAFoldRuntimeCall(IRP, A) {}
4817
4818 /// See AbstractAttribute::getAsStr()
4819 const std::string getAsStr() const override {
4820 if (!isValidState())
4821 return "<invalid>";
4822
4823 std::string Str("simplified value: ");
4824
4825 if (!SimplifiedValue)
4826 return Str + std::string("none");
4827
4828 if (!*SimplifiedValue)
4829 return Str + std::string("nullptr");
4830
4831 if (ConstantInt *CI = dyn_cast<ConstantInt>(*SimplifiedValue))
4832 return Str + std::to_string(CI->getSExtValue());
4833
4834 return Str + std::string("unknown");
4835 }
4836
4837 void initialize(Attributor &A) override {
4839 indicatePessimisticFixpoint();
4840
4841 Function *Callee = getAssociatedFunction();
4842
4843 auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4844 const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
4845 assert(It != OMPInfoCache.RuntimeFunctionIDMap.end() &&
4846 "Expected a known OpenMP runtime function");
4847
4848 RFKind = It->getSecond();
4849
4850 CallBase &CB = cast<CallBase>(getAssociatedValue());
4851 A.registerSimplificationCallback(
4853 [&](const IRPosition &IRP, const AbstractAttribute *AA,
4854 bool &UsedAssumedInformation) -> std::optional<Value *> {
4855 assert((isValidState() ||
4856 (SimplifiedValue && *SimplifiedValue == nullptr)) &&
4857 "Unexpected invalid state!");
4858
4859 if (!isAtFixpoint()) {
4860 UsedAssumedInformation = true;
4861 if (AA)
4862 A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
4863 }
4864 return SimplifiedValue;
4865 });
4866 }
4867
4868 ChangeStatus updateImpl(Attributor &A) override {
4869 ChangeStatus Changed = ChangeStatus::UNCHANGED;
4870 switch (RFKind) {
4871 case OMPRTL___kmpc_is_spmd_exec_mode:
4872 Changed |= foldIsSPMDExecMode(A);
4873 break;
4874 case OMPRTL___kmpc_parallel_level:
4875 Changed |= foldParallelLevel(A);
4876 break;
4877 case OMPRTL___kmpc_get_hardware_num_threads_in_block:
4878 Changed = Changed | foldKernelFnAttribute(A, "omp_target_thread_limit");
4879 break;
4880 case OMPRTL___kmpc_get_hardware_num_blocks:
4881 Changed = Changed | foldKernelFnAttribute(A, "omp_target_num_teams");
4882 break;
4883 default:
4884 llvm_unreachable("Unhandled OpenMP runtime function!");
4885 }
4886
4887 return Changed;
4888 }
4889
4890 ChangeStatus manifest(Attributor &A) override {
4891 ChangeStatus Changed = ChangeStatus::UNCHANGED;
4892
4893 if (SimplifiedValue && *SimplifiedValue) {
4894 Instruction &I = *getCtxI();
4895 A.changeAfterManifest(IRPosition::inst(I), **SimplifiedValue);
4896 A.deleteAfterManifest(I);
4897
4898 CallBase *CB = dyn_cast<CallBase>(&I);
4899 auto Remark = [&](OptimizationRemark OR) {
4900 if (auto *C = dyn_cast<ConstantInt>(*SimplifiedValue))
4901 return OR << "Replacing OpenMP runtime call "
4902 << CB->getCalledFunction()->getName() << " with "
4903 << ore::NV("FoldedValue", C->getZExtValue()) << ".";
4904 return OR << "Replacing OpenMP runtime call "
4905 << CB->getCalledFunction()->getName() << ".";
4906 };
4907
4908 if (CB && EnableVerboseRemarks)
4909 A.emitRemark<OptimizationRemark>(CB, "OMP180", Remark);
4910
4911 LLVM_DEBUG(dbgs() << TAG << "Replacing runtime call: " << I << " with "
4912 << **SimplifiedValue << "\n");
4913
4914 Changed = ChangeStatus::CHANGED;
4915 }
4916
4917 return Changed;
4918 }
4919
4920 ChangeStatus indicatePessimisticFixpoint() override {
4921 SimplifiedValue = nullptr;
4922 return AAFoldRuntimeCall::indicatePessimisticFixpoint();
4923 }
4924
4925private:
4926 /// Fold __kmpc_is_spmd_exec_mode into a constant if possible.
4927 ChangeStatus foldIsSPMDExecMode(Attributor &A) {
4928 std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
4929
4930 unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
4931 unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
4932 auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
4933 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
4934
4935 if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
4936 return indicatePessimisticFixpoint();
4937
4938 for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
4939 auto &AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
4940 DepClassTy::REQUIRED);
4941
4942 if (!AA.isValidState()) {
4943 SimplifiedValue = nullptr;
4944 return indicatePessimisticFixpoint();
4945 }
4946
4947 if (AA.SPMDCompatibilityTracker.isAssumed()) {
4948 if (AA.SPMDCompatibilityTracker.isAtFixpoint())
4949 ++KnownSPMDCount;
4950 else
4951 ++AssumedSPMDCount;
4952 } else {
4953 if (AA.SPMDCompatibilityTracker.isAtFixpoint())
4954 ++KnownNonSPMDCount;
4955 else
4956 ++AssumedNonSPMDCount;
4957 }
4958 }
4959
4960 if ((AssumedSPMDCount + KnownSPMDCount) &&
4961 (AssumedNonSPMDCount + KnownNonSPMDCount))
4962 return indicatePessimisticFixpoint();
4963
4964 auto &Ctx = getAnchorValue().getContext();
4965 if (KnownSPMDCount || AssumedSPMDCount) {
4966 assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
4967 "Expected only SPMD kernels!");
4968 // All reaching kernels are in SPMD mode. Update all function calls to
4969 // __kmpc_is_spmd_exec_mode to 1.
4970 SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), true);
4971 } else if (KnownNonSPMDCount || AssumedNonSPMDCount) {
4972 assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
4973 "Expected only non-SPMD kernels!");
4974 // All reaching kernels are in non-SPMD mode. Update all function
4975 // calls to __kmpc_is_spmd_exec_mode to 0.
4976 SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), false);
4977 } else {
4978 // We have empty reaching kernels, therefore we cannot tell if the
4979 // associated call site can be folded. At this moment, SimplifiedValue
4980 // must be none.
4981 assert(!SimplifiedValue && "SimplifiedValue should be none");
4982 }
4983
4984 return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
4985 : ChangeStatus::CHANGED;
4986 }
4987
4988 /// Fold __kmpc_parallel_level into a constant if possible.
4989 ChangeStatus foldParallelLevel(Attributor &A) {
4990 std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
4991
4992 auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
4993 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
4994
4995 if (!CallerKernelInfoAA.ParallelLevels.isValidState())
4996 return indicatePessimisticFixpoint();
4997
4998 if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
4999 return indicatePessimisticFixpoint();
5000
5001 if (CallerKernelInfoAA.ReachingKernelEntries.empty()) {
5002 assert(!SimplifiedValue &&
5003 "SimplifiedValue should keep none at this point");
5004 return ChangeStatus::UNCHANGED;
5005 }
5006
5007 unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
5008 unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
5009 for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
5010 auto &AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
5011 DepClassTy::REQUIRED);
5012 if (!AA.SPMDCompatibilityTracker.isValidState())
5013 return indicatePessimisticFixpoint();
5014
5015 if (AA.SPMDCompatibilityTracker.isAssumed()) {
5016 if (AA.SPMDCompatibilityTracker.isAtFixpoint())
5017 ++KnownSPMDCount;
5018 else
5019 ++AssumedSPMDCount;
5020 } else {
5021 if (AA.SPMDCompatibilityTracker.isAtFixpoint())
5022 ++KnownNonSPMDCount;
5023 else
5024 ++AssumedNonSPMDCount;
5025 }
5026 }
5027
5028 if ((AssumedSPMDCount + KnownSPMDCount) &&
5029 (AssumedNonSPMDCount + KnownNonSPMDCount))
5030 return indicatePessimisticFixpoint();
5031
5032 auto &Ctx = getAnchorValue().getContext();
5033 // If the caller can only be reached by SPMD kernel entries, the parallel
5034 // level is 1. Similarly, if the caller can only be reached by non-SPMD
5035 // kernel entries, it is 0.
5036 if (AssumedSPMDCount || KnownSPMDCount) {
5037 assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
5038 "Expected only SPMD kernels!");
5039 SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 1);
5040 } else {
5041 assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
5042 "Expected only non-SPMD kernels!");
5043 SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 0);
5044 }
5045 return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
5046 : ChangeStatus::CHANGED;
5047 }
5048
5049 ChangeStatus foldKernelFnAttribute(Attributor &A, llvm::StringRef Attr) {
5050 // Specialize only if all the calls agree with the attribute constant value
5051 int32_t CurrentAttrValue = -1;
5052 std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
5053
5054 auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
5055 *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
5056
5057 if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
5058 return indicatePessimisticFixpoint();
5059
5060 // Iterate over the kernels that reach this function
5061 for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
5062 int32_t NextAttrVal = K->getFnAttributeAsParsedInteger(Attr, -1);
5063
5064 if (NextAttrVal == -1 ||
5065 (CurrentAttrValue != -1 && CurrentAttrValue != NextAttrVal))
5066 return indicatePessimisticFixpoint();
5067 CurrentAttrValue = NextAttrVal;
5068 }
5069
5070 if (CurrentAttrValue != -1) {
5071 auto &Ctx = getAnchorValue().getContext();
5072 SimplifiedValue =
5073 ConstantInt::get(Type::getInt32Ty(Ctx), CurrentAttrValue);
5074 }
5075 return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
5076 : ChangeStatus::CHANGED;
5077 }
5078
5079 /// An optional value the associated value is assumed to fold to. That is, we
5080 /// assume the associated value (which is a call) can be replaced by this
5081 /// simplified value.
5082 std::optional<Value *> SimplifiedValue;
5083
5084 /// The runtime function kind of the callee of the associated call site.
5085 RuntimeFunction RFKind;
5086};
5087
5088} // namespace
5089
5090/// Register folding callsite
5091void OpenMPOpt::registerFoldRuntimeCall(RuntimeFunction RF) {
5092 auto &RFI = OMPInfoCache.RFIs[RF];
5093 RFI.foreachUse(SCC, [&](Use &U, Function &F) {
5094 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &RFI);
5095 if (!CI)
5096 return false;
5097 A.getOrCreateAAFor<AAFoldRuntimeCall>(
5098 IRPosition::callsite_returned(*CI), /* QueryingAA */ nullptr,
5099 DepClassTy::NONE, /* ForceUpdate */ false,
5100 /* UpdateAfterInit */ false);
5101 return false;
5102 });
5103}
5104
5105void OpenMPOpt::registerAAs(bool IsModulePass) {
5106 if (SCC.empty())
5107 return;
5108
5109 if (IsModulePass) {
5110 // Ensure we create the AAKernelInfo AAs first and without triggering an
5111 // update. This will make sure we register all value simplification
5112 // callbacks before any other AA has the chance to create an AAValueSimplify
5113 // or similar.
5114 auto CreateKernelInfoCB = [&](Use &, Function &Kernel) {
5115 A.getOrCreateAAFor<AAKernelInfo>(
5116 IRPosition::function(Kernel), /* QueryingAA */ nullptr,
5117 DepClassTy::NONE, /* ForceUpdate */ false,
5118 /* UpdateAfterInit */ false);
5119 return false;
5120 };
5121 OMPInformationCache::RuntimeFunctionInfo &InitRFI =
5122 OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
5123 InitRFI.foreachUse(SCC, CreateKernelInfoCB);
5124
5125 registerFoldRuntimeCall(OMPRTL___kmpc_is_spmd_exec_mode);
5126 registerFoldRuntimeCall(OMPRTL___kmpc_parallel_level);
5127 registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_threads_in_block);
5128 registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_blocks);
5129 }
5130
5131 // Create CallSite AA for all Getters.
5132 if (DeduceICVValues) {
5133 for (int Idx = 0; Idx < OMPInfoCache.ICVs.size() - 1; ++Idx) {
5134 auto ICVInfo = OMPInfoCache.ICVs[static_cast<InternalControlVar>(Idx)];
5135
5136 auto &GetterRFI = OMPInfoCache.RFIs[ICVInfo.Getter];
5137
5138 auto CreateAA = [&](Use &U, Function &Caller) {
5139 CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &GetterRFI);
5140 if (!CI)
5141 return false;
5142
5143 auto &CB = cast<CallBase>(*CI);
5144
5146 A.getOrCreateAAFor<AAICVTracker>(CBPos);
5147 return false;
5148 };
5149
5150 GetterRFI.foreachUse(SCC, CreateAA);
5151 }
5152 }
5153
5154 // Create an ExecutionDomain AA for every function and a HeapToStack AA for
5155 // every function if there is a device kernel.
5156 if (!isOpenMPDevice(M))
5157 return;
5158
5159 for (auto *F : SCC) {
5160 if (F->isDeclaration())
5161 continue;
5162
5163 // We look at internal functions only on-demand but if any use is not a
5164 // direct call or outside the current set of analyzed functions, we have
5165 // to do it eagerly.
5166 if (F->hasLocalLinkage()) {
5167 if (llvm::all_of(F->uses(), [this](const Use &U) {
5168 const auto *CB = dyn_cast<CallBase>(U.getUser());
5169 return CB && CB->isCallee(&U) &&
5170 A.isRunOn(const_cast<Function *>(CB->getCaller()));
5171 }))
5172 continue;
5173 }
5174 registerAAsForFunction(A, *F);
5175 }
5176}
5177
5178void OpenMPOpt::registerAAsForFunction(Attributor &A, const Function &F) {
5180 A.getOrCreateAAFor<AAHeapToShared>(IRPosition::function(F));
5181 A.getOrCreateAAFor<AAExecutionDomain>(IRPosition::function(F));
5183 A.getOrCreateAAFor<AAHeapToStack>(IRPosition::function(F));
5184
5185 for (auto &I : instructions(F)) {
5186 if (auto *LI = dyn_cast<LoadInst>(&I)) {
5187 bool UsedAssumedInformation = false;
5188 A.getAssumedSimplified(IRPosition::value(*LI), /* AA */ nullptr,
5189 UsedAssumedInformation, AA::Interprocedural);
5190 continue;
5191 }
5192 if (auto *SI = dyn_cast<StoreInst>(&I)) {
5193 A.getOrCreateAAFor<AAIsDead>(IRPosition::value(*SI));
5194 continue;
5195 }
5196 if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
5197 if (II->getIntrinsicID() == Intrinsic::assume) {
5198 A.getOrCreateAAFor<AAPotentialValues>(
5199 IRPosition::value(*II->getArgOperand(0)));
5200 continue;
5201 }
5202 }
5203 }
5204}
5205
5206const char AAICVTracker::ID = 0;
5207const char AAKernelInfo::ID = 0;
5208const char AAExecutionDomain::ID = 0;
5209const char AAHeapToShared::ID = 0;
5210const char AAFoldRuntimeCall::ID = 0;
5211
5212AAICVTracker &AAICVTracker::createForPosition(const IRPosition &IRP,
5213 Attributor &A) {
5214 AAICVTracker *AA = nullptr;
5215 switch (IRP.getPositionKind()) {
5220 llvm_unreachable("ICVTracker can only be created for function position!");
5222 AA = new (A.Allocator) AAICVTrackerFunctionReturned(IRP, A);
5223 break;
5225 AA = new (A.Allocator) AAICVTrackerCallSiteReturned(IRP, A);
5226 break;
5228 AA = new (A.Allocator) AAICVTrackerCallSite(IRP, A);
5229 break;
5231 AA = new (A.Allocator) AAICVTrackerFunction(IRP, A);
5232 break;
5233 }
5234
5235 return *AA;
5236}
5237
5239 Attributor &A) {
5240 AAExecutionDomainFunction *AA = nullptr;
5241 switch (IRP.getPositionKind()) {
5250 "AAExecutionDomain can only be created for function position!");
5252 AA = new (A.Allocator) AAExecutionDomainFunction(IRP, A);
5253 break;
5254 }
5255
5256 return *AA;
5257}
5258
5259AAHeapToShared &AAHeapToShared::createForPosition(const IRPosition &IRP,
5260 Attributor &A) {
5261 AAHeapToSharedFunction *AA = nullptr;
5262 switch (IRP.getPositionKind()) {
5271 "AAHeapToShared can only be created for function position!");
5273 AA = new (A.Allocator) AAHeapToSharedFunction(IRP, A);
5274 break;
5275 }
5276
5277 return *AA;
5278}
5279
5280AAKernelInfo &AAKernelInfo::createForPosition(const IRPosition &IRP,
5281 Attributor &A) {
5282 AAKernelInfo *AA = nullptr;
5283 switch (IRP.getPositionKind()) {
5290 llvm_unreachable("KernelInfo can only be created for function position!");
5292 AA = new (A.Allocator) AAKernelInfoCallSite(IRP, A);
5293 break;
5295 AA = new (A.Allocator) AAKernelInfoFunction(IRP, A);
5296 break;
5297 }
5298
5299 return *AA;
5300}
5301
5302AAFoldRuntimeCall &AAFoldRuntimeCall::createForPosition(const IRPosition &IRP,
5303 Attributor &A) {
5304 AAFoldRuntimeCall *AA = nullptr;
5305 switch (IRP.getPositionKind()) {
5313 llvm_unreachable("KernelInfo can only be created for call site position!");
5315 AA = new (A.Allocator) AAFoldRuntimeCallCallSiteReturned(IRP, A);
5316 break;
5317 }
5318
5319 return *AA;
5320}
5321
5323 if (!containsOpenMP(M))
5324 return PreservedAnalyses::all();
5326 return PreservedAnalyses::all();
5327
5330 KernelSet Kernels = getDeviceKernels(M);
5331
5333 LLVM_DEBUG(dbgs() << TAG << "Module before OpenMPOpt Module Pass:\n" << M);
5334
5335 auto IsCalled = [&](Function &F) {
5336 if (Kernels.contains(&F))
5337 return true;
5338 for (const User *U : F.users())
5339 if (!isa<BlockAddress>(U))
5340 return true;
5341 return false;
5342 };
5343
5344 auto EmitRemark = [&](Function &F) {
5346 ORE.emit([&]() {
5347 OptimizationRemarkAnalysis ORA(DEBUG_TYPE, "OMP140", &F);
5348 return ORA << "Could not internalize function. "
5349 << "Some optimizations may not be possible. [OMP140]";
5350 });
5351 };
5352
5353 // Create internal copies of each function if this is a kernel Module. This
5354 // allows iterprocedural passes to see every call edge.
5355 DenseMap<Function *, Function *> InternalizedMap;
5356 if (isOpenMPDevice(M)) {
5357 SmallPtrSet<Function *, 16> InternalizeFns;
5358 for (Function &F : M)
5359 if (!F.isDeclaration() && !Kernels.contains(&F) && IsCalled(F) &&
5362 InternalizeFns.insert(&F);
5363 } else if (!F.hasLocalLinkage() && !F.hasFnAttribute(Attribute::Cold)) {
5364 EmitRemark(F);
5365 }
5366 }
5367
5368 Attributor::internalizeFunctions(InternalizeFns, InternalizedMap);
5369 }
5370
5371 // Look at every function in the Module unless it was internalized.
5372 SetVector<Function *> Functions;
5374 for (Function &F : M)
5375 if (!F.isDeclaration() && !InternalizedMap.lookup(&F)) {
5376 SCC.push_back(&F);
5377 Functions.insert(&F);
5378 }
5379
5380 if (SCC.empty())
5381 return PreservedAnalyses::all();
5382
5383 AnalysisGetter AG(FAM);
5384
5385 auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
5386 return FAM.getResult<