LLVM  15.0.0git
OpenMPOpt.cpp
Go to the documentation of this file.
1 //===-- IPO/OpenMPOpt.cpp - Collection of OpenMP specific optimizations ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // OpenMP specific optimizations:
10 //
11 // - Deduplication of runtime calls, e.g., omp_get_thread_num.
12 // - Replacing globalized device memory with stack memory.
13 // - Replacing globalized device memory with shared memory.
14 // - Parallel region merging.
15 // - Transforming generic-mode device kernels to SPMD mode.
16 // - Specializing the state machine for generic-mode device kernels.
17 //
18 //===----------------------------------------------------------------------===//
19 
21 
24 #include "llvm/ADT/SetVector.h"
25 #include "llvm/ADT/Statistic.h"
26 #include "llvm/ADT/StringRef.h"
34 #include "llvm/IR/Assumptions.h"
35 #include "llvm/IR/Constants.h"
36 #include "llvm/IR/DiagnosticInfo.h"
37 #include "llvm/IR/GlobalValue.h"
38 #include "llvm/IR/GlobalVariable.h"
39 #include "llvm/IR/Instruction.h"
40 #include "llvm/IR/Instructions.h"
41 #include "llvm/IR/IntrinsicInst.h"
42 #include "llvm/IR/IntrinsicsAMDGPU.h"
43 #include "llvm/IR/IntrinsicsNVPTX.h"
44 #include "llvm/IR/LLVMContext.h"
45 #include "llvm/InitializePasses.h"
47 #include "llvm/Support/Debug.h"
48 #include "llvm/Transforms/IPO.h"
52 
53 #include <algorithm>
54 
55 using namespace llvm;
56 using namespace omp;
57 
58 #define DEBUG_TYPE "openmp-opt"
59 
61  "openmp-opt-disable", cl::ZeroOrMore,
62  cl::desc("Disable OpenMP specific optimizations."), cl::Hidden,
63  cl::init(false));
64 
66  "openmp-opt-enable-merging", cl::ZeroOrMore,
67  cl::desc("Enable the OpenMP region merging optimization."), cl::Hidden,
68  cl::init(false));
69 
70 static cl::opt<bool>
71  DisableInternalization("openmp-opt-disable-internalization", cl::ZeroOrMore,
72  cl::desc("Disable function internalization."),
73  cl::Hidden, cl::init(false));
74 
75 static cl::opt<bool> PrintICVValues("openmp-print-icv-values", cl::init(false),
76  cl::Hidden);
77 static cl::opt<bool> PrintOpenMPKernels("openmp-print-gpu-kernels",
78  cl::init(false), cl::Hidden);
79 
81  "openmp-hide-memory-transfer-latency",
82  cl::desc("[WIP] Tries to hide the latency of host to device memory"
83  " transfers"),
84  cl::Hidden, cl::init(false));
85 
87  "openmp-opt-disable-deglobalization", cl::ZeroOrMore,
88  cl::desc("Disable OpenMP optimizations involving deglobalization."),
89  cl::Hidden, cl::init(false));
90 
92  "openmp-opt-disable-spmdization", cl::ZeroOrMore,
93  cl::desc("Disable OpenMP optimizations involving SPMD-ization."),
94  cl::Hidden, cl::init(false));
95 
97  "openmp-opt-disable-folding", cl::ZeroOrMore,
98  cl::desc("Disable OpenMP optimizations involving folding."), cl::Hidden,
99  cl::init(false));
100 
102  "openmp-opt-disable-state-machine-rewrite", cl::ZeroOrMore,
103  cl::desc("Disable OpenMP optimizations that replace the state machine."),
104  cl::Hidden, cl::init(false));
105 
107  "openmp-opt-disable-barrier-elimination", cl::ZeroOrMore,
108  cl::desc("Disable OpenMP optimizations that eliminate barriers."),
109  cl::Hidden, cl::init(false));
110 
112  "openmp-opt-print-module-after", cl::ZeroOrMore,
113  cl::desc("Print the current module after OpenMP optimizations."),
114  cl::Hidden, cl::init(false));
115 
117  "openmp-opt-print-module-before", cl::ZeroOrMore,
118  cl::desc("Print the current module before OpenMP optimizations."),
119  cl::Hidden, cl::init(false));
120 
122  "openmp-opt-inline-device", cl::ZeroOrMore,
123  cl::desc("Inline all applicible functions on the device."), cl::Hidden,
124  cl::init(false));
125 
126 static cl::opt<bool>
127  EnableVerboseRemarks("openmp-opt-verbose-remarks", cl::ZeroOrMore,
128  cl::desc("Enables more verbose remarks."), cl::Hidden,
129  cl::init(false));
130 
131 static cl::opt<unsigned>
132  SetFixpointIterations("openmp-opt-max-iterations", cl::Hidden,
133  cl::desc("Maximal number of attributor iterations."),
134  cl::init(256));
135 
136 static cl::opt<unsigned>
137  SharedMemoryLimit("openmp-opt-shared-limit", cl::Hidden,
138  cl::desc("Maximum amount of shared memory to use."),
140 
141 STATISTIC(NumOpenMPRuntimeCallsDeduplicated,
142  "Number of OpenMP runtime calls deduplicated");
143 STATISTIC(NumOpenMPParallelRegionsDeleted,
144  "Number of OpenMP parallel regions deleted");
145 STATISTIC(NumOpenMPRuntimeFunctionsIdentified,
146  "Number of OpenMP runtime functions identified");
147 STATISTIC(NumOpenMPRuntimeFunctionUsesIdentified,
148  "Number of OpenMP runtime function uses identified");
149 STATISTIC(NumOpenMPTargetRegionKernels,
150  "Number of OpenMP target region entry points (=kernels) identified");
151 STATISTIC(NumOpenMPTargetRegionKernelsSPMD,
152  "Number of OpenMP target region entry points (=kernels) executed in "
153  "SPMD-mode instead of generic-mode");
154 STATISTIC(NumOpenMPTargetRegionKernelsWithoutStateMachine,
155  "Number of OpenMP target region entry points (=kernels) executed in "
156  "generic-mode without a state machines");
157 STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback,
158  "Number of OpenMP target region entry points (=kernels) executed in "
159  "generic-mode with customized state machines with fallback");
160 STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback,
161  "Number of OpenMP target region entry points (=kernels) executed in "
162  "generic-mode with customized state machines without fallback");
163 STATISTIC(
164  NumOpenMPParallelRegionsReplacedInGPUStateMachine,
165  "Number of OpenMP parallel regions replaced with ID in GPU state machines");
166 STATISTIC(NumOpenMPParallelRegionsMerged,
167  "Number of OpenMP parallel regions merged");
168 STATISTIC(NumBytesMovedToSharedMemory,
169  "Amount of memory pushed to shared memory");
170 STATISTIC(NumBarriersEliminated, "Number of redundant barriers eliminated");
171 
172 #if !defined(NDEBUG)
173 static constexpr auto TAG = "[" DEBUG_TYPE "]";
174 #endif
175 
176 namespace {
177 
178 struct AAHeapToShared;
179 
180 struct AAICVTracker;
181 
182 /// OpenMP specific information. For now, stores RFIs and ICVs also needed for
183 /// Attributor runs.
184 struct OMPInformationCache : public InformationCache {
185  OMPInformationCache(Module &M, AnalysisGetter &AG,
188  : InformationCache(M, AG, Allocator, &CGSCC), OMPBuilder(M),
189  Kernels(Kernels) {
190 
191  OMPBuilder.initialize();
192  initializeRuntimeFunctions();
193  initializeInternalControlVars();
194  }
195 
196  /// Generic information that describes an internal control variable.
197  struct InternalControlVarInfo {
198  /// The kind, as described by InternalControlVar enum.
200 
201  /// The name of the ICV.
202  StringRef Name;
203 
204  /// Environment variable associated with this ICV.
205  StringRef EnvVarName;
206 
207  /// Initial value kind.
208  ICVInitValue InitKind;
209 
210  /// Initial value.
211  ConstantInt *InitValue;
212 
213  /// Setter RTL function associated with this ICV.
214  RuntimeFunction Setter;
215 
216  /// Getter RTL function associated with this ICV.
217  RuntimeFunction Getter;
218 
219  /// RTL Function corresponding to the override clause of this ICV
221  };
222 
223  /// Generic information that describes a runtime function
224  struct RuntimeFunctionInfo {
225 
226  /// The kind, as described by the RuntimeFunction enum.
228 
229  /// The name of the function.
230  StringRef Name;
231 
232  /// Flag to indicate a variadic function.
233  bool IsVarArg;
234 
235  /// The return type of the function.
236  Type *ReturnType;
237 
238  /// The argument types of the function.
239  SmallVector<Type *, 8> ArgumentTypes;
240 
241  /// The declaration if available.
242  Function *Declaration = nullptr;
243 
244  /// Uses of this runtime function per function containing the use.
245  using UseVector = SmallVector<Use *, 16>;
246 
247  /// Clear UsesMap for runtime function.
248  void clearUsesMap() { UsesMap.clear(); }
249 
250  /// Boolean conversion that is true if the runtime function was found.
251  operator bool() const { return Declaration; }
252 
253  /// Return the vector of uses in function \p F.
254  UseVector &getOrCreateUseVector(Function *F) {
255  std::shared_ptr<UseVector> &UV = UsesMap[F];
256  if (!UV)
257  UV = std::make_shared<UseVector>();
258  return *UV;
259  }
260 
261  /// Return the vector of uses in function \p F or `nullptr` if there are
262  /// none.
263  const UseVector *getUseVector(Function &F) const {
264  auto I = UsesMap.find(&F);
265  if (I != UsesMap.end())
266  return I->second.get();
267  return nullptr;
268  }
269 
270  /// Return how many functions contain uses of this runtime function.
271  size_t getNumFunctionsWithUses() const { return UsesMap.size(); }
272 
273  /// Return the number of arguments (or the minimal number for variadic
274  /// functions).
275  size_t getNumArgs() const { return ArgumentTypes.size(); }
276 
277  /// Run the callback \p CB on each use and forget the use if the result is
278  /// true. The callback will be fed the function in which the use was
279  /// encountered as second argument.
280  void foreachUse(SmallVectorImpl<Function *> &SCC,
281  function_ref<bool(Use &, Function &)> CB) {
282  for (Function *F : SCC)
283  foreachUse(CB, F);
284  }
285 
286  /// Run the callback \p CB on each use within the function \p F and forget
287  /// the use if the result is true.
288  void foreachUse(function_ref<bool(Use &, Function &)> CB, Function *F) {
289  SmallVector<unsigned, 8> ToBeDeleted;
290  ToBeDeleted.clear();
291 
292  unsigned Idx = 0;
293  UseVector &UV = getOrCreateUseVector(F);
294 
295  for (Use *U : UV) {
296  if (CB(*U, *F))
297  ToBeDeleted.push_back(Idx);
298  ++Idx;
299  }
300 
301  // Remove the to-be-deleted indices in reverse order as prior
302  // modifications will not modify the smaller indices.
303  while (!ToBeDeleted.empty()) {
304  unsigned Idx = ToBeDeleted.pop_back_val();
305  UV[Idx] = UV.back();
306  UV.pop_back();
307  }
308  }
309 
310  private:
311  /// Map from functions to all uses of this runtime function contained in
312  /// them.
314 
315  public:
316  /// Iterators for the uses of this runtime function.
317  decltype(UsesMap)::iterator begin() { return UsesMap.begin(); }
318  decltype(UsesMap)::iterator end() { return UsesMap.end(); }
319  };
320 
321  /// An OpenMP-IR-Builder instance
322  OpenMPIRBuilder OMPBuilder;
323 
324  /// Map from runtime function kind to the runtime function description.
325  EnumeratedArray<RuntimeFunctionInfo, RuntimeFunction,
326  RuntimeFunction::OMPRTL___last>
327  RFIs;
328 
329  /// Map from function declarations/definitions to their runtime enum type.
330  DenseMap<Function *, RuntimeFunction> RuntimeFunctionIDMap;
331 
332  /// Map from ICV kind to the ICV description.
333  EnumeratedArray<InternalControlVarInfo, InternalControlVar,
334  InternalControlVar::ICV___last>
335  ICVs;
336 
337  /// Helper to initialize all internal control variable information for those
338  /// defined in OMPKinds.def.
339  void initializeInternalControlVars() {
340 #define ICV_RT_SET(_Name, RTL) \
341  { \
342  auto &ICV = ICVs[_Name]; \
343  ICV.Setter = RTL; \
344  }
345 #define ICV_RT_GET(Name, RTL) \
346  { \
347  auto &ICV = ICVs[Name]; \
348  ICV.Getter = RTL; \
349  }
350 #define ICV_DATA_ENV(Enum, _Name, _EnvVarName, Init) \
351  { \
352  auto &ICV = ICVs[Enum]; \
353  ICV.Name = _Name; \
354  ICV.Kind = Enum; \
355  ICV.InitKind = Init; \
356  ICV.EnvVarName = _EnvVarName; \
357  switch (ICV.InitKind) { \
358  case ICV_IMPLEMENTATION_DEFINED: \
359  ICV.InitValue = nullptr; \
360  break; \
361  case ICV_ZERO: \
362  ICV.InitValue = ConstantInt::get( \
363  Type::getInt32Ty(OMPBuilder.Int32->getContext()), 0); \
364  break; \
365  case ICV_FALSE: \
366  ICV.InitValue = ConstantInt::getFalse(OMPBuilder.Int1->getContext()); \
367  break; \
368  case ICV_LAST: \
369  break; \
370  } \
371  }
372 #include "llvm/Frontend/OpenMP/OMPKinds.def"
373  }
374 
375  /// Returns true if the function declaration \p F matches the runtime
376  /// function types, that is, return type \p RTFRetType, and argument types
377  /// \p RTFArgTypes.
378  static bool declMatchesRTFTypes(Function *F, Type *RTFRetType,
379  SmallVector<Type *, 8> &RTFArgTypes) {
380  // TODO: We should output information to the user (under debug output
381  // and via remarks).
382 
383  if (!F)
384  return false;
385  if (F->getReturnType() != RTFRetType)
386  return false;
387  if (F->arg_size() != RTFArgTypes.size())
388  return false;
389 
390  auto *RTFTyIt = RTFArgTypes.begin();
391  for (Argument &Arg : F->args()) {
392  if (Arg.getType() != *RTFTyIt)
393  return false;
394 
395  ++RTFTyIt;
396  }
397 
398  return true;
399  }
400 
401  // Helper to collect all uses of the declaration in the UsesMap.
402  unsigned collectUses(RuntimeFunctionInfo &RFI, bool CollectStats = true) {
403  unsigned NumUses = 0;
404  if (!RFI.Declaration)
405  return NumUses;
406  OMPBuilder.addAttributes(RFI.Kind, *RFI.Declaration);
407 
408  if (CollectStats) {
409  NumOpenMPRuntimeFunctionsIdentified += 1;
410  NumOpenMPRuntimeFunctionUsesIdentified += RFI.Declaration->getNumUses();
411  }
412 
413  // TODO: We directly convert uses into proper calls and unknown uses.
414  for (Use &U : RFI.Declaration->uses()) {
415  if (Instruction *UserI = dyn_cast<Instruction>(U.getUser())) {
416  if (ModuleSlice.count(UserI->getFunction())) {
417  RFI.getOrCreateUseVector(UserI->getFunction()).push_back(&U);
418  ++NumUses;
419  }
420  } else {
421  RFI.getOrCreateUseVector(nullptr).push_back(&U);
422  ++NumUses;
423  }
424  }
425  return NumUses;
426  }
427 
428  // Helper function to recollect uses of a runtime function.
429  void recollectUsesForFunction(RuntimeFunction RTF) {
430  auto &RFI = RFIs[RTF];
431  RFI.clearUsesMap();
432  collectUses(RFI, /*CollectStats*/ false);
433  }
434 
435  // Helper function to recollect uses of all runtime functions.
436  void recollectUses() {
437  for (int Idx = 0; Idx < RFIs.size(); ++Idx)
438  recollectUsesForFunction(static_cast<RuntimeFunction>(Idx));
439  }
440 
441  // Helper function to inherit the calling convention of the function callee.
442  void setCallingConvention(FunctionCallee Callee, CallInst *CI) {
443  if (Function *Fn = dyn_cast<Function>(Callee.getCallee()))
444  CI->setCallingConv(Fn->getCallingConv());
445  }
446 
447  /// Helper to initialize all runtime function information for those defined
448  /// in OpenMPKinds.def.
449  void initializeRuntimeFunctions() {
450  Module &M = *((*ModuleSlice.begin())->getParent());
451 
452  // Helper macros for handling __VA_ARGS__ in OMP_RTL
453 #define OMP_TYPE(VarName, ...) \
454  Type *VarName = OMPBuilder.VarName; \
455  (void)VarName;
456 
457 #define OMP_ARRAY_TYPE(VarName, ...) \
458  ArrayType *VarName##Ty = OMPBuilder.VarName##Ty; \
459  (void)VarName##Ty; \
460  PointerType *VarName##PtrTy = OMPBuilder.VarName##PtrTy; \
461  (void)VarName##PtrTy;
462 
463 #define OMP_FUNCTION_TYPE(VarName, ...) \
464  FunctionType *VarName = OMPBuilder.VarName; \
465  (void)VarName; \
466  PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \
467  (void)VarName##Ptr;
468 
469 #define OMP_STRUCT_TYPE(VarName, ...) \
470  StructType *VarName = OMPBuilder.VarName; \
471  (void)VarName; \
472  PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \
473  (void)VarName##Ptr;
474 
475 #define OMP_RTL(_Enum, _Name, _IsVarArg, _ReturnType, ...) \
476  { \
477  SmallVector<Type *, 8> ArgsTypes({__VA_ARGS__}); \
478  Function *F = M.getFunction(_Name); \
479  RTLFunctions.insert(F); \
480  if (declMatchesRTFTypes(F, OMPBuilder._ReturnType, ArgsTypes)) { \
481  RuntimeFunctionIDMap[F] = _Enum; \
482  auto &RFI = RFIs[_Enum]; \
483  RFI.Kind = _Enum; \
484  RFI.Name = _Name; \
485  RFI.IsVarArg = _IsVarArg; \
486  RFI.ReturnType = OMPBuilder._ReturnType; \
487  RFI.ArgumentTypes = std::move(ArgsTypes); \
488  RFI.Declaration = F; \
489  unsigned NumUses = collectUses(RFI); \
490  (void)NumUses; \
491  LLVM_DEBUG({ \
492  dbgs() << TAG << RFI.Name << (RFI.Declaration ? "" : " not") \
493  << " found\n"; \
494  if (RFI.Declaration) \
495  dbgs() << TAG << "-> got " << NumUses << " uses in " \
496  << RFI.getNumFunctionsWithUses() \
497  << " different functions.\n"; \
498  }); \
499  } \
500  }
501 #include "llvm/Frontend/OpenMP/OMPKinds.def"
502 
503  // Remove the `noinline` attribute from `__kmpc`, `_OMP::` and `omp_`
504  // functions, except if `optnone` is present.
505  for (Function &F : M) {
506  for (StringRef Prefix : {"__kmpc", "_ZN4_OMP", "omp_"})
507  if (F.getName().startswith(Prefix) &&
508  !F.hasFnAttribute(Attribute::OptimizeNone))
509  F.removeFnAttr(Attribute::NoInline);
510  }
511 
512  // TODO: We should attach the attributes defined in OMPKinds.def.
513  }
514 
515  /// Collection of known kernels (\see Kernel) in the module.
517 
518  /// Collection of known OpenMP runtime functions..
519  DenseSet<const Function *> RTLFunctions;
520 };
521 
522 template <typename Ty, bool InsertInvalidates = true>
523 struct BooleanStateWithSetVector : public BooleanState {
524  bool contains(const Ty &Elem) const { return Set.contains(Elem); }
525  bool insert(const Ty &Elem) {
526  if (InsertInvalidates)
528  return Set.insert(Elem);
529  }
530 
531  const Ty &operator[](int Idx) const { return Set[Idx]; }
532  bool operator==(const BooleanStateWithSetVector &RHS) const {
533  return BooleanState::operator==(RHS) && Set == RHS.Set;
534  }
535  bool operator!=(const BooleanStateWithSetVector &RHS) const {
536  return !(*this == RHS);
537  }
538 
539  bool empty() const { return Set.empty(); }
540  size_t size() const { return Set.size(); }
541 
542  /// "Clamp" this state with \p RHS.
543  BooleanStateWithSetVector &operator^=(const BooleanStateWithSetVector &RHS) {
545  Set.insert(RHS.Set.begin(), RHS.Set.end());
546  return *this;
547  }
548 
549 private:
550  /// A set to keep track of elements.
551  SetVector<Ty> Set;
552 
553 public:
554  typename decltype(Set)::iterator begin() { return Set.begin(); }
555  typename decltype(Set)::iterator end() { return Set.end(); }
556  typename decltype(Set)::const_iterator begin() const { return Set.begin(); }
557  typename decltype(Set)::const_iterator end() const { return Set.end(); }
558 };
559 
560 template <typename Ty, bool InsertInvalidates = true>
561 using BooleanStateWithPtrSetVector =
562  BooleanStateWithSetVector<Ty *, InsertInvalidates>;
563 
564 struct KernelInfoState : AbstractState {
565  /// Flag to track if we reached a fixpoint.
566  bool IsAtFixpoint = false;
567 
568  /// The parallel regions (identified by the outlined parallel functions) that
569  /// can be reached from the associated function.
570  BooleanStateWithPtrSetVector<Function, /* InsertInvalidates */ false>
571  ReachedKnownParallelRegions;
572 
573  /// State to track what parallel region we might reach.
574  BooleanStateWithPtrSetVector<CallBase> ReachedUnknownParallelRegions;
575 
576  /// State to track if we are in SPMD-mode, assumed or know, and why we decided
577  /// we cannot be. If it is assumed, then RequiresFullRuntime should also be
578  /// false.
579  BooleanStateWithPtrSetVector<Instruction, false> SPMDCompatibilityTracker;
580 
581  /// The __kmpc_target_init call in this kernel, if any. If we find more than
582  /// one we abort as the kernel is malformed.
583  CallBase *KernelInitCB = nullptr;
584 
585  /// The __kmpc_target_deinit call in this kernel, if any. If we find more than
586  /// one we abort as the kernel is malformed.
587  CallBase *KernelDeinitCB = nullptr;
588 
589  /// Flag to indicate if the associated function is a kernel entry.
590  bool IsKernelEntry = false;
591 
592  /// State to track what kernel entries can reach the associated function.
593  BooleanStateWithPtrSetVector<Function, false> ReachingKernelEntries;
594 
595  /// State to indicate if we can track parallel level of the associated
596  /// function. We will give up tracking if we encounter unknown caller or the
597  /// caller is __kmpc_parallel_51.
598  BooleanStateWithSetVector<uint8_t> ParallelLevels;
599 
600  /// Abstract State interface
601  ///{
602 
603  KernelInfoState() = default;
604  KernelInfoState(bool BestState) {
605  if (!BestState)
606  indicatePessimisticFixpoint();
607  }
608 
609  /// See AbstractState::isValidState(...)
610  bool isValidState() const override { return true; }
611 
612  /// See AbstractState::isAtFixpoint(...)
613  bool isAtFixpoint() const override { return IsAtFixpoint; }
614 
615  /// See AbstractState::indicatePessimisticFixpoint(...)
616  ChangeStatus indicatePessimisticFixpoint() override {
617  IsAtFixpoint = true;
618  ReachingKernelEntries.indicatePessimisticFixpoint();
619  SPMDCompatibilityTracker.indicatePessimisticFixpoint();
620  ReachedKnownParallelRegions.indicatePessimisticFixpoint();
621  ReachedUnknownParallelRegions.indicatePessimisticFixpoint();
622  return ChangeStatus::CHANGED;
623  }
624 
625  /// See AbstractState::indicateOptimisticFixpoint(...)
626  ChangeStatus indicateOptimisticFixpoint() override {
627  IsAtFixpoint = true;
628  ReachingKernelEntries.indicateOptimisticFixpoint();
629  SPMDCompatibilityTracker.indicateOptimisticFixpoint();
630  ReachedKnownParallelRegions.indicateOptimisticFixpoint();
631  ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
633  }
634 
635  /// Return the assumed state
636  KernelInfoState &getAssumed() { return *this; }
637  const KernelInfoState &getAssumed() const { return *this; }
638 
639  bool operator==(const KernelInfoState &RHS) const {
640  if (SPMDCompatibilityTracker != RHS.SPMDCompatibilityTracker)
641  return false;
642  if (ReachedKnownParallelRegions != RHS.ReachedKnownParallelRegions)
643  return false;
644  if (ReachedUnknownParallelRegions != RHS.ReachedUnknownParallelRegions)
645  return false;
646  if (ReachingKernelEntries != RHS.ReachingKernelEntries)
647  return false;
648  return true;
649  }
650 
651  /// Returns true if this kernel contains any OpenMP parallel regions.
652  bool mayContainParallelRegion() {
653  return !ReachedKnownParallelRegions.empty() ||
654  !ReachedUnknownParallelRegions.empty();
655  }
656 
657  /// Return empty set as the best state of potential values.
658  static KernelInfoState getBestState() { return KernelInfoState(true); }
659 
660  static KernelInfoState getBestState(KernelInfoState &KIS) {
661  return getBestState();
662  }
663 
664  /// Return full set as the worst state of potential values.
665  static KernelInfoState getWorstState() { return KernelInfoState(false); }
666 
667  /// "Clamp" this state with \p KIS.
668  KernelInfoState operator^=(const KernelInfoState &KIS) {
669  // Do not merge two different _init and _deinit call sites.
670  if (KIS.KernelInitCB) {
671  if (KernelInitCB && KernelInitCB != KIS.KernelInitCB)
672  llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
673  "assumptions.");
674  KernelInitCB = KIS.KernelInitCB;
675  }
676  if (KIS.KernelDeinitCB) {
677  if (KernelDeinitCB && KernelDeinitCB != KIS.KernelDeinitCB)
678  llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
679  "assumptions.");
680  KernelDeinitCB = KIS.KernelDeinitCB;
681  }
682  SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker;
683  ReachedKnownParallelRegions ^= KIS.ReachedKnownParallelRegions;
684  ReachedUnknownParallelRegions ^= KIS.ReachedUnknownParallelRegions;
685  return *this;
686  }
687 
688  KernelInfoState operator&=(const KernelInfoState &KIS) {
689  return (*this ^= KIS);
690  }
691 
692  ///}
693 };
694 
695 /// Used to map the values physically (in the IR) stored in an offload
696 /// array, to a vector in memory.
697 struct OffloadArray {
698  /// Physical array (in the IR).
699  AllocaInst *Array = nullptr;
700  /// Mapped values.
701  SmallVector<Value *, 8> StoredValues;
702  /// Last stores made in the offload array.
703  SmallVector<StoreInst *, 8> LastAccesses;
704 
705  OffloadArray() = default;
706 
707  /// Initializes the OffloadArray with the values stored in \p Array before
708  /// instruction \p Before is reached. Returns false if the initialization
709  /// fails.
710  /// This MUST be used immediately after the construction of the object.
711  bool initialize(AllocaInst &Array, Instruction &Before) {
712  if (!Array.getAllocatedType()->isArrayTy())
713  return false;
714 
715  if (!getValues(Array, Before))
716  return false;
717 
718  this->Array = &Array;
719  return true;
720  }
721 
722  static const unsigned DeviceIDArgNum = 1;
723  static const unsigned BasePtrsArgNum = 3;
724  static const unsigned PtrsArgNum = 4;
725  static const unsigned SizesArgNum = 5;
726 
727 private:
728  /// Traverses the BasicBlock where \p Array is, collecting the stores made to
729  /// \p Array, leaving StoredValues with the values stored before the
730  /// instruction \p Before is reached.
731  bool getValues(AllocaInst &Array, Instruction &Before) {
732  // Initialize container.
733  const uint64_t NumValues = Array.getAllocatedType()->getArrayNumElements();
734  StoredValues.assign(NumValues, nullptr);
735  LastAccesses.assign(NumValues, nullptr);
736 
737  // TODO: This assumes the instruction \p Before is in the same
738  // BasicBlock as Array. Make it general, for any control flow graph.
739  BasicBlock *BB = Array.getParent();
740  if (BB != Before.getParent())
741  return false;
742 
743  const DataLayout &DL = Array.getModule()->getDataLayout();
744  const unsigned int PointerSize = DL.getPointerSize();
745 
746  for (Instruction &I : *BB) {
747  if (&I == &Before)
748  break;
749 
750  if (!isa<StoreInst>(&I))
751  continue;
752 
753  auto *S = cast<StoreInst>(&I);
754  int64_t Offset = -1;
755  auto *Dst =
756  GetPointerBaseWithConstantOffset(S->getPointerOperand(), Offset, DL);
757  if (Dst == &Array) {
758  int64_t Idx = Offset / PointerSize;
759  StoredValues[Idx] = getUnderlyingObject(S->getValueOperand());
760  LastAccesses[Idx] = S;
761  }
762  }
763 
764  return isFilled();
765  }
766 
767  /// Returns true if all values in StoredValues and
768  /// LastAccesses are not nullptrs.
769  bool isFilled() {
770  const unsigned NumValues = StoredValues.size();
771  for (unsigned I = 0; I < NumValues; ++I) {
772  if (!StoredValues[I] || !LastAccesses[I])
773  return false;
774  }
775 
776  return true;
777  }
778 };
779 
780 struct OpenMPOpt {
781 
782  using OptimizationRemarkGetter =
784 
785  OpenMPOpt(SmallVectorImpl<Function *> &SCC, CallGraphUpdater &CGUpdater,
786  OptimizationRemarkGetter OREGetter,
787  OMPInformationCache &OMPInfoCache, Attributor &A)
788  : M(*(*SCC.begin())->getParent()), SCC(SCC), CGUpdater(CGUpdater),
789  OREGetter(OREGetter), OMPInfoCache(OMPInfoCache), A(A) {}
790 
791  /// Check if any remarks are enabled for openmp-opt
792  bool remarksEnabled() {
793  auto &Ctx = M.getContext();
794  return Ctx.getDiagHandlerPtr()->isAnyRemarkEnabled(DEBUG_TYPE);
795  }
796 
797  /// Run all OpenMP optimizations on the underlying SCC/ModuleSlice.
798  bool run(bool IsModulePass) {
799  if (SCC.empty())
800  return false;
801 
802  bool Changed = false;
803 
804  LLVM_DEBUG(dbgs() << TAG << "Run on SCC with " << SCC.size()
805  << " functions in a slice with "
806  << OMPInfoCache.ModuleSlice.size() << " functions\n");
807 
808  if (IsModulePass) {
809  Changed |= runAttributor(IsModulePass);
810 
811  // Recollect uses, in case Attributor deleted any.
812  OMPInfoCache.recollectUses();
813 
814  // TODO: This should be folded into buildCustomStateMachine.
815  Changed |= rewriteDeviceCodeStateMachine();
816 
817  if (remarksEnabled())
818  analysisGlobalization();
819 
820  Changed |= eliminateBarriers();
821  } else {
822  if (PrintICVValues)
823  printICVs();
824  if (PrintOpenMPKernels)
825  printKernels();
826 
827  Changed |= runAttributor(IsModulePass);
828 
829  // Recollect uses, in case Attributor deleted any.
830  OMPInfoCache.recollectUses();
831 
832  Changed |= deleteParallelRegions();
833 
835  Changed |= hideMemTransfersLatency();
836  Changed |= deduplicateRuntimeCalls();
838  if (mergeParallelRegions()) {
839  deduplicateRuntimeCalls();
840  Changed = true;
841  }
842  }
843 
844  Changed |= eliminateBarriers();
845  }
846 
847  return Changed;
848  }
849 
850  /// Print initial ICV values for testing.
851  /// FIXME: This should be done from the Attributor once it is added.
852  void printICVs() const {
853  InternalControlVar ICVs[] = {ICV_nthreads, ICV_active_levels, ICV_cancel,
854  ICV_proc_bind};
855 
856  for (Function *F : OMPInfoCache.ModuleSlice) {
857  for (auto ICV : ICVs) {
858  auto ICVInfo = OMPInfoCache.ICVs[ICV];
859  auto Remark = [&](OptimizationRemarkAnalysis ORA) {
860  return ORA << "OpenMP ICV " << ore::NV("OpenMPICV", ICVInfo.Name)
861  << " Value: "
862  << (ICVInfo.InitValue
863  ? toString(ICVInfo.InitValue->getValue(), 10, true)
864  : "IMPLEMENTATION_DEFINED");
865  };
866 
867  emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPICVTracker", Remark);
868  }
869  }
870  }
871 
872  /// Print OpenMP GPU kernels for testing.
873  void printKernels() const {
874  for (Function *F : SCC) {
875  if (!OMPInfoCache.Kernels.count(F))
876  continue;
877 
878  auto Remark = [&](OptimizationRemarkAnalysis ORA) {
879  return ORA << "OpenMP GPU kernel "
880  << ore::NV("OpenMPGPUKernel", F->getName()) << "\n";
881  };
882 
883  emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPGPU", Remark);
884  }
885  }
886 
887  /// Return the call if \p U is a callee use in a regular call. If \p RFI is
888  /// given it has to be the callee or a nullptr is returned.
889  static CallInst *getCallIfRegularCall(
890  Use &U, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
891  CallInst *CI = dyn_cast<CallInst>(U.getUser());
892  if (CI && CI->isCallee(&U) && !CI->hasOperandBundles() &&
893  (!RFI ||
894  (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
895  return CI;
896  return nullptr;
897  }
898 
899  /// Return the call if \p V is a regular call. If \p RFI is given it has to be
900  /// the callee or a nullptr is returned.
901  static CallInst *getCallIfRegularCall(
902  Value &V, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
903  CallInst *CI = dyn_cast<CallInst>(&V);
904  if (CI && !CI->hasOperandBundles() &&
905  (!RFI ||
906  (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
907  return CI;
908  return nullptr;
909  }
910 
911 private:
912  /// Merge parallel regions when it is safe.
913  bool mergeParallelRegions() {
914  const unsigned CallbackCalleeOperand = 2;
915  const unsigned CallbackFirstArgOperand = 3;
916  using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
917 
918  // Check if there are any __kmpc_fork_call calls to merge.
919  OMPInformationCache::RuntimeFunctionInfo &RFI =
920  OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
921 
922  if (!RFI.Declaration)
923  return false;
924 
925  // Unmergable calls that prevent merging a parallel region.
926  OMPInformationCache::RuntimeFunctionInfo UnmergableCallsInfo[] = {
927  OMPInfoCache.RFIs[OMPRTL___kmpc_push_proc_bind],
928  OMPInfoCache.RFIs[OMPRTL___kmpc_push_num_threads],
929  };
930 
931  bool Changed = false;
932  LoopInfo *LI = nullptr;
933  DominatorTree *DT = nullptr;
934 
936 
937  BasicBlock *StartBB = nullptr, *EndBB = nullptr;
938  auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
939  BasicBlock *CGStartBB = CodeGenIP.getBlock();
940  BasicBlock *CGEndBB =
941  SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
942  assert(StartBB != nullptr && "StartBB should not be null");
943  CGStartBB->getTerminator()->setSuccessor(0, StartBB);
944  assert(EndBB != nullptr && "EndBB should not be null");
945  EndBB->getTerminator()->setSuccessor(0, CGEndBB);
946  };
947 
948  auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, Value &,
949  Value &Inner, Value *&ReplacementValue) -> InsertPointTy {
950  ReplacementValue = &Inner;
951  return CodeGenIP;
952  };
953 
954  auto FiniCB = [&](InsertPointTy CodeGenIP) {};
955 
956  /// Create a sequential execution region within a merged parallel region,
957  /// encapsulated in a master construct with a barrier for synchronization.
958  auto CreateSequentialRegion = [&](Function *OuterFn,
959  BasicBlock *OuterPredBB,
960  Instruction *SeqStartI,
961  Instruction *SeqEndI) {
962  // Isolate the instructions of the sequential region to a separate
963  // block.
964  BasicBlock *ParentBB = SeqStartI->getParent();
965  BasicBlock *SeqEndBB =
966  SplitBlock(ParentBB, SeqEndI->getNextNode(), DT, LI);
967  BasicBlock *SeqAfterBB =
968  SplitBlock(SeqEndBB, &*SeqEndBB->getFirstInsertionPt(), DT, LI);
969  BasicBlock *SeqStartBB =
970  SplitBlock(ParentBB, SeqStartI, DT, LI, nullptr, "seq.par.merged");
971 
972  assert(ParentBB->getUniqueSuccessor() == SeqStartBB &&
973  "Expected a different CFG");
974  const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
975  ParentBB->getTerminator()->eraseFromParent();
976 
977  auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
978  BasicBlock *CGStartBB = CodeGenIP.getBlock();
979  BasicBlock *CGEndBB =
980  SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
981  assert(SeqStartBB != nullptr && "SeqStartBB should not be null");
982  CGStartBB->getTerminator()->setSuccessor(0, SeqStartBB);
983  assert(SeqEndBB != nullptr && "SeqEndBB should not be null");
984  SeqEndBB->getTerminator()->setSuccessor(0, CGEndBB);
985  };
986  auto FiniCB = [&](InsertPointTy CodeGenIP) {};
987 
988  // Find outputs from the sequential region to outside users and
989  // broadcast their values to them.
990  for (Instruction &I : *SeqStartBB) {
991  SmallPtrSet<Instruction *, 4> OutsideUsers;
992  for (User *Usr : I.users()) {
993  Instruction &UsrI = *cast<Instruction>(Usr);
994  // Ignore outputs to LT intrinsics, code extraction for the merged
995  // parallel region will fix them.
996  if (UsrI.isLifetimeStartOrEnd())
997  continue;
998 
999  if (UsrI.getParent() != SeqStartBB)
1000  OutsideUsers.insert(&UsrI);
1001  }
1002 
1003  if (OutsideUsers.empty())
1004  continue;
1005 
1006  // Emit an alloca in the outer region to store the broadcasted
1007  // value.
1008  const DataLayout &DL = M.getDataLayout();
1009  AllocaInst *AllocaI = new AllocaInst(
1010  I.getType(), DL.getAllocaAddrSpace(), nullptr,
1011  I.getName() + ".seq.output.alloc", &OuterFn->front().front());
1012 
1013  // Emit a store instruction in the sequential BB to update the
1014  // value.
1015  new StoreInst(&I, AllocaI, SeqStartBB->getTerminator());
1016 
1017  // Emit a load instruction and replace the use of the output value
1018  // with it.
1019  for (Instruction *UsrI : OutsideUsers) {
1020  LoadInst *LoadI = new LoadInst(
1021  I.getType(), AllocaI, I.getName() + ".seq.output.load", UsrI);
1022  UsrI->replaceUsesOfWith(&I, LoadI);
1023  }
1024  }
1025 
1027  InsertPointTy(ParentBB, ParentBB->end()), DL);
1028  InsertPointTy SeqAfterIP =
1029  OMPInfoCache.OMPBuilder.createMaster(Loc, BodyGenCB, FiniCB);
1030 
1031  OMPInfoCache.OMPBuilder.createBarrier(SeqAfterIP, OMPD_parallel);
1032 
1033  BranchInst::Create(SeqAfterBB, SeqAfterIP.getBlock());
1034 
1035  LLVM_DEBUG(dbgs() << TAG << "After sequential inlining " << *OuterFn
1036  << "\n");
1037  };
1038 
1039  // Helper to merge the __kmpc_fork_call calls in MergableCIs. They are all
1040  // contained in BB and only separated by instructions that can be
1041  // redundantly executed in parallel. The block BB is split before the first
1042  // call (in MergableCIs) and after the last so the entire region we merge
1043  // into a single parallel region is contained in a single basic block
1044  // without any other instructions. We use the OpenMPIRBuilder to outline
1045  // that block and call the resulting function via __kmpc_fork_call.
1046  auto Merge = [&](const SmallVectorImpl<CallInst *> &MergableCIs,
1047  BasicBlock *BB) {
1048  // TODO: Change the interface to allow single CIs expanded, e.g, to
1049  // include an outer loop.
1050  assert(MergableCIs.size() > 1 && "Assumed multiple mergable CIs");
1051 
1052  auto Remark = [&](OptimizationRemark OR) {
1053  OR << "Parallel region merged with parallel region"
1054  << (MergableCIs.size() > 2 ? "s" : "") << " at ";
1055  for (auto *CI : llvm::drop_begin(MergableCIs)) {
1056  OR << ore::NV("OpenMPParallelMerge", CI->getDebugLoc());
1057  if (CI != MergableCIs.back())
1058  OR << ", ";
1059  }
1060  return OR << ".";
1061  };
1062 
1063  emitRemark<OptimizationRemark>(MergableCIs.front(), "OMP150", Remark);
1064 
1065  Function *OriginalFn = BB->getParent();
1066  LLVM_DEBUG(dbgs() << TAG << "Merge " << MergableCIs.size()
1067  << " parallel regions in " << OriginalFn->getName()
1068  << "\n");
1069 
1070  // Isolate the calls to merge in a separate block.
1071  EndBB = SplitBlock(BB, MergableCIs.back()->getNextNode(), DT, LI);
1072  BasicBlock *AfterBB =
1073  SplitBlock(EndBB, &*EndBB->getFirstInsertionPt(), DT, LI);
1074  StartBB = SplitBlock(BB, MergableCIs.front(), DT, LI, nullptr,
1075  "omp.par.merged");
1076 
1077  assert(BB->getUniqueSuccessor() == StartBB && "Expected a different CFG");
1078  const DebugLoc DL = BB->getTerminator()->getDebugLoc();
1079  BB->getTerminator()->eraseFromParent();
1080 
1081  // Create sequential regions for sequential instructions that are
1082  // in-between mergable parallel regions.
1083  for (auto *It = MergableCIs.begin(), *End = MergableCIs.end() - 1;
1084  It != End; ++It) {
1085  Instruction *ForkCI = *It;
1086  Instruction *NextForkCI = *(It + 1);
1087 
1088  // Continue if there are not in-between instructions.
1089  if (ForkCI->getNextNode() == NextForkCI)
1090  continue;
1091 
1092  CreateSequentialRegion(OriginalFn, BB, ForkCI->getNextNode(),
1093  NextForkCI->getPrevNode());
1094  }
1095 
1096  OpenMPIRBuilder::LocationDescription Loc(InsertPointTy(BB, BB->end()),
1097  DL);
1098  IRBuilder<>::InsertPoint AllocaIP(
1099  &OriginalFn->getEntryBlock(),
1100  OriginalFn->getEntryBlock().getFirstInsertionPt());
1101  // Create the merged parallel region with default proc binding, to
1102  // avoid overriding binding settings, and without explicit cancellation.
1103  InsertPointTy AfterIP = OMPInfoCache.OMPBuilder.createParallel(
1104  Loc, AllocaIP, BodyGenCB, PrivCB, FiniCB, nullptr, nullptr,
1105  OMP_PROC_BIND_default, /* IsCancellable */ false);
1106  BranchInst::Create(AfterBB, AfterIP.getBlock());
1107 
1108  // Perform the actual outlining.
1109  OMPInfoCache.OMPBuilder.finalize(OriginalFn);
1110 
1111  Function *OutlinedFn = MergableCIs.front()->getCaller();
1112 
1113  // Replace the __kmpc_fork_call calls with direct calls to the outlined
1114  // callbacks.
1116  for (auto *CI : MergableCIs) {
1117  Value *Callee = CI->getArgOperand(CallbackCalleeOperand);
1118  FunctionType *FT = OMPInfoCache.OMPBuilder.ParallelTask;
1119  Args.clear();
1120  Args.push_back(OutlinedFn->getArg(0));
1121  Args.push_back(OutlinedFn->getArg(1));
1122  for (unsigned U = CallbackFirstArgOperand, E = CI->arg_size(); U < E;
1123  ++U)
1124  Args.push_back(CI->getArgOperand(U));
1125 
1126  CallInst *NewCI = CallInst::Create(FT, Callee, Args, "", CI);
1127  if (CI->getDebugLoc())
1128  NewCI->setDebugLoc(CI->getDebugLoc());
1129 
1130  // Forward parameter attributes from the callback to the callee.
1131  for (unsigned U = CallbackFirstArgOperand, E = CI->arg_size(); U < E;
1132  ++U)
1133  for (const Attribute &A : CI->getAttributes().getParamAttrs(U))
1134  NewCI->addParamAttr(
1135  U - (CallbackFirstArgOperand - CallbackCalleeOperand), A);
1136 
1137  // Emit an explicit barrier to replace the implicit fork-join barrier.
1138  if (CI != MergableCIs.back()) {
1139  // TODO: Remove barrier if the merged parallel region includes the
1140  // 'nowait' clause.
1141  OMPInfoCache.OMPBuilder.createBarrier(
1142  InsertPointTy(NewCI->getParent(),
1143  NewCI->getNextNode()->getIterator()),
1144  OMPD_parallel);
1145  }
1146 
1147  CI->eraseFromParent();
1148  }
1149 
1150  assert(OutlinedFn != OriginalFn && "Outlining failed");
1151  CGUpdater.registerOutlinedFunction(*OriginalFn, *OutlinedFn);
1152  CGUpdater.reanalyzeFunction(*OriginalFn);
1153 
1154  NumOpenMPParallelRegionsMerged += MergableCIs.size();
1155 
1156  return true;
1157  };
1158 
1159  // Helper function that identifes sequences of
1160  // __kmpc_fork_call uses in a basic block.
1161  auto DetectPRsCB = [&](Use &U, Function &F) {
1162  CallInst *CI = getCallIfRegularCall(U, &RFI);
1163  BB2PRMap[CI->getParent()].insert(CI);
1164 
1165  return false;
1166  };
1167 
1168  BB2PRMap.clear();
1169  RFI.foreachUse(SCC, DetectPRsCB);
1170  SmallVector<SmallVector<CallInst *, 4>, 4> MergableCIsVector;
1171  // Find mergable parallel regions within a basic block that are
1172  // safe to merge, that is any in-between instructions can safely
1173  // execute in parallel after merging.
1174  // TODO: support merging across basic-blocks.
1175  for (auto &It : BB2PRMap) {
1176  auto &CIs = It.getSecond();
1177  if (CIs.size() < 2)
1178  continue;
1179 
1180  BasicBlock *BB = It.getFirst();
1181  SmallVector<CallInst *, 4> MergableCIs;
1182 
1183  /// Returns true if the instruction is mergable, false otherwise.
1184  /// A terminator instruction is unmergable by definition since merging
1185  /// works within a BB. Instructions before the mergable region are
1186  /// mergable if they are not calls to OpenMP runtime functions that may
1187  /// set different execution parameters for subsequent parallel regions.
1188  /// Instructions in-between parallel regions are mergable if they are not
1189  /// calls to any non-intrinsic function since that may call a non-mergable
1190  /// OpenMP runtime function.
1191  auto IsMergable = [&](Instruction &I, bool IsBeforeMergableRegion) {
1192  // We do not merge across BBs, hence return false (unmergable) if the
1193  // instruction is a terminator.
1194  if (I.isTerminator())
1195  return false;
1196 
1197  if (!isa<CallInst>(&I))
1198  return true;
1199 
1200  CallInst *CI = cast<CallInst>(&I);
1201  if (IsBeforeMergableRegion) {
1202  Function *CalledFunction = CI->getCalledFunction();
1203  if (!CalledFunction)
1204  return false;
1205  // Return false (unmergable) if the call before the parallel
1206  // region calls an explicit affinity (proc_bind) or number of
1207  // threads (num_threads) compiler-generated function. Those settings
1208  // may be incompatible with following parallel regions.
1209  // TODO: ICV tracking to detect compatibility.
1210  for (const auto &RFI : UnmergableCallsInfo) {
1211  if (CalledFunction == RFI.Declaration)
1212  return false;
1213  }
1214  } else {
1215  // Return false (unmergable) if there is a call instruction
1216  // in-between parallel regions when it is not an intrinsic. It
1217  // may call an unmergable OpenMP runtime function in its callpath.
1218  // TODO: Keep track of possible OpenMP calls in the callpath.
1219  if (!isa<IntrinsicInst>(CI))
1220  return false;
1221  }
1222 
1223  return true;
1224  };
1225  // Find maximal number of parallel region CIs that are safe to merge.
1226  for (auto It = BB->begin(), End = BB->end(); It != End;) {
1227  Instruction &I = *It;
1228  ++It;
1229 
1230  if (CIs.count(&I)) {
1231  MergableCIs.push_back(cast<CallInst>(&I));
1232  continue;
1233  }
1234 
1235  // Continue expanding if the instruction is mergable.
1236  if (IsMergable(I, MergableCIs.empty()))
1237  continue;
1238 
1239  // Forward the instruction iterator to skip the next parallel region
1240  // since there is an unmergable instruction which can affect it.
1241  for (; It != End; ++It) {
1242  Instruction &SkipI = *It;
1243  if (CIs.count(&SkipI)) {
1244  LLVM_DEBUG(dbgs() << TAG << "Skip parallel region " << SkipI
1245  << " due to " << I << "\n");
1246  ++It;
1247  break;
1248  }
1249  }
1250 
1251  // Store mergable regions found.
1252  if (MergableCIs.size() > 1) {
1253  MergableCIsVector.push_back(MergableCIs);
1254  LLVM_DEBUG(dbgs() << TAG << "Found " << MergableCIs.size()
1255  << " parallel regions in block " << BB->getName()
1256  << " of function " << BB->getParent()->getName()
1257  << "\n";);
1258  }
1259 
1260  MergableCIs.clear();
1261  }
1262 
1263  if (!MergableCIsVector.empty()) {
1264  Changed = true;
1265 
1266  for (auto &MergableCIs : MergableCIsVector)
1267  Merge(MergableCIs, BB);
1268  MergableCIsVector.clear();
1269  }
1270  }
1271 
1272  if (Changed) {
1273  /// Re-collect use for fork calls, emitted barrier calls, and
1274  /// any emitted master/end_master calls.
1275  OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_fork_call);
1276  OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_barrier);
1277  OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_master);
1278  OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_end_master);
1279  }
1280 
1281  return Changed;
1282  }
1283 
1284  /// Try to delete parallel regions if possible.
1285  bool deleteParallelRegions() {
1286  const unsigned CallbackCalleeOperand = 2;
1287 
1288  OMPInformationCache::RuntimeFunctionInfo &RFI =
1289  OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
1290 
1291  if (!RFI.Declaration)
1292  return false;
1293 
1294  bool Changed = false;
1295  auto DeleteCallCB = [&](Use &U, Function &) {
1296  CallInst *CI = getCallIfRegularCall(U);
1297  if (!CI)
1298  return false;
1299  auto *Fn = dyn_cast<Function>(
1300  CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts());
1301  if (!Fn)
1302  return false;
1303  if (!Fn->onlyReadsMemory())
1304  return false;
1305  if (!Fn->hasFnAttribute(Attribute::WillReturn))
1306  return false;
1307 
1308  LLVM_DEBUG(dbgs() << TAG << "Delete read-only parallel region in "
1309  << CI->getCaller()->getName() << "\n");
1310 
1311  auto Remark = [&](OptimizationRemark OR) {
1312  return OR << "Removing parallel region with no side-effects.";
1313  };
1314  emitRemark<OptimizationRemark>(CI, "OMP160", Remark);
1315 
1316  CGUpdater.removeCallSite(*CI);
1317  CI->eraseFromParent();
1318  Changed = true;
1319  ++NumOpenMPParallelRegionsDeleted;
1320  return true;
1321  };
1322 
1323  RFI.foreachUse(SCC, DeleteCallCB);
1324 
1325  return Changed;
1326  }
1327 
1328  /// Try to eliminate runtime calls by reusing existing ones.
1329  bool deduplicateRuntimeCalls() {
1330  bool Changed = false;
1331 
1332  RuntimeFunction DeduplicableRuntimeCallIDs[] = {
1333  OMPRTL_omp_get_num_threads,
1334  OMPRTL_omp_in_parallel,
1335  OMPRTL_omp_get_cancellation,
1336  OMPRTL_omp_get_thread_limit,
1337  OMPRTL_omp_get_supported_active_levels,
1338  OMPRTL_omp_get_level,
1339  OMPRTL_omp_get_ancestor_thread_num,
1340  OMPRTL_omp_get_team_size,
1341  OMPRTL_omp_get_active_level,
1342  OMPRTL_omp_in_final,
1343  OMPRTL_omp_get_proc_bind,
1344  OMPRTL_omp_get_num_places,
1345  OMPRTL_omp_get_num_procs,
1346  OMPRTL_omp_get_place_num,
1347  OMPRTL_omp_get_partition_num_places,
1348  OMPRTL_omp_get_partition_place_nums};
1349 
1350  // Global-tid is handled separately.
1351  SmallSetVector<Value *, 16> GTIdArgs;
1352  collectGlobalThreadIdArguments(GTIdArgs);
1353  LLVM_DEBUG(dbgs() << TAG << "Found " << GTIdArgs.size()
1354  << " global thread ID arguments\n");
1355 
1356  for (Function *F : SCC) {
1357  for (auto DeduplicableRuntimeCallID : DeduplicableRuntimeCallIDs)
1358  Changed |= deduplicateRuntimeCalls(
1359  *F, OMPInfoCache.RFIs[DeduplicableRuntimeCallID]);
1360 
1361  // __kmpc_global_thread_num is special as we can replace it with an
1362  // argument in enough cases to make it worth trying.
1363  Value *GTIdArg = nullptr;
1364  for (Argument &Arg : F->args())
1365  if (GTIdArgs.count(&Arg)) {
1366  GTIdArg = &Arg;
1367  break;
1368  }
1369  Changed |= deduplicateRuntimeCalls(
1370  *F, OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num], GTIdArg);
1371  }
1372 
1373  return Changed;
1374  }
1375 
1376  /// Tries to hide the latency of runtime calls that involve host to
1377  /// device memory transfers by splitting them into their "issue" and "wait"
1378  /// versions. The "issue" is moved upwards as much as possible. The "wait" is
1379  /// moved downards as much as possible. The "issue" issues the memory transfer
1380  /// asynchronously, returning a handle. The "wait" waits in the returned
1381  /// handle for the memory transfer to finish.
1382  bool hideMemTransfersLatency() {
1383  auto &RFI = OMPInfoCache.RFIs[OMPRTL___tgt_target_data_begin_mapper];
1384  bool Changed = false;
1385  auto SplitMemTransfers = [&](Use &U, Function &Decl) {
1386  auto *RTCall = getCallIfRegularCall(U, &RFI);
1387  if (!RTCall)
1388  return false;
1389 
1390  OffloadArray OffloadArrays[3];
1391  if (!getValuesInOffloadArrays(*RTCall, OffloadArrays))
1392  return false;
1393 
1394  LLVM_DEBUG(dumpValuesInOffloadArrays(OffloadArrays));
1395 
1396  // TODO: Check if can be moved upwards.
1397  bool WasSplit = false;
1398  Instruction *WaitMovementPoint = canBeMovedDownwards(*RTCall);
1399  if (WaitMovementPoint)
1400  WasSplit = splitTargetDataBeginRTC(*RTCall, *WaitMovementPoint);
1401 
1402  Changed |= WasSplit;
1403  return WasSplit;
1404  };
1405  RFI.foreachUse(SCC, SplitMemTransfers);
1406 
1407  return Changed;
1408  }
1409 
1410  /// Eliminates redundant, aligned barriers in OpenMP offloaded kernels.
1411  /// TODO: Make this an AA and expand it to work across blocks and functions.
1412  bool eliminateBarriers() {
1413  bool Changed = false;
1414 
1416  return /*Changed=*/false;
1417 
1418  if (OMPInfoCache.Kernels.empty())
1419  return /*Changed=*/false;
1420 
1421  enum ImplicitBarrierType { IBT_ENTRY, IBT_EXIT };
1422 
1423  class BarrierInfo {
1424  Instruction *I;
1425  enum ImplicitBarrierType Type;
1426 
1427  public:
1428  BarrierInfo(enum ImplicitBarrierType Type) : I(nullptr), Type(Type) {}
1429  BarrierInfo(Instruction &I) : I(&I) {}
1430 
1431  bool isImplicit() { return !I; }
1432 
1433  bool isImplicitEntry() { return isImplicit() && Type == IBT_ENTRY; }
1434 
1435  bool isImplicitExit() { return isImplicit() && Type == IBT_EXIT; }
1436 
1437  Instruction *getInstruction() { return I; }
1438  };
1439 
1440  for (Function *Kernel : OMPInfoCache.Kernels) {
1441  for (BasicBlock &BB : *Kernel) {
1442  SmallVector<BarrierInfo, 8> BarriersInBlock;
1443  SmallPtrSet<Instruction *, 8> BarriersToBeDeleted;
1444 
1445  // Add the kernel entry implicit barrier.
1446  if (&Kernel->getEntryBlock() == &BB)
1447  BarriersInBlock.push_back(IBT_ENTRY);
1448 
1449  // Find implicit and explicit aligned barriers in the same basic block.
1450  for (Instruction &I : BB) {
1451  if (isa<ReturnInst>(I)) {
1452  // Add the implicit barrier when exiting the kernel.
1453  BarriersInBlock.push_back(IBT_EXIT);
1454  continue;
1455  }
1456  CallBase *CB = dyn_cast<CallBase>(&I);
1457  if (!CB)
1458  continue;
1459 
1460  auto IsAlignBarrierCB = [&](CallBase &CB) {
1461  switch (CB.getIntrinsicID()) {
1462  case Intrinsic::nvvm_barrier0:
1463  case Intrinsic::nvvm_barrier0_and:
1464  case Intrinsic::nvvm_barrier0_or:
1465  case Intrinsic::nvvm_barrier0_popc:
1466  return true;
1467  default:
1468  break;
1469  }
1470  return hasAssumption(CB,
1471  KnownAssumptionString("ompx_aligned_barrier"));
1472  };
1473 
1474  if (IsAlignBarrierCB(*CB)) {
1475  // Add an explicit aligned barrier.
1476  BarriersInBlock.push_back(I);
1477  }
1478  }
1479 
1480  if (BarriersInBlock.size() <= 1)
1481  continue;
1482 
1483  // A barrier in a barrier pair is removeable if all instructions
1484  // between the barriers in the pair are side-effect free modulo the
1485  // barrier operation.
1486  auto IsBarrierRemoveable = [&Kernel](BarrierInfo *StartBI,
1487  BarrierInfo *EndBI) {
1488  assert(
1489  !StartBI->isImplicitExit() &&
1490  "Expected start barrier to be other than a kernel exit barrier");
1491  assert(
1492  !EndBI->isImplicitEntry() &&
1493  "Expected end barrier to be other than a kernel entry barrier");
1494  // If StarBI instructions is null then this the implicit
1495  // kernel entry barrier, so iterate from the first instruction in the
1496  // entry block.
1497  Instruction *I = (StartBI->isImplicitEntry())
1498  ? &Kernel->getEntryBlock().front()
1499  : StartBI->getInstruction()->getNextNode();
1500  assert(I && "Expected non-null start instruction");
1501  Instruction *E = (EndBI->isImplicitExit())
1502  ? I->getParent()->getTerminator()
1503  : EndBI->getInstruction();
1504  assert(E && "Expected non-null end instruction");
1505 
1506  for (; I != E; I = I->getNextNode()) {
1507  if (!I->mayHaveSideEffects() && !I->mayReadFromMemory())
1508  continue;
1509 
1510  auto IsPotentiallyAffectedByBarrier =
1511  [](Optional<MemoryLocation> Loc) {
1512  const Value *Obj = (Loc && Loc->Ptr)
1513  ? getUnderlyingObject(Loc->Ptr)
1514  : nullptr;
1515  if (!Obj) {
1516  LLVM_DEBUG(
1517  dbgs()
1518  << "Access to unknown location requires barriers\n");
1519  return true;
1520  }
1521  if (isa<UndefValue>(Obj))
1522  return false;
1523  if (isa<AllocaInst>(Obj))
1524  return false;
1525  if (auto *GV = dyn_cast<GlobalVariable>(Obj)) {
1526  if (GV->isConstant())
1527  return false;
1528  if (GV->isThreadLocal())
1529  return false;
1530  if (GV->getAddressSpace() == (int)AddressSpace::Local)
1531  return false;
1532  if (GV->getAddressSpace() == (int)AddressSpace::Constant)
1533  return false;
1534  }
1535  LLVM_DEBUG(dbgs() << "Access to '" << *Obj
1536  << "' requires barriers\n");
1537  return true;
1538  };
1539 
1540  if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(I)) {
1542  if (IsPotentiallyAffectedByBarrier(Loc))
1543  return false;
1544  if (MemTransferInst *MTI = dyn_cast<MemTransferInst>(I)) {
1547  if (IsPotentiallyAffectedByBarrier(Loc))
1548  return false;
1549  }
1550  continue;
1551  }
1552 
1553  if (auto *LI = dyn_cast<LoadInst>(I))
1554  if (LI->hasMetadata(LLVMContext::MD_invariant_load))
1555  continue;
1556 
1558  if (IsPotentiallyAffectedByBarrier(Loc))
1559  return false;
1560  }
1561 
1562  return true;
1563  };
1564 
1565  // Iterate barrier pairs and remove an explicit barrier if analysis
1566  // deems it removeable.
1567  for (auto *It = BarriersInBlock.begin(),
1568  *End = BarriersInBlock.end() - 1;
1569  It != End; ++It) {
1570 
1571  BarrierInfo *StartBI = It;
1572  BarrierInfo *EndBI = (It + 1);
1573 
1574  // Cannot remove when both are implicit barriers, continue.
1575  if (StartBI->isImplicit() && EndBI->isImplicit())
1576  continue;
1577 
1578  if (!IsBarrierRemoveable(StartBI, EndBI))
1579  continue;
1580 
1581  assert(!(StartBI->isImplicit() && EndBI->isImplicit()) &&
1582  "Expected at least one explicit barrier to remove.");
1583 
1584  // Remove an explicit barrier, check first, then second.
1585  if (!StartBI->isImplicit()) {
1586  LLVM_DEBUG(dbgs() << "Remove start barrier "
1587  << *StartBI->getInstruction() << "\n");
1588  BarriersToBeDeleted.insert(StartBI->getInstruction());
1589  } else {
1590  LLVM_DEBUG(dbgs() << "Remove end barrier "
1591  << *EndBI->getInstruction() << "\n");
1592  BarriersToBeDeleted.insert(EndBI->getInstruction());
1593  }
1594  }
1595 
1596  if (BarriersToBeDeleted.empty())
1597  continue;
1598 
1599  Changed = true;
1600  for (Instruction *I : BarriersToBeDeleted) {
1601  ++NumBarriersEliminated;
1602  auto Remark = [&](OptimizationRemark OR) {
1603  return OR << "Redundant barrier eliminated.";
1604  };
1605 
1607  emitRemark<OptimizationRemark>(I, "OMP190", Remark);
1608  I->eraseFromParent();
1609  }
1610  }
1611  }
1612 
1613  return Changed;
1614  }
1615 
1616  void analysisGlobalization() {
1617  auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
1618 
1619  auto CheckGlobalization = [&](Use &U, Function &Decl) {
1620  if (CallInst *CI = getCallIfRegularCall(U, &RFI)) {
1621  auto Remark = [&](OptimizationRemarkMissed ORM) {
1622  return ORM
1623  << "Found thread data sharing on the GPU. "
1624  << "Expect degraded performance due to data globalization.";
1625  };
1626  emitRemark<OptimizationRemarkMissed>(CI, "OMP112", Remark);
1627  }
1628 
1629  return false;
1630  };
1631 
1632  RFI.foreachUse(SCC, CheckGlobalization);
1633  }
1634 
1635  /// Maps the values stored in the offload arrays passed as arguments to
1636  /// \p RuntimeCall into the offload arrays in \p OAs.
1637  bool getValuesInOffloadArrays(CallInst &RuntimeCall,
1639  assert(OAs.size() == 3 && "Need space for three offload arrays!");
1640 
1641  // A runtime call that involves memory offloading looks something like:
1642  // call void @__tgt_target_data_begin_mapper(arg0, arg1,
1643  // i8** %offload_baseptrs, i8** %offload_ptrs, i64* %offload_sizes,
1644  // ...)
1645  // So, the idea is to access the allocas that allocate space for these
1646  // offload arrays, offload_baseptrs, offload_ptrs, offload_sizes.
1647  // Therefore:
1648  // i8** %offload_baseptrs.
1649  Value *BasePtrsArg =
1650  RuntimeCall.getArgOperand(OffloadArray::BasePtrsArgNum);
1651  // i8** %offload_ptrs.
1652  Value *PtrsArg = RuntimeCall.getArgOperand(OffloadArray::PtrsArgNum);
1653  // i8** %offload_sizes.
1654  Value *SizesArg = RuntimeCall.getArgOperand(OffloadArray::SizesArgNum);
1655 
1656  // Get values stored in **offload_baseptrs.
1657  auto *V = getUnderlyingObject(BasePtrsArg);
1658  if (!isa<AllocaInst>(V))
1659  return false;
1660  auto *BasePtrsArray = cast<AllocaInst>(V);
1661  if (!OAs[0].initialize(*BasePtrsArray, RuntimeCall))
1662  return false;
1663 
1664  // Get values stored in **offload_baseptrs.
1665  V = getUnderlyingObject(PtrsArg);
1666  if (!isa<AllocaInst>(V))
1667  return false;
1668  auto *PtrsArray = cast<AllocaInst>(V);
1669  if (!OAs[1].initialize(*PtrsArray, RuntimeCall))
1670  return false;
1671 
1672  // Get values stored in **offload_sizes.
1673  V = getUnderlyingObject(SizesArg);
1674  // If it's a [constant] global array don't analyze it.
1675  if (isa<GlobalValue>(V))
1676  return isa<Constant>(V);
1677  if (!isa<AllocaInst>(V))
1678  return false;
1679 
1680  auto *SizesArray = cast<AllocaInst>(V);
1681  if (!OAs[2].initialize(*SizesArray, RuntimeCall))
1682  return false;
1683 
1684  return true;
1685  }
1686 
1687  /// Prints the values in the OffloadArrays \p OAs using LLVM_DEBUG.
1688  /// For now this is a way to test that the function getValuesInOffloadArrays
1689  /// is working properly.
1690  /// TODO: Move this to a unittest when unittests are available for OpenMPOpt.
1691  void dumpValuesInOffloadArrays(ArrayRef<OffloadArray> OAs) {
1692  assert(OAs.size() == 3 && "There are three offload arrays to debug!");
1693 
1694  LLVM_DEBUG(dbgs() << TAG << " Successfully got offload values:\n");
1695  std::string ValuesStr;
1696  raw_string_ostream Printer(ValuesStr);
1697  std::string Separator = " --- ";
1698 
1699  for (auto *BP : OAs[0].StoredValues) {
1700  BP->print(Printer);
1701  Printer << Separator;
1702  }
1703  LLVM_DEBUG(dbgs() << "\t\toffload_baseptrs: " << Printer.str() << "\n");
1704  ValuesStr.clear();
1705 
1706  for (auto *P : OAs[1].StoredValues) {
1707  P->print(Printer);
1708  Printer << Separator;
1709  }
1710  LLVM_DEBUG(dbgs() << "\t\toffload_ptrs: " << Printer.str() << "\n");
1711  ValuesStr.clear();
1712 
1713  for (auto *S : OAs[2].StoredValues) {
1714  S->print(Printer);
1715  Printer << Separator;
1716  }
1717  LLVM_DEBUG(dbgs() << "\t\toffload_sizes: " << Printer.str() << "\n");
1718  }
1719 
1720  /// Returns the instruction where the "wait" counterpart \p RuntimeCall can be
1721  /// moved. Returns nullptr if the movement is not possible, or not worth it.
1722  Instruction *canBeMovedDownwards(CallInst &RuntimeCall) {
1723  // FIXME: This traverses only the BasicBlock where RuntimeCall is.
1724  // Make it traverse the CFG.
1725 
1726  Instruction *CurrentI = &RuntimeCall;
1727  bool IsWorthIt = false;
1728  while ((CurrentI = CurrentI->getNextNode())) {
1729 
1730  // TODO: Once we detect the regions to be offloaded we should use the
1731  // alias analysis manager to check if CurrentI may modify one of
1732  // the offloaded regions.
1733  if (CurrentI->mayHaveSideEffects() || CurrentI->mayReadFromMemory()) {
1734  if (IsWorthIt)
1735  return CurrentI;
1736 
1737  return nullptr;
1738  }
1739 
1740  // FIXME: For now if we move it over anything without side effect
1741  // is worth it.
1742  IsWorthIt = true;
1743  }
1744 
1745  // Return end of BasicBlock.
1746  return RuntimeCall.getParent()->getTerminator();
1747  }
1748 
1749  /// Splits \p RuntimeCall into its "issue" and "wait" counterparts.
1750  bool splitTargetDataBeginRTC(CallInst &RuntimeCall,
1751  Instruction &WaitMovementPoint) {
1752  // Create stack allocated handle (__tgt_async_info) at the beginning of the
1753  // function. Used for storing information of the async transfer, allowing to
1754  // wait on it later.
1755  auto &IRBuilder = OMPInfoCache.OMPBuilder;
1756  auto *F = RuntimeCall.getCaller();
1757  Instruction *FirstInst = &(F->getEntryBlock().front());
1758  AllocaInst *Handle = new AllocaInst(
1759  IRBuilder.AsyncInfo, F->getAddressSpace(), "handle", FirstInst);
1760 
1761  // Add "issue" runtime call declaration:
1762  // declare %struct.tgt_async_info @__tgt_target_data_begin_issue(i64, i32,
1763  // i8**, i8**, i64*, i64*)
1764  FunctionCallee IssueDecl = IRBuilder.getOrCreateRuntimeFunction(
1765  M, OMPRTL___tgt_target_data_begin_mapper_issue);
1766 
1767  // Change RuntimeCall call site for its asynchronous version.
1769  for (auto &Arg : RuntimeCall.args())
1770  Args.push_back(Arg.get());
1771  Args.push_back(Handle);
1772 
1773  CallInst *IssueCallsite =
1774  CallInst::Create(IssueDecl, Args, /*NameStr=*/"", &RuntimeCall);
1775  OMPInfoCache.setCallingConvention(IssueDecl, IssueCallsite);
1776  RuntimeCall.eraseFromParent();
1777 
1778  // Add "wait" runtime call declaration:
1779  // declare void @__tgt_target_data_begin_wait(i64, %struct.__tgt_async_info)
1780  FunctionCallee WaitDecl = IRBuilder.getOrCreateRuntimeFunction(
1781  M, OMPRTL___tgt_target_data_begin_mapper_wait);
1782 
1783  Value *WaitParams[2] = {
1784  IssueCallsite->getArgOperand(
1785  OffloadArray::DeviceIDArgNum), // device_id.
1786  Handle // handle to wait on.
1787  };
1788  CallInst *WaitCallsite = CallInst::Create(
1789  WaitDecl, WaitParams, /*NameStr=*/"", &WaitMovementPoint);
1790  OMPInfoCache.setCallingConvention(WaitDecl, WaitCallsite);
1791 
1792  return true;
1793  }
1794 
1795  static Value *combinedIdentStruct(Value *CurrentIdent, Value *NextIdent,
1796  bool GlobalOnly, bool &SingleChoice) {
1797  if (CurrentIdent == NextIdent)
1798  return CurrentIdent;
1799 
1800  // TODO: Figure out how to actually combine multiple debug locations. For
1801  // now we just keep an existing one if there is a single choice.
1802  if (!GlobalOnly || isa<GlobalValue>(NextIdent)) {
1803  SingleChoice = !CurrentIdent;
1804  return NextIdent;
1805  }
1806  return nullptr;
1807  }
1808 
1809  /// Return an `struct ident_t*` value that represents the ones used in the
1810  /// calls of \p RFI inside of \p F. If \p GlobalOnly is true, we will not
1811  /// return a local `struct ident_t*`. For now, if we cannot find a suitable
1812  /// return value we create one from scratch. We also do not yet combine
1813  /// information, e.g., the source locations, see combinedIdentStruct.
1814  Value *
1815  getCombinedIdentFromCallUsesIn(OMPInformationCache::RuntimeFunctionInfo &RFI,
1816  Function &F, bool GlobalOnly) {
1817  bool SingleChoice = true;
1818  Value *Ident = nullptr;
1819  auto CombineIdentStruct = [&](Use &U, Function &Caller) {
1820  CallInst *CI = getCallIfRegularCall(U, &RFI);
1821  if (!CI || &F != &Caller)
1822  return false;
1823  Ident = combinedIdentStruct(Ident, CI->getArgOperand(0),
1824  /* GlobalOnly */ true, SingleChoice);
1825  return false;
1826  };
1827  RFI.foreachUse(SCC, CombineIdentStruct);
1828 
1829  if (!Ident || !SingleChoice) {
1830  // The IRBuilder uses the insertion block to get to the module, this is
1831  // unfortunate but we work around it for now.
1832  if (!OMPInfoCache.OMPBuilder.getInsertionPoint().getBlock())
1833  OMPInfoCache.OMPBuilder.updateToLocation(OpenMPIRBuilder::InsertPointTy(
1834  &F.getEntryBlock(), F.getEntryBlock().begin()));
1835  // Create a fallback location if non was found.
1836  // TODO: Use the debug locations of the calls instead.
1837  uint32_t SrcLocStrSize;
1838  Constant *Loc =
1839  OMPInfoCache.OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
1840  Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(Loc, SrcLocStrSize);
1841  }
1842  return Ident;
1843  }
1844 
1845  /// Try to eliminate calls of \p RFI in \p F by reusing an existing one or
1846  /// \p ReplVal if given.
1847  bool deduplicateRuntimeCalls(Function &F,
1848  OMPInformationCache::RuntimeFunctionInfo &RFI,
1849  Value *ReplVal = nullptr) {
1850  auto *UV = RFI.getUseVector(F);
1851  if (!UV || UV->size() + (ReplVal != nullptr) < 2)
1852  return false;
1853 
1854  LLVM_DEBUG(
1855  dbgs() << TAG << "Deduplicate " << UV->size() << " uses of " << RFI.Name
1856  << (ReplVal ? " with an existing value\n" : "\n") << "\n");
1857 
1858  assert((!ReplVal || (isa<Argument>(ReplVal) &&
1859  cast<Argument>(ReplVal)->getParent() == &F)) &&
1860  "Unexpected replacement value!");
1861 
1862  // TODO: Use dominance to find a good position instead.
1863  auto CanBeMoved = [this](CallBase &CB) {
1864  unsigned NumArgs = CB.arg_size();
1865  if (NumArgs == 0)
1866  return true;
1867  if (CB.getArgOperand(0)->getType() != OMPInfoCache.OMPBuilder.IdentPtr)
1868  return false;
1869  for (unsigned U = 1; U < NumArgs; ++U)
1870  if (isa<Instruction>(CB.getArgOperand(U)))
1871  return false;
1872  return true;
1873  };
1874 
1875  if (!ReplVal) {
1876  for (Use *U : *UV)
1877  if (CallInst *CI = getCallIfRegularCall(*U, &RFI)) {
1878  if (!CanBeMoved(*CI))
1879  continue;
1880 
1881  // If the function is a kernel, dedup will move
1882  // the runtime call right after the kernel init callsite. Otherwise,
1883  // it will move it to the beginning of the caller function.
1884  if (isKernel(F)) {
1885  auto &KernelInitRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
1886  auto *KernelInitUV = KernelInitRFI.getUseVector(F);
1887 
1888  if (KernelInitUV->empty())
1889  continue;
1890 
1891  assert(KernelInitUV->size() == 1 &&
1892  "Expected a single __kmpc_target_init in kernel\n");
1893 
1894  CallInst *KernelInitCI =
1895  getCallIfRegularCall(*KernelInitUV->front(), &KernelInitRFI);
1896  assert(KernelInitCI &&
1897  "Expected a call to __kmpc_target_init in kernel\n");
1898 
1899  CI->moveAfter(KernelInitCI);
1900  } else
1901  CI->moveBefore(&*F.getEntryBlock().getFirstInsertionPt());
1902  ReplVal = CI;
1903  break;
1904  }
1905  if (!ReplVal)
1906  return false;
1907  }
1908 
1909  // If we use a call as a replacement value we need to make sure the ident is
1910  // valid at the new location. For now we just pick a global one, either
1911  // existing and used by one of the calls, or created from scratch.
1912  if (CallBase *CI = dyn_cast<CallBase>(ReplVal)) {
1913  if (!CI->arg_empty() &&
1914  CI->getArgOperand(0)->getType() == OMPInfoCache.OMPBuilder.IdentPtr) {
1915  Value *Ident = getCombinedIdentFromCallUsesIn(RFI, F,
1916  /* GlobalOnly */ true);
1917  CI->setArgOperand(0, Ident);
1918  }
1919  }
1920 
1921  bool Changed = false;
1922  auto ReplaceAndDeleteCB = [&](Use &U, Function &Caller) {
1923  CallInst *CI = getCallIfRegularCall(U, &RFI);
1924  if (!CI || CI == ReplVal || &F != &Caller)
1925  return false;
1926  assert(CI->getCaller() == &F && "Unexpected call!");
1927 
1928  auto Remark = [&](OptimizationRemark OR) {
1929  return OR << "OpenMP runtime call "
1930  << ore::NV("OpenMPOptRuntime", RFI.Name) << " deduplicated.";
1931  };
1932  if (CI->getDebugLoc())
1933  emitRemark<OptimizationRemark>(CI, "OMP170", Remark);
1934  else
1935  emitRemark<OptimizationRemark>(&F, "OMP170", Remark);
1936 
1937  CGUpdater.removeCallSite(*CI);
1938  CI->replaceAllUsesWith(ReplVal);
1939  CI->eraseFromParent();
1940  ++NumOpenMPRuntimeCallsDeduplicated;
1941  Changed = true;
1942  return true;
1943  };
1944  RFI.foreachUse(SCC, ReplaceAndDeleteCB);
1945 
1946  return Changed;
1947  }
1948 
1949  /// Collect arguments that represent the global thread id in \p GTIdArgs.
1950  void collectGlobalThreadIdArguments(SmallSetVector<Value *, 16> &GTIdArgs) {
1951  // TODO: Below we basically perform a fixpoint iteration with a pessimistic
1952  // initialization. We could define an AbstractAttribute instead and
1953  // run the Attributor here once it can be run as an SCC pass.
1954 
1955  // Helper to check the argument \p ArgNo at all call sites of \p F for
1956  // a GTId.
1957  auto CallArgOpIsGTId = [&](Function &F, unsigned ArgNo, CallInst &RefCI) {
1958  if (!F.hasLocalLinkage())
1959  return false;
1960  for (Use &U : F.uses()) {
1961  if (CallInst *CI = getCallIfRegularCall(U)) {
1962  Value *ArgOp = CI->getArgOperand(ArgNo);
1963  if (CI == &RefCI || GTIdArgs.count(ArgOp) ||
1964  getCallIfRegularCall(
1965  *ArgOp, &OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num]))
1966  continue;
1967  }
1968  return false;
1969  }
1970  return true;
1971  };
1972 
1973  // Helper to identify uses of a GTId as GTId arguments.
1974  auto AddUserArgs = [&](Value &GTId) {
1975  for (Use &U : GTId.uses())
1976  if (CallInst *CI = dyn_cast<CallInst>(U.getUser()))
1977  if (CI->isArgOperand(&U))
1978  if (Function *Callee = CI->getCalledFunction())
1979  if (CallArgOpIsGTId(*Callee, U.getOperandNo(), *CI))
1980  GTIdArgs.insert(Callee->getArg(U.getOperandNo()));
1981  };
1982 
1983  // The argument users of __kmpc_global_thread_num calls are GTIds.
1984  OMPInformationCache::RuntimeFunctionInfo &GlobThreadNumRFI =
1985  OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num];
1986 
1987  GlobThreadNumRFI.foreachUse(SCC, [&](Use &U, Function &F) {
1988  if (CallInst *CI = getCallIfRegularCall(U, &GlobThreadNumRFI))
1989  AddUserArgs(*CI);
1990  return false;
1991  });
1992 
1993  // Transitively search for more arguments by looking at the users of the
1994  // ones we know already. During the search the GTIdArgs vector is extended
1995  // so we cannot cache the size nor can we use a range based for.
1996  for (unsigned U = 0; U < GTIdArgs.size(); ++U)
1997  AddUserArgs(*GTIdArgs[U]);
1998  }
1999 
2000  /// Kernel (=GPU) optimizations and utility functions
2001  ///
2002  ///{{
2003 
2004  /// Check if \p F is a kernel, hence entry point for target offloading.
2005  bool isKernel(Function &F) { return OMPInfoCache.Kernels.count(&F); }
2006 
2007  /// Cache to remember the unique kernel for a function.
2008  DenseMap<Function *, Optional<Kernel>> UniqueKernelMap;
2009 
2010  /// Find the unique kernel that will execute \p F, if any.
2011  Kernel getUniqueKernelFor(Function &F);
2012 
2013  /// Find the unique kernel that will execute \p I, if any.
2014  Kernel getUniqueKernelFor(Instruction &I) {
2015  return getUniqueKernelFor(*I.getFunction());
2016  }
2017 
2018  /// Rewrite the device (=GPU) code state machine create in non-SPMD mode in
2019  /// the cases we can avoid taking the address of a function.
2020  bool rewriteDeviceCodeStateMachine();
2021 
2022  ///
2023  ///}}
2024 
2025  /// Emit a remark generically
2026  ///
2027  /// This template function can be used to generically emit a remark. The
2028  /// RemarkKind should be one of the following:
2029  /// - OptimizationRemark to indicate a successful optimization attempt
2030  /// - OptimizationRemarkMissed to report a failed optimization attempt
2031  /// - OptimizationRemarkAnalysis to provide additional information about an
2032  /// optimization attempt
2033  ///
2034  /// The remark is built using a callback function provided by the caller that
2035  /// takes a RemarkKind as input and returns a RemarkKind.
2036  template <typename RemarkKind, typename RemarkCallBack>
2037  void emitRemark(Instruction *I, StringRef RemarkName,
2038  RemarkCallBack &&RemarkCB) const {
2039  Function *F = I->getParent()->getParent();
2040  auto &ORE = OREGetter(F);
2041 
2042  if (RemarkName.startswith("OMP"))
2043  ORE.emit([&]() {
2044  return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I))
2045  << " [" << RemarkName << "]";
2046  });
2047  else
2048  ORE.emit(
2049  [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I)); });
2050  }
2051 
2052  /// Emit a remark on a function.
2053  template <typename RemarkKind, typename RemarkCallBack>
2054  void emitRemark(Function *F, StringRef RemarkName,
2055  RemarkCallBack &&RemarkCB) const {
2056  auto &ORE = OREGetter(F);
2057 
2058  if (RemarkName.startswith("OMP"))
2059  ORE.emit([&]() {
2060  return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F))
2061  << " [" << RemarkName << "]";
2062  });
2063  else
2064  ORE.emit(
2065  [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F)); });
2066  }
2067 
2068  /// RAII struct to temporarily change an RTL function's linkage to external.
2069  /// This prevents it from being mistakenly removed by other optimizations.
2070  struct ExternalizationRAII {
2071  ExternalizationRAII(OMPInformationCache &OMPInfoCache,
2072  RuntimeFunction RFKind)
2073  : Declaration(OMPInfoCache.RFIs[RFKind].Declaration) {
2074  if (!Declaration)
2075  return;
2076 
2077  LinkageType = Declaration->getLinkage();
2078  Declaration->setLinkage(GlobalValue::ExternalLinkage);
2079  }
2080 
2081  ~ExternalizationRAII() {
2082  if (!Declaration)
2083  return;
2084 
2085  Declaration->setLinkage(LinkageType);
2086  }
2087 
2088  Function *Declaration;
2090  };
2091 
2092  /// The underlying module.
2093  Module &M;
2094 
2095  /// The SCC we are operating on.
2097 
2098  /// Callback to update the call graph, the first argument is a removed call,
2099  /// the second an optional replacement call.
2100  CallGraphUpdater &CGUpdater;
2101 
2102  /// Callback to get an OptimizationRemarkEmitter from a Function *
2103  OptimizationRemarkGetter OREGetter;
2104 
2105  /// OpenMP-specific information cache. Also Used for Attributor runs.
2106  OMPInformationCache &OMPInfoCache;
2107 
2108  /// Attributor instance.
2109  Attributor &A;
2110 
2111  /// Helper function to run Attributor on SCC.
2112  bool runAttributor(bool IsModulePass) {
2113  if (SCC.empty())
2114  return false;
2115 
2116  // Temporarily make these function have external linkage so the Attributor
2117  // doesn't remove them when we try to look them up later.
2118  ExternalizationRAII Parallel(OMPInfoCache, OMPRTL___kmpc_kernel_parallel);
2119  ExternalizationRAII EndParallel(OMPInfoCache,
2120  OMPRTL___kmpc_kernel_end_parallel);
2121  ExternalizationRAII BarrierSPMD(OMPInfoCache,
2122  OMPRTL___kmpc_barrier_simple_spmd);
2123  ExternalizationRAII BarrierGeneric(OMPInfoCache,
2124  OMPRTL___kmpc_barrier_simple_generic);
2125  ExternalizationRAII ThreadId(OMPInfoCache,
2126  OMPRTL___kmpc_get_hardware_thread_id_in_block);
2127  ExternalizationRAII NumThreads(
2128  OMPInfoCache, OMPRTL___kmpc_get_hardware_num_threads_in_block);
2129  ExternalizationRAII WarpSize(OMPInfoCache, OMPRTL___kmpc_get_warp_size);
2130 
2131  registerAAs(IsModulePass);
2132 
2133  ChangeStatus Changed = A.run();
2134 
2135  LLVM_DEBUG(dbgs() << "[Attributor] Done with " << SCC.size()
2136  << " functions, result: " << Changed << ".\n");
2137 
2138  return Changed == ChangeStatus::CHANGED;
2139  }
2140 
2141  void registerFoldRuntimeCall(RuntimeFunction RF);
2142 
2143  /// Populate the Attributor with abstract attribute opportunities in the
2144  /// function.
2145  void registerAAs(bool IsModulePass);
2146 };
2147 
2148 Kernel OpenMPOpt::getUniqueKernelFor(Function &F) {
2149  if (!OMPInfoCache.ModuleSlice.count(&F))
2150  return nullptr;
2151 
2152  // Use a scope to keep the lifetime of the CachedKernel short.
2153  {
2154  Optional<Kernel> &CachedKernel = UniqueKernelMap[&F];
2155  if (CachedKernel)
2156  return *CachedKernel;
2157 
2158  // TODO: We should use an AA to create an (optimistic and callback
2159  // call-aware) call graph. For now we stick to simple patterns that
2160  // are less powerful, basically the worst fixpoint.
2161  if (isKernel(F)) {
2162  CachedKernel = Kernel(&F);
2163  return *CachedKernel;
2164  }
2165 
2166  CachedKernel = nullptr;
2167  if (!F.hasLocalLinkage()) {
2168 
2169  // See https://openmp.llvm.org/remarks/OptimizationRemarks.html
2170  auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2171  return ORA << "Potentially unknown OpenMP target region caller.";
2172  };
2173  emitRemark<OptimizationRemarkAnalysis>(&F, "OMP100", Remark);
2174 
2175  return nullptr;
2176  }
2177  }
2178 
2179  auto GetUniqueKernelForUse = [&](const Use &U) -> Kernel {
2180  if (auto *Cmp = dyn_cast<ICmpInst>(U.getUser())) {
2181  // Allow use in equality comparisons.
2182  if (Cmp->isEquality())
2183  return getUniqueKernelFor(*Cmp);
2184  return nullptr;
2185  }
2186  if (auto *CB = dyn_cast<CallBase>(U.getUser())) {
2187  // Allow direct calls.
2188  if (CB->isCallee(&U))
2189  return getUniqueKernelFor(*CB);
2190 
2191  OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
2192  OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
2193  // Allow the use in __kmpc_parallel_51 calls.
2194  if (OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI))
2195  return getUniqueKernelFor(*CB);
2196  return nullptr;
2197  }
2198  // Disallow every other use.
2199  return nullptr;
2200  };
2201 
2202  // TODO: In the future we want to track more than just a unique kernel.
2203  SmallPtrSet<Kernel, 2> PotentialKernels;
2204  OMPInformationCache::foreachUse(F, [&](const Use &U) {
2205  PotentialKernels.insert(GetUniqueKernelForUse(U));
2206  });
2207 
2208  Kernel K = nullptr;
2209  if (PotentialKernels.size() == 1)
2210  K = *PotentialKernels.begin();
2211 
2212  // Cache the result.
2213  UniqueKernelMap[&F] = K;
2214 
2215  return K;
2216 }
2217 
2218 bool OpenMPOpt::rewriteDeviceCodeStateMachine() {
2219  OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
2220  OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
2221 
2222  bool Changed = false;
2223  if (!KernelParallelRFI)
2224  return Changed;
2225 
2226  // If we have disabled state machine changes, exit
2228  return Changed;
2229 
2230  for (Function *F : SCC) {
2231 
2232  // Check if the function is a use in a __kmpc_parallel_51 call at
2233  // all.
2234  bool UnknownUse = false;
2235  bool KernelParallelUse = false;
2236  unsigned NumDirectCalls = 0;
2237 
2238  SmallVector<Use *, 2> ToBeReplacedStateMachineUses;
2239  OMPInformationCache::foreachUse(*F, [&](Use &U) {
2240  if (auto *CB = dyn_cast<CallBase>(U.getUser()))
2241  if (CB->isCallee(&U)) {
2242  ++NumDirectCalls;
2243  return;
2244  }
2245 
2246  if (isa<ICmpInst>(U.getUser())) {
2247  ToBeReplacedStateMachineUses.push_back(&U);
2248  return;
2249  }
2250 
2251  // Find wrapper functions that represent parallel kernels.
2252  CallInst *CI =
2253  OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI);
2254  const unsigned int WrapperFunctionArgNo = 6;
2255  if (!KernelParallelUse && CI &&
2256  CI->getArgOperandNo(&U) == WrapperFunctionArgNo) {
2257  KernelParallelUse = true;
2258  ToBeReplacedStateMachineUses.push_back(&U);
2259  return;
2260  }
2261  UnknownUse = true;
2262  });
2263 
2264  // Do not emit a remark if we haven't seen a __kmpc_parallel_51
2265  // use.
2266  if (!KernelParallelUse)
2267  continue;
2268 
2269  // If this ever hits, we should investigate.
2270  // TODO: Checking the number of uses is not a necessary restriction and
2271  // should be lifted.
2272  if (UnknownUse || NumDirectCalls != 1 ||
2273  ToBeReplacedStateMachineUses.size() > 2) {
2274  auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2275  return ORA << "Parallel region is used in "
2276  << (UnknownUse ? "unknown" : "unexpected")
2277  << " ways. Will not attempt to rewrite the state machine.";
2278  };
2279  emitRemark<OptimizationRemarkAnalysis>(F, "OMP101", Remark);
2280  continue;
2281  }
2282 
2283  // Even if we have __kmpc_parallel_51 calls, we (for now) give
2284  // up if the function is not called from a unique kernel.
2285  Kernel K = getUniqueKernelFor(*F);
2286  if (!K) {
2287  auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2288  return ORA << "Parallel region is not called from a unique kernel. "
2289  "Will not attempt to rewrite the state machine.";
2290  };
2291  emitRemark<OptimizationRemarkAnalysis>(F, "OMP102", Remark);
2292  continue;
2293  }
2294 
2295  // We now know F is a parallel body function called only from the kernel K.
2296  // We also identified the state machine uses in which we replace the
2297  // function pointer by a new global symbol for identification purposes. This
2298  // ensures only direct calls to the function are left.
2299 
2300  Module &M = *F->getParent();
2301  Type *Int8Ty = Type::getInt8Ty(M.getContext());
2302 
2303  auto *ID = new GlobalVariable(
2304  M, Int8Ty, /* isConstant */ true, GlobalValue::PrivateLinkage,
2305  UndefValue::get(Int8Ty), F->getName() + ".ID");
2306 
2307  for (Use *U : ToBeReplacedStateMachineUses)
2309  ID, U->get()->getType()));
2310 
2311  ++NumOpenMPParallelRegionsReplacedInGPUStateMachine;
2312 
2313  Changed = true;
2314  }
2315 
2316  return Changed;
2317 }
2318 
2319 /// Abstract Attribute for tracking ICV values.
2320 struct AAICVTracker : public StateWrapper<BooleanState, AbstractAttribute> {
2322  AAICVTracker(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
2323 
2324  void initialize(Attributor &A) override {
2325  Function *F = getAnchorScope();
2326  if (!F || !A.isFunctionIPOAmendable(*F))
2327  indicatePessimisticFixpoint();
2328  }
2329 
2330  /// Returns true if value is assumed to be tracked.
2331  bool isAssumedTracked() const { return getAssumed(); }
2332 
2333  /// Returns true if value is known to be tracked.
2334  bool isKnownTracked() const { return getAssumed(); }
2335 
2336  /// Create an abstract attribute biew for the position \p IRP.
2337  static AAICVTracker &createForPosition(const IRPosition &IRP, Attributor &A);
2338 
2339  /// Return the value with which \p I can be replaced for specific \p ICV.
2340  virtual Optional<Value *> getReplacementValue(InternalControlVar ICV,
2341  const Instruction *I,
2342  Attributor &A) const {
2343  return None;
2344  }
2345 
2346  /// Return an assumed unique ICV value if a single candidate is found. If
2347  /// there cannot be one, return a nullptr. If it is not clear yet, return the
2348  /// Optional::NoneType.
2349  virtual Optional<Value *>
2350  getUniqueReplacementValue(InternalControlVar ICV) const = 0;
2351 
2352  // Currently only nthreads is being tracked.
2353  // this array will only grow with time.
2354  InternalControlVar TrackableICVs[1] = {ICV_nthreads};
2355 
2356  /// See AbstractAttribute::getName()
2357  const std::string getName() const override { return "AAICVTracker"; }
2358 
2359  /// See AbstractAttribute::getIdAddr()
2360  const char *getIdAddr() const override { return &ID; }
2361 
2362  /// This function should return true if the type of the \p AA is AAICVTracker
2363  static bool classof(const AbstractAttribute *AA) {
2364  return (AA->getIdAddr() == &ID);
2365  }
2366 
2367  static const char ID;
2368 };
2369 
2370 struct AAICVTrackerFunction : public AAICVTracker {
2371  AAICVTrackerFunction(const IRPosition &IRP, Attributor &A)
2372  : AAICVTracker(IRP, A) {}
2373 
2374  // FIXME: come up with better string.
2375  const std::string getAsStr() const override { return "ICVTrackerFunction"; }
2376 
2377  // FIXME: come up with some stats.
2378  void trackStatistics() const override {}
2379 
2380  /// We don't manifest anything for this AA.
2381  ChangeStatus manifest(Attributor &A) override {
2382  return ChangeStatus::UNCHANGED;
2383  }
2384 
2385  // Map of ICV to their values at specific program point.
2387  InternalControlVar::ICV___last>
2388  ICVReplacementValuesMap;
2389 
2390  ChangeStatus updateImpl(Attributor &A) override {
2392 
2393  Function *F = getAnchorScope();
2394 
2395  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2396 
2397  for (InternalControlVar ICV : TrackableICVs) {
2398  auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2399 
2400  auto &ValuesMap = ICVReplacementValuesMap[ICV];
2401  auto TrackValues = [&](Use &U, Function &) {
2402  CallInst *CI = OpenMPOpt::getCallIfRegularCall(U);
2403  if (!CI)
2404  return false;
2405 
2406  // FIXME: handle setters with more that 1 arguments.
2407  /// Track new value.
2408  if (ValuesMap.insert(std::make_pair(CI, CI->getArgOperand(0))).second)
2409  HasChanged = ChangeStatus::CHANGED;
2410 
2411  return false;
2412  };
2413 
2414  auto CallCheck = [&](Instruction &I) {
2415  Optional<Value *> ReplVal = getValueForCall(A, I, ICV);
2416  if (ReplVal.hasValue() &&
2417  ValuesMap.insert(std::make_pair(&I, *ReplVal)).second)
2418  HasChanged = ChangeStatus::CHANGED;
2419 
2420  return true;
2421  };
2422 
2423  // Track all changes of an ICV.
2424  SetterRFI.foreachUse(TrackValues, F);
2425 
2426  bool UsedAssumedInformation = false;
2427  A.checkForAllInstructions(CallCheck, *this, {Instruction::Call},
2428  UsedAssumedInformation,
2429  /* CheckBBLivenessOnly */ true);
2430 
2431  /// TODO: Figure out a way to avoid adding entry in
2432  /// ICVReplacementValuesMap
2433  Instruction *Entry = &F->getEntryBlock().front();
2434  if (HasChanged == ChangeStatus::CHANGED && !ValuesMap.count(Entry))
2435  ValuesMap.insert(std::make_pair(Entry, nullptr));
2436  }
2437 
2438  return HasChanged;
2439  }
2440 
2441  /// Helper to check if \p I is a call and get the value for it if it is
2442  /// unique.
2443  Optional<Value *> getValueForCall(Attributor &A, const Instruction &I,
2444  InternalControlVar &ICV) const {
2445 
2446  const auto *CB = dyn_cast<CallBase>(&I);
2447  if (!CB || CB->hasFnAttr("no_openmp") ||
2448  CB->hasFnAttr("no_openmp_routines"))
2449  return None;
2450 
2451  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2452  auto &GetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Getter];
2453  auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2454  Function *CalledFunction = CB->getCalledFunction();
2455 
2456  // Indirect call, assume ICV changes.
2457  if (CalledFunction == nullptr)
2458  return nullptr;
2459  if (CalledFunction == GetterRFI.Declaration)
2460  return None;
2461  if (CalledFunction == SetterRFI.Declaration) {
2462  if (ICVReplacementValuesMap[ICV].count(&I))
2463  return ICVReplacementValuesMap[ICV].lookup(&I);
2464 
2465  return nullptr;
2466  }
2467 
2468  // Since we don't know, assume it changes the ICV.
2469  if (CalledFunction->isDeclaration())
2470  return nullptr;
2471 
2472  const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2474 
2475  if (ICVTrackingAA.isAssumedTracked()) {
2476  Optional<Value *> URV = ICVTrackingAA.getUniqueReplacementValue(ICV);
2477  if (!URV || (*URV && AA::isValidAtPosition(**URV, I, OMPInfoCache)))
2478  return URV;
2479  }
2480 
2481  // If we don't know, assume it changes.
2482  return nullptr;
2483  }
2484 
2485  // We don't check unique value for a function, so return None.
2487  getUniqueReplacementValue(InternalControlVar ICV) const override {
2488  return None;
2489  }
2490 
2491  /// Return the value with which \p I can be replaced for specific \p ICV.
2492  Optional<Value *> getReplacementValue(InternalControlVar ICV,
2493  const Instruction *I,
2494  Attributor &A) const override {
2495  const auto &ValuesMap = ICVReplacementValuesMap[ICV];
2496  if (ValuesMap.count(I))
2497  return ValuesMap.lookup(I);
2498 
2501  Worklist.push_back(I);
2502 
2503  Optional<Value *> ReplVal;
2504 
2505  while (!Worklist.empty()) {
2506  const Instruction *CurrInst = Worklist.pop_back_val();
2507  if (!Visited.insert(CurrInst).second)
2508  continue;
2509 
2510  const BasicBlock *CurrBB = CurrInst->getParent();
2511 
2512  // Go up and look for all potential setters/calls that might change the
2513  // ICV.
2514  while ((CurrInst = CurrInst->getPrevNode())) {
2515  if (ValuesMap.count(CurrInst)) {
2516  Optional<Value *> NewReplVal = ValuesMap.lookup(CurrInst);
2517  // Unknown value, track new.
2518  if (!ReplVal.hasValue()) {
2519  ReplVal = NewReplVal;
2520  break;
2521  }
2522 
2523  // If we found a new value, we can't know the icv value anymore.
2524  if (NewReplVal.hasValue())
2525  if (ReplVal != NewReplVal)
2526  return nullptr;
2527 
2528  break;
2529  }
2530 
2531  Optional<Value *> NewReplVal = getValueForCall(A, *CurrInst, ICV);
2532  if (!NewReplVal.hasValue())
2533  continue;
2534 
2535  // Unknown value, track new.
2536  if (!ReplVal.hasValue()) {
2537  ReplVal = NewReplVal;
2538  break;
2539  }
2540 
2541  // if (NewReplVal.hasValue())
2542  // We found a new value, we can't know the icv value anymore.
2543  if (ReplVal != NewReplVal)
2544  return nullptr;
2545  }
2546 
2547  // If we are in the same BB and we have a value, we are done.
2548  if (CurrBB == I->getParent() && ReplVal.hasValue())
2549  return ReplVal;
2550 
2551  // Go through all predecessors and add terminators for analysis.
2552  for (const BasicBlock *Pred : predecessors(CurrBB))
2553  if (const Instruction *Terminator = Pred->getTerminator())
2554  Worklist.push_back(Terminator);
2555  }
2556 
2557  return ReplVal;
2558  }
2559 };
2560 
2561 struct AAICVTrackerFunctionReturned : AAICVTracker {
2562  AAICVTrackerFunctionReturned(const IRPosition &IRP, Attributor &A)
2563  : AAICVTracker(IRP, A) {}
2564 
2565  // FIXME: come up with better string.
2566  const std::string getAsStr() const override {
2567  return "ICVTrackerFunctionReturned";
2568  }
2569 
2570  // FIXME: come up with some stats.
2571  void trackStatistics() const override {}
2572 
2573  /// We don't manifest anything for this AA.
2574  ChangeStatus manifest(Attributor &A) override {
2575  return ChangeStatus::UNCHANGED;
2576  }
2577 
2578  // Map of ICV to their values at specific program point.
2580  InternalControlVar::ICV___last>
2581  ICVReplacementValuesMap;
2582 
2583  /// Return the value with which \p I can be replaced for specific \p ICV.
2585  getUniqueReplacementValue(InternalControlVar ICV) const override {
2586  return ICVReplacementValuesMap[ICV];
2587  }
2588 
2589  ChangeStatus updateImpl(Attributor &A) override {
2591  const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2592  *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2593 
2594  if (!ICVTrackingAA.isAssumedTracked())
2595  return indicatePessimisticFixpoint();
2596 
2597  for (InternalControlVar ICV : TrackableICVs) {
2598  Optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2599  Optional<Value *> UniqueICVValue;
2600 
2601  auto CheckReturnInst = [&](Instruction &I) {
2602  Optional<Value *> NewReplVal =
2603  ICVTrackingAA.getReplacementValue(ICV, &I, A);
2604 
2605  // If we found a second ICV value there is no unique returned value.
2606  if (UniqueICVValue.hasValue() && UniqueICVValue != NewReplVal)
2607  return false;
2608 
2609  UniqueICVValue = NewReplVal;
2610 
2611  return true;
2612  };
2613 
2614  bool UsedAssumedInformation = false;
2615  if (!A.checkForAllInstructions(CheckReturnInst, *this, {Instruction::Ret},
2616  UsedAssumedInformation,
2617  /* CheckBBLivenessOnly */ true))
2618  UniqueICVValue = nullptr;
2619 
2620  if (UniqueICVValue == ReplVal)
2621  continue;
2622 
2623  ReplVal = UniqueICVValue;
2624  Changed = ChangeStatus::CHANGED;
2625  }
2626 
2627  return Changed;
2628  }
2629 };
2630 
2631 struct AAICVTrackerCallSite : AAICVTracker {
2632  AAICVTrackerCallSite(const IRPosition &IRP, Attributor &A)
2633  : AAICVTracker(IRP, A) {}
2634 
2635  void initialize(Attributor &A) override {
2636  Function *F = getAnchorScope();
2637  if (!F || !A.isFunctionIPOAmendable(*F))
2638  indicatePessimisticFixpoint();
2639 
2640  // We only initialize this AA for getters, so we need to know which ICV it
2641  // gets.
2642  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2643  for (InternalControlVar ICV : TrackableICVs) {
2644  auto ICVInfo = OMPInfoCache.ICVs[ICV];
2645  auto &Getter = OMPInfoCache.RFIs[ICVInfo.Getter];
2646  if (Getter.Declaration == getAssociatedFunction()) {
2647  AssociatedICV = ICVInfo.Kind;
2648  return;
2649  }
2650  }
2651 
2652  /// Unknown ICV.
2653  indicatePessimisticFixpoint();
2654  }
2655 
2656  ChangeStatus manifest(Attributor &A) override {
2657  if (!ReplVal.hasValue() || !ReplVal.getValue())
2658  return ChangeStatus::UNCHANGED;
2659 
2660  A.changeValueAfterManifest(*getCtxI(), **ReplVal);
2661  A.deleteAfterManifest(*getCtxI());
2662 
2663  return ChangeStatus::CHANGED;
2664  }
2665 
2666  // FIXME: come up with better string.
2667  const std::string getAsStr() const override { return "ICVTrackerCallSite"; }
2668 
2669  // FIXME: come up with some stats.
2670  void trackStatistics() const override {}
2671 
2672  InternalControlVar AssociatedICV;
2673  Optional<Value *> ReplVal;
2674 
2675  ChangeStatus updateImpl(Attributor &A) override {
2676  const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2677  *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2678 
2679  // We don't have any information, so we assume it changes the ICV.
2680  if (!ICVTrackingAA.isAssumedTracked())
2681  return indicatePessimisticFixpoint();
2682 
2683  Optional<Value *> NewReplVal =
2684  ICVTrackingAA.getReplacementValue(AssociatedICV, getCtxI(), A);
2685 
2686  if (ReplVal == NewReplVal)
2687  return ChangeStatus::UNCHANGED;
2688 
2689  ReplVal = NewReplVal;
2690  return ChangeStatus::CHANGED;
2691  }
2692 
2693  // Return the value with which associated value can be replaced for specific
2694  // \p ICV.
2696  getUniqueReplacementValue(InternalControlVar ICV) const override {
2697  return ReplVal;
2698  }
2699 };
2700 
2701 struct AAICVTrackerCallSiteReturned : AAICVTracker {
2702  AAICVTrackerCallSiteReturned(const IRPosition &IRP, Attributor &A)
2703  : AAICVTracker(IRP, A) {}
2704 
2705  // FIXME: come up with better string.
2706  const std::string getAsStr() const override {
2707  return "ICVTrackerCallSiteReturned";
2708  }
2709 
2710  // FIXME: come up with some stats.
2711  void trackStatistics() const override {}
2712 
2713  /// We don't manifest anything for this AA.
2714  ChangeStatus manifest(Attributor &A) override {
2715  return ChangeStatus::UNCHANGED;
2716  }
2717 
2718  // Map of ICV to their values at specific program point.
2720  InternalControlVar::ICV___last>
2721  ICVReplacementValuesMap;
2722 
2723  /// Return the value with which associated value can be replaced for specific
2724  /// \p ICV.
2726  getUniqueReplacementValue(InternalControlVar ICV) const override {
2727  return ICVReplacementValuesMap[ICV];
2728  }
2729 
2730  ChangeStatus updateImpl(Attributor &A) override {
2732  const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2733  *this, IRPosition::returned(*getAssociatedFunction()),
2735 
2736  // We don't have any information, so we assume it changes the ICV.
2737  if (!ICVTrackingAA.isAssumedTracked())
2738  return indicatePessimisticFixpoint();
2739 
2740  for (InternalControlVar ICV : TrackableICVs) {
2741  Optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2742  Optional<Value *> NewReplVal =
2743  ICVTrackingAA.getUniqueReplacementValue(ICV);
2744 
2745  if (ReplVal == NewReplVal)
2746  continue;
2747 
2748  ReplVal = NewReplVal;
2749  Changed = ChangeStatus::CHANGED;
2750  }
2751  return Changed;
2752  }
2753 };
2754 
2755 struct AAExecutionDomainFunction : public AAExecutionDomain {
2756  AAExecutionDomainFunction(const IRPosition &IRP, Attributor &A)
2757  : AAExecutionDomain(IRP, A) {}
2758 
2759  const std::string getAsStr() const override {
2760  return "[AAExecutionDomain] " + std::to_string(SingleThreadedBBs.size()) +
2761  "/" + std::to_string(NumBBs) + " BBs thread 0 only.";
2762  }
2763 
2764  /// See AbstractAttribute::trackStatistics().
2765  void trackStatistics() const override {}
2766 
2767  void initialize(Attributor &A) override {
2768  Function *F = getAnchorScope();
2769  for (const auto &BB : *F)
2770  SingleThreadedBBs.insert(&BB);
2771  NumBBs = SingleThreadedBBs.size();
2772  }
2773 
2774  ChangeStatus manifest(Attributor &A) override {
2775  LLVM_DEBUG({
2776  for (const BasicBlock *BB : SingleThreadedBBs)
2777  dbgs() << TAG << " Basic block @" << getAnchorScope()->getName() << " "
2778  << BB->getName() << " is executed by a single thread.\n";
2779  });
2780  return ChangeStatus::UNCHANGED;
2781  }
2782 
2783  ChangeStatus updateImpl(Attributor &A) override;
2784 
2785  /// Check if an instruction is executed by a single thread.
2786  bool isExecutedByInitialThreadOnly(const Instruction &I) const override {
2787  return isExecutedByInitialThreadOnly(*I.getParent());
2788  }
2789 
2790  bool isExecutedByInitialThreadOnly(const BasicBlock &BB) const override {
2791  return isValidState() && SingleThreadedBBs.contains(&BB);
2792  }
2793 
2794  /// Set of basic blocks that are executed by a single thread.
2795  SmallSetVector<const BasicBlock *, 16> SingleThreadedBBs;
2796 
2797  /// Total number of basic blocks in this function.
2798  long unsigned NumBBs = 0;
2799 };
2800 
2801 ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
2802  Function *F = getAnchorScope();
2804  auto NumSingleThreadedBBs = SingleThreadedBBs.size();
2805 
2806  bool AllCallSitesKnown;
2807  auto PredForCallSite = [&](AbstractCallSite ACS) {
2808  const auto &ExecutionDomainAA = A.getAAFor<AAExecutionDomain>(
2809  *this, IRPosition::function(*ACS.getInstruction()->getFunction()),
2811  return ACS.isDirectCall() &&
2812  ExecutionDomainAA.isExecutedByInitialThreadOnly(
2813  *ACS.getInstruction());
2814  };
2815 
2816  if (!A.checkForAllCallSites(PredForCallSite, *this,
2817  /* RequiresAllCallSites */ true,
2818  AllCallSitesKnown))
2819  SingleThreadedBBs.remove(&F->getEntryBlock());
2820 
2821  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2822  auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
2823 
2824  // Check if the edge into the successor block contains a condition that only
2825  // lets the main thread execute it.
2826  auto IsInitialThreadOnly = [&](BranchInst *Edge, BasicBlock *SuccessorBB) {
2827  if (!Edge || !Edge->isConditional())
2828  return false;
2829  if (Edge->getSuccessor(0) != SuccessorBB)
2830  return false;
2831 
2832  auto *Cmp = dyn_cast<CmpInst>(Edge->getCondition());
2833  if (!Cmp || !Cmp->isTrueWhenEqual() || !Cmp->isEquality())
2834  return false;
2835 
2836  ConstantInt *C = dyn_cast<ConstantInt>(Cmp->getOperand(1));
2837  if (!C)
2838  return false;
2839 
2840  // Match: -1 == __kmpc_target_init (for non-SPMD kernels only!)
2841  if (C->isAllOnesValue()) {
2842  auto *CB = dyn_cast<CallBase>(Cmp->getOperand(0));
2843  CB = CB ? OpenMPOpt::getCallIfRegularCall(*CB, &RFI) : nullptr;
2844  if (!CB)
2845  return false;
2846  const int InitModeArgNo = 1;
2847  auto *ModeCI = dyn_cast<ConstantInt>(CB->getOperand(InitModeArgNo));
2848  return ModeCI && (ModeCI->getSExtValue() & OMP_TGT_EXEC_MODE_GENERIC);
2849  }
2850 
2851  if (C->isZero()) {
2852  // Match: 0 == llvm.nvvm.read.ptx.sreg.tid.x()
2853  if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
2854  if (II->getIntrinsicID() == Intrinsic::nvvm_read_ptx_sreg_tid_x)
2855  return true;
2856 
2857  // Match: 0 == llvm.amdgcn.workitem.id.x()
2858  if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
2859  if (II->getIntrinsicID() == Intrinsic::amdgcn_workitem_id_x)
2860  return true;
2861  }
2862 
2863  return false;
2864  };
2865 
2866  // Merge all the predecessor states into the current basic block. A basic
2867  // block is executed by a single thread if all of its predecessors are.
2868  auto MergePredecessorStates = [&](BasicBlock *BB) {
2869  if (pred_empty(BB))
2870  return SingleThreadedBBs.contains(BB);
2871 
2872  bool IsInitialThread = true;
2873  for (BasicBlock *PredBB : predecessors(BB)) {
2874  if (!IsInitialThreadOnly(dyn_cast<BranchInst>(PredBB->getTerminator()),
2875  BB))
2876  IsInitialThread &= SingleThreadedBBs.contains(PredBB);
2877  }
2878 
2879  return IsInitialThread;
2880  };
2881 
2882  for (auto *BB : RPOT) {
2883  if (!MergePredecessorStates(BB))
2884  SingleThreadedBBs.remove(BB);
2885  }
2886 
2887  return (NumSingleThreadedBBs == SingleThreadedBBs.size())
2890 }
2891 
2892 /// Try to replace memory allocation calls called by a single thread with a
2893 /// static buffer of shared memory.
2894 struct AAHeapToShared : public StateWrapper<BooleanState, AbstractAttribute> {
2896  AAHeapToShared(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
2897 
2898  /// Create an abstract attribute view for the position \p IRP.
2899  static AAHeapToShared &createForPosition(const IRPosition &IRP,
2900  Attributor &A);
2901 
2902  /// Returns true if HeapToShared conversion is assumed to be possible.
2903  virtual bool isAssumedHeapToShared(CallBase &CB) const = 0;
2904 
2905  /// Returns true if HeapToShared conversion is assumed and the CB is a
2906  /// callsite to a free operation to be removed.
2907  virtual bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const = 0;
2908 
2909  /// See AbstractAttribute::getName().
2910  const std::string getName() const override { return "AAHeapToShared"; }
2911 
2912  /// See AbstractAttribute::getIdAddr().
2913  const char *getIdAddr() const override { return &ID; }
2914 
2915  /// This function should return true if the type of the \p AA is
2916  /// AAHeapToShared.
2917  static bool classof(const AbstractAttribute *AA) {
2918  return (AA->getIdAddr() == &ID);
2919  }
2920 
2921  /// Unique ID (due to the unique address)
2922  static const char ID;
2923 };
2924 
2925 struct AAHeapToSharedFunction : public AAHeapToShared {
2926  AAHeapToSharedFunction(const IRPosition &IRP, Attributor &A)
2927  : AAHeapToShared(IRP, A) {}
2928 
2929  const std::string getAsStr() const override {
2930  return "[AAHeapToShared] " + std::to_string(MallocCalls.size()) +
2931  " malloc calls eligible.";
2932  }
2933 
2934  /// See AbstractAttribute::trackStatistics().
2935  void trackStatistics() const override {}
2936 
2937  /// This functions finds free calls that will be removed by the
2938  /// HeapToShared transformation.
2939  void findPotentialRemovedFreeCalls(Attributor &A) {
2940  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2941  auto &FreeRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
2942 
2943  PotentialRemovedFreeCalls.clear();
2944  // Update free call users of found malloc calls.
2945  for (CallBase *CB : MallocCalls) {
2946  SmallVector<CallBase *, 4> FreeCalls;
2947  for (auto *U : CB->users()) {
2948  CallBase *C = dyn_cast<CallBase>(U);
2949  if (C && C->getCalledFunction() == FreeRFI.Declaration)
2950  FreeCalls.push_back(C);
2951  }
2952 
2953  if (FreeCalls.size() != 1)
2954  continue;
2955 
2956  PotentialRemovedFreeCalls.insert(FreeCalls.front());
2957  }
2958  }
2959 
2960  void initialize(Attributor &A) override {
2962  indicatePessimisticFixpoint();
2963  return;
2964  }
2965 
2966  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2967  auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
2968 
2970  [](const IRPosition &, const AbstractAttribute *,
2971  bool &) -> Optional<Value *> { return nullptr; };
2972  for (User *U : RFI.Declaration->users())
2973  if (CallBase *CB = dyn_cast<CallBase>(U)) {
2974  MallocCalls.insert(CB);
2975  A.registerSimplificationCallback(IRPosition::callsite_returned(*CB),
2976  SCB);
2977  }
2978 
2979  findPotentialRemovedFreeCalls(A);
2980  }
2981 
2982  bool isAssumedHeapToShared(CallBase &CB) const override {
2983  return isValidState() && MallocCalls.count(&CB);
2984  }
2985 
2986  bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const override {
2987  return isValidState() && PotentialRemovedFreeCalls.count(&CB);
2988  }
2989 
2990  ChangeStatus manifest(Attributor &A) override {
2991  if (MallocCalls.empty())
2992  return ChangeStatus::UNCHANGED;
2993 
2994  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2995  auto &FreeCall = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
2996 
2997  Function *F = getAnchorScope();
2998  auto *HS = A.lookupAAFor<AAHeapToStack>(IRPosition::function(*F), this,
3000 
3002  for (CallBase *CB : MallocCalls) {
3003  // Skip replacing this if HeapToStack has already claimed it.
3004  if (HS && HS->isAssumedHeapToStack(*CB))
3005  continue;
3006 
3007  // Find the unique free call to remove it.
3008  SmallVector<CallBase *, 4> FreeCalls;
3009  for (auto *U : CB->users()) {
3010  CallBase *C = dyn_cast<CallBase>(U);
3011  if (C && C->getCalledFunction() == FreeCall.Declaration)
3012  FreeCalls.push_back(C);
3013  }
3014  if (FreeCalls.size() != 1)
3015  continue;
3016 
3017  auto *AllocSize = cast<ConstantInt>(CB->getArgOperand(0));
3018 
3019  if (AllocSize->getZExtValue() + SharedMemoryUsed > SharedMemoryLimit) {
3020  LLVM_DEBUG(dbgs() << TAG << "Cannot replace call " << *CB
3021  << " with shared memory."
3022  << " Shared memory usage is limited to "
3023  << SharedMemoryLimit << " bytes\n");
3024  continue;
3025  }
3026 
3027  LLVM_DEBUG(dbgs() << TAG << "Replace globalization call " << *CB
3028  << " with " << AllocSize->getZExtValue()
3029  << " bytes of shared memory\n");
3030 
3031  // Create a new shared memory buffer of the same size as the allocation
3032  // and replace all the uses of the original allocation with it.
3033  Module *M = CB->getModule();
3034  Type *Int8Ty = Type::getInt8Ty(M->getContext());
3035  Type *Int8ArrTy = ArrayType::get(Int8Ty, AllocSize->getZExtValue());
3036  auto *SharedMem = new GlobalVariable(
3037  *M, Int8ArrTy, /* IsConstant */ false, GlobalValue::InternalLinkage,
3038  UndefValue::get(Int8ArrTy), CB->getName() + "_shared", nullptr,
3040  static_cast<unsigned>(AddressSpace::Shared));
3041  auto *NewBuffer =
3042  ConstantExpr::getPointerCast(SharedMem, Int8Ty->getPointerTo());
3043 
3044  auto Remark = [&](OptimizationRemark OR) {
3045  return OR << "Replaced globalized variable with "
3046  << ore::NV("SharedMemory", AllocSize->getZExtValue())
3047  << ((AllocSize->getZExtValue() != 1) ? " bytes " : " byte ")
3048  << "of shared memory.";
3049  };
3050  A.emitRemark<OptimizationRemark>(CB, "OMP111", Remark);
3051 
3052  MaybeAlign Alignment = CB->getRetAlign();
3053  assert(Alignment &&
3054  "HeapToShared on allocation without alignment attribute");
3055  SharedMem->setAlignment(MaybeAlign(Alignment));
3056 
3057  A.changeValueAfterManifest(*CB, *NewBuffer);
3058  A.deleteAfterManifest(*CB);
3059  A.deleteAfterManifest(*FreeCalls.front());
3060 
3061  SharedMemoryUsed += AllocSize->getZExtValue();
3062  NumBytesMovedToSharedMemory = SharedMemoryUsed;
3063  Changed = ChangeStatus::CHANGED;
3064  }
3065 
3066  return Changed;
3067  }
3068 
3069  ChangeStatus updateImpl(Attributor &A) override {
3070  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3071  auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3072  Function *F = getAnchorScope();
3073 
3074  auto NumMallocCalls = MallocCalls.size();
3075 
3076  // Only consider malloc calls executed by a single thread with a constant.
3077  for (User *U : RFI.Declaration->users()) {
3078  const auto &ED = A.getAAFor<AAExecutionDomain>(
3080  if (CallBase *CB = dyn_cast<CallBase>(U))
3081  if (!isa<ConstantInt>(CB->getArgOperand(0)) ||
3082  !ED.isExecutedByInitialThreadOnly(*CB))
3083  MallocCalls.remove(CB);
3084  }
3085 
3086  findPotentialRemovedFreeCalls(A);
3087 
3088  if (NumMallocCalls != MallocCalls.size())
3089  return ChangeStatus::CHANGED;
3090 
3091  return ChangeStatus::UNCHANGED;
3092  }
3093 
3094  /// Collection of all malloc calls in a function.
3095  SmallSetVector<CallBase *, 4> MallocCalls;
3096  /// Collection of potentially removed free calls in a function.
3097  SmallPtrSet<CallBase *, 4> PotentialRemovedFreeCalls;
3098  /// The total amount of shared memory that has been used for HeapToShared.
3099  unsigned SharedMemoryUsed = 0;
3100 };
3101 
3102 struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
3104  AAKernelInfo(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
3105 
3106  /// Statistics are tracked as part of manifest for now.
3107  void trackStatistics() const override {}
3108 
3109  /// See AbstractAttribute::getAsStr()
3110  const std::string getAsStr() const override {
3111  if (!isValidState())
3112  return "<invalid>";
3113  return std::string(SPMDCompatibilityTracker.isAssumed() ? "SPMD"
3114  : "generic") +
3115  std::string(SPMDCompatibilityTracker.isAtFixpoint() ? " [FIX]"
3116  : "") +
3117  std::string(" #PRs: ") +
3118  (ReachedKnownParallelRegions.isValidState()
3119  ? std::to_string(ReachedKnownParallelRegions.size())
3120  : "<invalid>") +
3121  ", #Unknown PRs: " +
3122  (ReachedUnknownParallelRegions.isValidState()
3123  ? std::to_string(ReachedUnknownParallelRegions.size())
3124  : "<invalid>") +
3125  ", #Reaching Kernels: " +
3126  (ReachingKernelEntries.isValidState()
3127  ? std::to_string(ReachingKernelEntries.size())
3128  : "<invalid>");
3129  }
3130 
3131  /// Create an abstract attribute biew for the position \p IRP.
3132  static AAKernelInfo &createForPosition(const IRPosition &IRP, Attributor &A);
3133 
3134  /// See AbstractAttribute::getName()
3135  const std::string getName() const override { return "AAKernelInfo"; }
3136 
3137  /// See AbstractAttribute::getIdAddr()
3138  const char *getIdAddr() const override { return &ID; }
3139 
3140  /// This function should return true if the type of the \p AA is AAKernelInfo
3141  static bool classof(const AbstractAttribute *AA) {
3142  return (AA->getIdAddr() == &ID);
3143  }
3144 
3145  static const char ID;
3146 };
3147 
3148 /// The function kernel info abstract attribute, basically, what can we say
3149 /// about a function with regards to the KernelInfoState.
3150 struct AAKernelInfoFunction : AAKernelInfo {
3151  AAKernelInfoFunction(const IRPosition &IRP, Attributor &A)
3152  : AAKernelInfo(IRP, A) {}
3153 
3154  SmallPtrSet<Instruction *, 4> GuardedInstructions;
3155 
3156  SmallPtrSetImpl<Instruction *> &getGuardedInstructions() {
3157  return GuardedInstructions;
3158  }
3159 
3160  /// See AbstractAttribute::initialize(...).
3161  void initialize(Attributor &A) override {
3162  // This is a high-level transform that might change the constant arguments
3163  // of the init and dinit calls. We need to tell the Attributor about this
3164  // to avoid other parts using the current constant value for simpliication.
3165  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3166 
3167  Function *Fn = getAnchorScope();
3168 
3169  OMPInformationCache::RuntimeFunctionInfo &InitRFI =
3170  OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
3171  OMPInformationCache::RuntimeFunctionInfo &DeinitRFI =
3172  OMPInfoCache.RFIs[OMPRTL___kmpc_target_deinit];
3173 
3174  // For kernels we perform more initialization work, first we find the init
3175  // and deinit calls.
3176  auto StoreCallBase = [](Use &U,
3177  OMPInformationCache::RuntimeFunctionInfo &RFI,
3178  CallBase *&Storage) {
3179  CallBase *CB = OpenMPOpt::getCallIfRegularCall(U, &RFI);
3180  assert(CB &&
3181  "Unexpected use of __kmpc_target_init or __kmpc_target_deinit!");
3182  assert(!Storage &&
3183  "Multiple uses of __kmpc_target_init or __kmpc_target_deinit!");
3184  Storage = CB;
3185  return false;
3186  };
3187  InitRFI.foreachUse(
3188  [&](Use &U, Function &) {
3189  StoreCallBase(U, InitRFI, KernelInitCB);
3190  return false;
3191  },
3192  Fn);
3193  DeinitRFI.foreachUse(
3194  [&](Use &U, Function &) {
3195  StoreCallBase(U, DeinitRFI, KernelDeinitCB);
3196  return false;
3197  },
3198  Fn);
3199 
3200  // Ignore kernels without initializers such as global constructors.
3201  if (!KernelInitCB || !KernelDeinitCB)
3202  return;
3203 
3204  // Add itself to the reaching kernel and set IsKernelEntry.
3205  ReachingKernelEntries.insert(Fn);
3206  IsKernelEntry = true;
3207 
3208  // For kernels we might need to initialize/finalize the IsSPMD state and
3209  // we need to register a simplification callback so that the Attributor
3210  // knows the constant arguments to __kmpc_target_init and
3211  // __kmpc_target_deinit might actually change.
3212 
3213  Attributor::SimplifictionCallbackTy StateMachineSimplifyCB =
3214  [&](const IRPosition &IRP, const AbstractAttribute *AA,
3215  bool &UsedAssumedInformation) -> Optional<Value *> {
3216  // IRP represents the "use generic state machine" argument of an
3217  // __kmpc_target_init call. We will answer this one with the internal
3218  // state. As long as we are not in an invalid state, we will create a
3219  // custom state machine so the value should be a `i1 false`. If we are
3220  // in an invalid state, we won't change the value that is in the IR.
3221  if (!ReachedKnownParallelRegions.isValidState())
3222  return nullptr;
3223  // If we have disabled state machine rewrites, don't make a custom one.
3225  return nullptr;
3226  if (AA)
3227  A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
3228  UsedAssumedInformation = !isAtFixpoint();
3229  auto *FalseVal =
3231  return FalseVal;
3232  };
3233 
3234  Attributor::SimplifictionCallbackTy ModeSimplifyCB =
3235  [&](const IRPosition &IRP, const AbstractAttribute *AA,
3236  bool &UsedAssumedInformation) -> Optional<Value *> {
3237  // IRP represents the "SPMDCompatibilityTracker" argument of an
3238  // __kmpc_target_init or
3239  // __kmpc_target_deinit call. We will answer this one with the internal
3240  // state.
3241  if (!SPMDCompatibilityTracker.isValidState())
3242  return nullptr;
3243  if (!SPMDCompatibilityTracker.isAtFixpoint()) {
3244  if (AA)
3245  A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
3246  UsedAssumedInformation = true;
3247  } else {
3248  UsedAssumedInformation = false;
3249  }
3250  auto *Val = ConstantInt::getSigned(
3252  SPMDCompatibilityTracker.isAssumed() ? OMP_TGT_EXEC_MODE_SPMD
3254  return Val;
3255  };
3256 
3257  Attributor::SimplifictionCallbackTy IsGenericModeSimplifyCB =
3258  [&](const IRPosition &IRP, const AbstractAttribute *AA,
3259  bool &UsedAssumedInformation) -> Optional<Value *> {
3260  // IRP represents the "RequiresFullRuntime" argument of an
3261  // __kmpc_target_init or __kmpc_target_deinit call. We will answer this
3262  // one with the internal state of the SPMDCompatibilityTracker, so if
3263  // generic then true, if SPMD then false.
3264  if (!SPMDCompatibilityTracker.isValidState())
3265  return nullptr;
3266  if (!SPMDCompatibilityTracker.isAtFixpoint()) {
3267  if (AA)
3268  A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
3269  UsedAssumedInformation = true;
3270  } else {
3271  UsedAssumedInformation = false;
3272  }
3273  auto *Val = ConstantInt::getBool(IRP.getAnchorValue().getContext(),
3274  !SPMDCompatibilityTracker.isAssumed());
3275  return Val;
3276  };
3277 
3278  constexpr const int InitModeArgNo = 1;
3279  constexpr const int DeinitModeArgNo = 1;
3280  constexpr const int InitUseStateMachineArgNo = 2;
3281  constexpr const int InitRequiresFullRuntimeArgNo = 3;
3282  constexpr const int DeinitRequiresFullRuntimeArgNo = 2;
3283  A.registerSimplificationCallback(
3284  IRPosition::callsite_argument(*KernelInitCB, InitUseStateMachineArgNo),
3285  StateMachineSimplifyCB);
3286  A.registerSimplificationCallback(
3287  IRPosition::callsite_argument(*KernelInitCB, InitModeArgNo),
3288  ModeSimplifyCB);
3289  A.registerSimplificationCallback(
3290  IRPosition::callsite_argument(*KernelDeinitCB, DeinitModeArgNo),
3291  ModeSimplifyCB);
3292  A.registerSimplificationCallback(
3293  IRPosition::callsite_argument(*KernelInitCB,
3294  InitRequiresFullRuntimeArgNo),
3295  IsGenericModeSimplifyCB);
3296  A.registerSimplificationCallback(
3297  IRPosition::callsite_argument(*KernelDeinitCB,
3298  DeinitRequiresFullRuntimeArgNo),
3299  IsGenericModeSimplifyCB);
3300 
3301  // Check if we know we are in SPMD-mode already.
3302  ConstantInt *ModeArg =
3303  dyn_cast<ConstantInt>(KernelInitCB->getArgOperand(InitModeArgNo));
3304  if (ModeArg && (ModeArg->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD))
3305  SPMDCompatibilityTracker.indicateOptimisticFixpoint();
3306  // This is a generic region but SPMDization is disabled so stop tracking.
3307  else if (DisableOpenMPOptSPMDization)
3308  SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3309  }
3310 
3311  /// Sanitize the string \p S such that it is a suitable global symbol name.
3312  static std::string sanitizeForGlobalName(std::string S) {
3313  std::replace_if(
3314  S.begin(), S.end(),
3315  [](const char C) {
3316  return !((C >= 'a' && C <= 'z') || (C >= 'A' && C <= 'Z') ||
3317  (C >= '0' && C <= '9') || C == '_');
3318  },
3319  '.');
3320  return S;
3321  }
3322 
3323  /// Modify the IR based on the KernelInfoState as the fixpoint iteration is
3324  /// finished now.
3325  ChangeStatus manifest(Attributor &A) override {
3326  // If we are not looking at a kernel with __kmpc_target_init and
3327  // __kmpc_target_deinit call we cannot actually manifest the information.
3328  if (!KernelInitCB || !KernelDeinitCB)
3329  return ChangeStatus::UNCHANGED;
3330 
3331  // If we can we change the execution mode to SPMD-mode otherwise we build a
3332  // custom state machine.
3334  if (!changeToSPMDMode(A, Changed))
3335  return buildCustomStateMachine(A);
3336 
3337  return Changed;
3338  }
3339 
3340  bool changeToSPMDMode(Attributor &A, ChangeStatus &Changed) {
3341  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3342 
3343  if (!SPMDCompatibilityTracker.isAssumed()) {
3344  for (Instruction *NonCompatibleI : SPMDCompatibilityTracker) {
3345  if (!NonCompatibleI)
3346  continue;
3347 
3348  // Skip diagnostics on calls to known OpenMP runtime functions for now.
3349  if (auto *CB = dyn_cast<CallBase>(NonCompatibleI))
3350  if (OMPInfoCache.RTLFunctions.contains(CB->getCalledFunction()))
3351  continue;
3352 
3353  auto Remark = [&](OptimizationRemarkAnalysis ORA) {
3354  ORA << "Value has potential side effects preventing SPMD-mode "
3355  "execution";
3356  if (isa<CallBase>(NonCompatibleI)) {
3357  ORA << ". Add `__attribute__((assume(\"ompx_spmd_amenable\")))` to "
3358  "the called function to override";
3359  }
3360  return ORA << ".";
3361  };
3362  A.emitRemark<OptimizationRemarkAnalysis>(NonCompatibleI, "OMP121",
3363  Remark);
3364 
3365  LLVM_DEBUG(dbgs() << TAG << "SPMD-incompatible side-effect: "
3366  << *NonCompatibleI << "\n");
3367  }
3368 
3369  return false;
3370  }
3371 
3372  // Get the actual kernel, could be the caller of the anchor scope if we have
3373  // a debug wrapper.
3374  Function *Kernel = getAnchorScope();
3375  if (Kernel->hasLocalLinkage()) {
3376  assert(Kernel->hasOneUse() && "Unexpected use of debug kernel wrapper.");
3377  auto *CB = cast<CallBase>(Kernel->user_back());
3378  Kernel = CB->getCaller();
3379  }
3380  assert(OMPInfoCache.Kernels.count(Kernel) && "Expected kernel function!");
3381 
3382  // Check if the kernel is already in SPMD mode, if so, return success.
3384  (Kernel->getName() + "_exec_mode").str());
3385  assert(ExecMode && "Kernel without exec mode?");
3386  assert(ExecMode->getInitializer() && "ExecMode doesn't have initializer!");
3387 
3388  // Set the global exec mode flag to indicate SPMD-Generic mode.
3389  assert(isa<ConstantInt>(ExecMode->getInitializer()) &&
3390  "ExecMode is not an integer!");
3391  const int8_t ExecModeVal =
3392  cast<ConstantInt>(ExecMode->getInitializer())->getSExtValue();
3393  if (ExecModeVal != OMP_TGT_EXEC_MODE_GENERIC)
3394  return true;
3395 
3396  // We will now unconditionally modify the IR, indicate a change.
3397  Changed = ChangeStatus::CHANGED;
3398 
3399  auto CreateGuardedRegion = [&](Instruction *RegionStartI,
3400  Instruction *RegionEndI) {
3401  LoopInfo *LI = nullptr;
3402  DominatorTree *DT = nullptr;
3403  MemorySSAUpdater *MSU = nullptr;
3404  using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
3405 
3406  BasicBlock *ParentBB = RegionStartI->getParent();
3407  Function *Fn = ParentBB->getParent();
3408  Module &M = *Fn->getParent();
3409 
3410  // Create all the blocks and logic.
3411  // ParentBB:
3412  // goto RegionCheckTidBB
3413  // RegionCheckTidBB:
3414  // Tid = __kmpc_hardware_thread_id()
3415  // if (Tid != 0)
3416  // goto RegionBarrierBB
3417  // RegionStartBB:
3418  // <execute instructions guarded>
3419  // goto RegionEndBB
3420  // RegionEndBB:
3421  // <store escaping values to shared mem>
3422  // goto RegionBarrierBB
3423  // RegionBarrierBB:
3424  // __kmpc_simple_barrier_spmd()
3425  // // second barrier is omitted if lacking escaping values.
3426  // <load escaping values from shared mem>
3427  // __kmpc_simple_barrier_spmd()
3428  // goto RegionExitBB
3429  // RegionExitBB:
3430  // <execute rest of instructions>
3431 
3432  BasicBlock *RegionEndBB = SplitBlock(ParentBB, RegionEndI->getNextNode(),
3433  DT, LI, MSU, "region.guarded.end");
3434  BasicBlock *RegionBarrierBB =
3435  SplitBlock(RegionEndBB, &*RegionEndBB->getFirstInsertionPt(), DT, LI,
3436  MSU, "region.barrier");
3437  BasicBlock *RegionExitBB =
3438  SplitBlock(RegionBarrierBB, &*RegionBarrierBB->getFirstInsertionPt(),
3439  DT, LI, MSU, "region.exit");
3440  BasicBlock *RegionStartBB =
3441  SplitBlock(ParentBB, RegionStartI, DT, LI, MSU, "region.guarded");
3442 
3443  assert(ParentBB->getUniqueSuccessor() == RegionStartBB &&
3444  "Expected a different CFG");
3445 
3446  BasicBlock *RegionCheckTidBB = SplitBlock(
3447  ParentBB, ParentBB->getTerminator(), DT, LI, MSU, "region.check.tid");
3448 
3449  // Register basic blocks with the Attributor.
3450  A.registerManifestAddedBasicBlock(*RegionEndBB);
3451  A.registerManifestAddedBasicBlock(*RegionBarrierBB);
3452  A.registerManifestAddedBasicBlock(*RegionExitBB);
3453  A.registerManifestAddedBasicBlock(*RegionStartBB);
3454  A.registerManifestAddedBasicBlock(*RegionCheckTidBB);
3455 
3456  bool HasBroadcastValues = false;
3457  // Find escaping outputs from the guarded region to outside users and
3458  // broadcast their values to them.
3459  for (Instruction &I : *RegionStartBB) {
3460  SmallPtrSet<Instruction *, 4> OutsideUsers;
3461  for (User *Usr : I.users()) {
3462  Instruction &UsrI = *cast<Instruction>(Usr);
3463  if (UsrI.getParent() != RegionStartBB)
3464  OutsideUsers.insert(&UsrI);
3465  }
3466 
3467  if (OutsideUsers.empty())
3468  continue;
3469 
3470  HasBroadcastValues = true;
3471 
3472  // Emit a global variable in shared memory to store the broadcasted
3473  // value.
3474  auto *SharedMem = new GlobalVariable(
3475  M, I.getType(), /* IsConstant */ false,
3477  sanitizeForGlobalName(
3478  (I.getName() + ".guarded.output.alloc").str()),
3479  nullptr, GlobalValue::NotThreadLocal,
3480  static_cast<unsigned>(AddressSpace::Shared));
3481 
3482  // Emit a store instruction to update the value.
3483  new StoreInst(&I, SharedMem, RegionEndBB->getTerminator());
3484 
3485  LoadInst *LoadI = new LoadInst(I.getType(), SharedMem,
3486  I.getName() + ".guarded.output.load",
3487  RegionBarrierBB->getTerminator());
3488 
3489  // Emit a load instruction and replace uses of the output value.
3490  for (Instruction *UsrI : OutsideUsers)
3491  UsrI->replaceUsesOfWith(&I, LoadI);
3492  }
3493 
3494  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3495 
3496  // Go to tid check BB in ParentBB.
3497  const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
3498  ParentBB->getTerminator()->eraseFromParent();
3500  InsertPointTy(ParentBB, ParentBB->end()), DL);
3501  OMPInfoCache.OMPBuilder.updateToLocation(Loc);
3502  uint32_t SrcLocStrSize;
3503  auto *SrcLocStr =
3504  OMPInfoCache.OMPBuilder.getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3505  Value *Ident =
3506  OMPInfoCache.OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3507  BranchInst::Create(RegionCheckTidBB, ParentBB)->setDebugLoc(DL);
3508 
3509  // Add check for Tid in RegionCheckTidBB
3510  RegionCheckTidBB->getTerminator()->eraseFromParent();
3511  OpenMPIRBuilder::LocationDescription LocRegionCheckTid(
3512  InsertPointTy(RegionCheckTidBB, RegionCheckTidBB->end()), DL);
3513  OMPInfoCache.OMPBuilder.updateToLocation(LocRegionCheckTid);
3514  FunctionCallee HardwareTidFn =
3515  OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3516  M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
3517  CallInst *Tid =
3518  OMPInfoCache.OMPBuilder.Builder.CreateCall(HardwareTidFn, {});
3519  Tid->setDebugLoc(DL);
3520  OMPInfoCache.setCallingConvention(HardwareTidFn, Tid);
3521  Value *TidCheck = OMPInfoCache.OMPBuilder.Builder.CreateIsNull(Tid);
3522  OMPInfoCache.OMPBuilder.Builder
3523  .CreateCondBr(TidCheck, RegionStartBB, RegionBarrierBB)
3524  ->setDebugLoc(DL);
3525 
3526  // First barrier for synchronization, ensures main thread has updated
3527  // values.
3528  FunctionCallee BarrierFn =
3529  OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3530  M, OMPRTL___kmpc_barrier_simple_spmd);
3531  OMPInfoCache.OMPBuilder.updateToLocation(InsertPointTy(
3532  RegionBarrierBB, RegionBarrierBB->getFirstInsertionPt()));
3533  CallInst *Barrier =
3534  OMPInfoCache.OMPBuilder.Builder.CreateCall(BarrierFn, {Ident, Tid});
3535  Barrier->setDebugLoc(DL);
3536  OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
3537 
3538  // Second barrier ensures workers have read broadcast values.
3539  if (HasBroadcastValues) {
3540  CallInst *Barrier = CallInst::Create(BarrierFn, {Ident, Tid}, "",
3541  RegionBarrierBB->getTerminator());
3542  Barrier->setDebugLoc(DL);
3543  OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
3544  }
3545  };
3546 
3547  auto &AllocSharedRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3549  for (Instruction *GuardedI : SPMDCompatibilityTracker) {
3550  BasicBlock *BB = GuardedI->getParent();
3551  if (!Visited.insert(BB).second)
3552  continue;
3553 
3555  Instruction *LastEffect = nullptr;
3556  BasicBlock::reverse_iterator IP = BB->rbegin(), IPEnd = BB->rend();
3557  while (++IP != IPEnd) {
3558  if (!IP->mayHaveSideEffects() && !IP->mayReadFromMemory())
3559  continue;
3560  Instruction *I = &*IP;
3561  if (OpenMPOpt::getCallIfRegularCall(*I, &AllocSharedRFI))
3562  continue;
3563  if (!I->user_empty() || !SPMDCompatibilityTracker.contains(I)) {
3564  LastEffect = nullptr;
3565  continue;
3566  }
3567  if (LastEffect)
3568  Reorders.push_back({I, LastEffect});
3569  LastEffect = &*IP;
3570  }
3571  for (auto &Reorder : Reorders)
3572  Reorder.first->moveBefore(Reorder.second);
3573  }
3574 
3576 
3577  for (Instruction *GuardedI : SPMDCompatibilityTracker) {
3578  BasicBlock *BB = GuardedI->getParent();
3579  auto *CalleeAA = A.lookupAAFor<AAKernelInfo>(
3580  IRPosition::function(*GuardedI->getFunction()), nullptr,
3582  assert(CalleeAA != nullptr && "Expected Callee AAKernelInfo");
3583  auto &CalleeAAFunction = *cast<AAKernelInfoFunction>(CalleeAA);
3584  // Continue if instruction is already guarded.
3585  if (CalleeAAFunction.getGuardedInstructions().contains(GuardedI))
3586  continue;
3587 
3588  Instruction *GuardedRegionStart = nullptr, *GuardedRegionEnd = nullptr;
3589  for (Instruction &I : *BB) {
3590  // If instruction I needs to be guarded update the guarded region
3591  // bounds.
3592  if (SPMDCompatibilityTracker.contains(&I)) {
3593  CalleeAAFunction.getGuardedInstructions().insert(&I);
3594  if (GuardedRegionStart)
3595  GuardedRegionEnd = &I;
3596  else
3597  GuardedRegionStart = GuardedRegionEnd = &I;
3598 
3599  continue;
3600  }
3601 
3602  // Instruction I does not need guarding, store
3603  // any region found and reset bounds.
3604  if (GuardedRegionStart) {
3605  GuardedRegions.push_back(
3606  std::make_pair(GuardedRegionStart, GuardedRegionEnd));
3607  GuardedRegionStart = nullptr;
3608  GuardedRegionEnd = nullptr;
3609  }
3610  }
3611  }
3612 
3613  for (auto &GR : GuardedRegions)
3614  CreateGuardedRegion(GR.first, GR.second);
3615 
3616  // Adjust the global exec mode flag that tells the runtime what mode this
3617  // kernel is executed in.
3618  assert(ExecModeVal == OMP_TGT_EXEC_MODE_GENERIC &&
3619  "Initially non-SPMD kernel has SPMD exec mode!");
3620  ExecMode->setInitializer(
3621  ConstantInt::get(ExecMode->getInitializer()->getType(),
3622  ExecModeVal | OMP_TGT_EXEC_MODE_GENERIC_SPMD));
3623 
3624  // Next rewrite the init and deinit calls to indicate we use SPMD-mode now.
3625  const int InitModeArgNo = 1;
3626  const int DeinitModeArgNo = 1;
3627  const int InitUseStateMachineArgNo = 2;
3628  const int InitRequiresFullRuntimeArgNo = 3;
3629  const int DeinitRequiresFullRuntimeArgNo = 2;
3630 
3631  auto &Ctx = getAnchorValue().getContext();
3632  A.changeUseAfterManifest(
3633  KernelInitCB->getArgOperandUse(InitModeArgNo),
3636  A.changeUseAfterManifest(
3637  KernelInitCB->getArgOperandUse(InitUseStateMachineArgNo),
3638  *ConstantInt::getBool(Ctx, false));
3639  A.changeUseAfterManifest(
3640  KernelDeinitCB->getArgOperandUse(DeinitModeArgNo),
3643  A.changeUseAfterManifest(
3644  KernelInitCB->getArgOperandUse(InitRequiresFullRuntimeArgNo),
3645  *ConstantInt::getBool(Ctx, false));
3646  A.changeUseAfterManifest(
3647  KernelDeinitCB->getArgOperandUse(DeinitRequiresFullRuntimeArgNo),
3648  *ConstantInt::getBool(Ctx, false));
3649 
3650  ++NumOpenMPTargetRegionKernelsSPMD;
3651 
3652  auto Remark = [&](OptimizationRemark OR) {
3653  return OR << "Transformed generic-mode kernel to SPMD-mode.";
3654  };
3655  A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP120", Remark);
3656  return true;
3657  };
3658 
3659  ChangeStatus buildCustomStateMachine(Attributor &A) {
3660  // If we have disabled state machine rewrites, don't make a custom one
3662  return ChangeStatus::UNCHANGED;
3663 
3664  // Don't rewrite the state machine if we are not in a valid state.
3665  if (!ReachedKnownParallelRegions.isValidState())
3666  return ChangeStatus::UNCHANGED;
3667 
3668  const int InitModeArgNo = 1;
3669  const int InitUseStateMachineArgNo = 2;
3670 
3671  // Check if the current configuration is non-SPMD and generic state machine.
3672  // If we already have SPMD mode or a custom state machine we do not need to
3673  // go any further. If it is anything but a constant something is weird and
3674  // we give up.
3675  ConstantInt *UseStateMachine = dyn_cast<ConstantInt>(
3676  KernelInitCB->getArgOperand(InitUseStateMachineArgNo));
3677  ConstantInt *Mode =
3678  dyn_cast<ConstantInt>(KernelInitCB->getArgOperand(InitModeArgNo));
3679 
3680  // If we are stuck with generic mode, try to create a custom device (=GPU)
3681  // state machine which is specialized for the parallel regions that are
3682  // reachable by the kernel.
3683  if (!UseStateMachine || UseStateMachine->isZero() || !Mode ||
3684  (Mode->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD))
3685  return ChangeStatus::UNCHANGED;
3686 
3687  // If not SPMD mode, indicate we use a custom state machine now.
3688  auto &Ctx = getAnchorValue().getContext();
3689  auto *FalseVal = ConstantInt::getBool(Ctx, false);
3690  A.changeUseAfterManifest(
3691  KernelInitCB->getArgOperandUse(InitUseStateMachineArgNo), *FalseVal);
3692 
3693  // If we don't actually need a state machine we are done here. This can
3694  // happen if there simply are no parallel regions. In the resulting kernel
3695  // all worker threads will simply exit right away, leaving the main thread
3696  // to do the work alone.
3697  if (!mayContainParallelRegion()) {
3698  ++NumOpenMPTargetRegionKernelsWithoutStateMachine;
3699 
3700  auto Remark = [&](OptimizationRemark OR) {
3701  return OR << "Removing unused state machine from generic-mode kernel.";
3702  };
3703  A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP130", Remark);
3704 
3705  return ChangeStatus::CHANGED;
3706  }
3707 
3708  // Keep track in the statistics of our new shiny custom state machine.
3709  if (ReachedUnknownParallelRegions.empty()) {
3710  ++NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback;
3711 
3712  auto Remark = [&](OptimizationRemark OR) {
3713  return OR << "Rewriting generic-mode kernel with a customized state "
3714  "machine.";
3715  };
3716  A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP131", Remark);
3717  } else {
3718  ++NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback;
3719 
3720  auto Remark = [&](OptimizationRemarkAnalysis OR) {
3721  return OR << "Generic-mode kernel is executed with a customized state "
3722  "machine that requires a fallback.";
3723  };
3724  A.emitRemark<OptimizationRemarkAnalysis>(KernelInitCB, "OMP132", Remark);
3725 
3726  // Tell the user why we ended up with a fallback.
3727  for (CallBase *UnknownParallelRegionCB : ReachedUnknownParallelRegions) {
3728  if (!UnknownParallelRegionCB)
3729  continue;
3730  auto Remark = [&](OptimizationRemarkAnalysis ORA) {
3731  return ORA << "Call may contain unknown parallel regions. Use "
3732  << "`__attribute__((assume(\"omp_no_parallelism\")))` to "
3733  "override.";
3734  };
3735  A.emitRemark<OptimizationRemarkAnalysis>(UnknownParallelRegionCB,
3736  "OMP133", Remark);
3737  }
3738  }
3739 
3740  // Create all the blocks:
3741  //
3742  // InitCB = __kmpc_target_init(...)
3743  // BlockHwSize =
3744  // __kmpc_get_hardware_num_threads_in_block();
3745  // WarpSize = __kmpc_get_warp_size();
3746  // BlockSize = BlockHwSize - WarpSize;
3747  // IsWorkerCheckBB: bool IsWorker = InitCB != -1;
3748  // if (IsWorker) {
3749  // if (InitCB >= BlockSize) return;
3750  // SMBeginBB: __kmpc_barrier_simple_generic(...);
3751  // void *WorkFn;
3752  // bool Active = __kmpc_kernel_parallel(&WorkFn);
3753  // if (!WorkFn) return;
3754  // SMIsActiveCheckBB: if (Active) {
3755  // SMIfCascadeCurrentBB: if (WorkFn == <ParFn0>)
3756  // ParFn0(...);
3757  // SMIfCascadeCurrentBB: else if (WorkFn == <ParFn1>)
3758  // ParFn1(...);
3759  // ...
3760  // SMIfCascadeCurrentBB: else
3761  // ((WorkFnTy*)WorkFn)(...);
3762  // SMEndParallelBB: __kmpc_kernel_end_parallel(...);
3763  // }
3764  // SMDoneBB: __kmpc_barrier_simple_generic(...);
3765  // goto SMBeginBB;
3766  // }
3767  // UserCodeEntryBB: // user code
3768  // __kmpc_target_deinit(...)
3769  //
3770  Function *Kernel = getAssociatedFunction();
3771  assert(Kernel && "Expected an associated function!");
3772 
3773  BasicBlock *InitBB = KernelInitCB->getParent();
3774  BasicBlock *UserCodeEntryBB = InitBB->splitBasicBlock(
3775  KernelInitCB->getNextNode(), "thread.user_code.check");
3776  BasicBlock *IsWorkerCheckBB =
3777  BasicBlock::Create(Ctx, "is_worker_check", Kernel, UserCodeEntryBB);
3778  BasicBlock *StateMachineBeginBB = BasicBlock::Create(
3779  Ctx, "worker_state_machine.begin", Kernel, UserCodeEntryBB);
3780  BasicBlock *StateMachineFinishedBB = BasicBlock::Create(
3781  Ctx, "worker_state_machine.finished", Kernel, UserCodeEntryBB);
3782  BasicBlock *StateMachineIsActiveCheckBB = BasicBlock::Create(
3783  Ctx, "worker_state_machine.is_active.check", Kernel, UserCodeEntryBB);
3784  BasicBlock *StateMachineIfCascadeCurrentBB =
3785  BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
3786  Kernel, UserCodeEntryBB);
3787  BasicBlock *StateMachineEndParallelBB =
3788  BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.end",
3789  Kernel, UserCodeEntryBB);
3790  BasicBlock *StateMachineDoneBarrierBB = BasicBlock::Create(
3791  Ctx, "worker_state_machine.done.barrier", Kernel, UserCodeEntryBB);
3792  A.registerManifestAddedBasicBlock(*InitBB);
3793  A.registerManifestAddedBasicBlock(*UserCodeEntryBB);
3794  A.registerManifestAddedBasicBlock(*IsWorkerCheckBB);
3795  A.registerManifestAddedBasicBlock(*StateMachineBeginBB);
3796  A.registerManifestAddedBasicBlock(*StateMachineFinishedBB);
3797  A.registerManifestAddedBasicBlock(*StateMachineIsActiveCheckBB);
3798  A.registerManifestAddedBasicBlock(*StateMachineIfCascadeCurrentBB);
3799  A.registerManifestAddedBasicBlock(*StateMachineEndParallelBB);
3800  A.registerManifestAddedBasicBlock(*StateMachineDoneBarrierBB);
3801 
3802  const DebugLoc &DLoc = KernelInitCB->getDebugLoc();
3803  ReturnInst::Create(Ctx, StateMachineFinishedBB)->setDebugLoc(DLoc);
3804  InitBB->getTerminator()->eraseFromParent();
3805 
3806  Instruction *IsWorker =
3807  ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_NE, KernelInitCB,
3808  ConstantInt::get(KernelInitCB->getType(), -1),
3809  "thread.is_worker", InitBB);
3810  IsWorker->setDebugLoc(DLoc);
3811  BranchInst::Create(IsWorkerCheckBB, UserCodeEntryBB, IsWorker, InitBB);
3812 
3813  Module &M = *Kernel->getParent();
3814  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3815  FunctionCallee BlockHwSizeFn =
3816  OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3817  M, OMPRTL___kmpc_get_hardware_num_threads_in_block);
3818  FunctionCallee WarpSizeFn =
3819  OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3820  M, OMPRTL___kmpc_get_warp_size);
3821  CallInst *BlockHwSize =
3822  CallInst::Create(BlockHwSizeFn, "block.hw_size", IsWorkerCheckBB);
3823  OMPInfoCache.setCallingConvention(BlockHwSizeFn, BlockHwSize);
3824  BlockHwSize->setDebugLoc(DLoc);
3825  CallInst *WarpSize =
3826  CallInst::Create(WarpSizeFn, "warp.size", IsWorkerCheckBB);
3827  OMPInfoCache.setCallingConvention(WarpSizeFn, WarpSize);
3828  WarpSize->setDebugLoc(DLoc);
3829  Instruction *BlockSize = BinaryOperator::CreateSub(
3830  BlockHwSize, WarpSize, "block.size", IsWorkerCheckBB);
3831  BlockSize->setDebugLoc(DLoc);
3832  Instruction *IsMainOrWorker = ICmpInst::Create(
3833  ICmpInst::ICmp, llvm::CmpInst::ICMP_SLT, KernelInitCB, BlockSize,
3834  "thread.is_main_or_worker", IsWorkerCheckBB);
3835  IsMainOrWorker->setDebugLoc(DLoc);
3836  BranchInst::Create(StateMachineBeginBB, StateMachineFinishedBB,
3837  IsMainOrWorker, IsWorkerCheckBB);
3838 
3839  // Create local storage for the work function pointer.
3840  const DataLayout &DL = M.getDataLayout();
3841  Type *VoidPtrTy = Type::getInt8PtrTy(Ctx);
3842  Instruction *WorkFnAI =
3843  new AllocaInst(VoidPtrTy, DL.getAllocaAddrSpace(), nullptr,
3844  "worker.work_fn.addr", &Kernel->getEntryBlock().front());
3845  WorkFnAI->setDebugLoc(DLoc);
3846 
3847  OMPInfoCache.OMPBuilder.updateToLocation(
3849  IRBuilder<>::InsertPoint(StateMachineBeginBB,
3850  StateMachineBeginBB->end()),
3851  DLoc));
3852 
3853  Value *Ident = KernelInitCB->getArgOperand(0);
3854  Value *GTid = KernelInitCB;
3855 
3856  FunctionCallee BarrierFn =
3857  OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3858  M, OMPRTL___kmpc_barrier_simple_generic);
3859  CallInst *Barrier =
3860  CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineBeginBB);
3861  OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
3862  Barrier->setDebugLoc(DLoc);
3863 
3864  if (WorkFnAI->getType()->getPointerAddressSpace() !=
3865  (unsigned int)AddressSpace::Generic) {
3866  WorkFnAI = new AddrSpaceCastInst(
3867  WorkFnAI,
3869  cast<PointerType>(WorkFnAI->getType()),
3870  (unsigned int)AddressSpace::Generic),
3871  WorkFnAI->getName() + ".generic", StateMachineBeginBB);
3872  WorkFnAI->setDebugLoc(DLoc);
3873  }
3874 
3875  FunctionCallee KernelParallelFn =
3876  OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3877  M, OMPRTL___kmpc_kernel_parallel);
3878  CallInst *IsActiveWorker = CallInst::Create(
3879  KernelParallelFn, {WorkFnAI}, "worker.is_active", StateMachineBeginBB);
3880  OMPInfoCache.setCallingConvention(KernelParallelFn, IsActiveWorker);
3881  IsActiveWorker->setDebugLoc(DLoc);
3882  Instruction *WorkFn = new LoadInst(VoidPtrTy, WorkFnAI, "worker.work_fn",
3883  StateMachineBeginBB);
3884  WorkFn->setDebugLoc(DLoc);
3885 
3886  FunctionType *ParallelRegionFnTy = FunctionType::get(
3888  false);
3890  WorkFn, ParallelRegionFnTy->getPointerTo(), "worker.work_fn.addr_cast",
3891  StateMachineBeginBB);
3892 
3893  Instruction *IsDone =
3894  ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFn,
3895  Constant::getNullValue(VoidPtrTy), "worker.is_done",
3896  StateMachineBeginBB);
3897  IsDone->setDebugLoc(DLoc);
3898  BranchInst::Create(StateMachineFinishedBB, StateMachineIsActiveCheckBB,
3899  IsDone, StateMachineBeginBB)
3900  ->setDebugLoc(DLoc);
3901 
3902  BranchInst::Create(StateMachineIfCascadeCurrentBB,
3903  StateMachineDoneBarrierBB, IsActiveWorker,
3904  StateMachineIsActiveCheckBB)
3905  ->setDebugLoc(DLoc);
3906 
3907  Value *ZeroArg =
3908  Constant::getNullValue(ParallelRegionFnTy->getParamType(0));
3909 
3910  // Now that we have most of the CFG skeleton it is time for the if-cascade
3911  // that checks the function pointer we got from the runtime against the
3912  // parallel regions we expect, if there are any.
3913  for (int I = 0, E = ReachedKnownParallelRegions.size(); I < E; ++I) {
3914  auto *ParallelRegion = ReachedKnownParallelRegions[I];
3915  BasicBlock *PRExecuteBB = BasicBlock::Create(
3916  Ctx, "worker_state_machine.parallel_region.execute", Kernel,
3917  StateMachineEndParallelBB);
3918  CallInst::Create(ParallelRegion, {ZeroArg, GTid}, "", PRExecuteBB)
3919  ->setDebugLoc(DLoc);
3920  BranchInst::Create(StateMachineEndParallelBB, PRExecuteBB)
3921  ->setDebugLoc(DLoc);
3922 
3923  BasicBlock *PRNextBB =
3924  BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
3925  Kernel, StateMachineEndParallelBB);
3926 
3927  // Check if we need to compare the pointer at all or if we can just
3928  // call the parallel region function.
3929  Value *IsPR;
3930  if (I + 1 < E || !ReachedUnknownParallelRegions.empty()) {
3931  Instruction *CmpI = ICmpInst::Create(
3932  ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFnCast, ParallelRegion,
3933  "worker.check_parallel_region", StateMachineIfCascadeCurrentBB);
3934  CmpI->setDebugLoc(DLoc);
3935  IsPR = CmpI;
3936  } else {
3937  IsPR = ConstantInt::getTrue(Ctx);
3938  }
3939 
3940  BranchInst::Create(PRExecuteBB, PRNextBB, IsPR,
3941  StateMachineIfCascadeCurrentBB)
3942  ->setDebugLoc(DLoc);
3943  StateMachineIfCascadeCurrentBB = PRNextBB;
3944  }
3945 
3946  // At the end of the if-cascade we place the indirect function pointer call
3947  // in case we might need it, that is if there can be parallel regions we
3948  // have not handled in the if-cascade above.
3949  if (!ReachedUnknownParallelRegions.empty()) {
3950  StateMachineIfCascadeCurrentBB->setName(
3951  "worker_state_machine.parallel_region.fallback.execute");
3952  CallInst::Create(ParallelRegionFnTy, WorkFnCast, {ZeroArg, GTid}, "",
3953  StateMachineIfCascadeCurrentBB)
3954  ->setDebugLoc(DLoc);
3955  }
3956  BranchInst::Create(StateMachineEndParallelBB,
3957  StateMachineIfCascadeCurrentBB)
3958  ->setDebugLoc(DLoc);
3959 
3960  FunctionCallee EndParallelFn =
3961  OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3962  M, OMPRTL___kmpc_kernel_end_parallel);
3963  CallInst *EndParallel =
3964  CallInst::Create(EndParallelFn, {}, "", StateMachineEndParallelBB);
3965  OMPInfoCache.setCallingConvention(EndParallelFn, EndParallel);
3966  EndParallel->setDebugLoc(DLoc);
3967  BranchInst::Create(StateMachineDoneBarrierBB, StateMachineEndParallelBB)
3968  ->setDebugLoc(DLoc);
3969 
3970  CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineDoneBarrierBB)
3971  ->setDebugLoc(DLoc);
3972  BranchInst::Create(StateMachineBeginBB, StateMachineDoneBarrierBB)
3973  ->setDebugLoc(DLoc);
3974 
3975  return ChangeStatus::CHANGED;
3976  }
3977 
3978  /// Fixpoint iteration update function. Will be called every time a dependence
3979  /// changed its state (and in the beginning).
3980  ChangeStatus updateImpl(Attributor &A) override {
3981  KernelInfoState StateBefore = getState();
3982 
3983  // Callback to check a read/write instruction.
3984  auto CheckRWInst = [&](Instruction &I) {
3985  // We handle calls later.
3986  if (isa<CallBase>(I))
3987  return true;
3988  // We only care about write effects.
3989  if (!I.mayWriteToMemory())
3990  return true;
3991  if (auto *SI = dyn_cast<StoreInst>(&I)) {
3993  getUnderlyingObjects(SI->getPointerOperand(), Objects);
3994  if (llvm::all_of(Objects,
3995  [](const Value *Obj) { return isa<AllocaInst>(Obj); }))
3996  return true;
3997  // Check for AAHeapToStack moved objects which must not be guarded.
3998  auto &HS = A.getAAFor<AAHeapToStack>(
3999  *this, IRPosition::function(*I.getFunction()),
4001  if (llvm::all_of(Objects, [&HS](const Value *Obj) {
4002  auto *CB = dyn_cast<CallBase>(Obj);
4003  if (!CB)
4004  return false;
4005  return HS.isAssumedHeapToStack(*CB);
4006  })) {
4007  return true;
4008  }
4009  }
4010 
4011  // Insert instruction that needs guarding.
4012  SPMDCompatibilityTracker.insert(&I);
4013  return true;
4014  };
4015 
4016  bool UsedAssumedInformationInCheckRWInst = false;
4017  if (!SPMDCompatibilityTracker.isAtFixpoint())
4018  if (!A.checkForAllReadWriteInstructions(
4019  CheckRWInst, *this, UsedAssumedInformationInCheckRWInst))
4020  SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4021 
4022  bool UsedAssumedInformationFromReachingKernels = false;
4023  if (!IsKernelEntry) {
4024  updateParallelLevels(A);
4025 
4026  bool AllReachingKernelsKnown = true;
4027  updateReachingKernelEntries(A, AllReachingKernelsKnown);
4028  UsedAssumedInformationFromReachingKernels = !AllReachingKernelsKnown;
4029 
4030  if (!ParallelLevels.isValidState())
4031  SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4032  else if (!ReachingKernelEntries.isValidState())
4033  SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4034  else if (!SPMDCompatibilityTracker.empty()) {
4035  // Check if all reaching kernels agree on the mode as we can otherwise
4036  // not guard instructions. We might not be sure about the mode so we
4037  // we cannot fix the internal spmd-zation state either.
4038  int SPMD = 0, Generic = 0;
4039  for (auto *Kernel : ReachingKernelEntries) {
4040  auto &CBAA = A.getAAFor<AAKernelInfo>(
4042  if (CBAA.SPMDCompatibilityTracker.isValidState() &&
4043  CBAA.SPMDCompatibilityTracker.isAssumed())
4044  ++SPMD;
4045  else
4046  ++Generic;
4047  if (!CBAA.SPMDCompatibilityTracker.isAtFixpoint())
4048  UsedAssumedInformationFromReachingKernels = true;
4049  }
4050  if (SPMD != 0 && Generic != 0)
4051  SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4052  }
4053  }
4054 
4055  // Callback to check a call instruction.
4056  bool AllParallelRegionStatesWereFixed = true;
4057  bool AllSPMDStatesWereFixed = true;
4058  auto CheckCallInst = [&](Instruction &I) {
4059  auto &CB = cast<CallBase>(I);
4060  auto &CBAA = A.getAAFor<AAKernelInfo>(
4062  getState() ^= CBAA.getState();
4063  AllSPMDStatesWereFixed &= CBAA.SPMDCompatibilityTracker.isAtFixpoint();
4064  AllParallelRegionStatesWereFixed &=
4065  CBAA.ReachedKnownParallelRegions.isAtFixpoint();
4066  AllParallelRegionStatesWereFixed &=
4067  CBAA.ReachedUnknownParallelRegions.isAtFixpoint();
4068  return true;
4069  };
4070 
4071  bool UsedAssumedInformationInCheckCallInst = false;
4072  if (!A.checkForAllCallLikeInstructions(
4073  CheckCallInst, *this, UsedAssumedInformationInCheckCallInst)) {
4074  LLVM_DEBUG(dbgs() << TAG
4075  << "Failed to visit all call-like instructions!\n";);
4076  return indicatePessimisticFixpoint();
4077  }
4078 
4079  // If we haven't used any assumed information for the reached parallel
4080  // region states we can fix it.
4081  if (!UsedAssumedInformationInCheckCallInst &&
4082  AllParallelRegionStatesWereFixed) {
4083  ReachedKnownParallelRegions.indicateOptimisticFixpoint();
4084  ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
4085  }
4086 
4087  // If we are sure there are no parallel regions in the kernel we do not
4088  // want SPMD mode.
4089  if (IsKernelEntry && ReachedUnknownParallelRegions.isAtFixpoint() &&
4090  ReachedKnownParallelRegions.isAtFixpoint() &&
4091  ReachedUnknownParallelRegions.isValidState() &&
4092  ReachedKnownParallelRegions.isValidState() &&
4093  !mayContainParallelRegion())
4094  SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4095 
4096  // If we haven't used any assumed information for the SPMD state we can fix
4097  // it.
4098  if (!UsedAssumedInformationInCheckRWInst &&
4099  !UsedAssumedInformationInCheckCallInst &&
4100  !UsedAssumedInformationFromReachingKernels && AllSPMDStatesWereFixed)
4101  SPMDCompatibilityTracker.indicateOptimisticFixpoint();
4102 
4103  return StateBefore == getState() ? ChangeStatus::UNCHANGED
4105  }
4106 
4107 private:
4108  /// Update info regarding reaching kernels.
4109  void updateReachingKernelEntries(Attributor &A,
4110  bool &AllReachingKernelsKnown) {
4111  auto PredCallSite = [&](AbstractCallSite ACS) {
4112  Function *Caller = ACS.getInstruction()->getFunction();
4113 
4114  assert(Caller && "Caller is nullptr");
4115 
4116  auto &CAA = A.getOrCreateAAFor<AAKernelInfo>(
4117  IRPosition::function(*Caller), this, DepClassTy::REQUIRED);
4118  if (CAA.ReachingKernelEntries.isValidState()) {
4119  ReachingKernelEntries ^= CAA.ReachingKernelEntries;
4120  return true;
4121  }
4122 
4123  // We lost track of the caller of the associated function, any kernel
4124  // could reach now.
4125  ReachingKernelEntries.indicatePessimisticFixpoint();
4126 
4127  return true;
4128  };
4129 
4130  if (!A.checkForAllCallSites(PredCallSite, *this,
4131  true /* RequireAllCallSites */,
4132  AllReachingKernelsKnown))
4133  ReachingKernelEntries.indicatePessimisticFixpoint();
4134  }
4135 
4136  /// Update info regarding parallel levels.
4137  void updateParallelLevels(Attributor &A) {
4138  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4139  OMPInformationCache::RuntimeFunctionInfo &Parallel51RFI =
4140  OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
4141 
4142  auto PredCallSite = [&](AbstractCallSite ACS) {
4143  Function *Caller = ACS.getInstruction()->getFunction();
4144 
4145  assert(Caller && "Caller is nullptr");
4146 
4147  auto &CAA =
4148  A.getOrCreateAAFor<AAKernelInfo>(IRPosition::function(*Caller));
4149  if (CAA.ParallelLevels.isValidState()) {
4150  // Any function that is called by `__kmpc_parallel_51` will not be
4151  // folded as the parallel level in the function is updated. In order to
4152  // get it right, all the analysis would depend on the implentation. That
4153  // said, if in the future any change to the implementation, the analysis
4154  // could be wrong. As a consequence, we are just conservative here.
4155  if (Caller == Parallel51RFI.Declaration) {
4156  ParallelLevels.indicatePessimisticFixpoint();
4157  return true;
4158  }
4159 
4160  ParallelLevels ^= CAA.ParallelLevels;
4161 
4162  return true;
4163  }
4164 
4165  // We lost track of the caller of the associated function, any kernel
4166  // could reach now.
4167  ParallelLevels.indicatePessimisticFixpoint();
4168 
4169  return true;
4170  };
4171 
4172  bool AllCallSitesKnown = true;
4173  if (!A.checkForAllCallSites(PredCallSite, *this,
4174  true /* RequireAllCallSites */,
4175  AllCallSitesKnown))
4176  ParallelLevels.indicatePessimisticFixpoint();
4177  }
4178 };
4179 
4180 /// The call site kernel info abstract attribute, basically, what can we say
4181 /// about a call site with regards to the KernelInfoState. For now this simply
4182 /// forwards the information from the callee.
4183 struct AAKernelInfoCallSite : AAKernelInfo {
4184  AAKernelInfoCallSite(const IRPosition &IRP, Attributor &A)
4185  : AAKernelInfo(IRP, A) {}
4186 
4187  /// See AbstractAttribute::initialize(...).
4188  void initialize(Attributor &A) override {
4190 
4191  CallBase &CB = cast<CallBase>(getAssociatedValue());
4192  Function *Callee = getAssociatedFunction();
4193 
4194  auto &AssumptionAA = A.getAAFor<AAAssumptionInfo>(
4196 
4197  // Check for SPMD-mode assumptions.
4198  if (AssumptionAA.hasAssumption("ompx_spmd_amenable")) {
4199  SPMDCompatibilityTracker.indicateOptimisticFixpoint();
4200  indicateOptimisticFixpoint();
4201  }
4202 
4203  // First weed out calls we do not care about, that is readonly/readnone
4204  // calls, intrinsics, and "no_openmp" calls. Neither of these can reach a
4205  // parallel region or anything else we are looking for.
4206  if (!CB.mayWriteToMemory() || isa<IntrinsicInst>(CB)) {
4207  indicateOptimisticFixpoint();
4208  return;
4209  }
4210 
4211  // Next we check if we know the callee. If it is a known OpenMP function
4212  // we will handle them explicitly in the switch below. If it is not, we
4213  // will use an AAKernelInfo object on the callee to gather information and
4214  // merge that into the current state. The latter happens in the updateImpl.
4215  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4216  const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
4217  if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
4218  // Unknown caller or declarations are not analyzable, we give up.
4219  if (!Callee || !A.isFunctionIPOAmendable(*Callee)) {
4220 
4221  // Unknown callees might contain parallel regions, except if they have
4222  // an appropriate assumption attached.
4223  if (!(AssumptionAA.hasAssumption("omp_no_openmp") ||
4224  AssumptionAA.hasAssumption("omp_no_parallelism")))
4225  ReachedUnknownParallelRegions.insert(&CB);
4226 
4227  // If SPMDCompatibilityTracker is not fixed, we need to give up on the
4228  // idea we can run something unknown in SPMD-mode.
4229  if (!SPMDCompatibilityTracker.isAtFixpoint()) {
4230  SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4231  SPMDCompatibilityTracker.insert(&CB);
4232  }
4233 
4234  // We have updated the state for this unknown call properly, there won't
4235  // be any change so we indicate a fixpoint.
4236  indicateOptimisticFixpoint();
4237  }
4238  // If the callee is known and can be used in IPO, we will update the state
4239  // based on the callee state in updateImpl.
4240  return;
4241  }
4242 
4243  const unsigned int WrapperFunctionArgNo = 6;
4244  RuntimeFunction RF = It->getSecond();
4245  switch (RF) {
4246  // All the functions we know are compatible with SPMD mode.
4247  case OMPRTL___kmpc_is_spmd_exec_mode:
4248  case OMPRTL___kmpc_distribute_static_fini:
4249  case OMPRTL___kmpc_for_static_fini:
4250  case OMPRTL___kmpc_global_thread_num:
4251  case OMPRTL___kmpc_get_hardware_num_threads_in_block:
4252  case OMPRTL___kmpc_get_hardware_num_blocks:
4253  case OMPRTL___kmpc_single:
4254  case OMPRTL___kmpc_end_single:
4255  case OMPRTL___kmpc_master:
4256  case OMPRTL___kmpc_end_master:
4257  case OMPRTL___kmpc_barrier:
4258  case OMPRTL___kmpc_nvptx_parallel_reduce_nowait_v2:
4259  case OMPRTL___kmpc_nvptx_teams_reduce_nowait_v2:
4260  case OMPRTL___kmpc_nvptx_end_reduce_nowait:
4261  break;
4262  case OMPRTL___kmpc_distribute_static_init_4:
4263  case OMPRTL___kmpc_distribute_static_init_4u:
4264  case OMPRTL___kmpc_distribute_static_init_8:
4265  case OMPRTL___kmpc_distribute_static_init_8u:
4266  case OMPRTL___kmpc_for_static_init_4:
4267  case OMPRTL___kmpc_for_static_init_4u:
4268  case OMPRTL___kmpc_for_static_init_8:
4269  case OMPRTL___kmpc_for_static_init_8u: {
4270  // Check the schedule and allow static schedule in SPMD mode.
4271  unsigned ScheduleArgOpNo = 2;
4272  auto *ScheduleTypeCI =
4273  dyn_cast<ConstantInt>(CB.getArgOperand(ScheduleArgOpNo));
4274  unsigned ScheduleTypeVal =
4275  ScheduleTypeCI ? ScheduleTypeCI->getZExtValue() : 0;
4276  switch (OMPScheduleType(ScheduleTypeVal)) {
4281  break;
4282  default:
4283  SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4284  SPMDCompatibilityTracker.insert(&CB);
4285  break;
4286  };
4287  } break;
4288  case OMPRTL___kmpc_target_init:
4289  KernelInitCB = &CB;
4290  break;
4291  case OMPRTL___kmpc_target_deinit:
4292  KernelDeinitCB = &CB;
4293  break;
4294  case OMPRTL___kmpc_parallel_51:
4295  if (auto *ParallelRegion = dyn_cast<Function>(
4296  CB.getArgOperand(WrapperFunctionArgNo)->stripPointerCasts())) {
4297  ReachedKnownParallelRegions.insert(ParallelRegion);
4298  break;
4299  }
4300  // The condition above should usually get the parallel region function
4301  // pointer and record it. In the off chance it doesn't we assume the
4302  // worst.
4303  ReachedUnknownParallelRegions.insert(&CB);
4304  break;
4305  case OMPRTL___kmpc_omp_task:
4306  // We do not look into tasks right now, just give up.
4307  SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4308  SPMDCompatibilityTracker.insert(&CB);
4309  ReachedUnknownParallelRegions.insert(&CB);
4310  break;
4311  case OMPRTL___kmpc_alloc_shared:
4312  case OMPRTL___kmpc_free_shared:
4313  // Return without setting a fixpoint, to be resolved in updateImpl.
4314  return;
4315  default:
4316  // Unknown OpenMP runtime calls cannot be executed in SPMD-mode,
4317  // generally. However, they do not hide parallel regions.
4318  SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4319  SPMDCompatibilityTracker.insert(&CB);
4320  break;
4321  }
4322  // All other OpenMP runtime calls will not reach parallel regions so they
4323  // can be safely ignored for now. Since it is a known OpenMP runtime call we
4324  // have now modeled all effects and there is no need for any update.
4325  indicateOptimisticFixpoint();
4326  }
4327 
4328  ChangeStatus updateImpl(Attributor &A) override {
4329  // TODO: Once we have call site specific value information we can provide
4330  // call site specific liveness information and then it makes
4331  // sense to specialize attributes for call sites arguments instead of
4332  // redirecting requests to the callee argument.
4333  Function *F = getAssociatedFunction();
4334 
4335  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4336  const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(F);
4337 
4338  // If F is not a runtime function, propagate the AAKernelInfo of the callee.
4339  if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
4340  const IRPosition &FnPos = IRPosition::function(*F);
4341  auto &FnAA = A.getAAFor<AAKernelInfo>(*this, FnPos, DepClassTy::REQUIRED);
4342  if (getState() == FnAA.getState())
4343  return ChangeStatus::UNCHANGED;
4344  getState() = FnAA.getState();
4345  return ChangeStatus::CHANGED;
4346  }
4347 
4348  // F is a runtime function that allocates or frees memory, check
4349  // AAHeapToStack and AAHeapToShared.
4350  KernelInfoState StateBefore = getState();
4351  assert((It->getSecond() == OMPRTL___kmpc_alloc_shared ||
4352  It->getSecond() == OMPRTL___kmpc_free_shared) &&
4353  "Expected a __kmpc_alloc_shared or __kmpc_free_shared runtime call");
4354 
4355  CallBase &CB = cast<CallBase>(getAssociatedValue());
4356 
4357  auto &HeapToStackAA = A.getAAFor<AAHeapToStack>(
4359  auto &HeapToSharedAA = A.getAAFor<AAHeapToShared>(
4361 
4362  RuntimeFunction RF = It->getSecond();
4363 
4364  switch (RF) {
4365  // If neither HeapToStack nor HeapToShared assume the call is removed,
4366  // assume SPMD incompatibility.
4367  case OMPRTL___kmpc_alloc_shared:
4368  if (!HeapToStackAA.isAssumedHeapToStack(CB) &&
4369  !HeapToSharedAA.isAssumedHeapToShared(CB))
4370  SPMDCompatibilityTracker.insert(&CB);
4371  break;
4372  case OMPRTL___kmpc_free_shared:
4373  if (!HeapToStackAA.isAssumedHeapToStackRemovedFree(CB) &&
4374  !HeapToSharedAA.isAssumedHeapToSharedRemovedFree(CB))
4375  SPMDCompatibilityTracker.insert(&CB);
4376  break;
4377  default:
4378  SPMDCompatibilityTracker.indicatePessimisticFixpoint();
4379  SPMDCompatibilityTracker.insert(&CB);
4380  }
4381 
4382  return StateBefore == getState() ? ChangeStatus::UNCHANGED
4384  }
4385 };
4386 
4387 struct AAFoldRuntimeCall
4388  : public StateWrapper<BooleanState, AbstractAttribute> {
4390 
4391  AAFoldRuntimeCall(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
4392 
4393  /// Statistics are tracked as part of manifest for now.
4394  void trackStatistics() const override {}
4395 
4396  /// Create an abstract attribute biew for the position \p IRP.
4397  static AAFoldRuntimeCall &createForPosition(const IRPosition &IRP,
4398  Attributor &A);
4399 
4400  /// See AbstractAttribute::getName()
4401  const std::string getName() const override { return "AAFoldRuntimeCall"; }
4402 
4403  /// See AbstractAttribute::getIdAddr()
4404  const char *getIdAddr() const override { return &ID; }
4405 
4406  /// This function should return true if the type of the \p AA is
4407  /// AAFoldRuntimeCall
4408  static bool classof(const AbstractAttribute *AA) {
4409  return (AA->getIdAddr() == &ID);
4410  }
4411 
4412  static const char ID;
4413 };
4414 
4415 struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
4416  AAFoldRuntimeCallCallSiteReturned(const IRPosition &IRP, Attributor &A)
4417  : AAFoldRuntimeCall(IRP, A) {}
4418 
4419  /// See AbstractAttribute::getAsStr()
4420  const std::string getAsStr() const override {
4421  if (!isValidState())
4422  return "<invalid>";
4423 
4424  std::string Str("simplified value: ");
4425 
4426  if (!SimplifiedValue.hasValue())
4427  return Str + std::string("none");
4428 
4429  if (!SimplifiedValue.getValue())
4430  return Str + std::string("nullptr");
4431 
4432  if (ConstantInt *CI = dyn_cast<ConstantInt>(SimplifiedValue.getValue()))
4433  return Str + std::to_string(CI->getSExtValue());
4434 
4435  return Str + std::string("unknown");
4436  }
4437 
4438  void initialize(Attributor &A) override {
4440  indicatePessimisticFixpoint();
4441 
4442  Function *Callee = getAssociatedFunction();
4443 
4444  auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4445  const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
4446  assert(It != OMPInfoCache.RuntimeFunctionIDMap.end() &&
4447  "Expected a known OpenMP runtime function");
4448 
4449  RFKind = It->getSecond();
4450 
4451  CallBase &CB = cast<CallBase>(getAssociatedValue());
4452  A.registerSimplificationCallback(
4454  [&](const IRPosition &IRP, const AbstractAttribute *AA,
4455  bool &UsedAssumedInformation) -> Optional<Value *> {
4456  assert((isValidState() || (SimplifiedValue.hasValue() &&
4457  SimplifiedValue.getValue() == nullptr)) &&
4458  "Unexpected invalid state!");
4459 
4460  if (!isAtFixpoint()) {
4461  UsedAssumedInformation = true;
4462  if (AA)
4463  A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
4464  }
4465  return SimplifiedValue;
4466  });
4467  }
4468 
4469  ChangeStatus updateImpl(Attributor &A) override {
4471  switch (RFKind) {
4472  case OMPRTL___kmpc_is_spmd_exec_mode:
4473  Changed |= foldIsSPMDExecMode(A);
4474  break;
4475  case OMPRTL___kmpc_is_generic_main_thread_id:
4476  Changed |= foldIsGenericMainThread(A);
4477  break;
4478  case OMPRTL___kmpc_parallel_level:
4479  Changed |= foldParallelLevel(A);
4480  break;
4481  case OMPRTL___kmpc_get_hardware_num_threads_in_block:
4482  Changed = Changed | foldKernelFnAttribute(A, "omp_target_thread_limit");
4483  break;
4484  case OMPRTL___kmpc_get_hardware_num_blocks:
4485  Changed = Changed | foldKernelFnAttribute(A, "omp_target_num_teams");
4486  break;
4487  default:
4488  llvm_unreachable("Unhandled OpenMP runtime function!");
4489  }
4490 
4491  return Changed;
4492  }
4493 
4494  ChangeStatus manifest(Attributor &A) override {
4496 
4497  if (SimplifiedValue.hasValue() && SimplifiedValue.getValue()) {
4498  Instruction &I = *getCtxI();
4499  A.changeValueAfterManifest(I, **SimplifiedValue);
4500  A.deleteAfterManifest(I);
4501 
4502  CallBase *CB = dyn_cast<CallBase>(&I);
4503  auto Remark = [&](OptimizationRemark OR) {
4504  if (auto *C = dyn_cast<ConstantInt>(*SimplifiedValue))
4505  return OR << "Replacing OpenMP runtime call "
4506  << CB->getCalledFunction()->getName() << " with "
4507  << ore::NV("FoldedValue", C->getZExtValue()) << ".";
4508  return OR << "Replacing OpenMP runtime call "
4509  << CB->getCalledFunction()->getName() << ".";
4510  };
4511 
4512  if (CB && EnableVerboseRemarks)
4513  A.emitRemark<OptimizationRemark>(CB, "OMP180", Remark);
4514 
4515  LLVM_DEBUG(dbgs() << TAG << "Replacing runtime call: " << I << " with "
4516  << **SimplifiedValue << "\n");
4517 
4518  Changed = ChangeStatus::CHANGED;
4519  }
4520 
4521  return Changed;
4522  }
4523 
4524  ChangeStatus indicatePessimisticFixpoint() override {
4525  SimplifiedValue = nullptr;
4526  return AAFoldRuntimeCall::indicatePessimisticFixpoint();
4527  }
4528 
4529 private:
4530  /// Fold __kmpc_is_spmd_exec_mode into a constant if possible.
4531  ChangeStatus foldIsSPMDExecMode(Attributor &A) {
4532  Optional<Value *> SimplifiedValueBefore = SimplifiedValue;
4533 
4534  unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
4535  unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
4536  auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
4537  *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
4538 
4539  if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
4540  return indicatePessimisticFixpoint();
4541 
4542  for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
4543  auto &AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
4545 
4546  if (!AA.isValidState()) {
4547  SimplifiedValue = nullptr;
4548  return indicatePessimisticFixpoint();
4549  }
4550 
4551  if (AA.SPMDCompatibilityTracker.isAssumed()) {
4552  if (AA.SPMDCompatibilityTracker.isAtFixpoint())
4553  ++KnownSPMDCount;
4554  else
4555  ++AssumedSPMDCount;
4556  } else {
4557  if (AA.SPMDCompatibilityTracker.isAtFixpoint())
4558  ++KnownNonSPMDCount;
4559  else
4560  ++AssumedNonSPMDCount;
4561  }
4562  }
4563 
4564  if ((AssumedSPMDCount + KnownSPMDCount) &&
4565  (AssumedNonSPMDCount + KnownNonSPMDCount))
4566  return indicatePessimisticFixpoint();
4567 
4568  auto &Ctx = getAnchorValue().getContext();
4569  if (KnownSPMDCount || AssumedSPMDCount) {
4570  assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
4571  "Expected only SPMD kernels!");
4572  // All reaching kernels are in SPMD mode. Update all function calls to
4573  // __kmpc_is_spmd_exec_mode to 1.
4574  SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), true);
4575  } else if (KnownNonSPMDCount || AssumedNonSPMDCount) {
4576  assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
4577  "Expected only non-SPMD kernels!");
4578  // All reaching kernels are in non-SPMD mode. Update all function
4579  // calls to __kmpc_is_spmd_exec_mode to 0.
4580  SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), false);
4581  } else {
4582  // We have empty reaching kernels, therefore we cannot tell if the
4583  // associated call site can be folded. At this moment, SimplifiedValue
4584  // must be none.
4585  assert(!SimplifiedValue.hasValue() && "SimplifiedValue should be none");
4586  }
4587 
4588  return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
4590  }
4591 
4592  /// Fold __kmpc_is_generic_main_thread_id into a constant if possible.
4593  ChangeStatus foldIsGenericMainThread(Attributor &A) {
4594  Optional<Value *> SimplifiedValueBefore = SimplifiedValue;
4595 
4596  CallBase &CB = cast<CallBase>(getAssociatedValue());
4597  Function *F = CB.getFunction();
4598  const auto &ExecutionDomainAA = A.getAAFor<AAExecutionDomain>(
4600 
4601  if (!ExecutionDomainAA.isValidState())
4602  return indicatePessimisticFixpoint();
4603 
4604  auto &Ctx = getAnchorValue().getContext();
4605  if (ExecutionDomainAA.isExecutedByInitialThreadOnly(CB))
4606  SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), true);
4607  else
4608  return indicatePessimisticFixpoint();
4609 
4610  return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
4612  }
4613 
4614  /// Fold __kmpc_parallel_level into a constant if possible.
4615  ChangeStatus foldParallelLevel(Attributor &A) {
4616  Optional<Value *> SimplifiedValueBefore = SimplifiedValue;
4617 
4618  auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
4619  *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
4620 
4621  if (!CallerKernelInfoAA.ParallelLevels.isValidState())
4622  return indicatePessimisticFixpoint();
4623 
4624  if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
4625  return indicatePessimisticFixpoint();
4626 
4627  if (CallerKernelInfoAA.ReachingKernelEntries.empty()) {
4628  assert(!SimplifiedValue.hasValue() &&
4629  "SimplifiedValue should keep none at this point");
4630  return ChangeStatus::UNCHANGED;
4631  }
4632 
4633  unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
4634  unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
4635  for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
4636  auto &AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
4638  if (!AA.SPMDCompatibilityTracker.isValidState())
4639  return indicatePessimisticFixpoint();
4640 
4641  if (AA.SPMDCompatibilityTracker.isAssumed()) {
4642  if (AA.SPMDCompatibilityTracker.isAtFixpoint())
4643  ++KnownSPMDCount;
4644  else
4645  ++AssumedSPMDCount;
4646  } else {
4647  if (AA.SPMDCompatibilityTracker.isAtFixpoint())
4648  ++KnownNonSPMDCount;
4649  else
4650  ++AssumedNonSPMDCount;
4651  }
4652  }
4653 
4654  if ((AssumedSPMDCount + KnownSPMDCount) &&
4655  (AssumedNonSPMDCount + KnownNonSPMDCount))
4656  return indicatePessimisticFixpoint();
4657 
4658  auto &Ctx = getAnchorValue().getContext();
4659  // If the caller can only be reached by SPMD kernel entries, the parallel
4660  // level is 1. Similarly, if the caller can only be reached by non-SPMD
4661  // kernel entries, it is 0.
4662  if (AssumedSPMDCount || KnownSPMDCount) {
4663  assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
4664  "Expected only SPMD kernels!");
4665  SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 1);
4666  } else {
4667  assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
4668  "Expected only non-SPMD kernels!");
4669  SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 0);
4670  }
4671  return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
4673  }
4674 
4675  ChangeStatus foldKernelFnAttribute(Attributor &A, llvm::StringRef Attr) {
4676  // Specialize only if all the calls agree with the attribute constant value
4677  int32_t CurrentAttrValue = -1;
4678  Optional<Value *> SimplifiedValueBefore = SimplifiedValue;
4679 
4680  auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
4681  *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
4682 
4683  if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
4684  return indicatePessimisticFixpoint();
4685 
4686  // Iterate over the kernels that reach this function
4687  for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
4688  int32_t NextAttrVal = -1;
4689  if (K->hasFnAttribute(Attr))
4690  NextAttrVal =
4691  std::stoi(K->getFnAttribute(Attr).getValueAsString().str());
4692 
4693  if (NextAttrVal == -1 ||
4694  (CurrentAttrValue != -1 && CurrentAttrValue != NextAttrVal))
4695  return indicatePessimisticFixpoint();
4696  CurrentAttrValue = NextAttrVal;
4697  }
4698 
4699  if (CurrentAttrValue != -1) {
4700  auto &Ctx = getAnchorValue().getContext();
4701  SimplifiedValue =
4702  ConstantInt::get(Type::getInt32Ty(Ctx), CurrentAttrValue);
4703  }
4704  return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
4706  }
4707 
4708  /// An optional value the associated value is assumed to fold to. That is, we
4709  /// assume the associated value (which is a call) can be replaced by this
4710  /// simplified value.
4711  Optional<Value *> SimplifiedValue;
4712 
4713  /// The runtime function kind of the callee of the associated call site.
4714  RuntimeFunction RFKind;
4715 };
4716 
4717 } // namespace
4718 
4719 /// Register folding callsite
4720 void OpenMPOpt::registerFoldRuntimeCall(RuntimeFunction RF) {
4721  auto &RFI = OMPInfoCache.RFIs[RF];
4722  RFI.foreachUse(SCC, [&](Use &U, Function &F) {
4723  CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &RFI);
4724  if (!CI)
4725  return false;
4726  A.getOrCreateAAFor<AAFoldRuntimeCall>(
4727  IRPosition::callsite_returned(*CI), /* QueryingAA */ nullptr,
4728  DepClassTy::NONE, /* ForceUpdate */ false,
4729  /* UpdateAfterInit */ false);
4730  return false;
4731  });
4732 }
4733 
4734 void OpenMPOpt::registerAAs(bool IsModulePass) {
4735  if (SCC.empty())
4736  return;
4737 
4738  if (IsModulePass) {
4739  // Ensure we create the AAKernelInfo AAs first and without triggering an
4740  // update. This will make sure we register all value simplification
4741  // callbacks before any other AA has the chance to create an AAValueSimplify
4742  // or similar.
4743  auto CreateKernelInfoCB = [&](Use &, Function &Kernel) {
4744  A.getOrCreateAAFor<AAKernelInfo>(
4745  IRPosition::function(Kernel), /* QueryingAA */ nullptr,
4746  DepClassTy::NONE, /* ForceUpdate */ false,
4747  /* UpdateAfterInit */ false);
4748  return false;
4749  };
4750  OMPInformationCache::RuntimeFunctionInfo &InitRFI =
4751  OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
4752  InitRFI.foreachUse(SCC, CreateKernelInfoCB);
4753 
4754  registerFoldRuntimeCall(OMPRTL___kmpc_is_generic_main_thread_id);
4755  registerFoldRuntimeCall(OMPRTL___kmpc_is_spmd_exec_mode);
4756  registerFoldRuntimeCall(OMPRTL___kmpc_parallel_level);
4757  registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_threads_in_block);
4758  registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_blocks);
4759  }
4760 
4761  // Create CallSite AA for all Getters.
4762  for (int Idx = 0; Idx < OMPInfoCache.ICVs.size() - 1; ++Idx) {
4763  auto ICVInfo = OMPInfoCache.ICVs[static_cast<InternalControlVar>(Idx)];
4764 
4765  auto &GetterRFI = OMPInfoCache.RFIs[ICVInfo.Getter];
4766 
4767  auto CreateAA = [&](Use &U, Function &Caller) {
4768  CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &GetterRFI);
4769  if (!CI)
4770  return false;
4771 
4772  auto &CB = cast<CallBase>(*CI);
4773 
4775  A.getOrCreateAAFor<AAICVTracker>(CBPos);
4776  return false;
4777  };
4778 
4779  GetterRFI.foreachUse(SCC, CreateAA);
4780  }
4781  auto &GlobalizationRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
4782  auto CreateAA = [&](Use &U, Function &F) {
4783  A.getOrCreateAAFor<AAHeapToShared>(IRPosition::function(F));
4784  return false;
4785  };
4787  GlobalizationRFI.foreachUse(SCC, CreateAA);
4788 
4789  // Create an ExecutionDomain AA for every function and a HeapToStack AA for
4790  // every function if there is a device kernel.
4791  if (!isOpenMPDevice(M))
4792  return;
4793 
4794  for (auto *F : SCC) {
4795  if (F->isDeclaration())
4796  continue;
4797 
4798  A.getOrCreateAAFor<AAExecutionDomain>(IRPosition::function(*F));
4800  A.getOrCreateAAFor<AAHeapToStack>(IRPosition::function(*F));
4801 
4802  for (auto &I : instructions(*F)) {
4803  if (auto *LI = dyn_cast<LoadInst>(&I)) {
4804  bool UsedAssumedInformation = false;
4805  A.getAssumedSimplified(IRPosition::value(*LI), /* AA */ nullptr,
4806  UsedAssumedInformation);
4807  } else if (auto *SI = dyn_cast<StoreInst>(&I)) {
4808  A.getOrCreateAAFor<AAIsDead>(IRPosition::value(*SI));
4809  }
4810  }
4811  }
4812 }
4813 
4814 const char AAICVTracker::ID = 0;
4815 const char AAKernelInfo::ID = 0;
4816 const char AAExecutionDomain::ID = 0;
4817 const char AAHeapToShared::ID = 0;
4818 const char AAFoldRuntimeCall::ID = 0;
4819 
4820 AAICVTracker &AAICVTracker::createForPosition(const IRPosition &IRP,
4821  Attributor &A) {
4822  AAICVTracker *AA = nullptr;
4823  switch (IRP.getPositionKind()) {
4825  case IRPosition::IRP_FLOAT:
4828  llvm_unreachable("ICVTracker can only be created for function position!");
4830  AA = new (A.Allocator) AAICVTrackerFunctionReturned(IRP, A);
4831  break;
4833  AA = new (A.Allocator) AAICVTrackerCallSiteReturned(IRP, A);
4834  break;
4836  AA = new (A.Allocator) AAICVTrackerCallSite(IRP, A);
4837  break;
4839  AA = new (A.Allocator) AAICVTrackerFunction(IRP, A);
4840  break;
4841  }
4842 
4843  return *AA;
4844 }
4845 
4847  Attributor &A) {
4848  AAExecutionDomainFunction *AA = nullptr;
4849  switch (IRP.getPositionKind()) {
4851  case IRPosition::IRP_FLOAT:
4858  "AAExecutionDomain can only be created for function position!");
4860  AA = new (A.Allocator) AAExecutionDomainFunction(IRP, A);
4861  break;
4862  }
4863 
4864  return *AA;
4865 }
4866 
4867 AAHeapToShared &AAHeapToShared::createForPosition(const IRPosition &IRP,
4868  Attributor &A) {
4869  AAHeapToSharedFunction *AA = nullptr;
4870  switch (IRP.getPositionKind()) {
4872  case IRPosition::IRP_FLOAT:
4879  "AAHeapToShared can only be created for function position!");
4881  AA = new (A.Allocator) AAHeapToSharedFunction(IRP, A);
4882  break;
4883  }
4884 
4885  return *AA;
4886 }
4887 
4888 AAKernelInfo &AAKernelInfo::createForPosition(const IRPosition &IRP,
4889  Attributor &A) {
4890  AAKernelInfo *AA = nullptr;
4891  switch (IRP.getPositionKind()) {
4893  case IRPosition::IRP_FLOAT:
4898  llvm_unreachable("KernelInfo can only be created for function position!");
4900  AA = new (A.Allocator) AAKernelInfoCallSite(IRP, A);
4901  break;
4903  AA = new (A.Allocator) AAKernelInfoFunction(IRP, A);
4904  break;
4905  }
4906 
4907  return *AA;
4908 }
4909 
4910 AAFoldRuntimeCall &AAFoldRuntimeCall::createForPosition(const IRPosition &IRP,
4911  Attributor &A) {
4912  AAFoldRuntimeCall *AA = nullptr;
4913  switch (IRP.getPositionKind()) {
4915  case IRPosition::IRP_FLOAT:
4921  llvm_unreachable("KernelInfo can only be created for call site position!");
4923  AA = new (A.Allocator) AAFoldRuntimeCallCallSiteReturned(IRP, A);
4924  break;
4925  }
4926 
4927  return *AA;
4928 }
4929 
4931  if (!containsOpenMP(M))
4932  return PreservedAnalyses::all();
4934  return PreservedAnalyses::all();
4935 
4939 
4941  LLVM_DEBUG(dbgs() << TAG << "Module before OpenMPOpt Module Pass:\n" << M);
4942 
4943  auto IsCalled = [&](Function &F) {
4944  if (Kernels.contains(&F))
4945  return true;
4946  for (const User *U : F.users())
4947  if (!isa<BlockAddress>(U))
4948  return true;
4949  return false;
4950  };
4951 
4952  auto EmitRemark = [&](Function &F) {
4954  ORE.emit([&]() {
4955  OptimizationRemarkAnalysis ORA(DEBUG_TYPE, "OMP140", &F);
4956  return ORA << "Could not internalize function. "
4957  << "Some optimizations may not be possible. [OMP140]";
4958  });
4959  };
4960 
4961  // Create internal copies of each function if this is a kernel Module. This
4962  // allows iterprocedural passes to see every call edge.
4963  DenseMap<Function *, Function *> InternalizedMap;
4964  if (isOpenMPDevice(M)) {
4965  SmallPtrSet<Function *, 16> InternalizeFns;
4966  for (Function &F : M)
4967  if (!F.isDeclaration() && !Kernels.contains(&F) && IsCalled(F) &&
4970  InternalizeFns.insert(&F);
4971  } else if (!F.hasLocalLinkage() && !F.hasFnAttribute(Attribute::Cold)) {
4972  EmitRemark(F);
4973  }
4974  }
4975 
4976  Attributor::internalizeFunctions(InternalizeFns, InternalizedMap);
4977  }
4978 
4979  // Look at every function in the Module unless it was internalized.
4981  for (Function &F : M)
4982  if (!F.isDeclaration() && !InternalizedMap.lookup(&F))
4983  SCC.push_back(&F);
4984 
4985  if (SCC.empty())
4986  return PreservedAnalyses::all();
4987 
4988  AnalysisGetter AG(FAM);
4989 
4990  auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
4992  };
4993 
4995  CallGraphUpdater CGUpdater;
4996 
4997  SetVector<Function *> Functions(SCC.begin(), SCC.end());
4998  OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ Functions, Kernels);
4999 
5000  unsigned MaxFixpointIterations =
5002 
5003  AttributorConfig AC(CGUpdater);
5004  AC.DefaultInitializeLiveInternals = false;
5005  AC.RewriteSignatures = false;
5006  AC.MaxFixpointIterations = MaxFixpointIterations;
5007  AC.OREGetter = OREGetter;
5008  AC.PassName = DEBUG_TYPE;
5009 
5010  Attributor A(Functions, InfoCache, AC);
5011 
5012  OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
5013  bool Changed = OMPOpt.run(true);
5014 
5015  // Optionally inline device functions for potentially better performance.
5017  for (Function &F : M)
5018  if (!F.isDeclaration() && !Kernels.contains(&F) &&
5019  !F.hasFnAttribute(Attribute::NoInline))
5020  F.addFnAttr(Attribute::AlwaysInline);
5021 
5023  LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt Module Pass:\n" << M);
5024 
5025  if (Changed)
5026  return PreservedAnalyses::none();
5027 
5028  return PreservedAnalyses::all();
5029 }
5030 
5033  LazyCallGraph &CG,
5034  CGSCCUpdateResult &UR) {
5035  if (!containsOpenMP(*C.begin()->getFunction().getParent()))
5036  return PreservedAnalyses::all();
5038  return PreservedAnalyses::all();
5039 
5041  // If there are kernels in the module, we have to run on all SCC's.
5042  for (LazyCallGraph::Node &N : C) {
5043  Function *Fn = &N.getFunction();
5044  SCC.push_back(Fn);
5045  }
5046 
5047  if (SCC.empty())
5048  return PreservedAnalyses::all();
5049 
5050  Module &M = *C.begin()->getFunction().getParent();
5051 
5053  LLVM_DEBUG(dbgs() << TAG << "Module before OpenMPOpt CGSCC Pass:\n" << M);
5054 
5056 
5058  AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager();
5059 
5060  AnalysisGetter AG(FAM);
5061 
5062  auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
5064  };
5065 
5067  CallGraphUpdater CGUpdater;
5068  CGUpdater.initialize(CG, C, AM, UR);
5069 
5070  SetVector<Function *> Functions(SCC.begin(), SCC.end());
5071  OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, Allocator,
5072  /*CGSCC*/ Functions, Kernels);
5073 
5074  unsigned MaxFixpointIterations =
5076 
5077  AttributorConfig AC(CGUpdater);
5078  AC.DefaultInitializeLiveInternals = false;
5079  AC.IsModulePass = false;
5080  AC.RewriteSignatures = false;
5081  AC.MaxFixpointIterations = MaxFixpointIterations;
5082  AC.OREGetter = OREGetter;
5083  AC.PassName = DEBUG_TYPE;
5084 
5085  Attributor A(Functions, InfoCache, AC);
5086 
5087  OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
5088  bool Changed = OMPOpt.run(false);
5089 
5091  LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt CGSCC Pass:\n" << M);
5092 
5093  if (Changed)
5094  return PreservedAnalyses::none();
5095 
5096  return PreservedAnalyses::all();
5097 }
5098 
5099 namespace {
5100 
5101 struct OpenMPOptCGSCCLegacyPass : public CallGraphSCCPass {
5102  CallGraphUpdater CGUpdater;
5103  static char ID;
5104 
5105  OpenMPOptCGSCCLegacyPass() : CallGraphSCCPass(ID) {
5107  }
5108 
5109  void getAnalysisUsage(AnalysisUsage &AU) const override {
5111  }
5112 
5113  bool runOnSCC(CallGraphSCC &CGSCC) override {
5114  if (!containsOpenMP(CGSCC.getCallGraph().getModule()))
5115  return false;
5116  if (DisableOpenMPOptimizations || skipSCC(CGSCC))
5117  return false;
5118 
5120  // If there are kernels in the module, we have to run on all SCC's.
5121  for (CallGraphNode *CGN : CGSCC) {
5122  Function *Fn = CGN->getFunction();
5123  if (!Fn || Fn->isDeclaration())
5124  continue;
5125  SCC.push_back(Fn);
5126  }
5127 
5128  if (SCC.empty())
5129  return false;
5130 
5131  Module &M = CGSCC.getCallGraph().getModule();
5133 
5134  CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph();
5135  CGUpdater.initialize(CG, CGSCC);
5136 
5137  // Maintain a map of functions to avoid rebuilding the ORE
5139  auto OREGetter = [&OREMap](Function *F) -> OptimizationRemarkEmitter & {
5140  std::unique_ptr<OptimizationRemarkEmitter> &ORE = OREMap[F];
5141  if (!ORE)
5142  ORE = std::make_unique<OptimizationRemarkEmitter>(F);
5143  return *ORE;
5144  };
5145 
5146  AnalysisGetter AG;
5147  SetVector<Function *> Functions(SCC.begin(), SCC.end());
5149  OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG,
5150  Allocator,
5151  /*CGSCC*/ Functions, Kernels);
5152 
5153  unsigned MaxFixpointIterations =
5155 
5156  AttributorConfig AC(CGUpdater);
5157  AC.DefaultInitializeLiveInternals = false;
5158  AC.IsModulePass = false;
5159  AC.RewriteSignatures = false;
5160  AC.MaxFixpointIterations = MaxFixpointIterations;
5161  AC.OREGetter = OREGetter;
5162  AC.PassName = DEBUG_TYPE;
5163 
5164  Attributor A(Functions, InfoCache, AC);
5165 
5166  OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
5167  bool Result = OMPOpt.run(false);
5168 
5170  LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt CGSCC Pass:\n" << M);
5171 
5172  return Result;
5173  }
5174 
5175  bool doFinalization(CallGraph &CG) override { return CGUpdater.finalize(); }
5176 };
5177 
5178 } // end anonymous namespace
5179 
5181  // TODO: Create a more cross-platform way of determining device kernels.
5182  NamedMDNode *MD = M.getOrInsertNamedMetadata("nvvm.annotations");
5184 
5185  if (!MD)
5186  return Kernels;
5187 
5188  for (auto *Op : MD->operands()) {
5189  if (Op->getNumOperands() < 2)
5190  continue;
5191  MDString *KindID = dyn_cast<MDString>(Op->getOperand(1));
5192  if (!KindID || KindID->getString() != "kernel")
5193  continue;
5194 
5195  Function *KernelFn =
5196  mdconst::dyn_extract_or_null<Function>(Op->getOperand(0));
5197  if (!KernelFn)
5198  continue;
5199 
5200  ++NumOpenMPTargetRegionKernels;
5201 
5202  Kernels.insert(KernelFn);
5203  }
5204 
5205  return Kernels;
5206 }
5207 
5209  Metadata *MD = M.getModuleFlag("openmp");
5210  if (!MD)
5211  return false;
5212 
5213  return true;
5214 }
5215 
5217  Metadata *MD = M.getModuleFlag("openmp-device");
5218  if (!MD)
5219  return false;
5220 
5221  return true;
5222 }
5223 
5225 
5226 INITIALIZE_PASS_BEGIN(OpenMPOptCGSCCLegacyPass, "openmp-opt-cgscc",
5227  "OpenMP specific optimizations", false, false)
5229 INITIALIZE_PASS_END(OpenMPOptCGSCCLegacyPass, "openmp-opt-cgscc",
5230  "OpenMP specific optimizations", false, false)
5231 
5233  return new OpenMPOptCGSCCLegacyPass();
5234 }
llvm::PreservedAnalyses
A set of analyses that are preserved following a run of a transformation pass.
Definition: PassManager.h:152
llvm::CallGraphUpdater
Wrapper to unify "old style" CallGraph and "new style" LazyCallGraph.
Definition: CallGraphUpdater.h:29
llvm::Argument
This class represents an incoming formal argument to a Function.
Definition: Argument.h:28
llvm::IRPosition::function
static const IRPosition function(const Function &F, const CallBaseContext *CBContext=nullptr)
Create a position describing the function scope of F.
Definition: Attributor.h:419
llvm::BasicBlock::end
iterator end()
Definition: BasicBlock.h:299
llvm::StringRef::startswith
LLVM_NODISCARD bool startswith(StringRef Prefix) const
Check if this string starts with the given Prefix.
Definition: StringRef.h:286
llvm::OptimizationRemarkMissed
Diagnostic information for missed-optimization remarks.
Definition: DiagnosticInfo.h:735
llvm::OpenMPIRBuilder::LocationDescription
Description of a LLVM-IR insertion point (IP) and a debug/source location (filename,...
Definition: OMPIRBuilder.h:202
getName
static StringRef getName(Value *V)
Definition: ProvenanceAnalysisEvaluator.cpp:42
SetFixpointIterations
static cl::opt< unsigned > SetFixpointIterations("openmp-opt-max-iterations", cl::Hidden, cl::desc("Maximal number of attributor iterations."), cl::init(256))
MI
IRTranslator LLVM IR MI
Definition: IRTranslator.cpp:104
Merge
R600 Clause Merge
Definition: R600ClauseMergePass.cpp:70
llvm
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:17
PrintModuleAfterOptimizations
static cl::opt< bool > PrintModuleAfterOptimizations("openmp-opt-print-module-after", cl::ZeroOrMore, cl::desc("Print the current module after OpenMP optimizations."), cl::Hidden, cl::init(false))
BlockSize
static const int BlockSize
Definition: TarWriter.cpp:33
llvm::CastInst::CreatePointerBitCastOrAddrSpaceCast
static CastInst * CreatePointerBitCastOrAddrSpaceCast(Value *S, Type *Ty, const Twine &Name, BasicBlock *InsertAtEnd)
Create a BitCast or an AddrSpaceCast cast instruction.
Definition: Instructions.cpp:3309
M
We currently emits eax Perhaps this is what we really should generate is Is imull three or four cycles eax eax The current instruction priority is based on pattern complexity The former is more complex because it folds a load so the latter will not be emitted Perhaps we should use AddedComplexity to give LEA32r a higher priority We should always try to match LEA first since the LEA matching code does some estimate to determine whether the match is profitable if we care more about code then imull is better It s two bytes shorter than movl leal On a Pentium M
Definition: README.txt:252
llvm::Instruction::getModule
const Module * getModule() const
Return the module owning the function this instruction belongs to or nullptr it the function does not...
Definition: Instruction.cpp:65
llvm::CmpInst::ICMP_EQ
@ ICMP_EQ
equal
Definition: InstrTypes.h:740
llvm::NamedMDNode
A tuple of MDNodes.
Definition: Metadata.h:1477
llvm::drop_begin
auto drop_begin(T &&RangeOrContainer, size_t N=1)
Return a range covering RangeOrContainer with the first N elements excluded.
Definition: STLExtras.h:280
OpenMPOpt.h
llvm::SPIRV::LinkageType
LinkageType
Definition: SPIRVBaseInfo.h:431
llvm::DataLayout
A parsed version of the target data layout string in and methods for querying it.
Definition: DataLayout.h:113
llvm::ISD::OR
@ OR
Definition: ISDOpcodes.h:667
llvm::Value::hasOneUse
bool hasOneUse() const
Return true if there is exactly one use of this value.
Definition: Value.h:434
llvm::Type::getInt8PtrTy
static PointerType * getInt8PtrTy(LLVMContext &C, unsigned AS=0)
Definition: Type.cpp:291
llvm::BasicBlock::getParent
const Function * getParent() const
Return the enclosing method, or null if none.
Definition: BasicBlock.h:104
IntrinsicInst.h
llvm::AnalysisManager::getResult
PassT::Result & getResult(IRUnitT &IR, ExtraArgTs... ExtraArgs)
Get the result of an analysis pass for a given IR unit.
Definition: PassManager.h:780
llvm::cl::Prefix
@ Prefix
Definition: CommandLine.h:160
llvm::MemTransferInst
This class wraps the llvm.memcpy/memmove intrinsics.
Definition: IntrinsicInst.h:1005
llvm::Function
Definition: Function.h:60
llvm::Attribute
Definition: Attributes.h:52
llvm::DenseMapBase< DenseMap< KeyT, ValueT, DenseMapInfo< KeyT >, llvm::detail::DenseMapPair< KeyT, ValueT > >, KeyT, ValueT, DenseMapInfo< KeyT >, llvm::detail::DenseMapPair< KeyT, ValueT > >::lookup
ValueT lookup(const_arg_type_t< KeyT > Val) const
lookup - Return the entry for the specified key, or a default constructed value if no such entry exis...
Definition: DenseMap.h:199
StringRef.h
P
This currently compiles esp xmm0 movsd esp eax eax esp ret We should use not the dag combiner This is because dagcombine2 needs to be able to see through the X86ISD::Wrapper which DAGCombine can t really do The code for turning x load into a single vector load is target independent and should be moved to the dag combiner The code for turning x load into a vector load can only handle a direct load from a global or a direct load from the stack It should be generalized to handle any load from P
Definition: README-SSE.txt:411
DisableOpenMPOptSPMDization
static cl::opt< bool > DisableOpenMPOptSPMDization("openmp-opt-disable-spmdization", cl::ZeroOrMore, cl::desc("Disable OpenMP optimizations involving SPMD-ization."), cl::Hidden, cl::init(false))
TAG
static constexpr auto TAG
Definition: OpenMPOpt.cpp:173
llvm::raw_string_ostream
A raw_ostream that writes to an std::string.
Definition: raw_ostream.h:632
llvm::SetVector< T, SmallVector< T, N >, SmallDenseSet< T, N > >::size
size_type size() const
Determine the number of elements in the SetVector.
Definition: SetVector.h:77
contains
return AArch64::GPR64RegClass contains(Reg)
llvm::GlobalValue::NotThreadLocal
@ NotThreadLocal
Definition: GlobalValue.h:178
llvm::ilist_node_with_parent::getNextNode
NodeTy * getNextNode()
Get the next node, or nullptr for the list tail.
Definition: ilist_node.h:289
llvm::SmallVector
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1185
Statistic.h
llvm::initializeOpenMPOptCGSCCLegacyPassPass
void initializeOpenMPOptCGSCCLegacyPassPass(PassRegistry &)
llvm::IRPosition::value
static const IRPosition value(const Value &V, const CallBaseContext *CBContext=nullptr)
Create a position describing the value of V.
Definition: Attributor.h:400
llvm::OpenMPOptCGSCCPass::run
PreservedAnalyses run(LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM, LazyCallGraph &CG, CGSCCUpdateResult &UR)
Definition: OpenMPOpt.cpp:5031
llvm::Type::getPointerAddressSpace
unsigned getPointerAddressSpace() const
Get the address space of this pointer or pointer vector type.
Definition: DerivedTypes.h:729
llvm::Function::getEntryBlock
const BasicBlock & getEntryBlock() const
Definition: Function.h:709
llvm::OpenMPIRBuilder::InsertPointTy
IRBuilder<>::InsertPoint InsertPointTy
Type used throughout for insertion points.
Definition: OMPIRBuilder.h:96
llvm::IRBuilder
This provides a uniform API for creating instructions and inserting them into a basic block: either a...
Definition: IRBuilder.h:2489
llvm::StateWrapper
Helper to tie a abstract state implementation to an abstract attribute.
Definition: Attributor.h:2763
llvm::GlobalVariable
Definition: GlobalVariable.h:39
llvm::SmallDenseMap
Definition: DenseMap.h:882
llvm::CmpInst::ICMP_NE
@ ICMP_NE
not equal
Definition: InstrTypes.h:741
llvm::FunctionType::get
static FunctionType * get(Type *Result, ArrayRef< Type * > Params, bool isVarArg)
This static method is the primary way of constructing a FunctionType.
Definition: Type.cpp:361
ValueTracking.h
OptimizationRemarkEmitter.h
llvm::CallGraph
The basic data container for the call graph of a Module of IR.
Definition: CallGraph.h:72
llvm::DominatorTree
Concrete subclass of DominatorTreeBase that is used to compute a normal dominator tree.
Definition: Dominators.h:166
FAM
FunctionAnalysisManager FAM
Definition: PassBuilderBindings.cpp:59
llvm::cl::Hidden
@ Hidden
Definition: CommandLine.h:139
DisableOpenMPOptFolding
static cl::opt< bool > DisableOpenMPOptFolding("openmp-opt-disable-folding", cl::ZeroOrMore, cl::desc("Disable OpenMP optimizations involving folding."), cl::Hidden, cl::init(false))
llvm::DenseMapBase< DenseMap< KeyT, ValueT, DenseMapInfo< KeyT >, llvm::detail::DenseMapPair< KeyT, ValueT > >, KeyT, ValueT, DenseMapInfo< KeyT >, llvm::detail::DenseMapPair< KeyT, ValueT > >::begin
iterator begin()
Definition: DenseMap.h:75
llvm::PreservedAnalyses::none
static PreservedAnalyses none()
Convenience factory function for the empty preserved set.
Definition: PassManager.h:155
llvm::Type
The instances of the Type class are immutable: once they are created, they are never changed.
Definition: Type.h:45
llvm::sys::path::end
const_iterator end(StringRef path)
Get end iterator over path.
Definition: Path.cpp:235
llvm::sys::path::begin
const_iterator begin(StringRef path, Style style=Style::native)
Get begin iterator over path.
Definition: Path.cpp:226
llvm::CallBase::isCallee
bool isCallee(Value::const_user_iterator UI) const
Determine whether the passed iterator points to the callee operand's Use.
Definition: InstrTypes.h:1407
llvm::Use::getOperandNo
unsigned getOperandNo() const
Return the operand # of this use in its User.
Definition: Use.cpp:31
llvm::BasicBlock::splitBasicBlock
BasicBlock * splitBasicBlock(iterator I, const Twine &BBName="", bool Before=false)
Split the basic block into two basic blocks at the specified instruction.
Definition: BasicBlock.cpp:378
llvm::operator!=
bool operator!=(uint64_t V1, const APInt &V2)
Definition: APInt.h:2004
llvm::InformationCache
Data structure to hold cached (LLVM-IR) information.
Definition: Attributor.h:945
llvm::MemIntrinsic
This is the common base class for memset/memcpy/memmove.
Definition: IntrinsicInst.h:962
llvm::AttributorConfig
Configuration for the Attributor.
Definition: Attributor.h:1191
llvm::Optional
Definition: APInt.h:33
llvm::SmallPtrSet< Instruction *, 4 >
llvm::ore::NV
DiagnosticInfoOptimizationBase::Argument NV
Definition: OptimizationRemarkEmitter.h:136
llvm::tgtok::FalseVal
@ FalseVal
Definition: TGLexer.h:61
llvm::omp::AddressSpace::Constant
@ Constant
llvm::IRPosition::IRP_ARGUMENT
@ IRP_ARGUMENT
An attribute for a function argument.
Definition: Attributor.h:390
llvm::IRPosition::IRP_RETURNED
@ IRP_RETURNED
An attribute for the function return value.
Definition: Attributor.h:386
llvm::CallingConv::Cold
@ Cold
Definition: CallingConv.h:48
llvm::MipsISD::Ret
@ Ret
Definition: MipsISelLowering.h:119
PrintModuleBeforeOptimizations
static cl::opt< bool > PrintModuleBeforeOptimizations("openmp-opt-print-module-before", cl::ZeroOrMore, cl::desc("Print the current module before OpenMP optimizations."), cl::Hidden, cl::init(false))
initialize
static void initialize(TargetLibraryInfoImpl &TLI, const Triple &T, ArrayRef< StringLiteral > StandardNames)
Initialize the set of available library functions based on the specified target triple.
Definition: TargetLibraryInfo.cpp:116
llvm::AbstractCallSite
AbstractCallSite.
Definition: AbstractCallSite.h:50
llvm::SmallVectorImpl::pop_back_val
LLVM_NODISCARD T pop_back_val()
Definition: SmallVector.h:654
RHS
Value * RHS
Definition: X86PartialReduction.cpp:76
llvm::GlobalValue::LinkageTypes
LinkageTypes
An enumeration for the kinds of linkage for global values.
Definition: GlobalValue.h:47
llvm::AAExecutionDomain::ID
static const char ID
Unique ID (due to the unique address)
Definition: Attributor.h:4752
llvm::CallBase::addParamAttr
void addParamAttr(unsigned ArgNo, Attribute::AttrKind Kind)
Adds the attribute to the indicated argument.
Definition: InstrTypes.h:1526
llvm::Type::getInt8Ty
static IntegerType * getInt8Ty(LLVMContext &C)
Definition: Type.cpp:237
llvm::IRPosition::IRP_FLOAT
@ IRP_FLOAT
A position that is not associated with a spot suitable for attributes.
Definition: Attributor.h:384
llvm::Type::getInt32Ty
static IntegerType * getInt32Ty(LLVMContext &C)
Definition: Type.cpp:239
llvm::AA::isValidAtPosition
bool isValidAtPosition(const Value &V, const Instruction &CtxI, InformationCache &InfoCache)
Return true if V is a valid value at position CtxI, that is a constant, an argument of the same funct...
Definition: Attributor.cpp:245
LLVM_DEBUG
#define LLVM_DEBUG(X)
Definition: Debug.h:101
llvm::Instruction::mayHaveSideEffects
bool mayHaveSideEffects() const
Return true if the instruction may have side effects.
Definition: Instruction.cpp:695
PrintICVValues
static cl::opt< bool > PrintICVValues("openmp-print-icv-values", cl::init(false), cl::Hidden)
F
#define F(x, y, z)
Definition: MD5.cpp:55
llvm::IntegerStateBase< bool, true, false >::operator^=
void operator^=(const IntegerStateBase< base_t, BestState, WorstState > &R)
"Clamp" this state with R.
Definition: Attributor.h:2279
llvm::ConstantExpr::getPointerCast
static Constant * getPointerCast(Constant *C, Type *Ty)
Create a BitCast, AddrSpaceCast, or a PtrToInt cast constant expression.
Definition: Constants.cpp:2070
llvm::BasicBlock::getUniqueSuccessor
const BasicBlock * getUniqueSuccessor() const
Return the successor of this block if it has a unique successor.
Definition: BasicBlock.cpp:299
llvm::BasicBlock
LLVM Basic Block Representation.
Definition: BasicBlock.h:55
llvm::AMDGPU::isKernel
LLVM_READNONE bool isKernel(CallingConv::ID CC)
Definition: AMDGPUBaseInfo.h:782
llvm::Optional::hasValue
constexpr bool hasValue() const
Definition: Optional.h:283
llvm::dbgs
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:163
llvm::CallGraphUpdater::finalize
bool finalize()
}
Definition: CallGraphUpdater.cpp:24
Arg
amdgpu Simplify well known AMD library false FunctionCallee Value * Arg
Definition: AMDGPULibCalls.cpp:186
Instruction.h
llvm::AMDGPU::HSAMD::Key::Kernels
constexpr char Kernels[]
Key for HSA::Metadata::mKernels.
Definition: AMDGPUMetadata.h:432
CommandLine.h
Printer
print alias Alias Set Printer
Definition: AliasSetTracker.cpp:758
llvm::ConstantInt
This is the shared class of boolean and integer constants.
Definition: Constants.h:79
llvm::all_of
bool all_of(R &&range, UnaryPredicate P)
Provide wrappers to std::all_of which take ranges instead of having to pass begin/end explicitly.
Definition: STLExtras.h:1605
llvm::operator&=
bool operator&=(SparseBitVector< ElementSize > *LHS, const SparseBitVector< ElementSize > &RHS)
Definition: SparseBitVector.h:835
GlobalValue.h
llvm::CallGraphSCC
CallGraphSCC - This is a single SCC that a CallGraphSCCPass is run on.
Definition: CallGraphSCCPass.h:87
llvm::omp::OMPScheduleType::OrderedDistribute
@ OrderedDistribute
llvm::AddrSpaceCastInst
This class represents a conversion between pointers from one address space to another.
Definition: Instructions.h:5265
llvm::SetVector::begin
iterator begin()
Get an iterator to the beginning of the SetVector.
Definition: SetVector.h:82
llvm::PassRegistry::getPassRegistry
static PassRegistry * getPassRegistry()
getPassRegistry - Access the global registry object, which is automatically initialized at applicatio...
Definition: PassRegistry.cpp:31
llvm::LazyCallGraph::SCC
An SCC of the call graph.
Definition: LazyCallGraph.h:419
llvm::GlobalValue::isDeclaration
bool isDeclaration() const
Return true if the primary definition of this global value is outside of the current translation unit...
Definition: Globals.cpp:241
llvm::MutableArrayRef
MutableArrayRef - Represent a mutable reference to an array (0 or more elements consecutively in memo...
Definition: ArrayRef.h:306
OMPIRBuilder.h
Constants.h
llvm::omp::OMP_TGT_EXEC_MODE_GENERIC
@ OMP_TGT_EXEC_MODE_GENERIC
Definition: OMPConstants.h:189
llvm::ChangeStatus
ChangeStatus
{
Definition: Attributor.h:282
E
static GCRegistry::Add< CoreCLRGC > E("coreclr", "CoreCLR-compatible GC")
llvm::CallGraphUpdater::removeCallSite
void removeCallSite(CallBase &CS)
Remove the call site CS from the call graph.
Definition: CallGraphUpdater.cpp:162
llvm::OpenMPIRBuilder
An interface to create LLVM-IR for OpenMP directives.
Definition: OMPIRBuilder.h:75
llvm::User
Definition: User.h:44
INITIALIZE_PASS_BEGIN
INITIALIZE_PASS_BEGIN(OpenMPOptCGSCCLegacyPass, "openmp-opt-cgscc", "OpenMP specific optimizations", false, false) INITIALIZE_PASS_END(OpenMPOptCGSCCLegacyPass
llvm::AAExecutionDomain::isExecutedByInitialThreadOnly
virtual bool isExecutedByInitialThreadOnly(const Instruction &) const =0
Check if an instruction is executed only by the initial thread.
C
(vector float) vec_cmpeq(*A, *B) C
Definition: README_ALTIVEC.txt:86
llvm::ARM_PROC::A
@ A
Definition: ARMBaseInfo.h:34
llvm::IntegerStateBase< bool, true, false >::operator==
bool operator==(const IntegerStateBase< base_t, BestState, WorstState > &R) const
Equality for IntegerStateBase.
Definition: Attributor.h:2265
llvm::CallBase::getCalledFunction
Function * getCalledFunction() const
Returns the function called, or null if thi